Skip to content

Commit 1ae20a1

Browse files
committed
pass in logger to middleware
1 parent ac87a22 commit 1ae20a1

File tree

2 files changed

+51
-35
lines changed

2 files changed

+51
-35
lines changed

backend/src/websockets/middleware.ts

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,19 @@ import { updateRequestContext } from './request-context'
1717
import { sendAction } from './websocket-action'
1818
import { withAppContext } from '../context/app-context'
1919
import { checkAuth } from '../util/check-auth'
20-
import { logger } from '../util/logger'
2120

2221
import type { UserInfo } from './auth'
2322
import type { ClientAction, ServerAction } from '@codebuff/common/actions'
23+
import type { Logger } from '@codebuff/types/logger'
2424
import type { WebSocket } from 'ws'
2525

26-
type MiddlewareCallback = (
27-
action: ClientAction,
28-
clientSessionId: string,
29-
ws: WebSocket,
30-
userInfo: UserInfo | undefined,
31-
) => Promise<void | ServerAction>
26+
type MiddlewareCallback = (params: {
27+
action: ClientAction
28+
clientSessionId: string
29+
ws: WebSocket
30+
userInfo: UserInfo | undefined
31+
logger: Logger
32+
}) => Promise<void | ServerAction>
3233

3334
function getServerErrorAction<T extends ClientAction>(
3435
action: T,
@@ -52,34 +53,39 @@ export class WebSocketMiddleware {
5253
private middlewares: Array<MiddlewareCallback> = []
5354

5455
use<T extends ClientAction['type']>(
55-
callback: (
56-
action: ClientAction<T>,
57-
clientSessionId: string,
58-
ws: WebSocket,
59-
userInfo: UserInfo | undefined,
60-
) => Promise<void | ServerAction>,
56+
callback: (params: {
57+
action: ClientAction<T>
58+
clientSessionId: string
59+
ws: WebSocket
60+
userInfo: UserInfo | undefined
61+
logger: Logger
62+
}) => Promise<void | ServerAction>,
6163
) {
6264
this.middlewares.push(callback as MiddlewareCallback)
6365
}
6466

65-
async execute(
66-
action: ClientAction,
67-
clientSessionId: string,
68-
ws: WebSocket,
69-
options: { silent?: boolean } = {},
70-
): Promise<boolean> {
67+
async execute(params: {
68+
action: ClientAction
69+
clientSessionId: string
70+
ws: WebSocket
71+
silent?: boolean
72+
logger: Logger
73+
}): Promise<boolean> {
74+
const { action, clientSessionId, ws, silent, logger } = params
75+
7176
const userInfo =
7277
'authToken' in action && action.authToken
7378
? await getUserInfoFromAuthToken(action.authToken)
7479
: undefined
7580

7681
for (const middleware of this.middlewares) {
77-
const actionOrContinue = await middleware(
82+
const actionOrContinue = await middleware({
7883
action,
7984
clientSessionId,
8085
ws,
8186
userInfo,
82-
)
87+
logger,
88+
})
8389
if (actionOrContinue) {
8490
logger.warn(
8591
{
@@ -89,7 +95,7 @@ export class WebSocketMiddleware {
8995
},
9096
'Middleware execution halted.',
9197
)
92-
if (!options.silent) {
98+
if (!silent) {
9399
sendAction(ws, actionOrContinue)
94100
}
95101
return false
@@ -98,14 +104,17 @@ export class WebSocketMiddleware {
98104
return true
99105
}
100106

101-
run<T extends ClientAction['type']>(
107+
run<T extends ClientAction['type']>(params: {
102108
baseAction: (
103109
action: ClientAction<T>,
104110
clientSessionId: string,
105111
ws: WebSocket,
106-
) => void,
107-
options: { silent?: boolean } = {},
108-
) {
112+
) => void
113+
silent?: boolean
114+
logger: Logger
115+
}) {
116+
const { baseAction, silent, logger } = params
117+
109118
return async (
110119
action: ClientAction<T>,
111120
clientSessionId: string,
@@ -126,12 +135,13 @@ export class WebSocketMiddleware {
126135
},
127136
{}, // request context starts empty
128137
async () => {
129-
const shouldContinue = await this.execute(
138+
const shouldContinue = await this.execute({
130139
action,
131140
clientSessionId,
132141
ws,
133-
options,
134-
)
142+
silent,
143+
logger,
144+
})
135145
if (shouldContinue) {
136146
baseAction(action, clientSessionId, ws)
137147
}
@@ -143,7 +153,7 @@ export class WebSocketMiddleware {
143153

144154
export const protec = new WebSocketMiddleware()
145155

146-
protec.use(async (action, clientSessionId, ws, userInfo) =>
156+
protec.use(async ({ action, clientSessionId, ws, userInfo, logger }) =>
147157
checkAuth({
148158
fingerprintId: 'fingerprintId' in action ? action.fingerprintId : undefined,
149159
authToken: 'authToken' in action ? action.authToken : undefined,
@@ -153,7 +163,7 @@ protec.use(async (action, clientSessionId, ws, userInfo) =>
153163
)
154164

155165
// Organization repository coverage detection middleware
156-
protec.use(async (action, clientSessionId, ws, userInfo) => {
166+
protec.use(async ({ action, clientSessionId, ws, userInfo, logger }) => {
157167
const userId = userInfo?.id
158168

159169
// Only process actions that have repoUrl as a valid string
@@ -293,7 +303,7 @@ protec.use(async (action, clientSessionId, ws, userInfo) => {
293303
return undefined
294304
})
295305

296-
protec.use(async (action, clientSessionId, ws, userInfo) => {
306+
protec.use(async ({ action, clientSessionId, ws, userInfo, logger }) => {
297307
const userId = userInfo?.id
298308
const fingerprintId =
299309
'fingerprintId' in action ? action.fingerprintId : 'unknown-fingerprint'

backend/src/websockets/websocket-action.ts

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -396,9 +396,15 @@ export const onWebsocketAction = async (
396396
}
397397

398398
// Register action handlers
399-
subscribeToAction('prompt', protec.run(onPrompt))
400-
subscribeToAction('init', protec.run(onInit, { silent: true }))
401-
subscribeToAction('cancel-user-input', protec.run(onCancelUserInput))
399+
subscribeToAction('prompt', protec.run({ baseAction: onPrompt, logger }))
400+
subscribeToAction(
401+
'init',
402+
protec.run({ baseAction: onInit, silent: true, logger }),
403+
)
404+
subscribeToAction(
405+
'cancel-user-input',
406+
protec.run({ baseAction: onCancelUserInput, logger }),
407+
)
402408

403409
/**
404410
* Requests multiple files from the client

0 commit comments

Comments
 (0)