Skip to content

Commit 3c453dd

Browse files
refactor: inject deps via context instead of closures
This refactor addresses performance issues in per-request server scenarios where creating ~90 handler closures per request was causing latency. Changes: - Add ContextWithDeps, DepsFromContext, MustDepsFromContext to dependencies.go - Add NewServerToolWithContextHandler, NewServerToolWithRawContextHandler to inventory - Convert all 89 tool handlers from closure pattern to direct context-based deps - Update all tests to inject deps into context before calling handlers - Mark old NewServerTool and NewServerToolFromHandler as deprecated The new pattern: - Before: func(deps) handler { return func(ctx, req, args) { use deps } } - After: func(ctx, deps, req, args) { use deps } Dependencies are now injected into context once (via ContextWithDeps) and extracted by NewTool internally before passing to handlers. This eliminates closure creation on the hot path for remote servers.
1 parent 97feb5c commit 3c453dd

35 files changed

+5409
-5499
lines changed

pkg/github/actions.go

Lines changed: 562 additions & 590 deletions
Large diffs are not rendered by default.

pkg/github/actions_test.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ func Test_ListWorkflows(t *testing.T) {
114114
request := createMCPRequest(tc.requestArgs)
115115

116116
// Call handler
117-
result, err := handler(context.Background(), &request)
117+
result, err := handler(ContextWithDeps(context.Background(), deps), &request)
118118

119119
require.NoError(t, err)
120120
require.Equal(t, tc.expectError, result.IsError)
@@ -203,7 +203,7 @@ func Test_RunWorkflow(t *testing.T) {
203203
request := createMCPRequest(tc.requestArgs)
204204

205205
// Call handler
206-
result, err := handler(context.Background(), &request)
206+
result, err := handler(ContextWithDeps(context.Background(), deps), &request)
207207

208208
require.NoError(t, err)
209209
require.Equal(t, tc.expectError, result.IsError)
@@ -299,7 +299,7 @@ func Test_RunWorkflow_WithFilename(t *testing.T) {
299299
request := createMCPRequest(tc.requestArgs)
300300

301301
// Call handler
302-
result, err := handler(context.Background(), &request)
302+
result, err := handler(ContextWithDeps(context.Background(), deps), &request)
303303

304304
require.NoError(t, err)
305305
require.Equal(t, tc.expectError, result.IsError)
@@ -407,7 +407,7 @@ func Test_CancelWorkflowRun(t *testing.T) {
407407
request := createMCPRequest(tc.requestArgs)
408408

409409
// Call handler
410-
result, err := handler(context.Background(), &request)
410+
result, err := handler(ContextWithDeps(context.Background(), deps), &request)
411411

412412
require.NoError(t, err)
413413
require.Equal(t, tc.expectError, result.IsError)
@@ -537,7 +537,7 @@ func Test_ListWorkflowRunArtifacts(t *testing.T) {
537537
request := createMCPRequest(tc.requestArgs)
538538

539539
// Call handler
540-
result, err := handler(context.Background(), &request)
540+
result, err := handler(ContextWithDeps(context.Background(), deps), &request)
541541

542542
require.NoError(t, err)
543543
require.Equal(t, tc.expectError, result.IsError)
@@ -627,7 +627,7 @@ func Test_DownloadWorkflowRunArtifact(t *testing.T) {
627627
request := createMCPRequest(tc.requestArgs)
628628

629629
// Call handler
630-
result, err := handler(context.Background(), &request)
630+
result, err := handler(ContextWithDeps(context.Background(), deps), &request)
631631

632632
require.NoError(t, err)
633633
require.Equal(t, tc.expectError, result.IsError)
@@ -713,7 +713,7 @@ func Test_DeleteWorkflowRunLogs(t *testing.T) {
713713
request := createMCPRequest(tc.requestArgs)
714714

715715
// Call handler
716-
result, err := handler(context.Background(), &request)
716+
result, err := handler(ContextWithDeps(context.Background(), deps), &request)
717717

718718
require.NoError(t, err)
719719
require.Equal(t, tc.expectError, result.IsError)
@@ -817,7 +817,7 @@ func Test_GetWorkflowRunUsage(t *testing.T) {
817817
request := createMCPRequest(tc.requestArgs)
818818

819819
// Call handler
820-
result, err := handler(context.Background(), &request)
820+
result, err := handler(ContextWithDeps(context.Background(), deps), &request)
821821

822822
require.NoError(t, err)
823823
require.Equal(t, tc.expectError, result.IsError)
@@ -1082,7 +1082,7 @@ func Test_GetJobLogs(t *testing.T) {
10821082
request := createMCPRequest(tc.requestArgs)
10831083

10841084
// Call handler
1085-
result, err := handler(context.Background(), &request)
1085+
result, err := handler(ContextWithDeps(context.Background(), deps), &request)
10861086

10871087
require.NoError(t, err)
10881088
require.Equal(t, tc.expectError, result.IsError)
@@ -1149,7 +1149,7 @@ func Test_GetJobLogs_WithContentReturn(t *testing.T) {
11491149
"return_content": true,
11501150
})
11511151

1152-
result, err := handler(context.Background(), &request)
1152+
result, err := handler(ContextWithDeps(context.Background(), deps), &request)
11531153
require.NoError(t, err)
11541154
require.False(t, result.IsError)
11551155

@@ -1202,7 +1202,7 @@ func Test_GetJobLogs_WithContentReturnAndTailLines(t *testing.T) {
12021202
"tail_lines": float64(1), // Requesting last 1 line
12031203
})
12041204

1205-
result, err := handler(context.Background(), &request)
1205+
result, err := handler(ContextWithDeps(context.Background(), deps), &request)
12061206
require.NoError(t, err)
12071207
require.False(t, result.IsError)
12081208

@@ -1254,7 +1254,7 @@ func Test_GetJobLogs_WithContentReturnAndLargeTailLines(t *testing.T) {
12541254
"tail_lines": float64(100),
12551255
})
12561256

1257-
result, err := handler(context.Background(), &request)
1257+
result, err := handler(ContextWithDeps(context.Background(), deps), &request)
12581258
require.NoError(t, err)
12591259
require.False(t, result.IsError)
12601260

pkg/github/code_scanning.go

Lines changed: 84 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -44,51 +44,49 @@ func GetCodeScanningAlert(t translations.TranslationHelperFunc) inventory.Server
4444
Required: []string{"owner", "repo", "alertNumber"},
4545
},
4646
},
47-
func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] {
48-
return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) {
49-
owner, err := RequiredParam[string](args, "owner")
50-
if err != nil {
51-
return utils.NewToolResultError(err.Error()), nil, nil
52-
}
53-
repo, err := RequiredParam[string](args, "repo")
54-
if err != nil {
55-
return utils.NewToolResultError(err.Error()), nil, nil
56-
}
57-
alertNumber, err := RequiredInt(args, "alertNumber")
58-
if err != nil {
59-
return utils.NewToolResultError(err.Error()), nil, nil
60-
}
61-
62-
client, err := deps.GetClient(ctx)
63-
if err != nil {
64-
return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil
65-
}
47+
func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) {
48+
owner, err := RequiredParam[string](args, "owner")
49+
if err != nil {
50+
return utils.NewToolResultError(err.Error()), nil, nil
51+
}
52+
repo, err := RequiredParam[string](args, "repo")
53+
if err != nil {
54+
return utils.NewToolResultError(err.Error()), nil, nil
55+
}
56+
alertNumber, err := RequiredInt(args, "alertNumber")
57+
if err != nil {
58+
return utils.NewToolResultError(err.Error()), nil, nil
59+
}
6660

67-
alert, resp, err := client.CodeScanning.GetAlert(ctx, owner, repo, int64(alertNumber))
68-
if err != nil {
69-
return ghErrors.NewGitHubAPIErrorResponse(ctx,
70-
"failed to get alert",
71-
resp,
72-
err,
73-
), nil, nil
74-
}
75-
defer func() { _ = resp.Body.Close() }()
61+
client, err := deps.GetClient(ctx)
62+
if err != nil {
63+
return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil
64+
}
7665

77-
if resp.StatusCode != http.StatusOK {
78-
body, err := io.ReadAll(resp.Body)
79-
if err != nil {
80-
return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil
81-
}
82-
return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get alert", resp, body), nil, nil
83-
}
66+
alert, resp, err := client.CodeScanning.GetAlert(ctx, owner, repo, int64(alertNumber))
67+
if err != nil {
68+
return ghErrors.NewGitHubAPIErrorResponse(ctx,
69+
"failed to get alert",
70+
resp,
71+
err,
72+
), nil, nil
73+
}
74+
defer func() { _ = resp.Body.Close() }()
8475

85-
r, err := json.Marshal(alert)
76+
if resp.StatusCode != http.StatusOK {
77+
body, err := io.ReadAll(resp.Body)
8678
if err != nil {
87-
return utils.NewToolResultErrorFromErr("failed to marshal alert", err), nil, nil
79+
return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil
8880
}
81+
return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get alert", resp, body), nil, nil
82+
}
8983

90-
return utils.NewToolResultText(string(r)), nil, nil
84+
r, err := json.Marshal(alert)
85+
if err != nil {
86+
return utils.NewToolResultErrorFromErr("failed to marshal alert", err), nil, nil
9187
}
88+
89+
return utils.NewToolResultText(string(r)), nil, nil
9290
},
9391
)
9492
}
@@ -137,62 +135,60 @@ func ListCodeScanningAlerts(t translations.TranslationHelperFunc) inventory.Serv
137135
Required: []string{"owner", "repo"},
138136
},
139137
},
140-
func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] {
141-
return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) {
142-
owner, err := RequiredParam[string](args, "owner")
143-
if err != nil {
144-
return utils.NewToolResultError(err.Error()), nil, nil
145-
}
146-
repo, err := RequiredParam[string](args, "repo")
147-
if err != nil {
148-
return utils.NewToolResultError(err.Error()), nil, nil
149-
}
150-
ref, err := OptionalParam[string](args, "ref")
151-
if err != nil {
152-
return utils.NewToolResultError(err.Error()), nil, nil
153-
}
154-
state, err := OptionalParam[string](args, "state")
155-
if err != nil {
156-
return utils.NewToolResultError(err.Error()), nil, nil
157-
}
158-
severity, err := OptionalParam[string](args, "severity")
159-
if err != nil {
160-
return utils.NewToolResultError(err.Error()), nil, nil
161-
}
162-
toolName, err := OptionalParam[string](args, "tool_name")
163-
if err != nil {
164-
return utils.NewToolResultError(err.Error()), nil, nil
165-
}
166-
167-
client, err := deps.GetClient(ctx)
168-
if err != nil {
169-
return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil
170-
}
171-
alerts, resp, err := client.CodeScanning.ListAlertsForRepo(ctx, owner, repo, &github.AlertListOptions{Ref: ref, State: state, Severity: severity, ToolName: toolName})
172-
if err != nil {
173-
return ghErrors.NewGitHubAPIErrorResponse(ctx,
174-
"failed to list alerts",
175-
resp,
176-
err,
177-
), nil, nil
178-
}
179-
defer func() { _ = resp.Body.Close() }()
138+
func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) {
139+
owner, err := RequiredParam[string](args, "owner")
140+
if err != nil {
141+
return utils.NewToolResultError(err.Error()), nil, nil
142+
}
143+
repo, err := RequiredParam[string](args, "repo")
144+
if err != nil {
145+
return utils.NewToolResultError(err.Error()), nil, nil
146+
}
147+
ref, err := OptionalParam[string](args, "ref")
148+
if err != nil {
149+
return utils.NewToolResultError(err.Error()), nil, nil
150+
}
151+
state, err := OptionalParam[string](args, "state")
152+
if err != nil {
153+
return utils.NewToolResultError(err.Error()), nil, nil
154+
}
155+
severity, err := OptionalParam[string](args, "severity")
156+
if err != nil {
157+
return utils.NewToolResultError(err.Error()), nil, nil
158+
}
159+
toolName, err := OptionalParam[string](args, "tool_name")
160+
if err != nil {
161+
return utils.NewToolResultError(err.Error()), nil, nil
162+
}
180163

181-
if resp.StatusCode != http.StatusOK {
182-
body, err := io.ReadAll(resp.Body)
183-
if err != nil {
184-
return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil
185-
}
186-
return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list alerts", resp, body), nil, nil
187-
}
164+
client, err := deps.GetClient(ctx)
165+
if err != nil {
166+
return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil
167+
}
168+
alerts, resp, err := client.CodeScanning.ListAlertsForRepo(ctx, owner, repo, &github.AlertListOptions{Ref: ref, State: state, Severity: severity, ToolName: toolName})
169+
if err != nil {
170+
return ghErrors.NewGitHubAPIErrorResponse(ctx,
171+
"failed to list alerts",
172+
resp,
173+
err,
174+
), nil, nil
175+
}
176+
defer func() { _ = resp.Body.Close() }()
188177

189-
r, err := json.Marshal(alerts)
178+
if resp.StatusCode != http.StatusOK {
179+
body, err := io.ReadAll(resp.Body)
190180
if err != nil {
191-
return utils.NewToolResultErrorFromErr("failed to marshal alerts", err), nil, nil
181+
return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil
192182
}
183+
return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list alerts", resp, body), nil, nil
184+
}
193185

194-
return utils.NewToolResultText(string(r)), nil, nil
186+
r, err := json.Marshal(alerts)
187+
if err != nil {
188+
return utils.NewToolResultErrorFromErr("failed to marshal alerts", err), nil, nil
195189
}
190+
191+
return utils.NewToolResultText(string(r)), nil, nil
196192
},
197193
)
198194
}

pkg/github/code_scanning_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ func Test_GetCodeScanningAlert(t *testing.T) {
9090
request := createMCPRequest(tc.requestArgs)
9191

9292
// Call handler with new signature
93-
result, err := handler(context.Background(), &request)
93+
result, err := handler(ContextWithDeps(context.Background(), deps), &request)
9494

9595
// Verify results
9696
if tc.expectError {
@@ -216,7 +216,7 @@ func Test_ListCodeScanningAlerts(t *testing.T) {
216216
request := createMCPRequest(tc.requestArgs)
217217

218218
// Call handler with new signature
219-
result, err := handler(context.Background(), &request)
219+
result, err := handler(ContextWithDeps(context.Background(), deps), &request)
220220

221221
// Verify results
222222
if tc.expectError {

0 commit comments

Comments
 (0)