From 22764d57b245b4f7120ed76432511c30a6328e35 Mon Sep 17 00:00:00 2001 From: functionstackx <47992694+functionstackx@users.noreply.github.com> Date: Sat, 6 Jun 2026 20:28:36 -0400 Subject: [PATCH] fix(inference): per-precision line labels when multiple precisions selected MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When more than one precision is shown, each precision is its own curve, but line labels were deduplicated by hardware key — so only one of the two curves for a given hardware got a label, and that label omitted the precision. Now, when >1 precision is selected, every curve gets its own line label and the text includes the precision (e.g. "B200 (vLLM) FP8" vs "B200 (vLLM) FP4"). With a single precision selected, behavior is unchanged (one label per hardware, no precision suffix). Applies to both the interactivity (greedy placement) and TTFT/E2EL (endpoint) chart types, in the static render and the zoom re-placement paths. Overlay (unofficial run) line labels also gain the precision suffix so the two precision curves of an overlay stay distinguishable. Co-Authored-By: Claude Opus 4.8 (1M context) --- packages/app/cypress/e2e/line-labels.cy.ts | 27 +++++ .../components/inference/ui/ScatterGraph.tsx | 104 ++++++++++++------ 2 files changed, 97 insertions(+), 34 deletions(-) diff --git a/packages/app/cypress/e2e/line-labels.cy.ts b/packages/app/cypress/e2e/line-labels.cy.ts index 011bbda3..a3fa4b68 100644 --- a/packages/app/cypress/e2e/line-labels.cy.ts +++ b/packages/app/cypress/e2e/line-labels.cy.ts @@ -78,4 +78,31 @@ describe('Line Labels Toggle', () => { // Labels should be rendered cy.get('[data-testid="scatter-graph"] svg g.line-label').should('have.length.greaterThan', 0); }); + + it('appends the precision to each line label when multiple precisions are selected', () => { + cy.visit('/inference?i_linelabel=1&i_prec=fp4,fp8', { + onBeforeLoad(win) { + win.localStorage.setItem('inferencex-star-modal-dismissed', String(Date.now())); + }, + }); + cy.get('[data-testid="scatter-graph"]').should('be.visible'); + + // With both FP4 and FP8 shown, each curve is its own line and the label + // must carry the precision so the two curves of the same hardware are + // distinguishable (e.g. "B200 (vLLM) FP8" vs "B200 (vLLM) FP4"). + cy.get('[data-testid="scatter-graph"] svg g.line-label .ll-text') + .should('have.length.greaterThan', 0) + .then(($texts) => { + const labels = $texts.toArray().map((el) => el.textContent ?? ''); + // At least one label for each selected precision. + expect( + labels.some((t) => /\bFP8\b/u.test(t)), + 'an FP8 line label exists', + ).to.equal(true); + expect( + labels.some((t) => /\bFP4\b/u.test(t)), + 'an FP4 line label exists', + ).to.equal(true); + }); + }); }); diff --git a/packages/app/src/components/inference/ui/ScatterGraph.tsx b/packages/app/src/components/inference/ui/ScatterGraph.tsx index 5109047d..56e0088e 100644 --- a/packages/app/src/components/inference/ui/ScatterGraph.tsx +++ b/packages/app/src/components/inference/ui/ScatterGraph.tsx @@ -10,7 +10,7 @@ import ChartLegend from '@/components/ui/chart-legend'; import { useUnofficialRun } from '@/components/unofficial-run-provider'; import { computeToggle } from '@/hooks/useTogglableSet'; import { getHardwareConfig, getModelSortIndex } from '@/lib/constants'; -import { getChartWatermark } from '@/lib/data-mappings'; +import { getChartWatermark, getPrecisionLabel, type Precision } from '@/lib/data-mappings'; import { formatNumber, getDisplayLabel, updateRepoUrl } from '@/lib/utils'; import { D3Chart } from '@/lib/d3-chart/D3Chart'; import type { @@ -101,6 +101,14 @@ const parseHwKeyToLabel = (hwKey: string): { name: string; label: string } => { return { name: config.label, label: getDisplayLabel(config) }; }; +// Line-label text for a curve. When more than one precision is shown, each curve +// is its own line, so append the precision (e.g. "B200 (vLLM) FP8") to keep the +// FP4 and FP8 curves of the same hardware distinguishable. +const lineLabelText = (hwKey: string, precision: string, includePrecision: boolean): string => { + const base = parseHwKeyToLabel(hwKey).label; + return includePrecision ? `${base} ${getPrecisionLabel(precision as Precision)}` : base; +}; + const ScatterGraph = React.memo( ({ chartId, @@ -914,6 +922,9 @@ const ScatterGraph = React.memo( if (showLineLabels) { const isInteractivity = chartDefinition.chartType === 'interactivity'; + // With >1 precision selected each precision is its own curve, so label + // every curve and include the precision in the text. + const multiPrecision = selectedPrecisions.length > 1; const LABEL_H = 18; const LABEL_W = 120; // approximate label width for overlap check @@ -924,16 +935,19 @@ const ScatterGraph = React.memo( const collides = (cx: number, cy: number) => placed.some((p) => Math.abs(p.y - cy) < LABEL_H && Math.abs(p.x - cx) < LABEL_W); - // Deduplicate by hw key — pick the roofline with most points per hw - const bestByHw = new Map(); + // Deduplicate by group key — one label per curve. With a single + // precision that's one per hw; with multiple it's one per (hw, + // precision) so each precision curve keeps its own label. + const bestByGroup = new Map(); for (const e of entries) { if (!e.visible || e.points.length < 2) continue; - const prev = bestByHw.get(e.hw); - if (!prev || e.points.length > prev.points.length) bestByHw.set(e.hw, e); + const groupKey = multiPrecision ? e.key : e.hw; + const prev = bestByGroup.get(groupKey); + if (!prev || e.points.length > prev.points.length) bestByGroup.set(groupKey, e); } // Sort entries by highest y-value first (top of chart) for priority - const sorted = [...bestByHw.values()].toSorted((a, b) => { + const sorted = [...bestByGroup.values()].toSorted((a, b) => { const ay = yScale(a.points[0].y); const by = yScale(b.points[0].y); return ay - by; // smaller pixel y = higher on chart @@ -948,7 +962,7 @@ const ScatterGraph = React.memo( pts.at(-1)!, // endpoint ]; - const { label } = parseHwKeyToLabel(entry.hw); + const label = lineLabelText(entry.hw, entry.precision, multiPrecision); let foundPlacement = false; for (const pt of candidates) { const px = xScale(pt.x); @@ -983,21 +997,21 @@ const ScatterGraph = React.memo( } } - // Also add hidden entries for non-visible hw (so D3 data-join is clean) - const labeledHw = new Set(lineLabels.map((l) => l.hw)); + // Also add hidden entries for any curve that wasn't placed (so the + // D3 data-join, keyed by series key, is clean). + const labeledKeys = new Set(lineLabels.map((l) => l.key)); for (const entry of entries) { - if (entry.points.length >= 2 && !labeledHw.has(entry.hw)) { - const { label } = parseHwKeyToLabel(entry.hw); + if (entry.points.length >= 2 && !labeledKeys.has(entry.key)) { lineLabels.push({ key: entry.key, hw: entry.hw, - label, + label: lineLabelText(entry.hw, entry.precision, multiPrecision), color: getCssColor(resolveColor(entry.hw)), x: xScale(entry.points[0].x), y: yScale(entry.points[0].y), visible: false, }); - labeledHw.add(entry.hw); + labeledKeys.add(entry.key); } } @@ -1005,11 +1019,18 @@ const ScatterGraph = React.memo( // run-palette color so they match the legend swatches. The label // text mirrors the overlay legend ("✕ " — falls back to the // hw label if run metadata isn't available, e.g. legacy callers). - const overlayLabelText = (runIndex: number, hwKey: string): string => { + const overlayLabelText = ( + runIndex: number, + hwKey: string, + precision: string, + ): string => { const info = unofficialRunInfos[runIndex]; - if (!info) return parseHwKeyToLabel(hwKey).label; - const branch = info.branch || `run ${info.id}`; - return `✕ ${branch}`; + const base = info + ? `✕ ${info.branch || `run ${info.id}`}` + : parseHwKeyToLabel(hwKey).label; + return multiPrecision + ? `${base} ${getPrecisionLabel(precision as Precision)}` + : base; }; const sortedOverlay = Object.entries(overlayRooflines) .filter( @@ -1026,7 +1047,11 @@ const ScatterGraph = React.memo( pts[Math.max(0, Math.floor((pts.length * 2) / 3))], pts.at(-1)!, ]; - const label = overlayLabelText(group.runIndex, group.hwKey); + const label = overlayLabelText( + group.runIndex, + group.hwKey, + group.points[0]?.precision ?? '', + ); let placedOverlay = false; for (const pt of candidates) { const px = xScale(pt.x); @@ -1060,21 +1085,23 @@ const ScatterGraph = React.memo( } } } else { - // TTFT / E2EL: endpoint labels, one per hw key - const seenHw = new Set(); + // TTFT / E2EL: endpoint labels, one per curve (per hw, or per + // (hw, precision) when multiple precisions are shown). + const seen = new Set(); for (const entry of entries) { - if (entry.points.length < 2 || seenHw.has(entry.hw)) continue; - seenHw.add(entry.hw); + if (entry.points.length < 2 || !entry.visible) continue; + const groupKey = multiPrecision ? entry.key : entry.hw; + if (seen.has(groupKey)) continue; + seen.add(groupKey); const pt = entry.points.at(-1)!; - const { label } = parseHwKeyToLabel(entry.hw); lineLabels.push({ key: entry.key, hw: entry.hw, - label, + label: lineLabelText(entry.hw, entry.precision, multiPrecision), color: getCssColor(resolveColor(entry.hw)), x: xScale(pt.x), y: yScale(pt.y), - visible: entry.visible, + visible: true, }); } // Endpoint labels for overlay rooflines too (one per (hw, runIndex)), @@ -1082,9 +1109,12 @@ const ScatterGraph = React.memo( for (const [ovKey, group] of Object.entries(overlayRooflines)) { if (group.points.length < 2 || !activeOverlayHwTypes.has(group.hwKey)) continue; const info = unofficialRunInfos[group.runIndex]; - const labelText = info + const branchOrHw = info ? `✕ ${info.branch || `run ${info.id}`}` : parseHwKeyToLabel(group.hwKey).label; + const labelText = multiPrecision + ? `${branchOrHw} ${getPrecisionLabel((group.points[0]?.precision ?? '') as Precision)}` + : branchOrHw; const labelKey = `overlay-${ovKey}`; const pt = group.points.at(-1)!; lineLabels.push({ @@ -1236,6 +1266,7 @@ const ScatterGraph = React.memo( // Update line label positions on zoom if (showLineLabels) { const isInteractivity = chartDefinition.chartType === 'interactivity'; + const multiPrecision = selectedPrecisions.length > 1; const LABEL_H = 18; const LABEL_W = 120; @@ -1245,17 +1276,19 @@ const ScatterGraph = React.memo( const collides = (cx: number, cy: number) => placed.some((p) => Math.abs(p.y - cy) < LABEL_H && Math.abs(p.x - cx) < LABEL_W); - // Deduplicate by hw key — pick roofline with most points per hw - const bestByHw = new Map(); + // Deduplicate by group key — one curve per hw, or per (hw, precision) + // when multiple precisions are shown (mirrors the static render). + const bestByGroup = new Map(); for (const [key, pts] of Object.entries(rooflines)) { if (pts.length < 2) continue; const hw = key.split('_').slice(0, -1).join('_'); const prec = key.split('_').pop()!; if (!effectiveActiveHwTypes.has(hw) || !selectedPrecisions.includes(prec)) continue; - const prev = bestByHw.get(hw); - if (!prev || pts.length > prev[1].length) bestByHw.set(hw, [key, pts]); + const groupKey = multiPrecision ? key : hw; + const prev = bestByGroup.get(groupKey); + if (!prev || pts.length > prev[1].length) bestByGroup.set(groupKey, [key, pts]); } - const visibleEntries = [...bestByHw.values()].toSorted( + const visibleEntries = [...bestByGroup.values()].toSorted( ([, a], [, b]) => newYScale(a[0].y) - newYScale(b[0].y), ); @@ -1343,12 +1376,15 @@ const ScatterGraph = React.memo( y: number; } const zoomLabels: ZoomLabel[] = []; - const seenHw = new Set(); + const seen = new Set(); Object.entries(rooflines).forEach(([key, pts]) => { if (pts.length < 2) return; const hw = key.split('_').slice(0, -1).join('_'); - if (seenHw.has(hw)) return; - seenHw.add(hw); + const prec = key.split('_').pop()!; + if (!effectiveActiveHwTypes.has(hw) || !selectedPrecisions.includes(prec)) return; + const groupKey = multiPrecision ? key : hw; + if (seen.has(groupKey)) return; + seen.add(groupKey); const pt = pts.at(-1)!; zoomLabels.push({ key, x: newXScale(pt.x), y: newYScale(pt.y) }); });