From 2656dc55f974e9e19d25886d9031d86ac86a7208 Mon Sep 17 00:00:00 2001 From: waleed Date: Mon, 15 Jun 2026 17:07:18 -0700 Subject: [PATCH] fix(providers): pin vLLM provider endpoint to validated IP Validate the user-supplied vLLM endpoint (request.azureEndpoint) against the central SSRF guard and pin the connection to the resolved IP before issuing any request, mirroring the Azure OpenAI/Anthropic providers. The operator-configured VLLM_BASE_URL stays trusted and unvalidated. --- apps/sim/providers/vllm/index.test.ts | 127 ++++++++++++++++++++++---- apps/sim/providers/vllm/index.ts | 28 +++++- 2 files changed, 136 insertions(+), 19 deletions(-) diff --git a/apps/sim/providers/vllm/index.test.ts b/apps/sim/providers/vllm/index.test.ts index 829c48168f1..d81e696dae8 100644 --- a/apps/sim/providers/vllm/index.test.ts +++ b/apps/sim/providers/vllm/index.test.ts @@ -5,31 +5,50 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' const { mockCreate, + openAIArgs, + mockOpenAI, mockExecuteTool, mockPrepareTools, mockCheckForced, mockCreateStream, + mockValidateUrlWithDNS, + mockCreatePinnedFetch, + pinnedFetchFn, envState, -} = vi.hoisted(() => ({ - mockCreate: vi.fn(), - mockExecuteTool: vi.fn(), - mockPrepareTools: vi.fn(), - mockCheckForced: vi.fn(), - mockCreateStream: vi.fn(), - envState: { - VLLM_BASE_URL: 'http://localhost:8000', - VLLM_API_KEY: undefined as string | undefined, - }, -})) - -vi.mock('openai', () => ({ - default: vi.fn().mockImplementation( - class { - chat = { completions: { create: mockCreate } } +} = vi.hoisted(() => { + const openAIArgs: Array> = [] + const mockCreate = vi.fn() + const pinnedFetchFn = vi.fn() + class MockOpenAI { + chat = { completions: { create: mockCreate } } + constructor(opts: Record) { + openAIArgs.push(opts) } - ), -})) + } + return { + mockCreate, + openAIArgs, + mockOpenAI: MockOpenAI, + mockExecuteTool: vi.fn(), + mockPrepareTools: vi.fn(), + mockCheckForced: vi.fn(), + mockCreateStream: vi.fn(), + mockValidateUrlWithDNS: vi.fn(), + mockCreatePinnedFetch: vi.fn(() => pinnedFetchFn), + pinnedFetchFn, + envState: { + VLLM_BASE_URL: 'http://localhost:8000', + VLLM_API_KEY: undefined as string | undefined, + }, + } +}) + +vi.mock('openai', () => ({ default: mockOpenAI })) vi.mock('@/lib/core/config/env', () => ({ env: envState })) +vi.mock('@/lib/core/security/input-validation.server', () => ({ + validateUrlWithDNS: mockValidateUrlWithDNS, + createPinnedFetch: mockCreatePinnedFetch, +})) vi.mock('@/providers', () => ({ MAX_TOOL_ITERATIONS: 20 })) vi.mock('@/providers/models', () => ({ getProviderModels: vi.fn(() => []), @@ -94,6 +113,7 @@ const createPayload = (callIndex: number) => mockCreate.mock.calls[callIndex][0] describe('vllmProvider', () => { beforeEach(() => { vi.clearAllMocks() + openAIArgs.length = 0 envState.VLLM_BASE_URL = 'http://localhost:8000' envState.VLLM_API_KEY = undefined mockPrepareTools.mockReturnValue({ @@ -105,6 +125,77 @@ describe('vllmProvider', () => { mockCheckForced.mockReturnValue({ hasUsedForcedTool: false, usedForcedTools: [] }) mockCreateStream.mockReturnValue(new ReadableStream({ start: (c) => c.close() })) mockExecuteTool.mockResolvedValue({ success: true, output: { result: 'ok' } }) + mockValidateUrlWithDNS.mockResolvedValue({ isValid: true, resolvedIP: '203.0.113.10' }) + mockCreatePinnedFetch.mockReturnValue(pinnedFetchFn) + }) + + describe('endpoint SSRF protection', () => { + it('does not validate or pin when no endpoint is supplied (uses env base URL)', async () => { + mockCreate.mockResolvedValueOnce(chatResponse('hi')) + + await vllmProvider.executeRequest({ + model: 'vllm/llama-3', + messages: [{ role: 'user', content: 'hi' }], + }) + + expect(mockValidateUrlWithDNS).not.toHaveBeenCalled() + expect(mockCreatePinnedFetch).not.toHaveBeenCalled() + expect(openAIArgs[0].baseURL).toBe('http://localhost:8000/v1') + expect(openAIArgs[0].fetch).toBeUndefined() + }) + + it('validates a user-supplied endpoint and pins the connection to the resolved IP', async () => { + mockCreate.mockResolvedValueOnce(chatResponse('hi')) + + await vllmProvider.executeRequest({ + model: 'vllm/llama-3', + messages: [{ role: 'user', content: 'hi' }], + azureEndpoint: 'https://my-vllm.example.com', + }) + + expect(mockValidateUrlWithDNS).toHaveBeenCalledWith( + 'https://my-vllm.example.com', + 'vLLM endpoint' + ) + expect(mockCreatePinnedFetch).toHaveBeenCalledWith('203.0.113.10') + expect(openAIArgs[0].baseURL).toBe('https://my-vllm.example.com/v1') + expect(openAIArgs[0].fetch).toBe(pinnedFetchFn) + }) + + it('rejects a user-supplied endpoint that fails SSRF validation without issuing a request', async () => { + mockValidateUrlWithDNS.mockResolvedValueOnce({ + isValid: false, + error: 'vLLM endpoint resolves to a blocked IP address', + }) + + await expect( + vllmProvider.executeRequest({ + model: 'vllm/llama-3', + messages: [{ role: 'user', content: 'hi' }], + azureEndpoint: 'http://169.254.169.254', + }) + ).rejects.toThrow('Invalid vLLM endpoint') + + expect(mockCreatePinnedFetch).not.toHaveBeenCalled() + expect(openAIArgs).toHaveLength(0) + expect(mockCreate).not.toHaveBeenCalled() + }) + + it('rejects a validated endpoint that did not resolve to a pinnable IP', async () => { + mockValidateUrlWithDNS.mockResolvedValueOnce({ isValid: true }) + + await expect( + vllmProvider.executeRequest({ + model: 'vllm/llama-3', + messages: [{ role: 'user', content: 'hi' }], + azureEndpoint: 'https://my-vllm.example.com', + }) + ).rejects.toThrow('could not resolve a pinnable IP address') + + expect(mockCreatePinnedFetch).not.toHaveBeenCalled() + expect(openAIArgs).toHaveLength(0) + expect(mockCreate).not.toHaveBeenCalled() + }) }) it('builds a chat payload with the vllm/ prefix stripped and messages assembled in order', async () => { diff --git a/apps/sim/providers/vllm/index.ts b/apps/sim/providers/vllm/index.ts index 7251ec1c31a..90f6c7c0a3b 100644 --- a/apps/sim/providers/vllm/index.ts +++ b/apps/sim/providers/vllm/index.ts @@ -3,6 +3,7 @@ import { getErrorMessage, toError } from '@sim/utils/errors' import OpenAI from 'openai' import type { ChatCompletionCreateParamsStreaming } from 'openai/resources/chat/completions' import { env } from '@/lib/core/config/env' +import { createPinnedFetch, validateUrlWithDNS } from '@/lib/core/security/input-validation.server' import type { StreamingExecution } from '@/executor/types' import { MAX_TOOL_ITERATIONS } from '@/providers' import { formatMessagesForProvider } from '@/providers/attachments' @@ -95,15 +96,40 @@ export const vllmProvider: ProviderConfig = { stream: !!request.stream, }) - const baseUrl = (request.azureEndpoint || env.VLLM_BASE_URL || '').replace(/\/$/, '') + const userProvidedEndpoint = request.azureEndpoint + + const baseUrl = (userProvidedEndpoint || env.VLLM_BASE_URL || '').replace(/\/$/, '') if (!baseUrl) { throw new Error('VLLM_BASE_URL is required for vLLM provider') } + /** + * A user-supplied endpoint is attacker-controlled: validate it against the + * central SSRF guard and pin the connection to the resolved IP to defeat DNS + * rebinding. The operator-configured `VLLM_BASE_URL` is trusted and left + * unvalidated, mirroring the Azure providers. + */ + let pinnedFetch: typeof fetch | undefined + if (userProvidedEndpoint) { + const validation = await validateUrlWithDNS(userProvidedEndpoint, 'vLLM endpoint') + if (!validation.isValid) { + logger.warn('Blocked SSRF attempt via vLLM endpoint', { + endpoint: userProvidedEndpoint, + error: validation.error, + }) + throw new Error(`Invalid vLLM endpoint: ${validation.error}`) + } + if (!validation.resolvedIP) { + throw new Error('Invalid vLLM endpoint: could not resolve a pinnable IP address') + } + pinnedFetch = createPinnedFetch(validation.resolvedIP) + } + const apiKey = request.apiKey || env.VLLM_API_KEY || 'empty' const vllm = new OpenAI({ apiKey, baseURL: `${baseUrl}/v1`, + ...(pinnedFetch ? { fetch: pinnedFetch } : {}), }) const allMessages: Message[] = []