Skip to content

Commit e405e7a

Browse files
authored
DI for ai-sdk (#337)
1 parent b91abec commit e405e7a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+820
-869
lines changed

backend/src/__tests__/cost-aggregation.integration.test.ts

Lines changed: 75 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { TEST_USER_ID } from '@codebuff/common/old-constants'
2-
import { testAgentRuntimeImpl } from '@codebuff/common/testing/impl/agent-runtime'
2+
import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime'
33
import { getInitialSessionState } from '@codebuff/common/types/session-state'
44
import {
55
spyOn,
@@ -12,12 +12,12 @@ import {
1212
} from 'bun:test'
1313

1414
import * as messageCostTracker from '../llm-apis/message-cost-tracker'
15-
import * as aisdk from '../llm-apis/vercel-ai-sdk/ai-sdk'
1615
import { mainPrompt } from '../main-prompt'
1716
import * as agentRegistry from '../templates/agent-registry'
1817
import * as websocketAction from '../websockets/websocket-action'
1918

2019
import type { AgentTemplate } from '../templates/types'
20+
import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime'
2121
import type { ProjectFileContext } from '@codebuff/common/util/file'
2222
import type { WebSocket } from 'ws'
2323

@@ -99,8 +99,10 @@ class MockWebSocket {
9999
describe('Cost Aggregation Integration Tests', () => {
100100
let mockLocalAgentTemplates: Record<string, any>
101101
let mockWebSocket: MockWebSocket
102+
let agentRuntimeImpl: AgentRuntimeDeps
102103

103104
beforeEach(async () => {
105+
agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL }
104106
mockWebSocket = new MockWebSocket()
105107

106108
// Setup mock agent templates
@@ -150,33 +152,31 @@ describe('Cost Aggregation Integration Tests', () => {
150152
// Mock LLM streaming
151153
let callCount = 0
152154
const creditHistory: number[] = []
153-
spyOn(aisdk, 'promptAiSdkStream').mockImplementation(
154-
async function* (options) {
155-
callCount++
156-
const credits = callCount === 1 ? 10 : 7 // Main agent vs subagent costs
157-
creditHistory.push(credits)
158-
159-
if (options.onCostCalculated) {
160-
await options.onCostCalculated(credits)
161-
}
155+
agentRuntimeImpl.promptAiSdkStream = async function* (options) {
156+
callCount++
157+
const credits = callCount === 1 ? 10 : 7 // Main agent vs subagent costs
158+
creditHistory.push(credits)
162159

163-
// Simulate different responses based on call
164-
if (callCount === 1) {
165-
// Main agent spawns a subagent
166-
yield {
167-
type: 'text' as const,
168-
text: '<codebuff_tool_call>\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Write a simple hello world file"}]}\n</codebuff_tool_call>',
169-
}
170-
} else {
171-
// Subagent writes a file
172-
yield {
173-
type: 'text' as const,
174-
text: '<codebuff_tool_call>\n{"cb_tool_name": "write_file", "path": "hello.txt", "instructions": "Create hello world file", "content": "Hello, World!"}\n</codebuff_tool_call>',
175-
}
160+
if (options.onCostCalculated) {
161+
await options.onCostCalculated(credits)
162+
}
163+
164+
// Simulate different responses based on call
165+
if (callCount === 1) {
166+
// Main agent spawns a subagent
167+
yield {
168+
type: 'text' as const,
169+
text: '<codebuff_tool_call>\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Write a simple hello world file"}]}\n</codebuff_tool_call>',
176170
}
177-
return 'mock-message-id'
178-
},
179-
)
171+
} else {
172+
// Subagent writes a file
173+
yield {
174+
type: 'text' as const,
175+
text: '<codebuff_tool_call>\n{"cb_tool_name": "write_file", "path": "hello.txt", "instructions": "Create hello world file", "content": "Hello, World!"}\n</codebuff_tool_call>',
176+
}
177+
}
178+
return 'mock-message-id'
179+
}
180180

181181
// Mock tool call execution
182182
spyOn(websocketAction, 'requestToolCall').mockImplementation(
@@ -250,7 +250,7 @@ describe('Cost Aggregation Integration Tests', () => {
250250
}
251251

252252
const result = await mainPrompt({
253-
...testAgentRuntimeImpl,
253+
...agentRuntimeImpl,
254254
ws: mockWebSocket as unknown as WebSocket,
255255
action,
256256
userId: TEST_USER_ID,
@@ -285,7 +285,7 @@ describe('Cost Aggregation Integration Tests', () => {
285285

286286
// Call through websocket action handler to test full integration
287287
await websocketAction.callMainPrompt({
288-
...testAgentRuntimeImpl,
288+
...agentRuntimeImpl,
289289
ws: mockWebSocket as unknown as WebSocket,
290290
action,
291291
userId: TEST_USER_ID,
@@ -308,37 +308,35 @@ describe('Cost Aggregation Integration Tests', () => {
308308
it('should handle multi-level subagent hierarchies correctly', async () => {
309309
// Mock a more complex scenario with nested subagents
310310
let callCount = 0
311-
spyOn(aisdk, 'promptAiSdkStream').mockImplementation(
312-
async function* (options) {
313-
callCount++
311+
agentRuntimeImpl.promptAiSdkStream = async function* (options) {
312+
callCount++
314313

315-
if (options.onCostCalculated) {
316-
await options.onCostCalculated(5) // Each call costs 5 credits
317-
}
314+
if (options.onCostCalculated) {
315+
await options.onCostCalculated(5) // Each call costs 5 credits
316+
}
318317

319-
if (callCount === 1) {
320-
// Main agent spawns first-level subagent
321-
yield {
322-
type: 'text' as const,
323-
text: '<codebuff_tool_call>\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Create files"}]}\n</codebuff_tool_call>',
324-
}
325-
} else if (callCount === 2) {
326-
// First-level subagent spawns second-level subagent
327-
yield {
328-
type: 'text' as const,
329-
text: '<codebuff_tool_call>\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Write specific file"}]}\n</codebuff_tool_call>',
330-
}
331-
} else {
332-
// Second-level subagent does actual work
333-
yield {
334-
type: 'text' as const,
335-
text: '<codebuff_tool_call>\n{"cb_tool_name": "write_file", "path": "nested.txt", "instructions": "Create nested file", "content": "Nested content"}\n</codebuff_tool_call>',
336-
}
318+
if (callCount === 1) {
319+
// Main agent spawns first-level subagent
320+
yield {
321+
type: 'text' as const,
322+
text: '<codebuff_tool_call>\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Create files"}]}\n</codebuff_tool_call>',
323+
}
324+
} else if (callCount === 2) {
325+
// First-level subagent spawns second-level subagent
326+
yield {
327+
type: 'text' as const,
328+
text: '<codebuff_tool_call>\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Write specific file"}]}\n</codebuff_tool_call>',
337329
}
330+
} else {
331+
// Second-level subagent does actual work
332+
yield {
333+
type: 'text' as const,
334+
text: '<codebuff_tool_call>\n{"cb_tool_name": "write_file", "path": "nested.txt", "instructions": "Create nested file", "content": "Nested content"}\n</codebuff_tool_call>',
335+
}
336+
}
338337

339-
return 'mock-message-id'
340-
},
341-
)
338+
return 'mock-message-id'
339+
}
342340

343341
const sessionState = getInitialSessionState(mockFileContext)
344342
sessionState.mainAgentState.stepsRemaining = 10
@@ -355,7 +353,7 @@ describe('Cost Aggregation Integration Tests', () => {
355353
}
356354

357355
const result = await mainPrompt({
358-
...testAgentRuntimeImpl,
356+
...agentRuntimeImpl,
359357
ws: mockWebSocket as unknown as WebSocket,
360358
action,
361359
userId: TEST_USER_ID,
@@ -373,29 +371,27 @@ describe('Cost Aggregation Integration Tests', () => {
373371
it('should maintain cost integrity when subagents fail', async () => {
374372
// Mock scenario where subagent fails after incurring partial costs
375373
let callCount = 0
376-
spyOn(aisdk, 'promptAiSdkStream').mockImplementation(
377-
async function* (options) {
378-
callCount++
374+
agentRuntimeImpl.promptAiSdkStream = async function* (options) {
375+
callCount++
379376

380-
if (options.onCostCalculated) {
381-
await options.onCostCalculated(6) // Each call costs 6 credits
382-
}
377+
if (options.onCostCalculated) {
378+
await options.onCostCalculated(6) // Each call costs 6 credits
379+
}
383380

384-
if (callCount === 1) {
385-
// Main agent spawns subagent
386-
yield {
387-
type: 'text' as const,
388-
text: '<codebuff_tool_call>\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "This will fail"}]}\n</codebuff_tool_call>',
389-
}
390-
} else {
391-
// Subagent fails after incurring cost
392-
yield { type: 'text' as const, text: 'Some response' }
393-
throw new Error('Subagent execution failed')
381+
if (callCount === 1) {
382+
// Main agent spawns subagent
383+
yield {
384+
type: 'text' as const,
385+
text: '<codebuff_tool_call>\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "This will fail"}]}\n</codebuff_tool_call>',
394386
}
387+
} else {
388+
// Subagent fails after incurring cost
389+
yield { type: 'text' as const, text: 'Some response' }
390+
throw new Error('Subagent execution failed')
391+
}
395392

396-
return 'mock-message-id'
397-
},
398-
)
393+
return 'mock-message-id'
394+
}
399395

400396
const sessionState = getInitialSessionState(mockFileContext)
401397
sessionState.mainAgentState.agentType = 'base'
@@ -413,7 +409,7 @@ describe('Cost Aggregation Integration Tests', () => {
413409
let result
414410
try {
415411
result = await mainPrompt({
416-
...testAgentRuntimeImpl,
412+
...agentRuntimeImpl,
417413
ws: mockWebSocket as unknown as WebSocket,
418414
action,
419415
userId: TEST_USER_ID,
@@ -462,7 +458,7 @@ describe('Cost Aggregation Integration Tests', () => {
462458
}
463459

464460
await mainPrompt({
465-
...testAgentRuntimeImpl,
461+
...agentRuntimeImpl,
466462
ws: mockWebSocket as unknown as WebSocket,
467463
action,
468464
userId: TEST_USER_ID,
@@ -502,7 +498,7 @@ describe('Cost Aggregation Integration Tests', () => {
502498

503499
// Call through websocket action to test server-side reset
504500
await websocketAction.callMainPrompt({
505-
...testAgentRuntimeImpl,
501+
...agentRuntimeImpl,
506502
ws: mockWebSocket as unknown as WebSocket,
507503
action,
508504
userId: TEST_USER_ID,

backend/src/__tests__/cost-aggregation.test.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { testAgentRuntimeImpl } from '@codebuff/common/testing/impl/agent-runtime'
1+
import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime'
22
import {
33
getInitialAgentState,
44
getInitialSessionState,
@@ -180,7 +180,7 @@ describe('Cost Aggregation System', () => {
180180
}
181181

182182
const result = handleSpawnAgents({
183-
...testAgentRuntimeImpl,
183+
...TEST_AGENT_RUNTIME_IMPL,
184184
previousToolCallFinished: Promise.resolve(),
185185
toolCall: mockToolCall,
186186
fileContext: mockFileContext,
@@ -260,7 +260,7 @@ describe('Cost Aggregation System', () => {
260260
}
261261

262262
const result = handleSpawnAgents({
263-
...testAgentRuntimeImpl,
263+
...TEST_AGENT_RUNTIME_IMPL,
264264
previousToolCallFinished: Promise.resolve(),
265265
toolCall: mockToolCall,
266266
fileContext: mockFileContext,
@@ -417,7 +417,7 @@ describe('Cost Aggregation System', () => {
417417
}
418418

419419
const result = handleSpawnAgents({
420-
...testAgentRuntimeImpl,
420+
...TEST_AGENT_RUNTIME_IMPL,
421421
previousToolCallFinished: Promise.resolve(),
422422
toolCall: mockToolCall,
423423
fileContext: mockFileContext,

backend/src/__tests__/fast-rewrite.test.ts

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,19 @@
11
import path from 'path'
22

33
import { TEST_USER_ID } from '@codebuff/common/old-constants'
4+
import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime'
45
import {
56
clearMockedModules,
67
mockModule,
78
} from '@codebuff/common/testing/mock-modules'
8-
import { afterAll, beforeAll, describe, expect, it } from 'bun:test'
9+
import { afterAll, beforeAll, beforeEach, describe, expect, it } from 'bun:test'
910
import { createPatch } from 'diff'
1011

1112
import { rewriteWithOpenAI } from '../fast-rewrite'
1213

13-
import type { Logger } from '@codebuff/common/types/contracts/logger'
14+
import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime'
1415

15-
const logger: Logger = {
16-
debug: () => {},
17-
info: () => {},
18-
warn: () => {},
19-
error: () => {},
20-
}
16+
let agentRuntimeImpl: AgentRuntimeDeps
2117

2218
describe.skip('rewriteWithOpenAI', () => {
2319
beforeAll(() => {
@@ -42,6 +38,10 @@ describe.skip('rewriteWithOpenAI', () => {
4238
}))
4339
})
4440

41+
beforeEach(() => {
42+
agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL }
43+
})
44+
4545
afterAll(() => {
4646
clearMockedModules()
4747
})
@@ -53,15 +53,13 @@ describe.skip('rewriteWithOpenAI', () => {
5353
const expectedResult = await Bun.file(`${testDataDir}/expected.go`).text()
5454

5555
const result = await rewriteWithOpenAI({
56+
...agentRuntimeImpl,
5657
oldContent: originalContent,
5758
editSnippet,
58-
filePath: 'taskruntoolcall.go',
5959
clientSessionId: 'clientSessionId',
6060
fingerprintId: 'fingerprintId',
6161
userInputId: 'userInputId',
6262
userId: TEST_USER_ID,
63-
userMessage: undefined,
64-
logger,
6563
})
6664

6765
const patch = createPatch('test.ts', expectedResult, result)

0 commit comments

Comments
 (0)