diff --git a/lib/DuplicateCongestedPortSolver.ts b/lib/DuplicateCongestedPortSolver.ts new file mode 100644 index 0000000..e9f0f08 --- /dev/null +++ b/lib/DuplicateCongestedPortSolver.ts @@ -0,0 +1,512 @@ +import type { SerializedHyperGraph } from "@tscircuit/hypergraph" +import { BaseSolver } from "@tscircuit/solver-utils" +import { loadSerializedHyperGraph } from "./compat/loadSerializedHyperGraph" +import { + TinyHyperGraphSolver, + type TinyHyperGraphProblem, + type TinyHyperGraphSolverOptions, + type TinyHyperGraphTopology, +} from "./core" +import type { PortId, RouteId } from "./types" + +type SerializedPort = SerializedHyperGraph["ports"][number] +type SerializedRegion = SerializedHyperGraph["regions"][number] + +export const DUPLICATE_PORT_PROXIMITY = 0.05 + +export interface DuplicateCongestedPortSolverOptions { + duplicatePortProximity?: number + routeSolveOptions?: TinyHyperGraphSolverOptions +} + +export interface DuplicatedPortSummary { + sourcePortId: string + duplicatePortIds: string[] + useCount: number +} + +export interface DuplicateCongestedPortSolverReport { + portUseCounts: Record + duplicatedPorts: DuplicatedPortSummary[] +} + +interface Point { + x: number + y: number +} + +const EPSILON = 1e-9 + +const isRecord = (value: unknown): value is Record => + typeof value === "object" && value !== null + +const cloneSerializableValue = (value: T): T => { + if (Array.isArray(value)) { + return value.map((item) => cloneSerializableValue(item)) as T + } + + if (isRecord(value)) { + return Object.fromEntries( + Object.entries(value).map(([key, item]) => [ + key, + cloneSerializableValue(item), + ]), + ) as T + } + + return value +} + +const toObjectRecord = (value: unknown): Record => { + if (isRecord(value)) return { ...value } + if (value === undefined) return {} + return { value } +} + +const getNumber = (value: unknown, fallback = 0) => + typeof value === "number" && Number.isFinite(value) ? value : fallback + +const getPortPoint = (port: SerializedPort): Point => ({ + x: getNumber(port.d?.x), + y: getNumber(port.d?.y), +}) + +const getBoundaryKey = ( + port: Pick, +) => [port.region1Id, port.region2Id].sort().join("\u0000") + +const getDistance = (a: Point, b: Point) => Math.hypot(a.x - b.x, a.y - b.y) + +const normalize = (point: Point): Point | undefined => { + const length = Math.hypot(point.x, point.y) + if (length <= EPSILON) return undefined + return { + x: point.x / length, + y: point.y / length, + } +} + +const getRegionBounds = ( + region: SerializedRegion | undefined, +): { + minX: number + maxX: number + minY: number + maxY: number +} => { + const bounds = region?.d?.bounds + if ( + isRecord(bounds) && + typeof bounds.minX === "number" && + typeof bounds.maxX === "number" && + typeof bounds.minY === "number" && + typeof bounds.maxY === "number" + ) { + return { + minX: bounds.minX, + maxX: bounds.maxX, + minY: bounds.minY, + maxY: bounds.maxY, + } + } + + const center = region?.d?.center + const width = getNumber(region?.d?.width) + const height = getNumber(region?.d?.height) + if (isRecord(center)) { + const x = getNumber(center.x) + const y = getNumber(center.y) + return { + minX: x - width / 2, + maxX: x + width / 2, + minY: y - height / 2, + maxY: y + height / 2, + } + } + + return { + minX: 0, + maxX: 0, + minY: 0, + maxY: 0, + } +} + +const getRegionCenter = (region: SerializedRegion | undefined): Point => { + const center = region?.d?.center + if (isRecord(center)) { + return { + x: getNumber(center.x), + y: getNumber(center.y), + } + } + + const bounds = getRegionBounds(region) + return { + x: (bounds.minX + bounds.maxX) / 2, + y: (bounds.minY + bounds.maxY) / 2, + } +} + +const findNearestPortOnSameBoundary = ( + sourcePort: SerializedPort, + ports: SerializedPort[], +): SerializedPort | undefined => { + const sourceBoundaryKey = getBoundaryKey(sourcePort) + const sourcePoint = getPortPoint(sourcePort) + let nearestPort: SerializedPort | undefined + let nearestDistance = Number.POSITIVE_INFINITY + + for (const port of ports) { + if (port.portId === sourcePort.portId) continue + if (getBoundaryKey(port) !== sourceBoundaryKey) continue + + const distance = getDistance(sourcePoint, getPortPoint(port)) + if (distance <= EPSILON || distance >= nearestDistance) continue + + nearestPort = port + nearestDistance = distance + } + + return nearestPort +} + +const getFallbackBoundaryDirection = ( + sourcePort: SerializedPort, + regionById: Map, +): Point => { + const region1Center = getRegionCenter(regionById.get(sourcePort.region1Id)) + const region2Center = getRegionCenter(regionById.get(sourcePort.region2Id)) + const perpendicular = normalize({ + x: -(region2Center.y - region1Center.y), + y: region2Center.x - region1Center.x, + }) + + return perpendicular ?? { x: 1, y: 0 } +} + +const getDuplicateDirection = ( + sourcePort: SerializedPort, + nearestBoundaryPort: SerializedPort | undefined, + regionById: Map, +): Point => { + const sourcePoint = getPortPoint(sourcePort) + + if (nearestBoundaryPort) { + const nearestPoint = getPortPoint(nearestBoundaryPort) + const awayFromNearest = normalize({ + x: sourcePoint.x - nearestPoint.x, + y: sourcePoint.y - nearestPoint.y, + }) + if (awayFromNearest) return awayFromNearest + } + + return getFallbackBoundaryDirection(sourcePort, regionById) +} + +const createDuplicatePortId = ( + sourcePortId: string, + duplicateIndex: number, + usedPortIds: Set, +): string => { + const basePortId = `${sourcePortId}::dup${duplicateIndex}` + if (!usedPortIds.has(basePortId)) { + usedPortIds.add(basePortId) + return basePortId + } + + for (let collisionIndex = 2; ; collisionIndex++) { + const portId = `${basePortId}-${collisionIndex}` + if (!usedPortIds.has(portId)) { + usedPortIds.add(portId) + return portId + } + } +} + +const insertDuplicatePortIdsAfterSource = ( + pointIds: string[], + sourcePortId: string, + duplicatePortIds: string[], +) => { + const insertionIndex = pointIds.indexOf(sourcePortId) + if (insertionIndex === -1) { + pointIds.push(...duplicatePortIds) + return + } + + pointIds.splice(insertionIndex + 1, 0, ...duplicatePortIds) +} + +const getSerializedPortId = ( + topology: TinyHyperGraphTopology, + portId: PortId, +): string => { + const metadata = topology.portMetadata?.[portId] + if (isRecord(metadata)) { + if (typeof metadata.serializedPortId === "string") { + return metadata.serializedPortId + } + + if (typeof metadata.portId === "string") { + return metadata.portId + } + } + + return `port-${portId}` +} + +const createSingleRouteProblem = ( + problem: TinyHyperGraphProblem, + routeId: RouteId, +): TinyHyperGraphProblem => ({ + routeCount: 1, + portSectionMask: new Int8Array(problem.portSectionMask), + routeMetadata: + problem.routeMetadata === undefined + ? undefined + : [problem.routeMetadata[routeId]], + routeStartPort: Int32Array.from([problem.routeStartPort[routeId]]), + routeEndPort: Int32Array.from([problem.routeEndPort[routeId]]), + routeNet: Int32Array.from([problem.routeNet[routeId]]), + regionNetId: new Int32Array(problem.regionNetId), + portPenalty: + problem.portPenalty === undefined + ? undefined + : new Float64Array(problem.portPenalty), +}) + +const getUsedPortIdsForSolvedRoute = ( + solver: TinyHyperGraphSolver, +): Set => { + const usedPortIds = new Set() + + for (const regionSegments of solver.state.regionSegments) { + for (const [, fromPortId, toPortId] of regionSegments) { + usedPortIds.add(fromPortId) + usedPortIds.add(toPortId) + } + } + + if (usedPortIds.size === 0 && solver.problem.routeCount === 1) { + usedPortIds.add(solver.problem.routeStartPort[0]) + usedPortIds.add(solver.problem.routeEndPort[0]) + } + + return usedPortIds +} + +export class DuplicateCongestedPortSolver extends BaseSolver { + revisedSerializedHyperGraph?: SerializedHyperGraph + report: DuplicateCongestedPortSolverReport = { + portUseCounts: {}, + duplicatedPorts: [], + } + + constructor( + public serializedHyperGraph: SerializedHyperGraph, + public options: DuplicateCongestedPortSolverOptions = {}, + ) { + super() + } + + private getDuplicatePortProximity() { + return this.options.duplicatePortProximity ?? DUPLICATE_PORT_PROXIMITY + } + + private getIndividualRouteSolveOptions(): TinyHyperGraphSolverOptions { + return { + RIP_THRESHOLD_RAMP_ATTEMPTS: 0, + STATIC_REACHABILITY_PRECHECK: false, + ...this.options.routeSolveOptions, + } + } + + private getPortUseCounts(): Map { + const { topology, problem } = loadSerializedHyperGraph( + this.serializedHyperGraph, + ) + const portUseCounts = new Map() + + for (let routeId = 0; routeId < problem.routeCount; routeId++) { + const routeProblem = createSingleRouteProblem(problem, routeId) + const routeSolver = new TinyHyperGraphSolver( + topology, + routeProblem, + this.getIndividualRouteSolveOptions(), + ) + routeSolver.solve() + + if (!routeSolver.solved || routeSolver.failed) { + throw new Error( + `Route ${routeId} could not be solved independently: ${ + routeSolver.error ?? "unknown error" + }`, + ) + } + + for (const portId of getUsedPortIdsForSolvedRoute(routeSolver)) { + const serializedPortId = getSerializedPortId(topology, portId) + portUseCounts.set( + serializedPortId, + (portUseCounts.get(serializedPortId) ?? 0) + 1, + ) + } + } + + return portUseCounts + } + + private duplicateCongestedPorts( + portUseCounts: Map, + ): SerializedHyperGraph { + const duplicatePortProximity = this.getDuplicatePortProximity() + if (!(duplicatePortProximity > 0)) { + throw new Error("duplicatePortProximity must be greater than zero") + } + + const { solvedRoutes: _solvedRoutes, ...restHyperGraph } = + this.serializedHyperGraph + const regions: SerializedRegion[] = this.serializedHyperGraph.regions.map( + (region) => ({ + ...region, + pointIds: [...region.pointIds], + d: cloneSerializableValue(region.d), + }), + ) + const ports: SerializedPort[] = this.serializedHyperGraph.ports.map( + (port) => ({ + ...port, + d: cloneSerializableValue(port.d), + }), + ) + const regionById = new Map( + regions.map((region) => [region.regionId, region]), + ) + const sourcePortById = new Map( + ports.map((port) => [port.portId, port] as const), + ) + const usedPortIds = new Set(ports.map((port) => port.portId)) + const duplicatedPorts: DuplicatedPortSummary[] = [] + + for (const [sourcePortId, useCount] of [...portUseCounts.entries()].sort( + ([leftPortId], [rightPortId]) => leftPortId.localeCompare(rightPortId), + )) { + if (useCount <= 1) continue + + const sourcePort = sourcePortById.get(sourcePortId) + if (!sourcePort) continue + + const duplicateCount = useCount - 1 + const nearestBoundaryPort = findNearestPortOnSameBoundary( + sourcePort, + this.serializedHyperGraph.ports, + ) + const duplicateDirection = getDuplicateDirection( + sourcePort, + nearestBoundaryPort, + regionById, + ) + const sourcePoint = getPortPoint(sourcePort) + const duplicatePortIds: string[] = [] + + for ( + let duplicateIndex = 1; + duplicateIndex <= duplicateCount; + duplicateIndex++ + ) { + const duplicatePortId = createDuplicatePortId( + sourcePortId, + duplicateIndex, + usedPortIds, + ) + const offset = + (duplicatePortProximity * duplicateIndex) / (duplicateCount + 1) + const duplicatedPortData = toObjectRecord( + cloneSerializableValue(sourcePort.d), + ) + duplicatedPortData.x = sourcePoint.x + duplicateDirection.x * offset + duplicatedPortData.y = sourcePoint.y + duplicateDirection.y * offset + duplicatedPortData.duplicatedFromPortId = sourcePortId + duplicatedPortData.duplicateIndex = duplicateIndex + duplicatedPortData.duplicatePortUseCount = useCount + duplicatedPortData.duplicatePortProximity = duplicatePortProximity + duplicatedPortData.repairReason = "congested-port" + + ports.push({ + ...sourcePort, + portId: duplicatePortId, + d: duplicatedPortData, + }) + duplicatePortIds.push(duplicatePortId) + } + + for (const regionId of [sourcePort.region1Id, sourcePort.region2Id]) { + const region = regionById.get(regionId) + if (!region) continue + insertDuplicatePortIdsAfterSource( + region.pointIds, + sourcePortId, + duplicatePortIds, + ) + } + + duplicatedPorts.push({ + sourcePortId, + duplicatePortIds, + useCount, + }) + } + + this.report = { + portUseCounts: Object.fromEntries([...portUseCounts.entries()].sort()), + duplicatedPorts, + } + + return { + ...restHyperGraph, + regions, + ports, + connections: + this.serializedHyperGraph.connections === undefined + ? undefined + : cloneSerializableValue(this.serializedHyperGraph.connections), + } + } + + override _setup() { + try { + const portUseCounts = this.getPortUseCounts() + this.revisedSerializedHyperGraph = + this.duplicateCongestedPorts(portUseCounts) + this.stats = { + ...this.stats, + duplicateSourcePortCount: this.report.duplicatedPorts.length, + duplicatedPortCount: this.report.duplicatedPorts.reduce( + (sum, duplicatedPort) => sum + duplicatedPort.duplicatePortIds.length, + 0, + ), + } + this.solved = true + } catch (error) { + this.failed = true + this.error = error instanceof Error ? error.message : String(error) + } + } + + override _step() { + if (!this.failed) { + this.solved = true + } + } + + override getOutput(): SerializedHyperGraph { + if (!this.revisedSerializedHyperGraph || this.failed) { + throw new Error( + "DuplicateCongestedPortSolver does not have a repaired topology output", + ) + } + + return this.revisedSerializedHyperGraph + } +} diff --git a/lib/core.ts b/lib/core.ts index 78627e2..e50f3a5 100644 --- a/lib/core.ts +++ b/lib/core.ts @@ -29,6 +29,8 @@ import { visualizeTinyGraph } from "./visualizeTinyGraph" export type { StaticallyUnroutableRouteSummary } from "./static-reachability" +const GREEDY_FINAL_ROUTE_MAX_ITERATIONS = 50e3 + export const createEmptyRegionIntersectionCache = (): RegionIntersectionCache => ({ netIds: new Int32Array(0), @@ -42,6 +44,45 @@ export const createEmptyRegionIntersectionCache = existingSegmentCount: 0, }) +const cloneRegionSegments = ( + regionSegments: Array<[RouteId, PortId, PortId][]>, +): Array<[RouteId, PortId, PortId][]> => + regionSegments.map((segments) => + segments.map( + ([routeId, fromPortId, toPortId]) => + [routeId, fromPortId, toPortId] as [RouteId, PortId, PortId], + ), + ) + +const cloneRegionIntersectionCache = ( + regionIntersectionCache: RegionIntersectionCache, +): RegionIntersectionCache => ({ + netIds: new Int32Array(regionIntersectionCache.netIds), + lesserAngles: new Int32Array(regionIntersectionCache.lesserAngles), + greaterAngles: new Int32Array(regionIntersectionCache.greaterAngles), + layerMasks: new Int32Array(regionIntersectionCache.layerMasks), + existingCrossingLayerIntersections: + regionIntersectionCache.existingCrossingLayerIntersections, + existingSameLayerIntersections: + regionIntersectionCache.existingSameLayerIntersections, + existingEntryExitLayerChanges: + regionIntersectionCache.existingEntryExitLayerChanges, + existingRegionCost: regionIntersectionCache.existingRegionCost, + existingSegmentCount: regionIntersectionCache.existingSegmentCount, +}) + +const cloneSolvedStateSnapshot = ( + snapshot: SolvedStateSnapshot, +): SolvedStateSnapshot => ({ + portAssignment: new Int32Array(snapshot.portAssignment), + regionSegments: cloneRegionSegments(snapshot.regionSegments), + regionIntersectionCaches: snapshot.regionIntersectionCaches.map( + cloneRegionIntersectionCache, + ), + regionCongestionCost: new Float64Array(snapshot.regionCongestionCost), + ripCount: snapshot.ripCount, +}) + export interface TinyHyperGraphTopology { portCount: number regionCount: number @@ -96,6 +137,9 @@ export interface TinyHyperGraphProblem { routeNet: Int32Array // NetId[] /** regionNetId[regionId] = reserved net id for the region, -1 means freely traversable */ regionNetId: Int32Array + + /** portPenalty[portId] = extra cost paid when a route traverses the port */ + portPenalty?: Float64Array } export interface TinyHyperGraphProblemSetup { @@ -121,6 +165,14 @@ export interface RegionCostSummary { totalRegionCost: number } +interface SolvedStateSnapshot { + portAssignment: Int32Array + regionSegments: Array<[RouteId, PortId, PortId][]> + regionIntersectionCaches: RegionIntersectionCache[] + regionCongestionCost: Float64Array + ripCount: number +} + export interface NeverSuccessfullyRoutedRouteSummary { routeId: RouteId connectionId: string @@ -185,6 +237,8 @@ export interface TinyHyperGraphSolverOptions { VERBOSE?: boolean STATIC_REACHABILITY_PRECHECK?: boolean STATIC_REACHABILITY_PRECHECK_MAX_HOPS?: number + ACCEPT_BEST_SOLUTION_ON_TIMEOUT?: boolean + GREEDY_FINAL_ROUTE_ITERS?: number } export interface TinyHyperGraphSolverOptionTarget { @@ -200,6 +254,8 @@ export interface TinyHyperGraphSolverOptionTarget { VERBOSE: boolean STATIC_REACHABILITY_PRECHECK: boolean STATIC_REACHABILITY_PRECHECK_MAX_HOPS: number + ACCEPT_BEST_SOLUTION_ON_TIMEOUT: boolean + GREEDY_FINAL_ROUTE_ITERS: number } export const applyTinyHyperGraphSolverOptions = ( @@ -248,6 +304,13 @@ export const applyTinyHyperGraphSolverOptions = ( solver.STATIC_REACHABILITY_PRECHECK_MAX_HOPS = options.STATIC_REACHABILITY_PRECHECK_MAX_HOPS } + if (options.ACCEPT_BEST_SOLUTION_ON_TIMEOUT !== undefined) { + solver.ACCEPT_BEST_SOLUTION_ON_TIMEOUT = + options.ACCEPT_BEST_SOLUTION_ON_TIMEOUT + } + if (options.GREEDY_FINAL_ROUTE_ITERS !== undefined) { + solver.GREEDY_FINAL_ROUTE_ITERS = options.GREEDY_FINAL_ROUTE_ITERS + } } export const getTinyHyperGraphSolverOptions = ( @@ -266,6 +329,8 @@ export const getTinyHyperGraphSolverOptions = ( STATIC_REACHABILITY_PRECHECK: solver.STATIC_REACHABILITY_PRECHECK, STATIC_REACHABILITY_PRECHECK_MAX_HOPS: solver.STATIC_REACHABILITY_PRECHECK_MAX_HOPS, + ACCEPT_BEST_SOLUTION_ON_TIMEOUT: solver.ACCEPT_BEST_SOLUTION_ON_TIMEOUT, + GREEDY_FINAL_ROUTE_ITERS: solver.GREEDY_FINAL_ROUTE_ITERS, }) const compareCandidatesByF = (left: Candidate, right: Candidate) => @@ -283,6 +348,8 @@ export class TinyHyperGraphSolver extends BaseSolver { private _problemSetup?: TinyHyperGraphProblemSetup protected routeAttemptCountByRouteId: Uint32Array protected routeSuccessCountByRouteId: Uint32Array + protected bestSolvedStateSnapshot?: SolvedStateSnapshot + protected bestSolvedStateSummary?: RegionCostSummary private hasLoggedNeverSuccessfullyRoutedRoutes = false private segmentGeometryScratch: SegmentGeometryScratch = { lesserAngle: 0, @@ -306,6 +373,8 @@ export class TinyHyperGraphSolver extends BaseSolver { VERBOSE = false STATIC_REACHABILITY_PRECHECK = true STATIC_REACHABILITY_PRECHECK_MAX_HOPS = 16 + ACCEPT_BEST_SOLUTION_ON_TIMEOUT = true + GREEDY_FINAL_ROUTE_ITERS = 4 constructor( public topology: TinyHyperGraphTopology, @@ -968,6 +1037,199 @@ export class TinyHyperGraphSolver extends BaseSolver { ) } + protected compareRegionCostSummaries( + left: RegionCostSummary, + right: RegionCostSummary, + ) { + if (left.maxRegionCost !== right.maxRegionCost) { + return left.maxRegionCost - right.maxRegionCost + } + + return left.totalRegionCost - right.totalRegionCost + } + + protected captureBestSolvedState(summary: RegionCostSummary) { + if ( + this.bestSolvedStateSummary && + this.compareRegionCostSummaries(summary, this.bestSolvedStateSummary) >= 0 + ) { + return + } + + this.bestSolvedStateSummary = summary + this.bestSolvedStateSnapshot = cloneSolvedStateSnapshot({ + portAssignment: this.state.portAssignment, + regionSegments: this.state.regionSegments, + regionIntersectionCaches: this.state.regionIntersectionCaches, + regionCongestionCost: this.state.regionCongestionCost, + ripCount: this.state.ripCount, + }) + } + + protected restoreBestSolvedState() { + if (!this.bestSolvedStateSnapshot) { + return + } + + const snapshot = cloneSolvedStateSnapshot(this.bestSolvedStateSnapshot) + this.state.portAssignment = snapshot.portAssignment + this.state.regionSegments = snapshot.regionSegments + this.state.regionIntersectionCaches = snapshot.regionIntersectionCaches + this.state.regionCongestionCost = snapshot.regionCongestionCost + this.state.ripCount = snapshot.ripCount + this.state.currentRouteId = undefined + this.state.currentRouteNetId = undefined + this.state.unroutedRoutes = [] + this.state.candidateQueue.clear() + this.resetCandidateBestCosts() + this.state.goalPortId = -1 + } + + protected getRemainingRouteIdsForGreedyFinalRoute(): RouteId[] { + const routeIds = new Set(this.state.unroutedRoutes) + + if (this.state.currentRouteId !== undefined) { + routeIds.add(this.state.currentRouteId) + } + + return [...routeIds] + } + + protected applySnapshotToGreedyFinalRouteSolver( + solver: TinyHyperGraphSolver, + snapshot: SolvedStateSnapshot, + routeIds: RouteId[], + ) { + const clonedSnapshot = cloneSolvedStateSnapshot(snapshot) + + solver.state.portAssignment = clonedSnapshot.portAssignment + solver.state.regionSegments = clonedSnapshot.regionSegments + solver.state.regionIntersectionCaches = + clonedSnapshot.regionIntersectionCaches + solver.state.regionCongestionCost = clonedSnapshot.regionCongestionCost + solver.state.ripCount = 0 + solver.state.currentRouteId = undefined + solver.state.currentRouteNetId = undefined + solver.state.unroutedRoutes = [...routeIds] + solver.state.candidateQueue.clear() + solver.resetCandidateBestCosts() + solver.state.goalPortId = -1 + } + + protected summarizeSolvedState( + solver: TinyHyperGraphSolver, + ): RegionCostSummary { + let maxRegionCost = 0 + let totalRegionCost = 0 + + for (const regionIntersectionCache of solver.state + .regionIntersectionCaches) { + const regionCost = regionIntersectionCache.existingRegionCost + maxRegionCost = Math.max(maxRegionCost, regionCost) + totalRegionCost += regionCost + } + + return { + maxRegionCost, + totalRegionCost, + } + } + + protected tryGreedyFinalRouteAcceptance(): boolean { + const greedyFinalRouteIters = Math.max( + 0, + Math.floor(this.GREEDY_FINAL_ROUTE_ITERS), + ) + if (greedyFinalRouteIters === 0) { + return false + } + + const remainingRouteIds = this.getRemainingRouteIdsForGreedyFinalRoute() + if (remainingRouteIds.length === 0) { + return false + } + + const startingSnapshot = cloneSolvedStateSnapshot({ + portAssignment: this.state.portAssignment, + regionSegments: this.state.regionSegments, + regionIntersectionCaches: this.state.regionIntersectionCaches, + regionCongestionCost: this.state.regionCongestionCost, + ripCount: this.state.ripCount, + }) + + for ( + let greedyFinalRouteIter = 0; + greedyFinalRouteIter < greedyFinalRouteIters; + greedyFinalRouteIter++ + ) { + const routeIds = + greedyFinalRouteIter === 0 + ? remainingRouteIds + : shuffle( + remainingRouteIds, + this.state.ripCount + greedyFinalRouteIter, + ) + const greedySolver = new GreedyFinalRouteSolver( + this.topology, + this.problem, + { + ...getTinyHyperGraphSolverOptions(this), + ACCEPT_BEST_SOLUTION_ON_TIMEOUT: false, + GREEDY_FINAL_ROUTE_ITERS: 0, + MAX_ITERATIONS: GREEDY_FINAL_ROUTE_MAX_ITERATIONS, + RIP_THRESHOLD_RAMP_ATTEMPTS: 0, + STATIC_REACHABILITY_PRECHECK: false, + }, + ) + + this.applySnapshotToGreedyFinalRouteSolver( + greedySolver, + startingSnapshot, + routeIds, + ) + greedySolver.solve() + + if (!greedySolver.solved || greedySolver.failed) { + continue + } + + this.bestSolvedStateSnapshot = cloneSolvedStateSnapshot({ + portAssignment: greedySolver.state.portAssignment, + regionSegments: greedySolver.state.regionSegments, + regionIntersectionCaches: greedySolver.state.regionIntersectionCaches, + regionCongestionCost: greedySolver.state.regionCongestionCost, + ripCount: greedySolver.state.ripCount, + }) + this.bestSolvedStateSummary = this.summarizeSolvedState(greedySolver) + this.restoreBestSolvedState() + this.stats = { + ...this.stats, + acceptedGreedyFinalRouteOnTimeout: true, + greedyFinalRouteIter, + greedyFinalRouteRemainingRouteCount: remainingRouteIds.length, + greedyFinalRouteMaxIterations: GREEDY_FINAL_ROUTE_MAX_ITERATIONS, + neverSuccessfullyRoutedRouteCount: 0, + maxRegionCost: this.bestSolvedStateSummary.maxRegionCost, + totalRegionCost: this.bestSolvedStateSummary.totalRegionCost, + bestMaxRegionCost: this.bestSolvedStateSummary.maxRegionCost, + bestTotalRegionCost: this.bestSolvedStateSummary.totalRegionCost, + } + this.solved = true + this.failed = false + this.error = null + return true + } + + this.stats = { + ...this.stats, + greedyFinalRouteAttemptCount: greedyFinalRouteIters, + greedyFinalRouteRemainingRouteCount: remainingRouteIds.length, + greedyFinalRouteMaxIterations: GREEDY_FINAL_ROUTE_MAX_ITERATIONS, + } + + return false + } + onAllRoutesRouted() { const { topology, state } = this const ripThresholdProgress = @@ -981,23 +1243,33 @@ export class TinyHyperGraphSolver extends BaseSolver { const regionIdsOverCostThreshold: RegionId[] = [] const regionCosts = new Float64Array(topology.regionCount) let maxRegionCost = 0 + let totalRegionCost = 0 for (let regionId = 0; regionId < topology.regionCount; regionId++) { const regionCost = state.regionIntersectionCaches[regionId]?.existingRegionCost ?? 0 regionCosts[regionId] = regionCost maxRegionCost = Math.max(maxRegionCost, regionCost) + totalRegionCost += regionCost if (regionCost > currentRipThreshold) { regionIdsOverCostThreshold.push(regionId) } } + this.captureBestSolvedState({ + maxRegionCost, + totalRegionCost, + }) + this.stats = { ...this.stats, currentRipThreshold, hotRegionCount: regionIdsOverCostThreshold.length, maxRegionCost, + totalRegionCost, + bestMaxRegionCost: this.bestSolvedStateSummary?.maxRegionCost, + bestTotalRegionCost: this.bestSolvedStateSummary?.totalRegionCost, ripCount: state.ripCount, } @@ -1129,7 +1401,8 @@ export class TinyHyperGraphSolver extends BaseSolver { return ( currentCandidate.g + newRegionCost + - state.regionCongestionCost[nextRegionId] + state.regionCongestionCost[nextRegionId] + + (this.problem.portPenalty?.[neighborPortId] ?? 0) ) } @@ -1141,6 +1414,34 @@ export class TinyHyperGraphSolver extends BaseSolver { ...this.stats, neverSuccessfullyRoutedRouteCount: neverSuccessfullyRoutedRoutes.length, } + + if ( + this.ACCEPT_BEST_SOLUTION_ON_TIMEOUT && + this.bestSolvedStateSnapshot && + this.bestSolvedStateSummary + ) { + this.restoreBestSolvedState() + this.stats = { + ...this.stats, + acceptedBestSolutionOnTimeout: true, + maxRegionCost: this.bestSolvedStateSummary.maxRegionCost, + totalRegionCost: this.bestSolvedStateSummary.totalRegionCost, + bestMaxRegionCost: this.bestSolvedStateSummary.maxRegionCost, + bestTotalRegionCost: this.bestSolvedStateSummary.totalRegionCost, + } + this.solved = true + this.failed = false + this.error = null + return + } + + if ( + this.ACCEPT_BEST_SOLUTION_ON_TIMEOUT && + this.tryGreedyFinalRouteAcceptance() + ) { + return + } + this.logNeverSuccessfullyRoutedRoutes() } @@ -1168,3 +1469,12 @@ export class TinyHyperGraphSolver extends BaseSolver { return convertToSerializedHyperGraph(this) } } + +class GreedyFinalRouteSolver extends TinyHyperGraphSolver { + override computeG( + currentCandidate: Candidate, + _neighborPortId: PortId, + ): number { + return currentCandidate.g + } +} diff --git a/lib/index.ts b/lib/index.ts index b3f5c40..75ea4fb 100644 --- a/lib/index.ts +++ b/lib/index.ts @@ -1,4 +1,5 @@ export * from "./core" +export * from "./DuplicateCongestedPortSolver" export * from "./poly" export * from "./bus-solver" export * from "./region-graph" diff --git a/lib/section-solver/TinyHyperGraphSectionPipelineSolver.ts b/lib/section-solver/TinyHyperGraphSectionPipelineSolver.ts index 8614d61..9ccd4f9 100644 --- a/lib/section-solver/TinyHyperGraphSectionPipelineSolver.ts +++ b/lib/section-solver/TinyHyperGraphSectionPipelineSolver.ts @@ -47,11 +47,14 @@ const DEFAULT_SOLVE_GRAPH_OPTIONS: TinyHyperGraphSolverOptions = { RIP_THRESHOLD_RAMP_ATTEMPTS: 5, } +const DEFAULT_SECTION_SOLVER_MAX_ITERATIONS = 50_000 +const DEFAULT_SECTION_PIPELINE_MAX_ITERATIONS = 200_000 + const DEFAULT_SECTION_SOLVER_OPTIONS: TinyHyperGraphSectionSolverOptions = { DISTANCE_TO_COST: 0.05, RIP_THRESHOLD_RAMP_ATTEMPTS: 16, RIP_CONGESTION_REGION_COST_FACTOR: 0.1, - MAX_ITERATIONS: 1e6, + MAX_ITERATIONS: DEFAULT_SECTION_SOLVER_MAX_ITERATIONS, MAX_RIPS_WITHOUT_MAX_REGION_COST_IMPROVEMENT: 6, EXTRA_RIPS_AFTER_BEATING_BASELINE_MAX_REGION_COST: Number.POSITIVE_INFINITY, } @@ -120,6 +123,10 @@ const createProblemWithPortSectionMask = ( routeEndPort: new Int32Array(problem.routeEndPort), routeNet: new Int32Array(problem.routeNet), regionNetId: new Int32Array(problem.regionNetId), + portPenalty: + problem.portPenalty === undefined + ? undefined + : new Float64Array(problem.portPenalty), }) const getSectionMaskCandidates = ( @@ -317,6 +324,11 @@ export class TinyHyperGraphSectionPipelineSolver extends BasePipelineSolver("solveGraph")) { + this.stats = { + ...this.stats, + acceptedSolveGraphOutputOnSectionPipelineTimeout: true, + } + this.activeSubSolver = undefined + this.solved = true + this.failed = false + this.error = null + } + } } diff --git a/lib/section-solver/index.ts b/lib/section-solver/index.ts index db2ea74..e66f4a1 100644 --- a/lib/section-solver/index.ts +++ b/lib/section-solver/index.ts @@ -551,6 +551,10 @@ const createSectionRoutePlans = ( routeEndPort, routeNet: new Int32Array(problem.routeNet), regionNetId: new Int32Array(problem.regionNetId), + portPenalty: + problem.portPenalty === undefined + ? undefined + : new Float64Array(problem.portPenalty), }, routePlans, activeRouteIds, @@ -834,12 +838,25 @@ class TinyHyperGraphSectionSearchSolver extends TinyHyperGraphSolver { } override tryFinalAcceptance() { - if (!this.bestSnapshot) { - super.tryFinalAcceptance() + if (this.bestSnapshot) { + this.restoreBestState() + this.solved = true return } - this.restoreBestState() + if (this.fixedSnapshot) { + restoreSolvedStateSnapshot(this, this.fixedSnapshot) + } + this.state.currentRouteId = undefined + this.state.currentRouteNetId = undefined + this.state.unroutedRoutes = [...this.activeRouteIds] + this.state.candidateQueue.clear() + this.resetCandidateBestCosts() + this.state.goalPortId = -1 + this.stats = { + ...this.stats, + acceptedFixedSectionStateOnTimeout: true, + } this.solved = true } @@ -878,6 +895,8 @@ export class TinyHyperGraphSectionSolver extends BaseSolver { override MAX_ITERATIONS = 1e6 STATIC_REACHABILITY_PRECHECK = false STATIC_REACHABILITY_PRECHECK_MAX_HOPS = 16 + ACCEPT_BEST_SOLUTION_ON_TIMEOUT = true + GREEDY_FINAL_ROUTE_ITERS = 4 constructor( public topology: TinyHyperGraphTopology, @@ -975,8 +994,18 @@ export class TinyHyperGraphSectionSolver extends BaseSolver { } if (this.sectionSolver.failed) { - this.error = this.sectionSolver.error - this.failed = true + this.optimizedSolver = this.baselineSolver + this.stats = { + ...this.stats, + initialMaxRegionCost: this.baselineSummary.maxRegionCost, + initialTotalRegionCost: this.baselineSummary.totalRegionCost, + finalMaxRegionCost: this.baselineSummary.maxRegionCost, + finalTotalRegionCost: this.baselineSummary.totalRegionCost, + optimized: false, + sectionSearchFailedFallbackToBaseline: true, + sectionSearchError: this.sectionSolver.error, + } + this.solved = true return } @@ -1012,6 +1041,22 @@ export class TinyHyperGraphSectionSolver extends BaseSolver { this.solved = true } + override tryFinalAcceptance() { + this.optimizedSolver = this.baselineSolver + this.stats = { + ...this.stats, + initialMaxRegionCost: this.baselineSummary.maxRegionCost, + initialTotalRegionCost: this.baselineSummary.totalRegionCost, + finalMaxRegionCost: this.baselineSummary.maxRegionCost, + finalTotalRegionCost: this.baselineSummary.totalRegionCost, + optimized: false, + sectionSolverTimeoutFallbackToBaseline: true, + } + this.solved = true + this.failed = false + this.error = null + } + getSolvedSolver(): TinyHyperGraphSolver { if (!this.solved || this.failed || !this.optimizedSolver) { throw new Error( diff --git a/pages/port-chokepoint.page.tsx b/pages/port-chokepoint.page.tsx index 2baee30..4e9f30f 100644 --- a/pages/port-chokepoint.page.tsx +++ b/pages/port-chokepoint.page.tsx @@ -1,11 +1,18 @@ import type { SerializedHyperGraph } from "@tscircuit/hypergraph" import { loadSerializedHyperGraph } from "lib/compat/loadSerializedHyperGraph" -import { TinyHyperGraphSolver } from "lib/index" +import { DuplicateCongestedPortSolver, TinyHyperGraphSolver } from "lib/index" import { portChokepointFixture } from "../tests/fixtures/port-chokepoint.fixture" import { Debugger } from "./components/Debugger" const createSolver = (serializedHyperGraph: SerializedHyperGraph) => { - const { topology, problem } = loadSerializedHyperGraph(serializedHyperGraph) + const duplicateCongestedPortSolver = new DuplicateCongestedPortSolver( + serializedHyperGraph, + ) + duplicateCongestedPortSolver.solve() + const { topology, problem } = loadSerializedHyperGraph( + duplicateCongestedPortSolver.getOutput(), + ) + return new TinyHyperGraphSolver(topology, problem, { MAX_ITERATIONS: 20_000, STATIC_REACHABILITY_PRECHECK: false, @@ -18,10 +25,9 @@ export default function PortChokepointPage() {
Port chokepoint repro: connection-a routes from the top left endpoint to the top right endpoint, while connection-b{" "} - routes from the bottom left endpoint to the bottom right endpoint. Both - nets must pass through the single left-center-choke and{" "} - center-right-choke ports, so after one route claims the - corridor the other route has no valid port-disjoint path. + routes from the bottom left endpoint to the bottom right endpoint. The + page first runs DuplicateCongestedPortSolver, then routes + the repaired topology with the strict core solver.
({ + regionId, + pointIds, + d: { + center: { x: centerX, y: centerY }, + width, + height, + }, +}) + +const createPort = ( + portId: string, + region1Id: string, + region2Id: string, + x: number, + y: number, +): SerializedHyperGraph["ports"][number] => ({ + portId, + region1Id, + region2Id, + d: { x, y, z: 0 }, +}) + +const getNumber = (value: unknown) => + typeof value === "number" && Number.isFinite(value) ? value : 0 + +const getPortIndexBySerializedId = ( + topology: ReturnType["topology"], + serializedPortId: string, +) => + topology.portMetadata?.findIndex( + (metadata) => + typeof metadata === "object" && + metadata !== null && + "serializedPortId" in metadata && + metadata.serializedPortId === serializedPortId, + ) ?? -1 + +const createParallelPortFixture = (): SerializedHyperGraph => ({ + regions: [ + createRegion("start", -4, 0, 2, 2, ["start-port"]), + createRegion("left", -1, 0, 4, 6, ["start-port", "middle-a", "middle-b"]), + createRegion("right", 1, 0, 4, 6, ["middle-a", "middle-b", "end-port"]), + createRegion("end", 4, 0, 2, 2, ["end-port"]), + ], + ports: [ + createPort("start-port", "start", "left", -3, 0), + createPort("middle-a", "left", "right", 0, 0), + createPort("middle-b", "left", "right", 0, 2), + createPort("end-port", "right", "end", 3, 0), + ], + connections: [ + { + connectionId: "connection-a", + startRegionId: "start", + endRegionId: "end", + mutuallyConnectedNetworkId: "net-a", + }, + ], +}) + +const createDuplicatePortFixture = (): SerializedHyperGraph => ({ + regions: [ + createRegion("a-start", -4, 0, 2, 2, ["a-start-port"]), + createRegion("b-start", -4, -0.2, 2, 2, ["b-start-port"]), + createRegion("left", -1, 0, 4, 10, [ + "a-start-port", + "b-start-port", + "shared-choke", + "shared-neighbor", + ]), + createRegion("right", 1, 0, 4, 10, [ + "shared-choke", + "shared-neighbor", + "a-end-port", + "b-end-port", + ]), + createRegion("a-end", 4, 0, 2, 2, ["a-end-port"]), + createRegion("b-end", 4, -0.2, 2, 2, ["b-end-port"]), + ], + ports: [ + createPort("a-start-port", "a-start", "left", -3, 0), + createPort("b-start-port", "b-start", "left", -3, -0.2), + createPort("shared-choke", "left", "right", 0, 0), + createPort("shared-neighbor", "left", "right", 0, 4), + createPort("a-end-port", "right", "a-end", 3, 0), + createPort("b-end-port", "right", "b-end", 3, -0.2), + ], + connections: [ + { + connectionId: "connection-a", + startRegionId: "a-start", + endRegionId: "a-end", + mutuallyConnectedNetworkId: "net-a", + }, + { + connectionId: "connection-b", + startRegionId: "b-start", + endRegionId: "b-end", + mutuallyConnectedNetworkId: "net-b", + }, + ], +}) + +test("core solver applies port penalties when choosing an intermediate port", () => { + const { topology, problem } = loadSerializedHyperGraph( + createParallelPortFixture(), + ) + const penalizedPortIndex = getPortIndexBySerializedId(topology, "middle-a") + expect(penalizedPortIndex).toBeGreaterThanOrEqual(0) + + problem.portPenalty = new Float64Array(topology.portCount) + problem.portPenalty[penalizedPortIndex] = 1_000 + + const solver = new TinyHyperGraphSolver(topology, problem, { + RIP_THRESHOLD_RAMP_ATTEMPTS: 0, + STATIC_REACHABILITY_PRECHECK: false, + }) + solver.solve() + + expect(solver.solved).toBe(true) + expect(solver.failed).toBe(false) + expect( + solver.getOutput().solvedRoutes?.[0]?.path.map(({ portId }) => portId), + ).toEqual(["start-port", "middle-b", "end-port"]) +}) + +test("duplicate congested port solver duplicates independently reused ports in line with the boundary", () => { + const duplicatePortProximity = 0.2 + const solver = new DuplicateCongestedPortSolver( + createDuplicatePortFixture(), + { + duplicatePortProximity, + }, + ) + + solver.solve() + + expect(solver.solved).toBe(true) + expect(solver.failed).toBe(false) + expect(solver.report.portUseCounts["shared-choke"]).toBe(2) + expect(solver.report.duplicatedPorts).toContainEqual({ + sourcePortId: "shared-choke", + duplicatePortIds: ["shared-choke::dup1"], + useCount: 2, + }) + + const output = solver.getOutput() + const sourcePort = output.ports.find( + (port) => port.portId === "shared-choke", + )! + const nearestPort = output.ports.find( + (port) => port.portId === "shared-neighbor", + )! + const duplicatePort = output.ports.find( + (port) => port.portId === "shared-choke::dup1", + )! + + const sourcePoint = { + x: getNumber(sourcePort.d?.x), + y: getNumber(sourcePort.d?.y), + } + const nearestVector = { + x: getNumber(nearestPort.d?.x) - sourcePoint.x, + y: getNumber(nearestPort.d?.y) - sourcePoint.y, + } + const duplicateVector = { + x: getNumber(duplicatePort.d?.x) - sourcePoint.x, + y: getNumber(duplicatePort.d?.y) - sourcePoint.y, + } + const duplicateDistance = Math.hypot(duplicateVector.x, duplicateVector.y) + const crossProduct = + nearestVector.x * duplicateVector.y - nearestVector.y * duplicateVector.x + + expect(duplicatePort.d?.duplicatedFromPortId).toBe("shared-choke") + expect(duplicatePort.d?.duplicateIndex).toBe(1) + expect(duplicateDistance).toBeGreaterThan(0) + expect(duplicateDistance).toBeLessThanOrEqual(duplicatePortProximity) + expect(Math.abs(crossProduct)).toBeLessThan(1e-9) + expect(output.solvedRoutes).toBeUndefined() + expect( + output.regions + .filter( + (region) => region.regionId === "left" || region.regionId === "right", + ) + .every((region) => region.pointIds.includes("shared-choke::dup1")), + ).toBe(true) +}) diff --git a/tests/solver/on-all-routes-routed.test.ts b/tests/solver/on-all-routes-routed.test.ts index 53c6757..16431f3 100644 --- a/tests/solver/on-all-routes-routed.test.ts +++ b/tests/solver/on-all-routes-routed.test.ts @@ -71,6 +71,41 @@ const createTestSolver = ( return new TinyHyperGraphSolver(topology, problem, options) } +const createGreedyFinalRouteTestSolver = ( + options?: ConstructorParameters[2], +) => { + const portCount = 2 + const regionCount = 3 + const topology: TinyHyperGraphTopology = { + portCount, + regionCount, + regionIncidentPorts: [[0], [0, 1], [1]], + incidentPortRegion: [ + [1, 0], + [1, 2], + ], + regionWidth: new Float64Array(regionCount).fill(1), + regionHeight: new Float64Array(regionCount).fill(1), + regionCenterX: new Float64Array(regionCount).fill(0), + regionCenterY: new Float64Array(regionCount).fill(0), + portAngleForRegion1: new Int32Array(portCount), + portAngleForRegion2: new Int32Array(portCount), + portX: new Float64Array([0, 1]), + portY: new Float64Array(portCount), + portZ: new Int32Array(portCount), + } + const problem: TinyHyperGraphProblem = { + routeCount: 1, + portSectionMask: new Int8Array(portCount).fill(1), + routeStartPort: Int32Array.from([0]), + routeEndPort: Int32Array.from([1]), + routeNet: Int32Array.from([0]), + regionNetId: new Int32Array(regionCount).fill(-1), + } + + return new TinyHyperGraphSolver(topology, problem, options) +} + test("completed routing rerips when a region exceeds the current threshold", () => { const solver = createTestSolver() @@ -124,6 +159,81 @@ test("completed routing rerips when a region exceeds the current threshold", () expect(solver.state.goalPortId).toBe(-1) }) +test("completed routing can be accepted as best solution on timeout", () => { + const solver = createTestSolver({ MAX_ITERATIONS: 1 }) + + solver.state.unroutedRoutes = [] + solver.state.portAssignment.set([0, 0, 1, 1]) + solver.state.regionSegments[0] = [[0, 0, 1]] + solver.state.regionSegments[1] = [[1, 2, 3]] + solver.state.regionIntersectionCaches[0] = createRegionCache(0.5) + solver.state.regionIntersectionCaches[1] = createRegionCache(0.1) + + solver.step() + + expect(solver.solved).toBe(true) + expect(solver.failed).toBe(false) + expect(solver.stats.acceptedBestSolutionOnTimeout).toBe(true) + expect(solver.stats.bestMaxRegionCost).toBe(0.5) + expect(solver.stats.bestTotalRegionCost).toBe(0.6) + expect(Array.from(solver.state.portAssignment)).toEqual([0, 0, 1, 1]) + expect(solver.state.regionSegments[0]).toEqual([[0, 0, 1]]) + expect(solver.state.regionSegments[1]).toEqual([[1, 2, 3]]) + expect(solver.state.regionIntersectionCaches[0].existingRegionCost).toBe(0.5) + expect(solver.state.regionIntersectionCaches[1].existingRegionCost).toBe(0.1) + expect(solver.state.unroutedRoutes).toEqual([]) +}) + +test("best solution timeout acceptance can be disabled", () => { + const solver = createTestSolver({ + MAX_ITERATIONS: 1, + ACCEPT_BEST_SOLUTION_ON_TIMEOUT: false, + }) + + solver.state.unroutedRoutes = [] + solver.state.portAssignment.set([0, 0, 1, 1]) + solver.state.regionSegments[0] = [[0, 0, 1]] + solver.state.regionSegments[1] = [[1, 2, 3]] + solver.state.regionIntersectionCaches[0] = createRegionCache(0.5) + solver.state.regionIntersectionCaches[1] = createRegionCache(0.1) + + solver.step() + + expect(solver.solved).toBe(false) + expect(solver.failed).toBe(true) + expect(solver.error).toBe("TinyHyperGraphSolver ran out of iterations") + expect(solver.stats.acceptedBestSolutionOnTimeout).toBeUndefined() + expect(Array.from(solver.state.portAssignment)).toEqual([-1, -1, -1, -1]) +}) + +test("final acceptance greedily routes remaining routes when no complete snapshot exists", () => { + const solver = createGreedyFinalRouteTestSolver({ + GREEDY_FINAL_ROUTE_ITERS: 2, + }) + + solver.tryFinalAcceptance() + + expect(solver.solved).toBe(true) + expect(solver.failed).toBe(false) + expect(solver.stats.acceptedGreedyFinalRouteOnTimeout).toBe(true) + expect(solver.stats.greedyFinalRouteRemainingRouteCount).toBe(1) + expect(solver.state.unroutedRoutes).toEqual([]) + expect(Array.from(solver.state.portAssignment)).toEqual([0, 0]) + expect(solver.state.regionSegments[1]).toEqual([[0, 0, 1]]) +}) + +test("greedy final routing can be disabled independently", () => { + const solver = createGreedyFinalRouteTestSolver({ + GREEDY_FINAL_ROUTE_ITERS: 0, + }) + + solver.tryFinalAcceptance() + + expect(solver.solved).toBe(false) + expect(solver.failed).toBe(false) + expect(solver.stats.acceptedGreedyFinalRouteOnTimeout).toBeUndefined() +}) + test("completed routing is accepted once all region costs are under the threshold", () => { const solver = createTestSolver() @@ -147,6 +257,8 @@ test("constructor options override snake-case hyperparameters before setup", () RIP_THRESHOLD_RAMP_ATTEMPTS: 7, RIP_CONGESTION_REGION_COST_FACTOR: 0.45, MAX_ITERATIONS: 1234, + ACCEPT_BEST_SOLUTION_ON_TIMEOUT: false, + GREEDY_FINAL_ROUTE_ITERS: 6, }) expect(solver.DISTANCE_TO_COST).toBe(0.25) @@ -155,5 +267,7 @@ test("constructor options override snake-case hyperparameters before setup", () expect(solver.RIP_THRESHOLD_RAMP_ATTEMPTS).toBe(7) expect(solver.RIP_CONGESTION_REGION_COST_FACTOR).toBe(0.45) expect(solver.MAX_ITERATIONS).toBe(1234) + expect(solver.ACCEPT_BEST_SOLUTION_ON_TIMEOUT).toBe(false) + expect(solver.GREEDY_FINAL_ROUTE_ITERS).toBe(6) expect(solver.problemSetup.portHCostToEndOfRoute[0]).toBe(0.25) }) diff --git a/tests/solver/section-solver.test.ts b/tests/solver/section-solver.test.ts index 67b36a0..cd53248 100644 --- a/tests/solver/section-solver.test.ts +++ b/tests/solver/section-solver.test.ts @@ -215,6 +215,26 @@ test("section solver enforces section-specific rip thresholds and max rip cap", ).toBe(2) }) +test("section solver final acceptance falls back to the baseline solution", () => { + const { topology, problem, solution } = loadSerializedHyperGraph( + sectionSolverFixtureGraph, + ) + problem.portSectionMask = createSectionSolverFixturePortMask(topology) + + const sectionSolver = new TinyHyperGraphSectionSolver( + topology, + problem, + solution, + ) + + sectionSolver.tryFinalAcceptance() + + expect(sectionSolver.solved).toBe(true) + expect(sectionSolver.failed).toBe(false) + expect(sectionSolver.stats.sectionSolverTimeoutFallbackToBaseline).toBe(true) + expect(sectionSolver.getSolvedSolver()).toBe(sectionSolver.baselineSolver) +}) + test("section pipeline visualize renders the input graph at iteration zero", () => { const pipelineSolver = new TinyHyperGraphSectionPipelineSolver({ serializedHyperGraph: sectionSolverFixtureGraph, @@ -246,6 +266,36 @@ test("section pipeline visualize infers z-layer labels from incident ports", () expect(startPortCircle?.layer).toBe("z0") }) +test("section pipeline uses bounded default iteration limits", () => { + const pipelineSolver = new TinyHyperGraphSectionPipelineSolver({ + serializedHyperGraph: datasetHg07.sample029, + }) + + expect(pipelineSolver.MAX_ITERATIONS).toBe(200_000) + expect(pipelineSolver.getSectionSolverOptions().MAX_ITERATIONS).toBe(50_000) +}) + +test("section pipeline final acceptance falls back to solveGraph output", () => { + const pipelineSolver = new TinyHyperGraphSectionPipelineSolver({ + serializedHyperGraph: sectionSolverFixtureGraph, + }) + const pipelineOutputs = pipelineSolver.getAllOutputs() as Record< + string, + ReturnType + > + pipelineOutputs.solveGraph = sectionSolverFixtureGraph + pipelineSolver.pipelineOutputs = pipelineOutputs + + pipelineSolver.tryFinalAcceptance() + + expect(pipelineSolver.solved).toBe(true) + expect(pipelineSolver.failed).toBe(false) + expect( + pipelineSolver.stats.acceptedSolveGraphOutputOnSectionPipelineTimeout, + ).toBe(true) + expect(pipelineSolver.getOutput()).toBe(sectionSolverFixtureGraph) +}) + test("section pipeline searches multiple masks and commits an improving output on hg07 sample029", () => { const pipelineSolver = new TinyHyperGraphSectionPipelineSolver({ serializedHyperGraph: datasetHg07.sample029,