diff --git a/extension/src/cli/dvc/contract.ts b/extension/src/cli/dvc/contract.ts index 0e4799d17b..7e274854b7 100644 --- a/extension/src/cli/dvc/contract.ts +++ b/extension/src/cli/dvc/contract.ts @@ -233,10 +233,16 @@ export type TemplatePlotOutput = { type: PlotsType } +export type BoundingBox = { + box: { left: number; right: number; top: number; bottom: number } + score: number +} + export type ImagePlotOutput = { revisions: string[] type: PlotsType url: string + annotations?: { [label: string]: BoundingBox[] } } export type PlotOutput = TemplatePlotOutput | ImagePlotOutput diff --git a/extension/src/common/colors.ts b/extension/src/common/colors.ts new file mode 100644 index 0000000000..c7cb3aeda9 --- /dev/null +++ b/extension/src/common/colors.ts @@ -0,0 +1,41 @@ +const colorsList = [ + '#945dd6', + '#13adc7', + '#f46837', + '#48bb78', + '#4299e1', + '#ed8936', + '#f56565' +] as const + +export type Color = (typeof colorsList)[number] + +export const copyOriginalColors = (): Color[] => [...colorsList] + +const boundingBoxColorsList = [ + '#ff3838', + '#ff9d97', + '#ff701f', + '#ffb21d', + '#cfd231', + '#48f90a', + '#92cc17', + '#3ddb86', + '#1a9334', + '#00d4bb', + '#2c99a8', + '#00c2ff', + '#344593', + '#6473ff', + '#0018ec', + '#8438ff', + '#520085', + '#cb38ff', + '#ff95c8', + '#ff37c7' +] as const + +export type BoundingBoxColor = (typeof boundingBoxColorsList)[number] + +export const getBoundingBoxColor = (ind: number): BoundingBoxColor => + boundingBoxColorsList[ind % boundingBoxColorsList.length] diff --git a/extension/src/experiments/index.ts b/extension/src/experiments/index.ts index ed644b4aec..d88c56bd64 100644 --- a/extension/src/experiments/index.ts +++ b/extension/src/experiments/index.ts @@ -27,7 +27,6 @@ import { pickFilterToAdd, pickFiltersToRemove } from './model/filterBy/quickPick' -import { Color } from './model/status/colors' import { MAX_SELECTED_EXPERIMENTS, UNSELECTED } from './model/status' import { starredSort } from './model/sortBy/constants' import { pickSortsToRemove, pickSortToAdd } from './model/sortBy/quickPick' @@ -39,6 +38,7 @@ import { Experiment, ColumnType, TableData, Column } from './webview/contract' import { WebviewMessages } from './webview/messages' import { DecorationProvider } from './model/decorationProvider' import { starredFilter } from './model/filterBy/constants' +import { Color } from '../common/colors' import { ResourceLocator } from '../resourceLocator' import { AvailableCommands, InternalCommands } from '../commands/internal' import { diff --git a/extension/src/experiments/model/index.test.ts b/extension/src/experiments/model/index.test.ts index 644d58adbb..1f98422452 100644 --- a/extension/src/experiments/model/index.test.ts +++ b/extension/src/experiments/model/index.test.ts @@ -2,7 +2,7 @@ import { join } from 'path' import { commands } from 'vscode' import { ExperimentsModel } from '.' -import { copyOriginalColors } from './status/colors' +import { copyOriginalColors } from '../../common/colors' import gitLogFixture from '../../test/fixtures/expShow/base/gitLog' import rowOrderFixture from '../../test/fixtures/expShow/base/rowOrder' import outputFixture from '../../test/fixtures/expShow/base/output' diff --git a/extension/src/experiments/model/index.ts b/extension/src/experiments/model/index.ts index b86ee85f81..da8638b2c9 100644 --- a/extension/src/experiments/model/index.ts +++ b/extension/src/experiments/model/index.ts @@ -16,9 +16,9 @@ import { collectSelectedColors, collectStartedRunningExperiments } from './status/collect' -import { Color, copyOriginalColors } from './status/colors' import { canSelect, ColoredStatus, UNSELECTED } from './status' import { collectFlatExperimentParams } from './modify/collect' +import { Color, copyOriginalColors } from '../../common/colors' import { Commit, Experiment, diff --git a/extension/src/experiments/model/status/collect.test.ts b/extension/src/experiments/model/status/collect.test.ts index f21db36dc1..3705e55fbf 100644 --- a/extension/src/experiments/model/status/collect.test.ts +++ b/extension/src/experiments/model/status/collect.test.ts @@ -1,6 +1,6 @@ import { UNSELECTED } from '.' import { collectColoredStatus } from './collect' -import { copyOriginalColors } from './colors' +import { copyOriginalColors } from '../../../common/colors' import { Experiment } from '../../webview/contract' import { ExecutorStatus, diff --git a/extension/src/experiments/model/status/collect.ts b/extension/src/experiments/model/status/collect.ts index 1baa940cf2..4f326c4e6e 100644 --- a/extension/src/experiments/model/status/collect.ts +++ b/extension/src/experiments/model/status/collect.ts @@ -5,7 +5,7 @@ import { tooManySelected, UNSELECTED } from '.' -import { Color, copyOriginalColors } from './colors' +import { Color, copyOriginalColors } from '../../../common/colors' import { hasKey } from '../../../util/object' import { Experiment, isQueued, RunningExperiment } from '../../webview/contract' import { definedAndNonEmpty, reorderListSubset } from '../../../util/array' diff --git a/extension/src/experiments/model/status/colors.ts b/extension/src/experiments/model/status/colors.ts deleted file mode 100644 index be5080ae43..0000000000 --- a/extension/src/experiments/model/status/colors.ts +++ /dev/null @@ -1,13 +0,0 @@ -const colorsList = [ - '#945dd6', - '#13adc7', - '#f46837', - '#48bb78', - '#4299e1', - '#ed8936', - '#f56565' -] as const - -export type Color = (typeof colorsList)[number] - -export const copyOriginalColors = (): Color[] => [...colorsList] diff --git a/extension/src/experiments/model/status/index.test.ts b/extension/src/experiments/model/status/index.test.ts index b993c3549c..a6bc606fac 100644 --- a/extension/src/experiments/model/status/index.test.ts +++ b/extension/src/experiments/model/status/index.test.ts @@ -1,5 +1,5 @@ import { canSelect, limitToMaxSelected } from '.' -import { copyOriginalColors } from './colors' +import { copyOriginalColors } from '../../../common/colors' import { Experiment } from '../../webview/contract' import { ExecutorStatus } from '../../../cli/dvc/contract' diff --git a/extension/src/experiments/model/status/index.ts b/extension/src/experiments/model/status/index.ts index 181ed29d7b..055643e5d0 100644 --- a/extension/src/experiments/model/status/index.ts +++ b/extension/src/experiments/model/status/index.ts @@ -1,4 +1,4 @@ -import { Color } from './colors' +import { Color } from '../../../common/colors' import { Experiment, isRunning } from '../../webview/contract' export const MAX_SELECTED_EXPERIMENTS = 7 diff --git a/extension/src/persistence/constants.ts b/extension/src/persistence/constants.ts index fe5538a8d4..471f35ae37 100644 --- a/extension/src/persistence/constants.ts +++ b/extension/src/persistence/constants.ts @@ -20,6 +20,7 @@ export enum PersistenceKey { PLOT_SELECTED_METRICS = 'plotSelectedMetrics:', PLOTS_SMOOTH_PLOT_VALUES = 'plotSmoothPlotValues:', PLOTS_COMPARISON_MULTI_PLOT_VALUES = 'plotComparisonMultiPlotValues:', + PLOTS_COMPARISON_CLASSES_SELECTED = 'plotComparisonClassesSelected:', PLOT_TEMPLATE_ORDER = 'plotTemplateOrder:', SHOW_ONLY_CHANGED = 'columnsShowOnlyChanged:' } diff --git a/extension/src/plots/model/collect.test.ts b/extension/src/plots/model/collect.test.ts index 24f8355b7b..24fe0256e0 100644 --- a/extension/src/plots/model/collect.test.ts +++ b/extension/src/plots/model/collect.test.ts @@ -114,12 +114,14 @@ describe('collectData', () => { ]) const heatmapPlot = join('plots', 'heatmap.png') + const boundingBoxPlot = join('plots', 'bounding_boxes.png') expect(Object.keys(comparisonData.main)).toStrictEqual([ join('plots', 'acc.png'), heatmapPlot, join('plots', 'loss.png'), - join('plots', 'image') + join('plots', 'image'), + boundingBoxPlot ]) const testBranchHeatmap = comparisonData['test-branch'][heatmapPlot] diff --git a/extension/src/plots/model/collect.ts b/extension/src/plots/model/collect.ts index f3c5d54129..4482d70a7a 100644 --- a/extension/src/plots/model/collect.ts +++ b/extension/src/plots/model/collect.ts @@ -1,5 +1,6 @@ import get from 'lodash.get' import type { TopLevelSpec } from 'vega-lite' +import isEmpty from 'lodash.isempty' import { getContent, CustomPlotsOrderValue, @@ -14,7 +15,12 @@ import { CustomPlotData, CustomPlotValues, ComparisonRevisionData, - ComparisonPlotImg + ComparisonPlotImg, + ComparisonClassDetails, + ComparisonPlotClasses, + ComparisonClassesSelected, + ComparisonPlotClass, + ComparisonPlotRow } from '../webview/contract' import { AnchorDefinitions, @@ -23,7 +29,8 @@ import { PlotOutput, PlotsOutput, PlotsType, - TemplatePlotOutput + TemplatePlotOutput, + ImagePlotOutput } from '../../cli/dvc/contract' import { splitColumnPath } from '../../experiments/columns/paths' import { ColumnType, Experiment } from '../../experiments/webview/contract' @@ -36,7 +43,7 @@ import { getParent, getPathArray } from '../../fileSystem/util' -import { Color } from '../../experiments/model/status/colors' +import { Color, getBoundingBoxColor } from '../../common/colors' export const getCustomPlotId = (metric: string, param: string) => `custom-${metric}-${param}` @@ -208,7 +215,7 @@ const getMultiImageInd = (path: string) => { const collectImageData = ( acc: ComparisonData, path: string, - plot: ImagePlot + plot: ImagePlotOutput ) => { const isMultiImgPlot = MULTI_IMAGE_PATH_REG.test(path) const pathLabel = isMultiImgPlot ? getMultiImagePath(path) : path @@ -226,12 +233,20 @@ const collectImageData = ( acc[id][pathLabel] = [] } - const imgPlot: ImagePlot = { ...plot } + const imgPlot: ImagePlot = { + revisions: plot.revisions, + type: plot.type, + url: plot.url + } if (isMultiImgPlot) { imgPlot.ind = getMultiImageInd(path) } + if (plot.annotations) { + imgPlot.annotations = plot.annotations + } + acc[id][pathLabel].push(imgPlot) } @@ -312,60 +327,109 @@ export const collectData = (output: PlotsOutput): DataAccumulator => { return acc } -type ComparisonPlotsAcc = { path: string; revisions: ComparisonRevisionData }[] - type GetComparisonPlotImg = ( img: ImagePlot, id: string, path: string ) => ComparisonPlotImg +const collectSelectedPlotImgClassLabels = ( + boundingBoxClassLabels: Set, + imgs: ImagePlot[] = [] +) => { + for (const img of imgs) { + if (!img.annotations) { + continue + } + + for (const label of Object.keys(img.annotations)) { + boundingBoxClassLabels.add(label) + } + } +} + +const getSelectedPathComparisonPlotClassDetails = ( + boundingBoxClassLabels: Set, + comparisonClassesSelected: ComparisonClassesSelected, + path: string +) => { + const classDetails: ComparisonClassDetails = {} + + let classLabelInd = 0 + for (const label of boundingBoxClassLabels) { + const selectedState = comparisonClassesSelected[path]?.[label] + classDetails[label] = { + color: getBoundingBoxColor(classLabelInd), + selected: selectedState === undefined ? true : selectedState + } + classLabelInd++ + } + + return classDetails +} + const collectSelectedPathComparisonPlots = ({ acc, comparisonData, + comparisonClassesSelected, path, selectedRevisionIds, getComparisonPlotImg }: { - acc: ComparisonPlotsAcc + acc: ComparisonPlotRow[] comparisonData: ComparisonData + comparisonClassesSelected: ComparisonClassesSelected path: string selectedRevisionIds: string[] getComparisonPlotImg: GetComparisonPlotImg }) => { + const boundingBoxClassLabels = new Set() const pathRevisions = { + classDetails: {} as ComparisonClassDetails, path, revisions: {} as ComparisonRevisionData } for (const id of selectedRevisionIds) { - const imgs = comparisonData[id]?.[path] + const imgs: ImagePlot[] | undefined = comparisonData[id]?.[path] + pathRevisions.revisions[id] = { id, imgs: imgs ? imgs.map(img => getComparisonPlotImg(img, id, path)) : [{ errors: undefined, loading: false, url: undefined }] } + collectSelectedPlotImgClassLabels(boundingBoxClassLabels, imgs) } + + pathRevisions.classDetails = getSelectedPathComparisonPlotClassDetails( + boundingBoxClassLabels, + comparisonClassesSelected, + path + ) + acc.push(pathRevisions) } export const collectSelectedComparisonPlots = ({ comparisonData, + comparisonClassesSelected, paths, selectedRevisionIds, getComparisonPlotImg }: { comparisonData: ComparisonData + comparisonClassesSelected: ComparisonClassesSelected paths: string[] selectedRevisionIds: string[] getComparisonPlotImg: GetComparisonPlotImg -}) => { - const acc: ComparisonPlotsAcc = [] +}): ComparisonPlotRow[] => { + const acc: ComparisonPlotRow[] = [] for (const path of paths) { collectSelectedPathComparisonPlots({ acc, + comparisonClassesSelected, comparisonData, getComparisonPlotImg, path, @@ -376,6 +440,99 @@ export const collectSelectedComparisonPlots = ({ return acc } +const getSelectedImgComparisonPlotClasses = ({ + selectedClassLabels, + img +}: { + selectedClassLabels: string[] + img: ImagePlot + path: string +}) => { + const imgAnnotations = img.annotations + if (!imgAnnotations) { + return [] + } + const imgClasses: ComparisonPlotClass[] = [] + + for (const label of selectedClassLabels) { + const boxes = imgAnnotations[label] + + if (boxes) { + imgClasses.push({ + boxes, + label + }) + } + } + + return imgClasses +} + +const collectedSelectedPathComparisonPlotClasses = ({ + acc, + id, + comparisonData, + path, + selectedClassLabels +}: { + acc: ComparisonPlotClasses + selectedClassLabels: string[] + comparisonData: ComparisonData + path: string + id: string +}) => { + for (const img of comparisonData[id][path]) { + const imgClasses = getSelectedImgComparisonPlotClasses({ + img, + path, + selectedClassLabels + }) + + if (imgClasses.length === 0) { + return + } + + if (!acc[id]) { + acc[id] = {} + } + + acc[id][path] = imgClasses + } +} + +export const collectSelectedComparisonPlotClasses = ({ + comparisonData, + plots, + selectedRevisionIds +}: { + comparisonData: ComparisonData + plots: ComparisonPlotRow[] + selectedRevisionIds: string[] +}): ComparisonPlotClasses => { + const acc: ComparisonPlotClasses = {} + + for (const { path, classDetails } of plots) { + const selectedClassLabels = Object.keys(classDetails).filter( + (label: string) => classDetails[label].selected + ) + if (isEmpty(classDetails) || isEmpty(selectedClassLabels)) { + continue + } + + for (const id of selectedRevisionIds) { + collectedSelectedPathComparisonPlotClasses({ + acc, + comparisonData, + id, + path, + selectedClassLabels + }) + } + } + + return acc +} + export type TemplateDetailsAccumulator = { [path: string]: { content: TopLevelSpec diff --git a/extension/src/plots/model/index.ts b/extension/src/plots/model/index.ts index 394d5f77fd..b298d0e959 100644 --- a/extension/src/plots/model/index.ts +++ b/extension/src/plots/model/index.ts @@ -13,7 +13,8 @@ import { collectIdShas, collectSelectedTemplatePlotRawData, collectCustomPlotRawData, - collectSelectedComparisonPlots + collectSelectedComparisonPlots, + collectSelectedComparisonPlotClasses } from './collect' import { getRevisionSummaryColumns } from './util' import { cleanupOldOrderValue, CustomPlotsOrderValue } from './custom' @@ -30,7 +31,9 @@ import { SmoothPlotValues, ImagePlot, ComparisonMultiPlotValues, - ComparisonPlotImg + ComparisonPlotImg, + ComparisonClassesSelected, + ComparisonPlotRow } from '../webview/contract' import { EXPERIMENT_WORKSPACE_ID, @@ -76,6 +79,8 @@ export class PlotsModel extends ModelWithPersistence { private comparisonData: ComparisonData = {} private comparisonOrder: string[] private comparisonMultiPlotValues: ComparisonMultiPlotValues = {} + private comparisonClassesSelected: ComparisonClassesSelected = {} + private smoothPlotValues: SmoothPlotValues = {} private revisionData: RevisionData = {} @@ -112,6 +117,10 @@ export class PlotsModel extends ModelWithPersistence { PersistenceKey.PLOTS_COMPARISON_MULTI_PLOT_VALUES, {} ) + this.comparisonClassesSelected = this.revive( + PersistenceKey.PLOTS_COMPARISON_CLASSES_SELECTED, + {} + ) this.cleanupOutdatedCustomPlotsState() this.cleanupOutdatedTrendsState() @@ -266,6 +275,19 @@ export class PlotsModel extends ModelWithPersistence { return this.getSelectedComparisonPlots(paths, selectedRevisionIds) } + public getComparisonPlotClasses(plots: ComparisonPlotRow[]) { + const selectedRevisionIds = this.getSelectedRevisionIds() + if (!definedAndNonEmpty(selectedRevisionIds)) { + return {} + } + + return collectSelectedComparisonPlotClasses({ + comparisonData: this.comparisonData, + plots, + selectedRevisionIds + }) + } + public requiresUpdate() { return !sameContents([...this.fetchedRevs], this.getSelectedRevisionIds()) } @@ -310,10 +332,26 @@ export class PlotsModel extends ModelWithPersistence { ) } + public toggleComparisonClass(path: string, label: string, selected: boolean) { + if (!this.comparisonClassesSelected[path]) { + this.comparisonClassesSelected[path] = {} + } + + this.comparisonClassesSelected[path][label] = selected + this.persist( + PersistenceKey.PLOTS_COMPARISON_MULTI_PLOT_VALUES, + this.comparisonClassesSelected + ) + } + public getComparisonMultiPlotValues() { return this.comparisonMultiPlotValues } + public getComparisonClassesSelected() { + return this.comparisonClassesSelected + } + public getSelectedRevisionIds() { return this.experiments.getSelectedRevisions().map(({ id }) => id) } @@ -432,6 +470,7 @@ export class PlotsModel extends ModelWithPersistence { selectedRevisionIds: string[] ) { return collectSelectedComparisonPlots({ + comparisonClassesSelected: this.getComparisonClassesSelected(), comparisonData: this.comparisonData, getComparisonPlotImg: (image: ImagePlot, id: string, path: string) => { const errors = this.errors.getImageErrors(path, id) diff --git a/extension/src/plots/paths/collect.test.ts b/extension/src/plots/paths/collect.test.ts index 7284b46f87..47aa413486 100644 --- a/extension/src/plots/paths/collect.test.ts +++ b/extension/src/plots/paths/collect.test.ts @@ -64,6 +64,14 @@ const plotsDiffFixturePaths: PlotPath[] = [ revisions: new Set(REVISIONS), type: new Set([PathType.COMPARISON]) }, + { + hasChildren: false, + label: 'bounding_boxes.png', + parentPath: 'plots', + path: join('plots', 'bounding_boxes.png'), + revisions: new Set(REVISIONS), + type: new Set([PathType.COMPARISON]) + }, { hasChildren: false, label: 'loss.tsv', @@ -162,6 +170,7 @@ describe('collectPaths', () => { join('plots', 'heatmap.png'), join('plots', 'loss.png'), join('plots', 'image'), + join('plots', 'bounding_boxes.png'), join('logs', 'loss.tsv'), join('logs', 'acc.tsv'), 'predictions.json' diff --git a/extension/src/plots/paths/model.test.ts b/extension/src/plots/paths/model.test.ts index 7665aaeb15..c0f415cc2c 100644 --- a/extension/src/plots/paths/model.test.ts +++ b/extension/src/plots/paths/model.test.ts @@ -81,6 +81,15 @@ describe('PathsModel', () => { selected: true, type: comparisonType }, + { + hasChildren: false, + label: 'bounding_boxes.png', + parentPath: 'plots', + path: join('plots', 'bounding_boxes.png'), + revisions: new Set(REVISIONS), + selected: true, + type: comparisonType + }, { hasChildren: false, label: 'loss.tsv', @@ -369,14 +378,16 @@ describe('PathsModel', () => { join('plots', 'acc.png'), join('plots', 'heatmap.png'), join('plots', 'loss.png'), - join('plots', 'image') + join('plots', 'image'), + join('plots', 'bounding_boxes.png') ]) const newOrder = [ join('plots', 'heatmap.png'), join('plots', 'acc.png'), join('plots', 'loss.png'), - join('plots', 'image') + join('plots', 'image'), + join('plots', 'bounding_boxes.png') ] model.setComparisonPathsOrder(newOrder) @@ -411,7 +422,7 @@ describe('PathsModel', () => { tooltip: undefined }, { - descendantStatuses: [2, 2, 2, 2], + descendantStatuses: [2, 2, 2, 2, 2], hasChildren: true, label: 'plots', parentPath: undefined, diff --git a/extension/src/plots/vega/util.test.ts b/extension/src/plots/vega/util.test.ts index c8d27664f6..42c60d04e6 100644 --- a/extension/src/plots/vega/util.test.ts +++ b/extension/src/plots/vega/util.test.ts @@ -12,7 +12,7 @@ import confusionNormalizedTemplate from '../../test/fixtures/plotsDiff/templates import linearTemplate from '../../test/fixtures/plotsDiff/templates/linear' import scatterTemplate from '../../test/fixtures/plotsDiff/templates/scatter' import smoothTemplate from '../../test/fixtures/plotsDiff/templates/smooth' -import { copyOriginalColors } from '../../experiments/model/status/colors' +import { copyOriginalColors } from '../../common/colors' import { EXPERIMENT_WORKSPACE_ID, PLOT_ANCHORS } from '../../cli/dvc/contract' describe('isMultiViewPlot', () => { diff --git a/extension/src/plots/vega/util.ts b/extension/src/plots/vega/util.ts index 2620f70b93..70e53c5e61 100644 --- a/extension/src/plots/vega/util.ts +++ b/extension/src/plots/vega/util.ts @@ -14,7 +14,7 @@ import { } from 'vega-lite/build/src/spec/repeat' import { TopLevelUnitSpec } from 'vega-lite/build/src/spec/unit' import { ColorScale } from '../webview/contract' -import { Color } from '../../experiments/model/status/colors' +import { Color } from '../../common/colors' import { AnchorDefinitions, PLOT_ANCHORS, diff --git a/extension/src/plots/webview/contract.ts b/extension/src/plots/webview/contract.ts index 4a675cd1c3..dd0e3fe612 100644 --- a/extension/src/plots/webview/contract.ts +++ b/extension/src/plots/webview/contract.ts @@ -1,7 +1,8 @@ import type { TopLevelSpec } from 'vega-lite' -import { Color } from '../../experiments/model/status/colors' +import { BoundingBoxColor, Color } from '../../common/colors' import { AnchorDefinitions, + BoundingBox, ImagePlotOutput, PlotsType, TemplatePlotOutput @@ -50,10 +51,17 @@ export type SectionCollapsed = typeof DEFAULT_SECTION_COLLAPSED export type ComparisonRevisionData = { [revision: string]: ComparisonPlot } -export type ComparisonPlots = { +export type ComparisonClassDetails = { + [label: string]: { selected: boolean; color: BoundingBoxColor } +} + +export type ComparisonPlotRow = { + classDetails: ComparisonClassDetails path: string revisions: ComparisonRevisionData -}[] +} + +export type ComparisonPlots = ComparisonPlotRow[] export type RevisionSummaryColumns = Array<{ path: string @@ -76,12 +84,26 @@ export type ComparisonMultiPlotValues = { [revision: string]: { [path: string]: number } } +export type ComparisonClassesSelected = { + [path: string]: { [label: string]: boolean } +} + +export type ComparisonPlotClass = { + label: string + boxes: BoundingBox[] +} + +export type ComparisonPlotClasses = { + [revision: string]: { [path: string]: ComparisonPlotClass[] } +} + export interface PlotsComparisonData { plots: ComparisonPlots width: number height: PlotHeight revisions: Revision[] multiPlotValues: ComparisonMultiPlotValues + plotClasses: ComparisonPlotClasses } export type CustomPlotValues = { diff --git a/extension/src/plots/webview/messages.ts b/extension/src/plots/webview/messages.ts index 22e8b727e1..b524677f63 100644 --- a/extension/src/plots/webview/messages.ts +++ b/extension/src/plots/webview/messages.ts @@ -117,6 +117,12 @@ export class WebviewMessages { message.payload.path, message.payload.value ) + case MessageFromWebviewType.TOGGLE_COMPARISON_CLASS: + return this.toggleComparisonClass( + message.payload.path, + message.payload.label, + message.payload.selected + ) case MessageFromWebviewType.REMOVE_CUSTOM_PLOTS: return commands.executeCommand( RegisteredCommands.PLOTS_CUSTOM_REMOVE, @@ -283,6 +289,20 @@ export class WebviewMessages { ) } + private toggleComparisonClass( + path: string, + label: string, + selected: boolean + ) { + this.plots.toggleComparisonClass(path, label, selected) + this.sendComparisonPlots() + sendTelemetryEvent( + EventName.VIEWS_PLOTS_TOGGLE_COMPARISON_CLASS, + undefined, + undefined + ) + } + private setTemplateOrder(order: PlotsTemplatesReordered) { this.paths.setTemplateOrder(order) this.sendTemplatePlots() @@ -405,8 +425,13 @@ export class WebviewMessages { return { height: this.plots.getHeight(PlotsSection.COMPARISON_TABLE), multiPlotValues: this.plots.getComparisonMultiPlotValues(), - plots: comparison.map(({ path, revisions }) => { - return { path, revisions: this.getRevisionsWithCorrectUrls(revisions) } + plotClasses: this.plots.getComparisonPlotClasses(comparison), + plots: comparison.map(({ classDetails, path, revisions }) => { + return { + classDetails, + path, + revisions: this.getRevisionsWithCorrectUrls(revisions) + } }), revisions: this.plots.getComparisonRevisions(), width: this.plots.getNbItemsPerRowOrWidth(PlotsSection.COMPARISON_TABLE) diff --git a/extension/src/telemetry/constants.ts b/extension/src/telemetry/constants.ts index 1be21d9171..24ad6c259c 100644 --- a/extension/src/telemetry/constants.ts +++ b/extension/src/telemetry/constants.ts @@ -100,6 +100,7 @@ export const EventName = Object.assign( VIEWS_PLOTS_SET_COMPARISON_MULTI_PLOT_VALUE: 'view.plots.setComparisonMultiPlotValue', VIEWS_PLOTS_SET_SMOOTH_PLOT_VALUE: 'view.plots.setSmoothPlotValues', + VIEWS_PLOTS_TOGGLE_COMPARISON_CLASS: 'view.plots.toggleComparisonClass', VIEWS_PLOTS_ZOOM_PLOT: 'views.plots.zoomPlot', VIEWS_REORDER_PLOTS_CUSTOM: 'views.plots.customReordered', VIEWS_REORDER_PLOTS_TEMPLATES: 'views.plots.templatesReordered', @@ -305,6 +306,7 @@ export interface IEventNamePropertyMapping { [EventName.VIEWS_REORDER_PLOTS_TEMPLATES]: undefined [EventName.VIEWS_PLOTS_SET_COMPARISON_MULTI_PLOT_VALUE]: undefined [EventName.VIEWS_PLOTS_SET_SMOOTH_PLOT_VALUE]: undefined + [EventName.VIEWS_PLOTS_TOGGLE_COMPARISON_CLASS]: undefined [EventName.VIEWS_PLOTS_PATH_TREE_OPENED]: DvcRootCount diff --git a/extension/src/test/cli/plotsDiff.test.ts b/extension/src/test/cli/plotsDiff.test.ts index 80c7da1790..63b5af9bf0 100644 --- a/extension/src/test/cli/plotsDiff.test.ts +++ b/extension/src/test/cli/plotsDiff.test.ts @@ -3,14 +3,14 @@ import { expect } from 'chai' import { TEMP_DIR } from './constants' import { dvcReader, initializeDemoRepo, initializeEmptyRepo } from './util' import { dvcDemoPath } from '../util' -import { ImagePlot } from '../../plots/webview/contract' import { PLOT_ANCHORS, EXPERIMENT_WORKSPACE_ID, PlotsOutput, PlotsType, TemplatePlotOutput, - isImagePlotOutput + isImagePlotOutput, + ImagePlotOutput } from '../../cli/dvc/contract' import { isDvcError } from '../../cli/dvc/reader' @@ -43,7 +43,7 @@ suite('plots diff -o --split --show-json', () => { ).to.have.lengthOf.greaterThanOrEqual(1) // each plot - const expectImage = (plot: ImagePlot) => { + const expectImage = (plot: ImagePlotOutput) => { expect(plot.url).to.be.a('string') expect(plot.revisions, 'should have one revision').to.have.lengthOf(1) } diff --git a/extension/src/test/fixtures/expShow/base/rows.ts b/extension/src/test/fixtures/expShow/base/rows.ts index b14cdcd401..43710365bc 100644 --- a/extension/src/test/fixtures/expShow/base/rows.ts +++ b/extension/src/test/fixtures/expShow/base/rows.ts @@ -5,7 +5,7 @@ import { StudioLinkType, WORKSPACE_BRANCH } from '../../../../experiments/webview/contract' -import { copyOriginalColors } from '../../../../experiments/model/status/colors' +import { copyOriginalColors } from '../../../../common/colors' import { shortenForLabel } from '../../../../util/string' import { ExecutorStatus, diff --git a/extension/src/test/fixtures/plotsDiff/index.ts b/extension/src/test/fixtures/plotsDiff/index.ts index 7fed230687..5645c6cb6e 100644 --- a/extension/src/test/fixtures/plotsDiff/index.ts +++ b/extension/src/test/fixtures/plotsDiff/index.ts @@ -1,7 +1,9 @@ import type { TopLevelSpec } from 'vega-lite' import { isMultiViewPlot } from '../../../plots/vega/util' import { + BoundingBox, EXPERIMENT_WORKSPACE_ID, + ImagePlotOutput, PLOT_ANCHORS, PlotsOutput, PlotsType, @@ -18,10 +20,12 @@ import { DEFAULT_PLOT_HEIGHT, DEFAULT_NB_ITEMS_PER_ROW, DEFAULT_PLOT_WIDTH, - ComparisonPlotImg + ComparisonPlotImg, + ComparisonClassDetails, + ComparisonPlotClasses } from '../../../plots/webview/contract' import { join } from '../../util/path' -import { copyOriginalColors } from '../../../experiments/model/status/colors' +import { copyOriginalColors, getBoundingBoxColor } from '../../../common/colors' import { ColumnType } from '../../../experiments/webview/contract' const basicVega = { @@ -538,7 +542,10 @@ const getMultiImageData = ( return data } -const getImageData = (baseUrl: string, joinFunc = join) => ({ +const getImageData = ( + baseUrl: string, + joinFunc = join +): { [path: string]: ImagePlotOutput[] } => ({ [join('plots', 'acc.png')]: [ { type: PlotsType.IMAGE, @@ -626,7 +633,96 @@ const getImageData = (baseUrl: string, joinFunc = join) => ({ 'exp-e7a67', 'test-branch', 'exp-83425' - ]) + ]), + [join('plots', 'bounding_boxes.png')]: [ + { + type: PlotsType.IMAGE, + revisions: [EXPERIMENT_WORKSPACE_ID], + url: joinFunc(baseUrl, 'bounding_boxes.png'), + annotations: { + 'traffic light': [ + { box: { left: 120, right: 195, top: 120, bottom: 210 }, score: 0.99 } + ], + car: [ + { box: { left: 150, right: 180, top: 320, bottom: 350 }, score: 0.5 }, + { + box: { left: 200, right: 230, top: 310, bottom: 340 }, + score: 0.354 + } + ], + sign: [ + { box: { left: 300, right: 450, top: 170, bottom: 220 }, score: 0.87 } + ] + } + }, + { + type: PlotsType.IMAGE, + revisions: ['main'], + url: joinFunc(baseUrl, 'bounding_boxes.png'), + annotations: { + 'traffic light': [ + { box: { left: 120, right: 195, top: 120, bottom: 210 }, score: 0.99 } + ], + car: [ + { box: { left: 150, right: 180, top: 320, bottom: 350 }, score: 0.5 } + ], + sign: [ + { box: { left: 300, right: 450, top: 170, bottom: 220 }, score: 0.87 } + ] + } + }, + { + type: PlotsType.IMAGE, + revisions: ['exp-e7a67'], + url: joinFunc(baseUrl, 'bounding_boxes.png'), + annotations: { + 'traffic light': [ + { box: { left: 120, right: 195, top: 120, bottom: 210 }, score: 0.99 } + ], + car: [ + { box: { left: 150, right: 180, top: 320, bottom: 350 }, score: 0.5 } + ] + } + }, + { + type: PlotsType.IMAGE, + revisions: ['test-branch'], + url: joinFunc(baseUrl, 'bounding_boxes.png'), + annotations: { + 'traffic light': [ + { + box: { left: 120, right: 195, top: 120, bottom: 210 }, + score: 0.764 + } + ], + car: [ + { + box: { left: 150, right: 180, top: 320, bottom: 350 }, + score: 0.984 + } + ] + } + }, + { + type: PlotsType.IMAGE, + revisions: ['exp-83425'], + url: joinFunc(baseUrl, 'bounding_boxes.png'), + annotations: { + 'traffic light': [ + { + box: { left: 120, right: 195, top: 120, bottom: 210 }, + score: 0.984 + } + ], + car: [ + { + box: { left: 150, right: 180, top: 320, bottom: 350 }, + score: 0.984 + } + ] + } + } + ] }) export const getOutput = (baseUrl: string): PlotsOutput => ({ @@ -908,26 +1004,69 @@ const getIndFromComparisonMultiImgPath = (path: string) => { return Number((pathIndMatches as string[])[1]) } +export const collectPlotClasses = ({ + plotClasses, + imgLabels, + imgAnnotations, + id, + path +}: { + plotClasses: ComparisonPlotClasses + imgLabels: string[] + imgAnnotations: { [label: string]: BoundingBox[] } + id: string + path: string +}) => { + const classAcc = [] + + for (const label of imgLabels) { + classAcc.push({ boxes: imgAnnotations[label], label }) + } + + if (!plotClasses[id]) { + plotClasses[id] = {} + } + + plotClasses[id][path] = Object.values(classAcc) +} + +export const collectClassLabels = ( + imgLabels: string[], + plotClasses: Set +) => { + for (const label of imgLabels) { + plotClasses.add(label) + } +} + export const getComparisonWebviewMessage = ( baseUrl: string, joinFunc: (...args: string[]) => string = join ): PlotsComparisonData => { + const plotClasses: ComparisonPlotClasses = {} + const plotAcc: { - [path: string]: { path: string; revisions: ComparisonRevisionData } + [path: string]: { + path: string + revisions: ComparisonRevisionData + classDetails: ComparisonClassDetails + } } = {} for (const [path, plots] of Object.entries(getImageData(baseUrl, joinFunc))) { const isMulti = path.includes('image') const pathLabel = isMulti ? join('plots', 'image') : path + const classLabels = new Set() if (!plotAcc[pathLabel]) { plotAcc[pathLabel] = { path: pathLabel, - revisions: {} + revisions: {}, + classDetails: {} } } - for (const { url, revisions } of plots) { + for (const { url, revisions, annotations } of plots) { const id = revisions?.[0] if (!id) { continue @@ -950,11 +1089,31 @@ export const getComparisonWebviewMessage = ( img.ind = getIndFromComparisonMultiImgPath(path) } + if (annotations) { + const imgLabels = Object.keys(annotations) + collectPlotClasses({ + plotClasses, + imgAnnotations: annotations, + imgLabels, + id, + path + }) + collectClassLabels(imgLabels, classLabels) + } + plotAcc[pathLabel].revisions[id].imgs.push(img) } + + for (const [ind, label] of [...classLabels].entries()) { + plotAcc[pathLabel].classDetails[label] = { + selected: true, + color: getBoundingBoxColor(ind) + } + } } return { + plotClasses, revisions: getRevisions(), multiPlotValues: {}, plots: Object.values(plotAcc), diff --git a/extension/src/test/fixtures/plotsDiff/staticImages/bounding_boxes.png b/extension/src/test/fixtures/plotsDiff/staticImages/bounding_boxes.png new file mode 100644 index 0000000000..124dbe2c7d Binary files /dev/null and b/extension/src/test/fixtures/plotsDiff/staticImages/bounding_boxes.png differ diff --git a/extension/src/test/suite/experiments/index.test.ts b/extension/src/test/suite/experiments/index.test.ts index 5a3aa5a7fe..71e166e4d3 100644 --- a/extension/src/test/suite/experiments/index.test.ts +++ b/extension/src/test/suite/experiments/index.test.ts @@ -68,7 +68,7 @@ import { buildMetricOrParamPath } from '../../../experiments/columns/paths' import { ColumnsModel } from '../../../experiments/columns/model' import { MessageFromWebviewType } from '../../../webview/contract' import { ExperimentsModel } from '../../../experiments/model' -import { copyOriginalColors } from '../../../experiments/model/status/colors' +import { copyOriginalColors } from '../../../common/colors' import { WEBVIEW_TEST_TIMEOUT } from '../timeouts' import * as Telemetry from '../../../telemetry' import { EventName } from '../../../telemetry/constants' diff --git a/extension/src/test/suite/experiments/model/tree.test.ts b/extension/src/test/suite/experiments/model/tree.test.ts index 16365bfb31..03173af4c4 100644 --- a/extension/src/test/suite/experiments/model/tree.test.ts +++ b/extension/src/test/suite/experiments/model/tree.test.ts @@ -41,7 +41,7 @@ import { DvcExecutor } from '../../../../cli/dvc/executor' import { Param } from '../../../../experiments/model/modify/collect' import { WorkspaceExperiments } from '../../../../experiments/workspace' import { EXPERIMENT_WORKSPACE_ID } from '../../../../cli/dvc/contract' -import { copyOriginalColors } from '../../../../experiments/model/status/colors' +import { copyOriginalColors } from '../../../../common/colors' import { Revision } from '../../../../plots/webview/contract' suite('Experiments Tree Test Suite', () => { diff --git a/extension/src/test/suite/plots/index.test.ts b/extension/src/test/suite/plots/index.test.ts index 042d62f652..74d4e64166 100644 --- a/extension/src/test/suite/plots/index.test.ts +++ b/extension/src/test/suite/plots/index.test.ts @@ -28,7 +28,7 @@ import { PlotsData as TPlotsData, PlotsSection, TemplatePlotGroup, - ImagePlot + ComparisonPlotClasses } from '../../../plots/webview/contract' import { TEMP_PLOTS_DIR } from '../../../cli/dvc/constants' import { WEBVIEW_TEST_TIMEOUT } from '../timeouts' @@ -46,7 +46,8 @@ import { EXPERIMENT_WORKSPACE_ID, ExpShowOutput, TemplatePlotOutput, - experimentHasError + experimentHasError, + ImagePlotOutput } from '../../../cli/dvc/contract' import { Experiment } from '../../../experiments/webview/contract' import { COMMITS_SEPARATOR } from '../../../cli/git/constants' @@ -402,7 +403,8 @@ suite('Plots Test Suite', () => { join('plots', 'acc.png'), join('plots', 'heatmap.png'), join('plots', 'loss.png'), - join('plots', 'image') + join('plots', 'image'), + join('plots', 'bounding_boxes.png') ] messageSpy.resetHistory() @@ -1036,7 +1038,7 @@ suite('Plots Test Suite', () => { const accPngPath = join('plots', 'acc.png') const accPng = [ ...plotsDiffFixture.data[join('plots', 'acc.png')] - ] as ImagePlot[] + ] as ImagePlotOutput[] const lossTsvPath = join('logs', 'loss.tsv') const lossTsv = [ ...plotsDiffFixture.data[lossTsvPath] @@ -1108,7 +1110,7 @@ suite('Plots Test Suite', () => { const accPngPath = join('plots', 'acc.png') const accPng = [ ...plotsDiffFixture.data[join('plots', 'acc.png')] - ] as ImagePlot[] + ] as ImagePlotOutput[] const lossTsvPath = join('logs', 'loss.tsv') const lossTsv = [ ...plotsDiffFixture.data[lossTsvPath] @@ -1276,6 +1278,79 @@ suite('Plots Test Suite', () => { ) }).timeout(WEBVIEW_TEST_TIMEOUT) + it('should handle an toggle comparison class message from the webview', async () => { + const { messageSpy, mockMessageReceived, plotsModel } = + await buildPlotsWebview({ + disposer: disposable, + plotsDiff: plotsDiffFixture + }) + const toggledLabel = 'car' + const boundingBoxPlot = comparisonPlotsFixture.plots[4] + + const filteredPlotsClasses: ComparisonPlotClasses = {} + for (const [id, classesByPath] of Object.entries( + comparisonPlotsFixture.plotClasses + )) { + const classes = classesByPath[boundingBoxPlot.path] + filteredPlotsClasses[id] = { + [boundingBoxPlot.path]: classes.filter( + ({ label }) => label !== toggledLabel + ) + } + } + + const filteredPlots = comparisonPlotsFixture.plots.map(plot => { + if (plot.path !== boundingBoxPlot.path) { + return plot + } + + const { color } = plot.classDetails[toggledLabel] + return { + ...plot, + classDetails: { + ...plot.classDetails, + [toggledLabel]: { color, selected: false } + } + } + }) + + const mockSendTelemetryEvent = stub(Telemetry, 'sendTelemetryEvent') + const toggleComparisonClassSpy = spy(plotsModel, 'toggleComparisonClass') + + messageSpy.resetHistory() + mockMessageReceived.fire({ + payload: { + label: toggledLabel, + path: boundingBoxPlot.path, + selected: false + }, + type: MessageFromWebviewType.TOGGLE_COMPARISON_CLASS + }) + + expect(toggleComparisonClassSpy).to.be.called + expect(toggleComparisonClassSpy).to.be.calledWithExactly( + boundingBoxPlot.path, + 'car', + false + ) + expect( + messageSpy, + "should update the webview's comparison classes" + ).to.be.calledWithExactly({ + comparison: { + ...comparisonPlotsFixture, + plotClasses: filteredPlotsClasses, + plots: filteredPlots + } + }) + expect(mockSendTelemetryEvent).to.be.called + expect(mockSendTelemetryEvent).to.be.calledWithExactly( + EventName.VIEWS_PLOTS_TOGGLE_COMPARISON_CLASS, + undefined, + undefined + ) + }).timeout(WEBVIEW_TEST_TIMEOUT) + it('should handle an add plot message from the webview', async () => { const { mockMessageReceived } = await buildPlotsWebview({ disposer: disposable, diff --git a/extension/src/webview/contract.ts b/extension/src/webview/contract.ts index df16e9b04e..56bb310915 100644 --- a/extension/src/webview/contract.ts +++ b/extension/src/webview/contract.ts @@ -49,6 +49,7 @@ export enum MessageFromWebviewType { SAVE_STUDIO_TOKEN = 'save-studio-token', SAVE_STUDIO_URL = 'save-studio-url', SET_COMPARISON_MULTI_PLOT_VALUE = 'update-comparison-multi-plot-value', + TOGGLE_COMPARISON_CLASS = 'toggle-comparison-class', SET_SMOOTH_PLOT_VALUE = 'update-smooth-plot-value', SHOW_EXPERIMENT_LOGS = 'show-experiment-logs', SHOW_WALKTHROUGH = 'show-walkthrough', @@ -236,6 +237,10 @@ export type MessageFromWebview = type: MessageFromWebviewType.SET_COMPARISON_MULTI_PLOT_VALUE payload: { path: string; revision: string; value: number } } + | { + type: MessageFromWebviewType.TOGGLE_COMPARISON_CLASS + payload: { path: string; label: string; selected: boolean } + } | { type: MessageFromWebviewType.REORDER_PLOTS_CUSTOM payload: string[] diff --git a/scripts/create-svgs.ts b/scripts/create-svgs.ts index 011fcae7f1..75c76125d3 100644 --- a/scripts/create-svgs.ts +++ b/scripts/create-svgs.ts @@ -1,5 +1,5 @@ import { readFileSync, writeFileSync } from 'fs' -import { copyOriginalColors } from 'dvc/src/experiments/model/status/colors' +import { copyOriginalColors } from 'dvc/src/common/colors' const colors = copyOriginalColors() diff --git a/webview/src/plots/components/App.test.tsx b/webview/src/plots/components/App.test.tsx index eada042721..5697a98def 100644 --- a/webview/src/plots/components/App.test.tsx +++ b/webview/src/plots/components/App.test.tsx @@ -226,8 +226,10 @@ describe('App', () => { comparison: { height: DEFAULT_PLOT_HEIGHT, multiPlotValues: {}, + plotClasses: {}, plots: [ { + classDetails: {}, path: 'training/plots/images/misclassified.jpg', revisions: { ad2b5ec: { @@ -281,8 +283,10 @@ describe('App', () => { comparison: { height: DEFAULT_PLOT_HEIGHT, multiPlotValues: {}, + plotClasses: {}, plots: [ { + classDetails: {}, path: 'training/plots/images/image', revisions: { ad2b5ec: { diff --git a/webview/src/plots/components/comparisonTable/ComparisonTable.test.tsx b/webview/src/plots/components/comparisonTable/ComparisonTable.test.tsx index 7e38b0c0e5..fd98ce4e19 100644 --- a/webview/src/plots/components/comparisonTable/ComparisonTable.test.tsx +++ b/webview/src/plots/components/comparisonTable/ComparisonTable.test.tsx @@ -196,8 +196,8 @@ describe('ComparisonTable', () => { const rows = screen.getAllByRole('row') expect(rows.length).toBe( - Object.entries(comparisonTableFixture.plots).length * 2 + 1 - ) // 1 header row and 2 rows per plot + Object.entries(comparisonTableFixture.plots).length * 2 + 2 + ) // 1 header row, 1 bounding box classes row, and 2 rows per plot }) it('should display the plots in the rows in the same order as the columns', () => { @@ -297,8 +297,8 @@ describe('ComparisonTable', () => { renderTable({ ...comparisonTableFixture, - plots: comparisonTableFixture.plots.map(({ path, revisions }) => ({ - path, + plots: comparisonTableFixture.plots.map(({ revisions, ...rest }) => ({ + ...rest, revisions: { ...revisions, [revisionWithNoData]: { @@ -334,8 +334,8 @@ describe('ComparisonTable', () => { renderTable({ ...comparisonTableFixture, - plots: comparisonTableFixture.plots.map(({ path, revisions }) => ({ - path, + plots: comparisonTableFixture.plots.map(({ revisions, ...rest }) => ({ + ...rest, revisions: { ...revisions, [revisionWithNoData]: { @@ -743,4 +743,51 @@ describe('ComparisonTable', () => { }) }) }) + + describe('Plots With Bounding Boxes', () => { + it('should show toggable labels in the plot row', () => { + renderTable() + + const boundingBoxPlotClasses = screen.getByTestId( + 'row-bounding-box-classes' + ) + + expect( + within(boundingBoxPlotClasses).getByText('Classes') + ).toBeInTheDocument() + + const labelInput = within(boundingBoxPlotClasses).getByLabelText( + 'traffic light' + ) + const plotPath = comparisonTableFixture.plots[4].path + + expect(labelInput).toBeInTheDocument() + expect(labelInput).toBeChecked() + + fireEvent.click(within(boundingBoxPlotClasses).getByText('traffic light')) + + expect(labelInput).not.toBeChecked() + expect(mockPostMessage).toHaveBeenCalledTimes(1) + expect(mockPostMessage).toHaveBeenCalledWith({ + payload: { label: 'traffic light', path: plotPath, selected: false }, + type: MessageFromWebviewType.TOGGLE_COMPARISON_CLASS + }) + }) + + it('should show svgs with bounding boxes instead of images', () => { + renderTable() + + const boundingBoxPlotImage = screen.getByLabelText( + /bounding_boxes.png \(workspace\)/ + ) + const getBoxTextReg = (label: string) => new RegExp(`${label} 0\\.\\d+`) + expect(boundingBoxPlotImage).toHaveAttribute('viewBox') + expect( + within(boundingBoxPlotImage).getByText(getBoxTextReg('traffic light')) + ).toBeInTheDocument() + expect( + within(boundingBoxPlotImage).getAllByText(getBoxTextReg('car')) + ).toHaveLength(2) + }) + }) }) diff --git a/webview/src/plots/components/comparisonTable/ComparisonTablePinnedContentRow.tsx b/webview/src/plots/components/comparisonTable/ComparisonTablePinnedContentRow.tsx new file mode 100644 index 0000000000..44e23a8e21 --- /dev/null +++ b/webview/src/plots/components/comparisonTable/ComparisonTablePinnedContentRow.tsx @@ -0,0 +1,19 @@ +import React, { PropsWithChildren } from 'react' +import cx from 'classnames' +import styles from './styles.module.scss' + +export const ComparisonTablePinnedContentRow: React.FC< + PropsWithChildren<{ pinnedColumn: string; nbColumns: number }> +> = ({ children, pinnedColumn, nbColumns }) => ( + + + {children} + + {nbColumns > 1 && } + +) diff --git a/webview/src/plots/components/comparisonTable/ComparisonTableRow.test.tsx b/webview/src/plots/components/comparisonTable/ComparisonTableRow.test.tsx index 6124afd579..97ad295ed4 100644 --- a/webview/src/plots/components/comparisonTable/ComparisonTableRow.test.tsx +++ b/webview/src/plots/components/comparisonTable/ComparisonTableRow.test.tsx @@ -20,6 +20,7 @@ jest.mock('../../../shared/api') describe('ComparisonTableRow', () => { const basicProps: ComparisonTableRowProps = { + classDetails: {}, nbColumns: 3, onLayoutChange: jest.fn(), order: ['path/to/the-file/image.png'], diff --git a/webview/src/plots/components/comparisonTable/ComparisonTableRow.tsx b/webview/src/plots/components/comparisonTable/ComparisonTableRow.tsx index d74d73d4f8..9aef192302 100644 --- a/webview/src/plots/components/comparisonTable/ComparisonTableRow.tsx +++ b/webview/src/plots/components/comparisonTable/ComparisonTableRow.tsx @@ -1,4 +1,7 @@ -import { ComparisonPlot } from 'dvc/src/plots/webview/contract' +import { + ComparisonClassDetails, + ComparisonPlot +} from 'dvc/src/plots/webview/contract' import React, { useState, useEffect, @@ -12,6 +15,7 @@ import { useSelector } from 'react-redux' import styles from './styles.module.scss' import { ComparisonTableCell } from './cell/ComparisonTableCell' import { ComparisonTableMultiCell } from './cell/ComparisonTableMultiCell' +import { ComparisonTablePinnedContentRow } from './ComparisonTablePinnedContentRow' import { RowDropTarget } from './RowDropTarget' import { Icon } from '../../../shared/components/Icon' import { ChevronDown, ChevronRight } from '../../../shared/components/icons' @@ -23,6 +27,7 @@ import Tooltip, { } from '../../../shared/components/tooltip/Tooltip' import { useDragAndDrop } from '../../../shared/hooks/useDragAndDrop' import { DragDropItemWithTarget } from '../../../shared/components/dragDrop/DragDropItemWithTarget' +import { toggleComparisonClass } from '../../util/messages' export interface ComparisonTableRowProps { path: string @@ -32,6 +37,7 @@ export interface ComparisonTableRowProps { onLayoutChange: () => void setOrder: (order: string[]) => void order: string[] + classDetails: ComparisonClassDetails bodyRef?: RefObject } @@ -43,6 +49,7 @@ export const ComparisonTableRow: React.FC = ({ onLayoutChange, setOrder, order, + classDetails, bodyRef }) => { const plotsRowRef = useRef(null) @@ -67,6 +74,10 @@ export const ComparisonTableRow: React.FC = ({ type: , vertical: true }) + const classDetailsArr = Object.entries(classDetails) + const cellClasses = cx(styles.cell, { + [styles.cellHidden]: !isShown + }) useLayoutEffect(() => { onLayoutChange?.() @@ -123,30 +134,59 @@ export const ComparisonTableRow: React.FC = ({ id={path} ref={bodyRef} > - - -
- - + {path} + + + +
+ + {classDetailsArr.length > 0 && ( + +
+

Classes

+ {classDetailsArr.map(([label, { color, selected }]) => ( + + + toggleComparisonClass(path, label, event.target.checked) + } + /> + + + ))}
- - {nbColumns > 1 && pinnedColumn && } - +
+ )} {plots.map(plot => ( = ({ isInDragAndDropMode && draggedId === plot.id })} > -
+
{plot.imgs.length > 1 ? ( - + ) : ( - + )}
diff --git a/webview/src/plots/components/comparisonTable/ComparisonTableRows.tsx b/webview/src/plots/components/comparisonTable/ComparisonTableRows.tsx index 382d151e5c..526123c001 100644 --- a/webview/src/plots/components/comparisonTable/ComparisonTableRows.tsx +++ b/webview/src/plots/components/comparisonTable/ComparisonTableRows.tsx @@ -50,6 +50,7 @@ export const ComparisonTableRows: React.FC = ({ id: column.id, imgs: revs[column.id]?.imgs }))} + classDetails={plot.classDetails} nbColumns={columns.length} pinnedColumn={pinnedColumn} onLayoutChange={onLayoutChange} diff --git a/webview/src/plots/components/comparisonTable/cell/ComparisonTableBoundingBoxColorFilter.tsx b/webview/src/plots/components/comparisonTable/cell/ComparisonTableBoundingBoxColorFilter.tsx new file mode 100644 index 0000000000..f12a07be09 --- /dev/null +++ b/webview/src/plots/components/comparisonTable/cell/ComparisonTableBoundingBoxColorFilter.tsx @@ -0,0 +1,15 @@ +import React from 'react' + +export const ComparisonTableBoundingBoxColorFilter: React.FC<{ + color: string +}> = ({ color }) => { + return ( + + + + + + + + ) +} diff --git a/webview/src/plots/components/comparisonTable/cell/ComparisonTableBoundingBoxImg.tsx b/webview/src/plots/components/comparisonTable/cell/ComparisonTableBoundingBoxImg.tsx new file mode 100644 index 0000000000..b2f761d554 --- /dev/null +++ b/webview/src/plots/components/comparisonTable/cell/ComparisonTableBoundingBoxImg.tsx @@ -0,0 +1,88 @@ +import React, { useEffect, useState } from 'react' +import { useSelector } from 'react-redux' +import { createSelector } from '@reduxjs/toolkit' +import { + ComparisonClassDetails, + ComparisonPlotClasses +} from 'dvc/src/plots/webview/contract' +import { ComparisonTableBoundingBoxColorFilter } from './ComparisonTableBoundingBoxColorFilter' +import styles from '../styles.module.scss' +import { PlotsState } from '../../../store' + +const plotClassesSelector = (state: PlotsState) => state.comparison.plotClasses +const classesSelector = createSelector( + [plotClassesSelector, (_, id: string) => id, (_, id, path: string) => path], + (plotClasses: ComparisonPlotClasses, id: string, path: string) => + plotClasses[id]?.[path] || [] +) + +export const ComparisonTableBoundingBoxImg: React.FC<{ + id: string + src: string + path: string + classDetails: ComparisonClassDetails + alt: string +}> = ({ alt, classDetails, id, src, path }) => { + const classes = useSelector((state: PlotsState) => + classesSelector(state, id, path) + ) + const [naturalWidth, setNaturalWidth] = useState(0) + const [naturalHeight, setNaturalHeight] = useState(0) + + useEffect(() => { + const img = new Image() + img.src = src + + img.addEventListener('load', () => { + setNaturalWidth(img.naturalWidth) + setNaturalHeight(img.naturalHeight) + }) + }, [src]) + + return ( + + {Object.entries(classDetails).map( + ([label, { color, selected }]) => + selected && ( + + ) + )} + + {classes.map(({ label, boxes }) => { + const labelColor = classDetails[label]?.color + + if (!labelColor) { + return + } + + return boxes.map(({ box: { bottom, top, right, left }, score }) => ( + + + {label} {score} + + + + )) + })} + + ) +} diff --git a/webview/src/plots/components/comparisonTable/cell/ComparisonTableCell.tsx b/webview/src/plots/components/comparisonTable/cell/ComparisonTableCell.tsx index 53bee7782e..48404c6512 100644 --- a/webview/src/plots/components/comparisonTable/cell/ComparisonTableCell.tsx +++ b/webview/src/plots/components/comparisonTable/cell/ComparisonTableCell.tsx @@ -1,19 +1,25 @@ import React from 'react' -import { ComparisonPlot } from 'dvc/src/plots/webview/contract' +import { + ComparisonClassDetails, + ComparisonPlot +} from 'dvc/src/plots/webview/contract' import { ComparisonTableLoadingCell } from './ComparisonTableLoadingCell' import { ComparisonTableMissingCell } from './ComparisonTableMissingCell' +import { ComparisonTableBoundingBoxImg } from './ComparisonTableBoundingBoxImg' import styles from '../styles.module.scss' import { zoomPlot } from '../../../util/messages' export const ComparisonTableCell: React.FC<{ path: string plot: ComparisonPlot + classDetails: ComparisonClassDetails imgAlt?: string -}> = ({ path, plot, imgAlt }) => { +}> = ({ path, plot, imgAlt, classDetails }) => { const plotImg = plot.imgs[0] const loading = plotImg.loading const missing = !loading && !plotImg.url + const alt = imgAlt || `Plot of ${path} (${plot.id})` if (loading) { return @@ -29,12 +35,22 @@ export const ComparisonTableCell: React.FC<{ onClick={() => zoomPlot(plotImg.url)} data-testid="image-plot-button" > - {imgAlt + {plotImg.url && Object.keys(classDetails).length > 0 ? ( + + ) : ( + {alt} + )} ) } diff --git a/webview/src/plots/components/comparisonTable/cell/ComparisonTableMultiCell.tsx b/webview/src/plots/components/comparisonTable/cell/ComparisonTableMultiCell.tsx index 59ec411c7e..ecc402a209 100644 --- a/webview/src/plots/components/comparisonTable/cell/ComparisonTableMultiCell.tsx +++ b/webview/src/plots/components/comparisonTable/cell/ComparisonTableMultiCell.tsx @@ -1,6 +1,7 @@ import React, { useEffect, useCallback, useRef, useState } from 'react' import { useDispatch, useSelector } from 'react-redux' import { + ComparisonClassDetails, ComparisonPlot, ComparisonPlotImg } from 'dvc/src/plots/webview/contract' @@ -13,7 +14,8 @@ import { PlotsState } from '../../../store' export const ComparisonTableMultiCell: React.FC<{ path: string plot: ComparisonPlot -}> = ({ path, plot }) => { + classDetails: ComparisonClassDetails +}> = ({ path, plot, classDetails }) => { const values = useSelector( (state: PlotsState) => state.comparison.multiPlotValues ) @@ -58,6 +60,7 @@ export const ComparisonTableMultiCell: React.FC<{ imgs: [selectedImg] }} imgAlt={`${selectedImg.ind} of ${path} (${plot.id})`} + classDetails={classDetails} />
{ + sendMessage({ + payload: { label, path, selected }, + type: MessageFromWebviewType.TOGGLE_COMPARISON_CLASS + }) +} + export const togglePlotsSection = ( sectionKey: PlotsSection, sectionCollapsed: boolean diff --git a/webview/src/stories/ComparisonTable.stories.tsx b/webview/src/stories/ComparisonTable.stories.tsx index cfddfb4243..8b4e4859be 100644 --- a/webview/src/stories/ComparisonTable.stories.tsx +++ b/webview/src/stories/ComparisonTable.stories.tsx @@ -37,7 +37,7 @@ export default { title: 'Comparison Table' } as Meta -const Template: StoryFn = ({ plots, revisions }) => { +const Template: StoryFn = ({ plots, revisions, plotClasses = {} }) => { const store = configureStore({ reducer: plotsReducers }) @@ -48,6 +48,7 @@ const Template: StoryFn = ({ plots, revisions }) => { data={{ height: DEFAULT_PLOT_HEIGHT, multiPlotValues: {}, + plotClasses, plots, revisions, width: DEFAULT_NB_ITEMS_PER_ROW @@ -117,10 +118,13 @@ const removeImages = ( export const WithMissingData = Template.bind({}) WithMissingData.args = { - plots: comparisonTableFixture.plots.map(({ path, revisions }) => ({ - path, - revisions: removeImages(path, revisions) - })), + plots: comparisonTableFixture.plots.map( + ({ classDetails, path, revisions }) => ({ + classDetails, + path, + revisions: removeImages(path, revisions) + }) + ), revisions: comparisonTableFixture.revisions.map(revision => { if (revision.id === EXPERIMENT_WORKSPACE_ID) { return { ...revision, fetched: false } @@ -131,10 +135,13 @@ WithMissingData.args = { export const WithOnlyMissingData = Template.bind({}) WithOnlyMissingData.args = { - plots: comparisonTableFixture.plots.map(({ path, revisions }) => ({ - path, - revisions: removeImages(path, revisions) - })), + plots: comparisonTableFixture.plots.map( + ({ classDetails, path, revisions }) => ({ + classDetails, + path, + revisions: removeImages(path, revisions) + }) + ), revisions: comparisonTableFixture.revisions .map(revision => { if (revision.id === EXPERIMENT_WORKSPACE_ID) { diff --git a/webview/src/test/tableDataFixture.ts b/webview/src/test/tableDataFixture.ts index 3ff1eedcd8..0568548d55 100644 --- a/webview/src/test/tableDataFixture.ts +++ b/webview/src/test/tableDataFixture.ts @@ -1,4 +1,4 @@ -import { copyOriginalColors } from 'dvc/src/experiments/model/status/colors' +import { copyOriginalColors } from 'dvc/src/common/colors' import { Commit } from 'dvc/src/experiments/webview/contract' import { TableDataState } from '../experiments/state/tableDataSlice'