Skip to content

Commit 4a7c2ef

Browse files
authored
fix(providers): pin vLLM provider endpoint to validated IP (#5077)
Validate the user-supplied vLLM endpoint (request.azureEndpoint) against the central SSRF guard and pin the connection to the resolved IP before issuing any request, mirroring the Azure OpenAI/Anthropic providers. The operator-configured VLLM_BASE_URL stays trusted and unvalidated.
1 parent a09e393 commit 4a7c2ef

2 files changed

Lines changed: 136 additions & 19 deletions

File tree

apps/sim/providers/vllm/index.test.ts

Lines changed: 109 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,31 +5,50 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'
55

66
const {
77
mockCreate,
8+
openAIArgs,
9+
mockOpenAI,
810
mockExecuteTool,
911
mockPrepareTools,
1012
mockCheckForced,
1113
mockCreateStream,
14+
mockValidateUrlWithDNS,
15+
mockCreatePinnedFetch,
16+
pinnedFetchFn,
1217
envState,
13-
} = vi.hoisted(() => ({
14-
mockCreate: vi.fn(),
15-
mockExecuteTool: vi.fn(),
16-
mockPrepareTools: vi.fn(),
17-
mockCheckForced: vi.fn(),
18-
mockCreateStream: vi.fn(),
19-
envState: {
20-
VLLM_BASE_URL: 'http://localhost:8000',
21-
VLLM_API_KEY: undefined as string | undefined,
22-
},
23-
}))
24-
25-
vi.mock('openai', () => ({
26-
default: vi.fn().mockImplementation(
27-
class {
28-
chat = { completions: { create: mockCreate } }
18+
} = vi.hoisted(() => {
19+
const openAIArgs: Array<Record<string, unknown>> = []
20+
const mockCreate = vi.fn()
21+
const pinnedFetchFn = vi.fn()
22+
class MockOpenAI {
23+
chat = { completions: { create: mockCreate } }
24+
constructor(opts: Record<string, unknown>) {
25+
openAIArgs.push(opts)
2926
}
30-
),
31-
}))
27+
}
28+
return {
29+
mockCreate,
30+
openAIArgs,
31+
mockOpenAI: MockOpenAI,
32+
mockExecuteTool: vi.fn(),
33+
mockPrepareTools: vi.fn(),
34+
mockCheckForced: vi.fn(),
35+
mockCreateStream: vi.fn(),
36+
mockValidateUrlWithDNS: vi.fn(),
37+
mockCreatePinnedFetch: vi.fn(() => pinnedFetchFn),
38+
pinnedFetchFn,
39+
envState: {
40+
VLLM_BASE_URL: 'http://localhost:8000',
41+
VLLM_API_KEY: undefined as string | undefined,
42+
},
43+
}
44+
})
45+
46+
vi.mock('openai', () => ({ default: mockOpenAI }))
3247
vi.mock('@/lib/core/config/env', () => ({ env: envState }))
48+
vi.mock('@/lib/core/security/input-validation.server', () => ({
49+
validateUrlWithDNS: mockValidateUrlWithDNS,
50+
createPinnedFetch: mockCreatePinnedFetch,
51+
}))
3352
vi.mock('@/providers', () => ({ MAX_TOOL_ITERATIONS: 20 }))
3453
vi.mock('@/providers/models', () => ({
3554
getProviderModels: vi.fn(() => []),
@@ -94,6 +113,7 @@ const createPayload = (callIndex: number) => mockCreate.mock.calls[callIndex][0]
94113
describe('vllmProvider', () => {
95114
beforeEach(() => {
96115
vi.clearAllMocks()
116+
openAIArgs.length = 0
97117
envState.VLLM_BASE_URL = 'http://localhost:8000'
98118
envState.VLLM_API_KEY = undefined
99119
mockPrepareTools.mockReturnValue({
@@ -105,6 +125,77 @@ describe('vllmProvider', () => {
105125
mockCheckForced.mockReturnValue({ hasUsedForcedTool: false, usedForcedTools: [] })
106126
mockCreateStream.mockReturnValue(new ReadableStream({ start: (c) => c.close() }))
107127
mockExecuteTool.mockResolvedValue({ success: true, output: { result: 'ok' } })
128+
mockValidateUrlWithDNS.mockResolvedValue({ isValid: true, resolvedIP: '203.0.113.10' })
129+
mockCreatePinnedFetch.mockReturnValue(pinnedFetchFn)
130+
})
131+
132+
describe('endpoint SSRF protection', () => {
133+
it('does not validate or pin when no endpoint is supplied (uses env base URL)', async () => {
134+
mockCreate.mockResolvedValueOnce(chatResponse('hi'))
135+
136+
await vllmProvider.executeRequest({
137+
model: 'vllm/llama-3',
138+
messages: [{ role: 'user', content: 'hi' }],
139+
})
140+
141+
expect(mockValidateUrlWithDNS).not.toHaveBeenCalled()
142+
expect(mockCreatePinnedFetch).not.toHaveBeenCalled()
143+
expect(openAIArgs[0].baseURL).toBe('http://localhost:8000/v1')
144+
expect(openAIArgs[0].fetch).toBeUndefined()
145+
})
146+
147+
it('validates a user-supplied endpoint and pins the connection to the resolved IP', async () => {
148+
mockCreate.mockResolvedValueOnce(chatResponse('hi'))
149+
150+
await vllmProvider.executeRequest({
151+
model: 'vllm/llama-3',
152+
messages: [{ role: 'user', content: 'hi' }],
153+
azureEndpoint: 'https://my-vllm.example.com',
154+
})
155+
156+
expect(mockValidateUrlWithDNS).toHaveBeenCalledWith(
157+
'https://my-vllm.example.com',
158+
'vLLM endpoint'
159+
)
160+
expect(mockCreatePinnedFetch).toHaveBeenCalledWith('203.0.113.10')
161+
expect(openAIArgs[0].baseURL).toBe('https://my-vllm.example.com/v1')
162+
expect(openAIArgs[0].fetch).toBe(pinnedFetchFn)
163+
})
164+
165+
it('rejects a user-supplied endpoint that fails SSRF validation without issuing a request', async () => {
166+
mockValidateUrlWithDNS.mockResolvedValueOnce({
167+
isValid: false,
168+
error: 'vLLM endpoint resolves to a blocked IP address',
169+
})
170+
171+
await expect(
172+
vllmProvider.executeRequest({
173+
model: 'vllm/llama-3',
174+
messages: [{ role: 'user', content: 'hi' }],
175+
azureEndpoint: 'http://169.254.169.254',
176+
})
177+
).rejects.toThrow('Invalid vLLM endpoint')
178+
179+
expect(mockCreatePinnedFetch).not.toHaveBeenCalled()
180+
expect(openAIArgs).toHaveLength(0)
181+
expect(mockCreate).not.toHaveBeenCalled()
182+
})
183+
184+
it('rejects a validated endpoint that did not resolve to a pinnable IP', async () => {
185+
mockValidateUrlWithDNS.mockResolvedValueOnce({ isValid: true })
186+
187+
await expect(
188+
vllmProvider.executeRequest({
189+
model: 'vllm/llama-3',
190+
messages: [{ role: 'user', content: 'hi' }],
191+
azureEndpoint: 'https://my-vllm.example.com',
192+
})
193+
).rejects.toThrow('could not resolve a pinnable IP address')
194+
195+
expect(mockCreatePinnedFetch).not.toHaveBeenCalled()
196+
expect(openAIArgs).toHaveLength(0)
197+
expect(mockCreate).not.toHaveBeenCalled()
198+
})
108199
})
109200

110201
it('builds a chat payload with the vllm/ prefix stripped and messages assembled in order', async () => {

apps/sim/providers/vllm/index.ts

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import { getErrorMessage, toError } from '@sim/utils/errors'
33
import OpenAI from 'openai'
44
import type { ChatCompletionCreateParamsStreaming } from 'openai/resources/chat/completions'
55
import { env } from '@/lib/core/config/env'
6+
import { createPinnedFetch, validateUrlWithDNS } from '@/lib/core/security/input-validation.server'
67
import type { StreamingExecution } from '@/executor/types'
78
import { MAX_TOOL_ITERATIONS } from '@/providers'
89
import { formatMessagesForProvider } from '@/providers/attachments'
@@ -95,15 +96,40 @@ export const vllmProvider: ProviderConfig = {
9596
stream: !!request.stream,
9697
})
9798

98-
const baseUrl = (request.azureEndpoint || env.VLLM_BASE_URL || '').replace(/\/$/, '')
99+
const userProvidedEndpoint = request.azureEndpoint
100+
101+
const baseUrl = (userProvidedEndpoint || env.VLLM_BASE_URL || '').replace(/\/$/, '')
99102
if (!baseUrl) {
100103
throw new Error('VLLM_BASE_URL is required for vLLM provider')
101104
}
102105

106+
/**
107+
* A user-supplied endpoint is attacker-controlled: validate it against the
108+
* central SSRF guard and pin the connection to the resolved IP to defeat DNS
109+
* rebinding. The operator-configured `VLLM_BASE_URL` is trusted and left
110+
* unvalidated, mirroring the Azure providers.
111+
*/
112+
let pinnedFetch: typeof fetch | undefined
113+
if (userProvidedEndpoint) {
114+
const validation = await validateUrlWithDNS(userProvidedEndpoint, 'vLLM endpoint')
115+
if (!validation.isValid) {
116+
logger.warn('Blocked SSRF attempt via vLLM endpoint', {
117+
endpoint: userProvidedEndpoint,
118+
error: validation.error,
119+
})
120+
throw new Error(`Invalid vLLM endpoint: ${validation.error}`)
121+
}
122+
if (!validation.resolvedIP) {
123+
throw new Error('Invalid vLLM endpoint: could not resolve a pinnable IP address')
124+
}
125+
pinnedFetch = createPinnedFetch(validation.resolvedIP)
126+
}
127+
103128
const apiKey = request.apiKey || env.VLLM_API_KEY || 'empty'
104129
const vllm = new OpenAI({
105130
apiKey,
106131
baseURL: `${baseUrl}/v1`,
132+
...(pinnedFetch ? { fetch: pinnedFetch } : {}),
107133
})
108134

109135
const allMessages: Message[] = []

0 commit comments

Comments
 (0)