Skip to content

Commit 7d95bfc

Browse files
committed
add getUserInfoFromApiKey to AgentRuntimeDeps
1 parent 5a2e51e commit 7d95bfc

File tree

11 files changed

+96
-118
lines changed

11 files changed

+96
-118
lines changed

backend/src/__tests__/request-files-prompt.test.ts

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import {
99
beforeAll,
1010
beforeEach,
1111
mock as bunMockFn,
12-
spyOn as bunSpyOn,
1312
describe,
1413
expect,
1514
it,
@@ -63,8 +62,6 @@ describe('requestRelevantFiles', () => {
6362
const mockCostMode: CostMode = 'normal'
6463
const mockRepoId = 'owner/repo'
6564

66-
let getCustomFilePickerConfigForOrgSpy: any // Explicitly typed as any
67-
6865
beforeAll(() => {
6966
mockModule('@codebuff/backend/llm-apis/gemini-with-fallbacks', () => ({
7067
promptFlashWithFallbacks: bunMockFn(() =>
@@ -100,21 +97,6 @@ describe('requestRelevantFiles', () => {
10097
beforeEach(() => {
10198
agentRuntimeImpl = { ...TEST_AGENT_RUNTIME_IMPL }
10299

103-
// If the spy was created in a previous test, restore it
104-
if (
105-
getCustomFilePickerConfigForOrgSpy &&
106-
typeof getCustomFilePickerConfigForOrgSpy.mockRestore === 'function'
107-
) {
108-
getCustomFilePickerConfigForOrgSpy.mockRestore()
109-
getCustomFilePickerConfigForOrgSpy = undefined
110-
}
111-
112-
// Use the directly imported bunSpyOn
113-
getCustomFilePickerConfigForOrgSpy = bunSpyOn(
114-
OriginalRequestFilesPromptModule,
115-
'getCustomFilePickerConfigForOrg',
116-
).mockResolvedValue(null)
117-
118100
const promptFlashWithFallbacksMock =
119101
geminiWithFallbacksModule.promptFlashWithFallbacks as Mock<
120102
typeof geminiWithFallbacksModule.promptFlashWithFallbacks
@@ -140,7 +122,6 @@ describe('requestRelevantFiles', () => {
140122
expect(
141123
geminiWithFallbacksModule.promptFlashWithFallbacks,
142124
).toHaveBeenCalled()
143-
expect(getCustomFilePickerConfigForOrgSpy).toHaveBeenCalled()
144125
})
145126

146127
it('should use custom file counts from config', async () => {
@@ -149,7 +130,6 @@ describe('requestRelevantFiles', () => {
149130
customFileCounts: { normal: 5 },
150131
maxFilesPerRequest: 10,
151132
}
152-
getCustomFilePickerConfigForOrgSpy!.mockResolvedValue(customConfig as any)
153133

154134
await OriginalRequestFilesPromptModule.requestRelevantFiles({
155135
...agentRuntimeImpl,
@@ -167,15 +147,13 @@ describe('requestRelevantFiles', () => {
167147
expect(
168148
geminiWithFallbacksModule.promptFlashWithFallbacks,
169149
).toHaveBeenCalled()
170-
expect(getCustomFilePickerConfigForOrgSpy).toHaveBeenCalled()
171150
})
172151

173152
it('should use custom maxFilesPerRequest from config', async () => {
174153
const customConfig = {
175154
modelName: 'ft_filepicker_005',
176155
maxFilesPerRequest: 3,
177156
}
178-
getCustomFilePickerConfigForOrgSpy!.mockResolvedValue(customConfig as any)
179157

180158
const result = await OriginalRequestFilesPromptModule.requestRelevantFiles({
181159
...agentRuntimeImpl,
@@ -194,14 +172,12 @@ describe('requestRelevantFiles', () => {
194172
if (result) {
195173
expect(result.length).toBeLessThanOrEqual(3)
196174
}
197-
expect(getCustomFilePickerConfigForOrgSpy).toHaveBeenCalled()
198175
})
199176

200177
it('should use custom modelName from config', async () => {
201178
const customConfig = {
202179
modelName: 'ft_filepicker_010',
203180
}
204-
getCustomFilePickerConfigForOrgSpy!.mockResolvedValue(customConfig as any)
205181

206182
await OriginalRequestFilesPromptModule.requestRelevantFiles({
207183
...agentRuntimeImpl,
@@ -223,14 +199,12 @@ describe('requestRelevantFiles', () => {
223199
useFinetunedModel: finetunedVertexModels.ft_filepicker_010,
224200
}),
225201
)
226-
expect(getCustomFilePickerConfigForOrgSpy).toHaveBeenCalled()
227202
})
228203

229204
it('should use default model if custom modelName is invalid', async () => {
230205
const customConfig = {
231206
modelName: 'invalid-model-name',
232207
}
233-
getCustomFilePickerConfigForOrgSpy!.mockResolvedValue(customConfig as any)
234208

235209
await OriginalRequestFilesPromptModule.requestRelevantFiles({
236210
...agentRuntimeImpl,
@@ -253,6 +227,5 @@ describe('requestRelevantFiles', () => {
253227
useFinetunedModel: expectedModel,
254228
}),
255229
)
256-
expect(getCustomFilePickerConfigForOrgSpy).toHaveBeenCalled()
257230
})
258231
})

backend/src/api/org.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import { z } from 'zod/v4'
44

55
import { extractAuthTokenFromHeader } from '../util/auth-helpers'
66
import { logger } from '../util/logger'
7-
import { getUserIdFromAuthToken } from '../websockets/websocket-action'
7+
import { getUserInfoFromApiKey } from '../websockets/auth'
88

99
import type {
1010
Request as ExpressRequest,
@@ -35,7 +35,9 @@ async function isRepoCoveredHandler(
3535
.status(401)
3636
.json({ error: 'Missing x-codebuff-api-key header' })
3737
}
38-
const userId = await getUserIdFromAuthToken({ authToken })
38+
const userId = (
39+
await getUserInfoFromApiKey({ apiKey: authToken, fields: ['id'] })
40+
)?.id
3941

4042
if (!userId) {
4143
return res.status(401).json({ error: INVALID_AUTH_TOKEN_MESSAGE })

backend/src/api/usage.ts

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
import { getOrganizationUsageResponse } from '@codebuff/billing'
2-
import db from '@codebuff/common/db'
3-
import * as schema from '@codebuff/common/db/schema'
42
import { INVALID_AUTH_TOKEN_MESSAGE } from '@codebuff/common/old-constants'
5-
import { eq } from 'drizzle-orm'
63
import { z } from 'zod/v4'
74

85
import { checkAuth } from '../util/check-auth'
96
import { logger } from '../util/logger'
7+
import { getUserInfoFromApiKey } from '../websockets/auth'
108
import { genUsageResponse } from '../websockets/websocket-action'
119

1210
import type {
@@ -21,18 +19,6 @@ const usageRequestSchema = z.object({
2119
orgId: z.string().optional(),
2220
})
2321

24-
async function getUserIdFromAuthToken(
25-
token: string,
26-
): Promise<string | undefined> {
27-
const user = await db
28-
.select({ userId: schema.user.id })
29-
.from(schema.user)
30-
.innerJoin(schema.session, eq(schema.user.id, schema.session.userId))
31-
.where(eq(schema.session.sessionToken, token))
32-
.then((users) => users[0]?.userId)
33-
return user
34-
}
35-
3622
async function usageHandler(
3723
req: ExpressRequest,
3824
res: ExpressResponse,
@@ -59,7 +45,7 @@ async function usageHandler(
5945
}
6046

6147
const userId = authToken
62-
? await getUserIdFromAuthToken(authToken)
48+
? (await getUserInfoFromApiKey({ apiKey: authToken, fields: ['id'] }))?.id
6349
: undefined
6450

6551
if (!userId) {

backend/src/impl/agent-runtime.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@ import {
55
promptAiSdkStructured,
66
} from '../llm-apis/vercel-ai-sdk/ai-sdk'
77
import { logger } from '../util/logger'
8+
import { getUserInfoFromApiKey } from '../websockets/auth'
89

910
import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime'
1011

1112
export const BACKEND_AGENT_RUNTIME_IMPL: AgentRuntimeDeps = Object.freeze({
1213
// Database
14+
getUserInfoFromApiKey,
1315
startAgentRun,
1416
finishAgentRun,
1517
addAgentStep,

backend/src/websockets/auth.ts

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,40 +2,29 @@ import db from '@codebuff/common/db'
22
import * as schema from '@codebuff/common/db/schema'
33
import { eq } from 'drizzle-orm'
44

5-
export interface UserInfo {
6-
id: string
7-
email: string
8-
discord_id: string | null
9-
}
5+
import type {
6+
GetUserInfoFromApiKeyInput,
7+
GetUserInfoFromApiKeyOutput,
8+
UserColumn,
9+
} from '@codebuff/common/types/contracts/database'
1010

11-
export async function getUserIdFromAuthToken(
12-
authToken: string,
13-
): Promise<string | undefined> {
14-
const user = await db
15-
.select({ id: schema.user.id })
16-
.from(schema.user)
17-
.leftJoin(schema.session, eq(schema.user.id, schema.session.userId))
18-
.where(eq(schema.session.sessionToken, authToken))
19-
.limit(1)
20-
.then((rows) => rows[0])
11+
export async function getUserInfoFromApiKey<T extends UserColumn>(
12+
params: GetUserInfoFromApiKeyInput<T>,
13+
): GetUserInfoFromApiKeyOutput<T> {
14+
const { apiKey, fields } = params
2115

22-
return user?.id
23-
}
16+
// Build a typed selection object for user columns
17+
const userSelection = Object.fromEntries(
18+
fields.map((field) => [field, schema.user[field]]),
19+
) as { [K in T]: (typeof schema.user)[K] }
2420

25-
export async function getUserInfoFromAuthToken(
26-
authToken: string,
27-
): Promise<UserInfo | undefined> {
28-
const user = await db
29-
.select({
30-
id: schema.user.id,
31-
email: schema.user.email,
32-
discord_id: schema.user.discord_id,
33-
})
21+
const rows = await db
22+
.select({ user: userSelection }) // <-- important: nest under 'user'
3423
.from(schema.user)
3524
.leftJoin(schema.session, eq(schema.user.id, schema.session.userId))
36-
.where(eq(schema.session.sessionToken, authToken))
25+
.where(eq(schema.session.sessionToken, apiKey))
3726
.limit(1)
38-
.then((rows) => rows[0])
3927

40-
return user
28+
// Drizzle returns { user: ..., session: ... }, we return only the user part
29+
return rows[0]?.user ?? null
4130
}

backend/src/websockets/middleware.ts

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,24 @@ import * as schema from '@codebuff/common/db/schema'
1212
import { pluralize } from '@codebuff/common/util/string'
1313
import { eq } from 'drizzle-orm'
1414

15-
import { getUserInfoFromAuthToken } from './auth'
15+
import { getUserInfoFromApiKey } from './auth'
1616
import { updateRequestContext } from './request-context'
1717
import { sendAction } from './websocket-action'
1818
import { withAppContext } from '../context/app-context'
1919
import { BACKEND_AGENT_RUNTIME_IMPL } from '../impl/agent-runtime'
2020
import { checkAuth } from '../util/check-auth'
2121

22-
import type { UserInfo } from './auth'
2322
import type { ClientAction, ServerAction } from '@codebuff/common/actions'
2423
import type { AgentRuntimeDeps } from '@codebuff/common/types/contracts/agent-runtime'
24+
import type { GetUserInfoFromApiKeyFn } from '@codebuff/common/types/contracts/database'
2525
import type { Logger } from '@codebuff/common/types/contracts/logger'
2626
import type { WebSocket } from 'ws'
2727

2828
type MiddlewareCallback = (params: {
2929
action: ClientAction
3030
clientSessionId: string
3131
ws: WebSocket
32-
userInfo: UserInfo | undefined
32+
userInfo: { id: string } | null
3333
logger: Logger
3434
}) => Promise<void | ServerAction>
3535

@@ -64,7 +64,7 @@ export class WebSocketMiddleware {
6464
action: ClientAction<T>
6565
clientSessionId: string
6666
ws: WebSocket
67-
userInfo: UserInfo | undefined
67+
userInfo: { id: string } | null
6868
logger: Logger
6969
}) => Promise<void | ServerAction>,
7070
) {
@@ -76,14 +76,18 @@ export class WebSocketMiddleware {
7676
clientSessionId: string
7777
ws: WebSocket
7878
silent?: boolean
79+
getUserInfoFromApiKey: GetUserInfoFromApiKeyFn
7980
logger: Logger
8081
}): Promise<boolean> {
8182
const { action, clientSessionId, ws, silent, logger } = params
8283

8384
const userInfo =
8485
'authToken' in action && action.authToken
85-
? await getUserInfoFromAuthToken(action.authToken)
86-
: undefined
86+
? await getUserInfoFromApiKey({
87+
apiKey: action.authToken,
88+
fields: ['id'],
89+
})
90+
: null
8791

8892
for (const middleware of this.middlewares) {
8993
const actionOrContinue = await middleware({
@@ -130,7 +134,10 @@ export class WebSocketMiddleware {
130134
) => {
131135
const userInfo =
132136
'authToken' in action
133-
? await getUserInfoFromAuthToken(action.authToken!)
137+
? await getUserInfoFromApiKey({
138+
apiKey: action.authToken!,
139+
fields: ['id', 'email', 'discord_id'],
140+
})
134141
: undefined
135142

136143
// Use the new combined context - much cleaner!

0 commit comments

Comments
 (0)