Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions packages/app/cypress/e2e/line-labels.cy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,31 @@ describe('Line Labels Toggle', () => {
// Labels should be rendered
cy.get('[data-testid="scatter-graph"] svg g.line-label').should('have.length.greaterThan', 0);
});

it('appends the precision to each line label when multiple precisions are selected', () => {
cy.visit('/inference?i_linelabel=1&i_prec=fp4,fp8', {
onBeforeLoad(win) {
win.localStorage.setItem('inferencex-star-modal-dismissed', String(Date.now()));
},
});
cy.get('[data-testid="scatter-graph"]').should('be.visible');

// With both FP4 and FP8 shown, each curve is its own line and the label
// must carry the precision so the two curves of the same hardware are
// distinguishable (e.g. "B200 (vLLM) FP8" vs "B200 (vLLM) FP4").
cy.get('[data-testid="scatter-graph"] svg g.line-label .ll-text')
.should('have.length.greaterThan', 0)
.then(($texts) => {
const labels = $texts.toArray().map((el) => el.textContent ?? '');
// At least one label for each selected precision.
expect(
labels.some((t) => /\bFP8\b/u.test(t)),
'an FP8 line label exists',
).to.equal(true);
expect(
labels.some((t) => /\bFP4\b/u.test(t)),
'an FP4 line label exists',
).to.equal(true);
});
});
});
104 changes: 70 additions & 34 deletions packages/app/src/components/inference/ui/ScatterGraph.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import ChartLegend from '@/components/ui/chart-legend';
import { useUnofficialRun } from '@/components/unofficial-run-provider';
import { computeToggle } from '@/hooks/useTogglableSet';
import { getHardwareConfig, getModelSortIndex } from '@/lib/constants';
import { getChartWatermark } from '@/lib/data-mappings';
import { getChartWatermark, getPrecisionLabel, type Precision } from '@/lib/data-mappings';
import { formatNumber, getDisplayLabel, updateRepoUrl } from '@/lib/utils';
import { D3Chart } from '@/lib/d3-chart/D3Chart';
import type {
Expand Down Expand Up @@ -101,6 +101,14 @@ const parseHwKeyToLabel = (hwKey: string): { name: string; label: string } => {
return { name: config.label, label: getDisplayLabel(config) };
};

// Line-label text for a curve. When more than one precision is shown, each curve
// is its own line, so append the precision (e.g. "B200 (vLLM) FP8") to keep the
// FP4 and FP8 curves of the same hardware distinguishable.
const lineLabelText = (hwKey: string, precision: string, includePrecision: boolean): string => {
const base = parseHwKeyToLabel(hwKey).label;
return includePrecision ? `${base} ${getPrecisionLabel(precision as Precision)}` : base;
};

const ScatterGraph = React.memo(
({
chartId,
Expand Down Expand Up @@ -914,6 +922,9 @@ const ScatterGraph = React.memo(

if (showLineLabels) {
const isInteractivity = chartDefinition.chartType === 'interactivity';
// With >1 precision selected each precision is its own curve, so label
// every curve and include the precision in the text.
const multiPrecision = selectedPrecisions.length > 1;
const LABEL_H = 18;
const LABEL_W = 120; // approximate label width for overlap check

Expand All @@ -924,16 +935,19 @@ const ScatterGraph = React.memo(
const collides = (cx: number, cy: number) =>
placed.some((p) => Math.abs(p.y - cy) < LABEL_H && Math.abs(p.x - cx) < LABEL_W);

// Deduplicate by hw key — pick the roofline with most points per hw
const bestByHw = new Map<string, (typeof entries)[0]>();
// Deduplicate by group key — one label per curve. With a single
// precision that's one per hw; with multiple it's one per (hw,
// precision) so each precision curve keeps its own label.
const bestByGroup = new Map<string, (typeof entries)[0]>();
for (const e of entries) {
if (!e.visible || e.points.length < 2) continue;
const prev = bestByHw.get(e.hw);
if (!prev || e.points.length > prev.points.length) bestByHw.set(e.hw, e);
const groupKey = multiPrecision ? e.key : e.hw;
const prev = bestByGroup.get(groupKey);
if (!prev || e.points.length > prev.points.length) bestByGroup.set(groupKey, e);
}

// Sort entries by highest y-value first (top of chart) for priority
const sorted = [...bestByHw.values()].toSorted((a, b) => {
const sorted = [...bestByGroup.values()].toSorted((a, b) => {
const ay = yScale(a.points[0].y);
const by = yScale(b.points[0].y);
return ay - by; // smaller pixel y = higher on chart
Expand All @@ -948,7 +962,7 @@ const ScatterGraph = React.memo(
pts.at(-1)!, // endpoint
];

const { label } = parseHwKeyToLabel(entry.hw);
const label = lineLabelText(entry.hw, entry.precision, multiPrecision);
let foundPlacement = false;
for (const pt of candidates) {
const px = xScale(pt.x);
Expand Down Expand Up @@ -983,33 +997,40 @@ const ScatterGraph = React.memo(
}
}

// Also add hidden entries for non-visible hw (so D3 data-join is clean)
const labeledHw = new Set(lineLabels.map((l) => l.hw));
// Also add hidden entries for any curve that wasn't placed (so the
// D3 data-join, keyed by series key, is clean).
const labeledKeys = new Set(lineLabels.map((l) => l.key));
for (const entry of entries) {
if (entry.points.length >= 2 && !labeledHw.has(entry.hw)) {
const { label } = parseHwKeyToLabel(entry.hw);
if (entry.points.length >= 2 && !labeledKeys.has(entry.key)) {
lineLabels.push({
key: entry.key,
hw: entry.hw,
label,
label: lineLabelText(entry.hw, entry.precision, multiPrecision),
color: getCssColor(resolveColor(entry.hw)),
x: xScale(entry.points[0].x),
y: yScale(entry.points[0].y),
visible: false,
});
labeledHw.add(entry.hw);
labeledKeys.add(entry.key);
}
}

// Overlay (unofficial run) rooflines also get line labels using the
// run-palette color so they match the legend swatches. The label
// text mirrors the overlay legend ("✕ <branch>" — falls back to the
// hw label if run metadata isn't available, e.g. legacy callers).
const overlayLabelText = (runIndex: number, hwKey: string): string => {
const overlayLabelText = (
runIndex: number,
hwKey: string,
precision: string,
): string => {
const info = unofficialRunInfos[runIndex];
if (!info) return parseHwKeyToLabel(hwKey).label;
const branch = info.branch || `run ${info.id}`;
return `✕ ${branch}`;
const base = info
? `✕ ${info.branch || `run ${info.id}`}`
: parseHwKeyToLabel(hwKey).label;
return multiPrecision
? `${base} ${getPrecisionLabel(precision as Precision)}`
: base;
};
const sortedOverlay = Object.entries(overlayRooflines)
.filter(
Expand All @@ -1026,7 +1047,11 @@ const ScatterGraph = React.memo(
pts[Math.max(0, Math.floor((pts.length * 2) / 3))],
pts.at(-1)!,
];
const label = overlayLabelText(group.runIndex, group.hwKey);
const label = overlayLabelText(
group.runIndex,
group.hwKey,
group.points[0]?.precision ?? '',
);
let placedOverlay = false;
for (const pt of candidates) {
const px = xScale(pt.x);
Expand Down Expand Up @@ -1060,31 +1085,36 @@ const ScatterGraph = React.memo(
}
}
} else {
// TTFT / E2EL: endpoint labels, one per hw key
const seenHw = new Set<string>();
// TTFT / E2EL: endpoint labels, one per curve (per hw, or per
// (hw, precision) when multiple precisions are shown).
const seen = new Set<string>();
for (const entry of entries) {
if (entry.points.length < 2 || seenHw.has(entry.hw)) continue;
seenHw.add(entry.hw);
if (entry.points.length < 2 || !entry.visible) continue;
const groupKey = multiPrecision ? entry.key : entry.hw;
if (seen.has(groupKey)) continue;
seen.add(groupKey);
const pt = entry.points.at(-1)!;
const { label } = parseHwKeyToLabel(entry.hw);
lineLabels.push({
key: entry.key,
hw: entry.hw,
label,
label: lineLabelText(entry.hw, entry.precision, multiPrecision),
color: getCssColor(resolveColor(entry.hw)),
x: xScale(pt.x),
y: yScale(pt.y),
visible: entry.visible,
visible: true,
});
}
// Endpoint labels for overlay rooflines too (one per (hw, runIndex)),
// labeled with the run's branch name to mirror the overlay legend.
for (const [ovKey, group] of Object.entries(overlayRooflines)) {
if (group.points.length < 2 || !activeOverlayHwTypes.has(group.hwKey)) continue;
const info = unofficialRunInfos[group.runIndex];
const labelText = info
const branchOrHw = info
? `✕ ${info.branch || `run ${info.id}`}`
: parseHwKeyToLabel(group.hwKey).label;
const labelText = multiPrecision
? `${branchOrHw} ${getPrecisionLabel((group.points[0]?.precision ?? '') as Precision)}`
: branchOrHw;
const labelKey = `overlay-${ovKey}`;
const pt = group.points.at(-1)!;
lineLabels.push({
Expand Down Expand Up @@ -1236,6 +1266,7 @@ const ScatterGraph = React.memo(
// Update line label positions on zoom
if (showLineLabels) {
const isInteractivity = chartDefinition.chartType === 'interactivity';
const multiPrecision = selectedPrecisions.length > 1;
const LABEL_H = 18;
const LABEL_W = 120;

Expand All @@ -1245,17 +1276,19 @@ const ScatterGraph = React.memo(
const collides = (cx: number, cy: number) =>
placed.some((p) => Math.abs(p.y - cy) < LABEL_H && Math.abs(p.x - cx) < LABEL_W);

// Deduplicate by hw key — pick roofline with most points per hw
const bestByHw = new Map<string, [string, InferenceData[]]>();
// Deduplicate by group key — one curve per hw, or per (hw, precision)
// when multiple precisions are shown (mirrors the static render).
const bestByGroup = new Map<string, [string, InferenceData[]]>();
for (const [key, pts] of Object.entries(rooflines)) {
if (pts.length < 2) continue;
const hw = key.split('_').slice(0, -1).join('_');
const prec = key.split('_').pop()!;
if (!effectiveActiveHwTypes.has(hw) || !selectedPrecisions.includes(prec)) continue;
const prev = bestByHw.get(hw);
if (!prev || pts.length > prev[1].length) bestByHw.set(hw, [key, pts]);
const groupKey = multiPrecision ? key : hw;
const prev = bestByGroup.get(groupKey);
if (!prev || pts.length > prev[1].length) bestByGroup.set(groupKey, [key, pts]);
}
const visibleEntries = [...bestByHw.values()].toSorted(
const visibleEntries = [...bestByGroup.values()].toSorted(
([, a], [, b]) => newYScale(a[0].y) - newYScale(b[0].y),
);

Expand Down Expand Up @@ -1343,12 +1376,15 @@ const ScatterGraph = React.memo(
y: number;
}
const zoomLabels: ZoomLabel[] = [];
const seenHw = new Set<string>();
const seen = new Set<string>();
Object.entries(rooflines).forEach(([key, pts]) => {
if (pts.length < 2) return;
const hw = key.split('_').slice(0, -1).join('_');
if (seenHw.has(hw)) return;
seenHw.add(hw);
const prec = key.split('_').pop()!;
if (!effectiveActiveHwTypes.has(hw) || !selectedPrecisions.includes(prec)) return;
const groupKey = multiPrecision ? key : hw;
if (seen.has(groupKey)) return;
seen.add(groupKey);
const pt = pts.at(-1)!;
zoomLabels.push({ key, x: newXScale(pt.x), y: newYScale(pt.y) });
});
Expand Down
Loading