diff --git a/.github/workflows/test_python.yml b/.github/workflows/test_python.yml new file mode 100644 index 0000000..e148c6b --- /dev/null +++ b/.github/workflows/test_python.yml @@ -0,0 +1,32 @@ +name: Python SDK Test + +on: + push: + branches: [ main ] + paths: + - 'sam-mcp-python/**' + pull_request: + branches: [ main ] + paths: + - 'sam-mcp-python/**' + +jobs: + test_python: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: '1.22' + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Run Python unit tests + run: make test-python + diff --git a/.gitignore b/.gitignore index 6d0cc8d..165aa1e 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ *.dll *.so *.dylib +*.pyc # Test binary, built with `go test -c` *.test @@ -33,4 +34,6 @@ go.work.sum # bin/ tests/e2e/logs/ -tests/integration/scratch/ \ No newline at end of file +tests/integration/scratch/ +__pycache__/ +.venv/ diff --git a/Makefile b/Makefile index 58dea29..a56d3c1 100644 --- a/Makefile +++ b/Makefile @@ -20,6 +20,15 @@ clean: test: CGO_ENABLED=1 go test -v -race -count 1 ./... +.PHONY: test-python test-python-e2e +test-python: + python3 -m venv sam-mcp-python/.venv + ./sam-mcp-python/.venv/bin/pip install -e ./sam-mcp-python[test] + ./sam-mcp-python/.venv/bin/pytest sam-mcp-python/tests/unit + +test-python-e2e: build docker-build + bats --verbose-run tests/e2e/python_sdk_test.bats + e2e-test: bats --verbose-run tests/e2e/ diff --git a/cmd/mcp-client/main.go b/cmd/mcp-client/main.go index fd6aa57..77660d5 100644 --- a/cmd/mcp-client/main.go +++ b/cmd/mcp-client/main.go @@ -20,8 +20,6 @@ import ( "flag" "fmt" "log" - "net" - "net/http" "os" "os/signal" "syscall" @@ -31,14 +29,14 @@ import ( ) func main() { - socketPath := flag.String("socket", "", "Path to Unix domain socket") + serverURL := flag.String("url", "", "MCP server URL (e.g. http://localhost:8080/)") toolName := flag.String("tool", "get_mesh_info", "Tool to call") toolArgs := flag.String("args", "{}", "JSON arguments for the tool") timoutArgs := flag.Int("timeout", 10, "Timeout in seconds") flag.Parse() - if *socketPath == "" { - log.Fatal("Must specify -socket") + if *serverURL == "" { + log.Fatal("Must specify -url") } var ctx context.Context @@ -58,21 +56,14 @@ func main() { cancel() }() - // Override default HTTP client transport to use Unix socket - http.DefaultClient.Transport = &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return net.Dial("unix", *socketPath) - }, - } - // Create MCP client client := mcp.NewClient(&mcp.Implementation{ Name: "mcp-test-client", Version: "0.1.0", }, nil) - // Connect to server using the URL (host is ignored by custom dialer) - session, err := client.Connect(ctx, &mcp.StreamableClientTransport{Endpoint: "http://localhost/mcp"}, nil) + // Connect to server using the URL + session, err := client.Connect(ctx, &mcp.SSEClientTransport{Endpoint: *serverURL}, nil) if err != nil { log.Fatalf("Failed to connect: %v", err) } diff --git a/cmd/sam-node/main.go b/cmd/sam-node/main.go index 35bcab3..de14135 100644 --- a/cmd/sam-node/main.go +++ b/cmd/sam-node/main.go @@ -23,7 +23,6 @@ import ( "net/http" "os" "os/signal" - "path/filepath" "strings" "syscall" "time" @@ -59,7 +58,7 @@ var ( clientSecretFlag string tokenURLFlag string hubPublicKeyFlag string - mcpSocketFlag string + mcpAddrFlag string meshFlag string discoveryIntervalFlag string enableRelayFlag bool @@ -219,7 +218,7 @@ func main() { node.Host.SetStreamHandler(api.AuthProtocolID, node.HandleAuthHandshake) // Start MCP Server - startMCPServer(node, mcpSocketFlag, dataDir) + startMCPServer(node, mcpAddrFlag) fmt.Printf("SAM Node Online.\nPeerID: %s\nListening on: %v\n", node.Host.ID(), node.Host.Addrs()) @@ -306,7 +305,7 @@ func main() { runCmd.Flags().StringVar(&clientIDFlag, "client-id", os.Getenv("SAM_OIDC_ID"), "OIDC Client ID for M2M") runCmd.Flags().StringVar(&clientSecretFlag, "client-secret", os.Getenv("SAM_OIDC_SECRET"), "OIDC Client Secret for M2M") runCmd.Flags().StringVar(&hubPublicKeyFlag, "hub-public-key", "", "Hub Public Key (32-byte Hex)") - runCmd.Flags().StringVar(&mcpSocketFlag, "mcp-socket", "", "Path to Unix domain socket for local MCP server (default: /mcp.sock)") + runCmd.Flags().StringVar(&mcpAddrFlag, "mcp-addr", "127.0.0.1:8080", "Local TCP address for the MCP HTTP/SSE server") runCmd.Flags().StringVar(&meshFlag, "mesh", DefaultMeshName, "Mesh federation name") runCmd.Flags().StringVar(&discoveryIntervalFlag, "discovery-interval", DefaultDiscoveryInterval, "Polling interval for DHT discovery") runCmd.Flags().BoolVar(&enableRelayFlag, "enable-relay", false, "Allow this node to serve as a relay for others") @@ -357,20 +356,16 @@ func getOrGenerateKey(s *Store) crypto.PrivKey { return priv } -func startMCPServer(node *SamNode, socketPath string, dataDir string) { +func startMCPServer(node *SamNode, mcpAddr string) { mcpHandler := NewMCPHandler(node) go func() { - if socketPath == "" { - socketPath = filepath.Join(dataDir, "mcp.sock") + if mcpAddr == "" { + mcpAddr = "127.0.0.1:8080" } - if err := os.Remove(socketPath); err != nil && !os.IsNotExist(err) { - logger.Errorf("Failed to remove old socket %s: %v", socketPath, err) - } - - listener, err := net.Listen("unix", socketPath) + listener, err := net.Listen("tcp", mcpAddr) if err != nil { - logger.Errorf("Failed to listen on Unix socket %s: %v", socketPath, err) + logger.Errorf("Failed to listen on TCP address %s: %v", mcpAddr, err) return } defer func() { @@ -379,7 +374,7 @@ func startMCPServer(node *SamNode, socketPath string, dataDir string) { } }() - fmt.Printf("Starting MCP server on Unix socket %s\n", socketPath) + fmt.Printf("Starting MCP server on TCP address %s\n", listener.Addr().String()) if err := http.Serve(listener, mcpHandler); err != nil { logger.Errorf("MCP server error: %v", err) } diff --git a/cmd/sam-node/mcp.go b/cmd/sam-node/mcp.go index 8472ffc..bde99d5 100644 --- a/cmd/sam-node/mcp.go +++ b/cmd/sam-node/mcp.go @@ -49,19 +49,19 @@ func handleSendMessage(ctx context.Context, req *mcp.CallToolRequest, params Sen // NewMCPHandler creates a new HTTP handler for the MCP server using the official SDK. func NewMCPHandler(node *SamNode) http.Handler { // Create an MCP server. - server := mcp.NewServer(&mcp.Implementation{ + mcpServer := mcp.NewServer(&mcp.Implementation{ Name: "sam-node-mcp", Version: "0.1.0", }, nil) // Add the send_message tool. - mcp.AddTool(server, &mcp.Tool{ + mcp.AddTool(mcpServer, &mcp.Tool{ Name: "send_message", Description: "Send a message to another agent in the mesh", }, handleSendMessage) // Add the mesh_pubsub_broadcast tool. - mcp.AddTool(server, &mcp.Tool{ + mcp.AddTool(mcpServer, &mcp.Tool{ Name: "mesh_pubsub_broadcast", Description: "Publish an event payload to a custom GossipSub topic", }, func(ctx context.Context, req *mcp.CallToolRequest, params struct { @@ -92,7 +92,7 @@ func NewMCPHandler(node *SamNode) http.Handler { }) // Add the poll_messages tool. - mcp.AddTool(server, &mcp.Tool{ + mcp.AddTool(mcpServer, &mcp.Tool{ Name: "poll_messages", Description: "Poll for incoming messages on custom GossipSub topics", }, func(ctx context.Context, req *mcp.CallToolRequest, params struct { @@ -112,7 +112,7 @@ func NewMCPHandler(node *SamNode) http.Handler { }) // Add the subscribe_topic tool. - mcp.AddTool(server, &mcp.Tool{ + mcp.AddTool(mcpServer, &mcp.Tool{ Name: "subscribe_topic", Description: "Subscribe to a custom GossipSub topic", }, func(ctx context.Context, req *mcp.CallToolRequest, params struct { @@ -129,7 +129,7 @@ func NewMCPHandler(node *SamNode) http.Handler { }) // Add the get_mesh_info tool. - mcp.AddTool(server, &mcp.Tool{ + mcp.AddTool(mcpServer, &mcp.Tool{ Name: "get_mesh_info", Description: "Get information about the mesh network", }, func(ctx context.Context, req *mcp.CallToolRequest, params struct{}) (*mcp.CallToolResult, any, error) { @@ -170,7 +170,7 @@ func NewMCPHandler(node *SamNode) http.Handler { }) // Add the call_remote_tool tool. - mcp.AddTool(server, &mcp.Tool{ + mcp.AddTool(mcpServer, &mcp.Tool{ Name: "call_remote_tool", Description: "Call an MCP tool on a remote agent", }, func(ctx context.Context, req *mcp.CallToolRequest, params struct { @@ -198,7 +198,7 @@ func NewMCPHandler(node *SamNode) http.Handler { }) // Add the connect_peer tool. - mcp.AddTool(server, &mcp.Tool{ + mcp.AddTool(mcpServer, &mcp.Tool{ Name: "connect_peer", Description: "Connect to a peer in the mesh", }, func(ctx context.Context, req *mcp.CallToolRequest, params struct { @@ -222,12 +222,22 @@ func NewMCPHandler(node *SamNode) http.Handler { }, nil, nil }) - // Create the streamable HTTP handler. - handler := mcp.NewStreamableHTTPHandler(func(req *http.Request) *mcp.Server { - return server + // Create the SSE handler using the SDK + sseHandler := mcp.NewSSEHandler(func(request *http.Request) *mcp.Server { + return mcpServer }, nil) - return handler + mux := http.NewServeMux() + mux.Handle("/mcp/events", sseHandler) + mux.Handle("/mcp/message", sseHandler) + + // Wrap in logging middleware to debug incoming requests + wrappedHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + logger.Debugf("MCP Request: %s %s from %s", r.Method, r.URL.Path, r.RemoteAddr) + mux.ServeHTTP(w, r) + }) + + return wrappedHandler } // CallMCPTool opens a stream to a remote peer, performs the handshake, and calls a tool. diff --git a/cmd/sam-node/mcp_test.go b/cmd/sam-node/mcp_test.go index 68ecaeb..829530d 100644 --- a/cmd/sam-node/mcp_test.go +++ b/cmd/sam-node/mcp_test.go @@ -15,120 +15,50 @@ package main import ( - "io" + "net/http" + "net/http/httptest" "testing" - "time" - - lru "github.com/hashicorp/golang-lru/v2" - "github.com/libp2p/go-libp2p/core/network" - "github.com/libp2p/go-libp2p/core/peer" - "github.com/google/sam/api" - "github.com/libp2p/go-libp2p/core/protocol" ) +func TestMCPHandler_HTTP(t *testing.T) { + // Setup a dummy node + node := &SamNode{} + handler := NewMCPHandler(node) -type mockStream struct { - r io.Reader - w io.Writer - protocol protocol.ID -} - -func (s *mockStream) Read(p []byte) (n int, err error) { - return s.r.Read(p) -} -func (s *mockStream) Write(p []byte) (n int, err error) { - return s.w.Write(p) -} -func (s *mockStream) Close() error { - if c, ok := s.w.(io.Closer); ok { - return c.Close() - } - return nil -} -func (s *mockStream) Protocol() protocol.ID { - return s.protocol -} - -type mockConn struct { - network.Conn // Embed interface - remotePeer peer.ID -} - -func (c *mockConn) RemotePeer() peer.ID { - return c.remotePeer -} - -func (s *mockStream) Conn() network.Conn { - return &mockConn{remotePeer: peer.ID("dummy-peer-id")} -} -func (s *mockStream) Reset() error { - return nil -} -func (s *mockStream) CloseRead() error { - return nil -} -func (s *mockStream) CloseWrite() error { - return nil -} -func (s *mockStream) ID() string { - return "dummy-stream-id" -} -func (s *mockStream) ResetWithError(code network.StreamErrorCode) error { - return nil -} -func (s *mockStream) Scope() network.StreamScope { - return nil -} -func (s *mockStream) SetDeadline(t time.Time) error { - return nil -} -func (s *mockStream) SetReadDeadline(t time.Time) error { - return nil -} -func (s *mockStream) SetWriteDeadline(t time.Time) error { - return nil -} -func (s *mockStream) SetProtocol(id protocol.ID) error { - s.protocol = id - return nil -} -func (s *mockStream) Stat() network.Stats { - return network.Stats{} -} + ts := httptest.NewServer(handler) + defer ts.Close() -func TestZeroTrustMCPServer(t *testing.T) { - pr1, pw1 := io.Pipe() - pr2, pw2 := io.Pipe() + client := &http.Client{} - serverStream := &mockStream{r: pr1, w: pw2, protocol: api.MCPProtocolID} - clientStream := &mockStream{r: pr2, w: pw1, protocol: api.MCPProtocolID} + // Test GET on root (should be 404 now) + req, err := http.NewRequest("GET", ts.URL, nil) + if err != nil { + t.Fatal(err) + } - rl, _ := NewPeerRateLimiter(100) - rp, _ := lru.New[string, int64](100) - vc, _ := lru.New[string, string](100) - node := &SamNode{ - rateLimiter: rl, - revokedPeers: rp, - verificationCache: vc, + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) } + defer func() { _ = resp.Body.Close() }() - go func() { - handler := node.WithBiscuitAuth(node.HandleMCPStream) - handler(serverStream) - }() + if resp.StatusCode != http.StatusNotFound { + t.Errorf("Expected status NotFound on root, got %d", resp.StatusCode) + } - // Test: Skip sending AuthFrame and write MCP message directly! - if _, err := pw1.Write([]byte(`{"jsonrpc":"2.0","method":"initialize"}`)); err != nil { - t.Fatalf("failed to write to pipe: %v", err) + // Test GET on /mcp/events + req2, err := http.NewRequest("GET", ts.URL+"/mcp/events", nil) + if err != nil { + t.Fatal(err) } - if err := pw1.Close(); err != nil { - t.Fatalf("failed to close pipe: %v", err) + + resp2, err := client.Do(req2) + if err != nil { + t.Fatal(err) } + defer func() { _ = resp2.Body.Close() }() - // Server should read invalid auth frame and close stream! - msg := make([]byte, 100) - _, err := clientStream.Read(msg) - if err == nil { - t.Errorf("Expected error reading from stream (stream should be closed by server), got nil") + if resp2.StatusCode != http.StatusOK && resp2.StatusCode != http.StatusBadRequest { + t.Errorf("Expected status OK or BadRequest on /mcp/events, got %d", resp2.StatusCode) } } diff --git a/cmd/sam-node/middleware_test.go b/cmd/sam-node/middleware_test.go index 09a427f..2877b83 100644 --- a/cmd/sam-node/middleware_test.go +++ b/cmd/sam-node/middleware_test.go @@ -15,6 +15,7 @@ package main import ( + "context" "crypto/ed25519" "crypto/sha256" "encoding/hex" @@ -29,12 +30,58 @@ import ( "github.com/google/sam/api" lru "github.com/hashicorp/golang-lru/v2" "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-msgio" + "github.com/multiformats/go-multiaddr" "google.golang.org/protobuf/proto" ) +type mockConn struct { + remotePeer peer.ID +} + +func (c *mockConn) RemotePeer() peer.ID { return c.remotePeer } +func (c *mockConn) LocalPeer() peer.ID { return "" } +func (c *mockConn) LocalMultiaddr() multiaddr.Multiaddr { return nil } +func (c *mockConn) RemoteMultiaddr() multiaddr.Multiaddr { return nil } +func (c *mockConn) Stat() network.ConnStats { return network.ConnStats{} } +func (c *mockConn) Scope() network.ConnScope { return nil } +func (c *mockConn) Close() error { return nil } +func (c *mockConn) CloseWithError(network.ConnErrorCode) error { return nil } +func (c *mockConn) ConnState() network.ConnectionState { return network.ConnectionState{} } +func (c *mockConn) GetStreams() []network.Stream { return nil } +func (c *mockConn) ID() string { return "" } +func (c *mockConn) IsClosed() bool { return false } +func (c *mockConn) NewStream(context.Context) (network.Stream, error) { return nil, nil } +func (c *mockConn) RemotePublicKey() crypto.PubKey { return nil } +func (c *mockConn) As(interface{}) bool { return false } + +type mockStream struct { + r io.Reader + w io.Writer + protocol protocol.ID + conn network.Conn +} + +func (s *mockStream) Read(p []byte) (n int, err error) { return s.r.Read(p) } +func (s *mockStream) Write(p []byte) (n int, err error) { return s.w.Write(p) } +func (s *mockStream) Close() error { return nil } +func (s *mockStream) Protocol() protocol.ID { return s.protocol } +func (s *mockStream) ID() string { return "" } +func (s *mockStream) SetProtocol(protocol.ID) error { return nil } +func (s *mockStream) CloseRead() error { return nil } +func (s *mockStream) CloseWrite() error { return nil } +func (s *mockStream) Reset() error { return nil } +func (s *mockStream) ResetWithError(network.StreamErrorCode) error { return nil } +func (s *mockStream) SetDeadline(time.Time) error { return nil } +func (s *mockStream) SetReadDeadline(time.Time) error { return nil } +func (s *mockStream) SetWriteDeadline(time.Time) error { return nil } +func (s *mockStream) Stat() network.Stats { return network.Stats{} } +func (s *mockStream) Conn() network.Conn { return s.conn } +func (s *mockStream) Scope() network.StreamScope { return nil } + func TestAuthorize(t *testing.T) { dir, err := os.MkdirTemp("", "middleware-test") if err != nil { @@ -318,7 +365,7 @@ func TestRevocation(t *testing.T) { pr1, pw1 := io.Pipe() pr2, pw2 := io.Pipe() - serverStream := &mockStream{r: pr1, w: pw2, protocol: protocol.ID("/test/proto")} + serverStream := &mockStream{r: pr1, w: pw2, protocol: protocol.ID("/test/proto"), conn: &mockConn{remotePeer: dummyPeer}} // Run handler in goroutine go func() { @@ -442,7 +489,7 @@ func TestHandleAuthHandshakeCache(t *testing.T) { pr1, pw1 := io.Pipe() - serverStream := &mockStream{r: pr1, w: io.Discard, protocol: api.AuthProtocolID} + serverStream := &mockStream{r: pr1, w: io.Discard, protocol: api.AuthProtocolID, conn: &mockConn{remotePeer: dummyPeer}} go func() { node.HandleAuthHandshake(serverStream) @@ -483,7 +530,7 @@ func TestHandleAuthHandshakeCache(t *testing.T) { node.Store = store2 pr5, pw5 := io.Pipe() - serverStream3 := &mockStream{r: pr5, w: io.Discard, protocol: api.AuthProtocolID} + serverStream3 := &mockStream{r: pr5, w: io.Discard, protocol: api.AuthProtocolID, conn: &mockConn{remotePeer: dummyPeer}} go func() { node.HandleAuthHandshake(serverStream3) diff --git a/sam-mcp-python/README.md b/sam-mcp-python/README.md new file mode 100644 index 0000000..916f2e8 --- /dev/null +++ b/sam-mcp-python/README.md @@ -0,0 +1,38 @@ +# SAM Python SDK (sam-mcp-python) + +The official Python SDK for the Sovereign Agent Mesh (SAM). + +This SDK acts as a "Thin Client" that connects to the local Go node via a Unix Domain Socket and communicates using the Model Context Protocol (MCP) over JSON-RPC 2.0. + +## Installation + +```bash +pip install . +``` + +## Usage + +```python +import asyncio +from sam_mcp.client import SamClient + +async def main(): + async with SamClient() as client: + tools = await client.get_tools() + print("Available tools:", tools) + + result = await client.call_tool("echo", {"message": "hello"}) + print("Result:", result) + +if __name__ == "__main__": + asyncio.run(main()) +``` + +## Development + +### Running Tests + +```bash +pytest tests/unit +pytest tests/e2e +``` diff --git a/sam-mcp-python/pyproject.toml b/sam-mcp-python/pyproject.toml new file mode 100644 index 0000000..d63165e --- /dev/null +++ b/sam-mcp-python/pyproject.toml @@ -0,0 +1,32 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "sam-mcp" +version = "0.1.0" +description = "Python SDK for Sovereign Agent Mesh (SAM) using MCP" +readme = "README.md" +requires-python = ">=3.10" +dependencies = [ + "mcp>=1.0.0", +] + +[project.optional-dependencies] +langchain = [ + "langchain-core>=0.1.0", +] +test = [ + "pytest>=7.0.0", + "pytest-asyncio>=0.20.0", +] + +[tool.hatch.build.targets.sdist] +include = ["src"] + +[tool.hatch.build.targets.wheel] +packages = ["src/sam_mcp"] + +[tool.pytest.ini_options] +pythonpath = ["src"] +asyncio_mode = "strict" diff --git a/sam-mcp-python/src/sam_mcp/__init__.py b/sam-mcp-python/src/sam_mcp/__init__.py new file mode 100644 index 0000000..1a06e96 --- /dev/null +++ b/sam-mcp-python/src/sam_mcp/__init__.py @@ -0,0 +1,3 @@ +"""SAM MCP Python SDK.""" + +__version__ = "0.1.0" diff --git a/sam-mcp-python/src/sam_mcp/adapters/__init__.py b/sam-mcp-python/src/sam_mcp/adapters/__init__.py new file mode 100644 index 0000000..6bc465b --- /dev/null +++ b/sam-mcp-python/src/sam_mcp/adapters/__init__.py @@ -0,0 +1 @@ +"""Adapters for popular frameworks.""" diff --git a/sam-mcp-python/src/sam_mcp/adapters/langchain.py b/sam-mcp-python/src/sam_mcp/adapters/langchain.py new file mode 100644 index 0000000..18f8554 --- /dev/null +++ b/sam-mcp-python/src/sam_mcp/adapters/langchain.py @@ -0,0 +1,68 @@ +from typing import Any, Dict, List +from ..client import SamClient + +def get_langchain_tools(client: SamClient, tools: List[Dict[str, Any]]) -> List[Any]: + """Converts MCP tools into LangChain-compatible StructuredTool objects. + + Requires `langchain-core` and `pydantic` to be installed. + """ + try: + from langchain_core.tools import StructuredTool + from pydantic import create_model, Field + except ImportError: + raise ImportError( + "langchain-core and pydantic are required to use this adapter. " + "Install them or ensure they are available via langchain." + ) + + lc_tools = [] + for tool in tools: + name = tool.get("name") + description = tool.get("description", "") + input_schema = tool.get("inputSchema", {}) + + properties = input_schema.get("properties", {}) + required = input_schema.get("required", []) + + fields = {} + for prop_name, prop_schema in properties.items(): + prop_type = prop_schema.get("type") + prop_desc = prop_schema.get("description", "") + + python_type = Any + if prop_type == "string": + python_type = str + elif prop_type == "integer": + python_type = int + elif prop_type == "number": + python_type = float + elif prop_type == "boolean": + python_type = bool + elif prop_type == "array": + python_type = list + elif prop_type == "object": + python_type = dict + + default = ... if prop_name in required else None + + fields[prop_name] = (python_type, Field(default=default, description=prop_desc)) + + args_schema = None + if fields: + args_schema = create_model(f"{name}Schema", **fields) + + # Capture the tool name in the closure + def make_call(tool_name=name): + async def call_remote_tool(**kwargs): + return await client.call_tool(tool_name, kwargs) + return call_remote_tool + + lc_tool = StructuredTool.from_function( + name=name, + description=description, + coroutine=make_call(name), + args_schema=args_schema + ) + lc_tools.append(lc_tool) + + return lc_tools diff --git a/sam-mcp-python/src/sam_mcp/client.py b/sam-mcp-python/src/sam_mcp/client.py new file mode 100644 index 0000000..791c509 --- /dev/null +++ b/sam-mcp-python/src/sam_mcp/client.py @@ -0,0 +1,53 @@ +import asyncio +import os +from typing import Any, Dict, List, Optional +from mcp import ClientSession +from mcp.client.sse import sse_client + +class SamClient: + """High-level developer interface for SAM MCP using official SDK.""" + + def __init__(self, server_url: Optional[str] = None): + if server_url is None: + server_url = os.environ.get("SAM_MCP_URL", "http://localhost:8080/sse") + self.server_url = server_url + self.session: Optional[ClientSession] = None + self._sse_cm = None + + async def connect(self): + """Connects to the SAM node via SSE.""" + self._sse_cm = sse_client(self.server_url, headers={"Accept": "application/json, text/event-stream"}) + read_stream, write_stream = await self._sse_cm.__aenter__() + self.session = ClientSession(read_stream, write_stream) + await self.session.__aenter__() + await self.session.initialize() + + async def close(self): + """Closes the connection.""" + if self.session: + await self.session.__aexit__(None, None, None) + if self._sse_cm: + await self._sse_cm.__aexit__(None, None, None) + self.session = None + self._sse_cm = None + + async def get_tools(self) -> List[Dict[str, Any]]: + """Returns available mesh tools.""" + if not self.session: + raise RuntimeError("Not connected") + resp = await self.session.list_tools() + return [t.model_dump() if hasattr(t, "model_dump") else t for t in resp.tools] + + async def call_tool(self, name: str, arguments: Dict[str, Any]) -> Dict[str, Any]: + """Executes a tool over the mesh.""" + if not self.session: + raise RuntimeError("Not connected") + resp = await self.session.call_tool(name, arguments) + return resp.model_dump() if hasattr(resp, "model_dump") else resp + + async def __aenter__(self): + await self.connect() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() diff --git a/sam-mcp-python/test_client.py b/sam-mcp-python/test_client.py new file mode 100644 index 0000000..0944457 --- /dev/null +++ b/sam-mcp-python/test_client.py @@ -0,0 +1,27 @@ +import asyncio +import os +import sys +from sam_mcp.client import SamClient + +async def main(): + url = os.environ.get("SAM_MCP_URL", "http://sam-node-1:8080/mcp/events") + print(f"Connecting to {url}") + try: + async with SamClient(server_url=url) as client: + # Test get_tools + tools = await client.get_tools() + print(f"TOOLS_COUNT:{len(tools)}") + + # Test call_tool (get_mesh_info is a standard tool in sam-node) + result = await client.call_tool("get_mesh_info", {}) + print(f"CALL_RESULT:{result}") + + sys.exit(0) + except Exception as e: + import traceback + print(f"ERROR:{e}") + traceback.print_exc() + sys.exit(1) + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/sam-mcp-python/tests/unit/test_unit.py b/sam-mcp-python/tests/unit/test_unit.py new file mode 100644 index 0000000..feb08c1 --- /dev/null +++ b/sam-mcp-python/tests/unit/test_unit.py @@ -0,0 +1,53 @@ +import asyncio +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from sam_mcp.client import SamClient +from sam_mcp.adapters.langchain import get_langchain_tools + +# Client Tests +@pytest.mark.asyncio +async def test_client_get_tools(): + with patch("sam_mcp.client.sse_client") as mock_sse_client, \ + patch("sam_mcp.client.ClientSession") as MockClientSession: + + mock_cm = AsyncMock() + mock_sse_client.return_value = mock_cm + mock_cm.__aenter__.return_value = (MagicMock(), MagicMock()) + mock_cm.__aexit__ = AsyncMock() + + mock_session = MockClientSession.return_value + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock() + mock_session.initialize = AsyncMock() + + mock_tool = MagicMock() + mock_tool.name = "test_tool" + mock_tool.model_dump.return_value = {"name": "test_tool"} + + mock_resp = MagicMock() + mock_resp.tools = [mock_tool] + mock_session.list_tools = AsyncMock(return_value=mock_resp) + + async with SamClient(server_url="http://localhost:8080/sse") as client: + tools = await client.get_tools() + assert len(tools) == 1 + assert tools[0]["name"] == "test_tool" + +# Adapter Tests +def test_langchain_adapter(): + class MockClient: + pass + + client = MockClient() + tools = [{"name": "test_tool", "description": "A test tool"}] + + # We need to mock langchain-core and pydantic imports if they are not installed + with patch.dict("sys.modules", {"langchain_core.tools": MagicMock(), "pydantic": MagicMock()}): + from langchain_core.tools import StructuredTool + + mock_structured_tool = MagicMock() + StructuredTool.from_function.return_value = mock_structured_tool + + lc_tools = get_langchain_tools(client, tools) + assert len(lc_tools) == 1 + assert lc_tools[0] == mock_structured_tool diff --git a/tests/e2e/lib/container_mesh.bash b/tests/e2e/lib/container_mesh.bash index 91be8d9..9e39ae9 100644 --- a/tests/e2e/lib/container_mesh.bash +++ b/tests/e2e/lib/container_mesh.bash @@ -105,31 +105,8 @@ if [[ -z "${MESH_HELPERS_LOADED:-}" ]]; then local idx="$1" local timeout_s="${2:-20}" local i - local data='{"jsonrpc":"2.0","method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}},"id":1}' - for ((i=0; i/dev/null 2>&1; then + if docker run --rm --network "${MESH_NETWORK}" python:3.12 curl -s --max-time 5 -D - http://sam-node-${idx}:8080/mcp/events | grep -q "200 OK"; then return 0 fi sleep 1 @@ -140,7 +117,7 @@ sys.exit(1) mesh_get_node_count_via_mcp() { local idx="$1" local output - output="$(timeout 15s docker run --rm -v "${MESH_SOCKET_DIR}:/sockets" -v "$(pwd)/bin/mcp-client:/mcp-client" python:3.12 /mcp-client -socket "/sockets/node-${idx}.sock" 2>/dev/null)" + output="$(timeout 15s docker run --rm --network "${MESH_NETWORK}" -v "$(pwd)/bin/mcp-client:/mcp-client" python:3.12 /mcp-client -url "http://sam-node-${idx}:8080/mcp/events" 2>/dev/null)" echo "${output}" | jq '.known_peers | length' } @@ -151,7 +128,7 @@ sys.exit(1) local i for ((i=0; i/dev/null)" + output="$(timeout 15s docker run --rm --network "${MESH_NETWORK}" -v "$(pwd)/bin/mcp-client:/mcp-client" python:3.12 /mcp-client -url "http://sam-node-${idx}:8080/mcp/events" 2>/dev/null)" echo "Node ${idx} get_mesh_info raw output: ${output}" local count count="$(echo "${output}" | jq '.known_peers | length')" @@ -171,7 +148,7 @@ sys.exit(1) local i for ((i=0; i/dev/null)" + output="$(timeout 15s docker run --rm --network "${MESH_NETWORK}" -v "$(pwd)/bin/mcp-client:/mcp-client" python:3.12 /mcp-client -url "http://sam-node-${idx}:8080/mcp/events" 2>/dev/null)" echo "[$(date +%T)] Node ${idx} get_mesh_info raw output: ${output}" local connected connected="$(echo "${output}" | jq -r --arg peer "$target_peer" '.connected_peers | index($peer) != null')" @@ -191,7 +168,7 @@ sys.exit(1) local i for ((i=0; i/dev/null)" + output="$(timeout 15s docker run --rm --network "${MESH_NETWORK}" -v "$(pwd)/bin/mcp-client:/mcp-client" python:3.12 /mcp-client -url "http://sam-node-${idx}:8080/mcp/events" 2>/dev/null)" echo "[$(date +%T)] Node ${idx} get_mesh_info raw output: ${output}" local connected connected="$(echo "${output}" | jq -r --arg peer "$target_peer" '.connected_peers | index($peer) != null')" @@ -255,7 +232,6 @@ sys.exit(1) --name "${name}" \ --network "${MESH_NETWORK}" \ --network-alias "sam-node-${idx}" \ - -v "${MESH_SOCKET_DIR}:/sockets" \ "sam-node:local" \ run \ ${flags} \ @@ -265,7 +241,7 @@ sys.exit(1) --token-url "http://mock-oidc:18080/token" \ --listen "/ip4/0.0.0.0/udp/5001/quic-v1" \ --listen "/ip4/0.0.0.0/tcp/5002" \ - --mcp-socket "/sockets/node-${idx}.sock" \ + --mcp-addr "0.0.0.0:8080" \ --mesh "e2e-mesh" >/dev/null MESH_CONTAINERS+=("${name}") diff --git a/tests/e2e/policy.bats b/tests/e2e/policy.bats index 297b8a4..79b2ceb 100644 --- a/tests/e2e/policy.bats +++ b/tests/e2e/policy.bats @@ -131,7 +131,7 @@ mesh_call_remote_tool() { local args="{\"peer_id\":\"${target_peer_id}\",\"tool_name\":\"${tool_name}\",\"arguments\":\"{}\"}" - docker run --rm -v "${MESH_SOCKET_DIR}:/sockets" -v "$(pwd)/bin/mcp-client:/mcp-client" python:3.12 /mcp-client -socket "/sockets/node-${caller_idx}.sock" -tool "call_remote_tool" -args "${args}" + docker run --rm --network "${MESH_NETWORK}" -v "$(pwd)/bin/mcp-client:/mcp-client" python:3.12 /mcp-client -url "http://sam-node-${caller_idx}:8080/mcp/events" -tool "call_remote_tool" -args "${args}" } setup() { @@ -206,7 +206,6 @@ EOF" --name "${MESH_PREFIX}-node-1" \ --network "${MESH_NETWORK}" \ --network-alias "sam-node-1" \ - -v "${MESH_SOCKET_DIR}:/sockets" \ -v "${POLICY_VOL}:/etc/sam" \ "sam-node:local" \ run \ @@ -216,7 +215,7 @@ EOF" --token-url "http://mock-oidc:18080/token" \ --listen "/ip4/0.0.0.0/udp/5001/quic-v1" \ --listen "/ip4/0.0.0.0/tcp/5002" \ - --mcp-socket "/sockets/node-1.sock" \ + --mcp-addr "0.0.0.0:8080" \ --mesh "e2e-mesh" \ --local-policy "/etc/sam/local_policy.yaml" >/dev/null @@ -239,13 +238,16 @@ EOF" for ((i=0; i<30; i++)); do local output - output="$(docker run --rm -v "${MESH_SOCKET_DIR}:/sockets" -v "$(pwd)/bin/mcp-client:/mcp-client" python:3.12 /mcp-client -socket "/sockets/node-2.sock" 2>/dev/null)" + output="$(docker run --rm --network "${MESH_NETWORK}" -v "$(pwd)/bin/mcp-client:/mcp-client" python:3.12 /mcp-client -url "http://sam-node-2:8080/mcp/events" 2>/dev/null)" TARGET_PEER_ID=$(echo "${output}" | grep -oE '12D3Koo[a-zA-Z0-9]+' | grep -v "${hub_id}" | grep -v "${node2_id}" | head -n 1) if [[ -n "${TARGET_PEER_ID}" ]]; then break fi sleep 1 done + + echo "Node 2 logs after discovery loop:" >&3 + docker logs "${MESH_PREFIX}-node-2" >&3 if [[ -z "${TARGET_PEER_ID}" ]]; then echo "Timeout waiting for discovery of Node 1" @@ -254,7 +256,7 @@ EOF" # Explicitly connect Node 2 to Node 1 to avoid "no addresses" error local node1_addr="/dns4/sam-node-1/tcp/5002/p2p/${TARGET_PEER_ID}" - docker run --rm -v "${MESH_SOCKET_DIR}:/sockets" -v "$(pwd)/bin/mcp-client:/mcp-client" python:3.12 /mcp-client -socket "/sockets/node-2.sock" -tool "connect_peer" -args "{\"peer_addr\":\"${node1_addr}\"}" >/dev/null + docker run --rm --network "${MESH_NETWORK}" -v "$(pwd)/bin/mcp-client:/mcp-client" python:3.12 /mcp-client -url "http://sam-node-2:8080/mcp/events" -tool "connect_peer" -args "{\"peer_addr\":\"${node1_addr}\"}" >/dev/null } teardown() { diff --git a/tests/e2e/python_sdk_test.bats b/tests/e2e/python_sdk_test.bats new file mode 100644 index 0000000..1f3d4bf --- /dev/null +++ b/tests/e2e/python_sdk_test.bats @@ -0,0 +1,54 @@ +#!/usr/bin/env bats + +load "lib/container_mesh.bash" + +setup() { + if ! mesh_require_docker; then + skip "docker not available or daemon not running" + fi + + if [[ ! -x "./bin/sam-node" || ! -x "./bin/sam-hub" ]]; then + skip "missing binaries; run: make build" + fi + + mesh_setup_env +} + +teardown() { + if [[ "${BATS_TEST_COMPLETED:-0}" -ne 1 ]]; then + echo "Node 1 logs on failure (filtered):" + docker logs "${MESH_PREFIX}-node-1" 2>&1 | grep -i -E 'mcp|request|error|fatal|panic' || true + fi + mesh_cleanup_env +} + +@test "Python SDK: Connect, get tools, and call tool" { + run mesh_start_mock_oidc + [[ "$status" -eq 0 ]] + + run mesh_start_hub + [[ "$status" -eq 0 ]] + + run mesh_start_node 1 "--discovery-interval 100ms --log-level debug" + [[ "$status" -eq 0 ]] + + local node1_name="${MESH_PREFIX}-node-1" + mesh_wait_for_log "${node1_name}" "SAM Node Online" 20 + mesh_wait_for_mcp_ready 1 20 + + # Use the Python SDK to interact with the node + run docker run --rm \ + --network "${MESH_NETWORK}" \ + -v "$(pwd)/sam-mcp-python:/sam-mcp-python" \ + -e PYTHONPATH=/sam-mcp-python/src \ + python:3.12 \ + bash -c 'pip install mcp httpx && python3 /sam-mcp-python/test_client.py' + echo "Python SDK output: $output" + if [[ "$status" -ne 0 ]]; then + echo "Node 1 logs:" + docker logs "${node1_name}" + fi + [[ "$status" -eq 0 ]] + [[ "$output" == *"TOOLS_COUNT:"* ]] + [[ "$output" == *"CALL_RESULT:"* ]] +} diff --git a/tests/e2e/revocation_test.bats b/tests/e2e/revocation_test.bats index 89dda03..8134fe6 100644 --- a/tests/e2e/revocation_test.bats +++ b/tests/e2e/revocation_test.bats @@ -75,7 +75,7 @@ teardown() { # Explicitly connect Node 1 to Node 2 echo "[$(date +%T)] Explicitly connecting Node 1 to Node 2" local node2_addr="/dns4/sam-node-2/tcp/5002/p2p/${node2_peer_id}" - run docker run --rm -v "${MESH_SOCKET_DIR}:/sockets" -v "$(pwd)/bin/mcp-client:/mcp-client" python:3.12 /mcp-client -socket "/sockets/node-1.sock" -tool "connect_peer" -args "{\"peer_addr\":\"${node2_addr}\"}" + run docker run --rm --network "${MESH_NETWORK}" -v "$(pwd)/bin/mcp-client:/mcp-client" python:3.12 /mcp-client -url "http://sam-node-1:8080/mcp/events" -tool "connect_peer" -args "{\"peer_addr\":\"${node2_addr}\"}" [[ "$status" -eq 0 ]] # Verify Node 1 connects to Node 2 @@ -117,7 +117,7 @@ teardown() { # Verify Node 1 cannot reconnect to Node 2 echo "[$(date +%T)] Attempting to reconnect (should fail)" local node2_addr="/dns4/sam-node-2/tcp/5002/p2p/${node2_peer_id}" - run docker run --rm -v "${MESH_SOCKET_DIR}:/sockets" -v "$(pwd)/bin/mcp-client:/mcp-client" python:3.12 /mcp-client -socket "/sockets/node-1.sock" -tool "connect_peer" -args "{\"peer_addr\":\"${node2_addr}\"}" + run docker run --rm --network "${MESH_NETWORK}" -v "$(pwd)/bin/mcp-client:/mcp-client" python:3.12 /mcp-client -url "http://sam-node-1:8080/mcp/events" -tool "connect_peer" -args "{\"peer_addr\":\"${node2_addr}\"}" echo "Reconnect output: $output" [[ "$output" == *"gater disallows connection"* ]] } diff --git a/tests/integration/catalog_test.go b/tests/integration/catalog_test.go index 24a4dec..9dfb97e 100644 --- a/tests/integration/catalog_test.go +++ b/tests/integration/catalog_test.go @@ -3,8 +3,6 @@ package integration_test import ( "context" "encoding/json" - "net" - "net/http" "os" "os/exec" "path/filepath" @@ -21,7 +19,7 @@ func startBackgroundNode(t *testing.T, nodeBin string, hubAddr string, homeDir s "HOME="+homeDir, "XDG_CONFIG_HOME="+filepath.Join(homeDir, ".config"), ) - allArgs := append([]string{"run", "--hub", hubAddr, "--jwt", "test-jwt"}, args...) + allArgs := append([]string{"run", "--hub", hubAddr, "--jwt", "test-jwt", "--mcp-addr", "127.0.0.1:0"}, args...) cmd := exec.Command(nodeBin, allArgs...) cmd.Env = env @@ -48,25 +46,36 @@ func startBackgroundNode(t *testing.T, nodeBin string, hubAddr string, homeDir s return cmd } -func callMCP(t *testing.T, socketPath string, toolName string, params map[string]any) string { +func waitForMCPAddr(t *testing.T, logPath string) string { t.Helper() - ctx := context.Background() - - oldTransport := http.DefaultClient.Transport - defer func() { http.DefaultClient.Transport = oldTransport }() - - http.DefaultClient.Transport = &http.Transport{ - DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { - return net.Dial("unix", socketPath) - }, + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + data, _ := os.ReadFile(logPath) + lines := strings.Split(string(data), "\n") + for _, line := range lines { + if strings.Contains(line, "Starting MCP server on TCP address ") { + parts := strings.Split(line, "Starting MCP server on TCP address ") + if len(parts) > 1 { + return strings.TrimSpace(parts[1]) + } + } + } + time.Sleep(100 * time.Millisecond) } - + t.Fatalf("timeout waiting for MCP addr in log: %s", logPath) + return "" +} + +func callMCP(t *testing.T, mcpAddr string, toolName string, params map[string]any) string { + t.Helper() + ctx := context.Background() + client := mcp.NewClient(&mcp.Implementation{ Name: "test-client", Version: "0.1.0", }, nil) - - session, err := client.Connect(ctx, &mcp.StreamableClientTransport{Endpoint: "http://localhost/"}, nil) + + session, err := client.Connect(ctx, &mcp.SSEClientTransport{Endpoint: "http://" + mcpAddr + "/mcp/events"}, nil) if err != nil { t.Fatalf("Failed to connect: %v", err) } @@ -75,7 +84,7 @@ func callMCP(t *testing.T, socketPath string, toolName string, params map[string t.Logf("failed to close session: %v", err) } }() - + result, err := session.CallTool(ctx, &mcp.CallToolParams{ Name: toolName, Arguments: params, @@ -83,7 +92,7 @@ func callMCP(t *testing.T, socketPath string, toolName string, params map[string if err != nil { t.Fatalf("CallTool failed: %v", err) } - + for _, content := range result.Content { if textContent, ok := content.(*mcp.TextContent); ok { return textContent.Text @@ -126,30 +135,18 @@ func TestCatalogRoutingAndFailover(t *testing.T) { nodeBin := buildBinary(t, "./cmd/sam-node") _, hubAddr := startMockLibp2pHub(t) - scratchDir := filepath.Join(repoRoot(t), "tests", "integration", "scratch") - _ = os.RemoveAll(scratchDir) // Clear old logs - - homeA := filepath.Join(scratchDir, "homeA") - homeB := filepath.Join(scratchDir, "homeB") - homeC := filepath.Join(scratchDir, "homeC") - - if err := os.MkdirAll(homeA, 0755); err != nil { - t.Fatalf("failed to create homeA: %v", err) - } - if err := os.MkdirAll(homeB, 0755); err != nil { - t.Fatalf("failed to create homeB: %v", err) - } - if err := os.MkdirAll(homeC, 0755); err != nil { - t.Fatalf("failed to create homeC: %v", err) - } - - socketPathA := filepath.Join(homeA, ".config", "sam-mesh", "mcp.sock") + homeA := t.TempDir() + homeB := t.TempDir() + homeC := t.TempDir() // Start Node A (Client) t.Log("Starting Node A...") _ = startBackgroundNode(t, nodeBin, hubAddr, homeA, "--listen", "/ip4/127.0.0.1/udp/0/quic-v1", "--listen", "/ip4/127.0.0.1/tcp/0", "--discovery-interval", "100ms") t.Log("Node A started.") + // Wait for Node A to start and get its MCP address + mcpAddrA := waitForMCPAddr(t, filepath.Join(homeA, "node.log")) + // Start Node B (Provider 1) t.Log("Starting Node B...") cmdB := startBackgroundNode(t, nodeBin, hubAddr, homeB, "--listen", "/ip4/127.0.0.1/udp/0/quic-v1", "--listen", "/ip4/127.0.0.1/tcp/0", "--discovery-interval", "100ms") @@ -163,35 +160,25 @@ func TestCatalogRoutingAndFailover(t *testing.T) { // Wait for Node B and C to start and get their addresses addrB := waitForPeerInfoInLog(t, filepath.Join(homeB, "node.log")) addrC := waitForPeerInfoInLog(t, filepath.Join(homeC, "node.log")) - + // Force Node A to connect to Node B and Node C - callMCP(t, socketPathA, "connect_peer", map[string]any{"peer_addr": addrB}) - callMCP(t, socketPathA, "connect_peer", map[string]any{"peer_addr": addrC}) + callMCP(t, mcpAddrA, "connect_peer", map[string]any{"peer_addr": addrB}) + callMCP(t, mcpAddrA, "connect_peer", map[string]any{"peer_addr": addrC}) // Wait for them to discover each other and publish catalog by polling get_mesh_info t.Log("Polling for discovery...") deadline := time.Now().Add(2 * time.Second) var connected bool - - oldTransport := http.DefaultClient.Transport - defer func() { http.DefaultClient.Transport = oldTransport }() - - http.DefaultClient.Transport = &http.Transport{ - DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { - return net.Dial("unix", socketPathA) - }, - } - for time.Now().Before(deadline) { client := mcp.NewClient(&mcp.Implementation{Name: "test-client", Version: "0.1.0"}, nil) - session, err := client.Connect(context.Background(), &mcp.StreamableClientTransport{Endpoint: "http://localhost/"}, nil) + session, err := client.Connect(context.Background(), &mcp.SSEClientTransport{Endpoint: "http://" + mcpAddrA + "/mcp/events"}, nil) if err != nil { t.Logf("Poll: failed to connect: %v", err) time.Sleep(500 * time.Millisecond) continue } - + result, err := session.CallTool(context.Background(), &mcp.CallToolParams{Name: "get_mesh_info", Arguments: map[string]any{}}) if closeErr := session.Close(); closeErr != nil { t.Logf("Poll: failed to close session: %v", closeErr) @@ -201,7 +188,7 @@ func TestCatalogRoutingAndFailover(t *testing.T) { time.Sleep(500 * time.Millisecond) continue } - + var text string for _, content := range result.Content { if textContent, ok := content.(*mcp.TextContent); ok { @@ -233,7 +220,7 @@ func TestCatalogRoutingAndFailover(t *testing.T) { t.Fatalf("failed to discover peers (Hub + 2 nodes) in time") } - respData := callMCP(t, socketPathA, "send_message", map[string]any{"peer_id": "target-peer", "message": "hello"}) + respData := callMCP(t, mcpAddrA, "send_message", map[string]any{"peer_id": "target-peer", "message": "hello"}) t.Logf("First call response: %s", respData) // Now kill Node B and assert failover to Node C @@ -244,6 +231,6 @@ func TestCatalogRoutingAndFailover(t *testing.T) { // Wait a bit for catalog update or failover to happen on next call time.Sleep(500 * time.Millisecond) - respData2 := callMCP(t, socketPathA, "send_message", map[string]any{"peer_id": "target-peer", "message": "hello"}) + respData2 := callMCP(t, mcpAddrA, "send_message", map[string]any{"peer_id": "target-peer", "message": "hello"}) t.Logf("Second call response: %s", respData2) } diff --git a/tests/integration/pubsub_test.go b/tests/integration/pubsub_test.go index 7c135f0..9c0f222 100644 --- a/tests/integration/pubsub_test.go +++ b/tests/integration/pubsub_test.go @@ -16,8 +16,6 @@ package integration_test import ( "context" - "net" - "net/http" "os" "os/exec" "path/filepath" @@ -43,15 +41,12 @@ func TestPubSubTools(t *testing.T) { } t.Logf("Node 2 logs at: %s/node2.log", tmpHome2) - socket1 := filepath.Join(tmpHome1, "mcp.sock") - socket2 := filepath.Join(tmpHome2, "mcp.sock") - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() // Start Node 1 env1 := append(os.Environ(), "HOME="+tmpHome1, "XDG_CONFIG_HOME="+filepath.Join(tmpHome1, ".config")) - cmd1 := exec.CommandContext(ctx, nodeBin, "run", "--hub", hubAddr, "--mcp-socket", socket1, "--listen", "/ip4/127.0.0.1/udp/5003/quic-v1", "--listen", "/ip4/127.0.0.1/tcp/5004", "--jwt", "dummy-token", "--log-level", "debug", "--discovery-interval", "100ms") + cmd1 := exec.CommandContext(ctx, nodeBin, "run", "--hub", hubAddr, "--mcp-addr", "127.0.0.1:0", "--listen", "/ip4/127.0.0.1/udp/5003/quic-v1", "--listen", "/ip4/127.0.0.1/tcp/5004", "--jwt", "dummy-token", "--log-level", "debug", "--discovery-interval", "100ms") cmd1.Env = env1 logFile1, err := os.Create(filepath.Join(tmpHome1, "node1.log")) if err != nil { @@ -66,7 +61,7 @@ func TestPubSubTools(t *testing.T) { // Start Node 2 env2 := append(os.Environ(), "HOME="+tmpHome2, "XDG_CONFIG_HOME="+filepath.Join(tmpHome2, ".config")) - cmd2 := exec.CommandContext(ctx, nodeBin, "run", "--hub", hubAddr, "--mcp-socket", socket2, "--listen", "/ip4/127.0.0.1/udp/5005/quic-v1", "--listen", "/ip4/127.0.0.1/tcp/5006", "--jwt", "dummy-token", "--log-level", "debug", "--discovery-interval", "100ms") + cmd2 := exec.CommandContext(ctx, nodeBin, "run", "--hub", hubAddr, "--mcp-addr", "127.0.0.1:0", "--listen", "/ip4/127.0.0.1/udp/5005/quic-v1", "--listen", "/ip4/127.0.0.1/tcp/5006", "--jwt", "dummy-token", "--log-level", "debug", "--discovery-interval", "100ms") cmd2.Env = env2 logFile2, err := os.Create(filepath.Join(tmpHome2, "node2.log")) if err != nil { @@ -79,36 +74,38 @@ func TestPubSubTools(t *testing.T) { } defer func() { _ = cmd2.Process.Kill(); _ = logFile2.Close() }() - // Wait for sockets to appear - waitForSocket := func(socketPath string) { - for i := 0; i < 50; i++ { - if _, err := os.Stat(socketPath); err == nil { - return + // Helper to wait for MCP addr in log + waitForMCPAddr := func(t *testing.T, logPath string) string { + t.Helper() + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + data, _ := os.ReadFile(logPath) + lines := strings.Split(string(data), "\n") + for _, line := range lines { + if strings.Contains(line, "Starting MCP server on TCP address ") { + parts := strings.Split(line, "Starting MCP server on TCP address ") + if len(parts) > 1 { + return strings.TrimSpace(parts[1]) + } + } } time.Sleep(100 * time.Millisecond) } - t.Fatalf("Socket %s did not appear", socketPath) + t.Fatalf("timeout waiting for MCP addr in log: %s", logPath) + return "" } - waitForSocket(socket1) - waitForSocket(socket2) - // Helper to call MCP tool - callTool := func(socketPath string, toolName string, params map[string]any) string { - oldTransport := http.DefaultClient.Transport - defer func() { http.DefaultClient.Transport = oldTransport }() - - http.DefaultClient.Transport = &http.Transport{ - DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { - return net.Dial("unix", socketPath) - }, - } + mcpAddr1 := waitForMCPAddr(t, filepath.Join(tmpHome1, "node1.log")) + mcpAddr2 := waitForMCPAddr(t, filepath.Join(tmpHome2, "node2.log")) + // Helper to call MCP tool + callTool := func(mcpAddr string, toolName string, params map[string]any) string { client := mcp.NewClient(&mcp.Implementation{ Name: "test-client", Version: "0.1.0", }, nil) - session, err := client.Connect(context.Background(), &mcp.StreamableClientTransport{Endpoint: "http://localhost/"}, nil) + session, err := client.Connect(context.Background(), &mcp.SSEClientTransport{Endpoint: "http://" + mcpAddr + "/mcp/events"}, nil) if err != nil { t.Fatalf("Failed to connect: %v", err) } @@ -170,10 +167,10 @@ func TestPubSubTools(t *testing.T) { // Force Node 1 to connect to Node 2 addr2 := waitForPeerInfoInLog(t, filepath.Join(tmpHome2, "node2.log")) t.Logf("Node 2 address: %s", addr2) - callTool(socket1, "connect_peer", map[string]any{"peer_addr": addr2}) + callTool(mcpAddr1, "connect_peer", map[string]any{"peer_addr": addr2}) // Node 1 broadcasts on topic "test-topic" - broadcastResult := callTool(socket1, "mesh_pubsub_broadcast", map[string]any{ + broadcastResult := callTool(mcpAddr1, "mesh_pubsub_broadcast", map[string]any{ "topic": "test-topic", "payload": "hello from node 1", }) @@ -182,7 +179,7 @@ func TestPubSubTools(t *testing.T) { } // Node 2 subscribes to topic "test-topic" - subscribeResult := callTool(socket2, "subscribe_topic", map[string]any{ + subscribeResult := callTool(mcpAddr2, "subscribe_topic", map[string]any{ "topic": "test-topic", }) if !strings.Contains(subscribeResult, "Subscribed") { @@ -193,7 +190,7 @@ func TestPubSubTools(t *testing.T) { deadline := time.Now().Add(10 * time.Second) for time.Now().Before(deadline) { // Node 1 broadcasts on topic "test-topic" again to ensure delivery after subscription - broadcastResult = callTool(socket1, "mesh_pubsub_broadcast", map[string]any{ + broadcastResult = callTool(mcpAddr1, "mesh_pubsub_broadcast", map[string]any{ "topic": "test-topic", "payload": "hello from node 1", }) @@ -202,7 +199,7 @@ func TestPubSubTools(t *testing.T) { } // Node 2 polls for messages on topic "test-topic" - pollResult = callTool(socket2, "poll_messages", map[string]any{ + pollResult = callTool(mcpAddr2, "poll_messages", map[string]any{ "topic": "test-topic", }) if strings.Contains(pollResult, "hello from node 1") {