Skip to content

Commit 9de95ed

Browse files
committed
inject requestFiles and requestOptionalFile
1 parent 1a73022 commit 9de95ed

21 files changed

+270
-276
lines changed

backend/src/__tests__/cost-aggregation.integration.test.ts

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -213,15 +213,15 @@ describe('Cost Aggregation Integration Tests', () => {
213213
}
214214

215215
// Mock file reading
216-
spyOn(websocketAction, 'requestFiles').mockImplementation(
217-
async (params: { ws: any; filePaths: string[] }) => {
218-
const results: Record<string, string | null> = {}
219-
params.filePaths.forEach((path) => {
220-
results[path] = path === 'hello.txt' ? 'Hello, World!' : null
221-
})
222-
return results
223-
},
224-
)
216+
agentRuntimeScopedImpl.requestFiles = async (params: {
217+
filePaths: string[]
218+
}) => {
219+
const results: Record<string, string | null> = {}
220+
params.filePaths.forEach((path) => {
221+
results[path] = path === 'hello.txt' ? 'Hello, World!' : null
222+
})
223+
return results
224+
}
225225

226226
// Mock live user input checking
227227
const liveUserInputs = await import('../live-user-inputs')

backend/src/__tests__/main-prompt.integration.test.ts

Lines changed: 5 additions & 9 deletions
Large diffs are not rendered by default.

backend/src/__tests__/main-prompt.test.ts

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ import * as getDocumentationForQueryModule from '../get-documentation-for-query'
2727
import * as liveUserInputs from '../live-user-inputs'
2828
import { mainPrompt } from '../main-prompt'
2929
import * as processFileBlockModule from '../process-file-block'
30-
import * as websocketAction from '../websockets/websocket-action'
3130

3231
import type { AgentTemplate } from '@codebuff/common/types/agent-template'
3332
import type {
@@ -117,28 +116,24 @@ describe('mainPrompt', () => {
117116
mockAgentStream('Test response')
118117

119118
// Mock websocket actions
120-
spyOn(websocketAction, 'requestFiles').mockImplementation(
121-
async (params: { ws: any; filePaths: string[] }) => {
122-
const results: Record<string, string | null> = {}
123-
params.filePaths.forEach((p) => {
124-
if (p === 'test.txt') {
125-
results[p] = 'mock content for test.txt'
126-
} else {
127-
results[p] = null
128-
}
129-
})
130-
return results
131-
},
132-
)
133-
134-
spyOn(websocketAction, 'requestFile').mockImplementation(
135-
async (params: { ws: any; filePath: string }) => {
136-
if (params.filePath === 'test.txt') {
137-
return 'mock content for test.txt'
119+
agentRuntimeScopedImpl.requestFiles = async ({ filePaths }) => {
120+
const results: Record<string, string | null> = {}
121+
filePaths.forEach((p) => {
122+
if (p === 'test.txt') {
123+
results[p] = 'mock content for test.txt'
124+
} else {
125+
results[p] = null
138126
}
139-
return null
140-
},
141-
)
127+
})
128+
return results
129+
}
130+
131+
agentRuntimeScopedImpl.requestOptionalFile = async ({ filePath }) => {
132+
if (filePath === 'test.txt') {
133+
return 'mock content for test.txt'
134+
}
135+
return null
136+
}
142137

143138
agentRuntimeScopedImpl.requestToolCall = mock(
144139
async ({

backend/src/__tests__/malformed-tool-call.test.ts

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ import {
2020

2121
import { MockWebSocket, mockFileContext } from './test-utils'
2222
import { processStreamWithTools } from '../tools/stream-parser'
23-
import * as websocketAction from '../websockets/websocket-action'
2423

2524
import type { AgentTemplate } from '../templates/types'
2625
import type {
@@ -72,8 +71,8 @@ describe('malformed tool call error handling', () => {
7271
)
7372

7473
// Mock websocket actions
75-
spyOn(websocketAction, 'requestFiles').mockImplementation(async () => ({}))
76-
spyOn(websocketAction, 'requestFile').mockImplementation(async () => null)
74+
agentRuntimeScopedImpl.requestFiles = async () => ({})
75+
agentRuntimeScopedImpl.requestOptionalFile = async () => null
7776
agentRuntimeScopedImpl.requestToolCall = async () => ({
7877
output: [
7978
{

backend/src/__tests__/prompt-caching-subagents.test.ts

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ import {
1515
} from 'bun:test'
1616

1717
import { loopAgentSteps } from '../run-agent-step'
18-
import * as websocketAction from '../websockets/websocket-action'
1918

2019
import type { AgentTemplate } from '../templates/types'
2120
import type {
@@ -126,15 +125,13 @@ describe('Prompt Caching for Subagents with inheritParentSystemPrompt', () => {
126125
}
127126

128127
// Mock file operations
129-
spyOn(websocketAction, 'requestFiles').mockImplementation(
130-
async (params: { ws: any; filePaths: string[] }) => {
131-
const results: Record<string, string | null> = {}
132-
params.filePaths.forEach((path) => {
133-
results[path] = null
134-
})
135-
return results
136-
},
137-
)
128+
agentRuntimeScopedImpl.requestFiles = async ({ filePaths }) => {
129+
const results: Record<string, string | null> = {}
130+
filePaths.forEach((path) => {
131+
results[path] = null
132+
})
133+
return results
134+
}
138135

139136
agentRuntimeScopedImpl.requestToolCall = async () => ({
140137
output: [

backend/src/__tests__/read-docs-tool.test.ts

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -90,24 +90,8 @@ describe('read_docs tool with researcher agent', () => {
9090
mockedFunctions.push({ name: 'bigquery.insertTrace', spy: insertTraceSpy })
9191

9292
// Mock websocket actions
93-
const requestFilesSpy = spyOn(
94-
websocketAction,
95-
'requestFiles',
96-
).mockImplementation(async () => ({}))
97-
mockedFunctions.push({
98-
name: 'websocketAction.requestFiles',
99-
spy: requestFilesSpy,
100-
})
101-
102-
const requestFileSpy = spyOn(
103-
websocketAction,
104-
'requestFile',
105-
).mockImplementation(async () => null)
106-
mockedFunctions.push({
107-
name: 'websocketAction.requestFile',
108-
spy: requestFileSpy,
109-
})
110-
93+
agentRuntimeScopedImpl.requestFiles = async () => ({})
94+
agentRuntimeScopedImpl.requestOptionalFile = async () => null
11195
agentRuntimeScopedImpl.requestToolCall = async () => ({
11296
output: [
11397
{

backend/src/__tests__/run-agent-step-tools.test.ts

Lines changed: 40 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,24 @@ import * as liveUserInputs from '../live-user-inputs'
2424
import { runAgentStep } from '../run-agent-step'
2525
import { clearAgentGeneratorCache } from '../run-programmatic-step'
2626
import { asUserMessage } from '../util/messages'
27-
import * as websocketAction from '../websockets/websocket-action'
2827

2928
import type { AgentTemplate } from '../templates/types'
30-
import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime'
29+
import type {
30+
AgentRuntimeDeps,
31+
AgentRuntimeScopedDeps,
32+
} from '@codebuff/common/types/contracts/agent-runtime'
3133
import type { ProjectFileContext } from '@codebuff/common/util/file'
3234
import type { WebSocket } from 'ws'
3335

3436
describe('runAgentStep - set_output tool', () => {
3537
let testAgent: AgentTemplate
36-
let agentRuntimeImpl: AgentRuntimeDeps = { ...TEST_AGENT_RUNTIME_IMPL }
38+
let agentRuntimeImpl: AgentRuntimeDeps
39+
let agentRuntimeScopedImpl: AgentRuntimeScopedDeps
3740

3841
beforeEach(async () => {
42+
agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL }
43+
agentRuntimeScopedImpl = { ...TEST_AGENT_RUNTIME_SCOPED_IMPL }
44+
3945
// Create a test agent that supports set_output
4046
testAgent = {
4147
id: 'test-set-output-agent',
@@ -78,32 +84,27 @@ describe('runAgentStep - set_output tool', () => {
7884
spyOn(liveUserInputs, 'startUserInput').mockImplementation(() => {})
7985
spyOn(liveUserInputs, 'setSessionConnected').mockImplementation(() => {})
8086

81-
spyOn(websocketAction, 'requestFiles').mockImplementation(
82-
async (params: { ws: any; filePaths: string[] }) => {
83-
const results: Record<string, string | null> = {}
84-
params.filePaths.forEach((p) => {
85-
if (p === 'src/auth.ts') {
86-
results[p] = 'export function authenticate() { return true; }'
87-
} else if (p === 'src/user.ts') {
88-
results[p] = 'export interface User { id: string; name: string; }'
89-
} else {
90-
results[p] = null
91-
}
92-
})
93-
return results
94-
},
95-
)
96-
97-
spyOn(websocketAction, 'requestFile').mockImplementation(
98-
async (params: { ws: any; filePath: string }) => {
99-
if (params.filePath === 'src/auth.ts') {
100-
return 'export function authenticate() { return true; }'
101-
} else if (params.filePath === 'src/user.ts') {
102-
return 'export interface User { id: string; name: string; }'
87+
agentRuntimeScopedImpl.requestFiles = async ({ filePaths }) => {
88+
const results: Record<string, string | null> = {}
89+
filePaths.forEach((p) => {
90+
if (p === 'src/auth.ts') {
91+
results[p] = 'export function authenticate() { return true; }'
92+
} else if (p === 'src/user.ts') {
93+
results[p] = 'export interface User { id: string; name: string; }'
94+
} else {
95+
results[p] = null
10396
}
104-
return null
105-
},
106-
)
97+
})
98+
return results
99+
}
100+
agentRuntimeScopedImpl.requestOptionalFile = async ({ filePath }) => {
101+
if (filePath === 'src/auth.ts') {
102+
return 'export function authenticate() { return true; }'
103+
} else if (filePath === 'src/user.ts') {
104+
return 'export interface User { id: string; name: string; }'
105+
}
106+
return null
107+
}
107108

108109
// Don't mock requestToolCall for integration test - let real tool execution happen
109110

@@ -116,7 +117,6 @@ describe('runAgentStep - set_output tool', () => {
116117

117118
afterEach(() => {
118119
mock.restore()
119-
agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL }
120120
})
121121

122122
afterAll(() => {
@@ -363,19 +363,17 @@ describe('runAgentStep - set_output tool', () => {
363363
}
364364

365365
// Mock requestFiles to return test file content
366-
spyOn(websocketAction, 'requestFiles').mockImplementation(
367-
async (params: { ws: any; filePaths: string[] }) => {
368-
const results: Record<string, string | null> = {}
369-
params.filePaths.forEach((p) => {
370-
if (p === 'src/test.ts') {
371-
results[p] = 'export function testFunction() { return "test"; }'
372-
} else {
373-
results[p] = null
374-
}
375-
})
376-
return results
377-
},
378-
)
366+
agentRuntimeScopedImpl.requestFiles = async ({ filePaths }) => {
367+
const results: Record<string, string | null> = {}
368+
filePaths.forEach((p) => {
369+
if (p === 'src/test.ts') {
370+
results[p] = 'export function testFunction() { return "test"; }'
371+
} else {
372+
results[p] = null
373+
}
374+
})
375+
return results
376+
}
379377

380378
// Mock the LLM stream to return a response that doesn't end the turn
381379
agentRuntimeImpl.promptAiSdkStream = async function* ({}) {

backend/src/__tests__/web-search-tool.test.ts

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ import { MockWebSocket, mockFileContext } from './test-utils'
2828
import * as linkupApi from '../llm-apis/linkup-api'
2929
import { runAgentStep } from '../run-agent-step'
3030
import { assembleLocalAgentTemplates } from '../templates/agent-registry'
31-
import * as websocketAction from '../websockets/websocket-action'
3231

3332
import type {
3433
AgentRuntimeDeps,
@@ -64,8 +63,8 @@ describe('web_search tool with researcher agent', () => {
6463
)
6564

6665
// Mock websocket actions
67-
spyOn(websocketAction, 'requestFiles').mockImplementation(async () => ({}))
68-
spyOn(websocketAction, 'requestFile').mockImplementation(async () => null)
66+
agentRuntimeScopedImpl.requestFiles = async () => ({})
67+
agentRuntimeScopedImpl.requestOptionalFile = async () => null
6968
agentRuntimeScopedImpl.requestToolCall = async () => ({
7069
output: [
7170
{

backend/src/client-wrapper.ts

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
1+
import { toOptionalFile } from '@codebuff/common/old-constants'
2+
import { ensureEndsWithNewline } from '@codebuff/common/util/file'
13
import { generateCompactId } from '@codebuff/common/util/string'
24

35
import { subscribeToAction } from './websockets/websocket-action'
46

57
import type { ServerAction } from '@codebuff/common/actions'
6-
import type { RequestMcpToolDataFn } from '@codebuff/common/types/contracts/client'
8+
import type {
9+
RequestFilesFn,
10+
RequestMcpToolDataFn,
11+
RequestOptionalFileFn,
12+
} from '@codebuff/common/types/contracts/client'
713
import type { ParamsOf } from '@codebuff/common/types/function-params'
814
import type { MCPConfig } from '@codebuff/common/types/mcp'
915
import type { ToolResultOutput } from '@codebuff/common/types/messages/content-part'
@@ -141,3 +147,44 @@ export async function requestMcpToolDataWs(
141147
})
142148
})
143149
}
150+
151+
/**
152+
* Requests multiple files from the client
153+
* @param ws - The WebSocket connection
154+
* @param filePaths - Array of file paths to request
155+
* @returns Promise resolving to an object mapping file paths to their contents
156+
*/
157+
export async function requestFilesWs(
158+
params: {
159+
ws: WebSocket
160+
} & ParamsOf<RequestFilesFn>,
161+
): ReturnType<RequestFilesFn> {
162+
const { ws, filePaths } = params
163+
return new Promise<Record<string, string | null>>((resolve) => {
164+
const requestId = generateCompactId()
165+
const unsubscribe = subscribeToAction('read-files-response', (action) => {
166+
for (const [filename, contents] of Object.entries(action.files)) {
167+
action.files[filename] = ensureEndsWithNewline(contents)
168+
}
169+
if (action.requestId === requestId) {
170+
unsubscribe()
171+
resolve(action.files)
172+
}
173+
})
174+
sendAction(ws, {
175+
type: 'read-files',
176+
filePaths,
177+
requestId,
178+
})
179+
})
180+
}
181+
182+
export async function requestOptionalFileWs(
183+
params: {
184+
ws: WebSocket
185+
} & ParamsOf<RequestOptionalFileFn>,
186+
): ReturnType<RequestOptionalFileFn> {
187+
const { ws, filePath } = params
188+
const files = await requestFilesWs({ ws, filePaths: [filePath] })
189+
return toOptionalFile(files[filePath] ?? null)
190+
}

backend/src/get-file-reading-updates.ts

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
import { uniq } from 'lodash'
2-
import type { WebSocket } from 'ws'
32

4-
import { requestFiles } from './websockets/websocket-action'
3+
import type { RequestFilesFn } from '@codebuff/common/types/contracts/client'
54

6-
export async function getFileReadingUpdates(
7-
ws: WebSocket,
8-
requestedFiles: string[],
9-
): Promise<
5+
export async function getFileReadingUpdates(params: {
6+
requestFiles: RequestFilesFn
7+
requestedFiles: string[]
8+
}): Promise<
109
{
1110
path: string
1211
content: string
1312
}[]
1413
> {
14+
const { requestFiles, requestedFiles } = params
15+
1516
const allFilePaths = uniq(requestedFiles)
16-
const loadedFiles = await requestFiles({ ws, filePaths: allFilePaths })
17+
const loadedFiles = await requestFiles({ filePaths: allFilePaths })
1718

1819
const addedFiles = allFilePaths
1920
.filter(

0 commit comments

Comments
 (0)