@@ -5,31 +5,50 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'
55
66const {
77 mockCreate,
8+ openAIArgs,
9+ mockOpenAI,
810 mockExecuteTool,
911 mockPrepareTools,
1012 mockCheckForced,
1113 mockCreateStream,
14+ mockValidateUrlWithDNS,
15+ mockCreatePinnedFetch,
16+ pinnedFetchFn,
1217 envState,
13- } = vi . hoisted ( ( ) => ( {
14- mockCreate : vi . fn ( ) ,
15- mockExecuteTool : vi . fn ( ) ,
16- mockPrepareTools : vi . fn ( ) ,
17- mockCheckForced : vi . fn ( ) ,
18- mockCreateStream : vi . fn ( ) ,
19- envState : {
20- VLLM_BASE_URL : 'http://localhost:8000' ,
21- VLLM_API_KEY : undefined as string | undefined ,
22- } ,
23- } ) )
24-
25- vi . mock ( 'openai' , ( ) => ( {
26- default : vi . fn ( ) . mockImplementation (
27- class {
28- chat = { completions : { create : mockCreate } }
18+ } = vi . hoisted ( ( ) => {
19+ const openAIArgs : Array < Record < string , unknown > > = [ ]
20+ const mockCreate = vi . fn ( )
21+ const pinnedFetchFn = vi . fn ( )
22+ class MockOpenAI {
23+ chat = { completions : { create : mockCreate } }
24+ constructor ( opts : Record < string , unknown > ) {
25+ openAIArgs . push ( opts )
2926 }
30- ) ,
31- } ) )
27+ }
28+ return {
29+ mockCreate,
30+ openAIArgs,
31+ mockOpenAI : MockOpenAI ,
32+ mockExecuteTool : vi . fn ( ) ,
33+ mockPrepareTools : vi . fn ( ) ,
34+ mockCheckForced : vi . fn ( ) ,
35+ mockCreateStream : vi . fn ( ) ,
36+ mockValidateUrlWithDNS : vi . fn ( ) ,
37+ mockCreatePinnedFetch : vi . fn ( ( ) => pinnedFetchFn ) ,
38+ pinnedFetchFn,
39+ envState : {
40+ VLLM_BASE_URL : 'http://localhost:8000' ,
41+ VLLM_API_KEY : undefined as string | undefined ,
42+ } ,
43+ }
44+ } )
45+
46+ vi . mock ( 'openai' , ( ) => ( { default : mockOpenAI } ) )
3247vi . mock ( '@/lib/core/config/env' , ( ) => ( { env : envState } ) )
48+ vi . mock ( '@/lib/core/security/input-validation.server' , ( ) => ( {
49+ validateUrlWithDNS : mockValidateUrlWithDNS ,
50+ createPinnedFetch : mockCreatePinnedFetch ,
51+ } ) )
3352vi . mock ( '@/providers' , ( ) => ( { MAX_TOOL_ITERATIONS : 20 } ) )
3453vi . mock ( '@/providers/models' , ( ) => ( {
3554 getProviderModels : vi . fn ( ( ) => [ ] ) ,
@@ -94,6 +113,7 @@ const createPayload = (callIndex: number) => mockCreate.mock.calls[callIndex][0]
94113describe ( 'vllmProvider' , ( ) => {
95114 beforeEach ( ( ) => {
96115 vi . clearAllMocks ( )
116+ openAIArgs . length = 0
97117 envState . VLLM_BASE_URL = 'http://localhost:8000'
98118 envState . VLLM_API_KEY = undefined
99119 mockPrepareTools . mockReturnValue ( {
@@ -105,6 +125,77 @@ describe('vllmProvider', () => {
105125 mockCheckForced . mockReturnValue ( { hasUsedForcedTool : false , usedForcedTools : [ ] } )
106126 mockCreateStream . mockReturnValue ( new ReadableStream ( { start : ( c ) => c . close ( ) } ) )
107127 mockExecuteTool . mockResolvedValue ( { success : true , output : { result : 'ok' } } )
128+ mockValidateUrlWithDNS . mockResolvedValue ( { isValid : true , resolvedIP : '203.0.113.10' } )
129+ mockCreatePinnedFetch . mockReturnValue ( pinnedFetchFn )
130+ } )
131+
132+ describe ( 'endpoint SSRF protection' , ( ) => {
133+ it ( 'does not validate or pin when no endpoint is supplied (uses env base URL)' , async ( ) => {
134+ mockCreate . mockResolvedValueOnce ( chatResponse ( 'hi' ) )
135+
136+ await vllmProvider . executeRequest ( {
137+ model : 'vllm/llama-3' ,
138+ messages : [ { role : 'user' , content : 'hi' } ] ,
139+ } )
140+
141+ expect ( mockValidateUrlWithDNS ) . not . toHaveBeenCalled ( )
142+ expect ( mockCreatePinnedFetch ) . not . toHaveBeenCalled ( )
143+ expect ( openAIArgs [ 0 ] . baseURL ) . toBe ( 'http://localhost:8000/v1' )
144+ expect ( openAIArgs [ 0 ] . fetch ) . toBeUndefined ( )
145+ } )
146+
147+ it ( 'validates a user-supplied endpoint and pins the connection to the resolved IP' , async ( ) => {
148+ mockCreate . mockResolvedValueOnce ( chatResponse ( 'hi' ) )
149+
150+ await vllmProvider . executeRequest ( {
151+ model : 'vllm/llama-3' ,
152+ messages : [ { role : 'user' , content : 'hi' } ] ,
153+ azureEndpoint : 'https://my-vllm.example.com' ,
154+ } )
155+
156+ expect ( mockValidateUrlWithDNS ) . toHaveBeenCalledWith (
157+ 'https://my-vllm.example.com' ,
158+ 'vLLM endpoint'
159+ )
160+ expect ( mockCreatePinnedFetch ) . toHaveBeenCalledWith ( '203.0.113.10' )
161+ expect ( openAIArgs [ 0 ] . baseURL ) . toBe ( 'https://my-vllm.example.com/v1' )
162+ expect ( openAIArgs [ 0 ] . fetch ) . toBe ( pinnedFetchFn )
163+ } )
164+
165+ it ( 'rejects a user-supplied endpoint that fails SSRF validation without issuing a request' , async ( ) => {
166+ mockValidateUrlWithDNS . mockResolvedValueOnce ( {
167+ isValid : false ,
168+ error : 'vLLM endpoint resolves to a blocked IP address' ,
169+ } )
170+
171+ await expect (
172+ vllmProvider . executeRequest ( {
173+ model : 'vllm/llama-3' ,
174+ messages : [ { role : 'user' , content : 'hi' } ] ,
175+ azureEndpoint : 'http://169.254.169.254' ,
176+ } )
177+ ) . rejects . toThrow ( 'Invalid vLLM endpoint' )
178+
179+ expect ( mockCreatePinnedFetch ) . not . toHaveBeenCalled ( )
180+ expect ( openAIArgs ) . toHaveLength ( 0 )
181+ expect ( mockCreate ) . not . toHaveBeenCalled ( )
182+ } )
183+
184+ it ( 'rejects a validated endpoint that did not resolve to a pinnable IP' , async ( ) => {
185+ mockValidateUrlWithDNS . mockResolvedValueOnce ( { isValid : true } )
186+
187+ await expect (
188+ vllmProvider . executeRequest ( {
189+ model : 'vllm/llama-3' ,
190+ messages : [ { role : 'user' , content : 'hi' } ] ,
191+ azureEndpoint : 'https://my-vllm.example.com' ,
192+ } )
193+ ) . rejects . toThrow ( 'could not resolve a pinnable IP address' )
194+
195+ expect ( mockCreatePinnedFetch ) . not . toHaveBeenCalled ( )
196+ expect ( openAIArgs ) . toHaveLength ( 0 )
197+ expect ( mockCreate ) . not . toHaveBeenCalled ( )
198+ } )
108199 } )
109200
110201 it ( 'builds a chat payload with the vllm/ prefix stripped and messages assembled in order' , async ( ) => {
0 commit comments