diff --git a/apps/sim/lib/core/security/input-validation.server.ts b/apps/sim/lib/core/security/input-validation.server.ts index cdf5f28a9a1..81fbfe75e5a 100644 --- a/apps/sim/lib/core/security/input-validation.server.ts +++ b/apps/sim/lib/core/security/input-validation.server.ts @@ -5,6 +5,7 @@ import type { LookupFunction } from 'net' import { createLogger } from '@sim/logger' import { toError } from '@sim/utils/errors' import * as ipaddr from 'ipaddr.js' +import { Agent, type RequestInit as UndiciRequestInit, fetch as undiciFetch } from 'undici' import { isHosted } from '@/lib/core/config/feature-flags' import { type ValidationResult, validateExternalUrl } from '@/lib/core/security/input-validation' import { PayloadSizeLimitError } from '@/lib/core/utils/stream-limits' @@ -400,6 +401,40 @@ export function createPinnedLookup(resolvedIP: string): LookupFunction { } } +/** + * Builds a standard `fetch`-compatible function that pins every outbound + * connection to `resolvedIP`, preventing DNS-rebinding (TOCTOU) between URL + * validation and connection. The original hostname is preserved for TLS SNI and + * the `Host` header so it still matches the certificate. This is the single + * source of truth for pinned outbound fetches — both the LLM providers and the + * MCP transport consume it. + * + * Pass the returned function as the `fetch` option to the OpenAI/Anthropic SDKs + * (or call it directly) after validating the URL with {@link validateUrlWithDNS} + * and capturing `resolvedIP`. Because the pinned lookup always returns + * `resolvedIP` regardless of hostname, any redirect the server returns also + * connects to the validated IP — an attacker cannot rebind a redirect target to + * an internal address. + * + * The `Agent` is captured for the lifetime of the returned function, so repeated + * calls (e.g. a provider tool loop) reuse its keep-alive connections. + */ +export function createPinnedFetch(resolvedIP: string): typeof fetch { + const dispatcher = new Agent({ connect: { lookup: createPinnedLookup(resolvedIP) } }) + + const pinned = async (input: RequestInfo | URL, init?: RequestInit): Promise => { + // double-cast-allowed: DOM RequestInfo/URL and undici fetch input types differ but are structurally compatible at runtime (Node's global fetch IS undici) + const undiciInput = input as unknown as Parameters[0] + // double-cast-allowed: DOM RequestInit and undici RequestInit are structurally compatible at runtime but the TS types differ + const undiciInit: UndiciRequestInit = { ...(init as unknown as UndiciRequestInit), dispatcher } + const response = await undiciFetch(undiciInput, undiciInit) + // double-cast-allowed: undici Response and DOM Response are structurally compatible at runtime + return response as unknown as Response + } + + return pinned +} + /** * Performs a fetch with IP pinning to prevent DNS rebinding attacks. * Uses the pre-resolved IP address while preserving the original hostname for TLS SNI. diff --git a/apps/sim/lib/core/security/pinned-fetch.server.test.ts b/apps/sim/lib/core/security/pinned-fetch.server.test.ts new file mode 100644 index 00000000000..d63bad257ee --- /dev/null +++ b/apps/sim/lib/core/security/pinned-fetch.server.test.ts @@ -0,0 +1,126 @@ +/** + * @vitest-environment node + */ +import { featureFlagsMock } from '@sim/testing' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { mockAgent, mockUndiciFetch, capturedAgentOptions, agentCloses } = vi.hoisted(() => { + const capturedAgentOptions: unknown[] = [] + const agentCloses: unknown[] = [] + class MockAgent { + constructor(options: unknown) { + capturedAgentOptions.push(options) + } + close() { + agentCloses.push(this) + return Promise.resolve() + } + } + return { + mockAgent: MockAgent, + mockUndiciFetch: vi.fn(), + capturedAgentOptions, + agentCloses, + } +}) + +vi.mock('undici', () => ({ Agent: mockAgent, fetch: mockUndiciFetch })) +vi.mock('@/lib/core/config/feature-flags', () => featureFlagsMock) + +import { createPinnedFetch } from '@/lib/core/security/input-validation.server' + +type LookupCallback = (err: Error | null, address: string, family: number) => void +type PinnedLookup = (hostname: string, options: { all?: boolean }, callback: LookupCallback) => void + +describe('createPinnedFetch', () => { + beforeEach(() => { + vi.clearAllMocks() + capturedAgentOptions.length = 0 + agentCloses.length = 0 + mockUndiciFetch.mockResolvedValue(new Response('ok')) + }) + + it('builds an undici Agent whose pinned lookup always resolves to the validated IP', async () => { + createPinnedFetch('203.0.113.10') + + expect(capturedAgentOptions).toHaveLength(1) + const { connect } = capturedAgentOptions[0] as { connect: { lookup: PinnedLookup } } + expect(typeof connect.lookup).toBe('function') + + const resolved = await new Promise<{ address: string; family: number }>((resolve) => { + connect.lookup('rebind.attacker.tld', {}, (_err, address, family) => + resolve({ address, family }) + ) + }) + expect(resolved).toEqual({ address: '203.0.113.10', family: 4 }) + }) + + it('uses IPv6 family when the validated IP is IPv6', async () => { + createPinnedFetch('2606:4700:4700::1111') + const { connect } = capturedAgentOptions[0] as { connect: { lookup: PinnedLookup } } + const resolved = await new Promise<{ address: string; family: number }>((resolve) => { + connect.lookup('example.com', {}, (_err, address, family) => resolve({ address, family })) + }) + expect(resolved).toEqual({ address: '2606:4700:4700::1111', family: 6 }) + }) + + it('forwards the pinned dispatcher on every call while preserving init options', async () => { + const pinned = createPinnedFetch('203.0.113.10') + const controller = new AbortController() + + await pinned('https://myresource.openai.azure.com/openai/v1/responses', { + method: 'POST', + headers: { 'api-key': 'secret' }, + body: '{}', + signal: controller.signal, + }) + + expect(mockUndiciFetch).toHaveBeenCalledTimes(1) + const [url, init] = mockUndiciFetch.mock.calls[0] + expect(url).toBe('https://myresource.openai.azure.com/openai/v1/responses') + const typedInit = init as RequestInit & { dispatcher?: unknown } + expect(typedInit.dispatcher).toBeInstanceOf(mockAgent) + expect(typedInit.method).toBe('POST') + expect(typedInit.headers).toEqual({ 'api-key': 'secret' }) + expect(typedInit.body).toBe('{}') + expect(typedInit.signal).toBe(controller.signal) + }) + + it('handles an undefined init by still attaching the dispatcher', async () => { + const pinned = createPinnedFetch('203.0.113.10') + await pinned('https://example.com') + const init = mockUndiciFetch.mock.calls[0][1] as { dispatcher?: unknown } + expect(init.dispatcher).toBeInstanceOf(mockAgent) + }) + + it('reuses one captured dispatcher across all calls of a single instance', async () => { + const pinned = createPinnedFetch('203.0.113.10') + await pinned('https://example.com/a') + await pinned('https://example.com/b') + + expect(capturedAgentOptions).toHaveLength(1) + const d1 = (mockUndiciFetch.mock.calls[0][1] as { dispatcher: unknown }).dispatcher + const d2 = (mockUndiciFetch.mock.calls[1][1] as { dispatcher: unknown }).dispatcher + expect(d1).toBe(d2) + }) + + it('creates an independent dispatcher per instance', async () => { + const a = createPinnedFetch('203.0.113.10') + const b = createPinnedFetch('203.0.113.10') + await a('https://example.com/a') + await b('https://example.com/b') + + expect(capturedAgentOptions).toHaveLength(2) + const d1 = (mockUndiciFetch.mock.calls[0][1] as { dispatcher: unknown }).dispatcher + const d2 = (mockUndiciFetch.mock.calls[1][1] as { dispatcher: unknown }).dispatcher + expect(d1).not.toBe(d2) + }) + + it('returns the response produced by undici fetch', async () => { + mockUndiciFetch.mockResolvedValueOnce(new Response('pong', { status: 201 })) + const pinned = createPinnedFetch('203.0.113.10') + const response = await pinned('https://example.com') + expect(response.status).toBe(201) + expect(await response.text()).toBe('pong') + }) +}) diff --git a/apps/sim/lib/mcp/client.ts b/apps/sim/lib/mcp/client.ts index bef88182c9e..569320ef882 100644 --- a/apps/sim/lib/mcp/client.ts +++ b/apps/sim/lib/mcp/client.ts @@ -11,8 +11,8 @@ import { import { createLogger } from '@sim/logger' import { getErrorMessage } from '@sim/utils/errors' import { getMaxExecutionTimeout } from '@/lib/core/execution-limits' +import { createPinnedFetch } from '@/lib/core/security/input-validation.server' import { McpOauthRedirectRequired } from '@/lib/mcp/oauth' -import { createMcpPinnedFetch } from '@/lib/mcp/pinned-fetch' import { type McpClientOptions, McpConnectionError, @@ -70,7 +70,7 @@ export class McpClient { this.transport = new StreamableHTTPClientTransport(new URL(this.config.url), { authProvider: useOauth ? this.authProvider : undefined, requestInit: { headers: this.config.headers }, - ...(resolvedIP ? { fetch: createMcpPinnedFetch(resolvedIP) } : {}), + ...(resolvedIP ? { fetch: createPinnedFetch(resolvedIP) } : {}), }) this.client = new Client( diff --git a/apps/sim/lib/mcp/oauth/probe.test.ts b/apps/sim/lib/mcp/oauth/probe.test.ts index 34e7d6199e4..d691f1178c7 100644 --- a/apps/sim/lib/mcp/oauth/probe.test.ts +++ b/apps/sim/lib/mcp/oauth/probe.test.ts @@ -3,24 +3,22 @@ */ import { beforeEach, describe, expect, it, vi } from 'vitest' -const { - mockCreateMcpPinnedFetch, - mockCreateSsrfGuardedMcpFetch, - mockPinnedFetch, - mockGuardedFetch, -} = vi.hoisted(() => { - const mockPinnedFetch = vi.fn() - const mockGuardedFetch = vi.fn() - return { - mockPinnedFetch, - mockGuardedFetch, - mockCreateMcpPinnedFetch: vi.fn(() => mockPinnedFetch), - mockCreateSsrfGuardedMcpFetch: vi.fn(() => mockGuardedFetch), - } -}) +const { mockCreatePinnedFetch, mockCreateSsrfGuardedMcpFetch, mockPinnedFetch, mockGuardedFetch } = + vi.hoisted(() => { + const mockPinnedFetch = vi.fn() + const mockGuardedFetch = vi.fn() + return { + mockPinnedFetch, + mockGuardedFetch, + mockCreatePinnedFetch: vi.fn(() => mockPinnedFetch), + mockCreateSsrfGuardedMcpFetch: vi.fn(() => mockGuardedFetch), + } + }) +vi.mock('@/lib/core/security/input-validation.server', () => ({ + createPinnedFetch: mockCreatePinnedFetch, +})) vi.mock('@/lib/mcp/pinned-fetch', () => ({ - createMcpPinnedFetch: mockCreateMcpPinnedFetch, createSsrfGuardedMcpFetch: mockCreateSsrfGuardedMcpFetch, })) @@ -50,7 +48,7 @@ describe('detectMcpAuthType — connection pinning (SSRF / DNS-rebinding)', () = const authType = await detectMcpAuthType('https://rebind.example.com/mcp', '203.0.113.10') expect(authType).toBe('none') - expect(mockCreateMcpPinnedFetch).toHaveBeenCalledWith('203.0.113.10') + expect(mockCreatePinnedFetch).toHaveBeenCalledWith('203.0.113.10') expect(mockCreateSsrfGuardedMcpFetch).not.toHaveBeenCalled() expect(mockPinnedFetch).toHaveBeenCalledTimes(1) // The unpinned global fetch must never be used — that was the SSRF sink. @@ -64,7 +62,7 @@ describe('detectMcpAuthType — connection pinning (SSRF / DNS-rebinding)', () = expect(authType).toBe('none') expect(mockCreateSsrfGuardedMcpFetch).toHaveBeenCalledTimes(1) - expect(mockCreateMcpPinnedFetch).not.toHaveBeenCalled() + expect(mockCreatePinnedFetch).not.toHaveBeenCalled() expect(mockGuardedFetch).toHaveBeenCalledTimes(1) expect(globalFetchSpy).not.toHaveBeenCalled() }) @@ -90,7 +88,7 @@ describe('detectMcpAuthType — connection pinning (SSRF / DNS-rebinding)', () = const authType = await detectMcpAuthType('http://example.com/mcp', '203.0.113.10') expect(authType).toBe('headers') - expect(mockCreateMcpPinnedFetch).not.toHaveBeenCalled() + expect(mockCreatePinnedFetch).not.toHaveBeenCalled() expect(mockCreateSsrfGuardedMcpFetch).not.toHaveBeenCalled() expect(globalFetchSpy).not.toHaveBeenCalled() }) diff --git a/apps/sim/lib/mcp/oauth/probe.ts b/apps/sim/lib/mcp/oauth/probe.ts index 887ba8ce971..4343f27870a 100644 --- a/apps/sim/lib/mcp/oauth/probe.ts +++ b/apps/sim/lib/mcp/oauth/probe.ts @@ -1,8 +1,9 @@ import { extractWWWAuthenticateParams } from '@modelcontextprotocol/sdk/client/auth.js' import type { FetchLike } from '@modelcontextprotocol/sdk/shared/transport.js' import { createLogger } from '@sim/logger' +import { createPinnedFetch } from '@/lib/core/security/input-validation.server' import { isLoopbackHostname } from '@/lib/core/utils/urls' -import { createMcpPinnedFetch, createSsrfGuardedMcpFetch } from '@/lib/mcp/pinned-fetch' +import { createSsrfGuardedMcpFetch } from '@/lib/mcp/pinned-fetch' import type { McpAuthType } from '@/lib/mcp/types' const logger = createLogger('McpOauthProbe') @@ -33,7 +34,7 @@ export async function detectMcpAuthType( } const probeFetch: FetchLike = resolvedIP - ? createMcpPinnedFetch(resolvedIP) + ? createPinnedFetch(resolvedIP) : createSsrfGuardedMcpFetch() const controller = new AbortController() diff --git a/apps/sim/lib/mcp/oauth/revoke.test.ts b/apps/sim/lib/mcp/oauth/revoke.test.ts index d8f6342568b..ccd0caba98b 100644 --- a/apps/sim/lib/mcp/oauth/revoke.test.ts +++ b/apps/sim/lib/mcp/oauth/revoke.test.ts @@ -15,33 +15,23 @@ const PUBLIC_SERVER_URL = 'https://mcp.attacker.com' const PUBLIC_SERVER_IP = '203.0.113.10' const { - MockAgent, mockUndiciFetch, mockValidateMcpServerSsrf, mockDiscoverOAuthServerInfo, mockLoadOauthRow, mockDecryptSecret, mockDbSelect, -} = vi.hoisted(() => { - class MockAgent { - close() { - return Promise.resolve() - } - } - return { - MockAgent, - mockUndiciFetch: vi.fn(), - mockValidateMcpServerSsrf: vi.fn(), - mockDiscoverOAuthServerInfo: vi.fn(), - mockLoadOauthRow: vi.fn(), - mockDecryptSecret: vi.fn(), - mockDbSelect: vi.fn(), - } -}) +} = vi.hoisted(() => ({ + mockUndiciFetch: vi.fn(), + mockValidateMcpServerSsrf: vi.fn(), + mockDiscoverOAuthServerInfo: vi.fn(), + mockLoadOauthRow: vi.fn(), + mockDecryptSecret: vi.fn(), + mockDbSelect: vi.fn(), +})) -vi.mock('undici', () => ({ Agent: MockAgent, fetch: mockUndiciFetch })) vi.mock('@/lib/core/security/input-validation.server', () => ({ - createPinnedLookup: vi.fn(() => 'pinned-lookup-fn'), + createPinnedFetch: vi.fn(() => mockUndiciFetch), })) vi.mock('@/lib/mcp/domain-check', () => ({ validateMcpServerSsrf: mockValidateMcpServerSsrf, @@ -59,7 +49,6 @@ vi.mock('@sim/db', () => ({ db: { select: mockDbSelect }, })) -import { __resetPinnedAgentsForTests } from '@/lib/mcp/pinned-fetch' import { revokeMcpOauthTokens } from './revoke' function wireServerRow(row: Record) { @@ -74,7 +63,6 @@ function wireServerRow(row: Record) { describe('revokeMcpOauthTokens — SSRF guard', () => { beforeEach(() => { vi.clearAllMocks() - __resetPinnedAgentsForTests() mockLoadOauthRow.mockResolvedValue({ tokens: { access_token: 'access-secret', refresh_token: 'refresh-secret' }, diff --git a/apps/sim/lib/mcp/pinned-fetch.test.ts b/apps/sim/lib/mcp/pinned-fetch.test.ts index 9f6b5919bf2..64354ee708b 100644 --- a/apps/sim/lib/mcp/pinned-fetch.test.ts +++ b/apps/sim/lib/mcp/pinned-fetch.test.ts @@ -3,147 +3,26 @@ */ import { beforeEach, describe, expect, it, vi } from 'vitest' -const { mockAgent, mockCreatePinnedLookup, mockUndiciFetch, capturedAgentOptions, agentCloses } = - vi.hoisted(() => { - const capturedAgentOptions: unknown[] = [] - const agentCloses: unknown[] = [] - class MockAgent { - constructor(options: unknown) { - capturedAgentOptions.push(options) - } - close() { - agentCloses.push(this) - return Promise.resolve() - } - } - return { - mockAgent: MockAgent, - mockCreatePinnedLookup: vi.fn(), - mockUndiciFetch: vi.fn(), - capturedAgentOptions, - agentCloses, - } - }) - -const { mockValidateMcpServerSsrf } = vi.hoisted(() => ({ +const { mockCreatePinnedFetch, mockValidateMcpServerSsrf, sentinelFetch } = vi.hoisted(() => ({ + mockCreatePinnedFetch: vi.fn(), mockValidateMcpServerSsrf: vi.fn(), + sentinelFetch: vi.fn(), })) -vi.mock('undici', () => ({ Agent: mockAgent, fetch: mockUndiciFetch })) vi.mock('@/lib/core/security/input-validation.server', () => ({ - createPinnedLookup: mockCreatePinnedLookup, + createPinnedFetch: mockCreatePinnedFetch, })) vi.mock('@/lib/mcp/domain-check', () => ({ validateMcpServerSsrf: mockValidateMcpServerSsrf, })) -import { - __resetPinnedAgentsForTests, - createMcpPinnedFetch, - createSsrfGuardedMcpFetch, -} from '@/lib/mcp/pinned-fetch' - -describe('createMcpPinnedFetch', () => { - beforeEach(() => { - vi.clearAllMocks() - capturedAgentOptions.length = 0 - agentCloses.length = 0 - __resetPinnedAgentsForTests() - mockCreatePinnedLookup.mockReturnValue('pinned-lookup-fn') - mockUndiciFetch.mockResolvedValue(new Response('ok')) - }) - - it('builds an undici Agent with the pinned lookup for the resolved IP', () => { - createMcpPinnedFetch('203.0.113.10') - expect(mockCreatePinnedLookup).toHaveBeenCalledWith('203.0.113.10') - expect(capturedAgentOptions).toHaveLength(1) - expect(capturedAgentOptions[0]).toEqual({ connect: { lookup: 'pinned-lookup-fn' } }) - }) - - it('forwards the dispatcher on every fetch call', async () => { - const fetchLike = createMcpPinnedFetch('203.0.113.10') - await fetchLike('https://example.com/mcp', { method: 'POST' }) - expect(mockUndiciFetch).toHaveBeenCalledTimes(1) - const [url, init] = mockUndiciFetch.mock.calls[0] - expect(url).toBe('https://example.com/mcp') - expect((init as { dispatcher?: unknown }).dispatcher).toBeInstanceOf(mockAgent) - expect((init as { method?: string }).method).toBe('POST') - }) - - it('preserves caller-provided init options (headers, signal)', async () => { - const fetchLike = createMcpPinnedFetch('203.0.113.10') - const controller = new AbortController() - await fetchLike('https://example.com/mcp', { - method: 'GET', - headers: { 'x-test': '1' }, - signal: controller.signal, - }) - const init = mockUndiciFetch.mock.calls[0][1] as RequestInit & { dispatcher?: unknown } - expect(init.headers).toEqual({ 'x-test': '1' }) - expect(init.signal).toBe(controller.signal) - expect(init.dispatcher).toBeInstanceOf(mockAgent) - }) - - it('handles undefined init gracefully', async () => { - const fetchLike = createMcpPinnedFetch('203.0.113.10') - await fetchLike('https://example.com/mcp') - const init = mockUndiciFetch.mock.calls[0][1] as { dispatcher?: unknown } - expect(init.dispatcher).toBeInstanceOf(mockAgent) - }) - - it('reuses the same dispatcher across calls within a fetch instance', async () => { - const fetchLike = createMcpPinnedFetch('203.0.113.10') - await fetchLike('https://example.com/a') - await fetchLike('https://example.com/b') - expect(capturedAgentOptions).toHaveLength(1) - const d1 = (mockUndiciFetch.mock.calls[0][1] as { dispatcher: unknown }).dispatcher - const d2 = (mockUndiciFetch.mock.calls[1][1] as { dispatcher: unknown }).dispatcher - expect(d1).toBe(d2) - }) - - it('pools agents by resolvedIP across createMcpPinnedFetch calls', async () => { - const a = createMcpPinnedFetch('203.0.113.10') - const b = createMcpPinnedFetch('203.0.113.10') - await a('https://example.com/a') - await b('https://example.com/b') - expect(capturedAgentOptions).toHaveLength(1) - const d1 = (mockUndiciFetch.mock.calls[0][1] as { dispatcher: unknown }).dispatcher - const d2 = (mockUndiciFetch.mock.calls[1][1] as { dispatcher: unknown }).dispatcher - expect(d1).toBe(d2) - }) - - it('creates separate agents for different resolved IPs', async () => { - const a = createMcpPinnedFetch('203.0.113.10') - const b = createMcpPinnedFetch('198.51.100.20') - await a('https://example.com/a') - await b('https://example.com/b') - expect(capturedAgentOptions).toHaveLength(2) - const d1 = (mockUndiciFetch.mock.calls[0][1] as { dispatcher: unknown }).dispatcher - const d2 = (mockUndiciFetch.mock.calls[1][1] as { dispatcher: unknown }).dispatcher - expect(d1).not.toBe(d2) - }) - - it('does not close evicted agents — captured closures keep working', async () => { - // Build an early closure whose agent will get evicted by later IPs. - const earlyClient = createMcpPinnedFetch('10.0.0.1') - // Fill the cache past its 64-entry limit so the early entry is evicted. - for (let i = 0; i < 64; i++) createMcpPinnedFetch(`10.1.${Math.floor(i / 256)}.${i % 256}`) - - // Eviction must NOT have closed any agents. - expect(agentCloses).toHaveLength(0) - // The early closure's captured dispatcher is still callable. - await earlyClient('https://example.com/still-works') - expect(mockUndiciFetch).toHaveBeenCalledTimes(1) - }) -}) +import { createSsrfGuardedMcpFetch } from '@/lib/mcp/pinned-fetch' describe('createSsrfGuardedMcpFetch', () => { beforeEach(() => { vi.clearAllMocks() - capturedAgentOptions.length = 0 - __resetPinnedAgentsForTests() - mockCreatePinnedLookup.mockReturnValue('pinned-lookup-fn') - mockUndiciFetch.mockResolvedValue(new Response('ok')) + mockCreatePinnedFetch.mockReturnValue(sentinelFetch) + sentinelFetch.mockResolvedValue(new Response('ok')) }) it('validates each request URL and pins to the resolved IP', async () => { @@ -152,11 +31,10 @@ describe('createSsrfGuardedMcpFetch', () => { await fetchLike('https://attacker.example/revoke', { method: 'POST' }) expect(mockValidateMcpServerSsrf).toHaveBeenCalledWith('https://attacker.example/revoke') - expect(mockUndiciFetch).toHaveBeenCalledTimes(1) - const [url, init] = mockUndiciFetch.mock.calls[0] - expect(url).toBe('https://attacker.example/revoke') - expect((init as { dispatcher?: unknown }).dispatcher).toBeInstanceOf(mockAgent) - expect((init as { method?: string }).method).toBe('POST') + expect(mockCreatePinnedFetch).toHaveBeenCalledWith('203.0.113.10') + expect(sentinelFetch).toHaveBeenCalledWith('https://attacker.example/revoke', { + method: 'POST', + }) }) it('rejects URLs that resolve to blocked IPs without issuing the request', async () => { @@ -166,7 +44,8 @@ describe('createSsrfGuardedMcpFetch', () => { await expect( fetchLike('http://169.254.169.254/latest/meta-data/', { method: 'POST' }) ).rejects.toThrow('blocked') - expect(mockUndiciFetch).not.toHaveBeenCalled() + expect(mockCreatePinnedFetch).not.toHaveBeenCalled() + expect(sentinelFetch).not.toHaveBeenCalled() }) it('accepts URL objects and validates their href', async () => { @@ -175,6 +54,14 @@ describe('createSsrfGuardedMcpFetch', () => { await fetchLike(new URL('https://attacker.example/discover')) expect(mockValidateMcpServerSsrf).toHaveBeenCalledWith('https://attacker.example/discover') - expect(mockUndiciFetch).toHaveBeenCalledTimes(1) + expect(mockCreatePinnedFetch).toHaveBeenCalledWith('203.0.113.10') + }) + + it('falls back to global fetch when validation returns no IP', async () => { + mockValidateMcpServerSsrf.mockResolvedValue(null) + const fetchLike = createSsrfGuardedMcpFetch() + await fetchLike('https://allowed.internal/mcp') + + expect(mockCreatePinnedFetch).not.toHaveBeenCalled() }) }) diff --git a/apps/sim/lib/mcp/pinned-fetch.ts b/apps/sim/lib/mcp/pinned-fetch.ts index f395d6b03c5..3184a0da7a9 100644 --- a/apps/sim/lib/mcp/pinned-fetch.ts +++ b/apps/sim/lib/mcp/pinned-fetch.ts @@ -1,61 +1,7 @@ import type { FetchLike } from '@modelcontextprotocol/sdk/shared/transport.js' -import { Agent, type RequestInit as UndiciRequestInit, fetch as undiciFetch } from 'undici' -import { createPinnedLookup } from '@/lib/core/security/input-validation.server' +import { createPinnedFetch } from '@/lib/core/security/input-validation.server' import { validateMcpServerSsrf } from '@/lib/mcp/domain-check' -/** - * Pins outbound HTTP connections to a pre-resolved IP to prevent DNS-rebinding - * between URL validation and connection. Hostname is preserved so TLS SNI and - * the Host header still match the certificate. - * - * Agents are pooled by `resolvedIP` so back-to-back calls to the same server - * reuse the same keep-alive connection pool instead of opening a fresh TCP + - * TLS connection per McpClient instance. - */ -const MAX_POOLED_AGENTS = 64 -const pinnedAgents = new Map() - -function getPinnedAgent(resolvedIP: string): Agent { - const existing = pinnedAgents.get(resolvedIP) - if (existing) { - // LRU touch — re-insert to mark as most recently used. - pinnedAgents.delete(resolvedIP) - pinnedAgents.set(resolvedIP, existing) - return existing - } - if (pinnedAgents.size >= MAX_POOLED_AGENTS) { - // Drop the oldest entry WITHOUT closing it — existing `createMcpPinnedFetch` - // closures may still hold a reference and have in-flight requests. The - // dispatcher is GC'd (and its sockets cleaned up) when the last closure - // releases it; undici closes idle keep-alive connections after its own - // timeout (default 4s). - const oldestKey = pinnedAgents.keys().next().value - if (oldestKey !== undefined) pinnedAgents.delete(oldestKey) - } - const agent = new Agent({ connect: { lookup: createPinnedLookup(resolvedIP) } }) - pinnedAgents.set(resolvedIP, agent) - return agent -} - -export function __resetPinnedAgentsForTests(): void { - pinnedAgents.clear() -} - -export function createMcpPinnedFetch(resolvedIP: string): FetchLike { - const dispatcher = getPinnedAgent(resolvedIP) - - return (async (url, init) => { - const undiciInit: UndiciRequestInit = { - // double-cast-allowed: DOM RequestInit and undici RequestInit are structurally compatible at runtime (Node's global fetch IS undici) but the TS types differ - ...(init as unknown as UndiciRequestInit), - dispatcher, - } - const response = await undiciFetch(url as string | URL, undiciInit) - // double-cast-allowed: undici Response and DOM Response are structurally compatible at runtime; bridging the types is required to satisfy the FetchLike contract - return response as unknown as Response - }) satisfies FetchLike -} - /** * Builds a `FetchLike` that validates every outbound request URL against the * MCP SSRF policy before issuing it, then pins the connection to the resolved @@ -79,7 +25,7 @@ export function createSsrfGuardedMcpFetch(): FetchLike { return (async (url, init) => { const target = typeof url === 'string' ? url : url.href const resolvedIP = await validateMcpServerSsrf(target) - const pinnedFetch: FetchLike = resolvedIP ? createMcpPinnedFetch(resolvedIP) : globalThis.fetch + const pinnedFetch: FetchLike = resolvedIP ? createPinnedFetch(resolvedIP) : globalThis.fetch return pinnedFetch(url, init) }) satisfies FetchLike } diff --git a/apps/sim/providers/azure-anthropic/index.test.ts b/apps/sim/providers/azure-anthropic/index.test.ts new file mode 100644 index 00000000000..b5254f9eaf8 --- /dev/null +++ b/apps/sim/providers/azure-anthropic/index.test.ts @@ -0,0 +1,120 @@ +/** + * @vitest-environment node + */ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import type { ProviderRequest } from '@/providers/types' + +const { + mockAnthropic, + anthropicArgs, + mockValidate, + mockCreatePinnedFetch, + mockExecuteAnthropic, + sentinelFetch, + envState, +} = vi.hoisted(() => { + const anthropicArgs: Array> = [] + const sentinelFetch = vi.fn() + class MockAnthropic { + constructor(opts: Record) { + anthropicArgs.push(opts) + } + } + return { + mockAnthropic: MockAnthropic, + anthropicArgs, + mockValidate: vi.fn(), + mockCreatePinnedFetch: vi.fn(() => sentinelFetch), + mockExecuteAnthropic: vi.fn(), + sentinelFetch, + envState: { + AZURE_ANTHROPIC_ENDPOINT: undefined as string | undefined, + AZURE_ANTHROPIC_API_VERSION: undefined as string | undefined, + }, + } +}) + +vi.mock('@anthropic-ai/sdk', () => ({ default: mockAnthropic })) +vi.mock('@/lib/core/config/env', () => ({ env: envState })) +vi.mock('@/lib/core/security/input-validation.server', () => ({ + validateUrlWithDNS: mockValidate, + createPinnedFetch: mockCreatePinnedFetch, +})) +vi.mock('@/providers/anthropic/core', () => ({ + executeAnthropicProviderRequest: mockExecuteAnthropic, +})) +vi.mock('@/providers/models', () => ({ + getProviderModels: vi.fn(() => []), + getProviderDefaultModel: vi.fn(() => 'azure-anthropic/claude'), +})) + +import { azureAnthropicProvider } from '@/providers/azure-anthropic/index' + +function request(overrides: Partial): ProviderRequest { + return { model: 'azure-anthropic/claude-3-5-sonnet', apiKey: 'k', messages: [], ...overrides } +} + +/** Invokes the createClient factory handed to the Anthropic core and returns the SDK options it built. */ +function buildClientOptions(): Record { + const config = mockExecuteAnthropic.mock.calls[0][1] + config.createClient('k', false) + return anthropicArgs[0] +} + +describe('azureAnthropicProvider — SSRF pinning', () => { + beforeEach(() => { + vi.clearAllMocks() + anthropicArgs.length = 0 + envState.AZURE_ANTHROPIC_ENDPOINT = undefined + envState.AZURE_ANTHROPIC_API_VERSION = undefined + mockExecuteAnthropic.mockResolvedValue({ content: 'ok' }) + }) + + it('validates and pins the connection to the resolved IP for a user-supplied endpoint', async () => { + mockValidate.mockResolvedValue({ isValid: true, resolvedIP: '203.0.113.10' }) + + await azureAnthropicProvider.executeRequest( + request({ azureEndpoint: 'https://rebind.attacker.tld' }) + ) + + expect(mockValidate).toHaveBeenCalledWith('https://rebind.attacker.tld', 'azureEndpoint') + expect(mockCreatePinnedFetch).toHaveBeenCalledWith('203.0.113.10') + expect(buildClientOptions()).toMatchObject({ fetch: sentinelFetch }) + }) + + it('does not pin when the endpoint comes from trusted server env', async () => { + envState.AZURE_ANTHROPIC_ENDPOINT = 'https://trusted.services.ai.azure.com' + + await azureAnthropicProvider.executeRequest(request({ azureEndpoint: undefined })) + + expect(mockValidate).not.toHaveBeenCalled() + expect(mockCreatePinnedFetch).not.toHaveBeenCalled() + expect(buildClientOptions()).not.toHaveProperty('fetch') + }) + + it('throws and never builds a client when validation blocks the endpoint', async () => { + mockValidate.mockResolvedValue({ isValid: false, error: 'resolves to a blocked IP address' }) + + await expect( + azureAnthropicProvider.executeRequest( + request({ azureEndpoint: 'https://rebind.attacker.tld' }) + ) + ).rejects.toThrow('Invalid Azure Anthropic endpoint') + + expect(mockCreatePinnedFetch).not.toHaveBeenCalled() + expect(mockExecuteAnthropic).not.toHaveBeenCalled() + }) + + it('fails closed when validation passes but yields no resolvable IP to pin', async () => { + mockValidate.mockResolvedValue({ isValid: true }) + + await expect( + azureAnthropicProvider.executeRequest( + request({ azureEndpoint: 'https://rebind.attacker.tld' }) + ) + ).rejects.toThrow('could not resolve a pinnable IP address') + + expect(mockCreatePinnedFetch).not.toHaveBeenCalled() + expect(mockExecuteAnthropic).not.toHaveBeenCalled() + }) +}) diff --git a/apps/sim/providers/azure-anthropic/index.ts b/apps/sim/providers/azure-anthropic/index.ts index 999dc0938f8..39980d77c2e 100644 --- a/apps/sim/providers/azure-anthropic/index.ts +++ b/apps/sim/providers/azure-anthropic/index.ts @@ -1,7 +1,7 @@ import Anthropic from '@anthropic-ai/sdk' import { createLogger } from '@sim/logger' import { env } from '@/lib/core/config/env' -import { validateUrlWithDNS } from '@/lib/core/security/input-validation.server' +import { createPinnedFetch, validateUrlWithDNS } from '@/lib/core/security/input-validation.server' import type { StreamingExecution } from '@/executor/types' import { executeAnthropicProviderRequest } from '@/providers/anthropic/core' import { getProviderDefaultModel, getProviderModels } from '@/providers/models' @@ -28,6 +28,7 @@ export const azureAnthropicProvider: ProviderConfig = { ) } + let pinnedFetch: typeof fetch | undefined if (userProvidedEndpoint) { const validation = await validateUrlWithDNS(userProvidedEndpoint, 'azureEndpoint') if (!validation.isValid) { @@ -37,6 +38,10 @@ export const azureAnthropicProvider: ProviderConfig = { }) throw new Error(`Invalid Azure Anthropic endpoint: ${validation.error}`) } + if (!validation.resolvedIP) { + throw new Error('Invalid Azure Anthropic endpoint: could not resolve a pinnable IP address') + } + pinnedFetch = createPinnedFetch(validation.resolvedIP) } const apiKey = request.apiKey @@ -67,6 +72,7 @@ export const azureAnthropicProvider: ProviderConfig = { new Anthropic({ baseURL, apiKey, + ...(pinnedFetch ? { fetch: pinnedFetch } : {}), defaultHeaders: { 'api-key': apiKey, 'anthropic-version': anthropicVersion, diff --git a/apps/sim/providers/azure-openai/index.test.ts b/apps/sim/providers/azure-openai/index.test.ts new file mode 100644 index 00000000000..7e18ea809df --- /dev/null +++ b/apps/sim/providers/azure-openai/index.test.ts @@ -0,0 +1,190 @@ +/** + * @vitest-environment node + */ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import type { ProviderRequest } from '@/providers/types' + +const { + mockAzureOpenAI, + azureOpenAIArgs, + mockChatCreate, + mockValidate, + mockCreatePinnedFetch, + mockExecuteResponses, + sentinelFetch, + mockIsChatCompletionsEndpoint, + mockIsResponsesEndpoint, + envState, +} = vi.hoisted(() => { + const azureOpenAIArgs: Array> = [] + const sentinelFetch = vi.fn() + const mockChatCreate = vi.fn() + class MockAzureOpenAI { + chat = { completions: { create: mockChatCreate } } + constructor(opts: Record) { + azureOpenAIArgs.push(opts) + } + } + return { + mockAzureOpenAI: MockAzureOpenAI, + azureOpenAIArgs, + mockChatCreate, + mockValidate: vi.fn(), + mockCreatePinnedFetch: vi.fn(() => sentinelFetch), + mockExecuteResponses: vi.fn(), + sentinelFetch, + mockIsChatCompletionsEndpoint: vi.fn(() => false), + mockIsResponsesEndpoint: vi.fn(() => false), + envState: { + AZURE_OPENAI_ENDPOINT: undefined as string | undefined, + AZURE_OPENAI_API_VERSION: undefined as string | undefined, + }, + } +}) + +vi.mock('openai', () => ({ AzureOpenAI: mockAzureOpenAI })) +vi.mock('@/lib/core/config/env', () => ({ env: envState })) +vi.mock('@/providers', () => ({ MAX_TOOL_ITERATIONS: 20 })) +vi.mock('@/lib/core/security/input-validation.server', () => ({ + validateUrlWithDNS: mockValidate, + createPinnedFetch: mockCreatePinnedFetch, +})) +vi.mock('@/providers/openai/core', () => ({ + executeResponsesProviderRequest: mockExecuteResponses, +})) +vi.mock('@/providers/azure-openai/utils', () => ({ + isChatCompletionsEndpoint: mockIsChatCompletionsEndpoint, + isResponsesEndpoint: mockIsResponsesEndpoint, + extractBaseUrl: vi.fn((url: string) => url), + extractDeploymentFromUrl: vi.fn(() => null), + extractApiVersionFromUrl: vi.fn(() => null), + createReadableStreamFromAzureOpenAIStream: vi.fn(), + checkForForcedToolUsage: vi.fn(() => ({ hasUsedForcedTool: false, usedForcedTools: [] })), +})) +vi.mock('@/providers/models', () => ({ + getProviderModels: vi.fn(() => []), + getProviderDefaultModel: vi.fn(() => 'azure/gpt-4o'), +})) +vi.mock('@/providers/attachments', () => ({ + prepareProviderAttachments: vi.fn(() => []), +})) +vi.mock('@/providers/trace-enrichment', () => ({ + enrichLastModelSegmentFromChatCompletions: vi.fn(), +})) +vi.mock('@/providers/utils', () => ({ + calculateCost: vi.fn(() => ({ input: 0, output: 0, total: 0 })), + prepareToolExecution: vi.fn((_tool, args) => ({ toolParams: args, executionParams: args })), + prepareToolsWithUsageControl: vi.fn(() => ({ + tools: [], + toolChoice: undefined, + forcedTools: [], + })), + sumToolCosts: vi.fn(() => 0), +})) +vi.mock('@/tools', () => ({ executeTool: vi.fn() })) + +import { azureOpenAIProvider } from '@/providers/azure-openai/index' + +function request(overrides: Partial): ProviderRequest { + return { model: 'azure/gpt-4o', apiKey: 'k', messages: [], ...overrides } +} + +/** Config object passed to the Responses core on the Nth call. */ +const responsesConfig = (call = 0) => mockExecuteResponses.mock.calls[call][1] + +describe('azureOpenAIProvider — SSRF pinning', () => { + beforeEach(() => { + vi.clearAllMocks() + azureOpenAIArgs.length = 0 + envState.AZURE_OPENAI_ENDPOINT = undefined + envState.AZURE_OPENAI_API_VERSION = undefined + mockIsChatCompletionsEndpoint.mockReturnValue(false) + mockIsResponsesEndpoint.mockReturnValue(false) + mockExecuteResponses.mockResolvedValue({ content: 'ok' }) + }) + + describe('Responses API path', () => { + it('validates and threads the pinned fetch into the Responses core for a user endpoint', async () => { + mockValidate.mockResolvedValue({ isValid: true, resolvedIP: '203.0.113.10' }) + + await azureOpenAIProvider.executeRequest( + request({ azureEndpoint: 'https://rebind.attacker.tld' }) + ) + + expect(mockValidate).toHaveBeenCalledWith('https://rebind.attacker.tld', 'azureEndpoint') + expect(mockCreatePinnedFetch).toHaveBeenCalledWith('203.0.113.10') + expect(responsesConfig().fetch).toBe(sentinelFetch) + }) + + it('passes no custom fetch when the endpoint comes from trusted server env', async () => { + envState.AZURE_OPENAI_ENDPOINT = 'https://trusted.openai.azure.com' + + await azureOpenAIProvider.executeRequest(request({ azureEndpoint: undefined })) + + expect(mockValidate).not.toHaveBeenCalled() + expect(mockCreatePinnedFetch).not.toHaveBeenCalled() + expect(responsesConfig().fetch).toBeUndefined() + }) + + it('throws and never reaches the Responses core when validation blocks the endpoint', async () => { + mockValidate.mockResolvedValue({ isValid: false, error: 'resolves to a blocked IP address' }) + + await expect( + azureOpenAIProvider.executeRequest( + request({ azureEndpoint: 'https://rebind.attacker.tld' }) + ) + ).rejects.toThrow('Invalid Azure OpenAI endpoint') + + expect(mockCreatePinnedFetch).not.toHaveBeenCalled() + expect(mockExecuteResponses).not.toHaveBeenCalled() + }) + + it('fails closed when validation passes but yields no resolvable IP to pin', async () => { + mockValidate.mockResolvedValue({ isValid: true }) + + await expect( + azureOpenAIProvider.executeRequest( + request({ azureEndpoint: 'https://rebind.attacker.tld' }) + ) + ).rejects.toThrow('could not resolve a pinnable IP address') + + expect(mockCreatePinnedFetch).not.toHaveBeenCalled() + expect(mockExecuteResponses).not.toHaveBeenCalled() + }) + }) + + describe('Chat Completions path', () => { + it('constructs the AzureOpenAI client with the pinned fetch for a user endpoint', async () => { + mockIsChatCompletionsEndpoint.mockReturnValue(true) + mockValidate.mockResolvedValue({ isValid: true, resolvedIP: '203.0.113.10' }) + mockChatCreate.mockResolvedValue({ + choices: [{ message: { content: 'hi', tool_calls: undefined } }], + usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 }, + }) + + await azureOpenAIProvider.executeRequest( + request({ + azureEndpoint: 'https://rebind.attacker.tld/openai/deployments/gpt-4o/chat/completions', + }) + ) + + expect(mockCreatePinnedFetch).toHaveBeenCalledWith('203.0.113.10') + expect(azureOpenAIArgs[0]).toMatchObject({ fetch: sentinelFetch }) + }) + + it('constructs the AzureOpenAI client without a custom fetch for a trusted env endpoint', async () => { + mockIsChatCompletionsEndpoint.mockReturnValue(true) + envState.AZURE_OPENAI_ENDPOINT = + 'https://trusted.openai.azure.com/openai/deployments/gpt-4o/chat/completions' + mockChatCreate.mockResolvedValue({ + choices: [{ message: { content: 'hi', tool_calls: undefined } }], + usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 }, + }) + + await azureOpenAIProvider.executeRequest(request({ azureEndpoint: undefined })) + + expect(mockCreatePinnedFetch).not.toHaveBeenCalled() + expect(azureOpenAIArgs[0]).not.toHaveProperty('fetch') + }) + }) +}) diff --git a/apps/sim/providers/azure-openai/index.ts b/apps/sim/providers/azure-openai/index.ts index ddae32b7fb4..24d07184282 100644 --- a/apps/sim/providers/azure-openai/index.ts +++ b/apps/sim/providers/azure-openai/index.ts @@ -12,7 +12,7 @@ import type { } from 'openai/resources/chat/completions' import type { ReasoningEffort } from 'openai/resources/shared' import { env } from '@/lib/core/config/env' -import { validateUrlWithDNS } from '@/lib/core/security/input-validation.server' +import { createPinnedFetch, validateUrlWithDNS } from '@/lib/core/security/input-validation.server' import type { StreamingExecution } from '@/executor/types' import { MAX_TOOL_ITERATIONS } from '@/providers' import { prepareProviderAttachments } from '@/providers/attachments' @@ -56,7 +56,8 @@ async function executeChatCompletionsRequest( request: ProviderRequest, azureEndpoint: string, azureApiVersion: string, - deploymentName: string + deploymentName: string, + pinnedFetch?: typeof fetch ): Promise { logger.info('Using Azure OpenAI Chat Completions API', { model: request.model, @@ -75,6 +76,7 @@ async function executeChatCompletionsRequest( apiKey: request.apiKey!, apiVersion: azureApiVersion, endpoint: azureEndpoint, + ...(pinnedFetch ? { fetch: pinnedFetch } : {}), }) const allMessages: ChatCompletionMessageParam[] = [] @@ -606,6 +608,7 @@ export const azureOpenAIProvider: ProviderConfig = { ) } + let pinnedFetch: typeof fetch | undefined if (userProvidedEndpoint) { const validation = await validateUrlWithDNS(userProvidedEndpoint, 'azureEndpoint') if (!validation.isValid) { @@ -615,6 +618,10 @@ export const azureOpenAIProvider: ProviderConfig = { }) throw new Error(`Invalid Azure OpenAI endpoint: ${validation.error}`) } + if (!validation.resolvedIP) { + throw new Error('Invalid Azure OpenAI endpoint: could not resolve a pinnable IP address') + } + pinnedFetch = createPinnedFetch(validation.resolvedIP) } const apiKey = request.apiKey @@ -652,7 +659,8 @@ export const azureOpenAIProvider: ProviderConfig = { { ...request, apiKey }, baseUrl, azureApiVersion, - deploymentName + deploymentName, + pinnedFetch ) } @@ -676,6 +684,7 @@ export const azureOpenAIProvider: ProviderConfig = { 'api-key': apiKey, }, logger, + fetch: pinnedFetch, } ) } @@ -700,6 +709,7 @@ export const azureOpenAIProvider: ProviderConfig = { 'api-key': apiKey, }, logger, + fetch: pinnedFetch, } ) }, diff --git a/apps/sim/providers/openai/core.ts b/apps/sim/providers/openai/core.ts index c0fa50def86..913700ef5d7 100644 --- a/apps/sim/providers/openai/core.ts +++ b/apps/sim/providers/openai/core.ts @@ -41,6 +41,12 @@ export interface ResponsesProviderConfig { endpoint: string headers: Record logger: Logger + /** + * Optional fetch implementation. Used to pin the connection to a pre-validated + * IP (DNS-rebinding/SSRF protection) when the endpoint is user-supplied. + * Defaults to the global fetch. + */ + fetch?: typeof fetch } /** @@ -51,6 +57,7 @@ export async function executeResponsesProviderRequest( config: ResponsesProviderConfig ): Promise { const { logger } = config + const fetchImpl = config.fetch ?? fetch logger.info(`Preparing ${config.providerLabel} request`, { model: request.model, @@ -207,7 +214,7 @@ export async function executeResponsesProviderRequest( const postResponses = async ( body: Record ): Promise => { - const response = await fetch(config.endpoint, { + const response = await fetchImpl(config.endpoint, { method: 'POST', headers: config.headers, body: JSON.stringify(body), @@ -229,7 +236,7 @@ export async function executeResponsesProviderRequest( if (request.stream && (!tools || tools.length === 0)) { logger.info(`Using streaming response for ${config.providerLabel} request`) - const streamResponse = await fetch(config.endpoint, { + const streamResponse = await fetchImpl(config.endpoint, { method: 'POST', headers: config.headers, body: JSON.stringify(createRequestBody(initialInput, { stream: true })), @@ -643,7 +650,7 @@ export async function executeResponsesProviderRequest( } } - const streamResponse = await fetch(config.endpoint, { + const streamResponse = await fetchImpl(config.endpoint, { method: 'POST', headers: config.headers, body: JSON.stringify(createRequestBody(currentInput, streamOverrides)),