diff --git a/src/host/hostMain.ts b/src/host/hostMain.ts index 663bfcb..c4c9385 100644 --- a/src/host/hostMain.ts +++ b/src/host/hostMain.ts @@ -49,7 +49,14 @@ import { sessionDir, socketPath, } from '../storage/sessionPaths.js'; +import { + addAbortListener, + makeAbortReason, + throwIfAborted, + waitForScopedOperation, +} from '../util/abort.js'; import { invariant } from '../util/assert.js'; +import { ResourceScope } from '../util/resourceScope.js'; const ALLOWED_SIGNALS = [ 'SIGTERM', @@ -169,7 +176,8 @@ export async function runHost(sessionId: string): Promise { let ptyHasExited = false; let lastOutputAt = Date.now(); let lastActivityAt = lastOutputAt; - let idleTimeoutHandle: ReturnType | null = null; + const hostAbortController = new AbortController(); + let idleTimeoutScope: ResourceScope | null = null; let rpcListenPromise: Promise | null = null; let shutdownPromise: Promise | null = null; let markPtyExited: () => void = () => { @@ -184,6 +192,10 @@ export async function runHost(sessionId: string): Promise { }); let ptyIngestionQueue: Promise = Promise.resolve(); + // Per-client wait-exit callbacks, cleaned up individually via ResourceScope. + // Using ptyExitPromise.then() would permanently attach to the shared promise. + const ptyExitWaiters = new Set<() => void>(); + const ptyExitPromise = new Promise((resolve) => { markPtyExited = (): void => { if (ptyHasExited) { @@ -192,6 +204,11 @@ export async function runHost(sessionId: string): Promise { ptyHasExited = true; resolve(); + const waiters = [...ptyExitWaiters]; + ptyExitWaiters.clear(); + for (const waiter of waiters) { + waiter(); + } }; }); @@ -250,25 +267,36 @@ export async function runHost(sessionId: string): Promise { const clearIdleTimeout = (): void => { // Idempotent: safe to call multiple times during shutdown and PTY exit. - if (idleTimeoutHandle === null) { + const scope = idleTimeoutScope; + idleTimeoutScope = null; + if (scope === null) { return; } - clearInterval(idleTimeoutHandle); - idleTimeoutHandle = null; + void scope.close().catch(rethrowAsync); }; const startIdlePolling = (): void => { if ( idleTimeoutMs <= 0 || !isSessionCommandable(state) || - idleTimeoutHandle !== null + idleTimeoutScope !== null ) { return; } + throwIfAborted(hostAbortController.signal); + + const scope = new ResourceScope(); + idleTimeoutScope = scope; const checkIntervalMs = Math.min(idleTimeoutMs, IDLE_CHECK_CAP_MS); + let idleTimeoutHandle: ReturnType | null = null; idleTimeoutHandle = setInterval(() => { + if (hostAbortController.signal.aborted) { + clearIdleTimeout(); + return; + } + if (!isSessionCommandable(state)) { clearIdleTimeout(); return; @@ -280,6 +308,20 @@ export async function runHost(sessionId: string): Promise { pty.kill(); } }, checkIntervalMs); + scope.add('idle timeout interval', () => { + if (idleTimeoutHandle !== null) { + clearInterval(idleTimeoutHandle); + idleTimeoutHandle = null; + } + }); + addAbortListener( + scope, + 'idle timeout abort listener', + hostAbortController.signal, + () => { + clearIdleTimeout(); + }, + ); }; const initiateShutdown = (): Promise => { @@ -289,6 +331,7 @@ export async function runHost(sessionId: string): Promise { shutdownPromise = (async () => { try { + hostAbortController.abort(makeAbortReason('Host is shutting down.')); clearIdleTimeout(); if (isSessionCommandable(state)) { pty.kill(); @@ -357,6 +400,15 @@ export async function runHost(sessionId: string): Promise { })().catch(rethrowAsync); }; + const makeWaitExitOutcome = (): WaitOutcome => { + const snapshot = state.snapshot(); + const result: WaitOutcome = { timedOut: false }; + if (snapshot.exitCode !== null) { + result.exitCode = snapshot.exitCode; + } + return result; + }; + const handlers: Record = { inspect: () => Promise.resolve({ session: state.snapshot() }), snapshot: async (params: unknown) => { @@ -478,9 +530,11 @@ export async function runHost(sessionId: string): Promise { await eventLog.append('input_paste', { data: encoded }); return {}; }, - run: async (params: unknown) => { + run: async (params: unknown, context) => { const { command, noWait, timeoutMs } = params as RunParams; + const { signal } = context; + throwIfAborted(signal); assertSessionCommandable(state); invariant( @@ -527,7 +581,7 @@ export async function runHost(sessionId: string): Promise { pty.write(injectedText); lastActivityAt = Date.now(); - const waitResult = await completion.wait(effectiveTimeoutMs); + const waitResult = await completion.wait(effectiveTimeoutMs, { signal }); const durationMs = Date.now() - startTime; if (waitResult.kind === 'completed') { @@ -656,11 +710,13 @@ export async function runHost(sessionId: string): Promise { await eventLog.append('signal', { signal }); return {}; }, - wait: async (params: unknown) => { + wait: async (params: unknown, context) => { const { exit, idleMs, timeoutMs } = params as WaitParams; + const { signal } = context; const hasExit = exit === true; const hasIdle = idleMs !== undefined; + throwIfAborted(signal); if (hasExit === hasIdle) { throw makeCliError(ERROR_CODES.INVALID_DURATION, { message: 'Specify exactly one of exit or idleMs.', @@ -680,26 +736,22 @@ export async function runHost(sessionId: string): Promise { ); } + const waitScope = new ResourceScope(); let waitCondition: Promise; - let clearWaitCondition: (() => void) | null = null; if (hasExit) { if (ptyHasExited) { - const snapshot = state.snapshot(); - const result: WaitOutcome = { timedOut: false }; - if (snapshot.exitCode !== null) { - result.exitCode = snapshot.exitCode; - } - return result; + return makeWaitExitOutcome(); } - waitCondition = ptyExitPromise.then(() => { - const snapshot = state.snapshot(); - const result: WaitOutcome = { timedOut: false }; - if (snapshot.exitCode !== null) { - result.exitCode = snapshot.exitCode; - } - return result; + waitCondition = new Promise((resolve) => { + const waiter = (): void => { + resolve(makeWaitExitOutcome()); + }; + ptyExitWaiters.add(waiter); + waitScope.add('wait exit waiter', () => { + ptyExitWaiters.delete(waiter); + }); }); } else { assertSessionCommandable(state); @@ -714,13 +766,16 @@ export async function runHost(sessionId: string): Promise { waitCondition = new Promise((resolve) => { const checkInterval = setInterval( () => { + if (signal.aborted) { + return; + } + const effectiveLastOutput = Math.max(lastOutputAt, idleAnchor); const elapsed = Date.now() - effectiveLastOutput; if (elapsed < idleDuration) { return; } - clearInterval(checkInterval); const snapshot = state.snapshot(); const result: WaitOutcome = { timedOut: false }; if (snapshot.exitCode !== null) { @@ -731,30 +786,26 @@ export async function runHost(sessionId: string): Promise { Math.min(idleDuration / 2, 100), ); - clearWaitCondition = (): void => { + waitScope.add('wait idle poll interval', () => { clearInterval(checkInterval); - }; + }); }); } - if (timeoutMs === undefined) { - return await waitCondition; - } - - return await new Promise((resolve) => { - const timeoutHandle = setTimeout(() => { - clearWaitCondition?.(); - resolve({ timedOut: true }); - }, timeoutMs); - - void waitCondition.then((result) => { - clearTimeout(timeoutHandle); - clearWaitCondition?.(); - resolve(result); - }); + return await waitForScopedOperation({ + operationName: 'wait', + operation: waitCondition, + scope: waitScope, + signal, + ...(timeoutMs === undefined + ? {} + : { + timeoutMs, + timeoutResult: () => ({ timedOut: true }), + }), }); }, - waitForRender: async (params: unknown) => { + waitForRender: async (params: unknown, context) => { const { text, regex, @@ -764,7 +815,9 @@ export async function runHost(sessionId: string): Promise { timeoutMs, rendererName: requestedRendererName, } = params as WaitForRenderParams; + const { signal } = context; + throwIfAborted(signal); const preparedCondition = prepareRenderWaitCondition({ text, regex, @@ -782,30 +835,33 @@ export async function runHost(sessionId: string): Promise { const rendererName = resolveHostRendererName(requestedRendererName); const profile = resolveProfile(DEFAULT_RENDER_PROFILE_NAME); const pollIntervalMs = 200; + const waitScope = new ResourceScope(); let lastVisibleText: string | undefined; let lastTextChangeAt = Date.now(); let latestCapturedAtSeq = 0; - let clearWaitPoll: (() => void) | null = null; const pollCondition = new Promise((resolve) => { let pollInFlight = false; let consecutiveFailures = 0; const checkInterval = setInterval(() => { - if (pollInFlight) { + if (signal.aborted || pollInFlight) { return; } pollInFlight = true; void (async () => { try { + throwIfAborted(signal); const replayInput = loadReplayInput(); const backend = await rendererManager.getBackend( rendererName, profile, replayInput, ); + throwIfAborted(signal); const snapshot = await backend.snapshot(); + throwIfAborted(signal); const visibleText = snapshot.visibleLines .map((line) => line.text) .join('\n'); @@ -831,7 +887,6 @@ export async function runHost(sessionId: string): Promise { ); if (match.matched) { - clearInterval(checkInterval); resolve({ matched: true, timedOut: false, @@ -844,10 +899,13 @@ export async function runHost(sessionId: string): Promise { }); } } catch (pollError) { + if (signal.aborted) { + return; + } + void pollError; consecutiveFailures += 1; if (consecutiveFailures >= MAX_CONSECUTIVE_POLL_FAILURES) { - clearInterval(checkInterval); resolve({ matched: false, timedOut: true, @@ -862,47 +920,35 @@ export async function runHost(sessionId: string): Promise { })(); }, pollIntervalMs); - clearWaitPoll = (): void => { + waitScope.add('waitForRender poll interval', () => { clearInterval(checkInterval); - }; + }); }); - if (timeoutMs === undefined) { - return await pollCondition; - } - - return await new Promise((resolve) => { - let resolved = false; - const timeoutHandle = setTimeout(() => { - if (resolved) { - return; - } - resolved = true; - clearWaitPoll?.(); - - try { - const replayInput = loadReplayInput(); - latestCapturedAtSeq = replayInput?.targetSeq ?? 0; - } catch { - // Best-effort snapshot for timeout reporting. - } - - resolve({ - matched: false, - timedOut: true, - capturedAtSeq: latestCapturedAtSeq, - }); - }, timeoutMs); - - void pollCondition.then((result) => { - if (resolved) { - return; - } - resolved = true; - clearTimeout(timeoutHandle); - clearWaitPoll?.(); - resolve(result); - }); + return await waitForScopedOperation({ + operationName: 'waitForRender', + operation: pollCondition, + scope: waitScope, + signal, + ...(timeoutMs === undefined + ? {} + : { + timeoutMs, + timeoutResult: () => { + try { + const replayInput = loadReplayInput(); + latestCapturedAtSeq = replayInput?.targetSeq ?? 0; + } catch { + // Best-effort snapshot for timeout reporting. + } + + return { + matched: false, + timedOut: true, + capturedAtSeq: latestCapturedAtSeq, + }; + }, + }), }); }, destroy: () => { diff --git a/src/host/lifecycle.ts b/src/host/lifecycle.ts index 1b423c7..8d8e70b 100644 --- a/src/host/lifecycle.ts +++ b/src/host/lifecycle.ts @@ -25,12 +25,41 @@ import { sessionDir, socketPath, } from '../storage/sessionPaths.js'; +import { makeAbortError, throwIfAborted } from '../util/abort.js'; import { invariant } from '../util/assert.js'; import { sendRpc } from './rpcClient.js'; const DESTROY_POLL_INTERVAL_MS = 100; const DESTROY_MAX_ATTEMPTS = 50; +interface PollOptions { + readonly signal?: AbortSignal; +} + +function pollOptions(signal?: AbortSignal): PollOptions { + return signal === undefined ? {} : { signal }; +} + +function delayOptions( + signal?: AbortSignal, +): { signal: AbortSignal } | undefined { + return signal === undefined ? undefined : { signal }; +} + +async function pollDelay( + intervalMs: number, + signal?: AbortSignal, +): Promise { + try { + await delay(intervalMs, undefined, delayOptions(signal)); + } catch (error) { + if (signal?.aborted === true) { + throw makeAbortError(signal); + } + throw error; + } +} + interface NodeError extends Error { code?: string; } @@ -246,6 +275,7 @@ async function waitForTerminalManifest( manifestFile: string, maxAttempts: number = DESTROY_MAX_ATTEMPTS, intervalMs: number = DESTROY_POLL_INTERVAL_MS, + options: PollOptions = {}, ): Promise { invariant( Number.isInteger(maxAttempts) && maxAttempts > 0, @@ -256,7 +286,11 @@ async function waitForTerminalManifest( 'intervalMs must be a non-negative integer', ); + const { signal } = options; + throwIfAborted(signal); + for (let attempt = 0; attempt < maxAttempts; attempt += 1) { + throwIfAborted(signal); const manifest = await readManifest(manifestFile); if (isTerminalSessionStatus(manifest.status)) { @@ -264,7 +298,7 @@ async function waitForTerminalManifest( } if (attempt + 1 < maxAttempts) { - await delay(intervalMs); + await pollDelay(intervalMs, signal); } } @@ -277,6 +311,7 @@ async function waitForProcessAndSocketShutdown( socketFile: string, maxAttempts: number = DESTROY_MAX_ATTEMPTS, intervalMs: number = DESTROY_POLL_INTERVAL_MS, + options: PollOptions = {}, ): Promise { invariant( Number.isInteger(maxAttempts) && maxAttempts > 0, @@ -287,7 +322,11 @@ async function waitForProcessAndSocketShutdown( 'intervalMs must be a non-negative integer', ); + const { signal } = options; + throwIfAborted(signal); + for (let attempt = 0; attempt < maxAttempts; attempt += 1) { + throwIfAborted(signal); const hostAlive = isProcessAlive(hostPid); const childAlive = isProcessAlive(childPid); const socketPresent = await pathExists(socketFile); @@ -297,7 +336,7 @@ async function waitForProcessAndSocketShutdown( } if (attempt + 1 < maxAttempts) { - await delay(intervalMs); + await pollDelay(intervalMs, signal); } } @@ -458,10 +497,17 @@ export function launchHost(config: LaunchHostConfig): number { return child.pid; } +export interface DestroySessionOptions { + readonly signal?: AbortSignal; +} + export async function destroySession( sessionId: string, force?: boolean, + options: DestroySessionOptions = {}, ): Promise { + const { signal } = options; + throwIfAborted(signal); const { sessionDirectory, manifestFile, socketFile } = getSessionPaths(sessionId); const manifest = await readSessionManifestOrThrow(sessionId, manifestFile); @@ -493,6 +539,9 @@ export async function destroySession( manifest.hostPid, manifest.childPid, socketFile, + DESTROY_MAX_ATTEMPTS, + DESTROY_POLL_INTERVAL_MS, + pollOptions(signal), ); await reconcileSession(sessionDirectory); @@ -514,7 +563,8 @@ export async function destroySession( } try { - await sendRpc(socketFile, 'destroy'); + throwIfAborted(signal); + await sendRpc(socketFile, 'destroy', undefined, undefined, signal); } catch (error) { if ( !(error instanceof CliError) || @@ -535,7 +585,12 @@ export async function destroySession( throw error; } - const terminalManifest = await waitForTerminalManifest(manifestFile); + const terminalManifest = await waitForTerminalManifest( + manifestFile, + DESTROY_MAX_ATTEMPTS, + DESTROY_POLL_INTERVAL_MS, + pollOptions(signal), + ); if (terminalManifest !== null) { return; } diff --git a/src/host/rpcClient.ts b/src/host/rpcClient.ts index e1bab2f..50d27e9 100644 --- a/src/host/rpcClient.ts +++ b/src/host/rpcClient.ts @@ -13,7 +13,14 @@ import { RpcResponseSchema, type RpcMethod, } from '../protocol/messages.js'; +import { + addAbortListener, + createResourceScopedSettlers, + makeAbortError, + throwIfAborted, +} from '../util/abort.js'; import { invariant } from '../util/assert.js'; +import { ResourceScope } from '../util/resourceScope.js'; const DEFAULT_TIMEOUT_MS = 5_000; const MAX_RPC_BUFFER_BYTES = 1_048_576; @@ -95,7 +102,9 @@ export async function sendRpc( method: string, params?: Record, timeoutMs?: number, + signal?: AbortSignal, ): Promise { + throwIfAborted(signal); const effectiveTimeoutMs = timeoutMs ?? DEFAULT_TIMEOUT_MS; invariant( Number.isFinite(effectiveTimeoutMs) && effectiveTimeoutMs >= 0, @@ -116,35 +125,25 @@ export async function sendRpc( return await new Promise((resolve, reject) => { const socket = net.connect({ path: socketPath }); - let settled = false; + const scope = new ResourceScope(); + scope.add('rpc client socket', () => { + socket.destroy(); + }); + const settlers = createResourceScopedSettlers(scope, resolve, reject); let responseHandled = false; let buffer = ''; - const rejectWithCliError = (error: CliError): void => { - if (settled) { - return; - } - - settled = true; - socket.destroy(); - reject(error); - }; - const rejectWithTransportError = (error: unknown): void => { - rejectWithCliError( + settlers.reject( toTransportCliError(error, socketPath, method, effectiveTimeoutMs), ); }; - const resolveWithResult = (result: unknown): void => { - if (settled) { - return; - } - - settled = true; - socket.destroy(); - resolve(result); - }; + if (signal !== undefined) { + addAbortListener(scope, 'rpc client abort listener', signal, () => { + settlers.reject(makeAbortError(signal)); + }); + } socket.setEncoding('utf8'); socket.setTimeout(effectiveTimeoutMs); @@ -154,7 +153,7 @@ export async function sendRpc( }); socket.on('timeout', () => { - rejectWithCliError( + settlers.reject( makeCliError(ERROR_CODES.HOST_TIMEOUT, { message: `RPC request timed out after ${String(effectiveTimeoutMs)}ms.`, details: { @@ -176,7 +175,7 @@ export async function sendRpc( } if (buffer.length + chunk.length > MAX_RPC_BUFFER_BYTES) { - rejectWithCliError( + settlers.reject( makeCliError(ERROR_CODES.RPC_ERROR, { message: 'RPC response exceeds maximum buffer size.', details: { method, socketPath }, @@ -200,7 +199,7 @@ export async function sendRpc( const responseResult = RpcResponseSchema.safeParse(rawResponse); if (!responseResult.success) { - rejectWithCliError( + settlers.reject( makeCliError(ERROR_CODES.RPC_ERROR, { message: 'RPC response failed schema validation.', details: { @@ -216,7 +215,7 @@ export async function sendRpc( const response = responseResult.data; if (response.id !== request.id) { - rejectWithCliError( + settlers.reject( makeCliError(ERROR_CODES.RPC_ERROR, { message: `RPC response id mismatch for method "${method}".`, details: { @@ -237,7 +236,7 @@ export async function sendRpc( ); if (!resultResult.success) { - rejectWithCliError( + settlers.reject( makeCliError(ERROR_CODES.RPC_ERROR, { message: `RPC result failed validation for method "${method}".`, details: { @@ -250,19 +249,19 @@ export async function sendRpc( return; } - resolveWithResult(resultResult.data); + settlers.resolve(resultResult.data); return; } - resolveWithResult(response.result); + settlers.resolve(response.result); return; } - rejectWithCliError( + settlers.reject( toResponseCliError(response.error.code, response.error.message), ); } catch (error) { - rejectWithCliError( + settlers.reject( makeCliError(ERROR_CODES.RPC_ERROR, { message: toErrorMessage( error, @@ -279,11 +278,11 @@ export async function sendRpc( }); socket.on('end', () => { - if (settled || responseHandled) { + if (settlers.isSettled() || responseHandled) { return; } - rejectWithCliError( + settlers.reject( makeCliError(ERROR_CODES.RPC_ERROR, { message: `RPC connection closed before a complete response was received for method "${method}".`, details: { diff --git a/src/host/rpcServer.ts b/src/host/rpcServer.ts index 873a523..04831c0 100644 --- a/src/host/rpcServer.ts +++ b/src/host/rpcServer.ts @@ -10,14 +10,33 @@ import { type RpcMethod, type RpcResponse, } from '../protocol/messages.js'; +import { + createResourceScopedSettlers, + makeAbortReason, +} from '../util/abort.js'; import { invariant } from '../util/assert.js'; +import { ResourceScope } from '../util/resourceScope.js'; const MAX_UNIX_SOCKET_PATH = 104; const MAX_RPC_BUFFER_BYTES = 1_048_576; +const SOCKET_LIVENESS_PROBE_TIMEOUT_MS = 1_000; const UNKNOWN_REQUEST_ID = 'unknown'; -export type MethodHandler = (params: unknown) => Promise; +/** + * Per-request context passed to RPC method handlers. + * + * `signal` aborts when the client socket closes, indicating the caller is no + * longer waiting for a response. + */ +export interface MethodContext { + readonly signal: AbortSignal; +} + +export type MethodHandler = ( + params: unknown, + context: MethodContext, +) => Promise; function isKnownRpcMethod(method: string): method is RpcMethod { return Object.hasOwn(RpcMethodSchemas, method); @@ -46,22 +65,33 @@ async function socketPathExists(socketPath: string): Promise { async function probeSocketLiveness(socketPath: string): Promise { return await new Promise((resolve, reject) => { + const scope = new ResourceScope(); const probe = net.connect({ path: socketPath }); + const settlers = createResourceScopedSettlers(scope, resolve, reject); + + scope.add('rpc liveness probe socket', () => { + probe.destroy(); + }); + const timeoutHandle = setTimeout(() => { + // If connect neither succeeds nor fails promptly, treat the socket path as + // stale rather than blocking host startup indefinitely. + settlers.resolve(false); + }, SOCKET_LIVENESS_PROBE_TIMEOUT_MS); + scope.add('rpc liveness probe timeout', () => { + clearTimeout(timeoutHandle); + }); probe.once('connect', () => { - probe.end(); - resolve(true); + settlers.resolve(true); }); probe.once('error', (error: NodeJS.ErrnoException) => { - probe.destroy(); - if (error.code === 'ECONNREFUSED' || error.code === 'ENOENT') { - resolve(false); + settlers.resolve(false); return; } - reject(error); + settlers.reject(error); }); }); } @@ -368,8 +398,26 @@ export class RpcServer { return; } + const requestScope = new ResourceScope(); + const requestAbortController = new AbortController(); + const abortRequest = (): void => { + requestAbortController.abort( + makeAbortReason('RPC client socket closed.'), + ); + }; + socket.once('close', abortRequest); + requestScope.add('rpc request close listener', () => { + socket.off('close', abortRequest); + }); + try { - const result = await handler(paramsResult.data); + const result = await handler(paramsResult.data, { + signal: requestAbortController.signal, + }); + if (requestAbortController.signal.aborted) { + return; + } + const resultResult = RpcMethodSchemas[request.method].result.safeParse(result); @@ -386,6 +434,10 @@ export class RpcServer { buildSuccessResponse(request.id, resultResult.data), ); } catch (error) { + if (requestAbortController.signal.aborted) { + return; + } + this.sendResponse( socket, error instanceof CliError @@ -398,6 +450,12 @@ export class RpcServer { ), ), ); + } finally { + try { + await requestScope.close(); + } catch (error) { + console.debug('RPC request ResourceScope cleanup failed:', error); + } } } diff --git a/src/host/runCompletionCoordinator.ts b/src/host/runCompletionCoordinator.ts index fb2f30e..8e38594 100644 --- a/src/host/runCompletionCoordinator.ts +++ b/src/host/runCompletionCoordinator.ts @@ -8,7 +8,9 @@ import { RunCompletionSentinelScanner, type SentinelPiece, } from './runCompletionSentinel.js'; +import { waitForScopedOperation } from '../util/abort.js'; import { invariant } from '../util/assert.js'; +import { ResourceScope } from '../util/resourceScope.js'; const RUN_COMPLETION_POSTAMBLE_ECHO_PREFIX = String.raw`printf '\033\137`; const RUN_COMPLETION_SIGNAL_TOKEN_BYTES = 4; @@ -48,11 +50,18 @@ export interface PreparedWaitedRun { marker: string; } +export interface RunCompletionWaitOptions { + readonly signal?: AbortSignal; +} + /** Registered completion state returned after `input_run` appends successfully. */ export interface RegisteredWaitedRunCompletion { postamble: string; sentinel: string; - wait(timeoutMs: number): Promise; + wait( + timeoutMs: number, + options?: RunCompletionWaitOptions, + ): Promise; } function shellOctalEscapedBytes(value: string): string { @@ -190,13 +199,21 @@ export class RunCompletionCoordinator { return { postamble, sentinel, - wait: (timeoutMs: number): Promise => { + wait: ( + timeoutMs: number, + options: RunCompletionWaitOptions = {}, + ): Promise => { invariant( !waitStarted, 'run completion wait must only be started once', ); waitStarted = true; - return this.#waitForRunCompletion(marker, completionPromise, timeoutMs); + return this.#waitForRunCompletion( + marker, + completionPromise, + timeoutMs, + options, + ); }, }; } @@ -255,6 +272,7 @@ export class RunCompletionCoordinator { marker: string, completionPromise: Promise, timeoutMs: number, + options: RunCompletionWaitOptions, ): Promise { assertRunMarker(marker); invariant( @@ -262,43 +280,25 @@ export class RunCompletionCoordinator { 'timeoutMs must be a positive integer', ); - const { promise, reject, resolve } = - Promise.withResolvers(); - let resolved = false; - const timeoutHandle = setTimeout(() => { - if (resolved) { - return; - } - - resolved = true; - // Keep sentinel/postamble registrations active after timeout so the - // eventual internal completion bytes are still hidden from artifacts. + const forgetWaiter = (): void => { + // Match timeout behavior: stop waiting for a client response but keep + // sentinel/postamble registrations active so eventual completion bytes + // remain hidden and replayable. this.#runCompletionWaiters.delete(marker); - resolve({ kind: 'timeout' }); - }, timeoutMs); - - void completionPromise.then( - (result) => { - if (resolved) { - return; - } - - resolved = true; - clearTimeout(timeoutHandle); - resolve(result); - }, - (error: unknown) => { - if (resolved) { - return; - } + }; - resolved = true; - clearTimeout(timeoutHandle); - reject(error instanceof Error ? error : new Error(String(error))); + return await waitForScopedOperation({ + operationName: 'run completion', + operation: completionPromise, + scope: new ResourceScope(), + signal: options.signal, + timeoutMs, + timeoutResult: () => { + forgetWaiter(); + return { kind: 'timeout' }; }, - ); - - return await promise; + onAbort: forgetWaiter, + }); } async #appendOutput(data: string): Promise { diff --git a/src/util/abort.ts b/src/util/abort.ts new file mode 100644 index 0000000..b523edd --- /dev/null +++ b/src/util/abort.ts @@ -0,0 +1,209 @@ +import { invariant } from './assert.js'; +import type { ResourceScope } from './resourceScope.js'; + +export interface ResourceScopedSettlers { + readonly isSettled: () => boolean; + readonly reject: (error: unknown) => void; + readonly resolve: (value: T) => void; +} + +type TimeoutConfig = + | { + readonly timeoutMs: number; + readonly timeoutResult: () => T; + } + | { + readonly timeoutMs?: never; + readonly timeoutResult?: never; + }; + +export type ScopedOperationOptions = { + readonly operationName: string; + readonly operation: Promise; + readonly scope: ResourceScope; + readonly signal?: AbortSignal | undefined; + readonly onAbort?: () => void; +} & TimeoutConfig; + +function toError(error: unknown): Error { + return error instanceof Error ? error : new Error(String(error)); +} + +function makeScopeCloseRejectionError( + originalError: unknown, + closeError: unknown, +): AggregateError { + return new AggregateError( + [toError(originalError), toError(closeError)], + 'ResourceScope close failed while rejecting operation.', + ); +} + +/** Creates a specific AbortError reason to pass into `AbortController.abort()`. */ +export function makeAbortReason(message: string): Error { + invariant(message.length > 0, 'abort reason message must not be empty'); + const error = new Error(message); + error.name = 'AbortError'; + return error; +} + +/** Extracts an AbortError from an observed signal, preserving `signal.reason`. */ +export function makeAbortError(signal?: AbortSignal): Error { + const reason: unknown = signal?.reason; + if (reason instanceof Error) { + return reason; + } + + const error = new Error( + typeof reason === 'string' && reason.length > 0 + ? reason + : 'Operation aborted.', + ); + error.name = 'AbortError'; + return error; +} + +/** Throws if `signal` is aborted; no-op when `signal` is undefined. */ +export function throwIfAborted(signal?: AbortSignal): void { + if (signal?.aborted === true) { + throw makeAbortError(signal); + } +} + +/** + * Registers an abort listener and removes it when `scope` closes. + * + * The signal must not already be aborted; callers should check + * `signal.aborted` or call `throwIfAborted()` before registering. + */ +export function addAbortListener( + scope: ResourceScope, + name: string, + signal: AbortSignal, + listener: () => void, +): void { + invariant( + !signal.aborted, + 'abort listener must be registered before signal aborts', + ); + invariant( + typeof listener === 'function', + 'abort listener must be a function', + ); + + signal.addEventListener('abort', listener, { once: true }); + scope.add(name, () => { + signal.removeEventListener('abort', listener); + }); +} + +/** + * Creates idempotent Promise settlers that close `scope` before resolving or + * rejecting the outer operation. If cleanup fails while resolving, the promise + * rejects with the cleanup error. If cleanup fails while rejecting, the original + * operation error is preserved alongside the cleanup failure in an + * `AggregateError`. + */ +export function createResourceScopedSettlers( + scope: ResourceScope, + resolve: (value: T) => void, + reject: (error: Error) => void, +): ResourceScopedSettlers { + invariant(typeof resolve === 'function', 'resolve must be a function'); + invariant(typeof reject === 'function', 'reject must be a function'); + + let settled = false; + + return { + isSettled: () => settled, + reject: (error: unknown): void => { + if (settled) { + return; + } + + settled = true; + void scope.close().then( + () => { + reject(toError(error)); + }, + (closeError: unknown) => { + reject(makeScopeCloseRejectionError(error, closeError)); + }, + ); + }, + resolve: (value: T): void => { + if (settled) { + return; + } + + settled = true; + void scope.close().then( + () => { + resolve(value); + }, + (closeError: unknown) => { + reject(toError(closeError)); + }, + ); + }, + }; +} + +function runAbortCallback(onAbort: (() => void) | undefined): Error | null { + try { + onAbort?.(); + return null; + } catch (error) { + return toError(error); + } +} + +/** + * Waits for `operation`, an optional timeout, or an optional abort signal while + * tying all timers/listeners to `scope`. The scope closes before the returned + * promise settles. `timeoutResult` is evaluated lazily when the timeout wins, + * and `onAbort` runs before scope cleanup for both pre-aborted and later-aborted + * signals. + */ +export async function waitForScopedOperation( + options: ScopedOperationOptions, +): Promise { + const { + operationName, + operation, + scope, + signal, + timeoutMs, + timeoutResult, + onAbort, + } = options; + invariant(operationName.length > 0, 'operationName must not be empty'); + if (signal?.aborted === true) { + const abortCallbackError = runAbortCallback(onAbort); + await scope.close(); + throw abortCallbackError ?? makeAbortError(signal); + } + + const { promise, reject, resolve } = Promise.withResolvers(); + const settlers = createResourceScopedSettlers(scope, resolve, reject); + + if (timeoutMs !== undefined) { + const timeoutHandle = setTimeout(() => { + settlers.resolve(timeoutResult()); + }, timeoutMs); + scope.add(`${operationName} timeout`, () => { + clearTimeout(timeoutHandle); + }); + } + + if (signal !== undefined) { + addAbortListener(scope, `${operationName} abort listener`, signal, () => { + const abortCallbackError = runAbortCallback(onAbort); + settlers.reject(abortCallbackError ?? makeAbortError(signal)); + }); + } + + void operation.then(settlers.resolve, settlers.reject); + + return await promise; +} diff --git a/test/unit/host/rpcClient.test.ts b/test/unit/host/rpcClient.test.ts new file mode 100644 index 0000000..58389b6 --- /dev/null +++ b/test/unit/host/rpcClient.test.ts @@ -0,0 +1,86 @@ +import { mkdtemp, rm } from 'node:fs/promises'; +import { tmpdir } from 'node:os'; +import { join } from 'node:path'; +import net from 'node:net'; + +import { afterEach, beforeEach, describe, expect, it } from 'vitest'; + +import { sendRpc } from '../../../src/host/rpcClient.js'; + +let tempDir = ''; + +describe('sendRpc abort handling', () => { + beforeEach(async () => { + tempDir = await mkdtemp(join(tmpdir(), 'agent-tty-rpc-client-')); + }); + + afterEach(async () => { + await rm(tempDir, { recursive: true, force: true }); + tempDir = ''; + }); + + it('rejects before opening a socket when the signal is already aborted', async () => { + const controller = new AbortController(); + const reason = new Error('already cancelled'); + controller.abort(reason); + + await expect( + sendRpc( + join(tempDir, 'missing.sock'), + 'inspect', + {}, + 5_000, + controller.signal, + ), + ).rejects.toThrow(reason); + }); + + it('destroys an in-flight socket when the signal aborts', async () => { + const socketFile = join(tempDir, 'rpc.sock'); + const connected = Promise.withResolvers(); + const serverSocketClosed = Promise.withResolvers(); + const server = net.createServer((socket) => { + connected.resolve(socket); + socket.once('close', () => { + serverSocketClosed.resolve(undefined); + }); + socket.resume(); + }); + await new Promise((resolve, reject) => { + server.once('error', reject); + server.listen(socketFile, () => { + server.off('error', reject); + resolve(); + }); + }); + + try { + const controller = new AbortController(); + const reason = new Error('client cancelled'); + const request = sendRpc( + socketFile, + 'inspect', + {}, + 5_000, + controller.signal, + ); + const serverSocket = await connected.promise; + + controller.abort(reason); + + await expect(request).rejects.toThrow(reason); + await serverSocketClosed.promise; + expect(serverSocket.destroyed).toBe(true); + } finally { + await new Promise((resolve, reject) => { + server.close((error) => { + if (error !== undefined) { + reject(error); + return; + } + resolve(); + }); + }); + } + }); +}); diff --git a/test/unit/host/rpcServer.test.ts b/test/unit/host/rpcServer.test.ts new file mode 100644 index 0000000..37e653e --- /dev/null +++ b/test/unit/host/rpcServer.test.ts @@ -0,0 +1,60 @@ +import { mkdtemp, rm } from 'node:fs/promises'; +import { tmpdir } from 'node:os'; +import { join } from 'node:path'; +import net from 'node:net'; + +import { afterEach, beforeEach, describe, expect, it } from 'vitest'; + +import { RpcServer } from '../../../src/host/rpcServer.js'; + +let tempDir = ''; + +describe('RpcServer request abort handling', () => { + beforeEach(async () => { + tempDir = await mkdtemp(join(tmpdir(), 'agent-tty-rpc-server-')); + }); + + afterEach(async () => { + await rm(tempDir, { recursive: true, force: true }); + tempDir = ''; + }); + + it('aborts the request context when the client socket closes', async () => { + const socketFile = join(tempDir, 'rpc.sock'); + const handlerStarted = Promise.withResolvers(); + const requestAborted = Promise.withResolvers(); + const server = new RpcServer(socketFile, { + inspect: async (_params, context) => { + handlerStarted.resolve(undefined); + context.signal.addEventListener( + 'abort', + () => { + requestAborted.resolve(context.signal); + }, + { once: true }, + ); + await requestAborted.promise; + return {}; + }, + }); + await server.listen(); + + const client = net.connect({ path: socketFile }); + await new Promise((resolve, reject) => { + client.once('connect', resolve); + client.once('error', reject); + }); + client.write( + `${JSON.stringify({ id: 'request-1', method: 'inspect', params: {} })}\n`, + ); + await handlerStarted.promise; + client.destroy(); + + try { + const signal = await requestAborted.promise; + expect(signal.aborted).toBe(true); + } finally { + await server.close(); + } + }); +}); diff --git a/test/unit/host/runCompletionCoordinator.test.ts b/test/unit/host/runCompletionCoordinator.test.ts index 2b029c8..e184548 100644 --- a/test/unit/host/runCompletionCoordinator.test.ts +++ b/test/unit/host/runCompletionCoordinator.test.ts @@ -174,6 +174,43 @@ describe('RunCompletionCoordinator', () => { } }); + it('aborts an active wait and clears its timeout while preserving hidden completion bytes', async () => { + vi.useFakeTimers(); + try { + const { appender, events } = createFakeAppender(); + const coordinator = new RunCompletionCoordinator(appender); + const prepared = coordinator.prepareWaitedRun(); + const completion = coordinator.registerWaitedRun({ + marker: prepared.marker, + inputRunSeq: 13, + }); + const controller = new AbortController(); + const abortReason = new Error('caller disconnected'); + const waitPromise = completion.wait(1_000, { signal: controller.signal }); + + controller.abort(abortReason); + + await expect(waitPromise).rejects.toThrow('caller disconnected'); + expect(vi.getTimerCount()).toBe(0); + + await vi.advanceTimersByTimeAsync(1_000); + await coordinator.ingestPtyData(`before${completion.sentinel}after`); + + expect(events).toEqual([ + { type: 'output', data: 'before' }, + { + type: 'run_complete', + marker: prepared.marker, + inputRunSeq: 13, + seq: 100, + }, + { type: 'output', data: 'after' }, + ]); + } finally { + vi.useRealTimers(); + } + }); + it('fails ingestion loudly when a timed-out completion later cannot append run_complete', async () => { vi.useFakeTimers(); try { diff --git a/test/unit/util/abort.test.ts b/test/unit/util/abort.test.ts new file mode 100644 index 0000000..96d968a --- /dev/null +++ b/test/unit/util/abort.test.ts @@ -0,0 +1,272 @@ +import { describe, expect, it, vi } from 'vitest'; + +import { + addAbortListener, + createResourceScopedSettlers, + makeAbortError, + makeAbortReason, + throwIfAborted, + waitForScopedOperation, +} from '../../../src/util/abort.js'; +import { + ResourceScope, + ResourceScopeCloseError, +} from '../../../src/util/resourceScope.js'; + +describe('abort utilities', () => { + it('creates AbortError instances from missing or string abort reasons', () => { + const defaultError = makeAbortError(); + expect(defaultError.name).toBe('AbortError'); + expect(defaultError.message).toBe('Operation aborted.'); + + const controller = new AbortController(); + controller.abort('client disconnected'); + + const reasonError = makeAbortError(controller.signal); + expect(reasonError.name).toBe('AbortError'); + expect(reasonError.message).toBe('client disconnected'); + }); + + it('creates named abort reasons for internal controllers', () => { + const error = makeAbortReason('host shutdown'); + + expect(error.name).toBe('AbortError'); + expect(error.message).toBe('host shutdown'); + }); + + it('forwards Error abort reasons without wrapping them', () => { + const controller = new AbortController(); + const reason = new Error('stop now'); + controller.abort(reason); + + expect(makeAbortError(controller.signal)).toBe(reason); + expect(() => throwIfAborted(controller.signal)).toThrow(reason); + }); + + it('registers abort listeners with ResourceScope cleanup', async () => { + const scope = new ResourceScope(); + const controller = new AbortController(); + const listener = vi.fn(); + + addAbortListener(scope, 'test abort listener', controller.signal, listener); + await scope.close(); + controller.abort(); + + expect(listener).not.toHaveBeenCalled(); + }); + + it('asserts when registering a listener on an already aborted signal', () => { + const controller = new AbortController(); + controller.abort(); + + expect(() => + addAbortListener( + new ResourceScope(), + 'late listener', + controller.signal, + () => undefined, + ), + ).toThrow(/before signal aborts/u); + }); + + it('settles only once and closes the ResourceScope before resolving', async () => { + const scope = new ResourceScope(); + const releases: string[] = []; + scope.add('release', () => { + releases.push('closed'); + }); + const { promise, reject, resolve } = Promise.withResolvers(); + const settlers = createResourceScopedSettlers(scope, resolve, reject); + + settlers.resolve('ok'); + settlers.reject(new Error('late')); + + await expect(promise).resolves.toBe('ok'); + expect(releases).toEqual(['closed']); + }); + + it('preserves the original rejection when scope close also fails', async () => { + const scope = new ResourceScope(); + const closeFailure = new Error('close failed'); + const originalFailure = new Error('operation failed'); + scope.add('failing release', () => { + throw closeFailure; + }); + const { promise, reject, resolve } = Promise.withResolvers(); + const settlers = createResourceScopedSettlers(scope, resolve, reject); + + settlers.reject(originalFailure); + + await expect(promise).rejects.toMatchObject({ + errors: [originalFailure, expect.any(ResourceScopeCloseError)], + }); + }); + + it('rejects with the close failure when resolving cannot close the scope', async () => { + const scope = new ResourceScope(); + const closeFailure = new Error('close failed'); + scope.add('failing release', () => { + throw closeFailure; + }); + const { promise, reject, resolve } = Promise.withResolvers(); + const settlers = createResourceScopedSettlers(scope, resolve, reject); + + settlers.resolve('ok'); + + let caught: unknown; + try { + await promise; + } catch (error) { + caught = error; + } + + expect(caught).toBeInstanceOf(ResourceScopeCloseError); + expect((caught as ResourceScopeCloseError).failures).toEqual([ + { name: 'failing release', error: closeFailure }, + ]); + }); + + it('waitForScopedOperation resolves when the operation resolves first', async () => { + const releases: string[] = []; + const scope = new ResourceScope(); + scope.add('release', () => { + releases.push('closed'); + }); + + await expect( + waitForScopedOperation({ + operationName: 'test operation', + operation: Promise.resolve('done'), + scope, + }), + ).resolves.toBe('done'); + expect(releases).toEqual(['closed']); + }); + + it('waitForScopedOperation rejects when the operation rejects first', async () => { + const releases: string[] = []; + const failure = new Error('operation failed'); + const scope = new ResourceScope(); + scope.add('release', () => { + releases.push('closed'); + }); + + await expect( + waitForScopedOperation({ + operationName: 'test operation', + operation: Promise.reject(failure), + scope, + }), + ).rejects.toThrow(failure); + expect(releases).toEqual(['closed']); + }); + + it('waitForScopedOperation clears the timeout when the operation resolves first', async () => { + vi.useFakeTimers(); + try { + const promise = waitForScopedOperation({ + operationName: 'test operation', + operation: Promise.resolve('done'), + scope: new ResourceScope(), + timeoutMs: 100, + timeoutResult: () => 'timed out', + }); + + await expect(promise).resolves.toBe('done'); + expect(vi.getTimerCount()).toBe(0); + } finally { + vi.useRealTimers(); + } + }); + + it('waitForScopedOperation resolves timeout results and clears the timer', async () => { + vi.useFakeTimers(); + try { + const never = new Promise(() => undefined); + const promise = waitForScopedOperation({ + operationName: 'test operation', + operation: never, + scope: new ResourceScope(), + timeoutMs: 10, + timeoutResult: () => 'timed out', + }); + + await vi.advanceTimersByTimeAsync(10); + + await expect(promise).resolves.toBe('timed out'); + expect(vi.getTimerCount()).toBe(0); + } finally { + vi.useRealTimers(); + } + }); + + it('waitForScopedOperation handles pre-aborted signals with cleanup and onAbort', async () => { + const controller = new AbortController(); + const reason = new Error('already closed'); + const releases: string[] = []; + const onAbort = vi.fn(); + const scope = new ResourceScope(); + scope.add('release', () => { + releases.push('closed'); + }); + controller.abort(reason); + + await expect( + waitForScopedOperation({ + operationName: 'test operation', + operation: Promise.resolve('late'), + scope, + signal: controller.signal, + onAbort, + }), + ).rejects.toThrow(reason); + expect(onAbort).toHaveBeenCalledTimes(1); + expect(releases).toEqual(['closed']); + }); + + it('waitForScopedOperation rejects instead of throwing from an abort callback', async () => { + const controller = new AbortController(); + const callbackError = new Error('abort cleanup failed'); + const never = new Promise(() => undefined); + const promise = waitForScopedOperation({ + operationName: 'test operation', + operation: never, + scope: new ResourceScope(), + signal: controller.signal, + onAbort: () => { + throw callbackError; + }, + }); + + controller.abort(new Error('request closed')); + + await expect(promise).rejects.toThrow(callbackError); + }); + + it('waitForScopedOperation aborts, runs onAbort, and clears the timeout', async () => { + vi.useFakeTimers(); + try { + const controller = new AbortController(); + const onAbort = vi.fn(); + const reason = new Error('request closed'); + const never = new Promise(() => undefined); + const promise = waitForScopedOperation({ + operationName: 'test operation', + operation: never, + scope: new ResourceScope(), + signal: controller.signal, + timeoutMs: 100, + timeoutResult: () => 'timed out', + onAbort, + }); + + controller.abort(reason); + + await expect(promise).rejects.toThrow(reason); + expect(onAbort).toHaveBeenCalledTimes(1); + expect(vi.getTimerCount()).toBe(0); + } finally { + vi.useRealTimers(); + } + }); +});