Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 109 additions & 18 deletions apps/sim/providers/vllm/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,50 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'

const {
mockCreate,
openAIArgs,
mockOpenAI,
mockExecuteTool,
mockPrepareTools,
mockCheckForced,
mockCreateStream,
mockValidateUrlWithDNS,
mockCreatePinnedFetch,
pinnedFetchFn,
envState,
} = vi.hoisted(() => ({
mockCreate: vi.fn(),
mockExecuteTool: vi.fn(),
mockPrepareTools: vi.fn(),
mockCheckForced: vi.fn(),
mockCreateStream: vi.fn(),
envState: {
VLLM_BASE_URL: 'http://localhost:8000',
VLLM_API_KEY: undefined as string | undefined,
},
}))

vi.mock('openai', () => ({
default: vi.fn().mockImplementation(
class {
chat = { completions: { create: mockCreate } }
} = vi.hoisted(() => {
const openAIArgs: Array<Record<string, unknown>> = []
const mockCreate = vi.fn()
const pinnedFetchFn = vi.fn()
class MockOpenAI {
chat = { completions: { create: mockCreate } }
constructor(opts: Record<string, unknown>) {
openAIArgs.push(opts)
}
),
}))
}
return {
mockCreate,
openAIArgs,
mockOpenAI: MockOpenAI,
mockExecuteTool: vi.fn(),
mockPrepareTools: vi.fn(),
mockCheckForced: vi.fn(),
mockCreateStream: vi.fn(),
mockValidateUrlWithDNS: vi.fn(),
mockCreatePinnedFetch: vi.fn(() => pinnedFetchFn),
pinnedFetchFn,
envState: {
VLLM_BASE_URL: 'http://localhost:8000',
VLLM_API_KEY: undefined as string | undefined,
},
}
})

vi.mock('openai', () => ({ default: mockOpenAI }))
vi.mock('@/lib/core/config/env', () => ({ env: envState }))
vi.mock('@/lib/core/security/input-validation.server', () => ({
validateUrlWithDNS: mockValidateUrlWithDNS,
createPinnedFetch: mockCreatePinnedFetch,
}))
vi.mock('@/providers', () => ({ MAX_TOOL_ITERATIONS: 20 }))
vi.mock('@/providers/models', () => ({
getProviderModels: vi.fn(() => []),
Expand Down Expand Up @@ -94,6 +113,7 @@ const createPayload = (callIndex: number) => mockCreate.mock.calls[callIndex][0]
describe('vllmProvider', () => {
beforeEach(() => {
vi.clearAllMocks()
openAIArgs.length = 0
envState.VLLM_BASE_URL = 'http://localhost:8000'
envState.VLLM_API_KEY = undefined
mockPrepareTools.mockReturnValue({
Expand All @@ -105,6 +125,77 @@ describe('vllmProvider', () => {
mockCheckForced.mockReturnValue({ hasUsedForcedTool: false, usedForcedTools: [] })
mockCreateStream.mockReturnValue(new ReadableStream({ start: (c) => c.close() }))
mockExecuteTool.mockResolvedValue({ success: true, output: { result: 'ok' } })
mockValidateUrlWithDNS.mockResolvedValue({ isValid: true, resolvedIP: '203.0.113.10' })
mockCreatePinnedFetch.mockReturnValue(pinnedFetchFn)
})

describe('endpoint SSRF protection', () => {
it('does not validate or pin when no endpoint is supplied (uses env base URL)', async () => {
mockCreate.mockResolvedValueOnce(chatResponse('hi'))

await vllmProvider.executeRequest({
model: 'vllm/llama-3',
messages: [{ role: 'user', content: 'hi' }],
})

expect(mockValidateUrlWithDNS).not.toHaveBeenCalled()
expect(mockCreatePinnedFetch).not.toHaveBeenCalled()
expect(openAIArgs[0].baseURL).toBe('http://localhost:8000/v1')
expect(openAIArgs[0].fetch).toBeUndefined()
})

it('validates a user-supplied endpoint and pins the connection to the resolved IP', async () => {
mockCreate.mockResolvedValueOnce(chatResponse('hi'))

await vllmProvider.executeRequest({
model: 'vllm/llama-3',
messages: [{ role: 'user', content: 'hi' }],
azureEndpoint: 'https://my-vllm.example.com',
})

expect(mockValidateUrlWithDNS).toHaveBeenCalledWith(
'https://my-vllm.example.com',
'vLLM endpoint'
)
expect(mockCreatePinnedFetch).toHaveBeenCalledWith('203.0.113.10')
expect(openAIArgs[0].baseURL).toBe('https://my-vllm.example.com/v1')
expect(openAIArgs[0].fetch).toBe(pinnedFetchFn)
})

it('rejects a user-supplied endpoint that fails SSRF validation without issuing a request', async () => {
mockValidateUrlWithDNS.mockResolvedValueOnce({
isValid: false,
error: 'vLLM endpoint resolves to a blocked IP address',
})

await expect(
vllmProvider.executeRequest({
model: 'vllm/llama-3',
messages: [{ role: 'user', content: 'hi' }],
azureEndpoint: 'http://169.254.169.254',
})
).rejects.toThrow('Invalid vLLM endpoint')

expect(mockCreatePinnedFetch).not.toHaveBeenCalled()
expect(openAIArgs).toHaveLength(0)
expect(mockCreate).not.toHaveBeenCalled()
})

it('rejects a validated endpoint that did not resolve to a pinnable IP', async () => {
mockValidateUrlWithDNS.mockResolvedValueOnce({ isValid: true })

await expect(
vllmProvider.executeRequest({
model: 'vllm/llama-3',
messages: [{ role: 'user', content: 'hi' }],
azureEndpoint: 'https://my-vllm.example.com',
})
).rejects.toThrow('could not resolve a pinnable IP address')

expect(mockCreatePinnedFetch).not.toHaveBeenCalled()
expect(openAIArgs).toHaveLength(0)
expect(mockCreate).not.toHaveBeenCalled()
})
})

it('builds a chat payload with the vllm/ prefix stripped and messages assembled in order', async () => {
Expand Down
28 changes: 27 additions & 1 deletion apps/sim/providers/vllm/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { getErrorMessage, toError } from '@sim/utils/errors'
import OpenAI from 'openai'
import type { ChatCompletionCreateParamsStreaming } from 'openai/resources/chat/completions'
import { env } from '@/lib/core/config/env'
import { createPinnedFetch, validateUrlWithDNS } from '@/lib/core/security/input-validation.server'
import type { StreamingExecution } from '@/executor/types'
import { MAX_TOOL_ITERATIONS } from '@/providers'
import { formatMessagesForProvider } from '@/providers/attachments'
Expand Down Expand Up @@ -95,15 +96,40 @@ export const vllmProvider: ProviderConfig = {
stream: !!request.stream,
})

const baseUrl = (request.azureEndpoint || env.VLLM_BASE_URL || '').replace(/\/$/, '')
const userProvidedEndpoint = request.azureEndpoint

const baseUrl = (userProvidedEndpoint || env.VLLM_BASE_URL || '').replace(/\/$/, '')
if (!baseUrl) {
throw new Error('VLLM_BASE_URL is required for vLLM provider')
}

/**
* A user-supplied endpoint is attacker-controlled: validate it against the
* central SSRF guard and pin the connection to the resolved IP to defeat DNS
* rebinding. The operator-configured `VLLM_BASE_URL` is trusted and left
* unvalidated, mirroring the Azure providers.
*/
let pinnedFetch: typeof fetch | undefined
if (userProvidedEndpoint) {
const validation = await validateUrlWithDNS(userProvidedEndpoint, 'vLLM endpoint')
Comment thread
waleedlatif1 marked this conversation as resolved.
Comment thread
waleedlatif1 marked this conversation as resolved.
if (!validation.isValid) {
logger.warn('Blocked SSRF attempt via vLLM endpoint', {
endpoint: userProvidedEndpoint,
error: validation.error,
})
throw new Error(`Invalid vLLM endpoint: ${validation.error}`)
}
if (!validation.resolvedIP) {
throw new Error('Invalid vLLM endpoint: could not resolve a pinnable IP address')
}
pinnedFetch = createPinnedFetch(validation.resolvedIP)
}

const apiKey = request.apiKey || env.VLLM_API_KEY || 'empty'
const vllm = new OpenAI({
apiKey,
baseURL: `${baseUrl}/v1`,
...(pinnedFetch ? { fetch: pinnedFetch } : {}),
})

const allMessages: Message[] = []
Expand Down
Loading