Skip to content

Commit ed3f9dc

Browse files
Add tests for RegisterToolsWithMiddleware and RegisterAllWithMiddleware
Tests verify that middleware: - Receives tool name at registration time (outer function called) - Can read context values passed through from tool handlers - Can detect IsError on result without Go error (ghErrors pattern) - Can modify the result before returning - Is applied to all tools when using RegisterAllWithMiddleware - Handles nil middleware gracefully
1 parent b6289d3 commit ed3f9dc

File tree

1 file changed

+259
-0
lines changed

1 file changed

+259
-0
lines changed

pkg/inventory/registry_test.go

Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1643,3 +1643,262 @@ func TestFilteringOrder(t *testing.T) {
16431643
}
16441644
}
16451645
}
1646+
1647+
// contextKey is a custom type for context keys to avoid collisions
1648+
type contextKey string
1649+
1650+
const testErrorKey contextKey = "test_error"
1651+
1652+
func TestRegisterToolsWithMiddleware(t *testing.T) {
1653+
// Test that middleware can:
1654+
// 1. Access the tool name
1655+
// 2. Read context values set by the tool handler
1656+
// 3. Inspect and modify the result
1657+
// 4. Access IsError on the result
1658+
1659+
t.Run("middleware receives tool name and can wrap handler", func(t *testing.T) {
1660+
var capturedToolNameAtRegistration string
1661+
var capturedToolNameAtInvocation string
1662+
var handlerWasCalled bool
1663+
1664+
tool := NewServerToolFromHandler(
1665+
mcp.Tool{
1666+
Name: "test_tool",
1667+
Annotations: &mcp.ToolAnnotations{ReadOnlyHint: true},
1668+
InputSchema: json.RawMessage(`{"type":"object","properties":{}}`),
1669+
},
1670+
testToolsetMetadata("test"),
1671+
func(_ any) mcp.ToolHandler {
1672+
return func(_ context.Context, _ *mcp.CallToolRequest) (*mcp.CallToolResult, error) {
1673+
handlerWasCalled = true
1674+
return &mcp.CallToolResult{
1675+
Content: []mcp.Content{&mcp.TextContent{Text: "success"}},
1676+
}, nil
1677+
}
1678+
},
1679+
)
1680+
1681+
middleware := func(toolName string, handler mcp.ToolHandler) mcp.ToolHandler {
1682+
// This outer function is called at registration time
1683+
capturedToolNameAtRegistration = toolName
1684+
return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) {
1685+
// This inner function is called at invocation time
1686+
capturedToolNameAtInvocation = toolName
1687+
return handler(ctx, req)
1688+
}
1689+
}
1690+
1691+
server := mcp.NewServer(&mcp.Implementation{Name: "test"}, nil)
1692+
reg := NewBuilder().SetTools([]ServerTool{tool}).WithToolsets([]string{"all"}).Build()
1693+
reg.RegisterToolsWithMiddleware(context.Background(), server, nil, middleware)
1694+
1695+
// Verify the middleware is applied at registration time (wrapping)
1696+
// by checking that the tool name was captured in the outer function
1697+
if capturedToolNameAtRegistration != "test_tool" {
1698+
t.Errorf("Expected middleware to capture tool name at registration 'test_tool', got %q", capturedToolNameAtRegistration)
1699+
}
1700+
1701+
// The inner function (invocation-time capture) should not have been called yet
1702+
if capturedToolNameAtInvocation != "" {
1703+
t.Error("Inner middleware function should not be called during registration")
1704+
}
1705+
1706+
// Handler shouldn't be called until tool is invoked
1707+
if handlerWasCalled {
1708+
t.Error("Handler should not be called during registration")
1709+
}
1710+
})
1711+
1712+
t.Run("middleware can read context values set by tool", func(t *testing.T) {
1713+
var middlewareSeenError string
1714+
1715+
tool := NewServerToolFromHandler(
1716+
mcp.Tool{
1717+
Name: "error_tool",
1718+
Annotations: &mcp.ToolAnnotations{ReadOnlyHint: true},
1719+
InputSchema: json.RawMessage(`{"type":"object","properties":{}}`),
1720+
},
1721+
testToolsetMetadata("test"),
1722+
func(_ any) mcp.ToolHandler {
1723+
return func(ctx context.Context, _ *mcp.CallToolRequest) (*mcp.CallToolResult, error) {
1724+
// Simulate storing an error in context (like ghErrors does)
1725+
// The context is passed by value, but if we use a pointer in the context
1726+
// the middleware can see modifications
1727+
if ptr, ok := ctx.Value(testErrorKey).(*string); ok && ptr != nil {
1728+
*ptr = "github_api_error: 404 not found"
1729+
}
1730+
return &mcp.CallToolResult{
1731+
Content: []mcp.Content{&mcp.TextContent{Text: "error occurred"}},
1732+
IsError: true,
1733+
}, nil // Note: returning nil error, but IsError is true
1734+
}
1735+
},
1736+
)
1737+
1738+
middleware := func(_ string, handler mcp.ToolHandler) mcp.ToolHandler {
1739+
return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) {
1740+
// Set up a pointer in context that the handler can write to
1741+
errorHolder := ""
1742+
ctx = context.WithValue(ctx, testErrorKey, &errorHolder)
1743+
1744+
result, err := handler(ctx, req)
1745+
1746+
// Read what the handler wrote
1747+
middlewareSeenError = errorHolder
1748+
1749+
return result, err
1750+
}
1751+
}
1752+
1753+
server := mcp.NewServer(&mcp.Implementation{Name: "test"}, nil)
1754+
reg := NewBuilder().SetTools([]ServerTool{tool}).WithToolsets([]string{"all"}).Build()
1755+
reg.RegisterToolsWithMiddleware(context.Background(), server, nil, middleware)
1756+
1757+
// Simulate calling the tool - recreate the middleware wrapper
1758+
handler := tool.Handler(nil)
1759+
wrappedHandler := middleware("error_tool", handler)
1760+
1761+
result, err := wrappedHandler(context.Background(), &mcp.CallToolRequest{})
1762+
1763+
if err != nil {
1764+
t.Errorf("Expected no Go error, got %v", err)
1765+
}
1766+
if !result.IsError {
1767+
t.Error("Expected result.IsError to be true")
1768+
}
1769+
if middlewareSeenError != "github_api_error: 404 not found" {
1770+
t.Errorf("Middleware didn't see context error, got: %q", middlewareSeenError)
1771+
}
1772+
})
1773+
1774+
t.Run("middleware can detect IsError without Go error", func(t *testing.T) {
1775+
// This tests the exact pattern used by ghErrors.NewGitHubAPIStatusErrorResponse
1776+
// which returns (result_with_IsError_true, nil)
1777+
1778+
var middlewareDetectedError bool
1779+
var middlewareDetectedGoError bool
1780+
1781+
tool := NewServerToolFromHandler(
1782+
mcp.Tool{
1783+
Name: "status_error_tool",
1784+
Annotations: &mcp.ToolAnnotations{ReadOnlyHint: true},
1785+
InputSchema: json.RawMessage(`{"type":"object","properties":{}}`),
1786+
},
1787+
testToolsetMetadata("test"),
1788+
func(_ any) mcp.ToolHandler {
1789+
return func(_ context.Context, _ *mcp.CallToolRequest) (*mcp.CallToolResult, error) {
1790+
// This is exactly how ghErrors.NewGitHubAPIStatusErrorResponse returns
1791+
return &mcp.CallToolResult{
1792+
Content: []mcp.Content{&mcp.TextContent{Text: "unexpected status 404: not found"}},
1793+
IsError: true,
1794+
}, nil // No Go error!
1795+
}
1796+
},
1797+
)
1798+
1799+
middleware := func(_ string, handler mcp.ToolHandler) mcp.ToolHandler {
1800+
return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) {
1801+
result, err := handler(ctx, req)
1802+
1803+
// Check both error indicators
1804+
middlewareDetectedGoError = (err != nil)
1805+
middlewareDetectedError = (result != nil && result.IsError)
1806+
1807+
return result, err
1808+
}
1809+
}
1810+
1811+
handler := tool.Handler(nil)
1812+
wrappedHandler := middleware("status_error_tool", handler)
1813+
1814+
_, _ = wrappedHandler(context.Background(), &mcp.CallToolRequest{})
1815+
1816+
if middlewareDetectedGoError {
1817+
t.Error("Should NOT detect Go error (it's nil)")
1818+
}
1819+
if !middlewareDetectedError {
1820+
t.Error("SHOULD detect error via result.IsError")
1821+
}
1822+
})
1823+
1824+
t.Run("middleware can modify result", func(t *testing.T) {
1825+
tool := NewServerToolFromHandler(
1826+
mcp.Tool{
1827+
Name: "modifiable_tool",
1828+
Annotations: &mcp.ToolAnnotations{ReadOnlyHint: true},
1829+
InputSchema: json.RawMessage(`{"type":"object","properties":{}}`),
1830+
},
1831+
testToolsetMetadata("test"),
1832+
func(_ any) mcp.ToolHandler {
1833+
return func(_ context.Context, _ *mcp.CallToolRequest) (*mcp.CallToolResult, error) {
1834+
return &mcp.CallToolResult{
1835+
Content: []mcp.Content{&mcp.TextContent{Text: "original"}},
1836+
}, nil
1837+
}
1838+
},
1839+
)
1840+
1841+
middleware := func(_ string, handler mcp.ToolHandler) mcp.ToolHandler {
1842+
return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) {
1843+
result, err := handler(ctx, req)
1844+
1845+
// Middleware can modify the result
1846+
if result != nil {
1847+
result.Content = []mcp.Content{&mcp.TextContent{Text: "modified by middleware"}}
1848+
}
1849+
1850+
return result, err
1851+
}
1852+
}
1853+
1854+
handler := tool.Handler(nil)
1855+
wrappedHandler := middleware("modifiable_tool", handler)
1856+
1857+
result, _ := wrappedHandler(context.Background(), &mcp.CallToolRequest{})
1858+
1859+
if result == nil {
1860+
t.Fatal("Expected result")
1861+
}
1862+
textContent, ok := result.Content[0].(*mcp.TextContent)
1863+
if !ok {
1864+
t.Fatal("Expected TextContent")
1865+
}
1866+
if textContent.Text != "modified by middleware" {
1867+
t.Errorf("Expected modified text, got: %s", textContent.Text)
1868+
}
1869+
})
1870+
1871+
t.Run("RegisterAllWithMiddleware applies middleware to tools", func(t *testing.T) {
1872+
middlewareCallCount := 0
1873+
1874+
tools := []ServerTool{
1875+
mockTool("tool1", "toolset1", true),
1876+
mockTool("tool2", "toolset1", true),
1877+
}
1878+
1879+
middleware := func(_ string, handler mcp.ToolHandler) mcp.ToolHandler {
1880+
middlewareCallCount++
1881+
return handler
1882+
}
1883+
1884+
server := mcp.NewServer(&mcp.Implementation{Name: "test"}, nil)
1885+
reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build()
1886+
reg.RegisterAllWithMiddleware(context.Background(), server, nil, middleware)
1887+
1888+
// Middleware should be called once per tool during registration
1889+
if middlewareCallCount != 2 {
1890+
t.Errorf("Expected middleware to be called 2 times, got %d", middlewareCallCount)
1891+
}
1892+
})
1893+
1894+
t.Run("nil middleware is handled gracefully", func(_ *testing.T) {
1895+
tool := mockTool("test_tool", "toolset1", true)
1896+
1897+
server := mcp.NewServer(&mcp.Implementation{Name: "test"}, nil)
1898+
reg := NewBuilder().SetTools([]ServerTool{tool}).WithToolsets([]string{"all"}).Build()
1899+
1900+
// Should not panic with nil middleware
1901+
reg.RegisterToolsWithMiddleware(context.Background(), server, nil, nil)
1902+
reg.RegisterAllWithMiddleware(context.Background(), server, nil, nil)
1903+
})
1904+
}

0 commit comments

Comments
 (0)