Skip to content

Commit 89ef32f

Browse files
committed
pass in promptAiSdk
1 parent cf5aaa9 commit 89ef32f

16 files changed

+281
-359
lines changed

backend/src/__tests__/fast-rewrite.test.ts

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,19 @@
11
import path from 'path'
22

33
import { TEST_USER_ID } from '@codebuff/common/old-constants'
4+
import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime'
45
import {
56
clearMockedModules,
67
mockModule,
78
} from '@codebuff/common/testing/mock-modules'
8-
import { afterAll, beforeAll, describe, expect, it } from 'bun:test'
9+
import { afterAll, beforeAll, beforeEach, describe, expect, it } from 'bun:test'
910
import { createPatch } from 'diff'
1011

1112
import { rewriteWithOpenAI } from '../fast-rewrite'
1213

13-
import type { Logger } from '@codebuff/common/types/contracts/logger'
14+
import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime'
1415

15-
const logger: Logger = {
16-
debug: () => {},
17-
info: () => {},
18-
warn: () => {},
19-
error: () => {},
20-
}
16+
let agentRuntimeImpl: AgentRuntimeDeps
2117

2218
describe.skip('rewriteWithOpenAI', () => {
2319
beforeAll(() => {
@@ -42,6 +38,10 @@ describe.skip('rewriteWithOpenAI', () => {
4238
}))
4339
})
4440

41+
beforeEach(() => {
42+
agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL }
43+
})
44+
4545
afterAll(() => {
4646
clearMockedModules()
4747
})
@@ -53,15 +53,13 @@ describe.skip('rewriteWithOpenAI', () => {
5353
const expectedResult = await Bun.file(`${testDataDir}/expected.go`).text()
5454

5555
const result = await rewriteWithOpenAI({
56+
...agentRuntimeImpl,
5657
oldContent: originalContent,
5758
editSnippet,
58-
filePath: 'taskruntoolcall.go',
5959
clientSessionId: 'clientSessionId',
6060
fingerprintId: 'fingerprintId',
6161
userInputId: 'userInputId',
6262
userId: TEST_USER_ID,
63-
userMessage: undefined,
64-
logger,
6563
})
6664

6765
const patch = createPatch('test.ts', expectedResult, result)

backend/src/__tests__/request-files-prompt.test.ts

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { finetunedVertexModels } from '@codebuff/common/old-constants'
2+
import { TEST_AGENT_RUNTIME_IMPL } from '@codebuff/common/testing/impl/agent-runtime'
23
import {
34
clearMockedModules,
45
mockModule,
@@ -18,10 +19,13 @@ import * as OriginalRequestFilesPromptModule from '../find-files/request-files-p
1819
import * as geminiWithFallbacksModule from '../llm-apis/gemini-with-fallbacks'
1920

2021
import type { CostMode } from '@codebuff/common/old-constants'
22+
import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime'
2123
import type { Message } from '@codebuff/common/types/messages/codebuff-message'
2224
import type { ProjectFileContext } from '@codebuff/common/util/file'
2325
import type { Mock } from 'bun:test'
2426

27+
let agentRuntimeImpl: AgentRuntimeDeps
28+
2529
describe('requestRelevantFiles', () => {
2630
const mockMessages: Message[] = [{ role: 'user', content: 'test prompt' }]
2731
const mockSystem = 'test system'
@@ -58,12 +62,6 @@ describe('requestRelevantFiles', () => {
5862
const mockUserId = 'user1'
5963
const mockCostMode: CostMode = 'normal'
6064
const mockRepoId = 'owner/repo'
61-
const logger = {
62-
debug: () => {},
63-
info: () => {},
64-
warn: () => {},
65-
error: () => {},
66-
}
6765

6866
let getCustomFilePickerConfigForOrgSpy: any // Explicitly typed as any
6967

@@ -81,15 +79,6 @@ describe('requestRelevantFiles', () => {
8179
})),
8280
}))
8381

84-
mockModule('@codebuff/backend/util/logger', () => ({
85-
logger: {
86-
info: bunMockFn(() => {}),
87-
error: bunMockFn(() => {}),
88-
warn: bunMockFn(() => {}),
89-
debug: bunMockFn(() => {}),
90-
},
91-
}))
92-
9382
mockModule('@codebuff/common/db', () => ({
9483
default: {
9584
insert: bunMockFn(() => ({
@@ -109,6 +98,8 @@ describe('requestRelevantFiles', () => {
10998
})
11099

111100
beforeEach(() => {
101+
agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL }
102+
112103
// If the spy was created in a previous test, restore it
113104
if (
114105
getCustomFilePickerConfigForOrgSpy &&
@@ -134,6 +125,7 @@ describe('requestRelevantFiles', () => {
134125

135126
it('should use default file counts and maxFiles when no custom config', async () => {
136127
await OriginalRequestFilesPromptModule.requestRelevantFiles({
128+
...agentRuntimeImpl,
137129
messages: mockMessages,
138130
system: mockSystem,
139131
fileContext: mockFileContext,
@@ -144,7 +136,6 @@ describe('requestRelevantFiles', () => {
144136
userInputId: mockUserInputId,
145137
userId: mockUserId,
146138
repoId: mockRepoId,
147-
logger,
148139
})
149140
expect(
150141
geminiWithFallbacksModule.promptFlashWithFallbacks,
@@ -161,6 +152,7 @@ describe('requestRelevantFiles', () => {
161152
getCustomFilePickerConfigForOrgSpy!.mockResolvedValue(customConfig as any)
162153

163154
await OriginalRequestFilesPromptModule.requestRelevantFiles({
155+
...agentRuntimeImpl,
164156
messages: mockMessages,
165157
system: mockSystem,
166158
fileContext: mockFileContext,
@@ -171,7 +163,6 @@ describe('requestRelevantFiles', () => {
171163
userInputId: mockUserInputId,
172164
userId: mockUserId,
173165
repoId: mockRepoId,
174-
logger,
175166
})
176167
expect(
177168
geminiWithFallbacksModule.promptFlashWithFallbacks,
@@ -187,6 +178,7 @@ describe('requestRelevantFiles', () => {
187178
getCustomFilePickerConfigForOrgSpy!.mockResolvedValue(customConfig as any)
188179

189180
const result = await OriginalRequestFilesPromptModule.requestRelevantFiles({
181+
...agentRuntimeImpl,
190182
messages: mockMessages,
191183
system: mockSystem,
192184
fileContext: mockFileContext,
@@ -197,7 +189,6 @@ describe('requestRelevantFiles', () => {
197189
userInputId: mockUserInputId,
198190
userId: mockUserId,
199191
repoId: mockRepoId,
200-
logger,
201192
})
202193
expect(result).toBeArray()
203194
if (result) {
@@ -213,6 +204,7 @@ describe('requestRelevantFiles', () => {
213204
getCustomFilePickerConfigForOrgSpy!.mockResolvedValue(customConfig as any)
214205

215206
await OriginalRequestFilesPromptModule.requestRelevantFiles({
207+
...agentRuntimeImpl,
216208
messages: mockMessages,
217209
system: mockSystem,
218210
fileContext: mockFileContext,
@@ -223,7 +215,6 @@ describe('requestRelevantFiles', () => {
223215
userInputId: mockUserInputId,
224216
userId: mockUserId,
225217
repoId: mockRepoId,
226-
logger,
227218
})
228219
expect(
229220
geminiWithFallbacksModule.promptFlashWithFallbacks,
@@ -242,6 +233,7 @@ describe('requestRelevantFiles', () => {
242233
getCustomFilePickerConfigForOrgSpy!.mockResolvedValue(customConfig as any)
243234

244235
await OriginalRequestFilesPromptModule.requestRelevantFiles({
236+
...agentRuntimeImpl,
245237
messages: mockMessages,
246238
system: mockSystem,
247239
fileContext: mockFileContext,
@@ -252,7 +244,6 @@ describe('requestRelevantFiles', () => {
252244
userInputId: mockUserInputId,
253245
userId: mockUserId,
254246
repoId: mockRepoId,
255-
logger,
256247
})
257248
const expectedModel = finetunedVertexModels.ft_filepicker_010
258249
expect(

backend/src/admin/grade-runs.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import { models, TEST_USER_ID } from '@codebuff/common/old-constants'
22
import { closeXml } from '@codebuff/common/util/xml'
33

4-
import { promptAiSdk } from '../llm-apis/vercel-ai-sdk/ai-sdk'
5-
64
import type { Relabel, GetRelevantFilesTrace } from '@codebuff/bigquery'
5+
import type { PromptAiSdkFn } from '@codebuff/common/types/contracts/llm'
76
import type { Logger } from '@codebuff/common/types/contracts/logger'
87

98
const PROMPT = `
@@ -100,9 +99,10 @@ function extractResponse(response: string): {
10099
export async function gradeRun(params: {
101100
trace: GetRelevantFilesTrace
102101
relabels: Relabel[]
102+
promptAiSdk: PromptAiSdkFn
103103
logger: Logger
104104
}) {
105-
const { trace, relabels, logger } = params
105+
const { trace, relabels, promptAiSdk, logger } = params
106106
const messages = trace.payload.messages
107107

108108
const originalOutput = trace.payload.output

backend/src/admin/relabelRuns.ts

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@ import type {
2424
GetRelevantFilesTrace,
2525
Relabel,
2626
} from '@codebuff/bigquery'
27-
import type { Message } from '@codebuff/common/types/messages/codebuff-message'
27+
import type { PromptAiSdkFn } from '@codebuff/common/types/contracts/llm'
2828
import type { Logger } from '@codebuff/common/types/contracts/logger'
29+
import type { ParamsExcluding } from '@codebuff/common/types/function-params'
30+
import type { Message } from '@codebuff/common/types/messages/codebuff-message'
2931
import type { Request, Response } from 'express'
3032

3133
// --- GET Handler Logic ---
@@ -153,6 +155,7 @@ export async function relabelForUserHandler(params: {
153155
const relaceResults = relabelUsingFullFilesForUser({
154156
userId,
155157
limit,
158+
promptAiSdk,
156159
logger,
157160
})
158161

@@ -265,11 +268,16 @@ export async function relabelForUserHandler(params: {
265268
}
266269
}
267270

268-
async function relabelUsingFullFilesForUser(params: {
269-
userId: string
270-
limit: number
271-
logger: Logger
272-
}) {
271+
async function relabelUsingFullFilesForUser(
272+
params: {
273+
userId: string
274+
limit: number
275+
logger: Logger
276+
} & ParamsExcluding<
277+
typeof relabelWithClaudeWithFullFileContext,
278+
'trace' | 'fileBlobs' | 'model'
279+
>,
280+
) {
273281
const { userId, limit, logger } = params
274282
// TODO: We need to figure out changing _everything_ to use `getTracesAndAllDataForUser`
275283
const tracesBundles = await getTracesAndAllDataForUser(userId)
@@ -308,10 +316,10 @@ async function relabelUsingFullFilesForUser(params: {
308316
) {
309317
relabelPromises.push(
310318
relabelWithClaudeWithFullFileContext({
319+
...params,
311320
trace,
312321
fileBlobs,
313322
model,
314-
logger,
315323
}),
316324
)
317325
didRelabel = true
@@ -392,9 +400,10 @@ export async function relabelWithClaudeWithFullFileContext(params: {
392400
fileBlobs: GetExpandedFileContextForTrainingBlobTrace
393401
model: string
394402
dataset?: string
403+
promptAiSdk: PromptAiSdkFn
395404
logger: Logger
396405
}) {
397-
const { trace, fileBlobs, model, dataset, logger } = params
406+
const { trace, fileBlobs, model, dataset, promptAiSdk, logger } = params
398407
if (dataset) {
399408
await setupBigQuery({ dataset, logger })
400409
}

backend/src/check-terminal-command.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import { models } from '@codebuff/common/old-constants'
22
import { withTimeout } from '@codebuff/common/util/promise'
33

4-
import { promptAiSdk } from './llm-apis/vercel-ai-sdk/ai-sdk'
5-
4+
import type { PromptAiSdkFn } from '@codebuff/common/types/contracts/llm'
65
import type { Logger } from '@codebuff/common/types/contracts/logger'
76
import type { ParamsExcluding } from '@codebuff/common/types/function-params'
87

@@ -13,10 +12,11 @@ import type { ParamsExcluding } from '@codebuff/common/types/function-params'
1312
export async function checkTerminalCommand(
1413
params: {
1514
prompt: string
15+
promptAiSdk: PromptAiSdkFn
1616
logger: Logger
17-
} & ParamsExcluding<typeof promptAiSdk, 'messages' | 'model'>,
17+
} & ParamsExcluding<PromptAiSdkFn, 'messages' | 'model'>,
1818
): Promise<string | null> {
19-
const { prompt, logger } = params
19+
const { prompt, promptAiSdk, logger } = params
2020
if (!prompt?.trim()) {
2121
return null
2222
}

0 commit comments

Comments
 (0)