diff --git a/lib/DuplicateCongestedPortSolver.ts b/lib/DuplicateCongestedPortSolver.ts new file mode 100644 index 0000000..e9f0f08 --- /dev/null +++ b/lib/DuplicateCongestedPortSolver.ts @@ -0,0 +1,512 @@ +import type { SerializedHyperGraph } from "@tscircuit/hypergraph" +import { BaseSolver } from "@tscircuit/solver-utils" +import { loadSerializedHyperGraph } from "./compat/loadSerializedHyperGraph" +import { + TinyHyperGraphSolver, + type TinyHyperGraphProblem, + type TinyHyperGraphSolverOptions, + type TinyHyperGraphTopology, +} from "./core" +import type { PortId, RouteId } from "./types" + +type SerializedPort = SerializedHyperGraph["ports"][number] +type SerializedRegion = SerializedHyperGraph["regions"][number] + +export const DUPLICATE_PORT_PROXIMITY = 0.05 + +export interface DuplicateCongestedPortSolverOptions { + duplicatePortProximity?: number + routeSolveOptions?: TinyHyperGraphSolverOptions +} + +export interface DuplicatedPortSummary { + sourcePortId: string + duplicatePortIds: string[] + useCount: number +} + +export interface DuplicateCongestedPortSolverReport { + portUseCounts: Record + duplicatedPorts: DuplicatedPortSummary[] +} + +interface Point { + x: number + y: number +} + +const EPSILON = 1e-9 + +const isRecord = (value: unknown): value is Record => + typeof value === "object" && value !== null + +const cloneSerializableValue = (value: T): T => { + if (Array.isArray(value)) { + return value.map((item) => cloneSerializableValue(item)) as T + } + + if (isRecord(value)) { + return Object.fromEntries( + Object.entries(value).map(([key, item]) => [ + key, + cloneSerializableValue(item), + ]), + ) as T + } + + return value +} + +const toObjectRecord = (value: unknown): Record => { + if (isRecord(value)) return { ...value } + if (value === undefined) return {} + return { value } +} + +const getNumber = (value: unknown, fallback = 0) => + typeof value === "number" && Number.isFinite(value) ? value : fallback + +const getPortPoint = (port: SerializedPort): Point => ({ + x: getNumber(port.d?.x), + y: getNumber(port.d?.y), +}) + +const getBoundaryKey = ( + port: Pick, +) => [port.region1Id, port.region2Id].sort().join("\u0000") + +const getDistance = (a: Point, b: Point) => Math.hypot(a.x - b.x, a.y - b.y) + +const normalize = (point: Point): Point | undefined => { + const length = Math.hypot(point.x, point.y) + if (length <= EPSILON) return undefined + return { + x: point.x / length, + y: point.y / length, + } +} + +const getRegionBounds = ( + region: SerializedRegion | undefined, +): { + minX: number + maxX: number + minY: number + maxY: number +} => { + const bounds = region?.d?.bounds + if ( + isRecord(bounds) && + typeof bounds.minX === "number" && + typeof bounds.maxX === "number" && + typeof bounds.minY === "number" && + typeof bounds.maxY === "number" + ) { + return { + minX: bounds.minX, + maxX: bounds.maxX, + minY: bounds.minY, + maxY: bounds.maxY, + } + } + + const center = region?.d?.center + const width = getNumber(region?.d?.width) + const height = getNumber(region?.d?.height) + if (isRecord(center)) { + const x = getNumber(center.x) + const y = getNumber(center.y) + return { + minX: x - width / 2, + maxX: x + width / 2, + minY: y - height / 2, + maxY: y + height / 2, + } + } + + return { + minX: 0, + maxX: 0, + minY: 0, + maxY: 0, + } +} + +const getRegionCenter = (region: SerializedRegion | undefined): Point => { + const center = region?.d?.center + if (isRecord(center)) { + return { + x: getNumber(center.x), + y: getNumber(center.y), + } + } + + const bounds = getRegionBounds(region) + return { + x: (bounds.minX + bounds.maxX) / 2, + y: (bounds.minY + bounds.maxY) / 2, + } +} + +const findNearestPortOnSameBoundary = ( + sourcePort: SerializedPort, + ports: SerializedPort[], +): SerializedPort | undefined => { + const sourceBoundaryKey = getBoundaryKey(sourcePort) + const sourcePoint = getPortPoint(sourcePort) + let nearestPort: SerializedPort | undefined + let nearestDistance = Number.POSITIVE_INFINITY + + for (const port of ports) { + if (port.portId === sourcePort.portId) continue + if (getBoundaryKey(port) !== sourceBoundaryKey) continue + + const distance = getDistance(sourcePoint, getPortPoint(port)) + if (distance <= EPSILON || distance >= nearestDistance) continue + + nearestPort = port + nearestDistance = distance + } + + return nearestPort +} + +const getFallbackBoundaryDirection = ( + sourcePort: SerializedPort, + regionById: Map, +): Point => { + const region1Center = getRegionCenter(regionById.get(sourcePort.region1Id)) + const region2Center = getRegionCenter(regionById.get(sourcePort.region2Id)) + const perpendicular = normalize({ + x: -(region2Center.y - region1Center.y), + y: region2Center.x - region1Center.x, + }) + + return perpendicular ?? { x: 1, y: 0 } +} + +const getDuplicateDirection = ( + sourcePort: SerializedPort, + nearestBoundaryPort: SerializedPort | undefined, + regionById: Map, +): Point => { + const sourcePoint = getPortPoint(sourcePort) + + if (nearestBoundaryPort) { + const nearestPoint = getPortPoint(nearestBoundaryPort) + const awayFromNearest = normalize({ + x: sourcePoint.x - nearestPoint.x, + y: sourcePoint.y - nearestPoint.y, + }) + if (awayFromNearest) return awayFromNearest + } + + return getFallbackBoundaryDirection(sourcePort, regionById) +} + +const createDuplicatePortId = ( + sourcePortId: string, + duplicateIndex: number, + usedPortIds: Set, +): string => { + const basePortId = `${sourcePortId}::dup${duplicateIndex}` + if (!usedPortIds.has(basePortId)) { + usedPortIds.add(basePortId) + return basePortId + } + + for (let collisionIndex = 2; ; collisionIndex++) { + const portId = `${basePortId}-${collisionIndex}` + if (!usedPortIds.has(portId)) { + usedPortIds.add(portId) + return portId + } + } +} + +const insertDuplicatePortIdsAfterSource = ( + pointIds: string[], + sourcePortId: string, + duplicatePortIds: string[], +) => { + const insertionIndex = pointIds.indexOf(sourcePortId) + if (insertionIndex === -1) { + pointIds.push(...duplicatePortIds) + return + } + + pointIds.splice(insertionIndex + 1, 0, ...duplicatePortIds) +} + +const getSerializedPortId = ( + topology: TinyHyperGraphTopology, + portId: PortId, +): string => { + const metadata = topology.portMetadata?.[portId] + if (isRecord(metadata)) { + if (typeof metadata.serializedPortId === "string") { + return metadata.serializedPortId + } + + if (typeof metadata.portId === "string") { + return metadata.portId + } + } + + return `port-${portId}` +} + +const createSingleRouteProblem = ( + problem: TinyHyperGraphProblem, + routeId: RouteId, +): TinyHyperGraphProblem => ({ + routeCount: 1, + portSectionMask: new Int8Array(problem.portSectionMask), + routeMetadata: + problem.routeMetadata === undefined + ? undefined + : [problem.routeMetadata[routeId]], + routeStartPort: Int32Array.from([problem.routeStartPort[routeId]]), + routeEndPort: Int32Array.from([problem.routeEndPort[routeId]]), + routeNet: Int32Array.from([problem.routeNet[routeId]]), + regionNetId: new Int32Array(problem.regionNetId), + portPenalty: + problem.portPenalty === undefined + ? undefined + : new Float64Array(problem.portPenalty), +}) + +const getUsedPortIdsForSolvedRoute = ( + solver: TinyHyperGraphSolver, +): Set => { + const usedPortIds = new Set() + + for (const regionSegments of solver.state.regionSegments) { + for (const [, fromPortId, toPortId] of regionSegments) { + usedPortIds.add(fromPortId) + usedPortIds.add(toPortId) + } + } + + if (usedPortIds.size === 0 && solver.problem.routeCount === 1) { + usedPortIds.add(solver.problem.routeStartPort[0]) + usedPortIds.add(solver.problem.routeEndPort[0]) + } + + return usedPortIds +} + +export class DuplicateCongestedPortSolver extends BaseSolver { + revisedSerializedHyperGraph?: SerializedHyperGraph + report: DuplicateCongestedPortSolverReport = { + portUseCounts: {}, + duplicatedPorts: [], + } + + constructor( + public serializedHyperGraph: SerializedHyperGraph, + public options: DuplicateCongestedPortSolverOptions = {}, + ) { + super() + } + + private getDuplicatePortProximity() { + return this.options.duplicatePortProximity ?? DUPLICATE_PORT_PROXIMITY + } + + private getIndividualRouteSolveOptions(): TinyHyperGraphSolverOptions { + return { + RIP_THRESHOLD_RAMP_ATTEMPTS: 0, + STATIC_REACHABILITY_PRECHECK: false, + ...this.options.routeSolveOptions, + } + } + + private getPortUseCounts(): Map { + const { topology, problem } = loadSerializedHyperGraph( + this.serializedHyperGraph, + ) + const portUseCounts = new Map() + + for (let routeId = 0; routeId < problem.routeCount; routeId++) { + const routeProblem = createSingleRouteProblem(problem, routeId) + const routeSolver = new TinyHyperGraphSolver( + topology, + routeProblem, + this.getIndividualRouteSolveOptions(), + ) + routeSolver.solve() + + if (!routeSolver.solved || routeSolver.failed) { + throw new Error( + `Route ${routeId} could not be solved independently: ${ + routeSolver.error ?? "unknown error" + }`, + ) + } + + for (const portId of getUsedPortIdsForSolvedRoute(routeSolver)) { + const serializedPortId = getSerializedPortId(topology, portId) + portUseCounts.set( + serializedPortId, + (portUseCounts.get(serializedPortId) ?? 0) + 1, + ) + } + } + + return portUseCounts + } + + private duplicateCongestedPorts( + portUseCounts: Map, + ): SerializedHyperGraph { + const duplicatePortProximity = this.getDuplicatePortProximity() + if (!(duplicatePortProximity > 0)) { + throw new Error("duplicatePortProximity must be greater than zero") + } + + const { solvedRoutes: _solvedRoutes, ...restHyperGraph } = + this.serializedHyperGraph + const regions: SerializedRegion[] = this.serializedHyperGraph.regions.map( + (region) => ({ + ...region, + pointIds: [...region.pointIds], + d: cloneSerializableValue(region.d), + }), + ) + const ports: SerializedPort[] = this.serializedHyperGraph.ports.map( + (port) => ({ + ...port, + d: cloneSerializableValue(port.d), + }), + ) + const regionById = new Map( + regions.map((region) => [region.regionId, region]), + ) + const sourcePortById = new Map( + ports.map((port) => [port.portId, port] as const), + ) + const usedPortIds = new Set(ports.map((port) => port.portId)) + const duplicatedPorts: DuplicatedPortSummary[] = [] + + for (const [sourcePortId, useCount] of [...portUseCounts.entries()].sort( + ([leftPortId], [rightPortId]) => leftPortId.localeCompare(rightPortId), + )) { + if (useCount <= 1) continue + + const sourcePort = sourcePortById.get(sourcePortId) + if (!sourcePort) continue + + const duplicateCount = useCount - 1 + const nearestBoundaryPort = findNearestPortOnSameBoundary( + sourcePort, + this.serializedHyperGraph.ports, + ) + const duplicateDirection = getDuplicateDirection( + sourcePort, + nearestBoundaryPort, + regionById, + ) + const sourcePoint = getPortPoint(sourcePort) + const duplicatePortIds: string[] = [] + + for ( + let duplicateIndex = 1; + duplicateIndex <= duplicateCount; + duplicateIndex++ + ) { + const duplicatePortId = createDuplicatePortId( + sourcePortId, + duplicateIndex, + usedPortIds, + ) + const offset = + (duplicatePortProximity * duplicateIndex) / (duplicateCount + 1) + const duplicatedPortData = toObjectRecord( + cloneSerializableValue(sourcePort.d), + ) + duplicatedPortData.x = sourcePoint.x + duplicateDirection.x * offset + duplicatedPortData.y = sourcePoint.y + duplicateDirection.y * offset + duplicatedPortData.duplicatedFromPortId = sourcePortId + duplicatedPortData.duplicateIndex = duplicateIndex + duplicatedPortData.duplicatePortUseCount = useCount + duplicatedPortData.duplicatePortProximity = duplicatePortProximity + duplicatedPortData.repairReason = "congested-port" + + ports.push({ + ...sourcePort, + portId: duplicatePortId, + d: duplicatedPortData, + }) + duplicatePortIds.push(duplicatePortId) + } + + for (const regionId of [sourcePort.region1Id, sourcePort.region2Id]) { + const region = regionById.get(regionId) + if (!region) continue + insertDuplicatePortIdsAfterSource( + region.pointIds, + sourcePortId, + duplicatePortIds, + ) + } + + duplicatedPorts.push({ + sourcePortId, + duplicatePortIds, + useCount, + }) + } + + this.report = { + portUseCounts: Object.fromEntries([...portUseCounts.entries()].sort()), + duplicatedPorts, + } + + return { + ...restHyperGraph, + regions, + ports, + connections: + this.serializedHyperGraph.connections === undefined + ? undefined + : cloneSerializableValue(this.serializedHyperGraph.connections), + } + } + + override _setup() { + try { + const portUseCounts = this.getPortUseCounts() + this.revisedSerializedHyperGraph = + this.duplicateCongestedPorts(portUseCounts) + this.stats = { + ...this.stats, + duplicateSourcePortCount: this.report.duplicatedPorts.length, + duplicatedPortCount: this.report.duplicatedPorts.reduce( + (sum, duplicatedPort) => sum + duplicatedPort.duplicatePortIds.length, + 0, + ), + } + this.solved = true + } catch (error) { + this.failed = true + this.error = error instanceof Error ? error.message : String(error) + } + } + + override _step() { + if (!this.failed) { + this.solved = true + } + } + + override getOutput(): SerializedHyperGraph { + if (!this.revisedSerializedHyperGraph || this.failed) { + throw new Error( + "DuplicateCongestedPortSolver does not have a repaired topology output", + ) + } + + return this.revisedSerializedHyperGraph + } +} diff --git a/lib/core.ts b/lib/core.ts index 78627e2..ffeaec9 100644 --- a/lib/core.ts +++ b/lib/core.ts @@ -96,6 +96,9 @@ export interface TinyHyperGraphProblem { routeNet: Int32Array // NetId[] /** regionNetId[regionId] = reserved net id for the region, -1 means freely traversable */ regionNetId: Int32Array + + /** portPenalty[portId] = extra cost paid when a route traverses the port */ + portPenalty?: Float64Array } export interface TinyHyperGraphProblemSetup { @@ -1129,7 +1132,8 @@ export class TinyHyperGraphSolver extends BaseSolver { return ( currentCandidate.g + newRegionCost + - state.regionCongestionCost[nextRegionId] + state.regionCongestionCost[nextRegionId] + + (this.problem.portPenalty?.[neighborPortId] ?? 0) ) } diff --git a/lib/index.ts b/lib/index.ts index b3f5c40..75ea4fb 100644 --- a/lib/index.ts +++ b/lib/index.ts @@ -1,4 +1,5 @@ export * from "./core" +export * from "./DuplicateCongestedPortSolver" export * from "./poly" export * from "./bus-solver" export * from "./region-graph" diff --git a/lib/section-solver/TinyHyperGraphSectionPipelineSolver.ts b/lib/section-solver/TinyHyperGraphSectionPipelineSolver.ts index 8614d61..876ae97 100644 --- a/lib/section-solver/TinyHyperGraphSectionPipelineSolver.ts +++ b/lib/section-solver/TinyHyperGraphSectionPipelineSolver.ts @@ -120,6 +120,10 @@ const createProblemWithPortSectionMask = ( routeEndPort: new Int32Array(problem.routeEndPort), routeNet: new Int32Array(problem.routeNet), regionNetId: new Int32Array(problem.regionNetId), + portPenalty: + problem.portPenalty === undefined + ? undefined + : new Float64Array(problem.portPenalty), }) const getSectionMaskCandidates = ( diff --git a/lib/section-solver/index.ts b/lib/section-solver/index.ts index db2ea74..5d4b50a 100644 --- a/lib/section-solver/index.ts +++ b/lib/section-solver/index.ts @@ -551,6 +551,10 @@ const createSectionRoutePlans = ( routeEndPort, routeNet: new Int32Array(problem.routeNet), regionNetId: new Int32Array(problem.regionNetId), + portPenalty: + problem.portPenalty === undefined + ? undefined + : new Float64Array(problem.portPenalty), }, routePlans, activeRouteIds, diff --git a/pages/port-chokepoint.page.tsx b/pages/port-chokepoint.page.tsx index 2baee30..4e9f30f 100644 --- a/pages/port-chokepoint.page.tsx +++ b/pages/port-chokepoint.page.tsx @@ -1,11 +1,18 @@ import type { SerializedHyperGraph } from "@tscircuit/hypergraph" import { loadSerializedHyperGraph } from "lib/compat/loadSerializedHyperGraph" -import { TinyHyperGraphSolver } from "lib/index" +import { DuplicateCongestedPortSolver, TinyHyperGraphSolver } from "lib/index" import { portChokepointFixture } from "../tests/fixtures/port-chokepoint.fixture" import { Debugger } from "./components/Debugger" const createSolver = (serializedHyperGraph: SerializedHyperGraph) => { - const { topology, problem } = loadSerializedHyperGraph(serializedHyperGraph) + const duplicateCongestedPortSolver = new DuplicateCongestedPortSolver( + serializedHyperGraph, + ) + duplicateCongestedPortSolver.solve() + const { topology, problem } = loadSerializedHyperGraph( + duplicateCongestedPortSolver.getOutput(), + ) + return new TinyHyperGraphSolver(topology, problem, { MAX_ITERATIONS: 20_000, STATIC_REACHABILITY_PRECHECK: false, @@ -18,10 +25,9 @@ export default function PortChokepointPage() {
Port chokepoint repro: connection-a routes from the top left endpoint to the top right endpoint, while connection-b{" "} - routes from the bottom left endpoint to the bottom right endpoint. Both - nets must pass through the single left-center-choke and{" "} - center-right-choke ports, so after one route claims the - corridor the other route has no valid port-disjoint path. + routes from the bottom left endpoint to the bottom right endpoint. The + page first runs DuplicateCongestedPortSolver, then routes + the repaired topology with the strict core solver.
({ + regionId, + pointIds, + d: { + center: { x: centerX, y: centerY }, + width, + height, + }, +}) + +const createPort = ( + portId: string, + region1Id: string, + region2Id: string, + x: number, + y: number, +): SerializedHyperGraph["ports"][number] => ({ + portId, + region1Id, + region2Id, + d: { x, y, z: 0 }, +}) + +const getNumber = (value: unknown) => + typeof value === "number" && Number.isFinite(value) ? value : 0 + +const getPortIndexBySerializedId = ( + topology: ReturnType["topology"], + serializedPortId: string, +) => + topology.portMetadata?.findIndex( + (metadata) => + typeof metadata === "object" && + metadata !== null && + "serializedPortId" in metadata && + metadata.serializedPortId === serializedPortId, + ) ?? -1 + +const createParallelPortFixture = (): SerializedHyperGraph => ({ + regions: [ + createRegion("start", -4, 0, 2, 2, ["start-port"]), + createRegion("left", -1, 0, 4, 6, ["start-port", "middle-a", "middle-b"]), + createRegion("right", 1, 0, 4, 6, ["middle-a", "middle-b", "end-port"]), + createRegion("end", 4, 0, 2, 2, ["end-port"]), + ], + ports: [ + createPort("start-port", "start", "left", -3, 0), + createPort("middle-a", "left", "right", 0, 0), + createPort("middle-b", "left", "right", 0, 2), + createPort("end-port", "right", "end", 3, 0), + ], + connections: [ + { + connectionId: "connection-a", + startRegionId: "start", + endRegionId: "end", + mutuallyConnectedNetworkId: "net-a", + }, + ], +}) + +const createDuplicatePortFixture = (): SerializedHyperGraph => ({ + regions: [ + createRegion("a-start", -4, 0, 2, 2, ["a-start-port"]), + createRegion("b-start", -4, -0.2, 2, 2, ["b-start-port"]), + createRegion("left", -1, 0, 4, 10, [ + "a-start-port", + "b-start-port", + "shared-choke", + "shared-neighbor", + ]), + createRegion("right", 1, 0, 4, 10, [ + "shared-choke", + "shared-neighbor", + "a-end-port", + "b-end-port", + ]), + createRegion("a-end", 4, 0, 2, 2, ["a-end-port"]), + createRegion("b-end", 4, -0.2, 2, 2, ["b-end-port"]), + ], + ports: [ + createPort("a-start-port", "a-start", "left", -3, 0), + createPort("b-start-port", "b-start", "left", -3, -0.2), + createPort("shared-choke", "left", "right", 0, 0), + createPort("shared-neighbor", "left", "right", 0, 4), + createPort("a-end-port", "right", "a-end", 3, 0), + createPort("b-end-port", "right", "b-end", 3, -0.2), + ], + connections: [ + { + connectionId: "connection-a", + startRegionId: "a-start", + endRegionId: "a-end", + mutuallyConnectedNetworkId: "net-a", + }, + { + connectionId: "connection-b", + startRegionId: "b-start", + endRegionId: "b-end", + mutuallyConnectedNetworkId: "net-b", + }, + ], +}) + +test("core solver applies port penalties when choosing an intermediate port", () => { + const { topology, problem } = loadSerializedHyperGraph( + createParallelPortFixture(), + ) + const penalizedPortIndex = getPortIndexBySerializedId(topology, "middle-a") + expect(penalizedPortIndex).toBeGreaterThanOrEqual(0) + + problem.portPenalty = new Float64Array(topology.portCount) + problem.portPenalty[penalizedPortIndex] = 1_000 + + const solver = new TinyHyperGraphSolver(topology, problem, { + RIP_THRESHOLD_RAMP_ATTEMPTS: 0, + STATIC_REACHABILITY_PRECHECK: false, + }) + solver.solve() + + expect(solver.solved).toBe(true) + expect(solver.failed).toBe(false) + expect( + solver.getOutput().solvedRoutes?.[0]?.path.map(({ portId }) => portId), + ).toEqual(["start-port", "middle-b", "end-port"]) +}) + +test("duplicate congested port solver duplicates independently reused ports in line with the boundary", () => { + const duplicatePortProximity = 0.2 + const solver = new DuplicateCongestedPortSolver( + createDuplicatePortFixture(), + { + duplicatePortProximity, + }, + ) + + solver.solve() + + expect(solver.solved).toBe(true) + expect(solver.failed).toBe(false) + expect(solver.report.portUseCounts["shared-choke"]).toBe(2) + expect(solver.report.duplicatedPorts).toContainEqual({ + sourcePortId: "shared-choke", + duplicatePortIds: ["shared-choke::dup1"], + useCount: 2, + }) + + const output = solver.getOutput() + const sourcePort = output.ports.find( + (port) => port.portId === "shared-choke", + )! + const nearestPort = output.ports.find( + (port) => port.portId === "shared-neighbor", + )! + const duplicatePort = output.ports.find( + (port) => port.portId === "shared-choke::dup1", + )! + + const sourcePoint = { + x: getNumber(sourcePort.d?.x), + y: getNumber(sourcePort.d?.y), + } + const nearestVector = { + x: getNumber(nearestPort.d?.x) - sourcePoint.x, + y: getNumber(nearestPort.d?.y) - sourcePoint.y, + } + const duplicateVector = { + x: getNumber(duplicatePort.d?.x) - sourcePoint.x, + y: getNumber(duplicatePort.d?.y) - sourcePoint.y, + } + const duplicateDistance = Math.hypot(duplicateVector.x, duplicateVector.y) + const crossProduct = + nearestVector.x * duplicateVector.y - nearestVector.y * duplicateVector.x + + expect(duplicatePort.d?.duplicatedFromPortId).toBe("shared-choke") + expect(duplicatePort.d?.duplicateIndex).toBe(1) + expect(duplicateDistance).toBeGreaterThan(0) + expect(duplicateDistance).toBeLessThanOrEqual(duplicatePortProximity) + expect(Math.abs(crossProduct)).toBeLessThan(1e-9) + expect(output.solvedRoutes).toBeUndefined() + expect( + output.regions + .filter( + (region) => region.regionId === "left" || region.regionId === "right", + ) + .every((region) => region.pointIds.includes("shared-choke::dup1")), + ).toBe(true) +})