@@ -1643,3 +1643,262 @@ func TestFilteringOrder(t *testing.T) {
16431643 }
16441644 }
16451645}
1646+
1647+ // contextKey is a custom type for context keys to avoid collisions
1648+ type contextKey string
1649+
1650+ const testErrorKey contextKey = "test_error"
1651+
1652+ func TestRegisterToolsWithMiddleware (t * testing.T ) {
1653+ // Test that middleware can:
1654+ // 1. Access the tool name
1655+ // 2. Read context values set by the tool handler
1656+ // 3. Inspect and modify the result
1657+ // 4. Access IsError on the result
1658+
1659+ t .Run ("middleware receives tool name and can wrap handler" , func (t * testing.T ) {
1660+ var capturedToolNameAtRegistration string
1661+ var capturedToolNameAtInvocation string
1662+ var handlerWasCalled bool
1663+
1664+ tool := NewServerToolFromHandler (
1665+ mcp.Tool {
1666+ Name : "test_tool" ,
1667+ Annotations : & mcp.ToolAnnotations {ReadOnlyHint : true },
1668+ InputSchema : json .RawMessage (`{"type":"object","properties":{}}` ),
1669+ },
1670+ testToolsetMetadata ("test" ),
1671+ func (_ any ) mcp.ToolHandler {
1672+ return func (_ context.Context , _ * mcp.CallToolRequest ) (* mcp.CallToolResult , error ) {
1673+ handlerWasCalled = true
1674+ return & mcp.CallToolResult {
1675+ Content : []mcp.Content {& mcp.TextContent {Text : "success" }},
1676+ }, nil
1677+ }
1678+ },
1679+ )
1680+
1681+ middleware := func (toolName string , handler mcp.ToolHandler ) mcp.ToolHandler {
1682+ // This outer function is called at registration time
1683+ capturedToolNameAtRegistration = toolName
1684+ return func (ctx context.Context , req * mcp.CallToolRequest ) (* mcp.CallToolResult , error ) {
1685+ // This inner function is called at invocation time
1686+ capturedToolNameAtInvocation = toolName
1687+ return handler (ctx , req )
1688+ }
1689+ }
1690+
1691+ server := mcp .NewServer (& mcp.Implementation {Name : "test" }, nil )
1692+ reg := NewBuilder ().SetTools ([]ServerTool {tool }).WithToolsets ([]string {"all" }).Build ()
1693+ reg .RegisterToolsWithMiddleware (context .Background (), server , nil , middleware )
1694+
1695+ // Verify the middleware is applied at registration time (wrapping)
1696+ // by checking that the tool name was captured in the outer function
1697+ if capturedToolNameAtRegistration != "test_tool" {
1698+ t .Errorf ("Expected middleware to capture tool name at registration 'test_tool', got %q" , capturedToolNameAtRegistration )
1699+ }
1700+
1701+ // The inner function (invocation-time capture) should not have been called yet
1702+ if capturedToolNameAtInvocation != "" {
1703+ t .Error ("Inner middleware function should not be called during registration" )
1704+ }
1705+
1706+ // Handler shouldn't be called until tool is invoked
1707+ if handlerWasCalled {
1708+ t .Error ("Handler should not be called during registration" )
1709+ }
1710+ })
1711+
1712+ t .Run ("middleware can read context values set by tool" , func (t * testing.T ) {
1713+ var middlewareSeenError string
1714+
1715+ tool := NewServerToolFromHandler (
1716+ mcp.Tool {
1717+ Name : "error_tool" ,
1718+ Annotations : & mcp.ToolAnnotations {ReadOnlyHint : true },
1719+ InputSchema : json .RawMessage (`{"type":"object","properties":{}}` ),
1720+ },
1721+ testToolsetMetadata ("test" ),
1722+ func (_ any ) mcp.ToolHandler {
1723+ return func (ctx context.Context , _ * mcp.CallToolRequest ) (* mcp.CallToolResult , error ) {
1724+ // Simulate storing an error in context (like ghErrors does)
1725+ // The context is passed by value, but if we use a pointer in the context
1726+ // the middleware can see modifications
1727+ if ptr , ok := ctx .Value (testErrorKey ).(* string ); ok && ptr != nil {
1728+ * ptr = "github_api_error: 404 not found"
1729+ }
1730+ return & mcp.CallToolResult {
1731+ Content : []mcp.Content {& mcp.TextContent {Text : "error occurred" }},
1732+ IsError : true ,
1733+ }, nil // Note: returning nil error, but IsError is true
1734+ }
1735+ },
1736+ )
1737+
1738+ middleware := func (_ string , handler mcp.ToolHandler ) mcp.ToolHandler {
1739+ return func (ctx context.Context , req * mcp.CallToolRequest ) (* mcp.CallToolResult , error ) {
1740+ // Set up a pointer in context that the handler can write to
1741+ errorHolder := ""
1742+ ctx = context .WithValue (ctx , testErrorKey , & errorHolder )
1743+
1744+ result , err := handler (ctx , req )
1745+
1746+ // Read what the handler wrote
1747+ middlewareSeenError = errorHolder
1748+
1749+ return result , err
1750+ }
1751+ }
1752+
1753+ server := mcp .NewServer (& mcp.Implementation {Name : "test" }, nil )
1754+ reg := NewBuilder ().SetTools ([]ServerTool {tool }).WithToolsets ([]string {"all" }).Build ()
1755+ reg .RegisterToolsWithMiddleware (context .Background (), server , nil , middleware )
1756+
1757+ // Simulate calling the tool - recreate the middleware wrapper
1758+ handler := tool .Handler (nil )
1759+ wrappedHandler := middleware ("error_tool" , handler )
1760+
1761+ result , err := wrappedHandler (context .Background (), & mcp.CallToolRequest {})
1762+
1763+ if err != nil {
1764+ t .Errorf ("Expected no Go error, got %v" , err )
1765+ }
1766+ if ! result .IsError {
1767+ t .Error ("Expected result.IsError to be true" )
1768+ }
1769+ if middlewareSeenError != "github_api_error: 404 not found" {
1770+ t .Errorf ("Middleware didn't see context error, got: %q" , middlewareSeenError )
1771+ }
1772+ })
1773+
1774+ t .Run ("middleware can detect IsError without Go error" , func (t * testing.T ) {
1775+ // This tests the exact pattern used by ghErrors.NewGitHubAPIStatusErrorResponse
1776+ // which returns (result_with_IsError_true, nil)
1777+
1778+ var middlewareDetectedError bool
1779+ var middlewareDetectedGoError bool
1780+
1781+ tool := NewServerToolFromHandler (
1782+ mcp.Tool {
1783+ Name : "status_error_tool" ,
1784+ Annotations : & mcp.ToolAnnotations {ReadOnlyHint : true },
1785+ InputSchema : json .RawMessage (`{"type":"object","properties":{}}` ),
1786+ },
1787+ testToolsetMetadata ("test" ),
1788+ func (_ any ) mcp.ToolHandler {
1789+ return func (_ context.Context , _ * mcp.CallToolRequest ) (* mcp.CallToolResult , error ) {
1790+ // This is exactly how ghErrors.NewGitHubAPIStatusErrorResponse returns
1791+ return & mcp.CallToolResult {
1792+ Content : []mcp.Content {& mcp.TextContent {Text : "unexpected status 404: not found" }},
1793+ IsError : true ,
1794+ }, nil // No Go error!
1795+ }
1796+ },
1797+ )
1798+
1799+ middleware := func (_ string , handler mcp.ToolHandler ) mcp.ToolHandler {
1800+ return func (ctx context.Context , req * mcp.CallToolRequest ) (* mcp.CallToolResult , error ) {
1801+ result , err := handler (ctx , req )
1802+
1803+ // Check both error indicators
1804+ middlewareDetectedGoError = (err != nil )
1805+ middlewareDetectedError = (result != nil && result .IsError )
1806+
1807+ return result , err
1808+ }
1809+ }
1810+
1811+ handler := tool .Handler (nil )
1812+ wrappedHandler := middleware ("status_error_tool" , handler )
1813+
1814+ _ , _ = wrappedHandler (context .Background (), & mcp.CallToolRequest {})
1815+
1816+ if middlewareDetectedGoError {
1817+ t .Error ("Should NOT detect Go error (it's nil)" )
1818+ }
1819+ if ! middlewareDetectedError {
1820+ t .Error ("SHOULD detect error via result.IsError" )
1821+ }
1822+ })
1823+
1824+ t .Run ("middleware can modify result" , func (t * testing.T ) {
1825+ tool := NewServerToolFromHandler (
1826+ mcp.Tool {
1827+ Name : "modifiable_tool" ,
1828+ Annotations : & mcp.ToolAnnotations {ReadOnlyHint : true },
1829+ InputSchema : json .RawMessage (`{"type":"object","properties":{}}` ),
1830+ },
1831+ testToolsetMetadata ("test" ),
1832+ func (_ any ) mcp.ToolHandler {
1833+ return func (_ context.Context , _ * mcp.CallToolRequest ) (* mcp.CallToolResult , error ) {
1834+ return & mcp.CallToolResult {
1835+ Content : []mcp.Content {& mcp.TextContent {Text : "original" }},
1836+ }, nil
1837+ }
1838+ },
1839+ )
1840+
1841+ middleware := func (_ string , handler mcp.ToolHandler ) mcp.ToolHandler {
1842+ return func (ctx context.Context , req * mcp.CallToolRequest ) (* mcp.CallToolResult , error ) {
1843+ result , err := handler (ctx , req )
1844+
1845+ // Middleware can modify the result
1846+ if result != nil {
1847+ result .Content = []mcp.Content {& mcp.TextContent {Text : "modified by middleware" }}
1848+ }
1849+
1850+ return result , err
1851+ }
1852+ }
1853+
1854+ handler := tool .Handler (nil )
1855+ wrappedHandler := middleware ("modifiable_tool" , handler )
1856+
1857+ result , _ := wrappedHandler (context .Background (), & mcp.CallToolRequest {})
1858+
1859+ if result == nil {
1860+ t .Fatal ("Expected result" )
1861+ }
1862+ textContent , ok := result .Content [0 ].(* mcp.TextContent )
1863+ if ! ok {
1864+ t .Fatal ("Expected TextContent" )
1865+ }
1866+ if textContent .Text != "modified by middleware" {
1867+ t .Errorf ("Expected modified text, got: %s" , textContent .Text )
1868+ }
1869+ })
1870+
1871+ t .Run ("RegisterAllWithMiddleware applies middleware to tools" , func (t * testing.T ) {
1872+ middlewareCallCount := 0
1873+
1874+ tools := []ServerTool {
1875+ mockTool ("tool1" , "toolset1" , true ),
1876+ mockTool ("tool2" , "toolset1" , true ),
1877+ }
1878+
1879+ middleware := func (_ string , handler mcp.ToolHandler ) mcp.ToolHandler {
1880+ middlewareCallCount ++
1881+ return handler
1882+ }
1883+
1884+ server := mcp .NewServer (& mcp.Implementation {Name : "test" }, nil )
1885+ reg := NewBuilder ().SetTools (tools ).WithToolsets ([]string {"all" }).Build ()
1886+ reg .RegisterAllWithMiddleware (context .Background (), server , nil , middleware )
1887+
1888+ // Middleware should be called once per tool during registration
1889+ if middlewareCallCount != 2 {
1890+ t .Errorf ("Expected middleware to be called 2 times, got %d" , middlewareCallCount )
1891+ }
1892+ })
1893+
1894+ t .Run ("nil middleware is handled gracefully" , func (_ * testing.T ) {
1895+ tool := mockTool ("test_tool" , "toolset1" , true )
1896+
1897+ server := mcp .NewServer (& mcp.Implementation {Name : "test" }, nil )
1898+ reg := NewBuilder ().SetTools ([]ServerTool {tool }).WithToolsets ([]string {"all" }).Build ()
1899+
1900+ // Should not panic with nil middleware
1901+ reg .RegisterToolsWithMiddleware (context .Background (), server , nil , nil )
1902+ reg .RegisterAllWithMiddleware (context .Background (), server , nil , nil )
1903+ })
1904+ }
0 commit comments