Skip to content

Commit 894b548

Browse files
committed
add promptAiSdkStream to AgentRuntimeDeps
1 parent 9d5c697 commit 894b548

23 files changed

+293
-317
lines changed

backend/src/__tests__/cost-aggregation.integration.test.ts

Lines changed: 57 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { TEST_USER_ID } from '@codebuff/common/old-constants'
2-
import { testAgentRuntimeImpl } from '@codebuff/common/testing/impl/agent-runtime'
2+
import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime'
33
import { getInitialSessionState } from '@codebuff/common/types/session-state'
44
import {
55
spyOn,
@@ -18,6 +18,7 @@ import * as agentRegistry from '../templates/agent-registry'
1818
import * as websocketAction from '../websockets/websocket-action'
1919

2020
import type { AgentTemplate } from '../templates/types'
21+
import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime'
2122
import type { ProjectFileContext } from '@codebuff/common/util/file'
2223
import type { WebSocket } from 'ws'
2324

@@ -99,6 +100,7 @@ class MockWebSocket {
99100
describe('Cost Aggregation Integration Tests', () => {
100101
let mockLocalAgentTemplates: Record<string, any>
101102
let mockWebSocket: MockWebSocket
103+
let agentRuntimeImpl: AgentRuntimeDeps = { ...TEST_AGENT_RUNTIME_IMPL }
102104

103105
beforeEach(async () => {
104106
mockWebSocket = new MockWebSocket()
@@ -150,33 +152,31 @@ describe('Cost Aggregation Integration Tests', () => {
150152
// Mock LLM streaming
151153
let callCount = 0
152154
const creditHistory: number[] = []
153-
spyOn(aisdk, 'promptAiSdkStream').mockImplementation(
154-
async function* (options) {
155-
callCount++
156-
const credits = callCount === 1 ? 10 : 7 // Main agent vs subagent costs
157-
creditHistory.push(credits)
155+
agentRuntimeImpl.promptAiSdkStream = async function* (options) {
156+
callCount++
157+
const credits = callCount === 1 ? 10 : 7 // Main agent vs subagent costs
158+
creditHistory.push(credits)
158159

159-
if (options.onCostCalculated) {
160-
await options.onCostCalculated(credits)
161-
}
160+
if (options.onCostCalculated) {
161+
await options.onCostCalculated(credits)
162+
}
162163

163-
// Simulate different responses based on call
164-
if (callCount === 1) {
165-
// Main agent spawns a subagent
166-
yield {
167-
type: 'text' as const,
168-
text: '<codebuff_tool_call>\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Write a simple hello world file"}]}\n</codebuff_tool_call>',
169-
}
170-
} else {
171-
// Subagent writes a file
172-
yield {
173-
type: 'text' as const,
174-
text: '<codebuff_tool_call>\n{"cb_tool_name": "write_file", "path": "hello.txt", "instructions": "Create hello world file", "content": "Hello, World!"}\n</codebuff_tool_call>',
175-
}
164+
// Simulate different responses based on call
165+
if (callCount === 1) {
166+
// Main agent spawns a subagent
167+
yield {
168+
type: 'text' as const,
169+
text: '<codebuff_tool_call>\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Write a simple hello world file"}]}\n</codebuff_tool_call>',
176170
}
177-
return 'mock-message-id'
178-
},
179-
)
171+
} else {
172+
// Subagent writes a file
173+
yield {
174+
type: 'text' as const,
175+
text: '<codebuff_tool_call>\n{"cb_tool_name": "write_file", "path": "hello.txt", "instructions": "Create hello world file", "content": "Hello, World!"}\n</codebuff_tool_call>',
176+
}
177+
}
178+
return 'mock-message-id'
179+
}
180180

181181
// Mock tool call execution
182182
spyOn(websocketAction, 'requestToolCall').mockImplementation(
@@ -231,6 +231,7 @@ describe('Cost Aggregation Integration Tests', () => {
231231

232232
afterEach(() => {
233233
mock.restore()
234+
agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL }
234235
})
235236

236237
it('should correctly aggregate costs across the entire main prompt flow', async () => {
@@ -250,7 +251,7 @@ describe('Cost Aggregation Integration Tests', () => {
250251
}
251252

252253
const result = await mainPrompt({
253-
...testAgentRuntimeImpl,
254+
...agentRuntimeImpl,
254255
ws: mockWebSocket as unknown as WebSocket,
255256
action,
256257
userId: TEST_USER_ID,
@@ -285,7 +286,7 @@ describe('Cost Aggregation Integration Tests', () => {
285286

286287
// Call through websocket action handler to test full integration
287288
await websocketAction.callMainPrompt({
288-
...testAgentRuntimeImpl,
289+
...agentRuntimeImpl,
289290
ws: mockWebSocket as unknown as WebSocket,
290291
action,
291292
userId: TEST_USER_ID,
@@ -308,37 +309,35 @@ describe('Cost Aggregation Integration Tests', () => {
308309
it('should handle multi-level subagent hierarchies correctly', async () => {
309310
// Mock a more complex scenario with nested subagents
310311
let callCount = 0
311-
spyOn(aisdk, 'promptAiSdkStream').mockImplementation(
312-
async function* (options) {
313-
callCount++
312+
agentRuntimeImpl.promptAiSdkStream = async function* (options) {
313+
callCount++
314314

315-
if (options.onCostCalculated) {
316-
await options.onCostCalculated(5) // Each call costs 5 credits
317-
}
315+
if (options.onCostCalculated) {
316+
await options.onCostCalculated(5) // Each call costs 5 credits
317+
}
318318

319-
if (callCount === 1) {
320-
// Main agent spawns first-level subagent
321-
yield {
322-
type: 'text' as const,
323-
text: '<codebuff_tool_call>\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Create files"}]}\n</codebuff_tool_call>',
324-
}
325-
} else if (callCount === 2) {
326-
// First-level subagent spawns second-level subagent
327-
yield {
328-
type: 'text' as const,
329-
text: '<codebuff_tool_call>\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Write specific file"}]}\n</codebuff_tool_call>',
330-
}
331-
} else {
332-
// Second-level subagent does actual work
333-
yield {
334-
type: 'text' as const,
335-
text: '<codebuff_tool_call>\n{"cb_tool_name": "write_file", "path": "nested.txt", "instructions": "Create nested file", "content": "Nested content"}\n</codebuff_tool_call>',
336-
}
319+
if (callCount === 1) {
320+
// Main agent spawns first-level subagent
321+
yield {
322+
type: 'text' as const,
323+
text: '<codebuff_tool_call>\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Create files"}]}\n</codebuff_tool_call>',
324+
}
325+
} else if (callCount === 2) {
326+
// First-level subagent spawns second-level subagent
327+
yield {
328+
type: 'text' as const,
329+
text: '<codebuff_tool_call>\n{"cb_tool_name": "spawn_agents", "agents": [{"agent_type": "editor", "prompt": "Write specific file"}]}\n</codebuff_tool_call>',
330+
}
331+
} else {
332+
// Second-level subagent does actual work
333+
yield {
334+
type: 'text' as const,
335+
text: '<codebuff_tool_call>\n{"cb_tool_name": "write_file", "path": "nested.txt", "instructions": "Create nested file", "content": "Nested content"}\n</codebuff_tool_call>',
337336
}
337+
}
338338

339-
return 'mock-message-id'
340-
},
341-
)
339+
return 'mock-message-id'
340+
}
342341

343342
const sessionState = getInitialSessionState(mockFileContext)
344343
sessionState.mainAgentState.stepsRemaining = 10
@@ -355,7 +354,7 @@ describe('Cost Aggregation Integration Tests', () => {
355354
}
356355

357356
const result = await mainPrompt({
358-
...testAgentRuntimeImpl,
357+
...agentRuntimeImpl,
359358
ws: mockWebSocket as unknown as WebSocket,
360359
action,
361360
userId: TEST_USER_ID,
@@ -413,7 +412,7 @@ describe('Cost Aggregation Integration Tests', () => {
413412
let result
414413
try {
415414
result = await mainPrompt({
416-
...testAgentRuntimeImpl,
415+
...agentRuntimeImpl,
417416
ws: mockWebSocket as unknown as WebSocket,
418417
action,
419418
userId: TEST_USER_ID,
@@ -462,7 +461,7 @@ describe('Cost Aggregation Integration Tests', () => {
462461
}
463462

464463
await mainPrompt({
465-
...testAgentRuntimeImpl,
464+
...agentRuntimeImpl,
466465
ws: mockWebSocket as unknown as WebSocket,
467466
action,
468467
userId: TEST_USER_ID,
@@ -502,7 +501,7 @@ describe('Cost Aggregation Integration Tests', () => {
502501

503502
// Call through websocket action to test server-side reset
504503
await websocketAction.callMainPrompt({
505-
...testAgentRuntimeImpl,
504+
...agentRuntimeImpl,
506505
ws: mockWebSocket as unknown as WebSocket,
507506
action,
508507
userId: TEST_USER_ID,

backend/src/__tests__/cost-aggregation.test.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { testAgentRuntimeImpl } from '@codebuff/common/testing/impl/agent-runtime'
1+
import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime'
22
import {
33
getInitialAgentState,
44
getInitialSessionState,
@@ -180,7 +180,7 @@ describe('Cost Aggregation System', () => {
180180
}
181181

182182
const result = handleSpawnAgents({
183-
...testAgentRuntimeImpl,
183+
...TEST_AGENT_RUNTIME_IMPL,
184184
previousToolCallFinished: Promise.resolve(),
185185
toolCall: mockToolCall,
186186
fileContext: mockFileContext,
@@ -260,7 +260,7 @@ describe('Cost Aggregation System', () => {
260260
}
261261

262262
const result = handleSpawnAgents({
263-
...testAgentRuntimeImpl,
263+
...TEST_AGENT_RUNTIME_IMPL,
264264
previousToolCallFinished: Promise.resolve(),
265265
toolCall: mockToolCall,
266266
fileContext: mockFileContext,
@@ -417,7 +417,7 @@ describe('Cost Aggregation System', () => {
417417
}
418418

419419
const result = handleSpawnAgents({
420-
...testAgentRuntimeImpl,
420+
...TEST_AGENT_RUNTIME_IMPL,
421421
previousToolCallFinished: Promise.resolve(),
422422
toolCall: mockToolCall,
423423
fileContext: mockFileContext,

backend/src/__tests__/generate-diffs-prompt.test.ts

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,8 @@
1+
import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime'
12
import { expect, describe, it } from 'bun:test'
23

34
import { parseAndGetDiffBlocksSingleFile } from '../generate-diffs-prompt'
45

5-
import type { Logger } from '@codebuff/common/types/contracts/logger'
6-
7-
const logger: Logger = {
8-
debug: () => {},
9-
info: () => {},
10-
warn: () => {},
11-
error: () => {},
12-
}
13-
146
describe('parseAndGetDiffBlocksSingleFile', () => {
157
it('should parse diff blocks with newline before closing marker', () => {
168
const oldContent = 'function test() {\n return true;\n}\n'
@@ -26,9 +18,9 @@ function test() {
2618
>>>>>>> REPLACE`
2719

2820
const result = parseAndGetDiffBlocksSingleFile({
21+
...TEST_AGENT_RUNTIME_IMPL,
2922
newContent,
3023
oldFileContent: oldContent,
31-
logger,
3224
})
3325
console.log(JSON.stringify({ result }))
3426

@@ -55,9 +47,9 @@ function test() {
5547
}>>>>>>> REPLACE`
5648

5749
const result = parseAndGetDiffBlocksSingleFile({
50+
...TEST_AGENT_RUNTIME_IMPL,
5851
newContent,
5952
oldFileContent: oldContent,
60-
logger,
6153
})
6254

6355
expect(result.diffBlocks.length).toBe(1)
@@ -108,9 +100,9 @@ function subtract(a, b) {
108100
>>>>>>> REPLACE`
109101

110102
const result = parseAndGetDiffBlocksSingleFile({
103+
...TEST_AGENT_RUNTIME_IMPL,
111104
newContent,
112105
oldFileContent: oldContent,
113-
logger,
114106
})
115107

116108
expect(result.diffBlocks.length).toBe(2)
@@ -136,9 +128,9 @@ function subtract(a, b) {
136128
>>>>>>> REPLACE`
137129

138130
const result = parseAndGetDiffBlocksSingleFile({
131+
...TEST_AGENT_RUNTIME_IMPL,
139132
newContent,
140133
oldFileContent: oldContent,
141-
logger,
142134
})
143135

144136
expect(result.diffBlocks.length).toBe(1)

0 commit comments

Comments
 (0)