From fa40b4fd0fc7eb999164edcac579814a883bae56 Mon Sep 17 00:00:00 2001 From: Teodor Calin Date: Thu, 28 May 2026 15:53:44 -0700 Subject: [PATCH] feat: extract shared type packages from web4 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Migrates coreapi, protocol, driver, registry/{client,wire}, config, logging, urlvalidate, secure, and ipcutil out of TeoSlayer/pilotprotocol and into the canonical pilot-protocol/common module. These are the universally-shared types every sibling repo currently imports via github.com/TeoSlayer/pilotprotocol/pkg/. Bringing them into common is the foundation that breaks the locally-circular sibling↔web4 dependency: once siblings switch their imports (Phase 2) they no longer require the web4 hub for shared types. ipcutil was previously at web4/internal/ipcutil (40 LOC). Promoting it to common/ipcutil because Go internal-package visibility doesn't cross the module boundary, and several moved packages depend on it. Smoke-tested with go test across every new package — all pass. Bumps go.mod to 1.25.10 to match sibling repos. Subsequent steps (not in this PR): - Phase 2: 13 sibling repos switch shared-type imports to common - Phase 3: web4 cmd/* switches too, and the duplicated pkg/* in web4 is deleted. --- config/config.go | 59 + config/zz_config_test.go | 180 ++ coreapi/doc.go | 19 + coreapi/errors.go | 22 + coreapi/events.go | 31 + coreapi/identity.go | 15 + coreapi/lifecycle.go | 149 ++ coreapi/peers.go | 30 + coreapi/policy.go | 78 + coreapi/recover.go | 83 + coreapi/streams.go | 49 + coreapi/trust.go | 15 + coreapi/zz_lifecycle_edge_test.go | 75 + coreapi/zz_lifecycle_test.go | 114 ++ coreapi/zz_panic_recovery_test.go | 52 + coreapi/zz_recover_edge_test.go | 98 ++ coreapi/zz_recover_test.go | 130 ++ driver/conn.go | 150 ++ driver/driver.go | 495 ++++++ driver/ipc.go | 444 +++++ driver/listener.go | 79 + driver/zz_conn_test.go | 299 ++++ driver/zz_conn_write_test.go | 167 ++ driver/zz_driver_simple_ops_test.go | 134 ++ driver/zz_driver_test.go | 739 ++++++++ driver/zz_ipc_listener_test.go | 510 ++++++ go.mod | 2 +- go.sum | 0 ipcutil/ipcutil.go | 40 + ipcutil/zz_test.go | 137 ++ logging/logging.go | 44 + logging/zz_logging_test.go | 161 ++ protocol/address.go | 151 ++ protocol/checksum.go | 12 + protocol/header.go | 87 + protocol/packet.go | 158 ++ protocol/zz_fuzz_packet_test.go | 200 +++ protocol/zz_protocol_test.go | 442 +++++ registry/client/binary_client.go | 278 +++ registry/client/client.go | 1393 +++++++++++++++ registry/client/zz_binary_client_test.go | 550 ++++++ registry/client/zz_client_branch_test.go | 444 +++++ .../client/zz_client_join_signature_test.go | 707 ++++++++ .../client/zz_client_nil_receiver_test.go | 253 +++ registry/client/zz_client_pool_test.go | 861 ++++++++++ registry/client/zz_client_wire_test.go | 618 +++++++ registry/wire/blueprint.go | 178 ++ registry/wire/rules.go | 204 +++ registry/wire/wire.go | 595 +++++++ registry/wire/zz_blueprint_test.go | 254 +++ registry/wire/zz_decode_edge_test.go | 110 ++ registry/wire/zz_decode_truncation_test.go | 108 ++ registry/wire/zz_frame_test.go | 110 ++ registry/wire/zz_fuzz_wire_test.go | 216 +++ registry/wire/zz_message_framing_test.go | 151 ++ registry/wire/zz_rules_test.go | 334 ++++ registry/wire/zz_wire_test.go | 415 +++++ secure/client.go | 24 + secure/secure.go | 773 +++++++++ secure/server.go | 102 ++ secure/zz_extra_coverage_test.go | 1496 +++++++++++++++++ secure/zz_handshake_lookup_test.go | 284 ++++ secure/zz_secure_test.go | 586 +++++++ urlvalidate/validate.go | 68 + urlvalidate/zz_cloud_metadata_test.go | 65 + urlvalidate/zz_validate_edge_test.go | 60 + urlvalidate/zz_validate_test.go | 55 + 67 files changed, 16941 insertions(+), 1 deletion(-) create mode 100644 config/config.go create mode 100644 config/zz_config_test.go create mode 100644 coreapi/doc.go create mode 100644 coreapi/errors.go create mode 100644 coreapi/events.go create mode 100644 coreapi/identity.go create mode 100644 coreapi/lifecycle.go create mode 100644 coreapi/peers.go create mode 100644 coreapi/policy.go create mode 100644 coreapi/recover.go create mode 100644 coreapi/streams.go create mode 100644 coreapi/trust.go create mode 100644 coreapi/zz_lifecycle_edge_test.go create mode 100644 coreapi/zz_lifecycle_test.go create mode 100644 coreapi/zz_panic_recovery_test.go create mode 100644 coreapi/zz_recover_edge_test.go create mode 100644 coreapi/zz_recover_test.go create mode 100644 driver/conn.go create mode 100644 driver/driver.go create mode 100644 driver/ipc.go create mode 100644 driver/listener.go create mode 100644 driver/zz_conn_test.go create mode 100644 driver/zz_conn_write_test.go create mode 100644 driver/zz_driver_simple_ops_test.go create mode 100644 driver/zz_driver_test.go create mode 100644 driver/zz_ipc_listener_test.go create mode 100644 go.sum create mode 100644 ipcutil/ipcutil.go create mode 100644 ipcutil/zz_test.go create mode 100644 logging/logging.go create mode 100644 logging/zz_logging_test.go create mode 100644 protocol/address.go create mode 100644 protocol/checksum.go create mode 100644 protocol/header.go create mode 100644 protocol/packet.go create mode 100644 protocol/zz_fuzz_packet_test.go create mode 100644 protocol/zz_protocol_test.go create mode 100644 registry/client/binary_client.go create mode 100644 registry/client/client.go create mode 100644 registry/client/zz_binary_client_test.go create mode 100644 registry/client/zz_client_branch_test.go create mode 100644 registry/client/zz_client_join_signature_test.go create mode 100644 registry/client/zz_client_nil_receiver_test.go create mode 100644 registry/client/zz_client_pool_test.go create mode 100644 registry/client/zz_client_wire_test.go create mode 100644 registry/wire/blueprint.go create mode 100644 registry/wire/rules.go create mode 100644 registry/wire/wire.go create mode 100644 registry/wire/zz_blueprint_test.go create mode 100644 registry/wire/zz_decode_edge_test.go create mode 100644 registry/wire/zz_decode_truncation_test.go create mode 100644 registry/wire/zz_frame_test.go create mode 100644 registry/wire/zz_fuzz_wire_test.go create mode 100644 registry/wire/zz_message_framing_test.go create mode 100644 registry/wire/zz_rules_test.go create mode 100644 registry/wire/zz_wire_test.go create mode 100644 secure/client.go create mode 100644 secure/secure.go create mode 100644 secure/server.go create mode 100644 secure/zz_extra_coverage_test.go create mode 100644 secure/zz_handshake_lookup_test.go create mode 100644 secure/zz_secure_test.go create mode 100644 urlvalidate/validate.go create mode 100644 urlvalidate/zz_cloud_metadata_test.go create mode 100644 urlvalidate/zz_validate_edge_test.go create mode 100644 urlvalidate/zz_validate_test.go diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..5a58118 --- /dev/null +++ b/config/config.go @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package config + +import ( + "encoding/json" + "flag" + "fmt" + "os" + "strings" +) + +// Load reads a JSON config file and returns it as a map. +func Load(path string) (map[string]interface{}, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + var cfg map[string]interface{} + if err := json.NewDecoder(f).Decode(&cfg); err != nil { + return nil, err + } + return cfg, nil +} + +// ApplyToFlags overrides flag defaults from config for any flag not +// explicitly set on the command line. Call this AFTER flag.Parse(). +// Keys in the config can use either hyphens or underscores (e.g. +// "log-level" or "log_level" both match the -log-level flag). +func ApplyToFlags(cfg map[string]interface{}) { + explicit := make(map[string]bool) + flag.Visit(func(f *flag.Flag) { + explicit[f.Name] = true + }) + + flag.VisitAll(func(f *flag.Flag) { + if explicit[f.Name] { + return + } + val, ok := cfg[f.Name] + if !ok { + // Try underscore variant: log-level → log_level + val, ok = cfg[strings.ReplaceAll(f.Name, "-", "_")] + } + if !ok { + return + } + switch v := val.(type) { + case string: + f.Value.Set(v) + case float64: + f.Value.Set(fmt.Sprintf("%v", v)) + case bool: + f.Value.Set(fmt.Sprintf("%v", v)) + } + }) +} diff --git a/config/zz_config_test.go b/config/zz_config_test.go new file mode 100644 index 0000000..14762bd --- /dev/null +++ b/config/zz_config_test.go @@ -0,0 +1,180 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package config_test + +import ( + "flag" + "os" + "path/filepath" + "testing" + + "github.com/pilot-protocol/common/config" +) + +func TestLoadValidJSON(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "cfg.json") + body := `{"log_level":"debug","port":8080,"verbose":true}` + if err := os.WriteFile(path, []byte(body), 0644); err != nil { + t.Fatal(err) + } + cfg, err := config.Load(path) + if err != nil { + t.Fatalf("Load: %v", err) + } + if cfg["log_level"] != "debug" { + t.Errorf("log_level = %v, want debug", cfg["log_level"]) + } + if cfg["port"].(float64) != 8080 { + t.Errorf("port = %v, want 8080", cfg["port"]) + } + if cfg["verbose"] != true { + t.Errorf("verbose = %v, want true", cfg["verbose"]) + } +} + +func TestLoadMissingFile(t *testing.T) { + _, err := config.Load("/nonexistent/path/cfg.json") + if err == nil { + t.Fatal("expected error for missing file") + } +} + +func TestLoadMalformedJSON(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "bad.json") + if err := os.WriteFile(path, []byte("{not json"), 0644); err != nil { + t.Fatal(err) + } + _, err := config.Load(path) + if err == nil { + t.Fatal("expected parse error") + } +} + +// ApplyToFlags tests must serialize because flag package has global state. +// We use a dedicated FlagSet per test, but package-level flag.Visit reads +// flag.CommandLine — so we temporarily swap it. +func withFreshCommandLine(t *testing.T) *flag.FlagSet { + t.Helper() + saved := flag.CommandLine + flag.CommandLine = flag.NewFlagSet("test", flag.ContinueOnError) + t.Cleanup(func() { flag.CommandLine = saved }) + return flag.CommandLine +} + +func TestApplyToFlagsSetsUnsetFlags(t *testing.T) { + fs := withFreshCommandLine(t) + var level string + var port int + var verbose bool + fs.StringVar(&level, "log-level", "info", "") + fs.IntVar(&port, "port", 9000, "") + fs.BoolVar(&verbose, "verbose", false, "") + + // Parse with no args so nothing is explicitly set + if err := fs.Parse(nil); err != nil { + t.Fatalf("Parse: %v", err) + } + + cfg := map[string]interface{}{ + "log-level": "debug", + "port": float64(8080), + "verbose": true, + } + config.ApplyToFlags(cfg) + + if level != "debug" { + t.Errorf("log-level = %q, want debug", level) + } + if port != 8080 { + t.Errorf("port = %d, want 8080", port) + } + if verbose != true { + t.Errorf("verbose = %v, want true", verbose) + } +} + +func TestApplyToFlagsPreservesExplicitlySetFlags(t *testing.T) { + fs := withFreshCommandLine(t) + var level string + fs.StringVar(&level, "log-level", "info", "") + + // Explicitly set on the command line — config must NOT override. + if err := fs.Parse([]string{"-log-level=warn"}); err != nil { + t.Fatalf("Parse: %v", err) + } + + cfg := map[string]interface{}{"log-level": "debug"} + config.ApplyToFlags(cfg) + + if level != "warn" { + t.Errorf("log-level = %q, want warn (explicit flag must win over config)", level) + } +} + +func TestApplyToFlagsUnderscoreVariantMatches(t *testing.T) { + fs := withFreshCommandLine(t) + var level string + fs.StringVar(&level, "log-level", "info", "") + if err := fs.Parse(nil); err != nil { + t.Fatal(err) + } + + // Config uses underscore; flag uses hyphen. ApplyToFlags should match them. + cfg := map[string]interface{}{"log_level": "debug"} + config.ApplyToFlags(cfg) + + if level != "debug" { + t.Errorf("log-level = %q, want debug (underscore→hyphen match)", level) + } +} + +func TestApplyToFlagsHyphenVariantTakesPrecedenceOverUnderscore(t *testing.T) { + fs := withFreshCommandLine(t) + var level string + fs.StringVar(&level, "log-level", "info", "") + if err := fs.Parse(nil); err != nil { + t.Fatal(err) + } + + // If both keys present, the exact flag-name match (log-level) must win. + cfg := map[string]interface{}{ + "log-level": "debug", + "log_level": "warn", + } + config.ApplyToFlags(cfg) + + if level != "debug" { + t.Errorf("log-level = %q, want debug (exact match wins)", level) + } +} + +func TestApplyToFlagsIgnoresUnknownKeys(t *testing.T) { + fs := withFreshCommandLine(t) + var level string + fs.StringVar(&level, "log-level", "info", "") + if err := fs.Parse(nil); err != nil { + t.Fatal(err) + } + config.ApplyToFlags(map[string]interface{}{"unrelated-flag": "xyz"}) + if level != "info" { + t.Errorf("log-level changed unexpectedly: %q", level) + } +} + +func TestApplyToFlagsSkipsUnsupportedTypes(t *testing.T) { + fs := withFreshCommandLine(t) + var level string + fs.StringVar(&level, "log-level", "info", "") + if err := fs.Parse(nil); err != nil { + t.Fatal(err) + } + // Nested map / array — should be silently skipped (not panic) + config.ApplyToFlags(map[string]interface{}{ + "log-level": map[string]interface{}{"nested": "value"}, + }) + if level != "info" { + t.Errorf("log-level changed from nested map: %q (unsupported type should skip)", level) + } +} diff --git a/coreapi/doc.go b/coreapi/doc.go new file mode 100644 index 0000000..6524556 --- /dev/null +++ b/coreapi/doc.go @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +// Package coreapi defines the L10 plugin runtime contract. +// +// The interfaces in this package are the only surface a plugin +// (L11) ever sees of the daemon. Plugins import coreapi; the daemon +// implements coreapi; the bridge happens at lifecycle bootstrap +// (cmd/daemon/main.go registers concrete plugins against the +// daemon's coreapi implementations). +// +// See docs/architecture/01-LAYERS.md §10 for the layer's role, +// docs/architecture/03-INVARIANTS.md for the principles this +// package enforces, and docs/architecture/06-CHANGES.md §2 for +// the rationale of each interface signature. +// +// Stability contract: every exported identifier in this package is +// part of the daemon-plugin ABI. Removing or renaming any of them +// breaks every plugin. Additions are forward-compatible. +package coreapi diff --git a/coreapi/errors.go b/coreapi/errors.go new file mode 100644 index 0000000..214313a --- /dev/null +++ b/coreapi/errors.go @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package coreapi + +import "errors" + +// Sentinel errors returned by the L10 surface. +var ( + // ErrRegistryStarted is returned by ServiceRegistry.Register and + // ServiceRegistry.StartAll when StartAll has already been called. + // Plugins must register before bootstrap. + ErrRegistryStarted = errors.New("coreapi: service registry already started") + + // ErrServiceNotReady indicates a Service.Start call was made on a + // dependency that itself hasn't completed Start. Surface only — + // Service implementations shouldn't return this; the registry will. + ErrServiceNotReady = errors.New("coreapi: dependency service not ready") + + // ErrPeerNotFound is the canonical "directory has no record" error + // from PeerResolver. Plugins should match on errors.Is. + ErrPeerNotFound = errors.New("coreapi: peer not found") +) diff --git a/coreapi/events.go b/coreapi/events.go new file mode 100644 index 0000000..495bcf6 --- /dev/null +++ b/coreapi/events.go @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package coreapi + +import "time" + +// Event is one item published to the EventBus. Topics are +// dot-namespaced (e.g., "tunnel.established", "security.nonce_replay"). +// Payload keys/values are plugin-defined; subscribers parse them. +type Event struct { + Topic string + NodeID uint32 + Time time.Time + Payload map[string]any +} + +// EventBus is the publish/subscribe channel that replaces inline +// webhook.Emit calls inside core layers. Core (L2-L7) publishes; +// the webhook plugin (and any other observability plugin) subscribes. +// +// Publish is non-blocking. If the bus is over capacity, the event is +// dropped (and a metric counter is incremented inside the daemon +// implementation). This keeps L2 readLoop / L6 decrypt latency bounded. +// +// Subscribe returns a buffered channel and an unsubscribe func. Pattern +// is a glob: "tunnel.*" matches "tunnel.established" but not +// "security.nonce_replay". +type EventBus interface { + Publish(topic string, payload map[string]any) + Subscribe(pattern string) (<-chan Event, func()) +} diff --git a/coreapi/identity.go b/coreapi/identity.go new file mode 100644 index 0000000..741a3a3 --- /dev/null +++ b/coreapi/identity.go @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package coreapi + +import "crypto/ed25519" + +// Identity is the daemon's own identity — its Ed25519 keypair, its +// stable nodeID, its 48-bit address. Plugins may sign arbitrary bytes +// (e.g., for plugin-level auth proofs) but cannot replace the identity. +type Identity interface { + NodeID() uint32 + Address() Addr + PublicKey() ed25519.PublicKey + Sign(msg []byte) ([]byte, error) +} diff --git a/coreapi/lifecycle.go b/coreapi/lifecycle.go new file mode 100644 index 0000000..5482acc --- /dev/null +++ b/coreapi/lifecycle.go @@ -0,0 +1,149 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package coreapi + +import ( + "context" + "fmt" + "log/slog" + "sort" + "sync" +) + +// Service is the lifecycle contract every L11 plugin implements. +// +// Order determines the start sequence. Lower numbers start first; +// higher numbers stop first. Suggested ranges: +// +// 10-49 Foundation (none today) +// 50-79 Trust / identity-adjacent (trustedagents) +// 80-99 Observability (webhook) +// 100-199 Application services (dataexchange, eventstream, tasks) +// 200-249 Sidecars (skillinject) +// 250+ Tooling-bound (updater) +// +// Start receives Deps (the L10 surface). Implementations must NOT +// retain references to anything outside Deps — that's the whole +// extraction contract. +// +// Stop should drain in-flight work, close listeners, and signal +// background goroutines to exit. It must return within 5 seconds +// or the daemon shutdown gate will fail. +type Service interface { + Name() string + Order() int + Start(ctx context.Context, deps Deps) error + Stop(ctx context.Context) error +} + +// Deps is the bag of capabilities a plugin can use. Optional fields +// may be nil if the corresponding plugin isn't loaded; plugins that +// hard-depend on them should error in Start(). +type Deps struct { + Streams Streams + Identity Identity + Resolver PeerResolver + Events EventBus + Logger *slog.Logger + + // Optional — nil if the plugin providing them isn't registered. + Trust TrustChecker +} + +// ServiceRegistry coordinates plugin lifecycle. cmd/daemon/main.go +// constructs one, registers each plugin, and hands it to the daemon. +// The daemon calls StartAll during bootstrap and StopAll during +// shutdown. +type ServiceRegistry struct { + mu sync.Mutex + services []Service + started []Service // start order, used to stop in reverse +} + +// Register adds a service. Must be called before StartAll. After +// StartAll runs, Register is a no-op error. +func (sr *ServiceRegistry) Register(s Service) error { + sr.mu.Lock() + defer sr.mu.Unlock() + if len(sr.started) > 0 { + return ErrRegistryStarted + } + sr.services = append(sr.services, s) + return nil +} + +// StartAll sorts by Order and starts every service in sequence. +// The first failing Start aborts and returns its error; previously- +// started services are NOT auto-stopped (the caller's job, via Stop() +// or by passing a context that cancels). +func (sr *ServiceRegistry) StartAll(ctx context.Context, deps Deps) error { + sr.mu.Lock() + if len(sr.started) > 0 { + sr.mu.Unlock() + return ErrRegistryStarted + } + sort.SliceStable(sr.services, func(i, j int) bool { + return sr.services[i].Order() < sr.services[j].Order() + }) + queue := append([]Service(nil), sr.services...) + sr.mu.Unlock() + + for _, s := range queue { + if err := startWithPanicRecovery(ctx, s, deps); err != nil { + return err + } + sr.mu.Lock() + sr.started = append(sr.started, s) + sr.mu.Unlock() + } + return nil +} + +// startWithPanicRecovery calls s.Start(ctx, deps) inside a defer +// recover() so a buggy plugin panicking during initialization (nil +// deref, index OOB, channel-send on nil, etc.) surfaces as a normal +// Start error rather than crashing the entire daemon process. +// +// Without this wrapper, every plugin's Init bug becomes a single- +// point-of-failure for the host: the whole daemon dies, every OTHER +// plugin goes offline with it, and the operator's only signal is a +// stack trace. +// +// Behaviour preserved on normal error returns: the surrounding +// StartAll loop still aborts on first failure and leaves earlier +// services running for the caller's Stop() to drain. +func startWithPanicRecovery(ctx context.Context, s Service, deps Deps) (err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("plugin %q Start panicked: %v", s.Name(), r) + } + }() + return s.Start(ctx, deps) +} + +// StopAll stops every started service in reverse order. Errors from +// individual Stop calls are collected; the first one is returned but +// every service still gets its Stop call invoked. +func (sr *ServiceRegistry) StopAll(ctx context.Context) error { + sr.mu.Lock() + queue := append([]Service(nil), sr.started...) + sr.started = nil + sr.mu.Unlock() + + var firstErr error + for i := len(queue) - 1; i >= 0; i-- { + if err := queue[i].Stop(ctx); err != nil && firstErr == nil { + firstErr = err + } + } + return firstErr +} + +// All returns a snapshot of the registered services in start order. +func (sr *ServiceRegistry) All() []Service { + sr.mu.Lock() + defer sr.mu.Unlock() + out := make([]Service, len(sr.services)) + copy(out, sr.services) + return out +} diff --git a/coreapi/peers.go b/coreapi/peers.go new file mode 100644 index 0000000..0017a9b --- /dev/null +++ b/coreapi/peers.go @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package coreapi + +import ( + "context" + "crypto/ed25519" + "net" +) + +// PeerInfo is the directory record for a remote node. Returned by +// PeerResolver.Resolve and PeerResolver.ListByNetwork. +type PeerInfo struct { + NodeID uint32 + Addr Addr + Endpoint *net.UDPAddr // best-known reachable endpoint, or nil + PubKey ed25519.PublicKey + Public bool + Hostname string + RelayOnly bool +} + +// PeerResolver is the L8 directory surface. The daemon's +// implementation talks to the registry over the bootstrap TCP +// side-channel (see 01-LAYERS §L8). +type PeerResolver interface { + Resolve(ctx context.Context, nodeID uint32) (PeerInfo, error) + ResolveHostname(ctx context.Context, name string) (uint32, error) + ListByNetwork(ctx context.Context, networkID uint32) ([]PeerInfo, error) +} diff --git a/coreapi/policy.go b/coreapi/policy.go new file mode 100644 index 0000000..9d40980 --- /dev/null +++ b/coreapi/policy.go @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package coreapi + +// PolicyEventType is the kind of protocol event a policy is evaluated +// against. Type alias to string so daemon-local primitive interfaces +// can satisfy plugin signatures via structural typing without importing +// this package (T7.1). +type PolicyEventType = string + +const ( + PolicyEventConnect = "connect" + PolicyEventDial = "dial" + PolicyEventDatagram = "datagram" + PolicyEventJoin = "join" + PolicyEventLeave = "leave" + PolicyEventCycle = "cycle" +) + +// PolicyRunner is the daemon-facing surface of a single network's +// running policy. The plugin's concrete *PolicyRunner type implements +// this. The daemon never holds the concrete type — only this interface. +type PolicyRunner interface { + NetworkID() uint16 + + // HasMember returns true if peerNodeID is in this runner's + // per-peer state. The daemon iterates all runners to consult + // every network the peer belongs to (deny wins across networks). + HasMember(peerNodeID uint32) bool + + // EvaluatePortGate is the daemon-facing gate API for inbound SYN + // (Connect), outbound SYN (Dial), and datagram (in/out) events. + // The plugin builds the per-peer ctx internally (peer_age_s, + // peer_tags, members) using its peer state and the + // daemon-supplied localTags + nodeInfoTags. Returns the + // allow/deny verdict (default allow on no explicit deny). + EvaluatePortGate(eventType PolicyEventType, port uint16, peerNodeID uint32, payloadSize int, direction string, localTags, nodeInfoTags []string) bool + + // EvaluateActions runs an action-event (cycle/join/leave) with a + // caller-built ctx. Side-effect-only: no return value. + EvaluateActions(eventType PolicyEventType, ctx map[string]any) + + Status() map[string]any + PeerList() []map[string]any + ForceCycle() map[string]any + ReconcileNow() + + // PolicyJSON returns the marshaled policy document. Used by IPC + // handlers that read the current policy back to admin tools. + PolicyJSON() ([]byte, error) + + Stop() +} + +// PolicyManager owns the per-network registry of policy runners. The +// daemon holds it as an interface field; cmd/daemon (L12) constructs +// the concrete plugin and calls Daemon.RegisterPolicyManager. +type PolicyManager interface { + // Start compiles a policy JSON for the given network and registers + // a runner. Returns the runner handle; existing runners for the + // same network are stopped first. + Start(netID uint16, policyJSON []byte) (PolicyRunner, error) + + // Stop stops the runner for netID (no-op if absent). + Stop(netID uint16) + + // Get returns the runner for netID or nil. + Get(netID uint16) PolicyRunner + + // All returns a snapshot of all running runners. + All() []PolicyRunner + + // StopAll stops every runner. Called during daemon shutdown. + StopAll() + + // LoadPersisted runs at daemon-Start to restore runners from disk. + LoadPersisted() error +} diff --git a/coreapi/recover.go b/coreapi/recover.go new file mode 100644 index 0000000..7c771d4 --- /dev/null +++ b/coreapi/recover.go @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package coreapi + +import ( + "fmt" + "log/slog" + "runtime/debug" + "sync/atomic" +) + +// pluginRecoveredPanicCount is the L11 counterpart to the daemon's +// internal recoveredPanicCount. Tracks how many panics have been +// caught at plugin entry points (acceptLoop, handleConn, Service.Start +// goroutines). Exposed via PluginRecoveredPanicCount. +var pluginRecoveredPanicCount atomic.Uint64 + +// PluginRecoveredPanicCount returns the total number of panics +// swallowed by RecoverPlugin since process start. +func PluginRecoveredPanicCount() uint64 { + return pluginRecoveredPanicCount.Load() +} + +// ResetPluginRecoveredPanicCountForTest is test-only. +func ResetPluginRecoveredPanicCountForTest() { + pluginRecoveredPanicCount.Store(0) +} + +// RecoverPlugin is the L11 panic-recovery shim used at the top of +// every plugin entrypoint goroutine: Service.Start helper goroutines, +// acceptLoop, and per-connection handlers. Usage: +// +// defer coreapi.RecoverPlugin("eventstream", "acceptLoop", events, nil) +// +// On panic it: +// 1. Recovers (caller goroutine continues / loop iteration is dropped) +// 2. Logs at ERROR with structured plugin/op fields, panic value, and +// full goroutine stack trace +// 3. Increments PluginRecoveredPanicCount +// 4. Publishes a "plugin..panic" event on the bus (if +// events != nil) so observability subscribers see the recovery +// 5. Calls onPanic(r) if non-nil — typical use is per-conn close, +// or signaling a future per-plugin supervisor for restart +// +// TODO(03-INVARIANTS.md §8): per-plugin supervisor not yet implemented. +// Today the boundary just survives + logs. A future tier will signal a +// restart of the panicked plugin via the onPanic callback. +// +// This must be the OUTERMOST defer in the goroutine: defers run LIFO, +// so other defers (conn.Close, mu.Unlock, removeSub) run first. +func RecoverPlugin(plugin, op string, events EventBus, onPanic func(any)) { + r := recover() + if r == nil { + return + } + count := pluginRecoveredPanicCount.Add(1) + slog.Error("plugin panic recovered", + "layer", "L11", + "plugin", plugin, + "op", op, + "panic", r, + "recovered_total", count, + "stack", string(debug.Stack()), + ) + if events != nil { + // Defensive: a publisher that itself panics must not propagate. + func() { + defer func() { _ = recover() }() + events.Publish("plugin."+plugin+".panic", map[string]any{ + "plugin": plugin, + "op": op, + "panic": fmt.Sprintf("%v", r), + "recovered_total": count, + }) + }() + } + if onPanic != nil { + func() { + defer func() { _ = recover() }() + onPanic(r) + }() + } +} diff --git a/coreapi/streams.go b/coreapi/streams.go new file mode 100644 index 0000000..9f57ab0 --- /dev/null +++ b/coreapi/streams.go @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package coreapi + +import ( + "context" + "io" + + "github.com/pilot-protocol/common/protocol" +) + +// Addr is the 48-bit virtual address used throughout the protocol. +// Re-exported here so plugins can stay free of pkg/protocol if they want. +type Addr = protocol.Addr + +// Stream is one bidirectional ordered byte stream between two +// (Addr, port) endpoints. It satisfies io.ReadWriteCloser with +// Pilot Protocol addressing extensions. Deadline methods are +// intentionally excluded — the runtime currently cannot honor +// them, and removing them from the interface forces callers to +// get a compile-time signal rather than a silent no-op. +type Stream interface { + io.ReadWriteCloser + + LocalAddr() Addr + LocalPort() uint16 + RemoteAddr() Addr + RemotePort() uint16 +} + +// Listener accepts inbound streams on a single well-known or ephemeral +// port. Returned by Streams.Listen. +type Listener interface { + Accept() (Stream, error) + Close() error + Addr() Addr + Port() uint16 +} + +// Streams is the L7 surface plugins consume. The daemon-side +// implementation routes through L7 → L6 → L5 → L4 → L2. +// +// SendDatagram is the connectionless variant (one packet, no ACK, +// no retransmit). Used by plugins that don't need stream semantics. +type Streams interface { + Dial(ctx context.Context, dst Addr, port uint16) (Stream, error) + Listen(port uint16) (Listener, error) + SendDatagram(ctx context.Context, dst Addr, port uint16, data []byte) error +} diff --git a/coreapi/trust.go b/coreapi/trust.go new file mode 100644 index 0000000..c96f832 --- /dev/null +++ b/coreapi/trust.go @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package coreapi + +// TrustChecker is the trusted-agents gate consumed by L11/tasks (and +// any other plugin that gates on peer reputation). +// +// IsTrusted: returns true if the peer is on the auto-approve allowlist +// (loaded from the trusted-agents JSON, refreshed hourly). +type TrustChecker interface { + // IsTrusted reports whether the peer is on the auto-approve allowlist. + // Returns the agent's display name when known. Both return values are + // zero on miss. + IsTrusted(nodeID uint32) (name string, ok bool) +} diff --git a/coreapi/zz_lifecycle_edge_test.go b/coreapi/zz_lifecycle_edge_test.go new file mode 100644 index 0000000..9612073 --- /dev/null +++ b/coreapi/zz_lifecycle_edge_test.go @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package coreapi_test + +import ( + "context" + "errors" + "testing" + + "github.com/pilot-protocol/common/coreapi" +) + +func TestServiceRegistry_StartAllTwiceReturnsErrRegistryStarted(t *testing.T) { + t.Parallel() + sr := &coreapi.ServiceRegistry{} + _ = sr.Register(&fakeService{name: "a", order: 1}) + if err := sr.StartAll(context.Background(), coreapi.Deps{}); err != nil { + t.Fatalf("first StartAll: %v", err) + } + err := sr.StartAll(context.Background(), coreapi.Deps{}) + if !errors.Is(err, coreapi.ErrRegistryStarted) { + t.Errorf("second StartAll = %v, want ErrRegistryStarted", err) + } +} + +func TestServiceRegistry_StopAllSurfacesFirstError(t *testing.T) { + t.Parallel() + sr := &coreapi.ServiceRegistry{} + a := &fakeService{name: "a", order: 1, stopErr: errors.New("stop-a-failed")} + b := &fakeService{name: "b", order: 2, stopErr: errors.New("stop-b-failed")} + _ = sr.Register(a) + _ = sr.Register(b) + if err := sr.StartAll(context.Background(), coreapi.Deps{}); err != nil { + t.Fatalf("StartAll: %v", err) + } + // b stops first (reverse order), so its error is "first" returned. + err := sr.StopAll(context.Background()) + if err == nil || err.Error() != "stop-b-failed" { + t.Errorf("StopAll = %v, want stop-b-failed", err) + } +} + +func TestServiceRegistry_StopAllStopsAllEvenAfterError(t *testing.T) { + t.Parallel() + sr := &coreapi.ServiceRegistry{} + aStopped := false + bStopped := false + a := &recordingStopWithErr{name: "a", order: 1, stopped: &aStopped} + b := &recordingStopWithErr{name: "b", order: 2, stopped: &bStopped, err: errors.New("b-failed")} + _ = sr.Register(a) + _ = sr.Register(b) + _ = sr.StartAll(context.Background(), coreapi.Deps{}) + _ = sr.StopAll(context.Background()) + if !aStopped { + t.Error("service a was not stopped despite b's error") + } + if !bStopped { + t.Error("service b was not stopped") + } +} + +type recordingStopWithErr struct { + name string + order int + stopped *bool + err error +} + +func (r *recordingStopWithErr) Name() string { return r.name } +func (r *recordingStopWithErr) Order() int { return r.order } +func (r *recordingStopWithErr) Start(ctx context.Context, deps coreapi.Deps) error { return nil } +func (r *recordingStopWithErr) Stop(ctx context.Context) error { + *r.stopped = true + return r.err +} diff --git a/coreapi/zz_lifecycle_test.go b/coreapi/zz_lifecycle_test.go new file mode 100644 index 0000000..b335493 --- /dev/null +++ b/coreapi/zz_lifecycle_test.go @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package coreapi_test + +import ( + "context" + "errors" + "fmt" + "testing" + + "github.com/pilot-protocol/common/coreapi" +) + +type fakeService struct { + name string + order int + startErr error + stopErr error + startedAt int // sequence number, set by harness +} + +func (f *fakeService) Name() string { return f.name } +func (f *fakeService) Order() int { return f.order } +func (f *fakeService) Start(ctx context.Context, deps coreapi.Deps) error { + return f.startErr +} +func (f *fakeService) Stop(ctx context.Context) error { return f.stopErr } + +func TestServiceRegistry_StartOrder(t *testing.T) { + t.Parallel() + sr := &coreapi.ServiceRegistry{} + a := &fakeService{name: "a", order: 200} + b := &fakeService{name: "b", order: 100} + c := &fakeService{name: "c", order: 50} + for _, s := range []coreapi.Service{a, b, c} { + if err := sr.Register(s); err != nil { + t.Fatalf("register %s: %v", s.Name(), err) + } + } + if err := sr.StartAll(context.Background(), coreapi.Deps{}); err != nil { + t.Fatalf("StartAll: %v", err) + } + got := sr.All() + want := []string{"c", "b", "a"} + for i, s := range got { + if s.Name() != want[i] { + t.Errorf("position %d: got %s, want %s", i, s.Name(), want[i]) + } + } +} + +func TestServiceRegistry_StartFailureAborts(t *testing.T) { + t.Parallel() + sr := &coreapi.ServiceRegistry{} + a := &fakeService{name: "a", order: 10} + boom := &fakeService{name: "boom", order: 20, startErr: errors.New("boom")} + c := &fakeService{name: "c", order: 30} + for _, s := range []coreapi.Service{a, boom, c} { + _ = sr.Register(s) + } + err := sr.StartAll(context.Background(), coreapi.Deps{}) + if err == nil || err.Error() != "boom" { + t.Fatalf("want boom, got %v", err) + } + // `c` should NOT have been started after boom failed; verify by + // calling StopAll and checking only the started ones rolled back. + // (We can't directly observe started state, but the registry should + // not crash and Stop should return nil on the un-started services.) + if err := sr.StopAll(context.Background()); err != nil { + t.Errorf("StopAll after partial start: %v", err) + } +} + +func TestServiceRegistry_StopReverseOrder(t *testing.T) { + t.Parallel() + sr := &coreapi.ServiceRegistry{} + stops := []string{} + a := &recordingStop{name: "a", order: 10, stops: &stops} + b := &recordingStop{name: "b", order: 20, stops: &stops} + c := &recordingStop{name: "c", order: 30, stops: &stops} + for _, s := range []coreapi.Service{a, b, c} { + _ = sr.Register(s) + } + _ = sr.StartAll(context.Background(), coreapi.Deps{}) + _ = sr.StopAll(context.Background()) + want := []string{"c", "b", "a"} + if fmt.Sprint(stops) != fmt.Sprint(want) { + t.Errorf("stop order: got %v, want %v", stops, want) + } +} + +func TestServiceRegistry_RegisterAfterStart(t *testing.T) { + t.Parallel() + sr := &coreapi.ServiceRegistry{} + _ = sr.Register(&fakeService{name: "a", order: 10}) + _ = sr.StartAll(context.Background(), coreapi.Deps{}) + if err := sr.Register(&fakeService{name: "late", order: 50}); !errors.Is(err, coreapi.ErrRegistryStarted) { + t.Errorf("want ErrRegistryStarted, got %v", err) + } +} + +type recordingStop struct { + name string + order int + stops *[]string +} + +func (r *recordingStop) Name() string { return r.name } +func (r *recordingStop) Order() int { return r.order } +func (r *recordingStop) Start(ctx context.Context, deps coreapi.Deps) error { return nil } +func (r *recordingStop) Stop(ctx context.Context) error { + *r.stops = append(*r.stops, r.name) + return nil +} diff --git a/coreapi/zz_panic_recovery_test.go b/coreapi/zz_panic_recovery_test.go new file mode 100644 index 0000000..3918548 --- /dev/null +++ b/coreapi/zz_panic_recovery_test.go @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package coreapi_test + +// Regression for P1 plugin-crash DoS: StartAll invokes each plugin's +// Start() directly with no recover wrapper. A plugin that panics +// during Start (nil-deref, index OOB, etc.) crashes the entire daemon +// process — operator's recourse is to find the buggy plugin via +// stack trace and disable it, while every other plugin is offline. +// +// Fix: StartAll wraps each plugin Start() in defer recover(), converts +// the panic to an error like any other Start failure. The error path +// (return on first failure, previously-started plugins NOT auto- +// stopped) is preserved — the caller's Stop() handles cleanup. + +import ( + "context" + "strings" + "testing" + + "github.com/pilot-protocol/common/coreapi" +) + +// panickingService panics during Start with the given message. +type panickingService struct{ msg string } + +func (p *panickingService) Name() string { return "panicker" } +func (p *panickingService) Order() int { return 100 } +func (p *panickingService) Start(_ context.Context, _ coreapi.Deps) error { panic(p.msg) } +func (p *panickingService) Stop(_ context.Context) error { return nil } + +func TestServiceRegistry_StartAllRecoversFromPluginPanic(t *testing.T) { + t.Parallel() + + sr := &coreapi.ServiceRegistry{} + if err := sr.Register(&panickingService{msg: "boom from a buggy plugin"}); err != nil { + t.Fatalf("Register: %v", err) + } + + // Without the recover wrapper, this CRASHES the test process. + err := sr.StartAll(context.Background(), coreapi.Deps{}) + + if err == nil { + t.Fatal("StartAll returned nil for panicking plugin — recover wrapper missing") + } + if !strings.Contains(err.Error(), "panic") { + t.Errorf("expected error to mention 'panic'; got %q", err.Error()) + } + if !strings.Contains(err.Error(), "boom from a buggy plugin") { + t.Errorf("expected error to include the panic message; got %q", err.Error()) + } +} diff --git a/coreapi/zz_recover_edge_test.go b/coreapi/zz_recover_edge_test.go new file mode 100644 index 0000000..a62af97 --- /dev/null +++ b/coreapi/zz_recover_edge_test.go @@ -0,0 +1,98 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package coreapi_test + +import ( + "testing" + + "github.com/pilot-protocol/common/coreapi" +) + +func TestPluginRecoveredPanicCountAndReset(t *testing.T) { + // Not parallel — touches a package-level counter. + coreapi.ResetPluginRecoveredPanicCountForTest() + if got := coreapi.PluginRecoveredPanicCount(); got != 0 { + t.Fatalf("after reset = %d, want 0", got) + } + + // Induce a panic and let RecoverPlugin swallow it. + func() { + defer coreapi.RecoverPlugin("test-plugin", "test-op", nil, nil) + panic("synthetic") + }() + if got := coreapi.PluginRecoveredPanicCount(); got != 1 { + t.Errorf("after one panic = %d, want 1", got) + } + + // Another with onPanic callback exercised. + called := false + func() { + defer coreapi.RecoverPlugin("p2", "op", nil, func(_ any) { called = true }) + panic("two") + }() + if !called { + t.Errorf("onPanic callback not invoked") + } + if got := coreapi.PluginRecoveredPanicCount(); got != 2 { + t.Errorf("after two panics = %d, want 2", got) + } + + // Reset works after non-zero count. + coreapi.ResetPluginRecoveredPanicCountForTest() + if got := coreapi.PluginRecoveredPanicCount(); got != 0 { + t.Errorf("second reset = %d, want 0", got) + } +} + +func TestRecoverPlugin_NoPanicIsNoOp(t *testing.T) { + t.Parallel() + // The early-return path when recover() returns nil. No counter bump. + before := coreapi.PluginRecoveredPanicCount() + func() { + defer coreapi.RecoverPlugin("clean", "op", nil, nil) + }() + if got := coreapi.PluginRecoveredPanicCount(); got != before { + t.Errorf("counter changed without a panic: %d → %d", before, got) + } +} + +// fakeBusPanics publishes that itself panics — RecoverPlugin must +// shield itself from a nested publisher panic. +type fakeBusPanics struct{} + +func (fakeBusPanics) Publish(string, map[string]any) { panic("nested-publish-panic") } +func (fakeBusPanics) Subscribe(string) (<-chan coreapi.Event, func()) { return nil, func() {} } + +func TestRecoverPlugin_NestedPublishPanicSwallowed(t *testing.T) { + // Not parallel — touches counter. + coreapi.ResetPluginRecoveredPanicCountForTest() + defer func() { + if r := recover(); r != nil { + t.Fatalf("nested publish panic escaped: %v", r) + } + }() + func() { + defer coreapi.RecoverPlugin("p", "op", fakeBusPanics{}, nil) + panic("trigger") + }() + if got := coreapi.PluginRecoveredPanicCount(); got != 1 { + t.Errorf("counter = %d, want 1", got) + } +} + +func TestRecoverPlugin_NestedOnPanicSwallowed(t *testing.T) { + // Not parallel — touches counter. + coreapi.ResetPluginRecoveredPanicCountForTest() + defer func() { + if r := recover(); r != nil { + t.Fatalf("nested onPanic panic escaped: %v", r) + } + }() + func() { + defer coreapi.RecoverPlugin("p", "op", nil, func(_ any) { panic("nested-cb-panic") }) + panic("trigger") + }() + if got := coreapi.PluginRecoveredPanicCount(); got != 1 { + t.Errorf("counter = %d, want 1", got) + } +} diff --git a/coreapi/zz_recover_test.go b/coreapi/zz_recover_test.go new file mode 100644 index 0000000..b8a40ad --- /dev/null +++ b/coreapi/zz_recover_test.go @@ -0,0 +1,130 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package coreapi + +import ( + "sync" + "testing" +) + +// fakeBus implements EventBus for the panic-survival test. Records +// every published topic so the test can assert the boundary emitted +// the expected event. +type fakeBus struct { + mu sync.Mutex + topics []string +} + +func (b *fakeBus) Publish(topic string, _ map[string]any) { + b.mu.Lock() + defer b.mu.Unlock() + b.topics = append(b.topics, topic) +} + +func (b *fakeBus) Subscribe(_ string) (<-chan Event, func()) { + ch := make(chan Event) + return ch, func() {} +} + +func (b *fakeBus) latest() []string { + b.mu.Lock() + defer b.mu.Unlock() + out := make([]string, len(b.topics)) + copy(out, b.topics) + return out +} + +// TestL11PluginPanicSurvival exercises the L11 boundary +// (RecoverPlugin) by inducing a panic in a goroutine guarded by it. +// Verifies: +// 1. process survives +// 2. PluginRecoveredPanicCount increments +// 3. a "plugin..panic" event lands on the bus +// 4. the onPanic callback fires with the panic value +func TestL11PluginPanicSurvival(t *testing.T) { + t.Parallel() + before := PluginRecoveredPanicCount() + bus := &fakeBus{} + + var ( + gotPanicValue any + callbackCount int + mu sync.Mutex + ) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + defer RecoverPlugin("testplugin", "acceptLoop", bus, func(r any) { + mu.Lock() + defer mu.Unlock() + gotPanicValue = r + callbackCount++ + }) + panic("L11 boundary test panic") + }() + wg.Wait() + + if PluginRecoveredPanicCount() <= before { + t.Fatal("L11 boundary did not record the panic") + } + + mu.Lock() + defer mu.Unlock() + if callbackCount != 1 { + t.Fatalf("onPanic callback fired %d times, want 1", callbackCount) + } + if s, ok := gotPanicValue.(string); !ok || s != "L11 boundary test panic" { + t.Fatalf("onPanic got %v (%T), want string 'L11 boundary test panic'", gotPanicValue, gotPanicValue) + } + + // Bus event should be "plugin.testplugin.panic". + found := false + for _, topic := range bus.latest() { + if topic == "plugin.testplugin.panic" { + found = true + break + } + } + if !found { + t.Fatalf("plugin.testplugin.panic event not on bus: got %v", bus.latest()) + } +} + +// TestL11PluginPanicNilBus confirms the boundary is nil-safe when no +// bus is provided (e.g., the standalone nameserver binary). +func TestL11PluginPanicNilBus(t *testing.T) { + t.Parallel() + before := PluginRecoveredPanicCount() + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + defer RecoverPlugin("nullbus", "op", nil, nil) + panic("nil-bus panic") + }() + wg.Wait() + if PluginRecoveredPanicCount() <= before { + t.Fatal("L11 boundary did not record nil-bus panic") + } +} + +// TestL11PluginPanicCallbackPanicSwallowed checks the defensive guard +// against a panicking onPanic callback. +func TestL11PluginPanicCallbackPanicSwallowed(t *testing.T) { + t.Parallel() + before := PluginRecoveredPanicCount() + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + defer RecoverPlugin("buggy", "op", nil, func(_ any) { + panic("callback-itself-panics") + }) + panic("primary panic") + }() + wg.Wait() + if PluginRecoveredPanicCount() <= before { + t.Fatal("L11 boundary did not record the primary panic") + } +} diff --git a/driver/conn.go b/driver/conn.go new file mode 100644 index 0000000..d95c3fa --- /dev/null +++ b/driver/conn.go @@ -0,0 +1,150 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package driver + +import ( + "encoding/binary" + "io" + "net" + "os" + "sync" + "time" + + "github.com/pilot-protocol/common/ipcutil" + "github.com/pilot-protocol/common/protocol" +) + +// maxSendChunk is the largest payload we will pack into one cmdSend IPC +// message. IPC messages are capped at ipcutil.MaxMessageSize; we reserve +// 5 bytes for the cmdSend+conn_id header and leave a small safety margin. +const maxSendChunk = ipcutil.MaxMessageSize - 64 + +// Conn implements net.Conn over a Pilot Protocol stream. +type Conn struct { + id uint32 + localAddr protocol.SocketAddr + remoteAddr protocol.SocketAddr + ipc *ipcClient + recvCh chan []byte + recvBuf []byte // leftover from previous read + closed bool + + mu sync.Mutex + readDeadline time.Time + deadlineCh chan struct{} // closed when deadline is set/changed +} + +func (c *Conn) Read(b []byte) (int, error) { + // Drain leftover first + if len(c.recvBuf) > 0 { + n := copy(b, c.recvBuf) + c.recvBuf = c.recvBuf[n:] + return n, nil + } + + c.mu.Lock() + dl := c.readDeadline + dch := c.deadlineCh + c.mu.Unlock() + + // Check if deadline already passed + if !dl.IsZero() && !time.Now().Before(dl) { + return 0, os.ErrDeadlineExceeded + } + + // Set up timer if deadline is set + var timer <-chan time.Time + if !dl.IsZero() { + t := time.NewTimer(time.Until(dl)) + defer t.Stop() + timer = t.C + } + + select { + case data, ok := <-c.recvCh: + if !ok { + return 0, io.EOF + } + n := copy(b, data) + if n < len(data) { + c.recvBuf = data[n:] + } + return n, nil + case <-timer: + return 0, os.ErrDeadlineExceeded + case <-dch: + // Deadline was changed, re-check + return 0, os.ErrDeadlineExceeded + } +} + +func (c *Conn) Write(b []byte) (int, error) { + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return 0, protocol.ErrConnClosed + } + c.mu.Unlock() + + total := len(b) + written := 0 + for written < total { + chunk := total - written + if chunk > maxSendChunk { + chunk = maxSendChunk + } + msg := make([]byte, 1+4+chunk) + msg[0] = cmdSend + binary.BigEndian.PutUint32(msg[1:5], c.id) + copy(msg[5:], b[written:written+chunk]) + if err := c.ipc.send(msg); err != nil { + return written, err + } + written += chunk + } + return written, nil +} + +func (c *Conn) Close() error { + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return nil + } + c.closed = true + c.mu.Unlock() + c.ipc.unregisterRecvCh(c.id) + + msg := make([]byte, 5) + msg[0] = cmdClose + binary.BigEndian.PutUint32(msg[1:5], c.id) + return c.ipc.send(msg) +} + +func (c *Conn) LocalAddr() net.Addr { return pilotAddr(c.localAddr) } +func (c *Conn) RemoteAddr() net.Addr { return pilotAddr(c.remoteAddr) } + +func (c *Conn) SetDeadline(t time.Time) error { + c.SetReadDeadline(t) + return nil +} + +func (c *Conn) SetReadDeadline(t time.Time) error { + c.mu.Lock() + c.readDeadline = t + // Signal any blocked Read to re-check + if c.deadlineCh != nil { + close(c.deadlineCh) + } + c.deadlineCh = make(chan struct{}) + c.mu.Unlock() + return nil +} + +func (c *Conn) SetWriteDeadline(t time.Time) error { return nil } + +// pilotAddr wraps SocketAddr to satisfy net.Addr. +type pilotAddr protocol.SocketAddr + +func (a pilotAddr) Network() string { return "pilot" } +func (a pilotAddr) String() string { return protocol.SocketAddr(a).String() } diff --git a/driver/driver.go b/driver/driver.go new file mode 100644 index 0000000..e1d3a3a --- /dev/null +++ b/driver/driver.go @@ -0,0 +1,495 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package driver + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "os" + "path/filepath" + "runtime" + "time" + + "github.com/pilot-protocol/common/protocol" +) + +// DefaultSocketPath returns the default Unix socket path for IPC. +// On Linux it prefers $XDG_RUNTIME_DIR (typically /run/user/, +// which is private to the user); falls back to /tmp/pilot.sock. +// On macOS /tmp is already per-user via SIP, so /tmp/pilot.sock is safe. +func DefaultSocketPath() string { + if runtime.GOOS == "linux" { + if xdg := os.Getenv("XDG_RUNTIME_DIR"); xdg != "" { + return filepath.Join(xdg, "pilot.sock") + } + } + return "/tmp/pilot.sock" +} + +// Handshake sub-commands (must match daemon SubHandshake* constants) +const ( + subHandshakeSend byte = 0x01 + subHandshakeApprove byte = 0x02 + subHandshakeReject byte = 0x03 + subHandshakePending byte = 0x04 + subHandshakeTrusted byte = 0x05 + subHandshakeRevoke byte = 0x06 + subHandshakeWait byte = 0x07 +) + +// jsonRPC sends an IPC message, waits for the expected response, and +// unmarshals the JSON payload. Most driver methods follow this pattern. +func (d *Driver) jsonRPC(msg []byte, expectCmd byte, label string) (map[string]interface{}, error) { + resp, err := d.ipc.sendAndWait(msg, expectCmd) + if err != nil { + return nil, fmt.Errorf("%s: %w", label, err) + } + var result map[string]interface{} + if err := json.Unmarshal(resp, &result); err != nil { + return nil, fmt.Errorf("%s unmarshal: %w", label, err) + } + return result, nil +} + +// Driver is the main entry point for the Pilot Protocol SDK. +type Driver struct { + ipc *ipcClient + socketPath string +} + +// Connect creates a new driver connected to the local daemon. +func Connect(socketPath string) (*Driver, error) { + if socketPath == "" { + socketPath = DefaultSocketPath() + } + + ipc, err := newIPCClient(socketPath) + if err != nil { + return nil, err + } + + return &Driver{ipc: ipc, socketPath: socketPath}, nil +} + +// Dial opens a stream connection to a remote address:port. +// addr format: "N:XXXX.YYYY.YYYY:PORT" +func (d *Driver) Dial(addr string) (*Conn, error) { + sa, err := protocol.ParseSocketAddr(addr) + if err != nil { + return nil, fmt.Errorf("parse address: %w", err) + } + + return d.DialAddr(sa.Addr, sa.Port) +} + +// DialAddr opens a stream connection to a remote Addr + port. +func (d *Driver) DialAddr(dst protocol.Addr, port uint16) (*Conn, error) { + msg := make([]byte, 1+protocol.AddrSize+2) + msg[0] = cmdDial + dst.MarshalTo(msg, 1) + binary.BigEndian.PutUint16(msg[1+protocol.AddrSize:], port) + + resp, err := d.ipc.sendAndWait(msg, cmdDialOK) + if err != nil { + return nil, fmt.Errorf("dial: %w", err) + } + + if len(resp) < 4 { + return nil, fmt.Errorf("invalid dial response") + } + + connID := binary.BigEndian.Uint32(resp[0:4]) + recvCh := d.ipc.registerRecvCh(connID) + + return &Conn{ + id: connID, + remoteAddr: protocol.SocketAddr{Addr: dst, Port: port}, + ipc: d.ipc, + recvCh: recvCh, + deadlineCh: make(chan struct{}), + }, nil +} + +// DialAddrTimeout opens a stream connection with a client-side timeout. +// If the daemon does not respond within the timeout, the dial is cancelled. +func (d *Driver) DialAddrTimeout(dst protocol.Addr, port uint16, timeout time.Duration) (*Conn, error) { + msg := make([]byte, 1+protocol.AddrSize+2) + msg[0] = cmdDial + dst.MarshalTo(msg, 1) + binary.BigEndian.PutUint16(msg[1+protocol.AddrSize:], port) + + resp, err := d.ipc.sendAndWaitTimeout(msg, cmdDialOK, timeout) + if err != nil { + return nil, fmt.Errorf("dial: %w", err) + } + + if len(resp) < 4 { + return nil, fmt.Errorf("invalid dial response") + } + + connID := binary.BigEndian.Uint32(resp[0:4]) + recvCh := d.ipc.registerRecvCh(connID) + + return &Conn{ + id: connID, + remoteAddr: protocol.SocketAddr{Addr: dst, Port: port}, + ipc: d.ipc, + recvCh: recvCh, + deadlineCh: make(chan struct{}), + }, nil +} + +// Listen binds a port and returns a Listener that accepts connections. +func (d *Driver) Listen(port uint16) (*Listener, error) { + msg := make([]byte, 3) + msg[0] = cmdBind + binary.BigEndian.PutUint16(msg[1:3], port) + + resp, err := d.ipc.sendAndWait(msg, cmdBindOK) + if err != nil { + return nil, fmt.Errorf("bind: %w", err) + } + + boundPort := binary.BigEndian.Uint16(resp[0:2]) + + // H12 fix: register per-port accept channel + acceptCh := d.ipc.registerAcceptCh(boundPort) + + return &Listener{ + port: boundPort, + ipc: d.ipc, + acceptCh: acceptCh, + done: make(chan struct{}), + }, nil +} + +// SendTo sends an unreliable unicast datagram to the given address:port. +// Broadcast addresses (Node=0xFFFFFFFF) are not accepted on this path; use +// Broadcast, which requires the daemon's admin token. +func (d *Driver) SendTo(dst protocol.Addr, port uint16, data []byte) error { + if dst.IsBroadcast() { + return fmt.Errorf("broadcast address requires admin token: use Driver.Broadcast") + } + msg := make([]byte, 1+protocol.AddrSize+2+len(data)) + msg[0] = cmdSendTo + dst.MarshalTo(msg, 1) + binary.BigEndian.PutUint16(msg[1+protocol.AddrSize:], port) + copy(msg[1+protocol.AddrSize+2:], data) + return d.ipc.send(msg) +} + +// Broadcast fans an unreliable datagram out to every member of a network. +// The admin token must match the daemon's configured Config.AdminToken; an +// empty token or mismatched token is rejected. Permitted on every network +// including network 0 (backbone). Sender membership is not required. +func (d *Driver) Broadcast(netID uint16, port uint16, data []byte, adminToken string) error { + tokenBytes := []byte(adminToken) + msg := make([]byte, 1+2+2+2+len(tokenBytes)+len(data)) + msg[0] = cmdBroadcast + binary.BigEndian.PutUint16(msg[1:3], netID) + binary.BigEndian.PutUint16(msg[3:5], port) + binary.BigEndian.PutUint16(msg[5:7], uint16(len(tokenBytes))) + copy(msg[7:7+len(tokenBytes)], tokenBytes) + copy(msg[7+len(tokenBytes):], data) + if _, err := d.ipc.sendAndWait(msg, cmdBroadcastOK); err != nil { + return err + } + return nil +} + +// RecvFrom receives the next incoming datagram. +func (d *Driver) RecvFrom() (*Datagram, error) { + dg, ok := <-d.ipc.dgCh + if !ok { + return nil, fmt.Errorf("driver closed") + } + return dg, nil +} + +// Info returns the daemon's status information. +func (d *Driver) Info() (map[string]interface{}, error) { + return d.jsonRPC([]byte{cmdInfo}, cmdInfoOK, "info") +} + +// Health returns a lightweight health check from the daemon. +func (d *Driver) Health() (map[string]interface{}, error) { + return d.jsonRPC([]byte{cmdHealth}, cmdHealthOK, "health") +} + +// Handshake sends a trust handshake request to a remote node. +func (d *Driver) Handshake(nodeID uint32, justification string) (map[string]interface{}, error) { + msg := make([]byte, 1+1+4+len(justification)) + msg[0] = cmdHandshake + msg[1] = subHandshakeSend + binary.BigEndian.PutUint32(msg[2:6], nodeID) + copy(msg[6:], justification) + return d.jsonRPC(msg, cmdHandshakeOK, "handshake") +} + +// ApproveHandshake approves a pending trust handshake request. +func (d *Driver) ApproveHandshake(nodeID uint32) (map[string]interface{}, error) { + msg := make([]byte, 6) + msg[0] = cmdHandshake + msg[1] = subHandshakeApprove + binary.BigEndian.PutUint32(msg[2:6], nodeID) + return d.jsonRPC(msg, cmdHandshakeOK, "approve") +} + +// RejectHandshake rejects a pending trust handshake request. +func (d *Driver) RejectHandshake(nodeID uint32, reason string) (map[string]interface{}, error) { + msg := make([]byte, 1+1+4+len(reason)) + msg[0] = cmdHandshake + msg[1] = subHandshakeReject + binary.BigEndian.PutUint32(msg[2:6], nodeID) + copy(msg[6:], reason) + return d.jsonRPC(msg, cmdHandshakeOK, "reject") +} + +// PendingHandshakes returns pending trust handshake requests. +func (d *Driver) PendingHandshakes() (map[string]interface{}, error) { + return d.jsonRPC([]byte{cmdHandshake, subHandshakePending}, cmdHandshakeOK, "pending") +} + +// WaitForTrust blocks (in the daemon) until the peer transitions to trusted +// or the timeout elapses. Single IPC roundtrip — the daemon-side +// HandshakeService.WaitForTrust waits on a per-node channel that is closed +// the moment trust is granted, so wakeup latency is sub-millisecond. +// +// Backward compatibility: an old daemon (no SubHandshakeWait) returns an +// "unknown sub-command" error; callers should treat that as "wait skipped" +// and proceed. +func (d *Driver) WaitForTrust(nodeID uint32, timeoutMs uint32) (map[string]interface{}, error) { + msg := make([]byte, 1+1+4+4) + msg[0] = cmdHandshake + msg[1] = subHandshakeWait + binary.BigEndian.PutUint32(msg[2:6], nodeID) + binary.BigEndian.PutUint32(msg[6:10], timeoutMs) + return d.jsonRPC(msg, cmdHandshakeOK, "wait") +} + +// TrustedPeers returns all trusted peers from the handshake protocol. +func (d *Driver) TrustedPeers() (map[string]interface{}, error) { + return d.jsonRPC([]byte{cmdHandshake, subHandshakeTrusted}, cmdHandshakeOK, "trusted") +} + +// RevokeTrust removes a peer from the trusted set and notifies the registry. +func (d *Driver) RevokeTrust(nodeID uint32) (map[string]interface{}, error) { + msg := make([]byte, 6) + msg[0] = cmdHandshake + msg[1] = subHandshakeRevoke + binary.BigEndian.PutUint32(msg[2:6], nodeID) + return d.jsonRPC(msg, cmdHandshakeOK, "revoke") +} + +// ResolveHostname resolves a hostname to node info via the daemon. +func (d *Driver) ResolveHostname(hostname string) (map[string]interface{}, error) { + msg := make([]byte, 1+len(hostname)) + msg[0] = cmdResolveHostname + copy(msg[1:], hostname) + return d.jsonRPC(msg, cmdResolveHostnameOK, "resolve_hostname") +} + +// SetHostname sets or clears the daemon's hostname via the registry. +func (d *Driver) SetHostname(hostname string) (map[string]interface{}, error) { + msg := make([]byte, 1+len(hostname)) + msg[0] = cmdSetHostname + copy(msg[1:], hostname) + return d.jsonRPC(msg, cmdSetHostnameOK, "set_hostname") +} + +// SetVisibility sets the daemon's visibility on the registry. +func (d *Driver) SetVisibility(public bool) (map[string]interface{}, error) { + msg := make([]byte, 2) + msg[0] = cmdSetVisibility + if public { + msg[1] = 1 + } + return d.jsonRPC(msg, cmdSetVisibilityOK, "set_visibility") +} + +// Deregister removes the daemon from the registry. +func (d *Driver) Deregister() (map[string]interface{}, error) { + return d.jsonRPC([]byte{cmdDeregister}, cmdDeregisterOK, "deregister") +} + +// SetTags sets the capability tags for this daemon's node. +func (d *Driver) SetTags(tags []string) (map[string]interface{}, error) { + data, _ := json.Marshal(tags) + msg := make([]byte, 1+len(data)) + msg[0] = cmdSetTags + copy(msg[1:], data) + return d.jsonRPC(msg, cmdSetTagsOK, "set_tags") +} + +// SetWebhook sets or clears the daemon's webhook URL at runtime. +// An empty URL disables the webhook. +func (d *Driver) SetWebhook(url string) (map[string]interface{}, error) { + msg := make([]byte, 1+len(url)) + msg[0] = cmdSetWebhook + copy(msg[1:], url) + return d.jsonRPC(msg, cmdSetWebhookOK, "set_webhook") +} + +// RotateKey asks the daemon to rotate its Ed25519 identity at the registry. +// The daemon generates a new keypair, signs proof of the current key, calls +// registry.RotateKey, then atomically swaps and persists the new identity. +func (d *Driver) RotateKey() (map[string]interface{}, error) { + return d.jsonRPC([]byte{cmdRotateKey}, cmdRotateKeyOK, "rotate_key") +} + +// Disconnect closes a connection by ID. Used by administrative tools. +// Fire-and-forget: the daemon always responds CmdCloseOK regardless of +// whether the connID exists, so there is no error to propagate. Using +// sendAndWait here would corrupt a concurrent sendAndWait for a different +// command if a server-pushed cmdCloseOK (remote FIN) arrived simultaneously. +func (d *Driver) Disconnect(connID uint32) error { + msg := make([]byte, 5) + msg[0] = cmdClose + binary.BigEndian.PutUint32(msg[1:5], connID) + return d.ipc.send(msg) +} + +// NetworkList returns all networks known to the registry. +func (d *Driver) NetworkList() (map[string]interface{}, error) { + return d.jsonRPC([]byte{cmdNetwork, subNetworkList}, cmdNetworkOK, "network list") +} + +// NetworkJoin joins a network by ID, optionally using a token for token-gated networks. +func (d *Driver) NetworkJoin(networkID uint16, token string) (map[string]interface{}, error) { + msg := make([]byte, 1+1+2+len(token)) + msg[0] = cmdNetwork + msg[1] = subNetworkJoin + binary.BigEndian.PutUint16(msg[2:4], networkID) + copy(msg[4:], token) + return d.jsonRPC(msg, cmdNetworkOK, "network join") +} + +// NetworkLeave leaves a network by ID. +func (d *Driver) NetworkLeave(networkID uint16) (map[string]interface{}, error) { + msg := make([]byte, 4) + msg[0] = cmdNetwork + msg[1] = subNetworkLeave + binary.BigEndian.PutUint16(msg[2:4], networkID) + return d.jsonRPC(msg, cmdNetworkOK, "network leave") +} + +// NetworkMembers lists all members of a network. +func (d *Driver) NetworkMembers(networkID uint16) (map[string]interface{}, error) { + msg := make([]byte, 4) + msg[0] = cmdNetwork + msg[1] = subNetworkMembers + binary.BigEndian.PutUint16(msg[2:4], networkID) + return d.jsonRPC(msg, cmdNetworkOK, "network members") +} + +// NetworkInvite invites a target node to a network (requires admin token on daemon). +func (d *Driver) NetworkInvite(networkID uint16, targetNodeID uint32) (map[string]interface{}, error) { + msg := make([]byte, 8) + msg[0] = cmdNetwork + msg[1] = subNetworkInvite + binary.BigEndian.PutUint16(msg[2:4], networkID) + binary.BigEndian.PutUint32(msg[4:8], targetNodeID) + return d.jsonRPC(msg, cmdNetworkOK, "network invite") +} + +// NetworkPollInvites returns pending network invites for this node. +func (d *Driver) NetworkPollInvites() (map[string]interface{}, error) { + return d.jsonRPC([]byte{cmdNetwork, subNetworkPollInvites}, cmdNetworkOK, "network poll-invites") +} + +// NetworkRespondInvite accepts or rejects a pending network invite. +func (d *Driver) NetworkRespondInvite(networkID uint16, accept bool) (map[string]interface{}, error) { + msg := make([]byte, 5) + msg[0] = cmdNetwork + msg[1] = subNetworkRespondInvite + binary.BigEndian.PutUint16(msg[2:4], networkID) + if accept { + msg[4] = 1 + } + return d.jsonRPC(msg, cmdNetworkOK, "network respond-invite") +} + +// ManagedStatus returns the status of a managed network engine. +func (d *Driver) ManagedStatus(networkID uint16) (map[string]interface{}, error) { + msg := make([]byte, 4) + msg[0] = cmdManaged + msg[1] = subManagedStatus + binary.BigEndian.PutUint16(msg[2:4], networkID) + return d.jsonRPC(msg, cmdManagedOK, "managed status") +} + +// ManagedForceCycle forces a prune/fill cycle in a managed network. +func (d *Driver) ManagedForceCycle(networkID uint16) (map[string]interface{}, error) { + msg := make([]byte, 4) + msg[0] = cmdManaged + msg[1] = subManagedCycle + binary.BigEndian.PutUint16(msg[2:4], networkID) + return d.jsonRPC(msg, cmdManagedOK, "managed cycle") +} + +// ManagedReconcile asks the daemon's policy runner for networkID to +// poll the registry and refresh its peer set — without running a +// policy cycle. Returns {network_id, peers}. +func (d *Driver) ManagedReconcile(networkID uint16) (map[string]interface{}, error) { + msg := make([]byte, 4) + msg[0] = cmdManaged + msg[1] = subManagedReconcile + binary.BigEndian.PutUint16(msg[2:4], networkID) + return d.jsonRPC(msg, cmdManagedOK, "managed reconcile") +} + +// PolicyGet retrieves the active policy for a network from the daemon. +func (d *Driver) PolicyGet(networkID uint16) (map[string]interface{}, error) { + msg := make([]byte, 4) + msg[0] = cmdManaged + msg[1] = subManagedPolicy + msg[2] = 0x00 // get + // Shift: need [cmd][sub][action][netID_hi][netID_lo] + msg = make([]byte, 5) + msg[0] = cmdManaged + msg[1] = subManagedPolicy + msg[2] = 0x00 // get + binary.BigEndian.PutUint16(msg[3:5], networkID) + return d.jsonRPC(msg, cmdManagedOK, "policy get") +} + +// PolicySet sends a policy document to the daemon for immediate application. +func (d *Driver) PolicySet(networkID uint16, policyJSON []byte) (map[string]interface{}, error) { + msg := make([]byte, 5+len(policyJSON)) + msg[0] = cmdManaged + msg[1] = subManagedPolicy + msg[2] = 0x01 // set + binary.BigEndian.PutUint16(msg[3:5], networkID) + copy(msg[5:], policyJSON) + return d.jsonRPC(msg, cmdManagedOK, "policy set") +} + +// MemberTagsGet retrieves admin-assigned member tags for a node in a network. +func (d *Driver) MemberTagsGet(networkID uint16, nodeID uint32) (map[string]interface{}, error) { + msg := make([]byte, 9) + msg[0] = cmdManaged + msg[1] = subManagedMemberTags + msg[2] = 0x00 // get + binary.BigEndian.PutUint16(msg[3:5], networkID) + binary.BigEndian.PutUint32(msg[5:9], nodeID) + return d.jsonRPC(msg, cmdManagedOK, "member-tags get") +} + +// MemberTagsSet sets admin-assigned member tags for a node in a network. +func (d *Driver) MemberTagsSet(networkID uint16, nodeID uint32, tags []string) (map[string]interface{}, error) { + tagsJSON, _ := json.Marshal(tags) + msg := make([]byte, 9+len(tagsJSON)) + msg[0] = cmdManaged + msg[1] = subManagedMemberTags + msg[2] = 0x01 // set + binary.BigEndian.PutUint16(msg[3:5], networkID) + binary.BigEndian.PutUint32(msg[5:9], nodeID) + copy(msg[9:], tagsJSON) + return d.jsonRPC(msg, cmdManagedOK, "member-tags set") +} + +// Close disconnects from the daemon. +func (d *Driver) Close() error { + return d.ipc.close() +} diff --git a/driver/ipc.go b/driver/ipc.go new file mode 100644 index 0000000..457a458 --- /dev/null +++ b/driver/ipc.go @@ -0,0 +1,444 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package driver + +import ( + "encoding/binary" + "fmt" + "net" + "sync" + "time" + + "github.com/pilot-protocol/common/ipcutil" + "github.com/pilot-protocol/common/protocol" +) + +// IPC commands (must match daemon/ipc.go) +const ( + cmdBind byte = 0x01 + cmdBindOK byte = 0x02 + cmdDial byte = 0x03 + cmdDialOK byte = 0x04 + cmdAccept byte = 0x05 + cmdSend byte = 0x06 + cmdRecv byte = 0x07 + cmdClose byte = 0x08 + cmdCloseOK byte = 0x09 + cmdError byte = 0x0A + cmdSendTo byte = 0x0B + cmdRecvFrom byte = 0x0C + cmdInfo byte = 0x0D + cmdInfoOK byte = 0x0E + cmdHandshake byte = 0x0F + cmdHandshakeOK byte = 0x10 + cmdResolveHostname byte = 0x11 + cmdResolveHostnameOK byte = 0x12 + cmdSetHostname byte = 0x13 + cmdSetHostnameOK byte = 0x14 + cmdSetVisibility byte = 0x15 + cmdSetVisibilityOK byte = 0x16 + cmdDeregister byte = 0x17 + cmdDeregisterOK byte = 0x18 + cmdSetTags byte = 0x19 + cmdSetTagsOK byte = 0x1A + cmdSetWebhook byte = 0x1B + cmdSetWebhookOK byte = 0x1C + cmdNetwork byte = 0x1F + cmdNetworkOK byte = 0x20 + cmdHealth byte = 0x21 + cmdHealthOK byte = 0x22 + cmdManaged byte = 0x23 + cmdManagedOK byte = 0x24 + cmdRotateKey byte = 0x25 + cmdRotateKeyOK byte = 0x26 + cmdBroadcast byte = 0x29 + cmdBroadcastOK byte = 0x2A +) + +// Network sub-commands (must match daemon SubNetwork* constants) +const ( + subNetworkList byte = 0x01 + subNetworkJoin byte = 0x02 + subNetworkLeave byte = 0x03 + subNetworkMembers byte = 0x04 + subNetworkInvite byte = 0x05 + subNetworkPollInvites byte = 0x06 + subNetworkRespondInvite byte = 0x07 +) + +// Managed sub-commands (must match daemon SubManaged* constants) +const ( + subManagedStatus byte = 0x02 + subManagedCycle byte = 0x04 + subManagedPolicy byte = 0x05 + subManagedMemberTags byte = 0x06 + subManagedReconcile byte = 0x07 +) + +// ipcEnvelopeHeaderSize matches daemon.IPCEnvelopeHeaderSize: 1 byte cmd. +const ipcEnvelopeHeaderSize = 1 + +// Datagram represents a received unreliable datagram. +type Datagram struct { + SrcAddr protocol.Addr + SrcPort uint16 + DstPort uint16 + Data []byte +} + +// pendingResponse carries the response to a sendAndWait waiter — either +// the cmd-OK payload (ok=true) or the error text from cmdError. +type pendingResponse struct { + cmd byte + payload []byte +} + +type ipcClient struct { + conn net.Conn + + // writeMu serializes frame writes so concurrent goroutines don't + // interleave bytes on the wire. Held only for the write itself. + writeMu sync.Mutex + + // waitSem is a channel-based semaphore (capacity 1) that ensures at + // most one request/reply pair is in-flight at a time. Using a channel + // instead of sync.Mutex lets goroutines waiting for the semaphore be + // woken on doneCh close, preventing a deadlock when the daemon closes + // while many goroutines are queued behind a slow sendAndWait. + waitSem chan struct{} // capacity 1 + pending chan *pendingResponse // capacity 16; buffers reply frames from readLoop + + recvMu sync.Mutex + recvChs map[uint32]chan []byte // conn_id → data channel + pendRecv map[uint32][][]byte // conn_id → buffered data before recvCh registered + pendAccept map[uint16][][]byte // port → buffered cmdAccept payloads before acceptCh registered (post-#99 race fix) + + acceptMu sync.Mutex + acceptChs map[uint16]chan []byte // H12 fix: per-port accept channels + + dgCh chan *Datagram // incoming datagrams + doneCh chan struct{} // closed when readLoop exits + + closeOnce sync.Once +} + +func newIPCClient(socketPath string) (*ipcClient, error) { + conn, err := net.Dial("unix", socketPath) + if err != nil { + return nil, fmt.Errorf("connect to daemon: %w", err) + } + + c := &ipcClient{ + conn: conn, + waitSem: make(chan struct{}, 1), + pending: make(chan *pendingResponse, 16), + recvChs: make(map[uint32]chan []byte), + pendRecv: make(map[uint32][][]byte), + pendAccept: make(map[uint16][][]byte), + acceptChs: make(map[uint16]chan []byte), + dgCh: make(chan *Datagram, 256), + doneCh: make(chan struct{}), + } + + go c.readLoop() + return c, nil +} + +func (c *ipcClient) close() error { + var err error + c.closeOnce.Do(func() { + err = c.conn.Close() + }) + return err +} + +// readLoop demultiplexes incoming envelopes. Wire format: +// +// [uint32-len][uint8-cmd][payload...] +// +// Server-pushed frames (cmdRecv, cmdCloseOK, cmdRecvFrom, cmdAccept) are +// routed by cmd to their per-connection channels. cmdCloseOK is always +// a server-push (remote FIN); Driver.Disconnect uses send() not +// sendAndWait() so it never waits for cmdCloseOK in pending. +// Known response cmds are forwarded to c.pending for sendAndWait. +// Unknown cmds are silently dropped — they never reach pending, so +// sendAndWaitTimeout can use a single read without a discard loop. +func (c *ipcClient) readLoop() { + defer c.cleanup() + for { + msg, err := ipcutil.Read(c.conn) + if err != nil { + return + } + if len(msg) < ipcEnvelopeHeaderSize { + continue + } + + cmd := msg[0] + payload := msg[ipcEnvelopeHeaderSize:] + + switch cmd { + case cmdRecv, cmdRecvFrom, cmdAccept, cmdCloseOK: + // Server-pushed frames: route to per-connection channels. + c.dispatchPush(cmd, payload) + case cmdBindOK, cmdDialOK, cmdError, cmdInfoOK, cmdHandshakeOK, + cmdResolveHostnameOK, cmdSetHostnameOK, cmdSetVisibilityOK, + cmdDeregisterOK, cmdSetTagsOK, cmdSetWebhookOK, cmdNetworkOK, + cmdHealthOK, cmdManagedOK, cmdRotateKeyOK, cmdBroadcastOK: + // Known response cmds: route to pending for the in-flight sendAndWait. + select { + case c.pending <- &pendingResponse{cmd: cmd, payload: append([]byte(nil), payload...)}: + default: + } + // default: unknown cmd — silently drop (version mismatch, test injection, etc.) + } + } +} + +// dispatchPush routes server-pushed (reqID==0) frames to their per-cmd +// destination. CmdRecv and CmdCloseOK route by conn ID; CmdAccept by +// listener port; CmdRecvFrom into the global datagram channel. +func (c *ipcClient) dispatchPush(cmd byte, payload []byte) { + switch cmd { + case cmdRecv: + if len(payload) >= 4 { + connID := binary.BigEndian.Uint32(payload[0:4]) + data := append([]byte(nil), payload[4:]...) + c.recvMu.Lock() + ch, ok := c.recvChs[connID] + if ok { + c.recvMu.Unlock() + // Drop the recvMu BEFORE blocking on the channel send + // so Conn.Close() / unregisterRecvCh can take the lock + // while readLoop is parked. Without this, a slow Conn + // holds recvMu indirectly (through readLoop) and other + // IPC operations stall. + ch <- data + } else { + c.pendRecv[connID] = append(c.pendRecv[connID], data) + c.recvMu.Unlock() + } + } + case cmdCloseOK: + // Server-pushed CmdCloseOK fires from recvPusher when the remote + // FINs. Close the per-conn recv channel so blocked reads see EOF. + if len(payload) >= 4 { + connID := binary.BigEndian.Uint32(payload[0:4]) + c.recvMu.Lock() + ch, ok := c.recvChs[connID] + if ok { + delete(c.recvChs, connID) + close(ch) + } + c.recvMu.Unlock() + } + case cmdRecvFrom: + if len(payload) >= protocol.AddrSize+4 { + srcAddr := protocol.UnmarshalAddr(payload[0:protocol.AddrSize]) + srcPort := binary.BigEndian.Uint16(payload[protocol.AddrSize:]) + dstPort := binary.BigEndian.Uint16(payload[protocol.AddrSize+2:]) + data := append([]byte(nil), payload[protocol.AddrSize+4:]...) + select { + case c.dgCh <- &Datagram{SrcAddr: srcAddr, SrcPort: srcPort, DstPort: dstPort, Data: data}: + default: + } + } + case cmdAccept: + if len(payload) >= 2 { + port := binary.BigEndian.Uint16(payload[0:2]) + rest := append([]byte(nil), payload[2:]...) + c.acceptMu.Lock() + ch, ok := c.acceptChs[port] + if ok { + c.acceptMu.Unlock() + select { + case ch <- rest: + default: + } + } else { + // Buffer until registerAcceptCh is called. The race + // (post-#99): with concurrent daemon dispatch, the + // daemon can push cmdAccept BEFORE the driver registers + // acceptChs[port] — Listen() registers AFTER the + // cmdBind reply, but a peer's dial can race the bind + // reply through different worker goroutines on the + // daemon side. Same pattern as pendRecv for cmdRecv. + c.pendAccept[port] = append(c.pendAccept[port], rest) + c.acceptMu.Unlock() + } + } + default: + // Unknown unsolicited cmd — drop. The daemon should never send + // reqID=0 with a cmd outside this set; if a test or future + // addition does, dropping is the safe default. + } +} + +// cleanup closes channels when readLoop exits (daemon disconnect). +func (c *ipcClient) cleanup() { + close(c.doneCh) + + // Drain all buffered responses. + for { + select { + case <-c.pending: + default: + goto drained + } + } +drained: + + // Close all receive channels + c.recvMu.Lock() + for id, ch := range c.recvChs { + close(ch) + delete(c.recvChs, id) + } + c.recvMu.Unlock() + + // Close all accept channels (H12 fix) + c.acceptMu.Lock() + for port, ch := range c.acceptChs { + close(ch) + delete(c.acceptChs, port) + } + c.acceptMu.Unlock() +} + +// writeFrame builds a `[cmd][body...]` envelope and writes it under +// writeMu so frames don't interleave on the wire. +func (c *ipcClient) writeFrame(cmd byte, body []byte) error { + buf := make([]byte, ipcEnvelopeHeaderSize+len(body)) + buf[0] = cmd + copy(buf[1:], body) + c.writeMu.Lock() + defer c.writeMu.Unlock() + return ipcutil.Write(c.conn, buf) +} + +// send is a fire-and-forget write — used for cmdSend/cmdSendTo where +// the daemon does not reply. Acquires only writeMu (not waitMu), so +// concurrent fire-and-forget sends are never blocked behind a reply wait. +func (c *ipcClient) send(data []byte) error { + if len(data) < 1 { + return fmt.Errorf("ipc: empty message") + } + return c.writeFrame(data[0], data[1:]) +} + +// sendAndWait sends a request and waits for the reply. +func (c *ipcClient) sendAndWait(data []byte, expectCmd byte) ([]byte, error) { + return c.sendAndWaitTimeout(data, expectCmd, 0) +} + +// sendAndWaitTimeout serialises at most one request/reply pair at a time +// via waitSem. timeout=0 means wait forever. The timer is started BEFORE +// acquiring the semaphore so the timeout applies to queue wait + reply +// wait combined — without this, goroutines queued behind the semaphore +// can't time out and pile up indefinitely under high concurrency. +func (c *ipcClient) sendAndWaitTimeout(data []byte, expectCmd byte, timeout time.Duration) ([]byte, error) { + if len(data) < 1 { + return nil, fmt.Errorf("ipc: empty request") + } + + // Start the timer before acquiring the semaphore so queued goroutines + // can bail out instead of waiting forever. + var timer <-chan time.Time + if timeout > 0 { + t := time.NewTimer(timeout) + defer t.Stop() + timer = t.C + } + + // Acquire the serialisation semaphore. Channel-based (not sync.Mutex) + // so goroutines blocked here are woken by doneCh or timer. + select { + case c.waitSem <- struct{}{}: + case <-c.doneCh: + return nil, fmt.Errorf("daemon disconnected") + case <-timer: + return nil, fmt.Errorf("dial timeout") + } + defer func() { <-c.waitSem }() + + // Drain all stale replies buffered before this request was sent. + for { + select { + case <-c.pending: + default: + goto drained + } + } +drained: + + if err := c.writeFrame(data[0], data[1:]); err != nil { + return nil, err + } + + // Unknown cmds are dropped in readLoop, so the first frame in pending + // is always either the expected response or cmdError. + select { + case resp := <-c.pending: + if resp.cmd == cmdError { + if len(resp.payload) >= 2 { + return nil, fmt.Errorf("daemon: %s", string(resp.payload[2:])) + } + return nil, fmt.Errorf("daemon error") + } + if resp.cmd != expectCmd { + return nil, fmt.Errorf("ipc: unexpected reply 0x%02X (want 0x%02X)", resp.cmd, expectCmd) + } + return resp.payload, nil + case <-c.doneCh: + return nil, fmt.Errorf("daemon disconnected") + case <-timer: + return nil, fmt.Errorf("dial timeout") + } +} + +// H12 fix: per-port accept channel management. +// Drains any cmdAccept frames buffered in pendAccept (the post-#99 +// race window between cmdBind reply and acceptChs registration). +func (c *ipcClient) registerAcceptCh(port uint16) chan []byte { + ch := make(chan []byte, 64) + c.acceptMu.Lock() + c.acceptChs[port] = ch + pending := c.pendAccept[port] + delete(c.pendAccept, port) + c.acceptMu.Unlock() + for _, data := range pending { + select { + case ch <- data: + default: + } + } + return ch +} + +func (c *ipcClient) registerRecvCh(connID uint32) chan []byte { + ch := make(chan []byte, 256) + c.recvMu.Lock() + c.recvChs[connID] = ch + // Drain any data that arrived before registration. Hold recvMu + // across the drain so a concurrent dispatchPush(cmdCloseOK) for the + // same connID can't race with these sends — without this guard, the + // FIN handler at dispatchPush:250 closes the channel mid-drain and + // chansend1 panics on a closed channel (issue #105 §4.8 race). + // The drain is bounded by len(pendRecv[connID]) which is small — + // data only buffers in pendRecv during the brief window between + // the daemon dispatching cmdRecv and the driver's Accept calling + // registerRecvCh, and never exceeds a single slow-path frame batch. + pending := c.pendRecv[connID] + delete(c.pendRecv, connID) + for _, data := range pending { + ch <- data + } + c.recvMu.Unlock() + return ch +} + +func (c *ipcClient) unregisterRecvCh(connID uint32) { + c.recvMu.Lock() + defer c.recvMu.Unlock() + delete(c.recvChs, connID) +} diff --git a/driver/listener.go b/driver/listener.go new file mode 100644 index 0000000..22d61db --- /dev/null +++ b/driver/listener.go @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package driver + +import ( + "encoding/binary" + "fmt" + "net" + "sync" + + "github.com/pilot-protocol/common/protocol" +) + +// Listener implements net.Listener over a Pilot Protocol port. +type Listener struct { + port uint16 + ipc *ipcClient + acceptCh chan []byte // H12 fix: per-port accept channel + mu sync.Mutex + closed bool + done chan struct{} // closed on Close() to unblock Accept (H13 fix) +} + +func (l *Listener) Accept() (net.Conn, error) { + l.mu.Lock() + if l.closed { + l.mu.Unlock() + return nil, fmt.Errorf("listener closed") + } + l.mu.Unlock() + + // H12 fix: wait on per-port accept channel + var payload []byte + var ok bool + select { + case payload, ok = <-l.acceptCh: + if !ok { + return nil, fmt.Errorf("listener closed") + } + case <-l.done: + return nil, fmt.Errorf("listener closed") + } + + // Parse: [4 bytes conn_id][6 bytes remote addr][2 bytes remote port] + if len(payload) < 4+protocol.AddrSize+2 { + return nil, fmt.Errorf("invalid accept payload") + } + + connID := binary.BigEndian.Uint32(payload[0:4]) + remoteAddr := protocol.UnmarshalAddr(payload[4 : 4+protocol.AddrSize]) + remotePort := binary.BigEndian.Uint16(payload[4+protocol.AddrSize:]) + + recvCh := l.ipc.registerRecvCh(connID) + + conn := &Conn{ + id: connID, + localAddr: protocol.SocketAddr{Port: l.port}, + remoteAddr: protocol.SocketAddr{Addr: remoteAddr, Port: remotePort}, + ipc: l.ipc, + recvCh: recvCh, + deadlineCh: make(chan struct{}), + } + + return conn, nil +} + +func (l *Listener) Close() error { + l.mu.Lock() + if !l.closed { + l.closed = true + close(l.done) // unblock Accept() (H13 fix) + } + l.mu.Unlock() + return nil +} + +func (l *Listener) Addr() net.Addr { + return pilotAddr(protocol.SocketAddr{Port: l.port}) +} diff --git a/driver/zz_conn_test.go b/driver/zz_conn_test.go new file mode 100644 index 0000000..6ef215c --- /dev/null +++ b/driver/zz_conn_test.go @@ -0,0 +1,299 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package driver + +import ( + "errors" + "io" + "net" + "os" + "testing" + "time" + + "github.com/pilot-protocol/common/protocol" +) + +// --------------------------------------------------------------------------- +// pilotAddr — net.Addr implementation +// --------------------------------------------------------------------------- + +func TestPilotAddrNetwork(t *testing.T) { + t.Parallel() + a := pilotAddr(protocol.SocketAddr{Port: 80}) + if got := a.Network(); got != "pilot" { + t.Errorf("Network() = %q, want %q", got, "pilot") + } +} + +func TestPilotAddrString(t *testing.T) { + t.Parallel() + addr, _ := protocol.ParseAddr("1:0001.0002.0003") + a := pilotAddr(protocol.SocketAddr{Addr: addr, Port: 7}) + got := a.String() + want := protocol.SocketAddr{Addr: addr, Port: 7}.String() + if got != want { + t.Errorf("String() = %q, want %q", got, want) + } +} + +// --------------------------------------------------------------------------- +// Conn — read leftover and deadline behaviour exercise the in-memory +// branches that don't require live IPC. +// --------------------------------------------------------------------------- + +func TestConnReadDrainsLeftover(t *testing.T) { + t.Parallel() + c := &Conn{ + recvBuf: []byte("hello world"), + recvCh: make(chan []byte), + deadlineCh: make(chan struct{}), + } + got := make([]byte, 5) + n, err := c.Read(got) + if err != nil { + t.Fatal(err) + } + if n != 5 || string(got) != "hello" { + t.Fatalf("first read: n=%d got=%q", n, got) + } + // Second read drains the rest of the leftover (no IPC needed). + got2 := make([]byte, 6) + n2, err := c.Read(got2) + if err != nil { + t.Fatal(err) + } + if n2 != 6 || string(got2) != " world" { + t.Fatalf("second read: n=%d got=%q", n2, got2) + } +} + +func TestConnReadDeadlineAlreadyPassed(t *testing.T) { + t.Parallel() + c := &Conn{ + recvCh: make(chan []byte), + deadlineCh: make(chan struct{}), + readDeadline: time.Now().Add(-time.Second), // already in past + } + _, err := c.Read(make([]byte, 1)) + if !errors.Is(err, os.ErrDeadlineExceeded) { + t.Errorf("got %v, want ErrDeadlineExceeded", err) + } +} + +func TestConnReadEOFOnClosedRecvCh(t *testing.T) { + t.Parallel() + ch := make(chan []byte) + close(ch) + c := &Conn{ + recvCh: ch, + deadlineCh: make(chan struct{}), + } + _, err := c.Read(make([]byte, 1)) + if !errors.Is(err, io.EOF) { + t.Errorf("got %v, want io.EOF", err) + } +} + +func TestConnReadDelivers(t *testing.T) { + t.Parallel() + ch := make(chan []byte, 1) + ch <- []byte("xy") + c := &Conn{ + recvCh: ch, + deadlineCh: make(chan struct{}), + } + buf := make([]byte, 2) + n, err := c.Read(buf) + if err != nil || n != 2 || string(buf) != "xy" { + t.Fatalf("got n=%d err=%v buf=%q", n, err, buf) + } +} + +func TestConnReadStoresLeftoverWhenBufferTooSmall(t *testing.T) { + t.Parallel() + ch := make(chan []byte, 1) + ch <- []byte("12345") + c := &Conn{ + recvCh: ch, + deadlineCh: make(chan struct{}), + } + buf := make([]byte, 2) + n, err := c.Read(buf) + if err != nil || n != 2 || string(buf) != "12" { + t.Fatalf("first read got n=%d err=%v buf=%q", n, err, buf) + } + // Remaining 3 bytes should be in recvBuf + rest := make([]byte, 3) + n2, err := c.Read(rest) + if err != nil || n2 != 3 || string(rest) != "345" { + t.Fatalf("leftover read got n=%d err=%v buf=%q", n2, err, rest) + } +} + +func TestConnReadTimerExpires(t *testing.T) { + t.Parallel() + c := &Conn{ + recvCh: make(chan []byte), + deadlineCh: make(chan struct{}), + readDeadline: time.Now().Add(20 * time.Millisecond), + } + start := time.Now() + _, err := c.Read(make([]byte, 1)) + if !errors.Is(err, os.ErrDeadlineExceeded) { + t.Errorf("got %v, want ErrDeadlineExceeded", err) + } + if elapsed := time.Since(start); elapsed < 20*time.Millisecond { + t.Errorf("returned too early: %v", elapsed) + } +} + +func TestSetReadDeadlineUnblocksReader(t *testing.T) { + t.Parallel() + c := &Conn{ + recvCh: make(chan []byte), + deadlineCh: make(chan struct{}), + } + done := make(chan error, 1) + go func() { + _, err := c.Read(make([]byte, 1)) + done <- err + }() + // Give Read a moment to enter the select. + time.Sleep(10 * time.Millisecond) + c.SetReadDeadline(time.Now().Add(time.Hour)) // closes the old deadlineCh + select { + case err := <-done: + if !errors.Is(err, os.ErrDeadlineExceeded) { + t.Errorf("got %v, want ErrDeadlineExceeded", err) + } + case <-time.After(time.Second): + t.Fatal("Read did not unblock after SetReadDeadline") + } +} + +func TestSetDeadlineDelegatesToRead(t *testing.T) { + t.Parallel() + c := &Conn{ + recvCh: make(chan []byte), + deadlineCh: make(chan struct{}), + } + dl := time.Now().Add(time.Hour) + if err := c.SetDeadline(dl); err != nil { + t.Fatal(err) + } + if !c.readDeadline.Equal(dl) { + t.Errorf("readDeadline = %v, want %v", c.readDeadline, dl) + } +} + +func TestSetWriteDeadlineNoop(t *testing.T) { + t.Parallel() + c := &Conn{} + if err := c.SetWriteDeadline(time.Now()); err != nil { + t.Errorf("expected nil, got %v", err) + } +} + +func TestConnAddrs(t *testing.T) { + t.Parallel() + addr, _ := protocol.ParseAddr("1:0001.0002.0003") + c := &Conn{ + localAddr: protocol.SocketAddr{Port: 80}, + remoteAddr: protocol.SocketAddr{Addr: addr, Port: 7}, + } + if c.LocalAddr().Network() != "pilot" { + t.Errorf("LocalAddr().Network() unexpected") + } + if c.RemoteAddr().Network() != "pilot" { + t.Errorf("RemoteAddr().Network() unexpected") + } +} + +// --------------------------------------------------------------------------- +// Listener — Accept payload parsing and Close behavior +// --------------------------------------------------------------------------- + +func TestListenerCloseUnblocksAccept(t *testing.T) { + t.Parallel() + l := &Listener{ + port: 80, + acceptCh: make(chan []byte), + done: make(chan struct{}), + } + type r struct{ err error } + ch := make(chan r, 1) + go func() { + _, err := l.Accept() + ch <- r{err} + }() + time.Sleep(10 * time.Millisecond) + if err := l.Close(); err != nil { + t.Fatal(err) + } + select { + case got := <-ch: + if got.err == nil { + t.Fatal("expected error after close") + } + case <-time.After(time.Second): + t.Fatal("Accept did not unblock after Close") + } +} + +func TestListenerAcceptOnAlreadyClosed(t *testing.T) { + t.Parallel() + l := &Listener{ + port: 80, + acceptCh: make(chan []byte), + done: make(chan struct{}), + } + if err := l.Close(); err != nil { + t.Fatal(err) + } + _, err := l.Accept() + if err == nil { + t.Fatal("expected closed error") + } +} + +func TestListenerCloseIdempotent(t *testing.T) { + t.Parallel() + l := &Listener{ + port: 80, + acceptCh: make(chan []byte), + done: make(chan struct{}), + } + if err := l.Close(); err != nil { + t.Fatal(err) + } + // Second Close must not panic on closed channel + if err := l.Close(); err != nil { + t.Errorf("second Close: %v", err) + } +} + +func TestListenerAddr(t *testing.T) { + t.Parallel() + l := &Listener{port: 8080} + a := l.Addr() + if a.Network() != "pilot" { + t.Errorf("Network() = %q", a.Network()) + } +} + +func TestListenerAcceptInvalidPayload(t *testing.T) { + t.Parallel() + l := &Listener{ + port: 80, + acceptCh: make(chan []byte, 1), + done: make(chan struct{}), + } + l.acceptCh <- []byte{0x01, 0x02} // way too short + _, err := l.Accept() + if err == nil { + t.Fatal("expected invalid-payload error") + } +} + +// satisfy unused import detector if SDK isn't otherwise used here +var _ net.Listener = (*Listener)(nil) diff --git a/driver/zz_conn_write_test.go b/driver/zz_conn_write_test.go new file mode 100644 index 0000000..f208518 --- /dev/null +++ b/driver/zz_conn_write_test.go @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package driver + +import ( + "encoding/binary" + "io" + "net" + "sync" + "testing" + "time" + + "github.com/pilot-protocol/common/ipcutil" +) + +// TestConnWriteChunksLargePayload verifies that Conn.Write splits payloads +// larger than the IPC message cap into multiple cmdSend messages so the +// daemon side never rejects oversized frames. +func TestConnWriteChunksLargePayload(t *testing.T) { + t.Parallel() + clientSide, serverSide := net.Pipe() + defer clientSide.Close() + defer serverSide.Close() + + ipc := &ipcClient{ + conn: clientSide, + waitSem: make(chan struct{}, 1), + pending: make(chan *pendingResponse, 16), + recvChs: make(map[uint32]chan []byte), + pendRecv: make(map[uint32][][]byte), + acceptChs: make(map[uint16]chan []byte), + dgCh: make(chan *Datagram, 1), + doneCh: make(chan struct{}), + } + + const connID uint32 = 42 + c := &Conn{id: connID, ipc: ipc, deadlineCh: make(chan struct{})} + + const payloadSize = 5 * 1024 * 1024 // 5 MB + payload := make([]byte, payloadSize) + for i := range payload { + payload[i] = byte(i) + } + + // Wire format (issue #99): [cmd(1)][reqID(8)][connID(4)][data...]. + // Each cmdSend frame carries 13 bytes of header before the payload. + const sendHdr = ipcEnvelopeHeaderSize + 4 + + var got []byte + var chunks int + var readErr error + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + _ = serverSide.SetReadDeadline(time.Now().Add(5 * time.Second)) + for len(got) < payloadSize { + msg, err := ipcutil.Read(serverSide) + if err != nil { + if err != io.EOF { + readErr = err + } + return + } + if len(msg) < sendHdr { + readErr = io.ErrShortBuffer + return + } + if msg[0] != cmdSend { + readErr = io.ErrUnexpectedEOF + return + } + gotID := binary.BigEndian.Uint32(msg[ipcEnvelopeHeaderSize : ipcEnvelopeHeaderSize+4]) + if gotID != connID { + readErr = io.ErrUnexpectedEOF + return + } + if len(msg) > ipcutil.MaxMessageSize { + readErr = io.ErrShortBuffer + return + } + chunks++ + got = append(got, msg[sendHdr:]...) + } + }() + + n, err := c.Write(payload) + if err != nil { + t.Fatalf("Write returned err: %v", err) + } + if n != payloadSize { + t.Fatalf("Write returned n=%d, want %d", n, payloadSize) + } + + // Close to unblock reader if it got everything already. + _ = clientSide.Close() + wg.Wait() + + if readErr != nil { + t.Fatalf("reader err: %v", readErr) + } + if len(got) != payloadSize { + t.Fatalf("reader got %d bytes, want %d", len(got), payloadSize) + } + if chunks < 2 { + t.Fatalf("expected >=2 chunks for 5MB payload, got %d", chunks) + } + for i, b := range got { + if b != byte(i) { + t.Fatalf("byte %d: got %d, want %d", i, b, byte(i)) + } + } +} + +// TestConnWriteSinglePayloadNotSplit verifies that payloads that fit in one +// IPC message are still sent as a single cmdSend message. +func TestConnWriteSinglePayloadNotSplit(t *testing.T) { + t.Parallel() + clientSide, serverSide := net.Pipe() + defer clientSide.Close() + defer serverSide.Close() + + ipc := &ipcClient{ + conn: clientSide, + waitSem: make(chan struct{}, 1), + pending: make(chan *pendingResponse, 16), + recvChs: make(map[uint32]chan []byte), + pendRecv: make(map[uint32][][]byte), + acceptChs: make(map[uint16]chan []byte), + dgCh: make(chan *Datagram, 1), + doneCh: make(chan struct{}), + } + + const connID uint32 = 7 + c := &Conn{id: connID, ipc: ipc, deadlineCh: make(chan struct{})} + + payload := []byte("hello world") + + // Wire format: [cmd(1)][connID(4)][data...] + const sendHdr = ipcEnvelopeHeaderSize + 4 + + var got []byte + var chunks int + done := make(chan struct{}) + go func() { + defer close(done) + _ = serverSide.SetReadDeadline(time.Now().Add(2 * time.Second)) + msg, err := ipcutil.Read(serverSide) + if err != nil { + return + } + chunks++ + got = append(got, msg[sendHdr:]...) + }() + + if _, err := c.Write(payload); err != nil { + t.Fatalf("Write err: %v", err) + } + <-done + + if chunks != 1 { + t.Fatalf("expected 1 chunk, got %d", chunks) + } + if string(got) != string(payload) { + t.Fatalf("got %q, want %q", got, payload) + } +} diff --git a/driver/zz_driver_simple_ops_test.go b/driver/zz_driver_simple_ops_test.go new file mode 100644 index 0000000..d2e82b5 --- /dev/null +++ b/driver/zz_driver_simple_ops_test.go @@ -0,0 +1,134 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package driver + +import ( + "testing" +) + +// TestDriverClose covers the trivial Close() forwarder. +func TestDriverClose(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + + drv, err := Connect(d.path) + if err != nil { + t.Fatalf("Connect: %v", err) + } + if err := drv.Close(); err != nil { + t.Errorf("Close: %v", err) + } +} + +// TestDriverBroadcast covers the happy-path Broadcast (network + port + +// admin token + data → cmdBroadcastOK). +func TestDriverBroadcast(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + + d.onCmd(cmdBroadcast, func(frame []byte) [][]byte { + return [][]byte{{cmdBroadcastOK}} + }) + + drv, err := Connect(d.path) + if err != nil { + t.Fatalf("Connect: %v", err) + } + defer drv.Close() + + if err := drv.Broadcast(1, 8080, []byte("hello"), "admin-token"); err != nil { + t.Fatalf("Broadcast: %v", err) + } +} + +// TestConnClose covers Conn.Close (cmdClose fire-and-forget) and its +// idempotency — second Close is a no-op. +func TestConnClose(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + + // Dial to get a Conn back. + d.onCmd(cmdDial, func(frame []byte) [][]byte { + resp := make([]byte, 5) + resp[0] = cmdDialOK + resp[1] = 0x00 + resp[2] = 0x00 + resp[3] = 0x00 + resp[4] = 0x42 + return [][]byte{resp} + }) + + drv, _ := Connect(d.path) + defer drv.Close() + + conn, err := drv.Dial("0:0000.0000.0001:80") + if err != nil { + t.Fatalf("Dial: %v", err) + } + + if err := conn.Close(); err != nil { + t.Errorf("Close: %v", err) + } + // Second Close is idempotent. + if err := conn.Close(); err != nil { + t.Errorf("second Close: %v", err) + } +} + +// TestDriverWaitForTrust covers the handshake-wait JSON-RPC. +func TestDriverWaitForTrust(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + + d.onCmd(cmdHandshake, func(frame []byte) [][]byte { + // Verify the sub-command byte (0x07 = subHandshakeWait). + if len(frame) < 2 || frame[1] != subHandshakeWait { + return [][]byte{{cmdError, 'b', 'a', 'd'}} + } + body := []byte(`{"trusted":true}`) + return [][]byte{append([]byte{cmdHandshakeOK}, body...)} + }) + + drv, _ := Connect(d.path) + defer drv.Close() + + result, err := drv.WaitForTrust(0xCAFE, 5000) + if err != nil { + t.Fatalf("WaitForTrust: %v", err) + } + if result == nil { + t.Errorf("result is nil") + } +} + +// TestDriverRotateKey covers RotateKey's JSON-RPC roundtrip. +func TestDriverRotateKey(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + + d.onCmd(cmdRotateKey, func(frame []byte) [][]byte { + // jsonRPC expects [cmdRotateKeyOK][JSON body] + body := []byte(`{"old_node_id":1,"new_node_id":2}`) + resp := append([]byte{cmdRotateKeyOK}, body...) + return [][]byte{resp} + }) + + drv, err := Connect(d.path) + if err != nil { + t.Fatalf("Connect: %v", err) + } + defer drv.Close() + + result, err := drv.RotateKey() + if err != nil { + t.Fatalf("RotateKey: %v", err) + } + if result == nil { + t.Errorf("RotateKey result is nil") + } +} diff --git a/driver/zz_driver_test.go b/driver/zz_driver_test.go new file mode 100644 index 0000000..e6e6dd4 --- /dev/null +++ b/driver/zz_driver_test.go @@ -0,0 +1,739 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package driver + +import ( + "crypto/rand" + "encoding/binary" + "encoding/hex" + "net" + "os" + "path/filepath" + "strings" + "sync" + "testing" + "time" + + "github.com/pilot-protocol/common/ipcutil" + "github.com/pilot-protocol/common/protocol" +) + +// shortSocketPath returns a /tmp path short enough for macOS unix socket +// length limit (~104 chars). t.TempDir() paths exceed this on darwin. +func shortSocketPath(t *testing.T) string { + t.Helper() + var b [6]byte + if _, err := rand.Read(b[:]); err != nil { + t.Fatal(err) + } + p := filepath.Join("/tmp", "ps-"+hex.EncodeToString(b[:])+".sock") + t.Cleanup(func() { _ = os.Remove(p) }) + return p +} + +// fakeDaemon is a minimal test harness that simulates the Pilot daemon's +// IPC wire protocol. It listens on a unix socket, records incoming frames, +// and replies with configured responses. Sufficient for verifying each +// driver.* method's request encoding and response decoding end-to-end. +type fakeDaemon struct { + t *testing.T + ln net.Listener + path string + conn net.Conn + connSet chan struct{} // closed once conn is stored in acceptLoop + mu sync.Mutex + received [][]byte // all frames received + handlers map[byte]func(frame []byte) [][]byte +} + +func newFakeDaemon(t *testing.T) *fakeDaemon { + t.Helper() + path := shortSocketPath(t) + ln, err := net.Listen("unix", path) + if err != nil { + t.Fatalf("listen unix: %v", err) + } + d := &fakeDaemon{ + t: t, + ln: ln, + path: path, + connSet: make(chan struct{}), + handlers: make(map[byte]func(frame []byte) [][]byte), + } + go d.acceptLoop() + return d +} + +func (d *fakeDaemon) acceptLoop() { + conn, err := d.ln.Accept() + if err != nil { + return + } + d.mu.Lock() + d.conn = conn + d.mu.Unlock() + close(d.connSet) // signal that conn is stored and ready to be closed + + // Wire format: [cmd(1)][payload...] — matches driver.ipcEnvelopeHeaderSize. + for { + frame, err := ipcutil.Read(conn) + if err != nil { + return + } + d.mu.Lock() + var resp [][]byte + if len(frame) >= 1 { + cmd := frame[0] + d.received = append(d.received, frame) + if h, ok := d.handlers[cmd]; ok { + resp = h(frame) + } + } + d.mu.Unlock() + for _, r := range resp { + _ = ipcutil.Write(conn, r) + } + } +} + +func (d *fakeDaemon) onCmd(cmd byte, respond func(frame []byte) [][]byte) { + d.mu.Lock() + defer d.mu.Unlock() + d.handlers[cmd] = respond +} + +func (d *fakeDaemon) lastFrame() []byte { + d.mu.Lock() + defer d.mu.Unlock() + if len(d.received) == 0 { + return nil + } + return d.received[len(d.received)-1] +} + +func (d *fakeDaemon) allFrames() [][]byte { + d.mu.Lock() + defer d.mu.Unlock() + out := make([][]byte, len(d.received)) + copy(out, d.received) + return out +} + +func (d *fakeDaemon) closeConn() { + d.mu.Lock() + c := d.conn + d.mu.Unlock() + if c != nil { + _ = c.Close() + } +} + +func (d *fakeDaemon) close() { + d.ln.Close() + // Wait for acceptLoop to store d.conn before closing it. + // Without this, close() races with acceptLoop: d.conn may still be + // nil when closeConn() runs, leaving the accepted socket open and + // blocking the driver's readLoop indefinitely. + select { + case <-d.connSet: + case <-time.After(100 * time.Millisecond): + } + d.closeConn() +} + +// waitFor polls until cond is true or deadline is reached. +func waitFor(t *testing.T, max time.Duration, cond func() bool, what string) { + t.Helper() + deadline := time.Now().Add(max) + for time.Now().Before(deadline) { + if cond() { + return + } + time.Sleep(5 * time.Millisecond) + } + t.Fatalf("timeout waiting for %s", what) +} + +// jsonOK returns a [cmd][json-body] frame. +func jsonOK(cmd byte, body string) []byte { + out := make([]byte, 1+len(body)) + out[0] = cmd + copy(out[1:], body) + return out +} + +// ---------- Connect / Close ---------- + +func TestConnectNonExistentSocketReturnsError(t *testing.T) { + t.Parallel() + _, err := Connect("/tmp/definitely-not-a-real-pilot-socket-xxx.sock") + if err == nil { + t.Fatal("expected error") + } +} + +func TestConnectEmptySocketFallsBackToDefault(t *testing.T) { + t.Parallel() + // DefaultSocketPath is /tmp/pilot.sock — almost certainly not present + // in a test env. We just assert the fall-through path is taken and + // returns an error (no panic on empty input). + _, err := Connect("") + if err == nil { + t.Log("Connect(\"\") succeeded — a daemon is running on default path; not an error") + return + } +} + +func TestConnectAndCloseSuccess(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + + drv, err := Connect(d.path) + if err != nil { + t.Fatalf("Connect: %v", err) + } + if drv.socketPath != d.path { + t.Errorf("socketPath = %q, want %q", drv.socketPath, d.path) + } + if err := drv.Close(); err != nil { + t.Errorf("Close: %v", err) + } +} + +// ---------- DialAddr / Dial ---------- + +func TestDialAddrHappyPath(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + + d.onCmd(cmdDial, func(frame []byte) [][]byte { + resp := make([]byte, 1+4) + resp[0] = cmdDialOK + binary.BigEndian.PutUint32(resp[1:5], 0xDEADBEEF) + return [][]byte{resp} + }) + + drv, err := Connect(d.path) + if err != nil { + t.Fatal(err) + } + defer drv.Close() + + dst := protocol.Addr{Network: 1, Node: 0x0102_0304} + conn, err := drv.DialAddr(dst, 7) + if err != nil { + t.Fatalf("DialAddr: %v", err) + } + if conn.id != 0xDEADBEEF { + t.Errorf("conn.id = %#x, want 0xDEADBEEF", conn.id) + } + if conn.remoteAddr.Addr != dst || conn.remoteAddr.Port != 7 { + t.Errorf("remoteAddr = %+v, want {%+v, 7}", conn.remoteAddr, dst) + } +} + +func TestDialParsesAddressString(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + d.onCmd(cmdDial, func(frame []byte) [][]byte { + resp := make([]byte, 5) + resp[0] = cmdDialOK + binary.BigEndian.PutUint32(resp[1:5], 42) + return [][]byte{resp} + }) + + drv, _ := Connect(d.path) + defer drv.Close() + + conn, err := drv.Dial("1:0001.AAAA.BBBB:80") + if err != nil { + t.Fatalf("Dial: %v", err) + } + if conn.id != 42 { + t.Errorf("id = %d, want 42", conn.id) + } +} + +func TestDialBadAddressReturnsParseError(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + drv, _ := Connect(d.path) + defer drv.Close() + if _, err := drv.Dial("not-a-valid-addr"); err == nil { + t.Fatal("expected parse error") + } +} + +func TestDialAddrTimeoutFiresWhenDaemonSilent(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + // No handler for cmdDial → daemon never responds + drv, _ := Connect(d.path) + defer drv.Close() + + start := time.Now() + _, err := drv.DialAddrTimeout(protocol.Addr{Network: 1, Node: 1}, 1, 100*time.Millisecond) + elapsed := time.Since(start) + if err == nil { + t.Fatal("expected timeout error") + } + if elapsed < 80*time.Millisecond || elapsed > 500*time.Millisecond { + t.Errorf("elapsed = %v (expected ~100ms)", elapsed) + } +} + +// ---------- Listen ---------- + +func TestListenHappyPath(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + + d.onCmd(cmdBind, func(frame []byte) [][]byte { + resp := make([]byte, 3) + resp[0] = cmdBindOK + binary.BigEndian.PutUint16(resp[1:3], 7) // echoed port + return [][]byte{resp} + }) + + drv, _ := Connect(d.path) + defer drv.Close() + + ln, err := drv.Listen(7) + if err != nil { + t.Fatalf("Listen: %v", err) + } + if ln.port != 7 { + t.Errorf("port = %d, want 7", ln.port) + } + _ = ln.Close() +} + +// ---------- SendTo / RecvFrom ---------- + +func TestSendToWritesFrame(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + drv, _ := Connect(d.path) + defer drv.Close() + + dst := protocol.Addr{Network: 2, Node: 0x0A0B_0C0D} + if err := drv.SendTo(dst, 100, []byte("hi")); err != nil { + t.Fatalf("SendTo: %v", err) + } + + waitFor(t, time.Second, func() bool { + return d.lastFrame() != nil + }, "daemon to receive frame") + frame := d.lastFrame() + // d.received stores frames as-is: [cmd(1)][body...]. + if frame[0] != cmdSendTo { + t.Errorf("cmd = %#x, want %#x", frame[0], cmdSendTo) + } + if len(frame) != 1+protocol.AddrSize+2+2 { + t.Errorf("len = %d", len(frame)) + } + gotPort := binary.BigEndian.Uint16(frame[1+protocol.AddrSize:]) + if gotPort != 100 { + t.Errorf("port = %d, want 100", gotPort) + } + if string(frame[1+protocol.AddrSize+2:]) != "hi" { + t.Errorf("payload = %q", frame[1+protocol.AddrSize+2:]) + } +} + +func TestRecvFromDeliversDatagram(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + drv, _ := Connect(d.path) + defer drv.Close() + + // Inject a cmdRecvFrom frame from the daemon + src := protocol.Addr{Network: 1, Node: 0x1122_3344} + payload := make([]byte, protocol.AddrSize+4+3) + src.MarshalTo(payload, 0) + binary.BigEndian.PutUint16(payload[protocol.AddrSize:], 200) + binary.BigEndian.PutUint16(payload[protocol.AddrSize+2:], 300) + copy(payload[protocol.AddrSize+4:], "abc") + frame := append([]byte{cmdRecvFrom}, payload...) + + // Use pushFromDaemon to write the frame through the daemon-side conn. + pushFromDaemon(t, d, frame) + + dg, err := drv.RecvFrom() + if err != nil { + t.Fatalf("RecvFrom: %v", err) + } + if dg.SrcAddr != src || dg.SrcPort != 200 || dg.DstPort != 300 || string(dg.Data) != "abc" { + t.Errorf("got %+v", dg) + } +} + +func TestRecvFromErrorAfterClose(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + drv, _ := Connect(d.path) + + // Close the daemon so readLoop exits and drains dgCh + d.close() + + // Give the readLoop time to exit + waitFor(t, time.Second, func() bool { + select { + case <-drv.ipc.doneCh: + return true + default: + return false + } + }, "readLoop exit") + + // dgCh is not explicitly closed; RecvFrom blocks until dgCh closes OR + // until we push. Since it's buffered but not closed, this would hang. + // Instead we verify the doneCh path by calling Close on the driver. + _ = drv.Close() +} + +// ---------- Info / Health ---------- + +func TestInfoAndHealthReturnParsedJSON(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + d.onCmd(cmdInfo, func(_ []byte) [][]byte { + return [][]byte{jsonOK(cmdInfoOK, `{"node_id": 42, "addr": "1:0001.0002.0003"}`)} + }) + d.onCmd(cmdHealth, func(_ []byte) [][]byte { + return [][]byte{jsonOK(cmdHealthOK, `{"ok": true}`)} + }) + drv, _ := Connect(d.path) + defer drv.Close() + + info, err := drv.Info() + if err != nil { + t.Fatalf("Info: %v", err) + } + if info["node_id"].(float64) != 42 { + t.Errorf("node_id = %v", info["node_id"]) + } + + h, err := drv.Health() + if err != nil { + t.Fatalf("Health: %v", err) + } + if h["ok"] != true { + t.Errorf("ok = %v", h["ok"]) + } +} + +func TestJsonRPCUnmarshalErrorSurfaced(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + d.onCmd(cmdInfo, func(_ []byte) [][]byte { + return [][]byte{jsonOK(cmdInfoOK, `not-json`)} + }) + drv, _ := Connect(d.path) + defer drv.Close() + + if _, err := drv.Info(); err == nil { + t.Fatal("expected unmarshal error") + } +} + +func TestSendAndWaitSurfacesDaemonErrorFrame(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + d.onCmd(cmdInfo, func(_ []byte) [][]byte { + // cmdError frame: first byte cmdError, then 2 bytes code, then msg + body := []byte{cmdError, 0, 0} + body = append(body, []byte("boom")...) + return [][]byte{body} + }) + drv, _ := Connect(d.path) + defer drv.Close() + + _, err := drv.Info() + if err == nil || !strings.Contains(err.Error(), "boom") { + t.Fatalf("err = %v, want boom", err) + } +} + +// ---------- Handshake family ---------- + +func TestHandshakeFamilyRoundTrips(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + d.onCmd(cmdHandshake, func(frame []byte) [][]byte { + return [][]byte{jsonOK(cmdHandshakeOK, `{"ok": true}`)} + }) + drv, _ := Connect(d.path) + defer drv.Close() + + if _, err := drv.Handshake(99, "please"); err != nil { + t.Fatalf("Handshake: %v", err) + } + if _, err := drv.ApproveHandshake(100); err != nil { + t.Fatalf("Approve: %v", err) + } + if _, err := drv.RejectHandshake(101, "no"); err != nil { + t.Fatalf("Reject: %v", err) + } + if _, err := drv.PendingHandshakes(); err != nil { + t.Fatalf("Pending: %v", err) + } + if _, err := drv.TrustedPeers(); err != nil { + t.Fatalf("Trusted: %v", err) + } + if _, err := drv.RevokeTrust(102); err != nil { + t.Fatalf("Revoke: %v", err) + } + + frames := d.allFrames() + if len(frames) != 6 { + t.Fatalf("expected 6 handshake frames, got %d", len(frames)) + } + expectSub := []byte{subHandshakeSend, subHandshakeApprove, subHandshakeReject, + subHandshakePending, subHandshakeTrusted, subHandshakeRevoke} + for i, want := range expectSub { + if frames[i][0] != cmdHandshake || frames[i][1] != want { + t.Errorf("frame[%d] = %v, want cmd=%#x sub=%#x", i, frames[i][:2], cmdHandshake, want) + } + } +} + +// ---------- Registry-modifying wrappers ---------- + +func TestRegistryWrappersEncodeCorrectly(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + + okCommands := map[byte]byte{ + cmdResolveHostname: cmdResolveHostnameOK, + cmdSetHostname: cmdSetHostnameOK, + cmdSetVisibility: cmdSetVisibilityOK, + cmdDeregister: cmdDeregisterOK, + cmdSetTags: cmdSetTagsOK, + cmdSetWebhook: cmdSetWebhookOK, + } + for req, ok := range okCommands { + req, ok := req, ok + d.onCmd(req, func(_ []byte) [][]byte { + return [][]byte{jsonOK(ok, `{"ok":true}`)} + }) + } + + drv, _ := Connect(d.path) + defer drv.Close() + + if _, err := drv.ResolveHostname("myhost"); err != nil { + t.Fatalf("ResolveHostname: %v", err) + } + if _, err := drv.SetHostname("myhost"); err != nil { + t.Fatalf("SetHostname: %v", err) + } + if _, err := drv.SetVisibility(true); err != nil { + t.Fatalf("SetVisibility: %v", err) + } + if _, err := drv.Deregister(); err != nil { + t.Fatalf("Deregister: %v", err) + } + if _, err := drv.SetTags([]string{"a", "b"}); err != nil { + t.Fatalf("SetTags: %v", err) + } + if _, err := drv.SetWebhook("https://x/y"); err != nil { + t.Fatalf("SetWebhook: %v", err) + } + + // Check visibility byte=1 for enabled + for _, f := range d.allFrames() { + switch f[0] { + case cmdSetVisibility: + if f[1] != 1 { + t.Errorf("visibility byte = %d, want 1", f[1]) + } + case cmdResolveHostname: + if string(f[1:]) != "myhost" { + t.Errorf("ResolveHostname host = %q", f[1:]) + } + case cmdSetWebhook: + if string(f[1:]) != "https://x/y" { + t.Errorf("SetWebhook url = %q", f[1:]) + } + } + } +} + +func TestSetVisibilityFalsePath(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + d.onCmd(cmdSetVisibility, func(_ []byte) [][]byte { + return [][]byte{jsonOK(cmdSetVisibilityOK, `{}`)} + }) + drv, _ := Connect(d.path) + defer drv.Close() + + if _, err := drv.SetVisibility(false); err != nil { + t.Fatal(err) + } + + frames := d.allFrames() + if frames[0][1] != 0 { + t.Errorf("visibility false byte = %d, want 0", frames[0][1]) + } +} + +// ---------- Disconnect / cmdClose ---------- + +func TestDisconnectSendsCmdClose(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + d.onCmd(cmdClose, func(frame []byte) [][]byte { + resp := make([]byte, 5) + resp[0] = cmdCloseOK + binary.BigEndian.PutUint32(resp[1:5], 77) + return [][]byte{resp} + }) + + drv, _ := Connect(d.path) + defer drv.Close() + + if err := drv.Disconnect(77); err != nil { + t.Fatalf("Disconnect: %v", err) + } + // Disconnect is fire-and-forget; wait for the daemon to receive the frame. + waitFor(t, time.Second, func() bool { return d.lastFrame() != nil }, "daemon to receive cmdClose") + frame := d.lastFrame() + if frame[0] != cmdClose { + t.Errorf("cmd = %#x, want %#x", frame[0], cmdClose) + } + if connID := binary.BigEndian.Uint32(frame[1:5]); connID != 77 { + t.Errorf("connID = %d", connID) + } +} + +// ---------- Network family ---------- + +func TestNetworkFamilyRoundTrips(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + d.onCmd(cmdNetwork, func(_ []byte) [][]byte { + return [][]byte{jsonOK(cmdNetworkOK, `{"ok":true}`)} + }) + + drv, _ := Connect(d.path) + defer drv.Close() + + if _, err := drv.NetworkList(); err != nil { + t.Fatal(err) + } + if _, err := drv.NetworkJoin(5, "token"); err != nil { + t.Fatal(err) + } + if _, err := drv.NetworkLeave(5); err != nil { + t.Fatal(err) + } + if _, err := drv.NetworkMembers(5); err != nil { + t.Fatal(err) + } + if _, err := drv.NetworkInvite(5, 100); err != nil { + t.Fatal(err) + } + if _, err := drv.NetworkPollInvites(); err != nil { + t.Fatal(err) + } + if _, err := drv.NetworkRespondInvite(5, true); err != nil { + t.Fatal(err) + } + if _, err := drv.NetworkRespondInvite(5, false); err != nil { + t.Fatal(err) + } + + frames := d.allFrames() + wantSubs := []byte{subNetworkList, subNetworkJoin, subNetworkLeave, subNetworkMembers, + subNetworkInvite, subNetworkPollInvites, subNetworkRespondInvite, subNetworkRespondInvite} + if len(frames) != len(wantSubs) { + t.Fatalf("got %d frames, want %d", len(frames), len(wantSubs)) + } + for i, want := range wantSubs { + if frames[i][1] != want { + t.Errorf("frame[%d] sub = %#x, want %#x", i, frames[i][1], want) + } + } + // Respond-invite accept vs reject byte + // Accept frame is 7th (index 6), reject is 8th (index 7) + if frames[6][4] != 1 { + t.Errorf("accept byte = %d, want 1", frames[6][4]) + } + if frames[7][4] != 0 { + t.Errorf("reject byte = %d, want 0", frames[7][4]) + } +} + +// ---------- Managed family ---------- + +func TestManagedFamilyRoundTrips(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + d.onCmd(cmdManaged, func(_ []byte) [][]byte { + return [][]byte{jsonOK(cmdManagedOK, `{"ok":true}`)} + }) + + drv, _ := Connect(d.path) + defer drv.Close() + + if _, err := drv.ManagedStatus(5); err != nil { + t.Fatal(err) + } + if _, err := drv.ManagedForceCycle(5); err != nil { + t.Fatal(err) + } + if _, err := drv.PolicyGet(5); err != nil { + t.Fatal(err) + } + if _, err := drv.PolicySet(5, []byte(`{"version":1}`)); err != nil { + t.Fatal(err) + } + if _, err := drv.MemberTagsGet(5, 99); err != nil { + t.Fatal(err) + } + if _, err := drv.MemberTagsSet(5, 99, []string{"x", "y"}); err != nil { + t.Fatal(err) + } + + frames := d.allFrames() + wantSubs := []byte{subManagedStatus, subManagedCycle, + subManagedPolicy, subManagedPolicy, subManagedMemberTags, subManagedMemberTags} + for i, want := range wantSubs { + if frames[i][1] != want { + t.Errorf("frame[%d] sub = %#x, want %#x", i, frames[i][1], want) + } + } + // PolicyGet action byte is 0x00, PolicySet 0x01 + if frames[2][2] != 0x00 { + t.Errorf("PolicyGet action byte = %#x, want 0x00", frames[2][2]) + } + if frames[3][2] != 0x01 { + t.Errorf("PolicySet action byte = %#x, want 0x01", frames[3][2]) + } + // MemberTagsGet 0x00, Set 0x01 + if frames[4][2] != 0x00 { + t.Errorf("MemberTagsGet action byte = %#x", frames[4][2]) + } + if frames[5][2] != 0x01 { + t.Errorf("MemberTagsSet action byte = %#x", frames[5][2]) + } +} diff --git a/driver/zz_ipc_listener_test.go b/driver/zz_ipc_listener_test.go new file mode 100644 index 0000000..6cf7db2 --- /dev/null +++ b/driver/zz_ipc_listener_test.go @@ -0,0 +1,510 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package driver + +import ( + "encoding/binary" + "testing" + "time" + + "github.com/pilot-protocol/common/ipcutil" + "github.com/pilot-protocol/common/protocol" +) + +// pushFromDaemon writes an unsolicited frame from the fakeDaemon side to +// exercise driver.readLoop dispatch. Waits briefly for the daemon conn to +// be accepted first. +// +// frame must be [cmd][payload...]. Wire format is [cmd][payload...] with no reqID. +func pushFromDaemon(t *testing.T, d *fakeDaemon, frame []byte) { + t.Helper() + waitFor(t, 2*time.Second, func() bool { + d.mu.Lock() + defer d.mu.Unlock() + return d.conn != nil + }, "daemon accept") + d.mu.Lock() + conn := d.conn + d.mu.Unlock() + if err := ipcutil.Write(conn, frame); err != nil { + t.Fatalf("write from daemon: %v", err) + } +} + +// ---------- readLoop dispatch ---------- + +func TestReadLoopRecvDeliversToRegisteredChannel(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + drv, err := Connect(d.path) + if err != nil { + t.Fatalf("connect: %v", err) + } + defer drv.Close() + + connID := uint32(42) + ch := drv.ipc.registerRecvCh(connID) + + frame := make([]byte, 1+4+5) + frame[0] = cmdRecv + binary.BigEndian.PutUint32(frame[1:5], connID) + copy(frame[5:], "hello") + pushFromDaemon(t, d, frame) + + select { + case data := <-ch: + if string(data) != "hello" { + t.Errorf("got %q, want hello", data) + } + case <-time.After(2 * time.Second): + t.Fatal("no data delivered to recvCh") + } +} + +func TestReadLoopRecvBuffersWhenChannelNotRegistered(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + drv, err := Connect(d.path) + if err != nil { + t.Fatalf("connect: %v", err) + } + defer drv.Close() + + connID := uint32(99) + frame := make([]byte, 1+4+3) + frame[0] = cmdRecv + binary.BigEndian.PutUint32(frame[1:5], connID) + copy(frame[5:], "buf") + pushFromDaemon(t, d, frame) + + // Wait for readLoop to process the frame (no recvCh registered yet). + waitFor(t, 2*time.Second, func() bool { + drv.ipc.recvMu.Lock() + defer drv.ipc.recvMu.Unlock() + return len(drv.ipc.pendRecv[connID]) == 1 + }, "pendRecv buffered") + + // Registering now should drain the buffered data. + ch := drv.ipc.registerRecvCh(connID) + select { + case data := <-ch: + if string(data) != "buf" { + t.Errorf("got %q, want buf", data) + } + case <-time.After(1 * time.Second): + t.Fatal("registerRecvCh did not drain pendRecv") + } +} + +func TestReadLoopRecvShortPayloadDropped(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + drv, err := Connect(d.path) + if err != nil { + t.Fatalf("connect: %v", err) + } + defer drv.Close() + + // <4 bytes after cmd byte → dropped. + pushFromDaemon(t, d, []byte{cmdRecv, 0x01}) + + // Sanity: no crash, no data buffered, no recv channels created. + time.Sleep(50 * time.Millisecond) + drv.ipc.recvMu.Lock() + defer drv.ipc.recvMu.Unlock() + if len(drv.ipc.pendRecv) != 0 { + t.Errorf("short cmdRecv should not buffer, got %d entries", len(drv.ipc.pendRecv)) + } +} + +func TestReadLoopCloseOKClosesRegisteredChannel(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + drv, err := Connect(d.path) + if err != nil { + t.Fatalf("connect: %v", err) + } + defer drv.Close() + + connID := uint32(7) + ch := drv.ipc.registerRecvCh(connID) + + frame := make([]byte, 1+4) + frame[0] = cmdCloseOK + binary.BigEndian.PutUint32(frame[1:], connID) + pushFromDaemon(t, d, frame) + + select { + case _, ok := <-ch: + if ok { + t.Error("expected channel closed, got value") + } + case <-time.After(2 * time.Second): + t.Fatal("channel not closed") + } + + drv.ipc.recvMu.Lock() + _, stillThere := drv.ipc.recvChs[connID] + drv.ipc.recvMu.Unlock() + if stillThere { + t.Error("recvCh entry should be deleted") + } +} + +func TestReadLoopCloseOKShortPayloadDropped(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + drv, err := Connect(d.path) + if err != nil { + t.Fatalf("connect: %v", err) + } + defer drv.Close() + + // payload < 4 — must not panic, must not disturb recvChs. + connID := uint32(8) + ch := drv.ipc.registerRecvCh(connID) + pushFromDaemon(t, d, []byte{cmdCloseOK, 0x00}) + + time.Sleep(50 * time.Millisecond) + select { + case <-ch: + t.Error("channel should not close on short CloseOK") + default: + } +} + +func TestReadLoopRecvFromDeliversDatagram(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + drv, err := Connect(d.path) + if err != nil { + t.Fatalf("connect: %v", err) + } + defer drv.Close() + + srcAddr, err := protocol.ParseAddr("1:0001.AAAA.BBBB") + if err != nil { + t.Fatal(err) + } + frame := make([]byte, 1+protocol.AddrSize+4+5) + frame[0] = cmdRecvFrom + srcAddr.MarshalTo(frame, 1) + binary.BigEndian.PutUint16(frame[1+protocol.AddrSize:], 111) + binary.BigEndian.PutUint16(frame[1+protocol.AddrSize+2:], 222) + copy(frame[1+protocol.AddrSize+4:], "ping!") + pushFromDaemon(t, d, frame) + + dg, err := drv.RecvFrom() + if err != nil { + t.Fatalf("RecvFrom: %v", err) + } + if dg.SrcPort != 111 || dg.DstPort != 222 || string(dg.Data) != "ping!" { + t.Errorf("datagram = %+v, data=%q", dg, string(dg.Data)) + } + if dg.SrcAddr != srcAddr { + t.Errorf("src addr = %v, want %v", dg.SrcAddr, srcAddr) + } +} + +func TestReadLoopRecvFromShortPayloadDropped(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + drv, err := Connect(d.path) + if err != nil { + t.Fatalf("connect: %v", err) + } + defer drv.Close() + + // AddrSize=6, need +4 for ports — send just 5 bytes payload → drop. + pushFromDaemon(t, d, append([]byte{cmdRecvFrom}, make([]byte, 5)...)) + + // If it were dispatched, it'd land on dgCh. Confirm nothing arrives. + select { + case dg := <-drv.ipc.dgCh: + t.Errorf("unexpected datagram: %+v", dg) + case <-time.After(100 * time.Millisecond): + } +} + +func TestReadLoopAcceptShortPayloadDropped(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + drv, err := Connect(d.path) + if err != nil { + t.Fatalf("connect: %v", err) + } + defer drv.Close() + + // <2 bytes after cmd byte → dropped. + pushFromDaemon(t, d, []byte{cmdAccept, 0x01}) + time.Sleep(50 * time.Millisecond) // assertion: no crash +} + +func TestReadLoopEmptyFrameContinues(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + drv, err := Connect(d.path) + if err != nil { + t.Fatalf("connect: %v", err) + } + defer drv.Close() + + // Zero-length frame is skipped; readLoop must keep running. + pushFromDaemon(t, d, []byte{}) + + // Follow it with a valid cmdRecvFrom to prove readLoop is still alive. + srcAddr, err := protocol.ParseAddr("1:0001.CCCC.DDDD") + if err != nil { + t.Fatal(err) + } + frame := make([]byte, 1+protocol.AddrSize+4+2) + frame[0] = cmdRecvFrom + srcAddr.MarshalTo(frame, 1) + copy(frame[1+protocol.AddrSize+4:], "ok") + pushFromDaemon(t, d, frame) + + if _, err := drv.RecvFrom(); err != nil { + t.Fatalf("readLoop died after empty frame: %v", err) + } +} + +func TestReadLoopUnknownCmdWithNoHandlerDropped(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + drv, err := Connect(d.path) + if err != nil { + t.Fatalf("connect: %v", err) + } + defer drv.Close() + + // cmd 0xFE not in any handler map — readLoop default branch, no waiter, drop. + pushFromDaemon(t, d, []byte{0xFE, 0x01, 0x02}) + + // Prove readLoop still alive by exchanging Info. + d.onCmd(cmdInfo, func(_ []byte) [][]byte { + return [][]byte{jsonOK(cmdInfoOK, `{"ok":true}`)} + }) + if _, err := drv.Info(); err != nil { + t.Fatalf("Info after unknown cmd: %v", err) + } +} + +// ---------- Listener.Accept branches ---------- + +func TestListenerAcceptDeliversConn(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + drv, err := Connect(d.path) + if err != nil { + t.Fatalf("connect: %v", err) + } + defer drv.Close() + + // Bind port 5000. + d.onCmd(cmdBind, func(frame []byte) [][]byte { + return [][]byte{{cmdBindOK, frame[1], frame[2]}} + }) + ln, err := drv.Listen(5000) + if err != nil { + t.Fatalf("Listen: %v", err) + } + defer ln.Close() + + // Push an unsolicited cmdAccept: [port][connID][addr][port] + remoteAddr, _ := protocol.ParseAddr("1:0002.1111.2222") + payload := make([]byte, 2+4+protocol.AddrSize+2) + binary.BigEndian.PutUint16(payload[0:2], 5000) + binary.BigEndian.PutUint32(payload[2:6], 1234) + remoteAddr.MarshalTo(payload, 6) + binary.BigEndian.PutUint16(payload[6+protocol.AddrSize:], 99) + pushFromDaemon(t, d, append([]byte{cmdAccept}, payload...)) + + // Accept should return a Conn with the parsed fields. + done := make(chan *Conn, 1) + errCh := make(chan error, 1) + go func() { + c, err := ln.Accept() + if err != nil { + errCh <- err + return + } + done <- c.(*Conn) + }() + + select { + case c := <-done: + if c.id != 1234 { + t.Errorf("conn.id = %d, want 1234", c.id) + } + if c.remoteAddr.Port != 99 { + t.Errorf("remote port = %d, want 99", c.remoteAddr.Port) + } + if c.remoteAddr.Addr != remoteAddr { + t.Errorf("remote addr = %v, want %v", c.remoteAddr.Addr, remoteAddr) + } + case err := <-errCh: + t.Fatalf("Accept: %v", err) + case <-time.After(2 * time.Second): + t.Fatal("Accept did not complete") + } +} + +func TestListenerAcceptInvalidPayloadReturnsError(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + drv, err := Connect(d.path) + if err != nil { + t.Fatalf("connect: %v", err) + } + defer drv.Close() + + d.onCmd(cmdBind, func(frame []byte) [][]byte { + return [][]byte{{cmdBindOK, frame[1], frame[2]}} + }) + ln, err := drv.Listen(5001) + if err != nil { + t.Fatalf("Listen: %v", err) + } + defer ln.Close() + + // cmdAccept with port=5001 but truncated tail. + pushFromDaemon(t, d, []byte{cmdAccept, 0x13, 0x89, 0x00}) + + _, err = ln.Accept() + if err == nil { + t.Fatal("expected invalid payload error") + } +} + +func TestListenerAcceptUnblocksOnClose(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + drv, err := Connect(d.path) + if err != nil { + t.Fatalf("connect: %v", err) + } + defer drv.Close() + + d.onCmd(cmdBind, func(frame []byte) [][]byte { + return [][]byte{{cmdBindOK, frame[1], frame[2]}} + }) + ln, err := drv.Listen(5002) + if err != nil { + t.Fatalf("Listen: %v", err) + } + + errCh := make(chan error, 1) + go func() { + _, err := ln.Accept() + errCh <- err + }() + + // Give Accept a moment to enter the select, then close. + time.Sleep(50 * time.Millisecond) + _ = ln.Close() + + select { + case err := <-errCh: + if err == nil { + t.Fatal("expected error after Close") + } + case <-time.After(2 * time.Second): + t.Fatal("Accept did not unblock on Close") + } +} + +// ---------- ipcClient helpers ---------- + +func TestUnregisterRecvChRemovesEntry(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + drv, err := Connect(d.path) + if err != nil { + t.Fatalf("connect: %v", err) + } + defer drv.Close() + + connID := uint32(55) + _ = drv.ipc.registerRecvCh(connID) + + drv.ipc.recvMu.Lock() + if _, ok := drv.ipc.recvChs[connID]; !ok { + drv.ipc.recvMu.Unlock() + t.Fatal("recvCh not present after registerRecvCh") + } + drv.ipc.recvMu.Unlock() + + drv.ipc.unregisterRecvCh(connID) + + drv.ipc.recvMu.Lock() + _, ok := drv.ipc.recvChs[connID] + drv.ipc.recvMu.Unlock() + if ok { + t.Error("recvCh still present after unregisterRecvCh") + } +} + +func TestSendAndWaitTimeoutFires(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + defer d.close() + drv, err := Connect(d.path) + if err != nil { + t.Fatalf("connect: %v", err) + } + defer drv.Close() + + // No handler for cmdInfo — daemon accepts the frame and never replies. + start := time.Now() + _, err = drv.ipc.sendAndWaitTimeout([]byte{cmdInfo}, cmdInfoOK, 80*time.Millisecond) + elapsed := time.Since(start) + if err == nil { + t.Fatal("expected timeout error") + } + if elapsed < 50*time.Millisecond || elapsed > 500*time.Millisecond { + t.Errorf("unexpected elapsed %v (want ~80ms)", elapsed) + } +} + +func TestSendAndWaitReturnsWhenDaemonDisconnects(t *testing.T) { + t.Parallel() + d := newFakeDaemon(t) + drv, err := Connect(d.path) + if err != nil { + t.Fatalf("connect: %v", err) + } + defer drv.Close() + + errCh := make(chan error, 1) + go func() { + _, err := drv.ipc.sendAndWait([]byte{cmdInfo}, cmdInfoOK) + errCh <- err + }() + + // Give the request time to enqueue its handler, then yank the daemon. + time.Sleep(50 * time.Millisecond) + d.close() + + select { + case err := <-errCh: + if err == nil { + t.Fatal("expected error on daemon disconnect") + } + case <-time.After(2 * time.Second): + t.Fatal("sendAndWait did not unblock on disconnect") + } +} diff --git a/go.mod b/go.mod index 85ae1fe..933a51d 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/pilot-protocol/common -go 1.25.3 +go 1.25.10 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..e69de29 diff --git a/ipcutil/ipcutil.go b/ipcutil/ipcutil.go new file mode 100644 index 0000000..5939d86 --- /dev/null +++ b/ipcutil/ipcutil.go @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package ipcutil + +import ( + "encoding/binary" + "fmt" + "io" +) + +// MaxMessageSize is the maximum IPC message size (1MB). +const MaxMessageSize = 1 << 20 + +// Read reads a length-prefixed IPC message from r. +func Read(r io.Reader) ([]byte, error) { + var lenBuf [4]byte + if _, err := io.ReadFull(r, lenBuf[:]); err != nil { + return nil, err + } + length := binary.BigEndian.Uint32(lenBuf[:]) + if length > MaxMessageSize { + return nil, fmt.Errorf("ipc message too large: %d bytes (max %d)", length, MaxMessageSize) + } + buf := make([]byte, length) + if _, err := io.ReadFull(r, buf); err != nil { + return nil, err + } + return buf, nil +} + +// Write writes a length-prefixed IPC message to w. +func Write(w io.Writer, data []byte) error { + var lenBuf [4]byte + binary.BigEndian.PutUint32(lenBuf[:], uint32(len(data))) + if _, err := w.Write(lenBuf[:]); err != nil { + return err + } + _, err := w.Write(data) + return err +} diff --git a/ipcutil/zz_test.go b/ipcutil/zz_test.go new file mode 100644 index 0000000..6df60be --- /dev/null +++ b/ipcutil/zz_test.go @@ -0,0 +1,137 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package ipcutil + +import ( + "bytes" + "encoding/binary" + "io" + "strings" + "testing" +) + +func TestReadWriteRoundTrip(t *testing.T) { + t.Parallel() + for _, payload := range [][]byte{ + nil, + {}, + []byte("hello"), + bytes.Repeat([]byte{0xAB}, 10000), + } { + var buf bytes.Buffer + if err := Write(&buf, payload); err != nil { + t.Fatalf("Write(%d bytes): %v", len(payload), err) + } + got, err := Read(&buf) + if err != nil { + t.Fatalf("Read: %v", err) + } + if !bytes.Equal(got, payload) { + t.Fatalf("round-trip mismatch: got %d bytes, want %d bytes", len(got), len(payload)) + } + } +} + +func TestWriteLengthPrefix(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + payload := []byte("abcde") + if err := Write(&buf, payload); err != nil { + t.Fatal(err) + } + if buf.Len() != 4+len(payload) { + t.Fatalf("buf len = %d, want %d", buf.Len(), 4+len(payload)) + } + length := binary.BigEndian.Uint32(buf.Bytes()[:4]) + if length != uint32(len(payload)) { + t.Fatalf("length prefix = %d, want %d", length, len(payload)) + } +} + +func TestReadTooLargeRejected(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + // Write length prefix claiming > MaxMessageSize + lenBuf := make([]byte, 4) + binary.BigEndian.PutUint32(lenBuf, MaxMessageSize+1) + buf.Write(lenBuf) + + _, err := Read(&buf) + if err == nil { + t.Fatal("expected too-large error") + } + if !strings.Contains(err.Error(), "too large") { + t.Fatalf("error %q missing 'too large'", err) + } +} + +func TestReadExactlyMaxSizeAccepted(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + // Length exactly == max, followed by that many zero bytes + data := make([]byte, MaxMessageSize) + if err := Write(&buf, data); err != nil { + t.Fatal(err) + } + got, err := Read(&buf) + if err != nil { + t.Fatalf("max-size read should succeed: %v", err) + } + if len(got) != MaxMessageSize { + t.Fatalf("len = %d, want %d", len(got), MaxMessageSize) + } +} + +func TestReadTruncatedLength(t *testing.T) { + t.Parallel() + buf := bytes.NewReader([]byte{0x00, 0x00}) // only 2 bytes of length prefix + _, err := Read(buf) + if err == nil { + t.Fatal("expected error on truncated length prefix") + } +} + +func TestReadTruncatedPayload(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + lenBuf := make([]byte, 4) + binary.BigEndian.PutUint32(lenBuf, 100) + buf.Write(lenBuf) + buf.Write([]byte("only 20 bytes here..")) // payload truncated + _, err := Read(&buf) + if err == nil { + t.Fatal("expected truncation error") + } +} + +// errWriter fails every write — exercises the Write error paths. +type errWriter struct { + failAfter int + calls int +} + +func (w *errWriter) Write(p []byte) (int, error) { + w.calls++ + if w.calls > w.failAfter { + return 0, io.ErrShortWrite + } + return len(p), nil +} + +func TestWriteErrorOnLengthPrefix(t *testing.T) { + t.Parallel() + w := &errWriter{failAfter: 0} + err := Write(w, []byte("data")) + if err == nil { + t.Fatal("expected error from failing writer on length prefix") + } +} + +func TestWriteErrorOnPayload(t *testing.T) { + t.Parallel() + w := &errWriter{failAfter: 1} // first write (length) succeeds, second (payload) fails + err := Write(w, []byte("data")) + if err == nil { + t.Fatal("expected error from failing writer on payload") + } +} diff --git a/logging/logging.go b/logging/logging.go new file mode 100644 index 0000000..7d92d8c --- /dev/null +++ b/logging/logging.go @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package logging + +import ( + "io" + "log/slog" + "os" + "strings" +) + +// Setup configures the default slog logger with the given level and format. +// format can be "text" (human-readable) or "json" (machine-parseable). +// level can be "debug", "info", "warn", "error". +func Setup(level, format string) { + SetupWriter(os.Stderr, level, format) +} + +// SetupWriter configures the default slog logger writing to w. +func SetupWriter(w io.Writer, level, format string) { + var lvl slog.Level + switch strings.ToLower(level) { + case "debug": + lvl = slog.LevelDebug + case "warn", "warning": + lvl = slog.LevelWarn + case "error": + lvl = slog.LevelError + default: + lvl = slog.LevelInfo + } + + opts := &slog.HandlerOptions{Level: lvl} + + var handler slog.Handler + switch strings.ToLower(format) { + case "json": + handler = slog.NewJSONHandler(w, opts) + default: + handler = slog.NewTextHandler(w, opts) + } + + slog.SetDefault(slog.New(handler)) +} diff --git a/logging/zz_logging_test.go b/logging/zz_logging_test.go new file mode 100644 index 0000000..df01cf9 --- /dev/null +++ b/logging/zz_logging_test.go @@ -0,0 +1,161 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package logging_test + +import ( + "bytes" + "encoding/json" + "log/slog" + "strings" + "testing" + + "github.com/pilot-protocol/common/logging" +) + +func TestSetupWriterJSONFormat(t *testing.T) { + // Cannot run in parallel because this mutates the package-global default logger. + saved := slog.Default() + defer slog.SetDefault(saved) + + var buf bytes.Buffer + logging.SetupWriter(&buf, "info", "json") + slog.Info("hello", "k", "v") + + line := strings.TrimSpace(buf.String()) + if line == "" { + t.Fatal("no output emitted") + } + var m map[string]interface{} + if err := json.Unmarshal([]byte(line), &m); err != nil { + t.Fatalf("output not valid JSON: %v\n%s", err, line) + } + if m["msg"] != "hello" { + t.Errorf("msg = %v, want hello", m["msg"]) + } + if m["k"] != "v" { + t.Errorf("attr k = %v, want v", m["k"]) + } + if m["level"] != "INFO" { + t.Errorf("level = %v, want INFO", m["level"]) + } +} + +func TestSetupWriterTextFormat(t *testing.T) { + saved := slog.Default() + defer slog.SetDefault(saved) + + var buf bytes.Buffer + logging.SetupWriter(&buf, "info", "text") + slog.Info("hello", "k", "v") + + out := buf.String() + if out == "" { + t.Fatal("no output") + } + // text format output should NOT parse as JSON + if err := json.Unmarshal([]byte(strings.TrimSpace(out)), &map[string]interface{}{}); err == nil { + t.Fatalf("text output unexpectedly parsed as JSON: %s", out) + } + if !strings.Contains(out, "hello") || !strings.Contains(out, "k=v") { + t.Errorf("missing expected content: %s", out) + } +} + +func TestSetupWriterDefaultFormatIsText(t *testing.T) { + saved := slog.Default() + defer slog.SetDefault(saved) + + var buf bytes.Buffer + logging.SetupWriter(&buf, "info", "unknown-format") + slog.Info("msg") + + out := strings.TrimSpace(buf.String()) + // default (unknown format) → text handler, not JSON + if strings.HasPrefix(out, "{") { + t.Fatalf("unknown format should default to text, got JSON: %s", out) + } +} + +func TestSetupWriterLevelsGateOutput(t *testing.T) { + cases := []struct { + level string + wantDebug bool + wantInfo bool + wantWarn bool + wantError bool + }{ + {"debug", true, true, true, true}, + {"info", false, true, true, true}, + {"warn", false, false, true, true}, + {"warning", false, false, true, true}, + {"error", false, false, false, true}, + {"unknown", false, true, true, true}, // default → info + } + for _, tc := range cases { + t.Run(tc.level, func(t *testing.T) { + saved := slog.Default() + defer slog.SetDefault(saved) + + var buf bytes.Buffer + logging.SetupWriter(&buf, tc.level, "text") + + slog.Debug("D") + slog.Info("I") + slog.Warn("W") + slog.Error("E") + + out := buf.String() + if strings.Contains(out, "\"D\"") != tc.wantDebug && strings.Contains(out, "D") != tc.wantDebug { + hasD := strings.Contains(out, "msg=D") + if hasD != tc.wantDebug { + t.Errorf("debug output present=%v, want=%v\n%s", hasD, tc.wantDebug, out) + } + } + hasInfo := strings.Contains(out, "msg=I") + if hasInfo != tc.wantInfo { + t.Errorf("info output present=%v, want=%v\n%s", hasInfo, tc.wantInfo, out) + } + hasWarn := strings.Contains(out, "msg=W") + if hasWarn != tc.wantWarn { + t.Errorf("warn output present=%v, want=%v\n%s", hasWarn, tc.wantWarn, out) + } + hasError := strings.Contains(out, "msg=E") + if hasError != tc.wantError { + t.Errorf("error output present=%v, want=%v\n%s", hasError, tc.wantError, out) + } + }) + } +} + +func TestSetupCaseInsensitive(t *testing.T) { + saved := slog.Default() + defer slog.SetDefault(saved) + + var buf bytes.Buffer + logging.SetupWriter(&buf, "DEBUG", "JSON") + slog.Debug("dbg-msg") + line := strings.TrimSpace(buf.String()) + if !strings.HasPrefix(line, "{") { + t.Fatalf("uppercase JSON should produce JSON, got: %s", line) + } + var m map[string]interface{} + if err := json.Unmarshal([]byte(line), &m); err != nil { + t.Fatalf("not JSON: %v", err) + } + if m["level"] != "DEBUG" { + t.Errorf("level = %v, want DEBUG", m["level"]) + } +} + +func TestSetupUsesStderrByDefault(t *testing.T) { + // Smoke test: Setup() picks the stderr path. We don't capture stderr (too + // invasive) but we verify the call doesn't panic and leaves slog.Default + // non-nil. + saved := slog.Default() + defer slog.SetDefault(saved) + + logging.Setup("info", "text") + if slog.Default() == nil { + t.Fatal("slog.Default() is nil after Setup") + } +} diff --git a/protocol/address.go b/protocol/address.go new file mode 100644 index 0000000..74cd57f --- /dev/null +++ b/protocol/address.go @@ -0,0 +1,151 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package protocol + +import ( + "encoding/binary" + "fmt" + "strconv" + "strings" +) + +const AddrSize = 6 // 48 bits: 2 bytes network + 4 bytes node + +// Addr is a 48-bit Pilot Protocol virtual address. +// Layout: [16-bit Network ID][32-bit Node ID] +// Text format: N:NNNN.HHHH.LLLL +// +// N = network ID in decimal +// NNNN = network ID in hex (redundant, for readability) +// HHHH = node ID high 16 bits in hex +// LLLL = node ID low 16 bits in hex +type Addr struct { + Network uint16 + Node uint32 +} + +var ( + AddrRegistry = Addr{0, 1} + AddrBeacon = Addr{0, 2} + AddrNameserver = Addr{0, 3} +) + +// ZeroAddr returns the zero-value address ({0, 0}). It exists as a +// function rather than a package-level var so callers cannot mutate a +// shared sentinel (P3 — no cross-layer mutable globals). The returned +// value is freshly constructed on each call. +func ZeroAddr() Addr { return Addr{} } + +// BroadcastAddr returns the broadcast address for a given network. +func BroadcastAddr(network uint16) Addr { + return Addr{Network: network, Node: 0xFFFFFFFF} +} + +func (a Addr) IsZero() bool { return a.Network == 0 && a.Node == 0 } +func (a Addr) IsBroadcast() bool { return a.Node == 0xFFFFFFFF } + +// Marshal writes the address as 6 bytes (big-endian). +func (a Addr) Marshal() []byte { + b := make([]byte, AddrSize) + a.MarshalTo(b, 0) + return b +} + +// MarshalTo writes the address into buf at the given offset. +func (a Addr) MarshalTo(buf []byte, offset int) { + binary.BigEndian.PutUint16(buf[offset:], a.Network) + binary.BigEndian.PutUint32(buf[offset+2:], a.Node) +} + +// UnmarshalAddr reads a 6-byte address from buf. +// Returns a zero address if buf is shorter than AddrSize (6 bytes), +// rather than panicking on the out-of-bounds slice (PILOT-133). +func UnmarshalAddr(buf []byte) Addr { + if len(buf) < AddrSize { + return Addr{} + } + return Addr{ + Network: binary.BigEndian.Uint16(buf[0:2]), + Node: binary.BigEndian.Uint32(buf[2:6]), + } +} + +// String returns the text representation: N:NNNN.HHHH.LLLL +func (a Addr) String() string { + return fmt.Sprintf("%d:%04X.%04X.%04X", a.Network, a.Network, (a.Node>>16)&0xFFFF, a.Node&0xFFFF) +} + +// ParseAddr parses "0:0000.0000.0001" or "1:00A3.F291.0004" into an Addr. +func ParseAddr(s string) (Addr, error) { + parts := strings.SplitN(s, ":", 2) + if len(parts) != 2 { + return Addr{}, fmt.Errorf("invalid address: %q (expected N:XXXX.YYYY.YYYY)", s) + } + + networkDec, err := strconv.ParseUint(parts[0], 10, 16) + if err != nil { + return Addr{}, fmt.Errorf("invalid network ID: %q: %w", parts[0], err) + } + + hexGroups := strings.Split(parts[1], ".") + if len(hexGroups) != 3 { + return Addr{}, fmt.Errorf("invalid address: %q (expected 3 dot-separated hex groups)", parts[1]) + } + for _, h := range hexGroups { + if len(h) != 4 { + return Addr{}, fmt.Errorf("invalid hex group: %q (expected 4 digits)", h) + } + } + + netHex, err := strconv.ParseUint(hexGroups[0], 16, 16) + if err != nil { + return Addr{}, fmt.Errorf("invalid hex group: %q: %w", hexGroups[0], err) + } + if netHex != networkDec { + return Addr{}, fmt.Errorf("network mismatch: decimal %d != hex 0x%04X", networkDec, netHex) + } + + nodeHigh, err := strconv.ParseUint(hexGroups[1], 16, 16) + if err != nil { + return Addr{}, fmt.Errorf("invalid hex group: %q: %w", hexGroups[1], err) + } + nodeLow, err := strconv.ParseUint(hexGroups[2], 16, 16) + if err != nil { + return Addr{}, fmt.Errorf("invalid hex group: %q: %w", hexGroups[2], err) + } + + return Addr{ + Network: uint16(networkDec), + Node: uint32(nodeHigh)<<16 | uint32(nodeLow), + }, nil +} + +// SocketAddr is a full endpoint: virtual address + port. +type SocketAddr struct { + Addr Addr + Port uint16 +} + +func (sa SocketAddr) String() string { + return fmt.Sprintf("%s:%d", sa.Addr.String(), sa.Port) +} + +// ParseSocketAddr parses "N:XXXX.YYYY.YYYY:PORT". +func ParseSocketAddr(s string) (SocketAddr, error) { + lastColon := strings.LastIndex(s, ":") + if lastColon == -1 { + return SocketAddr{}, fmt.Errorf("invalid socket address: %q (no port)", s) + } + + addr, err := ParseAddr(s[:lastColon]) + if err != nil { + return SocketAddr{}, err + } + + port, err := strconv.ParseUint(s[lastColon+1:], 10, 16) + if err != nil { + return SocketAddr{}, fmt.Errorf("invalid port: %q: %w", s[lastColon+1:], err) + } + + return SocketAddr{Addr: addr, Port: uint16(port)}, nil +} diff --git a/protocol/checksum.go b/protocol/checksum.go new file mode 100644 index 0000000..39a3ca8 --- /dev/null +++ b/protocol/checksum.go @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package protocol + +import "hash/crc32" + +var crcTable = crc32.MakeTable(crc32.IEEE) + +// Checksum computes CRC32 (IEEE) over the given data. +func Checksum(data []byte) uint32 { + return crc32.Checksum(data, crcTable) +} diff --git a/protocol/header.go b/protocol/header.go new file mode 100644 index 0000000..edc3dd1 --- /dev/null +++ b/protocol/header.go @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package protocol + +import "errors" + +// Protocol version +const Version uint8 = 1 + +// Sentinel errors shared across packages. +var ( + ErrNodeNotFound = errors.New("node not found") + ErrNetworkNotFound = errors.New("network not found") + ErrConnClosed = errors.New("connection closed") + ErrConnRefused = errors.New("connection refused") + ErrDialTimeout = errors.New("dial timeout") + ErrChecksumMismatch = errors.New("checksum mismatch") + // ErrMalformedPacket is returned by Marshal/Unmarshal's L1 panic + // boundary when a panic is recovered during wire-format decode/encode. + // Wraps the original panic value via fmt.Errorf("%w: %v", ...). + ErrMalformedPacket = errors.New("malformed packet") +) + +// Flags (4 bits, stored in lower nibble of first byte alongside version) +const ( + FlagSYN uint8 = 0x1 + FlagACK uint8 = 0x2 + FlagFIN uint8 = 0x4 + FlagRST uint8 = 0x8 +) + +// Protocol types +const ( + ProtoStream uint8 = 0x01 // Reliable, ordered (TCP-like) + ProtoDatagram uint8 = 0x02 // Unreliable, unordered (UDP-like) + ProtoControl uint8 = 0x03 // Internal control +) + +// Well-known ports +const ( + PortPing uint16 = 0 + PortControl uint16 = 1 + PortEcho uint16 = 7 + PortNameserver uint16 = 53 + PortHTTP uint16 = 80 + PortSecure uint16 = 443 + PortStdIO uint16 = 1000 + PortDataExchange uint16 = 1001 + PortEventStream uint16 = 1002 +) + +// Port ranges +const ( + PortReservedMax uint16 = 1023 + PortRegisteredMax uint16 = 49151 + PortEphemeralMin uint16 = 49152 + PortEphemeralMax uint16 = 65535 +) + +// Tunnel magic bytes: "PILT" (0x50494C54) +var TunnelMagic = [4]byte{0x50, 0x49, 0x4C, 0x54} + +// Tunnel magic bytes for encrypted packets: "PILS" (0x50494C53) +var TunnelMagicSecure = [4]byte{0x50, 0x49, 0x4C, 0x53} + +// Tunnel magic bytes for key exchange: "PILK" (0x50494C4B) +var TunnelMagicKeyEx = [4]byte{0x50, 0x49, 0x4C, 0x4B} + +// Tunnel magic bytes for authenticated key exchange: "PILA" (0x50494C41) +var TunnelMagicAuthEx = [4]byte{0x50, 0x49, 0x4C, 0x41} + +// Tunnel magic bytes for NAT punch packet: "PILP" (0x50494C50) +var TunnelMagicPunch = [4]byte{0x50, 0x49, 0x4C, 0x50} + +// Well-known port for handshake requests +const PortHandshake uint16 = 444 + +// Beacon message types (single-byte codes, all < 0x10 to avoid collision with tunnel magic) +const ( + BeaconMsgDiscover byte = 0x01 + BeaconMsgDiscoverReply byte = 0x02 + BeaconMsgPunchRequest byte = 0x03 + BeaconMsgPunchCommand byte = 0x04 + BeaconMsgRelay byte = 0x05 + BeaconMsgRelayDeliver byte = 0x06 + BeaconMsgSync byte = 0x07 // gossip: beacon-to-beacon node list exchange +) diff --git a/protocol/packet.go b/protocol/packet.go new file mode 100644 index 0000000..190f501 --- /dev/null +++ b/protocol/packet.go @@ -0,0 +1,158 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package protocol + +import ( + "encoding/binary" + "fmt" +) + +// Wire layout (34 bytes): +// +// Byte 0: [Version:4][Flags:4] +// Byte 1: Protocol +// Byte 2-3: Payload Length +// Byte 4-9: Source Address (6 bytes) +// Byte 10-15: Destination Address (6 bytes) +// Byte 16-17: Source Port +// Byte 18-19: Destination Port +// Byte 20-23: Sequence Number +// Byte 24-27: Acknowledgment Number +// Byte 28-29: Window (receive window in segments, 0 = no flow control) +// Byte 30-33: Checksum (CRC32) +const packetHeaderSize = 34 + +type Packet struct { + Version uint8 + Flags uint8 + Protocol uint8 + + Src Addr + Dst Addr + SrcPort uint16 + DstPort uint16 + + Seq uint32 + Ack uint32 + Window uint16 // advertised receive window (in segments; 0 = no limit) + + Payload []byte +} + +func (p *Packet) HasFlag(f uint8) bool { return p.Flags&f != 0 } +func (p *Packet) SetFlag(f uint8) { p.Flags |= f } +func (p *Packet) ClearFlag(f uint8) { p.Flags &^= f } + +// Marshal serializes the packet to wire format with checksum. +// +// L1 panic boundary (architecture-notes/03-INVARIANTS.md §8): +// the explicit length-check below covers the only known caller-induced +// failure (oversize payload), but a nil-pointer Packet receiver or +// future bug could trigger a panic mid-encode. The deferred recover +// converts any panic into ErrMalformedPacket so callers (Send paths) +// drop the frame instead of crashing the daemon. +func (p *Packet) Marshal() (out []byte, err error) { + defer func() { + if r := recover(); r != nil { + out = nil + err = fmt.Errorf("%w: panic during encode: %v", ErrMalformedPacket, r) + } + }() + + payloadLen := len(p.Payload) + if payloadLen > 0xFFFF { + return nil, fmt.Errorf("payload too large: %d bytes (max 65535)", payloadLen) + } + + totalLen := packetHeaderSize + payloadLen // safe: payloadLen ≤ 0xFFFF (checked above) + buf := make([]byte, totalLen) + + buf[0] = (p.Version << 4) | (p.Flags & 0x0F) + buf[1] = p.Protocol + binary.BigEndian.PutUint16(buf[2:4], uint16(payloadLen)) + p.Src.MarshalTo(buf, 4) + p.Dst.MarshalTo(buf, 10) + binary.BigEndian.PutUint16(buf[16:18], p.SrcPort) + binary.BigEndian.PutUint16(buf[18:20], p.DstPort) + binary.BigEndian.PutUint32(buf[20:24], p.Seq) + binary.BigEndian.PutUint32(buf[24:28], p.Ack) + binary.BigEndian.PutUint16(buf[28:30], p.Window) + + if payloadLen > 0 { + copy(buf[packetHeaderSize:], p.Payload) + } + + // Checksum: CRC32 over header (with checksum field zeroed) + payload. + // Field is already zero from make(). + binary.BigEndian.PutUint32(buf[30:34], Checksum(buf)) + + return buf, nil +} + +// Unmarshal deserializes a packet from wire bytes. +// +// L1 panic boundary (architecture-notes/03-INVARIANTS.md §8): +// the explicit length-checks below cover all *known* malformed inputs, +// but a future caller could pass a slice that aliases a buffer being +// concurrently mutated, or a malformed input not yet enumerated, causing +// an out-of-bounds slice expression to panic. The deferred recover +// converts any such panic into a structured error so callers (the +// tunnel readLoop, relay path) drop the frame instead of taking down +// the whole daemon. Returns ErrMalformedPacket on panic; the original +// panic value is wrapped via fmt.Errorf for diagnostics. +func Unmarshal(data []byte) (p *Packet, err error) { + defer func() { + if r := recover(); r != nil { + p = nil + err = fmt.Errorf("%w: panic during decode: %v", ErrMalformedPacket, r) + } + }() + + if len(data) < packetHeaderSize { + return nil, fmt.Errorf("packet too short: %d bytes (min %d)", len(data), packetHeaderSize) + } + + payloadLen := binary.BigEndian.Uint16(data[2:4]) + total := packetHeaderSize + int(payloadLen) + if len(data) < total { + return nil, fmt.Errorf("packet truncated: have %d bytes, need %d", len(data), total) + } + + // Verify checksum before parsing. + wireChecksum := binary.BigEndian.Uint32(data[30:34]) + binary.BigEndian.PutUint32(data[30:34], 0) // zero for computation + computed := Checksum(data[:total]) + binary.BigEndian.PutUint32(data[30:34], wireChecksum) // restore + + if computed != wireChecksum { + return nil, ErrChecksumMismatch + } + + // Validate protocol version. + wireVersion := (data[0] >> 4) & 0x0F + if wireVersion != Version { + return nil, fmt.Errorf("unsupported protocol version %d (expected %d)", wireVersion, Version) + } + + p = &Packet{ + Version: (data[0] >> 4) & 0x0F, + Flags: data[0] & 0x0F, + Protocol: data[1], + Src: UnmarshalAddr(data[4:10]), + Dst: UnmarshalAddr(data[10:16]), + SrcPort: binary.BigEndian.Uint16(data[16:18]), + DstPort: binary.BigEndian.Uint16(data[18:20]), + Seq: binary.BigEndian.Uint32(data[20:24]), + Ack: binary.BigEndian.Uint32(data[24:28]), + Window: binary.BigEndian.Uint16(data[28:30]), + } + + if payloadLen > 0 { + p.Payload = make([]byte, payloadLen) + copy(p.Payload, data[packetHeaderSize:total]) + } + + return p, nil +} + +func PacketHeaderSize() int { return packetHeaderSize } diff --git a/protocol/zz_fuzz_packet_test.go b/protocol/zz_fuzz_packet_test.go new file mode 100644 index 0000000..99818ff --- /dev/null +++ b/protocol/zz_fuzz_packet_test.go @@ -0,0 +1,200 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package protocol + +import ( + "bytes" + "testing" +) + +// FuzzUnmarshalPacket exercises the packet decoder with arbitrary bytes. +// The decoder is documented as having an L1 panic boundary (deferred +// recover → ErrMalformedPacket), so a literal panic should never escape +// to the fuzz harness — any escape is a real bug. PILOT-133 noted a +// latent unmarshal panic in UnmarshalAddr; this target should reproduce +// such issues quickly. +// +// Seeds include valid frames at multiple sizes, header-only inputs, all +// flag combinations, malformed length fields, and a few adversarial +// envelopes (length field claims much more than the buffer holds). +func FuzzUnmarshalPacket(f *testing.F) { + // Seed 1: minimal valid (no payload) packet. + { + p := &Packet{Version: Version, Protocol: ProtoStream} + b, err := p.Marshal() + if err == nil { + f.Add(b) + } + } + // Seed 2: valid packet with small payload. + { + p := &Packet{ + Version: Version, Flags: FlagSYN | FlagACK, Protocol: ProtoStream, + Src: Addr{1, 0xDEADBEEF}, Dst: Addr{2, 0xCAFEBABE}, + SrcPort: 1234, DstPort: 5678, + Seq: 0x11223344, Ack: 0x55667788, Window: 16, + Payload: []byte("hello"), + } + b, err := p.Marshal() + if err == nil { + f.Add(b) + } + } + // Seed 3: control proto, all flags. + { + p := &Packet{ + Version: Version, Flags: FlagSYN | FlagACK | FlagFIN | FlagRST, + Protocol: ProtoControl, Payload: bytes.Repeat([]byte{0xAB}, 64), + } + b, err := p.Marshal() + if err == nil { + f.Add(b) + } + } + // Seed 4: datagram with binary payload. + { + p := &Packet{ + Version: Version, Protocol: ProtoDatagram, + Payload: []byte{0x00, 0xFF, 0x7F, 0x80, 0x01, 0xFE}, + } + b, err := p.Marshal() + if err == nil { + f.Add(b) + } + } + // Seed 5: broadcast destination. + { + p := &Packet{ + Version: Version, Protocol: ProtoDatagram, + Dst: BroadcastAddr(7), DstPort: PortPing, + } + b, err := p.Marshal() + if err == nil { + f.Add(b) + } + } + // Seed 6: exactly header-sized (34 bytes of zeros). + f.Add(make([]byte, packetHeaderSize)) + // Seed 7: shorter than header. + f.Add(make([]byte, packetHeaderSize-1)) + // Seed 8: empty. + f.Add([]byte{}) + // Seed 9: header claims big payload but buffer truncated. + { + b := make([]byte, packetHeaderSize) + b[0] = Version << 4 // valid version + b[2], b[3] = 0xFF, 0xFF + f.Add(b) + } + // Seed 10: header with unsupported version. + { + b := make([]byte, packetHeaderSize) + b[0] = 0xF0 // version 0x0F (unsupported) + f.Add(b) + } + + f.Fuzz(func(t *testing.T, data []byte) { + // Defensive: literal panic out of Unmarshal is the find. + defer func() { + if r := recover(); r != nil { + t.Errorf("panic on input %x: %v", data, r) + } + }() + + // Unmarshal mutates data[30:34] briefly during checksum verify; + // pass a copy so the fuzzer's input slice is not aliased. + buf := make([]byte, len(data)) + copy(buf, data) + + p, err := Unmarshal(buf) + if err != nil { + return // expected on most random input + } + + // Round-trip property: a successfully decoded packet should + // re-encode to bytes that decode back to an equivalent struct. + re, err := p.Marshal() + if err != nil { + t.Errorf("decode-then-encode failed: %v (orig=%x)", err, data) + return + } + p2, err := Unmarshal(re) + if err != nil { + t.Errorf("re-decode failed: %v (re=%x)", err, re) + return + } + if p.Seq != p2.Seq || p.Ack != p2.Ack || p.SrcPort != p2.SrcPort || + p.DstPort != p2.DstPort || p.Window != p2.Window || + p.Protocol != p2.Protocol || p.Flags != p2.Flags || + p.Version != p2.Version || p.Src != p2.Src || p.Dst != p2.Dst { + t.Errorf("round-trip header mismatch: %+v vs %+v", p, p2) + } + if !bytes.Equal(p.Payload, p2.Payload) { + t.Errorf("round-trip payload mismatch: %x vs %x", p.Payload, p2.Payload) + } + }) +} + +// FuzzUnmarshalAddr targets the 6-byte address decoder directly. +// PILOT-133 specifically flagged this function; UnmarshalAddr does NOT +// have a defer/recover, so out-of-bounds slicing would propagate as a +// real panic. A naive call with len(buf) < AddrSize panics, so the +// fuzzer must include the bounds check the harness uses in practice. +func FuzzUnmarshalAddr(f *testing.F) { + f.Add(make([]byte, AddrSize)) + f.Add([]byte{0x00, 0x01, 0xDE, 0xAD, 0xBE, 0xEF}) + f.Add([]byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}) + + f.Fuzz(func(t *testing.T, data []byte) { + defer func() { + if r := recover(); r != nil { + t.Errorf("panic on input %x: %v", data, r) + } + }() + + // UnmarshalAddr's contract is "exactly 6 bytes". Callers that + // pass less are the bug shape PILOT-133 was concerned with — + // fuzz both the contract-respecting path and the over-long path. + if len(data) >= AddrSize { + a := UnmarshalAddr(data[:AddrSize]) + b := a.Marshal() + a2 := UnmarshalAddr(b) + if a != a2 { + t.Errorf("addr round-trip: %v != %v", a, a2) + } + } + }) +} + +// FuzzParseAddr exercises the text-form address parser. +func FuzzParseAddr(f *testing.F) { + f.Add("0:0000.0000.0001") + f.Add("1:0001.DEAD.BEEF") + f.Add("65535:FFFF.FFFF.FFFF") + f.Add("") + f.Add(":") + f.Add("garbage") + f.Add("0:0000.0000") + f.Add("0:0000.0000.0000.0000") + + f.Fuzz(func(t *testing.T, s string) { + defer func() { + if r := recover(); r != nil { + t.Errorf("panic on input %q: %v", s, r) + } + }() + a, err := ParseAddr(s) + if err != nil { + return + } + // Round-trip: String() of a parsed addr must re-parse equal. + a2, err := ParseAddr(a.String()) + if err != nil { + t.Errorf("re-parse of %q (= %v) failed: %v", s, a, err) + return + } + if a != a2 { + t.Errorf("round-trip mismatch: %v != %v (input %q)", a, a2, s) + } + }) +} diff --git a/protocol/zz_protocol_test.go b/protocol/zz_protocol_test.go new file mode 100644 index 0000000..13fffd1 --- /dev/null +++ b/protocol/zz_protocol_test.go @@ -0,0 +1,442 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package protocol + +import ( + "bytes" + "encoding/binary" + "errors" + "strings" + "testing" +) + +// --------------------------------------------------------------------------- +// Addr +// --------------------------------------------------------------------------- + +func TestAddrIsZero(t *testing.T) { + t.Parallel() + if !ZeroAddr().IsZero() { + t.Fatal("ZeroAddr().IsZero() should be true") + } + if AddrRegistry.IsZero() { + t.Fatal("AddrRegistry should not be zero") + } + if (Addr{Network: 1}).IsZero() { + t.Fatal("non-zero network should not be zero") + } + if (Addr{Node: 1}).IsZero() { + t.Fatal("non-zero node should not be zero") + } +} + +func TestAddrIsBroadcast(t *testing.T) { + t.Parallel() + b := BroadcastAddr(5) + if !b.IsBroadcast() { + t.Fatalf("BroadcastAddr(5).IsBroadcast() = false") + } + if b.Network != 5 { + t.Fatalf("BroadcastAddr network = %d, want 5", b.Network) + } + if (Addr{Node: 0xFFFFFFFE}).IsBroadcast() { + t.Fatal("non-broadcast node should not be broadcast") + } +} + +func TestAddrMarshalUnmarshalRoundTrip(t *testing.T) { + t.Parallel() + cases := []Addr{ + ZeroAddr(), + AddrRegistry, + {Network: 0xABCD, Node: 0x12345678}, + {Network: 0xFFFF, Node: 0xFFFFFFFF}, + } + for _, in := range cases { + buf := in.Marshal() + if len(buf) != AddrSize { + t.Fatalf("Marshal len = %d, want %d", len(buf), AddrSize) + } + out := UnmarshalAddr(buf) + if out != in { + t.Fatalf("round-trip: got %+v, want %+v", out, in) + } + } +} + +func TestAddrMarshalToOffset(t *testing.T) { + t.Parallel() + a := Addr{Network: 0xCAFE, Node: 0xDEADBEEF} + buf := make([]byte, 20) + a.MarshalTo(buf, 8) + out := UnmarshalAddr(buf[8:14]) + if out != a { + t.Fatalf("MarshalTo offset round-trip: got %+v, want %+v", out, a) + } + // Bytes outside the 6-byte window must remain zero + for i := 0; i < 8; i++ { + if buf[i] != 0 { + t.Errorf("byte %d not zero before offset: 0x%02x", i, buf[i]) + } + } + for i := 14; i < 20; i++ { + if buf[i] != 0 { + t.Errorf("byte %d not zero after offset: 0x%02x", i, buf[i]) + } + } +} + +func TestUnmarshalAddrShortBuffer(t *testing.T) { + t.Parallel() + // PILOT-133: UnmarshalAddr must not panic on short buffers. + // It should return a zero address instead of indexing out of bounds. + + // 0 bytes + func() { + defer func() { + if r := recover(); r != nil { + t.Errorf("UnmarshalAddr panicked on 0-byte buffer: %v", r) + } + }() + a := UnmarshalAddr([]byte{}) + if !a.IsZero() { + t.Errorf("expected zero addr for empty buffer, got %+v", a) + } + }() + + // 3 bytes + func() { + defer func() { + if r := recover(); r != nil { + t.Errorf("UnmarshalAddr panicked on 3-byte buffer: %v", r) + } + }() + a := UnmarshalAddr([]byte{0x00, 0x01, 0x02}) + if !a.IsZero() { + t.Errorf("expected zero addr for short buffer, got %+v", a) + } + }() + + // 5 bytes (one short) + func() { + defer func() { + if r := recover(); r != nil { + t.Errorf("UnmarshalAddr panicked on 5-byte buffer: %v", r) + } + }() + a := UnmarshalAddr([]byte{0x00, 0x01, 0xDE, 0xAD, 0xBE}) + if !a.IsZero() { + t.Errorf("expected zero addr for 5-byte buffer, got %+v", a) + } + }() + + // 6 bytes (valid, should work normally) + a := UnmarshalAddr([]byte{0x00, 0x01, 0xDE, 0xAD, 0xBE, 0xEF}) + want := Addr{Network: 0x0001, Node: 0xDEADBEEF} + if a != want { + t.Errorf("valid 6-byte buffer: got %+v, want %+v", a, want) + } +} + +func TestAddrStringFormat(t *testing.T) { + t.Parallel() + a := Addr{Network: 0x00A3, Node: 0xF2910004} + got := a.String() + want := "163:00A3.F291.0004" + if got != want { + t.Fatalf("String() = %q, want %q", got, want) + } +} + +func TestParseAddrValid(t *testing.T) { + t.Parallel() + in := "163:00A3.F291.0004" + a, err := ParseAddr(in) + if err != nil { + t.Fatalf("ParseAddr: %v", err) + } + if a.Network != 0x00A3 || a.Node != 0xF2910004 { + t.Fatalf("parsed addr wrong: %+v", a) + } + // Round-trip via String must equal input + if a.String() != in { + t.Fatalf("round-trip: %q != %q", a.String(), in) + } +} + +func TestParseAddrErrors(t *testing.T) { + t.Parallel() + cases := []struct { + in string + wantSub string + }{ + {"no-colon", "expected N:XXXX"}, + {"abc:0000.0000.0000", "invalid network ID"}, + {"1:0000.0000", "expected 3 dot-separated"}, + {"1:000.0000.0000", "expected 4 digits"}, + {"1:GGGG.0000.0000", "invalid hex group"}, + {"1:0001.GGGG.0000", "invalid hex group"}, // network matches so we reach the high-group check + {"1:0001.0000.GGGG", "invalid hex group"}, // network matches so we reach the low-group check + {"1:0002.0000.0000", "network mismatch"}, + } + for _, tc := range cases { + t.Run(tc.in, func(t *testing.T) { + t.Parallel() + _, err := ParseAddr(tc.in) + if err == nil { + t.Fatalf("expected error for %q", tc.in) + } + if !strings.Contains(err.Error(), tc.wantSub) { + t.Fatalf("error %q missing substring %q", err.Error(), tc.wantSub) + } + }) + } +} + +// --------------------------------------------------------------------------- +// SocketAddr +// --------------------------------------------------------------------------- + +func TestSocketAddrStringAndParse(t *testing.T) { + t.Parallel() + in := SocketAddr{Addr: Addr{Network: 1, Node: 0x00010001}, Port: 8080} + str := in.String() + if str != "1:0001.0001.0001:8080" { + t.Fatalf("String() = %q", str) + } + out, err := ParseSocketAddr(str) + if err != nil { + t.Fatalf("ParseSocketAddr: %v", err) + } + if out != in { + t.Fatalf("round-trip: got %+v, want %+v", out, in) + } +} + +func TestParseSocketAddrErrors(t *testing.T) { + t.Parallel() + cases := []struct { + in string + wantSub string + }{ + {"noport", "no port"}, + {"bad-addr:80", "invalid address"}, + {"1:0001.0001.0001:notanumber", "invalid port"}, + } + for _, tc := range cases { + t.Run(tc.in, func(t *testing.T) { + _, err := ParseSocketAddr(tc.in) + if err == nil { + t.Fatalf("expected error for %q", tc.in) + } + if !strings.Contains(err.Error(), tc.wantSub) { + t.Fatalf("error %q missing substring %q", err.Error(), tc.wantSub) + } + }) + } +} + +// --------------------------------------------------------------------------- +// Packet flags + Marshal/Unmarshal +// --------------------------------------------------------------------------- + +func TestPacketFlagOps(t *testing.T) { + t.Parallel() + p := &Packet{} + if p.HasFlag(FlagSYN) { + t.Fatal("fresh packet should have no flags") + } + p.SetFlag(FlagSYN) + p.SetFlag(FlagACK) + if !p.HasFlag(FlagSYN) || !p.HasFlag(FlagACK) { + t.Fatalf("set flags not detected: flags=0x%x", p.Flags) + } + if p.HasFlag(FlagFIN) { + t.Fatal("FIN should not be set") + } + p.ClearFlag(FlagSYN) + if p.HasFlag(FlagSYN) { + t.Fatal("SYN should be cleared") + } + if !p.HasFlag(FlagACK) { + t.Fatal("ACK should still be set") + } +} + +func TestPacketHeaderSize(t *testing.T) { + t.Parallel() + if PacketHeaderSize() != 34 { + t.Fatalf("PacketHeaderSize() = %d, want 34", PacketHeaderSize()) + } +} + +func TestPacketMarshalUnmarshalRoundTrip(t *testing.T) { + t.Parallel() + in := &Packet{ + Version: Version, + Flags: FlagACK | FlagSYN, + Protocol: ProtoStream, + Src: Addr{Network: 1, Node: 0x12345678}, + Dst: Addr{Network: 2, Node: 0xABCDEF01}, + SrcPort: 4040, + DstPort: PortHTTP, + Seq: 1234, + Ack: 5678, + Window: 64, + Payload: []byte("hello world"), + } + buf, err := in.Marshal() + if err != nil { + t.Fatalf("Marshal: %v", err) + } + if len(buf) != 34+len(in.Payload) { + t.Fatalf("buf len = %d, want %d", len(buf), 34+len(in.Payload)) + } + out, err := Unmarshal(buf) + if err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if out.Version != in.Version || out.Flags != in.Flags || out.Protocol != in.Protocol { + t.Fatalf("header mismatch: got %+v", out) + } + if out.Src != in.Src || out.Dst != in.Dst { + t.Fatalf("addr mismatch: got src=%v dst=%v", out.Src, out.Dst) + } + if out.SrcPort != in.SrcPort || out.DstPort != in.DstPort { + t.Fatalf("port mismatch: got src=%d dst=%d", out.SrcPort, out.DstPort) + } + if out.Seq != in.Seq || out.Ack != in.Ack || out.Window != in.Window { + t.Fatalf("seq/ack/window mismatch: got seq=%d ack=%d window=%d", out.Seq, out.Ack, out.Window) + } + if !bytes.Equal(out.Payload, in.Payload) { + t.Fatalf("payload mismatch: got %q", out.Payload) + } +} + +func TestPacketMarshalEmptyPayload(t *testing.T) { + t.Parallel() + in := &Packet{Version: Version, Protocol: ProtoControl} + buf, err := in.Marshal() + if err != nil { + t.Fatalf("Marshal: %v", err) + } + if len(buf) != 34 { + t.Fatalf("len = %d, want 34", len(buf)) + } + out, err := Unmarshal(buf) + if err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if len(out.Payload) != 0 { + t.Fatalf("expected empty payload, got %d bytes", len(out.Payload)) + } +} + +func TestPacketMarshalPayloadTooLarge(t *testing.T) { + t.Parallel() + in := &Packet{Version: Version, Payload: make([]byte, 0x10000)} // 65536, exceeds 0xFFFF + _, err := in.Marshal() + if err == nil { + t.Fatal("expected payload-too-large error") + } + if !strings.Contains(err.Error(), "payload too large") { + t.Fatalf("error %q missing substring", err) + } +} + +func TestUnmarshalTooShort(t *testing.T) { + t.Parallel() + _, err := Unmarshal(make([]byte, 33)) + if err == nil { + t.Fatal("expected too-short error") + } + if !strings.Contains(err.Error(), "too short") { + t.Fatalf("error %q missing substring", err) + } +} + +func TestUnmarshalTruncatedPayload(t *testing.T) { + t.Parallel() + // Build a header claiming 100-byte payload but send only the header. + buf := make([]byte, 34) + buf[0] = Version << 4 + binary.BigEndian.PutUint16(buf[2:4], 100) + _, err := Unmarshal(buf) + if err == nil { + t.Fatal("expected truncated error") + } + if !strings.Contains(err.Error(), "truncated") { + t.Fatalf("error %q missing 'truncated'", err) + } +} + +func TestUnmarshalChecksumMismatch(t *testing.T) { + t.Parallel() + in := &Packet{Version: Version, Protocol: ProtoStream, Payload: []byte("abc")} + buf, _ := in.Marshal() + // Corrupt one byte in the payload + buf[34] ^= 0xFF + _, err := Unmarshal(buf) + if !errors.Is(err, ErrChecksumMismatch) { + t.Fatalf("expected ErrChecksumMismatch, got %v", err) + } +} + +func TestUnmarshalUnsupportedVersion(t *testing.T) { + t.Parallel() + in := &Packet{Version: Version, Payload: []byte("x")} + buf, _ := in.Marshal() + // Flip the version nibble to a different value, then re-checksum so we + // hit the version check rather than the checksum check. + buf[0] = (0xA << 4) | (buf[0] & 0x0F) // version = 0xA + binary.BigEndian.PutUint32(buf[30:34], 0) + cs := Checksum(buf) + binary.BigEndian.PutUint32(buf[30:34], cs) + _, err := Unmarshal(buf) + if err == nil { + t.Fatal("expected unsupported version error") + } + if !strings.Contains(err.Error(), "unsupported protocol version") { + t.Fatalf("error %q missing substring", err) + } +} + +func TestUnmarshalRestoresChecksumBytes(t *testing.T) { + t.Parallel() + // Verify Unmarshal does not permanently mutate the caller's buffer + // (it temporarily zeroes the checksum field for computation, then restores). + in := &Packet{Version: Version, Protocol: ProtoStream, Payload: []byte("xyz")} + buf, _ := in.Marshal() + original := append([]byte(nil), buf...) + if _, err := Unmarshal(buf); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if !bytes.Equal(buf, original) { + t.Fatalf("Unmarshal mutated caller buffer: original checksum bytes %x, current %x", + original[30:34], buf[30:34]) + } +} + +// --------------------------------------------------------------------------- +// Checksum +// --------------------------------------------------------------------------- + +func TestChecksumDeterministic(t *testing.T) { + t.Parallel() + data := []byte("the quick brown fox") + c1 := Checksum(data) + c2 := Checksum(data) + if c1 != c2 { + t.Fatalf("Checksum non-deterministic: %d != %d", c1, c2) + } +} + +func TestChecksumDiffersOnSingleBitFlip(t *testing.T) { + t.Parallel() + a := []byte("payload") + b := append([]byte(nil), a...) + b[0] ^= 0x01 + if Checksum(a) == Checksum(b) { + t.Fatal("checksum did not detect 1-bit flip") + } +} diff --git a/registry/client/binary_client.go b/registry/client/binary_client.go new file mode 100644 index 0000000..3268679 --- /dev/null +++ b/registry/client/binary_client.go @@ -0,0 +1,278 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package client + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "net" + "sync" + "time" + + "github.com/pilot-protocol/common/registry/wire" +) + +// BinaryClient talks to a registry server using the binary wire protocol. +// It provides native binary encoding for hot-path operations (heartbeat, lookup, +// resolve) and JSON-over-binary passthrough for all other operations. +type BinaryClient struct { + conn net.Conn + mu sync.Mutex + addr string + closed bool +} + +// DialBinary connects to a registry server and negotiates the binary wire protocol. +// The server detects the magic bytes and switches to binary mode for this connection. +func DialBinary(addr string) (*BinaryClient, error) { + conn, err := net.DialTimeout("tcp", addr, 5*time.Second) + if err != nil { + return nil, fmt.Errorf("dial registry: %w", err) + } + + // Send magic + version to negotiate binary protocol + var handshake [5]byte + copy(handshake[:4], wire.Magic[:]) + handshake[4] = wire.Version + if _, err := conn.Write(handshake[:]); err != nil { + conn.Close() + return nil, fmt.Errorf("binary handshake: %w", err) + } + + return &BinaryClient{conn: conn, addr: addr}, nil +} + +// Close shuts down the binary client connection. +func (c *BinaryClient) Close() error { + c.mu.Lock() + c.closed = true + conn := c.conn + c.mu.Unlock() + if conn != nil { + return conn.Close() + } + return nil +} + +// Addr returns the registry address this client is connected to. +func (c *BinaryClient) Addr() string { + return c.addr +} + +// reconnect re-establishes the binary connection. Must be called with c.mu held. +func (c *BinaryClient) reconnect() error { + if c.closed { + return fmt.Errorf("client closed") + } + if c.conn != nil { + c.conn.Close() + } + + backoff := 500 * time.Millisecond + maxBackoff := 10 * time.Second + var lastErr error + + for attempts := 0; attempts < 5; attempts++ { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", c.addr) + cancel() + if err != nil { + lastErr = err + slog.Warn("binary client reconnect failed", "attempt", attempts+1, "err", err) + time.Sleep(backoff) + backoff *= 2 + if backoff > maxBackoff { + backoff = maxBackoff + } + continue + } + + // Re-negotiate binary protocol + var handshake [5]byte + copy(handshake[:4], wire.Magic[:]) + handshake[4] = wire.Version + if _, err := conn.Write(handshake[:]); err != nil { + conn.Close() + lastErr = err + continue + } + + c.conn = conn + slog.Info("binary client reconnected", "addr", c.addr) + return nil + } + return fmt.Errorf("reconnect failed after 5 attempts: %w", lastErr) +} + +// Heartbeat sends a binary heartbeat and returns the server time and key expiry warning. +func (c *BinaryClient) Heartbeat(nodeID uint32, sig []byte) (unixTime int64, keyExpiryWarning bool, err error) { + c.mu.Lock() + defer c.mu.Unlock() + + unixTime, keyExpiryWarning, err = c.heartbeatLocked(nodeID, sig) + if err != nil && !c.closed { + // Connection-level failure — reconnect and retry once + if reconnErr := c.reconnect(); reconnErr != nil { + return 0, false, fmt.Errorf("heartbeat failed and reconnect failed: %w", err) + } + unixTime, keyExpiryWarning, err = c.heartbeatLocked(nodeID, sig) + } + return +} + +func (c *BinaryClient) heartbeatLocked(nodeID uint32, sig []byte) (int64, bool, error) { + if err := wire.WriteFrame(c.conn, wire.MsgHeartbeat, wire.EncodeHeartbeatReq(nodeID, sig)); err != nil { + return 0, false, fmt.Errorf("send heartbeat: %w", err) + } + + c.conn.SetReadDeadline(time.Now().Add(30 * time.Second)) + msgType, payload, err := wire.ReadFrame(c.conn) + c.conn.SetReadDeadline(time.Time{}) + if err != nil { + return 0, false, fmt.Errorf("recv heartbeat: %w", err) + } + + if msgType == wire.MsgError { + return 0, false, fmt.Errorf("registry: %s", wire.DecodeError(payload)) + } + if msgType != wire.MsgHeartbeatOK { + return 0, false, fmt.Errorf("unexpected response type 0x%02x", msgType) + } + + return wire.DecodeHeartbeatResp(payload) +} + +// Lookup sends a binary lookup request and returns the decoded result. +func (c *BinaryClient) Lookup(nodeID uint32) (*wire.LookupResult, error) { + c.mu.Lock() + defer c.mu.Unlock() + + result, err := c.lookupLocked(nodeID) + if err != nil && !c.closed { + if reconnErr := c.reconnect(); reconnErr != nil { + return nil, fmt.Errorf("lookup failed and reconnect failed: %w", err) + } + result, err = c.lookupLocked(nodeID) + } + return result, err +} + +func (c *BinaryClient) lookupLocked(nodeID uint32) (*wire.LookupResult, error) { + if err := wire.WriteFrame(c.conn, wire.MsgLookup, wire.EncodeLookupReq(nodeID)); err != nil { + return nil, fmt.Errorf("send lookup: %w", err) + } + + c.conn.SetReadDeadline(time.Now().Add(30 * time.Second)) + msgType, payload, err := wire.ReadFrame(c.conn) + c.conn.SetReadDeadline(time.Time{}) + if err != nil { + return nil, fmt.Errorf("recv lookup: %w", err) + } + + if msgType == wire.MsgError { + return nil, fmt.Errorf("registry: %s", wire.DecodeError(payload)) + } + if msgType != wire.MsgLookupOK { + return nil, fmt.Errorf("unexpected response type 0x%02x", msgType) + } + + result, err := wire.DecodeLookupResp(payload) + if err != nil { + return nil, fmt.Errorf("decode lookup response: %w", err) + } + return &result, nil +} + +// Resolve sends a binary resolve request and returns the decoded result. +func (c *BinaryClient) Resolve(nodeID, requesterID uint32, sig []byte) (*wire.ResolveResult, error) { + c.mu.Lock() + defer c.mu.Unlock() + + result, err := c.resolveLocked(nodeID, requesterID, sig) + if err != nil && !c.closed { + if reconnErr := c.reconnect(); reconnErr != nil { + return nil, fmt.Errorf("resolve failed and reconnect failed: %w", err) + } + result, err = c.resolveLocked(nodeID, requesterID, sig) + } + return result, err +} + +func (c *BinaryClient) resolveLocked(nodeID, requesterID uint32, sig []byte) (*wire.ResolveResult, error) { + if err := wire.WriteFrame(c.conn, wire.MsgResolve, wire.EncodeResolveReq(nodeID, requesterID, sig)); err != nil { + return nil, fmt.Errorf("send resolve: %w", err) + } + + c.conn.SetReadDeadline(time.Now().Add(30 * time.Second)) + msgType, payload, err := wire.ReadFrame(c.conn) + c.conn.SetReadDeadline(time.Time{}) + if err != nil { + return nil, fmt.Errorf("recv resolve: %w", err) + } + + if msgType == wire.MsgError { + return nil, fmt.Errorf("registry: %s", wire.DecodeError(payload)) + } + if msgType != wire.MsgResolveOK { + return nil, fmt.Errorf("unexpected response type 0x%02x", msgType) + } + + result, err := wire.DecodeResolveResp(payload) + if err != nil { + return nil, fmt.Errorf("decode resolve response: %w", err) + } + return &result, nil +} + +// SendJSON sends a JSON message over the binary protocol using JSON passthrough. +// This allows any registry operation to be used without a native binary encoding. +func (c *BinaryClient) SendJSON(msg map[string]interface{}) (map[string]interface{}, error) { + c.mu.Lock() + defer c.mu.Unlock() + + resp, err := c.sendJSONLocked(msg) + if err != nil && resp == nil && !c.closed { + if reconnErr := c.reconnect(); reconnErr != nil { + return nil, fmt.Errorf("send failed and reconnect failed: %w", err) + } + resp, err = c.sendJSONLocked(msg) + } + return resp, err +} + +func (c *BinaryClient) sendJSONLocked(msg map[string]interface{}) (map[string]interface{}, error) { + body, err := json.Marshal(msg) + if err != nil { + return nil, fmt.Errorf("json encode: %w", err) + } + + if err := wire.WriteFrame(c.conn, wire.MsgJSON, body); err != nil { + return nil, fmt.Errorf("send: %w", err) + } + + c.conn.SetReadDeadline(time.Now().Add(30 * time.Second)) + msgType, payload, readErr := wire.ReadFrame(c.conn) + c.conn.SetReadDeadline(time.Time{}) + if readErr != nil { + return nil, fmt.Errorf("recv: %w", readErr) + } + + if msgType == wire.MsgError { + errMsg := wire.DecodeError(payload) + return map[string]interface{}{"type": "error", "error": errMsg}, fmt.Errorf("registry: %s", errMsg) + } + if msgType != wire.MsgJSON { + return nil, fmt.Errorf("unexpected response type 0x%02x for JSON passthrough", msgType) + } + + var resp map[string]interface{} + if err := json.Unmarshal(payload, &resp); err != nil { + return nil, fmt.Errorf("json decode response: %w", err) + } + if errMsg, ok := resp["error"].(string); ok { + return resp, fmt.Errorf("registry: %s", errMsg) + } + return resp, nil +} diff --git a/registry/client/client.go b/registry/client/client.go new file mode 100644 index 0000000..404b190 --- /dev/null +++ b/registry/client/client.go @@ -0,0 +1,1393 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package client + +import ( + "context" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "log/slog" + "net" + "sync" + "time" + + "github.com/pilot-protocol/common/registry/wire" +) + +// ErrNoRegistry is returned from every exported *Client method when the +// receiver is a typed nil pointer. Callers (loadPolicyRunners, +// ManagedEngine.fetchMembers, Daemon.Info → nodeNetworks, etc.) sometimes +// invoke registry methods before the client is configured; returning this +// sentinel instead of panicking lets them treat "no registry" as a +// recoverable condition. +var ErrNoRegistry = errors.New("registry client not configured") + +// Client talks to a registry server over TCP (optionally TLS). +// It automatically reconnects if the connection drops. +// +// By default a Client owns a single TCP connection (Dial / DialTLS / +// DialTLSPinned). Each Send takes c.mu and serialises the entire +// request/response round-trip on that one conn. Under heavy concurrent +// load (the §4.8 lock-graph stress harness — 250 heartbeat goroutines +// per daemon hammering Health / Info / SetTags / ResolveHostname plus +// the per-resolve prewarm goroutines and persistHostnameCache writers, +// all funnelling through regConn.Send) that single mutex becomes the +// bottleneck: in-flight calls cannot honour shutdown signals because +// they're queued behind the mutex. +// +// DialPool / DialTLSPool create a Client backed by a small pool of +// connections (the primary c.conn plus N-1 secondary conns). Each +// concurrent Send picks a free pooled conn (blocking only if every +// conn is in use), eliminating the head-of-line wait. The primary +// c.conn / c.mu / c.closed fields are retained for backward compatibility +// with tests that touch them directly. +type Client struct { + // Primary connection. Always present; tests in this package read + // c.conn / c.mu / c.closed directly so the field set must stay stable. + conn net.Conn + mu sync.Mutex + addr string // registry address for reconnection + closed bool + tlsConfig *tls.Config + signer func(challenge string) string // H3 fix: optional message signer + + // Optional pool of secondary connections used to parallelise Send. + // nil / empty when DialPool was not used. + pool poolState +} + +// poolState holds the secondary-conn pool. The primary slot (c.conn / c.mu) +// is also represented here as the first entry, so acquireConn / releaseConn +// can pick uniformly across all conns. +type poolState struct { + // entries is the full set of conns including the primary at index 0. + // Each entry has its own mu — taking entry.mu lets one Send proceed + // without blocking other Sends on different entries. Closed when the + // Client is closed. + entries []*pooledConn + // free is a buffered channel of pointers to entries currently free. + // Capacity equals len(entries). Send: <-free; defer free<-entry. + // nil means "no pool" (legacy single-conn path via c.mu). + free chan *pooledConn + // done is closed by Close() to wake goroutines blocked on <-free and to + // signal the deferred pool-return in sendPool to drop its entry instead + // of sending on free. Using a separate done channel avoids the race + // between close(free) and concurrent sends on free. + done chan struct{} +} + +// pooledConn wraps one TCP connection plus its own mutex. The mutex +// guards both the conn pointer and any reconnect that happens through +// it; sendOnEntry takes it for the full write/read round-trip. +type pooledConn struct { + mu sync.Mutex + conn net.Conn +} + +// SetSigner sets a signing function for authenticated registry operations (H3 fix). +// The signer receives a challenge string and returns a base64-encoded Ed25519 signature. +// +// Issue #93: when the regConn is pooled (DialPool), multiple Send goroutines +// may call sign() concurrently while a parallel RotateKey path calls +// SetSigner. We guard the field with c.mu to keep that race-free; reads via +// sign() take the same lock so the loaded function pointer is consistent. +func (c *Client) SetSigner(fn func(challenge string) string) { + if c == nil { + return + } + c.mu.Lock() + c.signer = fn + c.mu.Unlock() +} + +// sign returns a signature for the challenge. It returns an error when the +// signer is unavailable or returns an empty signature. A nil receiver returns +// ErrNoRegistry so callers can rely on errors.Is. +func (c *Client) sign(challenge string) (string, error) { + if c == nil { + return "", ErrNoRegistry + } + c.mu.Lock() + fn := c.signer + c.mu.Unlock() + if fn == nil { + return "", fmt.Errorf("registry client: no signer configured (call SetSigner first)") + } + sig := fn(challenge) + if sig == "" { + return "", fmt.Errorf("registry client: signer returned empty signature for %q", challenge) + } + return sig, nil +} + +func Dial(addr string) (*Client, error) { + conn, err := net.Dial("tcp", addr) + if err != nil { + return nil, fmt.Errorf("dial registry: %w", err) + } + return &Client{conn: conn, addr: addr}, nil +} + +// DialPool connects to a registry server over plain TCP and pre-warms a +// pool of `size` connections (size >= 1). When size == 1 this is identical +// to Dial. When size > 1, additional secondary conns are dialed; concurrent +// Send calls then run in parallel up to `size` at a time, instead of all +// queueing on a single mutex. +// +// DialPool exists to fix #93 (regConn fairness under sustained load): the +// daemon's IPC handlers spawn goroutines that all call regConn.Send and +// previously serialised on c.mu. With DialPool the daemon can keep the +// same code path while letting up to `size` registry round-trips run +// concurrently. +// +// On any pool conn dial failure DialPool closes the conns it had already +// opened and returns an error. +func DialPool(addr string, size int) (*Client, error) { + if size <= 0 { + size = 1 + } + primary, err := net.Dial("tcp", addr) + if err != nil { + return nil, fmt.Errorf("dial registry: %w", err) + } + c := &Client{conn: primary, addr: addr} + if err := c.initPool(size, nil); err != nil { + primary.Close() + return nil, err + } + return c, nil +} + +// DialTLS connects to a registry server over TLS. +// A non-nil tlsConfig is required. For certificate pinning, use DialTLSPinned. +func DialTLS(addr string, tlsConfig *tls.Config) (*Client, error) { + if tlsConfig == nil { + return nil, fmt.Errorf("TLS config required; use DialTLSPinned for certificate pinning") + } + conn, err := tls.Dial("tcp", addr, tlsConfig) + if err != nil { + return nil, fmt.Errorf("dial registry TLS: %w", err) + } + return &Client{conn: conn, addr: addr, tlsConfig: tlsConfig}, nil +} + +// DialTLSPool is the TLS variant of DialPool. +func DialTLSPool(addr string, tlsConfig *tls.Config, size int) (*Client, error) { + if tlsConfig == nil { + return nil, fmt.Errorf("TLS config required; use DialTLSPinnedPool for certificate pinning") + } + if size <= 0 { + size = 1 + } + primary, err := tls.Dial("tcp", addr, tlsConfig) + if err != nil { + return nil, fmt.Errorf("dial registry TLS: %w", err) + } + c := &Client{conn: primary, addr: addr, tlsConfig: tlsConfig} + if err := c.initPool(size, tlsConfig); err != nil { + primary.Close() + return nil, err + } + return c, nil +} + +// initPool dials size-1 additional connections and registers them in c.pool. +// It assumes c.conn (primary) is already set. tlsCfg, when non-nil, is used +// for TLS dialing; otherwise plain TCP. +func (c *Client) initPool(size int, tlsCfg *tls.Config) error { + if size <= 1 { + // No secondary conns — single-conn legacy path; pool stays empty. + return nil + } + entries := make([]*pooledConn, 0, size) + entries = append(entries, &pooledConn{conn: c.conn}) + for i := 1; i < size; i++ { + var conn net.Conn + var err error + if tlsCfg != nil { + conn, err = tls.Dial("tcp", c.addr, tlsCfg) + } else { + conn, err = net.Dial("tcp", c.addr) + } + if err != nil { + // Close any conns we already opened (excluding primary — + // caller closes that on failure). + for _, e := range entries[1:] { + e.conn.Close() + } + return fmt.Errorf("dial pool conn %d: %w", i, err) + } + entries = append(entries, &pooledConn{conn: conn}) + } + free := make(chan *pooledConn, len(entries)) + for _, e := range entries { + free <- e + } + c.pool.entries = entries + c.pool.free = free + c.pool.done = make(chan struct{}) + return nil +} + +// DialTLSPinned connects to a registry server over TLS with certificate pinning. +// The fingerprint is a hex-encoded SHA-256 hash of the server's DER-encoded certificate. +func DialTLSPinned(addr, fingerprint string) (*Client, error) { + tlsConfig := &tls.Config{ + // InsecureSkipVerify disables the default CA chain check so we can + // use VerifyPeerCertificate for certificate pinning (SHA-256 fingerprint). + // This is the standard Go pattern — the custom callback below provides + // strictly stronger verification than CA-based trust. + InsecureSkipVerify: true, //nolint:gosec // cert pinning via VerifyPeerCertificate + VerifyPeerCertificate: func(rawCerts [][]byte, _ [][]*x509.Certificate) error { + if len(rawCerts) == 0 { + return fmt.Errorf("no certificate presented") + } + hash := sha256.Sum256(rawCerts[0]) + got := hex.EncodeToString(hash[:]) + if got != fingerprint { + return fmt.Errorf("certificate fingerprint mismatch: got %s, want %s", got, fingerprint) + } + return nil + }, + } + conn, err := tls.Dial("tcp", addr, tlsConfig) + if err != nil { + return nil, fmt.Errorf("dial registry TLS pinned: %w", err) + } + return &Client{conn: conn, addr: addr, tlsConfig: tlsConfig}, nil +} + +func (c *Client) Close() error { + if c == nil { + return nil + } + c.mu.Lock() + c.closed = true + conn := c.conn + pool := c.pool.entries + c.mu.Unlock() + // Close the conn after releasing the lock; conn is captured by value + // so reconnect() can't see it after we set c.closed=true (M7 fix) + var firstErr error + if conn != nil { + if err := conn.Close(); err != nil { + firstErr = err + } + } + // Close every secondary pooled conn. The primary is already closed + // above (entries[0] holds the same fd as c.conn). Skip index 0 so + // we don't double-close. + for i := 1; i < len(pool); i++ { + e := pool[i] + // Take e.mu to coordinate with any in-flight sendOnEntry that + // holds it; once we release the mutex it'll see a closed conn + // on its next Read/Write and return an error to the caller. + e.mu.Lock() + if e.conn != nil { + if err := e.conn.Close(); err != nil && firstErr == nil { + firstErr = err + } + } + e.mu.Unlock() + } + // Close pool.done to wake any goroutine blocked on <-c.pool.free and to + // signal the deferred pool-return in sendPool to drop its entry. We never + // close pool.free itself because that would race with concurrent sends on + // it from the sendPool defer. + if c.pool.done != nil { + close(c.pool.done) + } + return firstErr +} + +// reconnect re-establishes the TCP connection to the registry. +// Must be called with c.mu held. +func (c *Client) reconnect(ctx context.Context) error { + if c.closed { + return fmt.Errorf("client closed") + } + if c.conn != nil { + c.conn.Close() + } + + var conn net.Conn + var err error + backoff := 500 * time.Millisecond + maxBackoff := 10 * time.Second + + for attempts := 0; attempts < 5; attempts++ { + if c.tlsConfig != nil { + dialer := &tls.Dialer{Config: c.tlsConfig, NetDialer: &net.Dialer{Timeout: 5 * time.Second}} + conn, err = dialer.DialContext(ctx, "tcp", c.addr) + } else { + conn, err = net.DialTimeout("tcp", c.addr, 5*time.Second) + } + if err == nil { + c.conn = conn + slog.Info("registry reconnected", "addr", c.addr) + return nil + } + slog.Warn("registry reconnect failed", "attempt", attempts+1, "err", err) + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(backoff): + } + backoff *= 2 + if backoff > maxBackoff { + backoff = maxBackoff + } + } + return fmt.Errorf("reconnect failed after 5 attempts: %w", err) +} + +// Send sends a registry message without a deadline. For shutdown-safe use +// that respects context cancellation, prefer SendContext. +func (c *Client) Send(msg map[string]interface{}) (map[string]interface{}, error) { + return c.SendContext(context.Background(), msg) +} + +// SendContext sends a registry message with context propagation through +// reconnect retries. Callers should pass a context with deadline or +// cancellation (e.g. daemon shutdown context) so that reconnect backoff +// does not block graceful stop. +func (c *Client) SendContext(ctx context.Context, msg map[string]interface{}) (map[string]interface{}, error) { + // Nil receiver — return a sentinel rather than panicking. Every + // exported wrapper method (Register, Lookup, Resolve, …) funnels + // through Send, so this single guard turns "calling a registry + // method on a nil client" into a recoverable error for every + // caller (loadPolicyRunners, ManagedEngine.fetchMembers, + // Daemon.Info → nodeNetworks, etc.). + if c == nil { + return nil, ErrNoRegistry + } + // Pool-enabled path (DialPool / DialTLSPool): pick a free conn and + // run the round-trip on it without touching c.mu. Multiple Send + // callers can run concurrently on different pooled conns. + if c.pool.free != nil { + return c.sendPool(ctx, msg) + } + + c.mu.Lock() + defer c.mu.Unlock() + + resp, err := c.sendLocked(msg) + if err != nil && resp == nil && !c.closed { + // Connection-level failure (no response received) — reconnect and retry once. + // Server error responses (resp != nil) do NOT trigger reconnection. + if reconnErr := c.reconnect(ctx); reconnErr != nil { + return nil, fmt.Errorf("send failed and reconnect failed: %w", err) + } + resp, err = c.sendLocked(msg) + } + return resp, err +} + +// sendPool runs Send on a free pooled connection. It blocks only when +// every pooled conn is busy (capacity exhausted) — one concurrent Send +// per pool entry can be in flight at a time. +func (c *Client) sendPool(ctx context.Context, msg map[string]interface{}) (map[string]interface{}, error) { + // Cheap closed check — avoids a wedged caller waiting on a free + // channel that nobody will ever return to once Close has run. + c.mu.Lock() + closed := c.closed + c.mu.Unlock() + if closed { + return nil, fmt.Errorf("client closed") + } + + var entry *pooledConn + select { + case entry = <-c.pool.free: + case <-c.pool.done: + return nil, fmt.Errorf("client closed") + } + defer func() { + select { + case c.pool.free <- entry: + case <-c.pool.done: + // pool is torn down; drop the entry + } + }() + + entry.mu.Lock() + defer entry.mu.Unlock() + + resp, err := c.sendOnEntry(entry, msg) + if err != nil && resp == nil && !c.isClosed() { + // Connection-level failure on this entry — reconnect THIS entry + // only (other pool entries are unaffected) and retry once. + if reconnErr := c.reconnectEntry(ctx, entry); reconnErr != nil { + return nil, fmt.Errorf("send failed and reconnect failed: %w", err) + } + resp, err = c.sendOnEntry(entry, msg) + } + return resp, err +} + +// sendOnEntry writes the request and reads the response on entry.conn. +// Caller must hold entry.mu. +func (c *Client) sendOnEntry(entry *pooledConn, msg map[string]interface{}) (map[string]interface{}, error) { + if err := wire.WriteMessage(entry.conn, msg); err != nil { + return nil, fmt.Errorf("send: %w", err) + } + entry.conn.SetReadDeadline(time.Now().Add(30 * time.Second)) + resp, err := wire.ReadMessage(entry.conn) + entry.conn.SetReadDeadline(time.Time{}) + if err != nil { + return nil, fmt.Errorf("recv: %w", err) + } + if errMsg, ok := resp["error"].(string); ok { + return resp, fmt.Errorf("registry: %s", errMsg) + } + return resp, nil +} + +// reconnectEntry redials a single pool entry. Caller must hold entry.mu. +// This is the per-entry analogue of Client.reconnect. +func (c *Client) reconnectEntry(ctx context.Context, entry *pooledConn) error { + if c.isClosed() { + return fmt.Errorf("client closed") + } + if entry.conn != nil { + entry.conn.Close() + } + + var conn net.Conn + var err error + backoff := 500 * time.Millisecond + maxBackoff := 10 * time.Second + for attempts := 0; attempts < 5; attempts++ { + if c.tlsConfig != nil { + dialer := &tls.Dialer{Config: c.tlsConfig, NetDialer: &net.Dialer{Timeout: 5 * time.Second}} + conn, err = dialer.DialContext(ctx, "tcp", c.addr) + } else { + conn, err = net.DialTimeout("tcp", c.addr, 5*time.Second) + } + if err == nil { + entry.conn = conn + // Keep c.conn (primary) in sync if this is the primary entry. + // Tests in this package read c.conn directly, so we must not + // leave it pointing at a closed fd. + if entry == c.pool.entries[0] { + c.mu.Lock() + c.conn = conn + c.mu.Unlock() + } + slog.Info("registry pool conn reconnected", "addr", c.addr) + return nil + } + slog.Warn("registry pool conn reconnect failed", "attempt", attempts+1, "err", err) + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(backoff): + } + backoff *= 2 + if backoff > maxBackoff { + backoff = maxBackoff + } + } + return fmt.Errorf("reconnect failed after 5 attempts: %w", err) +} + +// isClosed returns whether Close has been called. Cheap, lock-protected. +func (c *Client) isClosed() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.closed +} + +// sendLocked sends a message and reads the response. Must be called with c.mu held. +func (c *Client) sendLocked(msg map[string]interface{}) (map[string]interface{}, error) { + if err := wire.WriteMessage(c.conn, msg); err != nil { + return nil, fmt.Errorf("send: %w", err) + } + c.conn.SetReadDeadline(time.Now().Add(30 * time.Second)) + resp, err := wire.ReadMessage(c.conn) + c.conn.SetReadDeadline(time.Time{}) + if err != nil { + return nil, fmt.Errorf("recv: %w", err) + } + if errMsg, ok := resp["error"].(string); ok { + return resp, fmt.Errorf("registry: %s", errMsg) + } + return resp, nil +} + +func (c *Client) Register(listenAddr string) (map[string]interface{}, error) { + return c.Send(map[string]interface{}{ + "type": "register", + "listen_addr": listenAddr, + }) +} + +// RegisterWithOwner registers a new node with an owner identifier (email/name) +// for key rotation recovery. +func (c *Client) RegisterWithOwner(listenAddr, owner string) (map[string]interface{}, error) { + return c.Send(map[string]interface{}{ + "type": "register", + "listen_addr": listenAddr, + "owner": owner, + }) +} + +// RegisterWithKey re-registers using an existing Ed25519 public key. +// The registry returns the same node_id if the key is known. +// lanAddrs are the node's LAN addresses for same-network peer detection. +func (c *Client) RegisterWithKey(listenAddr, publicKeyB64, owner string, lanAddrs []string, opts ...string) (map[string]interface{}, error) { + return c.RegisterWithKeyOpts(RegisterOpts{ + ListenAddr: listenAddr, + PublicKey: publicKeyB64, + Owner: owner, + LANAddrs: lanAddrs, + Version: firstNonEmpty(opts...), + }) +} + +// RegisterOpts is the full set of registration options. Lets us add +// fields (like RelayOnly for task 32) without breaking the variadic +// signature of RegisterWithKey. +type RegisterOpts struct { + ListenAddr string + PublicKey string // base64 Ed25519 + Owner string + LANAddrs []string + Version string + RelayOnly bool // task 32: hide real_addr from peers +} + +// RegisterWithKeyOpts is the structured-form register call. Existing +// callers keep using RegisterWithKey; new flags go here. +func (c *Client) RegisterWithKeyOpts(o RegisterOpts) (map[string]interface{}, error) { + msg := map[string]interface{}{ + "type": "register", + "listen_addr": o.ListenAddr, + "public_key": o.PublicKey, + } + if o.Owner != "" { + msg["owner"] = o.Owner + } + if len(o.LANAddrs) > 0 { + msg["lan_addrs"] = o.LANAddrs + } + if o.Version != "" { + msg["version"] = o.Version + } + if o.RelayOnly { + msg["relay_only"] = true + } + return c.Send(msg) +} + +func firstNonEmpty(s ...string) string { + for _, v := range s { + if v != "" { + return v + } + } + return "" +} + +// RotateKey requests a key rotation for a node. +// Requires a signature proving ownership of the current key and the new public key. +func (c *Client) RotateKey(nodeID uint32, signatureB64, newPubKeyB64 string) (map[string]interface{}, error) { + msg := map[string]interface{}{ + "type": "rotate_key", + "node_id": nodeID, + } + if signatureB64 != "" { + msg["signature"] = signatureB64 + } + if newPubKeyB64 != "" { + msg["new_public_key"] = newPubKeyB64 + } + return c.Send(msg) +} + +func (c *Client) Lookup(nodeID uint32) (map[string]interface{}, error) { + return c.Send(map[string]interface{}{ + "type": "lookup", + "node_id": nodeID, + }) +} + +func (c *Client) Resolve(nodeID, requesterID uint32) (map[string]interface{}, error) { + msg := map[string]interface{}{ + "type": "resolve", + "node_id": nodeID, + "requester_id": requesterID, + } + sig, err := c.sign(fmt.Sprintf("resolve:%d:%d", requesterID, nodeID)) + if err != nil { + return nil, err + } + msg["signature"] = sig + return c.Send(msg) +} + +func (c *Client) ReportTrust(nodeID, peerID uint32) (map[string]interface{}, error) { + msg := map[string]interface{}{ + "type": "report_trust", + "node_id": nodeID, + "peer_id": peerID, + } + sig, err := c.sign(fmt.Sprintf("report_trust:%d:%d", nodeID, peerID)) + if err != nil { + return nil, err + } + msg["signature"] = sig + return c.Send(msg) +} + +func (c *Client) RevokeTrust(nodeID, peerID uint32) (map[string]interface{}, error) { + msg := map[string]interface{}{ + "type": "revoke_trust", + "node_id": nodeID, + "peer_id": peerID, + } + sig, err := c.sign(fmt.Sprintf("revoke_trust:%d:%d", nodeID, peerID)) + if err != nil { + return nil, err + } + msg["signature"] = sig + return c.Send(msg) +} + +func (c *Client) SetVisibility(nodeID uint32, public bool) (map[string]interface{}, error) { + msg := map[string]interface{}{ + "type": "set_visibility", + "node_id": nodeID, + "public": public, + } + sig, err := c.sign(fmt.Sprintf("set_visibility:%d", nodeID)) + if err != nil { + return nil, err + } + msg["signature"] = sig + return c.Send(msg) +} + +func (c *Client) CreateNetwork(nodeID uint32, name, joinRule, token, adminToken string, enterprise bool, networkAdminToken ...string) (map[string]interface{}, error) { + msg := map[string]interface{}{ + "type": "create_network", + "node_id": nodeID, + "name": name, + "join_rule": joinRule, + "token": token, + } + if adminToken != "" { + msg["admin_token"] = adminToken + } + if enterprise { + msg["enterprise"] = true + } + if len(networkAdminToken) > 0 && networkAdminToken[0] != "" { + msg["network_admin_token"] = networkAdminToken[0] + } + return c.Send(msg) +} + +// CreateManagedNetwork creates a network with managed rules. +func (c *Client) CreateManagedNetwork(nodeID uint32, name, joinRule, token, adminToken string, enterprise bool, rules string, networkAdminToken ...string) (map[string]interface{}, error) { + msg := map[string]interface{}{ + "type": "create_network", + "node_id": nodeID, + "name": name, + "join_rule": joinRule, + "token": token, + "rules": rules, + } + if adminToken != "" { + msg["admin_token"] = adminToken + } + if enterprise { + msg["enterprise"] = true + } + if len(networkAdminToken) > 0 && networkAdminToken[0] != "" { + msg["network_admin_token"] = networkAdminToken[0] + } + return c.Send(msg) +} + +func (c *Client) JoinNetwork(nodeID uint32, networkID uint16, token string, inviterID uint32, adminToken string) (map[string]interface{}, error) { + msg := map[string]interface{}{ + "type": "join_network", + "node_id": nodeID, + "network_id": networkID, + "token": token, + "inviter_id": inviterID, + } + sig, err := c.sign(fmt.Sprintf("join_network:%d:%d", nodeID, networkID)) + if err == nil { + msg["signature"] = sig + } else if adminToken != "" { + msg["admin_token"] = adminToken + } + return c.Send(msg) +} + +func (c *Client) LeaveNetwork(nodeID uint32, networkID uint16, adminToken string) (map[string]interface{}, error) { + msg := map[string]interface{}{ + "type": "leave_network", + "node_id": nodeID, + "network_id": networkID, + } + sig, err := c.sign(fmt.Sprintf("leave_network:%d:%d", nodeID, networkID)) + if err == nil { + msg["signature"] = sig + } else if adminToken != "" { + msg["admin_token"] = adminToken + } + return c.Send(msg) +} + +func (c *Client) DeleteNetwork(networkID uint16, adminToken string, nodeID ...uint32) (map[string]interface{}, error) { + msg := map[string]interface{}{ + "type": "delete_network", + "network_id": networkID, + } + if adminToken != "" { + msg["admin_token"] = adminToken + } + if len(nodeID) > 0 && nodeID[0] != 0 { + msg["node_id"] = nodeID[0] + } + return c.Send(msg) +} + +func (c *Client) RenameNetwork(networkID uint16, name, adminToken string, nodeID ...uint32) (map[string]interface{}, error) { + msg := map[string]interface{}{ + "type": "rename_network", + "network_id": networkID, + "name": name, + } + if adminToken != "" { + msg["admin_token"] = adminToken + } + if len(nodeID) > 0 && nodeID[0] != 0 { + msg["node_id"] = nodeID[0] + } + return c.Send(msg) +} + +func (c *Client) SetNetworkEnterprise(networkID uint16, enterprise bool, adminToken string) (map[string]interface{}, error) { + return c.Send(map[string]interface{}{ + "type": "set_network_enterprise", + "network_id": networkID, + "enterprise": enterprise, + "admin_token": adminToken, + }) +} + +// ListNetworks returns the registry's network catalog. Member counts +// (the `members` field on each entry) are admin-only — pass a non-empty +// adminToken to receive them; otherwise the field is omitted. +func (c *Client) ListNetworks(adminToken ...string) (map[string]interface{}, error) { + msg := map[string]interface{}{ + "type": "list_networks", + } + if len(adminToken) > 0 && adminToken[0] != "" { + msg["admin_token"] = adminToken[0] + } + return c.Send(msg) +} + +func (c *Client) ListNodes(networkID uint16, adminToken ...string) (map[string]interface{}, error) { + msg := map[string]interface{}{ + "type": "list_nodes", + "network_id": networkID, + } + if len(adminToken) > 0 && adminToken[0] != "" { + msg["admin_token"] = adminToken[0] + } + return c.Send(msg) +} + +func (c *Client) Deregister(nodeID uint32) (map[string]interface{}, error) { + msg := map[string]interface{}{ + "type": "deregister", + "node_id": nodeID, + } + sig, err := c.sign(fmt.Sprintf("deregister:%d", nodeID)) + if err != nil { + return nil, err + } + msg["signature"] = sig + return c.Send(msg) +} + +func (c *Client) Heartbeat(nodeID uint32) (map[string]interface{}, error) { + msg := map[string]interface{}{ + "type": "heartbeat", + "node_id": nodeID, + } + sig, err := c.sign(fmt.Sprintf("heartbeat:%d", nodeID)) + if err != nil { + return nil, err + } + msg["signature"] = sig + return c.Send(msg) +} + +func (c *Client) Punch(requesterID, nodeA, nodeB uint32) (map[string]interface{}, error) { + msg := map[string]interface{}{ + "type": "punch", + "requester_id": requesterID, + "node_a": nodeA, + "node_b": nodeB, + } + sig, err := c.sign(fmt.Sprintf("punch:%d:%d", nodeA, nodeB)) + if err != nil { + return nil, err + } + msg["signature"] = sig + return c.Send(msg) +} + +// RequestHandshake relays a handshake request through the registry to a target node. +// This works even for private nodes — no IP exposure needed. +// M12 fix: includes a signature to prove sender identity. +func (c *Client) RequestHandshake(fromNodeID, toNodeID uint32, justification, signatureB64 string) (map[string]interface{}, error) { + msg := map[string]interface{}{ + "type": "request_handshake", + "from_node_id": fromNodeID, + "to_node_id": toNodeID, + "justification": justification, + } + if signatureB64 != "" { + msg["signature"] = signatureB64 + } + return c.Send(msg) +} + +// PollHandshakes retrieves and clears pending handshake requests for a node. +// H3 fix: includes a signature to prove node identity. +func (c *Client) PollHandshakes(nodeID uint32) (map[string]interface{}, error) { + msg := map[string]interface{}{ + "type": "poll_handshakes", + "node_id": nodeID, + } + sig, err := c.sign(fmt.Sprintf("poll_handshakes:%d", nodeID)) + if err != nil { + return nil, err + } + msg["signature"] = sig + return c.Send(msg) +} + +// RespondHandshake approves or rejects a relayed handshake request. +// If accepted, the registry creates a mutual trust pair. +// M12 fix: includes a signature to prove responder identity. +func (c *Client) RespondHandshake(nodeID, peerID uint32, accept bool, signatureB64 string) (map[string]interface{}, error) { + msg := map[string]interface{}{ + "type": "respond_handshake", + "node_id": nodeID, + "peer_id": peerID, + "accept": accept, + } + if signatureB64 != "" { + msg["signature"] = signatureB64 + } + return c.Send(msg) +} + +// SetHostname sets or clears the hostname for a node. +// An empty hostname clears the current hostname. +func (c *Client) SetHostname(nodeID uint32, hostname string) (map[string]interface{}, error) { + msg := map[string]interface{}{ + "type": "set_hostname", + "node_id": nodeID, + "hostname": hostname, + } + sig, err := c.sign(fmt.Sprintf("set_hostname:%d", nodeID)) + if err != nil { + return nil, err + } + msg["signature"] = sig + return c.Send(msg) +} + +// SetTags sets the capability tags for a node. +func (c *Client) SetTags(nodeID uint32, tags []string) (map[string]interface{}, error) { + msg := map[string]interface{}{ + "type": "set_tags", + "node_id": nodeID, + "tags": tags, + } + sig, err := c.sign(fmt.Sprintf("set_tags:%d", nodeID)) + if err != nil { + return nil, err + } + msg["signature"] = sig + return c.Send(msg) +} + +// ResolveHostname resolves a hostname to node info (node_id, address, public flag). +func (c *Client) ResolveHostname(hostname string) (map[string]interface{}, error) { + return c.Send(map[string]interface{}{ + "type": "resolve_hostname", + "hostname": hostname, + }) +} + +// ResolveHostnameAs resolves a hostname with a requester_id for privacy checks. +// Private nodes require the requester to have a trust pair or shared network. +func (c *Client) ResolveHostnameAs(requesterID uint32, hostname string) (map[string]interface{}, error) { + return c.Send(map[string]interface{}{ + "type": "resolve_hostname", + "hostname": hostname, + "requester_id": requesterID, + }) +} + +// CheckTrust checks if a trust pair or shared network exists between two nodes. +func (c *Client) CheckTrust(nodeA, nodeB uint32) (bool, error) { + if c == nil { + return false, ErrNoRegistry + } + resp, err := c.Send(map[string]interface{}{ + "type": "check_trust", + "node_id": nodeA, + "peer_id": nodeB, + }) + if err != nil { + return false, err + } + trusted, _ := resp["trusted"].(bool) + return trusted, nil +} + +// InviteToNetwork stores a pending invite for a target node to join an invite-only network. +func (c *Client) InviteToNetwork(networkID uint16, inviterID, targetNodeID uint32, adminToken string) (map[string]interface{}, error) { + msg := map[string]interface{}{ + "type": "invite_to_network", + "network_id": networkID, + "inviter_id": inviterID, + "target_node_id": targetNodeID, + } + sig, err := c.sign(fmt.Sprintf("invite:%d:%d:%d", inviterID, networkID, targetNodeID)) + if err != nil { + return nil, err + } + msg["signature"] = sig + if adminToken != "" { + msg["admin_token"] = adminToken + } + return c.Send(msg) +} + +// PollInvites returns and clears pending network invites for a node. Signed. +func (c *Client) PollInvites(nodeID uint32) (map[string]interface{}, error) { + msg := map[string]interface{}{ + "type": "poll_invites", + "node_id": nodeID, + } + sig, err := c.sign(fmt.Sprintf("poll_invites:%d", nodeID)) + if err != nil { + return nil, err + } + msg["signature"] = sig + return c.Send(msg) +} + +// RespondInvite accepts or rejects a pending network invite. Signed. +func (c *Client) RespondInvite(nodeID uint32, networkID uint16, accept bool) (map[string]interface{}, error) { + msg := map[string]interface{}{ + "type": "respond_invite", + "node_id": nodeID, + "network_id": networkID, + "accept": accept, + } + sig, err := c.sign(fmt.Sprintf("respond_invite:%d:%d", nodeID, networkID)) + if err != nil { + return nil, err + } + msg["signature"] = sig + return c.Send(msg) +} + +// PromoteMember promotes a network member to admin. Only the owner can promote. +func (c *Client) PromoteMember(networkID uint16, nodeID, targetNodeID uint32, adminToken string) (map[string]interface{}, error) { + msg := map[string]interface{}{ + "type": "promote_member", + "network_id": networkID, + "node_id": nodeID, + "target_node_id": targetNodeID, + } + if adminToken != "" { + msg["admin_token"] = adminToken + } + return c.Send(msg) +} + +// DemoteMember demotes an admin to member. Only the owner can demote. +func (c *Client) DemoteMember(networkID uint16, nodeID, targetNodeID uint32, adminToken string) (map[string]interface{}, error) { + msg := map[string]interface{}{ + "type": "demote_member", + "network_id": networkID, + "node_id": nodeID, + "target_node_id": targetNodeID, + } + if adminToken != "" { + msg["admin_token"] = adminToken + } + return c.Send(msg) +} + +// KickMember removes a member from a network. Requires owner or admin role. +func (c *Client) KickMember(networkID uint16, nodeID, targetNodeID uint32, adminToken string) (map[string]interface{}, error) { + msg := map[string]interface{}{ + "type": "kick_member", + "network_id": networkID, + "node_id": nodeID, + "target_node_id": targetNodeID, + } + if adminToken != "" { + msg["admin_token"] = adminToken + } + return c.Send(msg) +} + +// TransferOwnership transfers network ownership to another member. Only the current owner can transfer. +func (c *Client) TransferOwnership(networkID uint16, ownerNodeID, newOwnerID uint32, adminToken string) (map[string]interface{}, error) { + msg := map[string]interface{}{ + "type": "transfer_ownership", + "network_id": networkID, + "node_id": ownerNodeID, + "new_owner_id": newOwnerID, + } + if adminToken != "" { + msg["admin_token"] = adminToken + } + return c.Send(msg) +} + +// GetMemberRole returns the RBAC role of a node in a network. +func (c *Client) GetMemberRole(networkID uint16, targetNodeID uint32) (map[string]interface{}, error) { + return c.Send(map[string]interface{}{ + "type": "get_member_role", + "network_id": networkID, + "target_node_id": targetNodeID, + }) +} + +// SetNetworkPolicy sets or updates a network's policy. Requires owner/admin role or admin token. +func (c *Client) SetNetworkPolicy(networkID uint16, policy map[string]interface{}, adminToken string) (map[string]interface{}, error) { + msg := map[string]interface{}{} + for k, v := range policy { + msg[k] = v + } + msg["type"] = "set_network_policy" + msg["network_id"] = networkID + if adminToken != "" { + msg["admin_token"] = adminToken + } + return c.Send(msg) +} + +// GetNetworkPolicy returns the policy for a given network. +func (c *Client) GetNetworkPolicy(networkID uint16) (map[string]interface{}, error) { + return c.Send(map[string]interface{}{ + "type": "get_network_policy", + "network_id": networkID, + }) +} + +// SetExprPolicy sets the programmable policy for a network. +// Requires owner/admin role or admin token. +func (c *Client) SetExprPolicy(networkID uint16, policyJSON json.RawMessage, adminToken string) (map[string]interface{}, error) { + msg := map[string]interface{}{ + "type": "set_expr_policy", + "network_id": networkID, + "expr_policy": string(policyJSON), + } + if adminToken != "" { + msg["admin_token"] = adminToken + } + return c.Send(msg) +} + +// GetExprPolicy returns the programmable policy for a network. +func (c *Client) GetExprPolicy(networkID uint16) (map[string]interface{}, error) { + return c.Send(map[string]interface{}{ + "type": "get_expr_policy", + "network_id": networkID, + }) +} + +// SetKeyExpiry sets the key expiry time for a node. Requires signature. +func (c *Client) SetKeyExpiry(nodeID uint32, expiresAt time.Time) (map[string]interface{}, error) { + msg := map[string]interface{}{ + "type": "set_key_expiry", + "node_id": nodeID, + "expires_at": expiresAt.Format(time.RFC3339), + } + sig, err := c.sign(fmt.Sprintf("set_key_expiry:%d", nodeID)) + if err != nil { + return nil, err + } + msg["signature"] = sig + return c.Send(msg) +} + +// GetKeyInfo returns key lifecycle metadata for a node. +func (c *Client) GetKeyInfo(nodeID uint32) (map[string]interface{}, error) { + return c.Send(map[string]interface{}{ + "type": "get_key_info", + "node_id": nodeID, + }) +} + +// --- Admin methods (bypass node signature, use admin_token instead) --- + +// SetHostnameAdmin sets a node's hostname using admin token auth. +func (c *Client) SetHostnameAdmin(nodeID uint32, hostname, adminToken string) (map[string]interface{}, error) { + return c.Send(map[string]interface{}{ + "type": "set_hostname", + "node_id": nodeID, + "hostname": hostname, + "admin_token": adminToken, + }) +} + +// SetVisibilityAdmin sets a node's visibility using admin token auth. +func (c *Client) SetVisibilityAdmin(nodeID uint32, public bool, adminToken string) (map[string]interface{}, error) { + return c.Send(map[string]interface{}{ + "type": "set_visibility", + "node_id": nodeID, + "public": public, + "admin_token": adminToken, + }) +} + +// SetTagsAdmin sets a node's tags using admin token auth. +func (c *Client) SetTagsAdmin(nodeID uint32, tags []string, adminToken string) (map[string]interface{}, error) { + return c.Send(map[string]interface{}{ + "type": "set_tags", + "node_id": nodeID, + "tags": tags, + "admin_token": adminToken, + }) +} + +// SetMemberTags sets admin-assigned tags for a member within a network. +func (c *Client) SetMemberTags(netID uint16, targetNodeID uint32, tags []string, adminToken string) (map[string]interface{}, error) { + return c.Send(map[string]interface{}{ + "type": "set_member_tags", + "network_id": netID, + "target_node_id": targetNodeID, + "tags": tags, + "admin_token": adminToken, + }) +} + +// GetMemberTags returns admin-assigned member tags for a node (or all members if targetNodeID=0). +func (c *Client) GetMemberTags(netID uint16, targetNodeID uint32) (map[string]interface{}, error) { + return c.Send(map[string]interface{}{ + "type": "get_member_tags", + "network_id": netID, + "target_node_id": targetNodeID, + }) +} + +// SetKeyExpiryAdmin sets a node's key expiry using admin token auth. +func (c *Client) SetKeyExpiryAdmin(nodeID uint32, expiresAt time.Time, adminToken string) (map[string]interface{}, error) { + return c.Send(map[string]interface{}{ + "type": "set_key_expiry", + "node_id": nodeID, + "expires_at": expiresAt.Format(time.RFC3339), + "admin_token": adminToken, + }) +} + +// ClearKeyExpiryAdmin removes the key expiry from a node using admin token auth. +func (c *Client) ClearKeyExpiryAdmin(nodeID uint32, adminToken string) (map[string]interface{}, error) { + return c.Send(map[string]interface{}{ + "type": "set_key_expiry", + "node_id": nodeID, + "expires_at": "never", + "admin_token": adminToken, + }) +} + +// DeregisterAdmin removes a node using admin token auth. +func (c *Client) DeregisterAdmin(nodeID uint32, adminToken string) (map[string]interface{}, error) { + return c.Send(map[string]interface{}{ + "type": "deregister", + "node_id": nodeID, + "admin_token": adminToken, + }) +} + +// GetAuditLog returns recent audit entries from the registry. +func (c *Client) GetAuditLog(networkID uint16, adminToken string) (map[string]interface{}, error) { + msg := map[string]interface{}{ + "type": "get_audit_log", + "admin_token": adminToken, + } + if networkID != 0 { + msg["network_id"] = networkID + } + return c.Send(msg) +} + +// SetWebhook configures the registry webhook URL. Pass empty string to disable. +func (c *Client) SetWebhook(url, adminToken string) (map[string]interface{}, error) { + return c.Send(map[string]interface{}{ + "type": "set_webhook", + "url": url, + "admin_token": adminToken, + }) +} + +// GetWebhook returns the current webhook configuration. +func (c *Client) GetWebhook(adminToken string) (map[string]interface{}, error) { + return c.Send(map[string]interface{}{ + "type": "get_webhook", + "admin_token": adminToken, + }) +} + +// GetWebhookDLQ returns the dead letter queue (failed webhook events). +func (c *Client) GetWebhookDLQ(adminToken string) (map[string]interface{}, error) { + return c.Send(map[string]interface{}{ + "type": "get_webhook_dlq", + "admin_token": adminToken, + }) +} + +// SetIdentityWebhook configures the identity verification webhook URL. +func (c *Client) SetIdentityWebhook(url, adminToken string) (map[string]interface{}, error) { + return c.Send(map[string]interface{}{ + "type": "set_identity_webhook", + "url": url, + "admin_token": adminToken, + }) +} + +// SetExternalID sets the external identity on a node. Requires admin token. +func (c *Client) SetExternalID(nodeID uint32, externalID, adminToken string) (map[string]interface{}, error) { + return c.Send(map[string]interface{}{ + "type": "set_external_id", + "node_id": nodeID, + "external_id": externalID, + "admin_token": adminToken, + }) +} + +// GetIdentity returns the external identity of a node. Requires admin token. +func (c *Client) GetIdentity(nodeID uint32, adminToken string) (map[string]interface{}, error) { + return c.Send(map[string]interface{}{ + "type": "get_identity", + "node_id": nodeID, + "admin_token": adminToken, + }) +} + +// ProvisionNetwork applies a network blueprint. Requires admin token. +func (c *Client) ProvisionNetwork(blueprint map[string]interface{}, adminToken string) (map[string]interface{}, error) { + return c.Send(map[string]interface{}{ + "type": "provision_network", + "blueprint": blueprint, + "admin_token": adminToken, + }) +} + +// SetAuditExport configures the audit export adapter. Requires admin token. +func (c *Client) SetAuditExport(format, endpoint, token, index, source, adminToken string) (map[string]interface{}, error) { + return c.Send(map[string]interface{}{ + "type": "set_audit_export", + "format": format, + "endpoint": endpoint, + "token": token, + "index": index, + "source": source, + "admin_token": adminToken, + }) +} + +// GetAuditExport returns the current audit export configuration. Requires admin token. +func (c *Client) GetAuditExport(adminToken string) (map[string]interface{}, error) { + return c.Send(map[string]interface{}{ + "type": "get_audit_export", + "admin_token": adminToken, + }) +} + +// SetIDPConfig configures the identity provider. Requires admin token. +func (c *Client) SetIDPConfig(idpType, url, issuer, clientID, tenantID, domain, adminToken string) (map[string]interface{}, error) { + msg := map[string]interface{}{ + "type": "set_idp_config", + "idp_type": idpType, + "url": url, + "admin_token": adminToken, + } + if issuer != "" { + msg["issuer"] = issuer + } + if clientID != "" { + msg["client_id"] = clientID + } + if tenantID != "" { + msg["tenant_id"] = tenantID + } + if domain != "" { + msg["domain"] = domain + } + return c.Send(msg) +} + +// GetIDPConfig returns the current identity provider configuration. Requires admin token. +func (c *Client) GetIDPConfig(adminToken string) (map[string]interface{}, error) { + return c.Send(map[string]interface{}{ + "type": "get_idp_config", + "admin_token": adminToken, + }) +} + +// GetProvisionStatus returns per-network provisioning status. Requires admin token. +func (c *Client) GetProvisionStatus(adminToken string) (map[string]interface{}, error) { + return c.Send(map[string]interface{}{ + "type": "get_provision_status", + "admin_token": adminToken, + }) +} + +// DirectorySync pushes a directory listing to update RBAC roles and membership. +func (c *Client) DirectorySync(networkID uint16, entries []map[string]interface{}, removeUnlisted bool, adminToken string) (map[string]interface{}, error) { + entryList := make([]interface{}, len(entries)) + for i, e := range entries { + entryList[i] = e + } + return c.Send(map[string]interface{}{ + "type": "directory_sync", + "network_id": networkID, + "entries": entryList, + "remove_unlisted": removeUnlisted, + "admin_token": adminToken, + }) +} + +// DirectoryStatus returns directory sync status for a network. +func (c *Client) DirectoryStatus(networkID uint16, adminToken string) (map[string]interface{}, error) { + return c.Send(map[string]interface{}{ + "type": "directory_status", + "network_id": networkID, + "admin_token": adminToken, + }) +} + +// ValidateToken validates a JWT token against the configured IDP. Requires admin token. +func (c *Client) ValidateToken(token, adminToken string) (map[string]interface{}, error) { + return c.Send(map[string]interface{}{ + "type": "validate_token", + "token": token, + "admin_token": adminToken, + }) +} diff --git a/registry/client/zz_binary_client_test.go b/registry/client/zz_binary_client_test.go new file mode 100644 index 0000000..e8f6913 --- /dev/null +++ b/registry/client/zz_binary_client_test.go @@ -0,0 +1,550 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package client + +import ( + "encoding/binary" + "encoding/json" + "io" + "net" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/pilot-protocol/common/registry/wire" +) + +// Iter-116 coverage for registry/binary_client.go — 9 zero-coverage functions: +// DialBinary, Close, Addr, reconnect, Heartbeat/heartbeatLocked, Lookup/lookupLocked, +// Resolve/resolveLocked, SendJSON/sendJSONLocked. Strategy: stand up a real TCP +// listener that reads the 5-byte handshake (magic + version), then runs a +// per-test frame handler against the wire protocol via wire.ReadFrame/wire.WriteFrame. + +// --- fakeBinaryServer: minimal TCP server speaking the binary wire protocol --- + +type fakeBinaryServer struct { + ln net.Listener + handler func(msgType byte, payload []byte) (respType byte, respPayload []byte) + mu sync.Mutex + handshakes atomic.Uint32 + frames atomic.Uint32 + done chan struct{} +} + +func newFakeBinaryServer(t *testing.T, handler func(msgType byte, payload []byte) (byte, []byte)) *fakeBinaryServer { + t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + s := &fakeBinaryServer{ln: ln, handler: handler, done: make(chan struct{})} + go s.accept() + t.Cleanup(s.Close) + return s +} + +func (s *fakeBinaryServer) addr() string { return s.ln.Addr().String() } + +func (s *fakeBinaryServer) Close() { + s.mu.Lock() + defer s.mu.Unlock() + select { + case <-s.done: + return + default: + } + close(s.done) + s.ln.Close() +} + +func (s *fakeBinaryServer) accept() { + for { + conn, err := s.ln.Accept() + if err != nil { + return + } + go s.handle(conn) + } +} + +func (s *fakeBinaryServer) handle(conn net.Conn) { + defer conn.Close() + // Read 5-byte handshake. + var hdr [5]byte + if _, err := io.ReadFull(conn, hdr[:]); err != nil { + return + } + s.handshakes.Add(1) + // Verify magic — but don't enforce version. + for i, b := range wire.Magic { + if hdr[i] != b { + return + } + } + // Per-frame loop. + for { + msgType, payload, err := wire.ReadFrame(conn) + if err != nil { + return + } + s.frames.Add(1) + if s.handler == nil { + return + } + respType, respPayload := s.handler(msgType, payload) + if respType == 0 && respPayload == nil { + // Sentinel for "close without responding" — test uses this to force recv error. + return + } + if err := wire.WriteFrame(conn, respType, respPayload); err != nil { + return + } + } +} + +// --- DialBinary: success, dial error, handshake write error --- + +func TestDialBinarySuccess(t *testing.T) { + t.Parallel() + srv := newFakeBinaryServer(t, nil) + c, err := DialBinary(srv.addr()) + if err != nil { + t.Fatalf("DialBinary: %v", err) + } + defer c.Close() + + if c.Addr() != srv.addr() { + t.Fatalf("Addr = %q, want %q", c.Addr(), srv.addr()) + } + // Wait for the server to see the handshake. + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if srv.handshakes.Load() == 1 { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatal("server did not receive handshake within 2s") +} + +func TestDialBinaryDialErrorWrapsMessage(t *testing.T) { + t.Parallel() + // Port 1 on 127.0.0.1 is almost certainly not listening (privileged, reserved). + // The dial will fail with ECONNREFUSED within the 5s timeout. + _, err := DialBinary("127.0.0.1:1") + if err == nil { + t.Fatal("DialBinary expected error on unreachable addr") + } + // The wrap format is `dial registry: `. We don't pin the exact text. + if len(err.Error()) == 0 { + t.Fatal("error message is empty") + } +} + +// --- Close: nil-conn path + idempotency --- + +func TestBinaryClientCloseIsSafeWithNilConn(t *testing.T) { + t.Parallel() + c := &BinaryClient{conn: nil} + if err := c.Close(); err != nil { + t.Fatalf("Close on nil conn = %v, want nil (no panic, no err)", err) + } + // Second Close is also safe — closed flag set, conn already nil. + if err := c.Close(); err != nil { + t.Fatalf("second Close = %v, want nil", err) + } +} + +// --- Addr: returns the configured addr without connection --- + +func TestBinaryClientAddrReflectsCtorValue(t *testing.T) { + t.Parallel() + c := &BinaryClient{addr: "host.example:9000"} + if got := c.Addr(); got != "host.example:9000" { + t.Fatalf("Addr = %q, want host.example:9000", got) + } +} + +// --- Heartbeat: happy path returns unixTime + warning flag --- + +func TestHeartbeatHappyPathReturnsTimeAndWarning(t *testing.T) { + t.Parallel() + srv := newFakeBinaryServer(t, func(msgType byte, payload []byte) (byte, []byte) { + if msgType != wire.MsgHeartbeat { + return wire.MsgError, wire.EncodeError("unexpected msg") + } + req, err := wire.DecodeHeartbeatReq(payload) + if err != nil { + return wire.MsgError, wire.EncodeError(err.Error()) + } + if req.NodeID != 12345 { + return wire.MsgError, wire.EncodeError("wrong node id") + } + return wire.MsgHeartbeatOK, wire.EncodeHeartbeatResp(1_700_000_000, true) + }) + + c, err := DialBinary(srv.addr()) + if err != nil { + t.Fatalf("DialBinary: %v", err) + } + defer c.Close() + + sig := make([]byte, 64) + unixTime, warn, err := c.Heartbeat(12345, sig) + if err != nil { + t.Fatalf("Heartbeat: %v", err) + } + if unixTime != 1_700_000_000 { + t.Fatalf("unixTime = %d, want 1_700_000_000", unixTime) + } + if !warn { + t.Fatal("keyExpiryWarning = false, want true") + } +} + +// --- Heartbeat: server returns wire.MsgError → client surfaces "registry: " --- + +func TestHeartbeatServerErrorResponseReturnsWrappedError(t *testing.T) { + t.Parallel() + srv := newFakeBinaryServer(t, func(msgType byte, payload []byte) (byte, []byte) { + return wire.MsgError, wire.EncodeError("node not registered") + }) + + c, err := DialBinary(srv.addr()) + if err != nil { + t.Fatalf("DialBinary: %v", err) + } + defer c.Close() + + _, _, err = c.Heartbeat(9999, make([]byte, 64)) + if err == nil { + t.Fatal("Heartbeat should return error when server sends wire.MsgError") + } + if got := err.Error(); got != "registry: node not registered" { + t.Fatalf("err = %q, want %q", got, "registry: node not registered") + } +} + +// --- Heartbeat: unexpected response type → error --- + +func TestHeartbeatUnexpectedResponseTypeReturnsError(t *testing.T) { + t.Parallel() + srv := newFakeBinaryServer(t, func(msgType byte, payload []byte) (byte, []byte) { + // Respond with a LookupOK type instead of HeartbeatOK. + return wire.MsgLookupOK, []byte{0, 0, 0, 0} + }) + + c, err := DialBinary(srv.addr()) + if err != nil { + t.Fatalf("DialBinary: %v", err) + } + defer c.Close() + + _, _, err = c.Heartbeat(1, make([]byte, 64)) + if err == nil { + t.Fatal("expected error on unexpected response type") + } +} + +// --- Lookup: happy path decodes wire.LookupResult --- + +func TestLookupHappyPathDecodesResult(t *testing.T) { + t.Parallel() + srv := newFakeBinaryServer(t, func(msgType byte, payload []byte) (byte, []byte) { + if msgType != wire.MsgLookup { + return wire.MsgError, wire.EncodeError("bad type") + } + return wire.MsgLookupOK, wire.EncodeLookupResp( + 42, // nodeID + true, false, // public, taskExec + []uint16{1, 2}, // networks + []byte{0xAB}, // pubkey + "host.example", // hostname + []string{"t1"}, // tags + "1.2.3.4:444", // realAddr + "ext-123", // externalID + ) + }) + + c, err := DialBinary(srv.addr()) + if err != nil { + t.Fatalf("DialBinary: %v", err) + } + defer c.Close() + + res, err := c.Lookup(42) + if err != nil { + t.Fatalf("Lookup: %v", err) + } + if res.NodeID != 42 { + t.Fatalf("NodeID = %d", res.NodeID) + } + if !res.Public || res.TaskExec { + t.Fatalf("flags: public=%v taskExec=%v", res.Public, res.TaskExec) + } + if len(res.Networks) != 2 || res.Networks[0] != 1 || res.Networks[1] != 2 { + t.Fatalf("Networks = %v", res.Networks) + } + if res.Hostname != "host.example" { + t.Fatalf("Hostname = %q", res.Hostname) + } + if len(res.Tags) != 1 || res.Tags[0] != "t1" { + t.Fatalf("Tags = %v", res.Tags) + } + if res.RealAddr != "1.2.3.4:444" { + t.Fatalf("RealAddr = %q", res.RealAddr) + } + if res.ExternalID != "ext-123" { + t.Fatalf("ExternalID = %q", res.ExternalID) + } +} + +// --- Lookup: unexpected response type --- + +func TestLookupUnexpectedResponseTypeReturnsError(t *testing.T) { + t.Parallel() + srv := newFakeBinaryServer(t, func(msgType byte, payload []byte) (byte, []byte) { + return wire.MsgHeartbeatOK, wire.EncodeHeartbeatResp(0, false) + }) + c, err := DialBinary(srv.addr()) + if err != nil { + t.Fatalf("DialBinary: %v", err) + } + defer c.Close() + + if _, err := c.Lookup(99); err == nil { + t.Fatal("expected error on wrong response type") + } +} + +// --- Resolve: happy path decodes wire.ResolveResult --- + +func TestResolveHappyPathDecodesResult(t *testing.T) { + t.Parallel() + srv := newFakeBinaryServer(t, func(msgType byte, payload []byte) (byte, []byte) { + if msgType != wire.MsgResolve { + return wire.MsgError, wire.EncodeError("bad type") + } + return wire.MsgResolveOK, wire.EncodeResolveResp( + 77, "10.0.0.1:5000", + []string{"192.168.1.1:5000", "192.168.1.2:5000"}, + 42, + ) + }) + + c, err := DialBinary(srv.addr()) + if err != nil { + t.Fatalf("DialBinary: %v", err) + } + defer c.Close() + + res, err := c.Resolve(77, 1, make([]byte, 64)) + if err != nil { + t.Fatalf("Resolve: %v", err) + } + if res.NodeID != 77 { + t.Fatalf("NodeID = %d", res.NodeID) + } + if res.RealAddr != "10.0.0.1:5000" { + t.Fatalf("RealAddr = %q", res.RealAddr) + } + if len(res.LANAddrs) != 2 { + t.Fatalf("LANAddrs = %v", res.LANAddrs) + } + if res.KeyAgeDays != 42 { + t.Fatalf("KeyAgeDays = %d", res.KeyAgeDays) + } +} + +// --- Resolve: -1 key_age_days (MaxUint32 in wire) --- + +func TestResolveMaxUint32KeyAgeMapsToNegativeOne(t *testing.T) { + t.Parallel() + srv := newFakeBinaryServer(t, func(msgType byte, payload []byte) (byte, []byte) { + return wire.MsgResolveOK, wire.EncodeResolveResp(1, "a:1", nil, -1) + }) + c, err := DialBinary(srv.addr()) + if err != nil { + t.Fatalf("DialBinary: %v", err) + } + defer c.Close() + + res, err := c.Resolve(1, 1, make([]byte, 64)) + if err != nil { + t.Fatalf("Resolve: %v", err) + } + if res.KeyAgeDays != -1 { + t.Fatalf("KeyAgeDays = %d, want -1 (MaxUint32 sentinel)", res.KeyAgeDays) + } +} + +// --- SendJSON: roundtrip of a generic map --- + +func TestSendJSONRoundtripsGenericMap(t *testing.T) { + t.Parallel() + srv := newFakeBinaryServer(t, func(msgType byte, payload []byte) (byte, []byte) { + if msgType != wire.MsgJSON { + return wire.MsgError, wire.EncodeError("bad type") + } + var req map[string]interface{} + if err := json.Unmarshal(payload, &req); err != nil { + return wire.MsgError, wire.EncodeError(err.Error()) + } + resp := map[string]interface{}{ + "type": "ok", + "echo": req["x"], + } + body, _ := json.Marshal(resp) + return wire.MsgJSON, body + }) + + c, err := DialBinary(srv.addr()) + if err != nil { + t.Fatalf("DialBinary: %v", err) + } + defer c.Close() + + resp, err := c.SendJSON(map[string]interface{}{"x": 7.0}) + if err != nil { + t.Fatalf("SendJSON: %v", err) + } + if resp["type"] != "ok" { + t.Fatalf("resp.type = %v, want ok", resp["type"]) + } + if got, _ := resp["echo"].(float64); got != 7 { + t.Fatalf("resp.echo = %v, want 7", resp["echo"]) + } +} + +// --- SendJSON: server returns wire.MsgError (server-side protocol error) --- + +func TestSendJSONWireMsgErrorReturnsMapWithError(t *testing.T) { + t.Parallel() + srv := newFakeBinaryServer(t, func(msgType byte, payload []byte) (byte, []byte) { + return wire.MsgError, wire.EncodeError("rate limited") + }) + c, err := DialBinary(srv.addr()) + if err != nil { + t.Fatalf("DialBinary: %v", err) + } + defer c.Close() + + resp, err := c.SendJSON(map[string]interface{}{"op": "whatever"}) + if err == nil { + t.Fatal("expected error on wire.MsgError") + } + if resp == nil { + t.Fatal("resp must NOT be nil on wire.MsgError — caller relies on non-nil to skip reconnect") + } + if resp["type"] != "error" || resp["error"] != "rate limited" { + t.Fatalf("resp = %v, want type=error error=rate limited", resp) + } +} + +// --- SendJSON: application-level error field in normal JSON response --- + +func TestSendJSONReturnsErrorWhenResponseHasErrorField(t *testing.T) { + t.Parallel() + srv := newFakeBinaryServer(t, func(msgType byte, payload []byte) (byte, []byte) { + resp := map[string]interface{}{"type": "bad", "error": "invalid op"} + body, _ := json.Marshal(resp) + return wire.MsgJSON, body + }) + c, err := DialBinary(srv.addr()) + if err != nil { + t.Fatalf("DialBinary: %v", err) + } + defer c.Close() + + resp, err := c.SendJSON(map[string]interface{}{"op": "x"}) + if err == nil { + t.Fatal("expected error when response has error field") + } + if got := err.Error(); got != "registry: invalid op" { + t.Fatalf("err = %q, want %q", got, "registry: invalid op") + } + if resp["type"] != "bad" { + t.Fatalf("resp.type = %v, want bad", resp["type"]) + } +} + +// --- SendJSON: unexpected response type --- + +func TestSendJSONUnexpectedResponseTypeReturnsError(t *testing.T) { + t.Parallel() + srv := newFakeBinaryServer(t, func(msgType byte, payload []byte) (byte, []byte) { + return wire.MsgLookupOK, []byte{0, 0, 0, 0} + }) + c, err := DialBinary(srv.addr()) + if err != nil { + t.Fatalf("DialBinary: %v", err) + } + defer c.Close() + + _, err = c.SendJSON(map[string]interface{}{"op": "x"}) + if err == nil { + t.Fatal("expected error on wrong response type") + } +} + +// --- SendJSON: server returns malformed JSON in wire.MsgJSON → decode err --- + +func TestSendJSONMalformedResponseReturnsDecodeError(t *testing.T) { + t.Parallel() + srv := newFakeBinaryServer(t, func(msgType byte, payload []byte) (byte, []byte) { + return wire.MsgJSON, []byte("not valid json }{") + }) + c, err := DialBinary(srv.addr()) + if err != nil { + t.Fatalf("DialBinary: %v", err) + } + defer c.Close() + + _, err = c.SendJSON(map[string]interface{}{"op": "x"}) + if err == nil { + t.Fatal("expected decode error on malformed JSON response") + } +} + +// --- encode/decode round-trips: sanity of our test helpers as well as SUT symmetry --- + +func TestEncodeDecodeHeartbeatReqRoundTrip(t *testing.T) { + t.Parallel() + sig := make([]byte, 64) + for i := range sig { + sig[i] = byte(i) + } + buf := wire.EncodeHeartbeatReq(0xDEADBEEF, sig) + req, err := wire.DecodeHeartbeatReq(buf) + if err != nil { + t.Fatalf("decode: %v", err) + } + if req.NodeID != 0xDEADBEEF { + t.Fatalf("NodeID = %x", req.NodeID) + } + for i := 0; i < 64; i++ { + if req.Signature[i] != byte(i) { + t.Fatalf("sig[%d] = %x, want %x", i, req.Signature[i], i) + } + } +} + +func TestDecodeWireErrorShortPayloadReturnsSentinel(t *testing.T) { + t.Parallel() + if got := wire.DecodeError([]byte{0x00}); got != "unknown error" { + t.Fatalf("wire.DecodeError(short) = %q, want unknown error", got) + } +} + +func TestDecodeWireErrorTruncatesToActualLen(t *testing.T) { + t.Parallel() + // Claim length=100 but only 5 real bytes follow — decoder clamps to available. + buf := make([]byte, 7) + binary.BigEndian.PutUint16(buf[:2], 100) + copy(buf[2:], []byte("hello")) + got := wire.DecodeError(buf) + if got != "hello" { + t.Fatalf("wire.DecodeError(truncated) = %q, want hello", got) + } +} diff --git a/registry/client/zz_client_branch_test.go b/registry/client/zz_client_branch_test.go new file mode 100644 index 0000000..9d23714 --- /dev/null +++ b/registry/client/zz_client_branch_test.go @@ -0,0 +1,444 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package client + +import ( + "context" + "crypto/tls" + "encoding/json" + "net" + "strings" + "testing" + + "github.com/pilot-protocol/common/registry/wire" +) + +// Branch-fill tests: every wrapper that takes an optional adminToken, +// signature, or variadic flag has an untested branch when the optional +// arg is blank. This file ticks the remaining `if x != ""` / `if len(...) > 0` +// branches and the binary_client reconnect/lookup/resolve error edges. + +// --- Client member-mgmt wrappers: with-adminToken branches -------------- +// +// Existing tests cover the blank-token path; here we cover the non-blank +// branch so the `if adminToken != ""` is exercised both ways. + +func TestPromoteDemoteKickTransferIncludeAdminToken(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + cases := []struct { + name string + call func() (map[string]interface{}, error) + targetKey string + }{ + {"promote", func() (map[string]interface{}, error) { return c.PromoteMember(1, 2, 3, "ADM") }, "target_node_id"}, + {"demote", func() (map[string]interface{}, error) { return c.DemoteMember(1, 2, 3, "ADM") }, "target_node_id"}, + {"kick", func() (map[string]interface{}, error) { return c.KickMember(1, 2, 3, "ADM") }, "target_node_id"}, + {"transfer", func() (map[string]interface{}, error) { return c.TransferOwnership(1, 2, 3, "ADM") }, "new_owner_id"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + resp, err := tc.call() + if err != nil { + t.Fatalf("%s: %v", tc.name, err) + } + echo := assertEcho(t, resp) + if got, _ := echo["admin_token"].(string); got != "ADM" { + t.Fatalf("%s: admin_token: %q", tc.name, got) + } + if got, _ := echo[tc.targetKey].(float64); uint32(got) != 3 { + t.Fatalf("%s: %s: %v", tc.name, tc.targetKey, got) + } + }) + } +} + +// --- ReportTrust / RevokeTrust / SetVisibility: WITH-signer branch ----- +// +// Existing TestReportTrustAndRevokeTrustFormat and TestSetVisibilityPublicFlagSerialized +// drive the no-signer (sig empty) path. Cover the signer-attached branch +// so `if sig := ...; sig != ""` is hit both ways. + +func TestReportRevokeVisibilityIncludeSignatureWhenSignerSet(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + c.SetSigner(func(ch string) string { return "SIG:" + ch }) + cases := []struct { + name string + call func() (map[string]interface{}, error) + challenge string + }{ + {"report_trust", func() (map[string]interface{}, error) { return c.ReportTrust(1, 2) }, "report_trust:1:2"}, + {"revoke_trust", func() (map[string]interface{}, error) { return c.RevokeTrust(1, 2) }, "revoke_trust:1:2"}, + {"set_visibility", func() (map[string]interface{}, error) { return c.SetVisibility(9, false) }, "set_visibility:9"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + resp, err := tc.call() + if err != nil { + t.Fatalf("%s: %v", tc.name, err) + } + echo := assertEcho(t, resp) + if got, _ := echo["signature"].(string); got != "SIG:"+tc.challenge { + t.Fatalf("%s: signature: want SIG:%s, got %q", tc.name, tc.challenge, got) + } + }) + } +} + +// --- CreateManagedNetwork full-options branch ----------------------------- + +func TestCreateManagedNetworkFullOptions(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + resp, err := c.CreateManagedNetwork(2, "n", "invite", "tok", "ADM", true, `{"a":1}`, "NAT") + if err != nil { + t.Fatalf("create: %v", err) + } + echo := assertEcho(t, resp) + if got, _ := echo["admin_token"].(string); got != "ADM" { + t.Fatalf("admin_token: %q", got) + } + if got, _ := echo["enterprise"].(bool); !got { + t.Fatalf("enterprise: %v", got) + } + if got, _ := echo["network_admin_token"].(string); got != "NAT" { + t.Fatalf("network_admin_token: %q", got) + } +} + +// --- ListNetworks with adminToken ---------------------------------------- + +func TestListNetworksWithAdminTokenIncludesField(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + resp, err := c.ListNetworks("SUPER") + if err != nil { + t.Fatalf("list: %v", err) + } + echo := assertEcho(t, resp) + if got, _ := echo["admin_token"].(string); got != "SUPER" { + t.Fatalf("admin_token: %q", got) + } +} + +// --- DialTLS error wrapping happy/sad already covered; sad path only used +// "dial registry TLS" prefix once. Cover the happy-path connect branch with +// a real TLS listener that closes immediately. ----------------------------- + +// --- binary_client: reconnect after failure ------------------------------ + +// TestBinaryHeartbeatReconnectsAfterBrokenConn covers the +// `err != nil && !c.closed → reconnect → retry` branch in Heartbeat. +func TestBinaryHeartbeatReconnectsAfterBrokenConn(t *testing.T) { + t.Parallel() + srv := newFakeBinaryServer(t, func(msgType byte, payload []byte) (byte, []byte) { + if msgType != wire.MsgHeartbeat { + return wire.MsgError, wire.EncodeError("bad") + } + return wire.MsgHeartbeatOK, wire.EncodeHeartbeatResp(123, false) + }) + + c, err := DialBinary(srv.addr()) + if err != nil { + t.Fatalf("DialBinary: %v", err) + } + defer c.Close() + + // Forcibly close the underlying conn so the next Heartbeat triggers + // reconnect+retry. + c.mu.Lock() + _ = c.conn.Close() + c.mu.Unlock() + + unixTime, _, err := c.Heartbeat(1, make([]byte, 64)) + if err != nil { + t.Fatalf("Heartbeat after broken conn: %v", err) + } + if unixTime != 123 { + t.Fatalf("unixTime = %d, want 123", unixTime) + } +} + +// TestBinaryLookupReconnectsAfterBrokenConn covers the reconnect branch +// inside Lookup. +func TestBinaryLookupReconnectsAfterBrokenConn(t *testing.T) { + t.Parallel() + srv := newFakeBinaryServer(t, func(msgType byte, payload []byte) (byte, []byte) { + if msgType != wire.MsgLookup { + return wire.MsgError, wire.EncodeError("bad") + } + return wire.MsgLookupOK, wire.EncodeLookupResp( + 7, true, false, nil, nil, "h", nil, "1.1.1.1:1", "", + ) + }) + + c, err := DialBinary(srv.addr()) + if err != nil { + t.Fatalf("DialBinary: %v", err) + } + defer c.Close() + + c.mu.Lock() + _ = c.conn.Close() + c.mu.Unlock() + + res, err := c.Lookup(7) + if err != nil { + t.Fatalf("Lookup: %v", err) + } + if res.NodeID != 7 { + t.Fatalf("NodeID = %d", res.NodeID) + } +} + +// TestBinaryResolveReconnectsAfterBrokenConn covers the reconnect branch +// inside Resolve. +func TestBinaryResolveReconnectsAfterBrokenConn(t *testing.T) { + t.Parallel() + srv := newFakeBinaryServer(t, func(msgType byte, payload []byte) (byte, []byte) { + if msgType != wire.MsgResolve { + return wire.MsgError, wire.EncodeError("bad") + } + return wire.MsgResolveOK, wire.EncodeResolveResp(8, "2.2.2.2:2", nil, 0) + }) + + c, err := DialBinary(srv.addr()) + if err != nil { + t.Fatalf("DialBinary: %v", err) + } + defer c.Close() + + c.mu.Lock() + _ = c.conn.Close() + c.mu.Unlock() + + res, err := c.Resolve(8, 1, make([]byte, 64)) + if err != nil { + t.Fatalf("Resolve: %v", err) + } + if res.NodeID != 8 { + t.Fatalf("NodeID = %d", res.NodeID) + } +} + +// TestBinarySendJSONReconnectsAfterBrokenConn covers the reconnect branch +// inside SendJSON. +func TestBinarySendJSONReconnectsAfterBrokenConn(t *testing.T) { + t.Parallel() + srv := newFakeBinaryServer(t, func(msgType byte, payload []byte) (byte, []byte) { + if msgType != wire.MsgJSON { + return wire.MsgError, wire.EncodeError("bad") + } + body, _ := json.Marshal(map[string]interface{}{"type": "ok"}) + return wire.MsgJSON, body + }) + + c, err := DialBinary(srv.addr()) + if err != nil { + t.Fatalf("DialBinary: %v", err) + } + defer c.Close() + + c.mu.Lock() + _ = c.conn.Close() + c.mu.Unlock() + + resp, err := c.SendJSON(map[string]interface{}{"op": "x"}) + if err != nil { + t.Fatalf("SendJSON: %v", err) + } + if resp["type"] != "ok" { + t.Fatalf("type: %v", resp["type"]) + } +} + +// TestBinaryReconnectAllAttemptsFail covers the failure path: all 5 +// reconnect attempts fail and the client surfaces "reconnect failed". +// +// We pass a small backoff window indirectly by pointing at a closed port. +// 5 attempts * ~0.5s backoff each is up to ~7.5s of sleeping inside +// reconnect — too slow for -short. Instead, exercise the immediate path: +// close the client first so reconnect returns "client closed" without +// sleeping. This still covers the c.closed branch. +func TestBinaryReconnectShortCircuitsWhenClosed(t *testing.T) { + t.Parallel() + srv := newFakeBinaryServer(t, func(byte, []byte) (byte, []byte) { + return wire.MsgHeartbeatOK, wire.EncodeHeartbeatResp(1, false) + }) + + c, err := DialBinary(srv.addr()) + if err != nil { + t.Fatalf("DialBinary: %v", err) + } + + // Tear down the listener to make a future dial fail, then close the + // client to force the reconnect early-return branch. + srv.Close() + // Drop the local conn and mark closed before triggering reconnect. + c.mu.Lock() + _ = c.conn.Close() + c.closed = true + err = c.reconnect() + c.mu.Unlock() + if err == nil { + t.Fatalf("reconnect after Close should fail") + } + if !strings.Contains(err.Error(), "client closed") { + t.Fatalf("expected 'client closed' error, got: %v", err) + } +} + +// TestBinaryDialBinaryHandshakeWriteFailure exercises the +// "conn.Write(handshake) fails" branch in DialBinary. We can't intercept +// the write directly, but we can race a close: connect to a listener that +// accepts then immediately closes the conn before the handshake write +// completes. On macOS this typically surfaces as a write error on a +// half-closed socket. +func TestBinaryDialBinaryHandshakeWriteFailure(t *testing.T) { + t.Parallel() + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + addr := ln.Addr().String() + + done := make(chan struct{}) + go func() { + defer close(done) + conn, err := ln.Accept() + if err != nil { + return + } + // Drop the conn immediately. Then close the listener so subsequent + // dials return ECONNREFUSED if the test re-runs. + conn.Close() + }() + // DialBinary may succeed (handshake writes 5 bytes into the kernel buffer + // before EOF is observed) or fail. Either is fine — we just need the path + // to execute and not panic. The accept goroutine guarantees a real connect. + c, err := DialBinary(addr) + if err == nil && c != nil { + c.Close() + } + <-done + ln.Close() + + // As a deterministic companion: dial a closed port. This exercises the + // "net.Dial fails" branch. + closed, _ := net.Listen("tcp", "127.0.0.1:0") + closedAddr := closed.Addr().String() + closed.Close() + _, err = DialBinary(closedAddr) + if err == nil { + t.Fatalf("DialBinary to closed port should fail") + } +} + +// --- Backoff cap check via reconnect against unreachable addr ------------ + +// TestClientReconnectBackoffCapsAtMax pushes reconnect into >5s of backoff +// growth and verifies it returns the eventual "reconnect failed" wrap. +// We use a Client whose addr points to a closed port; reconnect dials it +// 5 times then gives up. Using a fresh-grabbed-then-released kernel port +// keeps each failed dial fast (ECONNREFUSED on loopback is sub-ms). +func TestClientReconnectExhaustsAttempts(t *testing.T) { + if testing.Short() { + // 5 attempts * 0.5s = ~7.5s of sleep — too slow for -short with -race. + t.Skip("skipping long reconnect-exhaustion test under -short") + } + t.Parallel() + ln, _ := net.Listen("tcp", "127.0.0.1:0") + addr := ln.Addr().String() + ln.Close() + + c := &Client{addr: addr} + c.mu.Lock() + err := c.reconnect(context.Background()) + c.mu.Unlock() + if err == nil { + t.Fatalf("expected reconnect failure") + } + if !strings.Contains(err.Error(), "reconnect failed") { + t.Fatalf("expected 'reconnect failed' wrap, got: %v", err) + } +} + +// --- Close after pool conn already closed: idempotency / second-close ---- + +func TestClosePoolEntryAlreadyClosedReturnsFirstErrOrNil(t *testing.T) { + t.Parallel() + srv := newFakeJSONServer(t, echoHandler()) + defer srv.close() + + c, err := DialPool(srv.addr(), 3) + if err != nil { + t.Fatalf("DialPool: %v", err) + } + // Pre-close one secondary conn so Close()'s per-entry loop hits the + // "already-closed" path on it. + c.pool.entries[1].mu.Lock() + _ = c.pool.entries[1].conn.Close() + c.pool.entries[1].mu.Unlock() + + // Double-close should be safe. + if err := c.Close(); err != nil { + // Close may surface the first error (double-close on a TCPConn). + // That's acceptable — what matters is no panic. + t.Logf("Close returned (acceptable): %v", err) + } +} + +// --- sendOnEntry server error path (response with "error" key) ----------- + +func TestSendOnEntryReturnsServerErrorResponse(t *testing.T) { + t.Parallel() + srv := newFakeJSONServer(t, func(_ map[string]interface{}) map[string]interface{} { + return map[string]interface{}{"error": "rate-limited"} + }) + defer srv.close() + + c, err := DialPool(srv.addr(), 2) + if err != nil { + t.Fatalf("DialPool: %v", err) + } + defer c.Close() + + resp, err := c.Send(map[string]interface{}{"type": "x"}) + if err == nil { + t.Fatalf("expected server error") + } + if resp == nil { + t.Fatalf("resp must be non-nil for server-error path") + } + if !strings.Contains(err.Error(), "rate-limited") { + t.Fatalf("error should contain server message, got: %v", err) + } +} + +// --- DialTLS happy path --------------------------------------------------- + +func TestDialTLSHappyPathConnects(t *testing.T) { + t.Parallel() + srv := newFakeTLSServer(t, echoHandler()) + cfg := &tls.Config{ + MinVersion: tls.VersionTLS12, + InsecureSkipVerify: true, //nolint:gosec // test-only + } + c, err := DialTLS(srv.addr(), cfg) + if err != nil { + t.Fatalf("DialTLS: %v", err) + } + defer c.Close() + resp, err := c.Send(map[string]interface{}{"type": "x"}) + if err != nil { + t.Fatalf("send: %v", err) + } + if got, _ := resp["type"].(string); got != "ok" { + t.Fatalf("type: %q", got) + } + // Ensure tlsConfig is retained so reconnect would use TLS too. + if c.tlsConfig == nil { + t.Fatalf("tlsConfig should be retained on Client after DialTLS") + } +} diff --git a/registry/client/zz_client_join_signature_test.go b/registry/client/zz_client_join_signature_test.go new file mode 100644 index 0000000..677b0b0 --- /dev/null +++ b/registry/client/zz_client_join_signature_test.go @@ -0,0 +1,707 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package client + +import ( + "encoding/json" + "errors" + "strings" + "testing" + "time" +) + +// echoOnlyClient dials a fakeJSONServer with echoHandler and returns a +// connected Client plus the server so test bodies can assert wire payloads. +func echoOnlyClient(t *testing.T) (*Client, *fakeJSONServer) { + t.Helper() + srv := newFakeJSONServer(t, echoHandler()) + c, err := Dial(srv.addr()) + if err != nil { + srv.close() + t.Fatalf("dial: %v", err) + } + t.Cleanup(func() { c.Close(); srv.close() }) + return c, srv +} + +// assertEcho fetches the echoed request payload that the fake server round-tripped. +func assertEcho(t *testing.T, resp map[string]interface{}) map[string]interface{} { + t.Helper() + echo, ok := resp["echo"].(map[string]interface{}) + if !ok { + t.Fatalf("response missing echo key: %+v", resp) + } + return echo +} + +// --- JoinNetwork / LeaveNetwork : signature wins over admin_token -------- + +func TestJoinNetworkSignaturePreferredOverAdminToken(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + c.SetSigner(func(ch string) string { return "SIG:" + ch }) + + resp, err := c.JoinNetwork(11, 3, "tok", 4, "ADMIN_SHOULD_BE_IGNORED") + if err != nil { + t.Fatalf("join: %v", err) + } + echo := assertEcho(t, resp) + if got, _ := echo["type"].(string); got != "join_network" { + t.Fatalf("type: %q", got) + } + if got, _ := echo["signature"].(string); got != "SIG:join_network:11:3" { + t.Fatalf("signature: %q", got) + } + if _, ok := echo["admin_token"]; ok { + t.Fatalf("admin_token should be omitted when signature present") + } +} + +func TestJoinNetworkFallsBackToAdminTokenWithoutSigner(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + resp, err := c.JoinNetwork(11, 3, "tok", 4, "ADM") + if err != nil { + t.Fatalf("join: %v", err) + } + echo := assertEcho(t, resp) + if _, ok := echo["signature"]; ok { + t.Fatalf("signature should be absent with no signer") + } + if got, _ := echo["admin_token"].(string); got != "ADM" { + t.Fatalf("admin_token: %q", got) + } + if got, _ := echo["inviter_id"].(float64); uint32(got) != 4 { + t.Fatalf("inviter_id: %v", got) + } +} + +func TestLeaveNetworkSignatureOrAdminToken(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + // Signer wins. + c.SetSigner(func(ch string) string { return "SIG:" + ch }) + resp, _ := c.LeaveNetwork(5, 2, "ADMIN") + echo := assertEcho(t, resp) + if got, _ := echo["signature"].(string); got != "SIG:leave_network:5:2" { + t.Fatalf("signature: %q", got) + } + if _, ok := echo["admin_token"]; ok { + t.Fatalf("admin_token should be omitted when sig present") + } + // Drop signer → admin_token fallback. + c.SetSigner(nil) + resp, _ = c.LeaveNetwork(5, 2, "ADMIN") + echo = assertEcho(t, resp) + if got, _ := echo["admin_token"].(string); got != "ADMIN" { + t.Fatalf("admin_token fallback: %q", got) + } +} + +// --- DeleteNetwork / RenameNetwork : variadic node_id -------------------- + +func TestDeleteNetworkVariadicNodeIDAndAdminToken(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + // No nodeID, no adminToken. + resp, _ := c.DeleteNetwork(3, "") + echo := assertEcho(t, resp) + if _, ok := echo["node_id"]; ok { + t.Fatalf("node_id should be omitted when not passed") + } + if _, ok := echo["admin_token"]; ok { + t.Fatalf("admin_token should be omitted when blank") + } + // With both. + resp, _ = c.DeleteNetwork(3, "ADM", 77) + echo = assertEcho(t, resp) + if got, _ := echo["node_id"].(float64); uint32(got) != 77 { + t.Fatalf("node_id: %v", got) + } + if got, _ := echo["admin_token"].(string); got != "ADM" { + t.Fatalf("admin_token: %q", got) + } + // Explicit 0 node_id → still omitted. + resp, _ = c.DeleteNetwork(3, "ADM", 0) + echo = assertEcho(t, resp) + if _, ok := echo["node_id"]; ok { + t.Fatalf("node_id=0 should be omitted (matches client logic)") + } +} + +func TestRenameNetworkPassesName(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + resp, _ := c.RenameNetwork(1, "shiny", "ADM", 9) + echo := assertEcho(t, resp) + if got, _ := echo["name"].(string); got != "shiny" { + t.Fatalf("name: %q", got) + } + if got, _ := echo["node_id"].(float64); uint32(got) != 9 { + t.Fatalf("node_id: %v", got) + } +} + +// --- ListNetworks / ListNodes / SetNetworkEnterprise -------------------- + +func TestListNetworksBareType(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + resp, _ := c.ListNetworks() + echo := assertEcho(t, resp) + if got, _ := echo["type"].(string); got != "list_networks" { + t.Fatalf("type: %q", got) + } +} + +func TestListNodesAdminTokenOptional(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + // Without admin token. + resp, _ := c.ListNodes(42) + echo := assertEcho(t, resp) + if _, ok := echo["admin_token"]; ok { + t.Fatalf("admin_token should be omitted when not supplied") + } + // With admin token. + resp, _ = c.ListNodes(42, "ADM") + echo = assertEcho(t, resp) + if got, _ := echo["admin_token"].(string); got != "ADM" { + t.Fatalf("admin_token: %q", got) + } +} + +func TestSetNetworkEnterpriseSerializesBool(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + resp, _ := c.SetNetworkEnterprise(7, true, "ADM") + echo := assertEcho(t, resp) + if got, _ := echo["enterprise"].(bool); !got { + t.Fatalf("enterprise should be true") + } + if got, _ := echo["admin_token"].(string); got != "ADM" { + t.Fatalf("admin_token: %q", got) + } +} + +// --- Signed thin wrappers: Deregister / Heartbeat / Punch ---------------- + +func TestSignedWrappersIncludeCorrectChallenge(t *testing.T) { + t.Parallel() + cases := []struct { + name string + call func(c *Client) (map[string]interface{}, error) + expect string + }{ + {"deregister", func(c *Client) (map[string]interface{}, error) { return c.Deregister(42) }, "deregister:42"}, + {"heartbeat", func(c *Client) (map[string]interface{}, error) { return c.Heartbeat(42) }, "heartbeat:42"}, + {"punch", func(c *Client) (map[string]interface{}, error) { return c.Punch(1, 42, 43) }, "punch:42:43"}, + {"poll_handshakes", func(c *Client) (map[string]interface{}, error) { return c.PollHandshakes(42) }, "poll_handshakes:42"}, + {"set_hostname", func(c *Client) (map[string]interface{}, error) { return c.SetHostname(42, "h") }, "set_hostname:42"}, + {"set_tags", func(c *Client) (map[string]interface{}, error) { return c.SetTags(42, []string{"a"}) }, "set_tags:42"}, + {"poll_invites", func(c *Client) (map[string]interface{}, error) { return c.PollInvites(42) }, "poll_invites:42"}, + {"set_key_expiry", func(c *Client) (map[string]interface{}, error) { return c.SetKeyExpiry(42, time.Unix(0, 0).UTC()) }, "set_key_expiry:42"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + c, _ := echoOnlyClient(t) + c.SetSigner(func(ch string) string { return "SIG:" + ch }) + resp, err := tc.call(c) + if err != nil { + t.Fatalf("call: %v", err) + } + echo := assertEcho(t, resp) + if got, _ := echo["signature"].(string); got != "SIG:"+tc.expect { + t.Fatalf("signature: want SIG:%s, got %q", tc.expect, got) + } + }) + } +} + +func TestSetKeyExpiryFormatsRFC3339(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + c.SetSigner(func(ch string) string { return "SIG:" + ch }) + moment := time.Date(2030, 1, 2, 3, 4, 5, 0, time.UTC) + resp, _ := c.SetKeyExpiry(9, moment) + echo := assertEcho(t, resp) + if got, _ := echo["expires_at"].(string); got != "2030-01-02T03:04:05Z" { + t.Fatalf("expires_at: %q", got) + } +} + +// --- RequestHandshake / RespondHandshake (caller-supplied signature) ----- + +func TestRequestAndRespondHandshakePassThroughSignature(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + resp, _ := c.RequestHandshake(1, 2, "please", "SIG_REQ") + echo := assertEcho(t, resp) + if got, _ := echo["signature"].(string); got != "SIG_REQ" { + t.Fatalf("req signature: %q", got) + } + if got, _ := echo["justification"].(string); got != "please" { + t.Fatalf("justification: %q", got) + } + + resp, _ = c.RespondHandshake(3, 4, true, "SIG_RESP") + echo = assertEcho(t, resp) + if got, _ := echo["accept"].(bool); !got { + t.Fatalf("accept: %v", got) + } + if got, _ := echo["signature"].(string); got != "SIG_RESP" { + t.Fatalf("resp signature: %q", got) + } + + // Blank signature omitted. + resp, _ = c.RespondHandshake(3, 4, false, "") + echo = assertEcho(t, resp) + if _, ok := echo["signature"]; ok { + t.Fatalf("signature should be omitted when blank") + } +} + +// --- ResolveHostname / ResolveHostnameAs / CheckTrust -------------------- + +func TestResolveHostnameBothForms(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + resp, _ := c.ResolveHostname("alpha") + echo := assertEcho(t, resp) + if got, _ := echo["hostname"].(string); got != "alpha" { + t.Fatalf("hostname: %q", got) + } + if _, ok := echo["requester_id"]; ok { + t.Fatalf("requester_id should be absent") + } + + resp, _ = c.ResolveHostnameAs(99, "beta") + echo = assertEcho(t, resp) + if got, _ := echo["requester_id"].(float64); uint32(got) != 99 { + t.Fatalf("requester_id: %v", got) + } +} + +func TestCheckTrustReturnsTypedBool(t *testing.T) { + t.Parallel() + srv := newFakeJSONServer(t, func(_ map[string]interface{}) map[string]interface{} { + return map[string]interface{}{"type": "ok", "trusted": true} + }) + defer srv.close() + c, _ := Dial(srv.addr()) + defer c.Close() + + trusted, err := c.CheckTrust(1, 2) + if err != nil { + t.Fatalf("check trust: %v", err) + } + if !trusted { + t.Fatalf("expected trusted=true") + } +} + +func TestCheckTrustPropagatesError(t *testing.T) { + t.Parallel() + srv := newFakeJSONServer(t, func(_ map[string]interface{}) map[string]interface{} { + return map[string]interface{}{"error": "forbidden"} + }) + defer srv.close() + c, _ := Dial(srv.addr()) + defer c.Close() + + trusted, err := c.CheckTrust(1, 2) + if err == nil || !strings.Contains(err.Error(), "forbidden") { + t.Fatalf("expected forbidden error, got %v", err) + } + if trusted { + t.Fatalf("expected trusted=false on error") + } +} + +// --- Invite family -------------------------------------------------------- + +func TestInviteToNetworkSigAndAdminBothAllowed(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + c.SetSigner(func(ch string) string { return "SIG:" + ch }) + // Both signature AND admin_token are included (logic uses two independent ifs). + resp, _ := c.InviteToNetwork(3, 1, 2, "ADM") + echo := assertEcho(t, resp) + if got, _ := echo["signature"].(string); got != "SIG:invite:1:3:2" { + t.Fatalf("signature: %q", got) + } + if got, _ := echo["admin_token"].(string); got != "ADM" { + t.Fatalf("admin_token: %q", got) + } +} + +func TestRespondInvitePassesAccept(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + c.SetSigner(func(ch string) string { return "SIG:" + ch }) + resp, _ := c.RespondInvite(5, 9, false) + echo := assertEcho(t, resp) + if got, _ := echo["accept"].(bool); got { + t.Fatalf("accept: %v", got) + } + if got, _ := echo["signature"].(string); got != "SIG:respond_invite:5:9" { + t.Fatalf("signature: %q", got) + } +} + +// --- Member role operations --------------------------------------------- + +func TestMemberRoleOpsOmitBlankAdmin(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + ops := map[string]func() (map[string]interface{}, error){ + "promote_member": func() (map[string]interface{}, error) { return c.PromoteMember(1, 2, 3, "") }, + "demote_member": func() (map[string]interface{}, error) { return c.DemoteMember(1, 2, 3, "") }, + "kick_member": func() (map[string]interface{}, error) { return c.KickMember(1, 2, 3, "") }, + "transfer_ownership": func() (map[string]interface{}, error) { return c.TransferOwnership(1, 2, 3, "") }, + } + for name, op := range ops { + resp, err := op() + if err != nil { + t.Fatalf("%s: %v", name, err) + } + echo := assertEcho(t, resp) + if got, _ := echo["type"].(string); got != name { + t.Fatalf("%s: type=%q", name, got) + } + if _, ok := echo["admin_token"]; ok { + t.Fatalf("%s: admin_token should be omitted when blank", name) + } + } +} + +func TestGetMemberRoleSimple(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + resp, _ := c.GetMemberRole(3, 7) + echo := assertEcho(t, resp) + if got, _ := echo["type"].(string); got != "get_member_role" { + t.Fatalf("type: %q", got) + } + if got, _ := echo["target_node_id"].(float64); uint32(got) != 7 { + t.Fatalf("target_node_id: %v", got) + } +} + +// --- Policy / ExprPolicy -------------------------------------------------- + +func TestSetNetworkPolicyMergesPolicyMap(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + policy := map[string]interface{}{ + "allow_public": true, + "max_members": float64(10), + "network_id": float64(999), // must NOT override the real network_id + "type": "evil_type", // must NOT override the real type + "admin_token": "EVIL", // must NOT override the real admin_token + } + resp, _ := c.SetNetworkPolicy(3, policy, "ADM") + echo := assertEcho(t, resp) + if got, _ := echo["allow_public"].(bool); !got { + t.Fatalf("allow_public: %v", got) + } + if got, _ := echo["max_members"].(float64); got != 10 { + t.Fatalf("max_members: %v", got) + } + // Explicit networkID parameter must win over a policy key of the same name. + if got, _ := echo["network_id"].(float64); got != 3 { + t.Fatalf("network_id: want 3 (explicit param wins), got %v", got) + } + if got, _ := echo["type"].(string); got != "set_network_policy" { + t.Fatalf("type: want set_network_policy (protected), got %q", got) + } + if got, _ := echo["admin_token"].(string); got != "ADM" { + t.Fatalf("admin_token: want ADM (explicit param wins), got %q", got) + } +} + +func TestGetNetworkPolicyType(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + resp, _ := c.GetNetworkPolicy(7) + echo := assertEcho(t, resp) + if got, _ := echo["type"].(string); got != "get_network_policy" { + t.Fatalf("type: %q", got) + } +} + +func TestSetExprPolicyStringifiesJSON(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + raw := json.RawMessage(`{"rule":"true"}`) + resp, _ := c.SetExprPolicy(9, raw, "ADM") + echo := assertEcho(t, resp) + if got, _ := echo["expr_policy"].(string); got != `{"rule":"true"}` { + t.Fatalf("expr_policy: %q", got) + } +} + +func TestGetExprPolicyType(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + resp, _ := c.GetExprPolicy(9) + echo := assertEcho(t, resp) + if got, _ := echo["type"].(string); got != "get_expr_policy" { + t.Fatalf("type: %q", got) + } +} + +// --- Admin wrappers (trivial payload formatters) ------------------------- + +func TestAdminWrappersIncludeAdminToken(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + cases := []struct { + name string + call func() (map[string]interface{}, error) + typ string + }{ + {"set_hostname_admin", func() (map[string]interface{}, error) { return c.SetHostnameAdmin(1, "h", "T") }, "set_hostname"}, + {"set_visibility_admin", func() (map[string]interface{}, error) { return c.SetVisibilityAdmin(1, true, "T") }, "set_visibility"}, + {"set_tags_admin", func() (map[string]interface{}, error) { return c.SetTagsAdmin(1, []string{"x"}, "T") }, "set_tags"}, + {"set_key_expiry_admin", func() (map[string]interface{}, error) { return c.SetKeyExpiryAdmin(1, time.Unix(0, 0).UTC(), "T") }, "set_key_expiry"}, + {"clear_key_expiry_admin", func() (map[string]interface{}, error) { return c.ClearKeyExpiryAdmin(1, "T") }, "set_key_expiry"}, + {"deregister_admin", func() (map[string]interface{}, error) { return c.DeregisterAdmin(1, "T") }, "deregister"}, + } + for _, tc := range cases { + resp, err := tc.call() + if err != nil { + t.Fatalf("%s: %v", tc.name, err) + } + echo := assertEcho(t, resp) + if got, _ := echo["type"].(string); got != tc.typ { + t.Fatalf("%s: type=%q want %q", tc.name, got, tc.typ) + } + if got, _ := echo["admin_token"].(string); got != "T" { + t.Fatalf("%s: admin_token=%q", tc.name, got) + } + } +} + +func TestClearKeyExpiryAdminSendsNeverLiteral(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + resp, _ := c.ClearKeyExpiryAdmin(1, "T") + echo := assertEcho(t, resp) + if got, _ := echo["expires_at"].(string); got != "never" { + t.Fatalf("expires_at: want 'never', got %q", got) + } +} + +func TestSetMemberTagsAndGetMemberTags(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + resp, _ := c.SetMemberTags(2, 3, []string{"gpu", "fast"}, "T") + echo := assertEcho(t, resp) + if got, _ := echo["type"].(string); got != "set_member_tags" { + t.Fatalf("type: %q", got) + } + tags, _ := echo["tags"].([]interface{}) + if len(tags) != 2 || tags[0] != "gpu" || tags[1] != "fast" { + t.Fatalf("tags: %v", tags) + } + + resp, _ = c.GetMemberTags(2, 3) + echo = assertEcho(t, resp) + if got, _ := echo["type"].(string); got != "get_member_tags" { + t.Fatalf("type: %q", got) + } +} + +// --- Audit log / Audit export / Webhooks / Identity / IDP / Provision ---- + +func TestGetAuditLogOmitsZeroNetworkID(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + resp, _ := c.GetAuditLog(0, "T") + echo := assertEcho(t, resp) + if _, ok := echo["network_id"]; ok { + t.Fatalf("network_id should be omitted when 0") + } + resp, _ = c.GetAuditLog(3, "T") + echo = assertEcho(t, resp) + if got, _ := echo["network_id"].(float64); uint16(got) != 3 { + t.Fatalf("network_id: %v", got) + } +} + +func TestWebhookWrappers(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + cases := []struct { + typ string + call func() (map[string]interface{}, error) + }{ + {"set_webhook", func() (map[string]interface{}, error) { return c.SetWebhook("http://x", "T") }}, + {"get_webhook", func() (map[string]interface{}, error) { return c.GetWebhook("T") }}, + {"get_webhook_dlq", func() (map[string]interface{}, error) { return c.GetWebhookDLQ("T") }}, + {"set_identity_webhook", func() (map[string]interface{}, error) { return c.SetIdentityWebhook("http://id", "T") }}, + } + for _, tc := range cases { + resp, err := tc.call() + if err != nil { + t.Fatalf("%s: %v", tc.typ, err) + } + echo := assertEcho(t, resp) + if got, _ := echo["type"].(string); got != tc.typ { + t.Fatalf("%s: type=%q", tc.typ, got) + } + } +} + +func TestIdentityExternalIDWrappers(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + resp, _ := c.SetExternalID(5, "ext-7", "T") + echo := assertEcho(t, resp) + if got, _ := echo["external_id"].(string); got != "ext-7" { + t.Fatalf("external_id: %q", got) + } + resp, _ = c.GetIdentity(5, "T") + echo = assertEcho(t, resp) + if got, _ := echo["type"].(string); got != "get_identity" { + t.Fatalf("type: %q", got) + } +} + +func TestSetIDPConfigOptionalFields(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + // Only required fields. + resp, _ := c.SetIDPConfig("oidc", "https://idp", "", "", "", "", "T") + echo := assertEcho(t, resp) + for _, key := range []string{"issuer", "client_id", "tenant_id", "domain"} { + if _, ok := echo[key]; ok { + t.Fatalf("%s should be omitted when blank", key) + } + } + if got, _ := echo["idp_type"].(string); got != "oidc" { + t.Fatalf("idp_type: %q", got) + } + // All fields. + resp, _ = c.SetIDPConfig("oidc", "https://idp", "ISS", "CID", "TID", "example.com", "T") + echo = assertEcho(t, resp) + for _, key := range []string{"issuer", "client_id", "tenant_id", "domain"} { + if _, ok := echo[key]; !ok { + t.Fatalf("%s should be present when supplied", key) + } + } +} + +func TestGetIDPConfigAndGetProvisionStatus(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + resp, _ := c.GetIDPConfig("T") + echo := assertEcho(t, resp) + if got, _ := echo["type"].(string); got != "get_idp_config" { + t.Fatalf("type: %q", got) + } + resp, _ = c.GetProvisionStatus("T") + echo = assertEcho(t, resp) + if got, _ := echo["type"].(string); got != "get_provision_status" { + t.Fatalf("type: %q", got) + } +} + +func TestProvisionNetworkPassesBlueprint(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + bp := map[string]interface{}{"name": "bp", "networks": []interface{}{}} + resp, _ := c.ProvisionNetwork(bp, "T") + echo := assertEcho(t, resp) + blueprint, _ := echo["blueprint"].(map[string]interface{}) + if got, _ := blueprint["name"].(string); got != "bp" { + t.Fatalf("blueprint.name: %q", got) + } +} + +func TestSetAuditExportAllFields(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + resp, _ := c.SetAuditExport("splunk_hec", "https://hec", "TOK", "idx", "src", "T") + echo := assertEcho(t, resp) + if got, _ := echo["format"].(string); got != "splunk_hec" { + t.Fatalf("format: %q", got) + } + if got, _ := echo["endpoint"].(string); got != "https://hec" { + t.Fatalf("endpoint: %q", got) + } + if got, _ := echo["index"].(string); got != "idx" { + t.Fatalf("index: %q", got) + } + if got, _ := echo["source"].(string); got != "src" { + t.Fatalf("source: %q", got) + } +} + +func TestGetAuditExport(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + resp, _ := c.GetAuditExport("T") + echo := assertEcho(t, resp) + if got, _ := echo["type"].(string); got != "get_audit_export" { + t.Fatalf("type: %q", got) + } +} + +// --- Directory sync / ValidateToken / GetKeyInfo ------------------------- + +func TestDirectorySyncConvertsEntriesAndPassesFlag(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + entries := []map[string]interface{}{ + {"id": "u1", "role": "admin"}, + {"id": "u2", "role": "member"}, + } + resp, _ := c.DirectorySync(1, entries, true, "T") + echo := assertEcho(t, resp) + list, _ := echo["entries"].([]interface{}) + if len(list) != 2 { + t.Fatalf("entries: %v", list) + } + first, _ := list[0].(map[string]interface{}) + if got, _ := first["id"].(string); got != "u1" { + t.Fatalf("first.id: %q", got) + } + if got, _ := echo["remove_unlisted"].(bool); !got { + t.Fatalf("remove_unlisted: %v", got) + } +} + +func TestDirectoryStatusSimple(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + resp, _ := c.DirectoryStatus(5, "T") + echo := assertEcho(t, resp) + if got, _ := echo["type"].(string); got != "directory_status" { + t.Fatalf("type: %q", got) + } +} + +func TestValidateTokenPassesPayload(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + resp, _ := c.ValidateToken("jwt.header.sig", "T") + echo := assertEcho(t, resp) + if got, _ := echo["token"].(string); got != "jwt.header.sig" { + t.Fatalf("token: %q", got) + } +} + +func TestGetKeyInfoSimple(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + resp, _ := c.GetKeyInfo(7) + echo := assertEcho(t, resp) + if got, _ := echo["type"].(string); got != "get_key_info" { + t.Fatalf("type: %q", got) + } +} + +// Ensure errors package remains used if inline error checks are trimmed. +var _ = errors.New diff --git a/registry/client/zz_client_nil_receiver_test.go b/registry/client/zz_client_nil_receiver_test.go new file mode 100644 index 0000000..1c0995e --- /dev/null +++ b/registry/client/zz_client_nil_receiver_test.go @@ -0,0 +1,253 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package client + +import ( + "encoding/json" + "errors" + "testing" + "time" +) + +// TestNilClient_AllMethodsReturnError asserts that every exported *Client +// method is safe to call on a typed-nil receiver and returns ErrNoRegistry +// (or, for SetSigner/Close, a no-op without panic). Several callers in the +// daemon (loadPolicyRunners, ManagedEngine.fetchMembers, Daemon.Info → +// nodeNetworks) invoke registry methods without nil-checking the client, +// so the only acceptable behavior is "no panic; recoverable error." +// +// The test invokes each method, recovers any panic, and asserts the +// expected error. A panic counts as a regression and fails the test. +func TestNilClient_AllMethodsReturnError(t *testing.T) { + t.Parallel() + + var c *Client + + // callErr runs fn and asserts (a) no panic and (b) the returned error + // is ErrNoRegistry (using errors.Is). name identifies the method. + callErr := func(name string, fn func() error) { + t.Helper() + defer func() { + if r := recover(); r != nil { + t.Errorf("%s panicked on nil receiver: %v", name, r) + } + }() + err := fn() + if !errors.Is(err, ErrNoRegistry) { + t.Errorf("%s: err = %v, want ErrNoRegistry", name, err) + } + } + + // callMap discards the map return and asserts the error contract. + callMap := func(name string, fn func() (map[string]interface{}, error)) { + t.Helper() + callErr(name, func() error { + _, err := fn() + return err + }) + } + + // --- void / no-error methods (must not panic; nothing else to assert) --- + func() { + defer func() { + if r := recover(); r != nil { + t.Errorf("SetSigner panicked on nil receiver: %v", r) + } + }() + c.SetSigner(func(string) string { return "" }) + }() + + func() { + defer func() { + if r := recover(); r != nil { + t.Errorf("Close panicked on nil receiver: %v", r) + } + }() + if err := c.Close(); err != nil { + t.Errorf("Close on nil receiver: err = %v, want nil", err) + } + }() + + // --- methods that return (map, error) — go through Send --- + callMap("Send", func() (map[string]interface{}, error) { + return c.Send(map[string]interface{}{"type": "ping"}) + }) + callMap("Register", func() (map[string]interface{}, error) { return c.Register("127.0.0.1:0") }) + callMap("RegisterWithOwner", func() (map[string]interface{}, error) { + return c.RegisterWithOwner("127.0.0.1:0", "owner") + }) + callMap("RegisterWithKey", func() (map[string]interface{}, error) { + return c.RegisterWithKey("127.0.0.1:0", "key", "owner", nil) + }) + callMap("RegisterWithKeyOpts", func() (map[string]interface{}, error) { + return c.RegisterWithKeyOpts(RegisterOpts{ListenAddr: "127.0.0.1:0", PublicKey: "k"}) + }) + callMap("RotateKey", func() (map[string]interface{}, error) { + return c.RotateKey(1, "sig", "newkey") + }) + callMap("Lookup", func() (map[string]interface{}, error) { return c.Lookup(1) }) + callMap("Resolve", func() (map[string]interface{}, error) { return c.Resolve(1, 2) }) + callMap("ReportTrust", func() (map[string]interface{}, error) { return c.ReportTrust(1, 2) }) + callMap("RevokeTrust", func() (map[string]interface{}, error) { return c.RevokeTrust(1, 2) }) + callMap("SetVisibility", func() (map[string]interface{}, error) { return c.SetVisibility(1, true) }) + callMap("CreateNetwork", func() (map[string]interface{}, error) { + return c.CreateNetwork(1, "name", "open", "tok", "admin", false) + }) + callMap("CreateManagedNetwork", func() (map[string]interface{}, error) { + return c.CreateManagedNetwork(1, "name", "open", "tok", "admin", false, "{}") + }) + callMap("JoinNetwork", func() (map[string]interface{}, error) { + return c.JoinNetwork(1, 2, "tok", 3, "admin") + }) + callMap("LeaveNetwork", func() (map[string]interface{}, error) { + return c.LeaveNetwork(1, 2, "admin") + }) + callMap("DeleteNetwork", func() (map[string]interface{}, error) { return c.DeleteNetwork(1, "admin") }) + callMap("RenameNetwork", func() (map[string]interface{}, error) { + return c.RenameNetwork(1, "new", "admin") + }) + callMap("SetNetworkEnterprise", func() (map[string]interface{}, error) { + return c.SetNetworkEnterprise(1, true, "admin") + }) + callMap("ListNetworks", func() (map[string]interface{}, error) { return c.ListNetworks() }) + callMap("ListNodes", func() (map[string]interface{}, error) { return c.ListNodes(1) }) + callMap("Deregister", func() (map[string]interface{}, error) { return c.Deregister(1) }) + callMap("Heartbeat", func() (map[string]interface{}, error) { return c.Heartbeat(1) }) + callMap("Punch", func() (map[string]interface{}, error) { return c.Punch(1, 2, 3) }) + callMap("RequestHandshake", func() (map[string]interface{}, error) { + return c.RequestHandshake(1, 2, "why", "sig") + }) + callMap("PollHandshakes", func() (map[string]interface{}, error) { return c.PollHandshakes(1) }) + callMap("RespondHandshake", func() (map[string]interface{}, error) { + return c.RespondHandshake(1, 2, true, "sig") + }) + callMap("SetHostname", func() (map[string]interface{}, error) { return c.SetHostname(1, "h") }) + callMap("SetTags", func() (map[string]interface{}, error) { return c.SetTags(1, []string{"t"}) }) + callMap("ResolveHostname", func() (map[string]interface{}, error) { return c.ResolveHostname("h") }) + callMap("ResolveHostnameAs", func() (map[string]interface{}, error) { + return c.ResolveHostnameAs(1, "h") + }) + callMap("InviteToNetwork", func() (map[string]interface{}, error) { + return c.InviteToNetwork(1, 2, 3, "admin") + }) + callMap("PollInvites", func() (map[string]interface{}, error) { return c.PollInvites(1) }) + callMap("RespondInvite", func() (map[string]interface{}, error) { + return c.RespondInvite(1, 2, true) + }) + callMap("PromoteMember", func() (map[string]interface{}, error) { + return c.PromoteMember(1, 2, 3, "admin") + }) + callMap("DemoteMember", func() (map[string]interface{}, error) { + return c.DemoteMember(1, 2, 3, "admin") + }) + callMap("KickMember", func() (map[string]interface{}, error) { + return c.KickMember(1, 2, 3, "admin") + }) + callMap("TransferOwnership", func() (map[string]interface{}, error) { + return c.TransferOwnership(1, 2, 3, "admin") + }) + callMap("GetMemberRole", func() (map[string]interface{}, error) { + return c.GetMemberRole(1, 2) + }) + callMap("SetNetworkPolicy", func() (map[string]interface{}, error) { + return c.SetNetworkPolicy(1, map[string]interface{}{}, "admin") + }) + callMap("GetNetworkPolicy", func() (map[string]interface{}, error) { + return c.GetNetworkPolicy(1) + }) + callMap("SetExprPolicy", func() (map[string]interface{}, error) { + return c.SetExprPolicy(1, json.RawMessage(`{}`), "admin") + }) + callMap("GetExprPolicy", func() (map[string]interface{}, error) { return c.GetExprPolicy(1) }) + callMap("SetKeyExpiry", func() (map[string]interface{}, error) { + return c.SetKeyExpiry(1, time.Now()) + }) + callMap("GetKeyInfo", func() (map[string]interface{}, error) { return c.GetKeyInfo(1) }) + callMap("SetHostnameAdmin", func() (map[string]interface{}, error) { + return c.SetHostnameAdmin(1, "h", "admin") + }) + callMap("SetVisibilityAdmin", func() (map[string]interface{}, error) { + return c.SetVisibilityAdmin(1, true, "admin") + }) + callMap("SetTagsAdmin", func() (map[string]interface{}, error) { + return c.SetTagsAdmin(1, []string{"t"}, "admin") + }) + callMap("SetMemberTags", func() (map[string]interface{}, error) { + return c.SetMemberTags(1, 2, []string{"t"}, "admin") + }) + callMap("GetMemberTags", func() (map[string]interface{}, error) { + return c.GetMemberTags(1, 2) + }) + callMap("SetKeyExpiryAdmin", func() (map[string]interface{}, error) { + return c.SetKeyExpiryAdmin(1, time.Now(), "admin") + }) + callMap("ClearKeyExpiryAdmin", func() (map[string]interface{}, error) { + return c.ClearKeyExpiryAdmin(1, "admin") + }) + callMap("DeregisterAdmin", func() (map[string]interface{}, error) { + return c.DeregisterAdmin(1, "admin") + }) + callMap("GetAuditLog", func() (map[string]interface{}, error) { + return c.GetAuditLog(1, "admin") + }) + callMap("SetWebhook", func() (map[string]interface{}, error) { + return c.SetWebhook("http://x", "admin") + }) + callMap("GetWebhook", func() (map[string]interface{}, error) { return c.GetWebhook("admin") }) + callMap("GetWebhookDLQ", func() (map[string]interface{}, error) { + return c.GetWebhookDLQ("admin") + }) + callMap("SetIdentityWebhook", func() (map[string]interface{}, error) { + return c.SetIdentityWebhook("http://x", "admin") + }) + callMap("SetExternalID", func() (map[string]interface{}, error) { + return c.SetExternalID(1, "ext", "admin") + }) + callMap("GetIdentity", func() (map[string]interface{}, error) { + return c.GetIdentity(1, "admin") + }) + callMap("ProvisionNetwork", func() (map[string]interface{}, error) { + return c.ProvisionNetwork(map[string]interface{}{}, "admin") + }) + callMap("SetAuditExport", func() (map[string]interface{}, error) { + return c.SetAuditExport("splunk", "https://x", "t", "i", "s", "admin") + }) + callMap("GetAuditExport", func() (map[string]interface{}, error) { + return c.GetAuditExport("admin") + }) + callMap("SetIDPConfig", func() (map[string]interface{}, error) { + return c.SetIDPConfig("oidc", "https://x", "iss", "cid", "tid", "dom", "admin") + }) + callMap("GetIDPConfig", func() (map[string]interface{}, error) { + return c.GetIDPConfig("admin") + }) + callMap("GetProvisionStatus", func() (map[string]interface{}, error) { + return c.GetProvisionStatus("admin") + }) + callMap("DirectorySync", func() (map[string]interface{}, error) { + return c.DirectorySync(1, nil, false, "admin") + }) + callMap("DirectoryStatus", func() (map[string]interface{}, error) { + return c.DirectoryStatus(1, "admin") + }) + callMap("ValidateToken", func() (map[string]interface{}, error) { + return c.ValidateToken("tok", "admin") + }) + + // --- CheckTrust: (bool, error) — pinned separately because the + // return type differs and a non-false bool would be misleading. + func() { + defer func() { + if r := recover(); r != nil { + t.Errorf("CheckTrust panicked on nil receiver: %v", r) + } + }() + ok, err := c.CheckTrust(1, 2) + if ok { + t.Errorf("CheckTrust: ok = true, want false") + } + if !errors.Is(err, ErrNoRegistry) { + t.Errorf("CheckTrust: err = %v, want ErrNoRegistry", err) + } + }() +} diff --git a/registry/client/zz_client_pool_test.go b/registry/client/zz_client_pool_test.go new file mode 100644 index 0000000..04c314e --- /dev/null +++ b/registry/client/zz_client_pool_test.go @@ -0,0 +1,861 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package client + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/hex" + "encoding/json" + "encoding/pem" + "math/big" + "net" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +// Coverage push for pkg/registry/client targeting the previously 0% / low% +// surfaces in client.go: +// +// - DialPool / DialTLSPool / initPool +// - sendPool, sendOnEntry, reconnectEntry, isClosed +// - Close with pooled secondary conns +// - DialTLSPinned full verify path (fingerprint match + mismatch) +// - Send: reconnect-failure error wrap +// +// All fake servers are 127.0.0.1:0 TCP listeners that speak the +// length-prefixed JSON wire protocol used by Client.Send. + +// --- helpers ---------------------------------------------------------------- + +// genSelfSignedCert returns a fresh single-host self-signed cert+key plus the +// raw DER bytes (for pin fingerprint computation). Used by the TLS dial tests. +func genSelfSignedCert(t *testing.T) (tlsCert tls.Certificate, derBytes []byte) { + t.Helper() + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("genkey: %v", err) + } + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "pilot-test"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + DNSNames: []string{"localhost"}, + } + der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &priv.PublicKey, priv) + if err != nil { + t.Fatalf("create cert: %v", err) + } + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) + keyDER, err := x509.MarshalECPrivateKey(priv) + if err != nil { + t.Fatalf("marshal key: %v", err) + } + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}) + tlsCert, err = tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + t.Fatalf("X509KeyPair: %v", err) + } + return tlsCert, der +} + +// newFakeTLSServer wraps the existing fakeJSONServer with a TLS listener +// (so DialTLS / DialTLSPool / DialTLSPinned can connect). +type fakeTLSServer struct { + ln net.Listener + cert tls.Certificate + der []byte + handler func(req map[string]interface{}) map[string]interface{} + connections atomic.Uint32 + wg sync.WaitGroup + closeOnce sync.Once +} + +func newFakeTLSServer(t *testing.T, handler func(req map[string]interface{}) map[string]interface{}) *fakeTLSServer { + t.Helper() + cert, der := genSelfSignedCert(t) + cfg := &tls.Config{ + Certificates: []tls.Certificate{cert}, + MinVersion: tls.VersionTLS12, + } + ln, err := tls.Listen("tcp", "127.0.0.1:0", cfg) + if err != nil { + t.Fatalf("tls listen: %v", err) + } + s := &fakeTLSServer{ln: ln, cert: cert, der: der, handler: handler} + s.wg.Add(1) + go s.accept() + t.Cleanup(s.close) + return s +} + +func (s *fakeTLSServer) addr() string { return s.ln.Addr().String() } + +func (s *fakeTLSServer) close() { + s.closeOnce.Do(func() { + s.ln.Close() + s.wg.Wait() + }) +} + +func (s *fakeTLSServer) accept() { + defer s.wg.Done() + for { + conn, err := s.ln.Accept() + if err != nil { + return + } + s.connections.Add(1) + s.wg.Add(1) + go func() { + defer s.wg.Done() + handleJSONOverConn(conn, s.handler) + }() + } +} + +// handleJSONOverConn runs the standard 4-byte length-prefix JSON loop on a conn. +func handleJSONOverConn(conn net.Conn, handler func(req map[string]interface{}) map[string]interface{}) { + defer conn.Close() + for { + var lenBuf [4]byte + if _, err := readFullN(conn, lenBuf[:]); err != nil { + return + } + n := uint32(lenBuf[0])<<24 | uint32(lenBuf[1])<<16 | uint32(lenBuf[2])<<8 | uint32(lenBuf[3]) + if n > 1<<20 { + return + } + body := make([]byte, n) + if _, err := readFullN(conn, body); err != nil { + return + } + req := map[string]interface{}{} + if err := jsonUnmarshalLite(body, &req); err != nil { + return + } + resp := handler(req) + if resp == nil { + return + } + out, _ := jsonMarshalLite(resp) + var outLen [4]byte + outLen[0] = byte(len(out) >> 24) + outLen[1] = byte(len(out) >> 16) + outLen[2] = byte(len(out) >> 8) + outLen[3] = byte(len(out)) + conn.Write(outLen[:]) + conn.Write(out) + } +} + +// Thin wrappers around encoding/json so the per-conn read loop helper stays +// readable. Same framing as the canonical fakeJSONServer.handle(). +func jsonUnmarshalLite(b []byte, v interface{}) error { return json.Unmarshal(b, v) } +func jsonMarshalLite(v interface{}) ([]byte, error) { return json.Marshal(v) } + +func readFullN(r net.Conn, buf []byte) (int, error) { + total := 0 + for total < len(buf) { + n, err := r.Read(buf[total:]) + if n > 0 { + total += n + } + if err != nil { + return total, err + } + } + return total, nil +} + +// --- DialPool / sendPool basic happy path ---------------------------------- + +func TestDialPoolSizeOneIsSingleConn(t *testing.T) { + t.Parallel() + srv := newFakeJSONServer(t, echoHandler()) + defer srv.close() + + c, err := DialPool(srv.addr(), 1) + if err != nil { + t.Fatalf("DialPool: %v", err) + } + defer c.Close() + // size == 1 → no secondary pool entries, free chan stays nil. + if c.pool.free != nil { + t.Fatalf("pool.free should be nil for size=1, got %v", c.pool.free) + } + // Send still works via legacy single-conn path. + resp, err := c.Send(map[string]interface{}{"type": "hello"}) + if err != nil { + t.Fatalf("send: %v", err) + } + if got, _ := resp["type"].(string); got != "ok" { + t.Fatalf("type: %q", got) + } +} + +func TestDialPoolMultiConnExercisesSendPool(t *testing.T) { + t.Parallel() + srv := newFakeJSONServer(t, echoHandler()) + defer srv.close() + + c, err := DialPool(srv.addr(), 4) + if err != nil { + t.Fatalf("DialPool: %v", err) + } + defer c.Close() + + if c.pool.free == nil { + t.Fatalf("pool.free should be initialised for size>1") + } + if len(c.pool.entries) != 4 { + t.Fatalf("pool entries: want 4, got %d", len(c.pool.entries)) + } + // Each pool entry corresponds to a real conn on the server. + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if srv.connections.Load() == 4 { + break + } + time.Sleep(10 * time.Millisecond) + } + if srv.connections.Load() != 4 { + t.Fatalf("server connections: want 4, got %d", srv.connections.Load()) + } + + // A serial Send drives sendPool / sendOnEntry on entry[0] (or whichever + // is free), exercising the lock+round-trip path. + resp, err := c.Send(map[string]interface{}{"type": "ping"}) + if err != nil { + t.Fatalf("send: %v", err) + } + if got, _ := resp["type"].(string); got != "ok" { + t.Fatalf("type: %q", got) + } +} + +// TestDialPoolZeroOrNegativeSizeNormalisesToOne covers the size<=0 branch. +func TestDialPoolZeroOrNegativeSizeNormalisesToOne(t *testing.T) { + t.Parallel() + srv := newFakeJSONServer(t, echoHandler()) + defer srv.close() + + c, err := DialPool(srv.addr(), 0) + if err != nil { + t.Fatalf("DialPool(0): %v", err) + } + defer c.Close() + if c.pool.free != nil { + t.Fatalf("size<=0 should normalise to 1 (no pool)") + } + + c2, err := DialPool(srv.addr(), -3) + if err != nil { + t.Fatalf("DialPool(-3): %v", err) + } + defer c2.Close() + if c2.pool.free != nil { + t.Fatalf("negative size should normalise to 1 (no pool)") + } +} + +func TestDialPoolErrorOnUnreachable(t *testing.T) { + t.Parallel() + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + addr := ln.Addr().String() + ln.Close() + if _, err := DialPool(addr, 3); err == nil { + t.Fatalf("expected DialPool error on unreachable addr") + } +} + +// TestDialPoolPartialSecondaryFailureClosesPrimary covers the +// "primary dialed, secondary dial failed → close primary, return err" branch +// inside initPool. We accept the primary conn then close the listener to make +// secondary dials fail. +func TestDialPoolPartialSecondaryFailureClosesPrimary(t *testing.T) { + t.Parallel() + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + addr := ln.Addr().String() + + accepted := make(chan net.Conn, 1) + go func() { + conn, err := ln.Accept() + if err != nil { + return + } + accepted <- conn + // Close listener immediately so the next net.Dial inside initPool + // fails with ECONNREFUSED. + ln.Close() + }() + + _, err = DialPool(addr, 4) + if err == nil { + t.Fatalf("expected DialPool to fail when secondary dial errors") + } + if !strings.Contains(err.Error(), "dial pool conn") { + t.Fatalf("error should mention dial pool conn, got: %v", err) + } + + // Clean up the primary conn that the server accepted. + select { + case c := <-accepted: + c.Close() + case <-time.After(time.Second): + } +} + +// --- sendPool: closed-client guards --------------------------------------- + +func TestSendPoolAfterCloseFailsFast(t *testing.T) { + t.Parallel() + srv := newFakeJSONServer(t, echoHandler()) + defer srv.close() + + c, err := DialPool(srv.addr(), 3) + if err != nil { + t.Fatalf("DialPool: %v", err) + } + c.Close() + + _, err = c.Send(map[string]interface{}{"type": "x"}) + if err == nil { + t.Fatalf("expected error after Close") + } + if !strings.Contains(err.Error(), "closed") { + t.Fatalf("expected 'closed' in error, got: %v", err) + } +} + +// TestSendPoolUnblocksOnCloseWhileBlocked ensures that a goroutine blocked +// in <-c.pool.free is woken up by Close (via the done channel select). +// We exhaust the pool, then Close, then assert that a pending Send returns +// with a "closed" error rather than hanging. +func TestSendPoolUnblocksOnCloseWhileBlocked(t *testing.T) { + t.Parallel() + // Handler that blocks until we tell it to release — so the in-flight + // Send holds its pool entry indefinitely. Pool size 1 → second Send + // is blocked on <-c.pool.free. + release := make(chan struct{}) + srv := newFakeJSONServer(t, func(_ map[string]interface{}) map[string]interface{} { + <-release + return map[string]interface{}{"type": "ok"} + }) + defer srv.close() + + c, err := DialPool(srv.addr(), 2) // primary + 1 secondary + if err != nil { + t.Fatalf("DialPool: %v", err) + } + + // Saturate both pool entries with in-flight Sends. + for i := 0; i < 2; i++ { + go func() { + _, _ = c.Send(map[string]interface{}{"type": "block"}) + }() + } + // Give them a chance to grab pool entries. + time.Sleep(50 * time.Millisecond) + + // Third Send blocks in <-c.pool.free. + thirdDone := make(chan error, 1) + go func() { + _, err := c.Send(map[string]interface{}{"type": "third"}) + thirdDone <- err + }() + time.Sleep(50 * time.Millisecond) + + // Close should unblock the waiter via the <-c.pool.done branch. + c.Close() + close(release) // let the in-flight Sends drain + + select { + case err := <-thirdDone: + if err == nil { + t.Fatalf("third Send should have errored after Close") + } + if !strings.Contains(err.Error(), "closed") { + t.Fatalf("expected 'closed' error, got: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatalf("third Send did not return after Close — pool.done branch not wired") + } +} + +// --- sendPool: per-entry reconnect when the conn dies --------------------- + +func TestSendPoolReconnectsBrokenEntry(t *testing.T) { + t.Parallel() + srv := newFakeJSONServer(t, echoHandler()) + defer srv.close() + + c, err := DialPool(srv.addr(), 2) + if err != nil { + t.Fatalf("DialPool: %v", err) + } + defer c.Close() + + // Kill the primary entry's conn so the next Send hitting it triggers + // reconnectEntry. We can't predict which entry the channel picks, so + // kill both — sendOnEntry on whichever one we get will fail then reconnect. + for _, e := range c.pool.entries { + e.mu.Lock() + _ = e.conn.Close() + e.mu.Unlock() + } + + resp, err := c.Send(map[string]interface{}{"type": "ping"}) + if err != nil { + t.Fatalf("send after killing entry: %v", err) + } + if got, _ := resp["type"].(string); got != "ok" { + t.Fatalf("type: %q", got) + } + // The reconnect path must have produced a new TCP conn. + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if srv.connections.Load() >= 3 { + break + } + time.Sleep(10 * time.Millisecond) + } + if srv.connections.Load() < 3 { + t.Fatalf("expected reconnect to open a new conn, server saw %d", srv.connections.Load()) + } +} + +// TestReconnectEntrySyncsPrimary verifies that when the primary entry +// (entries[0]) is reconnected, c.conn is updated in lockstep so callers +// reading c.conn directly don't see a stale fd. This is the "if entry == +// c.pool.entries[0]" branch in reconnectEntry. +func TestReconnectEntrySyncsPrimary(t *testing.T) { + t.Parallel() + srv := newFakeJSONServer(t, echoHandler()) + defer srv.close() + + c, err := DialPool(srv.addr(), 2) + if err != nil { + t.Fatalf("DialPool: %v", err) + } + defer c.Close() + + // Acquire entries[0] off the free channel directly to guarantee we're + // reconnecting the primary slot. + primary := <-c.pool.free + // Sanity: should be entries[0] OR entries[1]; force primary if not. + if primary != c.pool.entries[0] { + // put it back, grab the actual primary. + c.pool.free <- primary + // drain until we get entries[0]. + for i := 0; i < 4; i++ { + candidate := <-c.pool.free + if candidate == c.pool.entries[0] { + primary = candidate + break + } + c.pool.free <- candidate + } + } + + oldConn := c.conn + primary.mu.Lock() + _ = primary.conn.Close() + if err := c.reconnectEntry(context.Background(), primary); err != nil { + primary.mu.Unlock() + t.Fatalf("reconnectEntry: %v", err) + } + newConn := primary.conn + primary.mu.Unlock() + + if newConn == oldConn { + t.Fatalf("primary conn was not replaced by reconnectEntry") + } + // c.conn should be in sync with the new primary conn. + c.mu.Lock() + if c.conn != newConn { + c.mu.Unlock() + t.Fatalf("c.conn not synced after primary reconnect") + } + c.mu.Unlock() + + // Put the entry back so Close doesn't deadlock. + c.pool.free <- primary +} + +// TestReconnectEntryFailsWhenClosed exercises the early-return-on-closed +// branch of reconnectEntry. +func TestReconnectEntryFailsWhenClosed(t *testing.T) { + t.Parallel() + srv := newFakeJSONServer(t, echoHandler()) + defer srv.close() + + c, err := DialPool(srv.addr(), 2) + if err != nil { + t.Fatalf("DialPool: %v", err) + } + c.Close() + + if err := c.reconnectEntry(context.Background(), c.pool.entries[0]); err == nil { + t.Fatalf("reconnectEntry on closed client should fail") + } +} + +// --- isClosed ------------------------------------------------------------- + +func TestIsClosedReflectsCloseState(t *testing.T) { + t.Parallel() + srv := newFakeJSONServer(t, echoHandler()) + defer srv.close() + + c, err := DialPool(srv.addr(), 2) + if err != nil { + t.Fatalf("DialPool: %v", err) + } + if c.isClosed() { + t.Fatalf("fresh client should not report closed") + } + c.Close() + if !c.isClosed() { + t.Fatalf("client should report closed after Close()") + } +} + +// --- DialTLSPool --------------------------------------------------------- + +func TestDialTLSPoolNilConfigReturnsError(t *testing.T) { + t.Parallel() + if _, err := DialTLSPool("127.0.0.1:1", nil, 2); err == nil { + t.Fatalf("expected nil config error") + } +} + +func TestDialTLSPoolSucceedsAndDialsSize(t *testing.T) { + t.Parallel() + srv := newFakeTLSServer(t, echoHandler()) + + clientCfg := &tls.Config{ + MinVersion: tls.VersionTLS12, + InsecureSkipVerify: true, //nolint:gosec // test-only + } + c, err := DialTLSPool(srv.addr(), clientCfg, 3) + if err != nil { + t.Fatalf("DialTLSPool: %v", err) + } + defer c.Close() + if len(c.pool.entries) != 3 { + t.Fatalf("pool entries: %d, want 3", len(c.pool.entries)) + } + resp, err := c.Send(map[string]interface{}{"type": "hello"}) + if err != nil { + t.Fatalf("send over TLS pool: %v", err) + } + if got, _ := resp["type"].(string); got != "ok" { + t.Fatalf("type: %q", got) + } +} + +func TestDialTLSPoolSizeOneIsSingleConn(t *testing.T) { + t.Parallel() + srv := newFakeTLSServer(t, echoHandler()) + clientCfg := &tls.Config{ + MinVersion: tls.VersionTLS12, + InsecureSkipVerify: true, //nolint:gosec // test-only + } + c, err := DialTLSPool(srv.addr(), clientCfg, 1) + if err != nil { + t.Fatalf("DialTLSPool size=1: %v", err) + } + defer c.Close() + if c.pool.free != nil { + t.Fatalf("size=1 should not initialise pool channel") + } +} + +func TestDialTLSPoolDialErrorWrapsMessage(t *testing.T) { + t.Parallel() + ln, _ := net.Listen("tcp", "127.0.0.1:0") + addr := ln.Addr().String() + ln.Close() + cfg := &tls.Config{InsecureSkipVerify: true, MinVersion: tls.VersionTLS12} //nolint:gosec // test-only + if _, err := DialTLSPool(addr, cfg, 2); err == nil { + t.Fatalf("expected DialTLSPool error on unreachable addr") + } +} + +// --- DialTLSPinned: full verify path ------------------------------------- + +func TestDialTLSPinnedAcceptsMatchingFingerprint(t *testing.T) { + t.Parallel() + srv := newFakeTLSServer(t, echoHandler()) + + sum := sha256.Sum256(srv.der) + fp := hex.EncodeToString(sum[:]) + + c, err := DialTLSPinned(srv.addr(), fp) + if err != nil { + t.Fatalf("DialTLSPinned (matching fp): %v", err) + } + defer c.Close() + resp, err := c.Send(map[string]interface{}{"type": "hi"}) + if err != nil { + t.Fatalf("send over pinned conn: %v", err) + } + if got, _ := resp["type"].(string); got != "ok" { + t.Fatalf("type: %q", got) + } +} + +func TestDialTLSPinnedRejectsMismatchedFingerprint(t *testing.T) { + t.Parallel() + srv := newFakeTLSServer(t, echoHandler()) + + _, err := DialTLSPinned(srv.addr(), "00112233445566778899aabbccddeeff") + if err == nil { + t.Fatalf("expected fingerprint mismatch error") + } + if !strings.Contains(err.Error(), "fingerprint mismatch") && + !strings.Contains(err.Error(), "dial registry TLS pinned") { + t.Fatalf("error should mention fingerprint mismatch or pinned dial: %v", err) + } +} + +// --- Concurrent Send under -race confirms the regConn mutex is real ------ + +func TestSendConcurrentRaceFreeOnSingleConn(t *testing.T) { + t.Parallel() + var counter atomic.Uint64 + srv := newFakeJSONServer(t, func(_ map[string]interface{}) map[string]interface{} { + counter.Add(1) + return map[string]interface{}{"type": "ok"} + }) + defer srv.close() + + c, err := Dial(srv.addr()) + if err != nil { + t.Fatalf("dial: %v", err) + } + defer c.Close() + + const goroutines = 16 + const callsEach = 25 + var wg sync.WaitGroup + wg.Add(goroutines) + for i := 0; i < goroutines; i++ { + go func() { + defer wg.Done() + for j := 0; j < callsEach; j++ { + if _, err := c.Send(map[string]interface{}{"type": "x"}); err != nil { + t.Errorf("send: %v", err) + return + } + } + }() + } + wg.Wait() + if got := counter.Load(); got != uint64(goroutines*callsEach) { + t.Fatalf("server saw %d requests, want %d", got, goroutines*callsEach) + } +} + +func TestSendConcurrentRaceFreeOnPool(t *testing.T) { + t.Parallel() + var counter atomic.Uint64 + srv := newFakeJSONServer(t, func(_ map[string]interface{}) map[string]interface{} { + counter.Add(1) + return map[string]interface{}{"type": "ok"} + }) + defer srv.close() + + c, err := DialPool(srv.addr(), 4) + if err != nil { + t.Fatalf("DialPool: %v", err) + } + defer c.Close() + + const goroutines = 16 + const callsEach = 25 + var wg sync.WaitGroup + wg.Add(goroutines) + for i := 0; i < goroutines; i++ { + go func() { + defer wg.Done() + for j := 0; j < callsEach; j++ { + if _, err := c.Send(map[string]interface{}{"type": "x"}); err != nil { + t.Errorf("send: %v", err) + return + } + } + }() + } + wg.Wait() + if got := counter.Load(); got != uint64(goroutines*callsEach) { + t.Fatalf("server saw %d requests, want %d", got, goroutines*callsEach) + } +} + +// --- Send (single-conn) reconnect-failure branch ------------------------ + +// TestSendReconnectFailureSurfacesWrappedError covers the legacy-path +// branch: send fails, reconnect also fails → Client returns a "send failed +// and reconnect failed" wrap. We close the server first, then make Send +// hit a dead conn — both attempts (initial + reconnect) fail. +func TestSendReconnectFailureSurfacesWrappedError(t *testing.T) { + t.Parallel() + srv := newFakeJSONServer(t, echoHandler()) + c, err := Dial(srv.addr()) + if err != nil { + t.Fatalf("dial: %v", err) + } + defer c.Close() + + // Kill the server first so the reconnect dial inside Send also fails. + srv.close() + // Also close the local end so the very first WriteMessage errors quickly. + c.mu.Lock() + _ = c.conn.Close() + c.mu.Unlock() + + _, err = c.Send(map[string]interface{}{"type": "x"}) + if err == nil { + t.Fatalf("expected error when both send and reconnect fail") + } + // Could be either "send failed and reconnect failed" or a raw send/recv + // error — accept any failure. + if err.Error() == "" { + t.Fatalf("error message must not be empty") + } +} + +// TestPoolSendReconnectFailureSurfacesWrappedError is the pool-path analogue. +func TestPoolSendReconnectFailureSurfacesWrappedError(t *testing.T) { + t.Parallel() + srv := newFakeJSONServer(t, echoHandler()) + c, err := DialPool(srv.addr(), 2) + if err != nil { + t.Fatalf("DialPool: %v", err) + } + defer c.Close() + + // Kill the server, then kill both pool entries' conns so the round-trip + // fails AND reconnectEntry's dial fails. + srv.close() + for _, e := range c.pool.entries { + e.mu.Lock() + _ = e.conn.Close() + e.mu.Unlock() + } + + _, err = c.Send(map[string]interface{}{"type": "x"}) + if err == nil { + t.Fatalf("expected pool-path error when both send and reconnect fail") + } +} + +// --- Close: pool with secondary entries ----------------------------------- + +func TestClosePoolReleasesAllSecondaryConns(t *testing.T) { + t.Parallel() + srv := newFakeJSONServer(t, echoHandler()) + defer srv.close() + + c, err := DialPool(srv.addr(), 3) + if err != nil { + t.Fatalf("DialPool: %v", err) + } + + // All entries should currently have non-nil conns. + conns := make([]net.Conn, len(c.pool.entries)) + for i, e := range c.pool.entries { + conns[i] = e.conn + } + + if err := c.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + + // Every conn (primary + secondary) should now be unusable. + for i, conn := range conns { + if _, err := conn.Write([]byte{0}); err == nil { + t.Fatalf("conn %d should be closed after Close()", i) + } + } +} + +// --- Misc small branch fills --------------------------------------------- + +// Verify the helper Send returns a "client closed" error when isClosed +// is true AND we still try to send (covers the closed-guard inside sendPool). +func TestSendPoolReturnsClosedErrorBeforeAcquire(t *testing.T) { + t.Parallel() + srv := newFakeJSONServer(t, echoHandler()) + defer srv.close() + + c, err := DialPool(srv.addr(), 2) + if err != nil { + t.Fatalf("DialPool: %v", err) + } + c.Close() + + _, err = c.Send(map[string]interface{}{"type": "x"}) + if err == nil || !strings.Contains(err.Error(), "closed") { + t.Fatalf("expected closed error, got: %v", err) + } +} + +// Smoke-test the RegisterWithKeyOpts RelayOnly + LANAddrs branch (the only +// "false" branches not exercised by existing tests). +func TestRegisterWithKeyOptsRelayOnlySerialized(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + resp, err := c.RegisterWithKeyOpts(RegisterOpts{ + ListenAddr: "x:1", + PublicKey: "PUB", + LANAddrs: []string{"10.0.0.1:1"}, + RelayOnly: true, + }) + if err != nil { + t.Fatalf("register: %v", err) + } + echo := assertEcho(t, resp) + if got, _ := echo["relay_only"].(bool); !got { + t.Fatalf("relay_only: %v", got) + } + if _, ok := echo["owner"]; ok { + t.Fatalf("owner should be omitted when blank") + } +} + +// Ensure RegisterWithKey with multiple version variadic args picks the first +// non-empty (firstNonEmpty branch). +func TestRegisterWithKeyFirstNonEmptyVersion(t *testing.T) { + t.Parallel() + c, _ := echoOnlyClient(t) + resp, err := c.RegisterWithKey("x:1", "PUB", "", nil, "", "", "v2.0.0", "v3.0.0") + if err != nil { + t.Fatalf("register: %v", err) + } + echo := assertEcho(t, resp) + if got, _ := echo["version"].(string); got != "v2.0.0" { + t.Fatalf("version: want v2.0.0 (first non-empty), got %q", got) + } +} diff --git a/registry/client/zz_client_wire_test.go b/registry/client/zz_client_wire_test.go new file mode 100644 index 0000000..b8cd93f --- /dev/null +++ b/registry/client/zz_client_wire_test.go @@ -0,0 +1,618 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package client + +import ( + "crypto/tls" + "encoding/binary" + "encoding/json" + "io" + "net" + "strings" + "sync/atomic" + "testing" +) + +// fakeJSONServer speaks the registry JSON-over-TCP wire protocol +// (4-byte big-endian length prefix + JSON body). Each connection handshake +// is dispatched to a handler callback that can read the request and write +// a reply. +type fakeJSONServer struct { + ln net.Listener + handler func(req map[string]interface{}) map[string]interface{} + requests atomic.Uint32 + connections atomic.Uint32 + done chan struct{} +} + +func newFakeJSONServer(t *testing.T, handler func(req map[string]interface{}) map[string]interface{}) *fakeJSONServer { + t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + s := &fakeJSONServer{ln: ln, handler: handler, done: make(chan struct{})} + go s.accept() + return s +} + +func (s *fakeJSONServer) addr() string { return s.ln.Addr().String() } + +func (s *fakeJSONServer) close() { s.ln.Close(); close(s.done) } + +func (s *fakeJSONServer) accept() { + for { + conn, err := s.ln.Accept() + if err != nil { + return + } + s.connections.Add(1) + go s.handle(conn) + } +} + +func (s *fakeJSONServer) handle(conn net.Conn) { + defer conn.Close() + for { + var lenBuf [4]byte + if _, err := io.ReadFull(conn, lenBuf[:]); err != nil { + return + } + n := binary.BigEndian.Uint32(lenBuf[:]) + // Defensive cap: any caller that sends non-JSON framing (e.g. TLS + // ClientHello) would otherwise block this goroutine in io.ReadFull + // until the full test timeout. + if n > 1<<20 { + return + } + body := make([]byte, n) + if _, err := io.ReadFull(conn, body); err != nil { + return + } + var req map[string]interface{} + if err := json.Unmarshal(body, &req); err != nil { + return + } + s.requests.Add(1) + resp := s.handler(req) + if resp == nil { + return + } + out, _ := json.Marshal(resp) + var outLen [4]byte + binary.BigEndian.PutUint32(outLen[:], uint32(len(out))) + conn.Write(outLen[:]) + conn.Write(out) + } +} + +// Echo the request type, plus include every key that was sent, under "echo". +// Tests can assert that the wire payload carried the right keys. +func echoHandler() func(map[string]interface{}) map[string]interface{} { + return func(req map[string]interface{}) map[string]interface{} { + resp := map[string]interface{}{"type": "ok", "echo": req} + return resp + } +} + +// --- Dial / Close / Addr ---------------------------------------------------- + +func TestDialSuccessReturnsClientWithAddr(t *testing.T) { + t.Parallel() + srv := newFakeJSONServer(t, echoHandler()) + defer srv.close() + + c, err := Dial(srv.addr()) + if err != nil { + t.Fatalf("dial: %v", err) + } + defer c.Close() + + if c.addr != srv.addr() { + t.Fatalf("addr: want %q, got %q", srv.addr(), c.addr) + } + if c.conn == nil { + t.Fatalf("conn should be set") + } +} + +func TestDialErrorOnBadAddress(t *testing.T) { + t.Parallel() + // Grab a port from the kernel and immediately release it so Dial + // fails fast with ECONNREFUSED on loopback (no DNS/route wait). + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + addr := ln.Addr().String() + ln.Close() + + _, err = Dial(addr) + if err == nil { + t.Fatalf("expected error") + } + if !strings.Contains(err.Error(), "dial registry") { + t.Fatalf("error should mention dial registry: %v", err) + } +} + +func TestDialTLSReturnsErrorWhenConfigNil(t *testing.T) { + t.Parallel() + if _, err := DialTLS("127.0.0.1:1", nil); err == nil { + t.Fatalf("expected error on nil tlsConfig") + } +} + +// closeOnAcceptListener accepts each connection and immediately closes it, so +// a TLS dial against it fails fast with EOF during the handshake. +func closeOnAcceptListener(t *testing.T) (addr string, stop func()) { + t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return + } + conn.Close() + } + }() + return ln.Addr().String(), func() { ln.Close() } +} + +func TestDialTLSFailsConnectToPlainServer(t *testing.T) { + t.Parallel() + addr, stop := closeOnAcceptListener(t) + defer stop() + _, err := DialTLS(addr, minimalTLSConfig()) + if err == nil { + t.Fatalf("expected TLS error") + } + if !strings.Contains(err.Error(), "dial registry TLS") { + t.Fatalf("error should mention TLS dial: %v", err) + } +} + +func TestDialTLSPinnedFailsConnectToPlainServer(t *testing.T) { + t.Parallel() + addr, stop := closeOnAcceptListener(t) + defer stop() + _, err := DialTLSPinned(addr, "deadbeef") + if err == nil { + t.Fatalf("expected TLS pin error") + } + if !strings.Contains(err.Error(), "dial registry TLS pinned") { + t.Fatalf("error should mention TLS pinned dial: %v", err) + } +} + +func TestCloseSafeWhenNilConn(t *testing.T) { + t.Parallel() + c := &Client{} + if err := c.Close(); err != nil { + t.Fatalf("Close on empty client should not error: %v", err) + } + if !c.closed { + t.Fatalf("client should report closed after Close()") + } +} + +func TestCloseClosesRealConn(t *testing.T) { + t.Parallel() + srv := newFakeJSONServer(t, echoHandler()) + defer srv.close() + c, err := Dial(srv.addr()) + if err != nil { + t.Fatalf("dial: %v", err) + } + if err := c.Close(); err != nil { + t.Fatalf("close: %v", err) + } + // After Close, conn.Write should error. + if _, err := c.conn.Write([]byte{0}); err == nil { + t.Fatalf("expected write error after Close") + } +} + +// --- Signer ----------------------------------------------------------------- + +func TestSetSignerReturnsSignature(t *testing.T) { + t.Parallel() + c := &Client{} + sig, err := c.sign("whatever") + if err == nil { + t.Fatalf("expected error with no signer, got sig=%q", sig) + } + c.SetSigner(func(challenge string) string { + return "sig(" + challenge + ")" + }) + sig, err = c.sign("abc") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if sig != "sig(abc)" { + t.Fatalf("expected sig(abc), got %q", sig) + } +} + +func TestResolveIncludesSignatureWhenSignerSet(t *testing.T) { + t.Parallel() + var gotChallenge string + srv := newFakeJSONServer(t, echoHandler()) + defer srv.close() + + c, err := Dial(srv.addr()) + if err != nil { + t.Fatalf("dial: %v", err) + } + defer c.Close() + + c.SetSigner(func(challenge string) string { + gotChallenge = challenge + return "SIG" + }) + resp, err := c.Resolve(42, 7) + if err != nil { + t.Fatalf("resolve: %v", err) + } + if gotChallenge != "resolve:7:42" { + t.Fatalf("challenge: want resolve:7:42, got %q", gotChallenge) + } + echo, _ := resp["echo"].(map[string]interface{}) + if sig, _ := echo["signature"].(string); sig != "SIG" { + t.Fatalf("signature wire value: want SIG, got %q", sig) + } +} + +// --- Send / sendLocked ------------------------------------------------------ + +func TestSendReturnsServerErrorResponse(t *testing.T) { + t.Parallel() + srv := newFakeJSONServer(t, func(_ map[string]interface{}) map[string]interface{} { + return map[string]interface{}{"error": "boom"} + }) + defer srv.close() + + c, err := Dial(srv.addr()) + if err != nil { + t.Fatalf("dial: %v", err) + } + defer c.Close() + + resp, err := c.Send(map[string]interface{}{"type": "ping"}) + if err == nil { + t.Fatalf("expected server error") + } + if !strings.Contains(err.Error(), "boom") { + t.Fatalf("error should contain 'boom': %v", err) + } + // resp is non-nil for server errors so the caller can inspect it. + if resp == nil { + t.Fatalf("expected non-nil response on server error") + } +} + +func TestSendHappyPath(t *testing.T) { + t.Parallel() + srv := newFakeJSONServer(t, echoHandler()) + defer srv.close() + c, err := Dial(srv.addr()) + if err != nil { + t.Fatalf("dial: %v", err) + } + defer c.Close() + + resp, err := c.Send(map[string]interface{}{"type": "hello", "num": float64(3)}) + if err != nil { + t.Fatalf("send: %v", err) + } + if got, _ := resp["type"].(string); got != "ok" { + t.Fatalf("type: want ok, got %q", got) + } +} + +func TestSendReconnectsAfterDroppedConnection(t *testing.T) { + t.Parallel() + srv := newFakeJSONServer(t, echoHandler()) + defer srv.close() + + c, err := Dial(srv.addr()) + if err != nil { + t.Fatalf("dial: %v", err) + } + defer c.Close() + + // Simulate a connection-level failure by closing the client's conn + // without marking the Client closed. The next Send should reconnect. + c.mu.Lock() + _ = c.conn.Close() + c.mu.Unlock() + + resp, err := c.Send(map[string]interface{}{"type": "hello"}) + if err != nil { + t.Fatalf("send after reconnect: %v", err) + } + if got, _ := resp["type"].(string); got != "ok" { + t.Fatalf("type: want ok, got %q", got) + } + if srv.connections.Load() < 2 { + t.Fatalf("expected second connection from reconnect, got %d", srv.connections.Load()) + } +} + +func TestSendFailsWhenClosed(t *testing.T) { + t.Parallel() + srv := newFakeJSONServer(t, echoHandler()) + defer srv.close() + c, err := Dial(srv.addr()) + if err != nil { + t.Fatalf("dial: %v", err) + } + c.Close() + + _, err = c.Send(map[string]interface{}{"type": "hello"}) + if err == nil { + t.Fatalf("expected error after Close") + } +} + +// --- Register family -------------------------------------------------------- + +func TestRegisterSendsCorrectWireMessage(t *testing.T) { + t.Parallel() + srv := newFakeJSONServer(t, echoHandler()) + defer srv.close() + c, err := Dial(srv.addr()) + if err != nil { + t.Fatalf("dial: %v", err) + } + defer c.Close() + + resp, err := c.Register("1.2.3.4:4000") + if err != nil { + t.Fatalf("register: %v", err) + } + echo, _ := resp["echo"].(map[string]interface{}) + if got, _ := echo["type"].(string); got != "register" { + t.Fatalf("wire type: want register, got %q", got) + } + if got, _ := echo["listen_addr"].(string); got != "1.2.3.4:4000" { + t.Fatalf("listen_addr: want 1.2.3.4:4000, got %q", got) + } +} + +func TestRegisterWithOwnerIncludesOwner(t *testing.T) { + t.Parallel() + srv := newFakeJSONServer(t, echoHandler()) + defer srv.close() + c, _ := Dial(srv.addr()) + defer c.Close() + + resp, err := c.RegisterWithOwner("x:1", "alice@example.com") + if err != nil { + t.Fatalf("register: %v", err) + } + echo, _ := resp["echo"].(map[string]interface{}) + if got, _ := echo["owner"].(string); got != "alice@example.com" { + t.Fatalf("owner: %q", got) + } +} + +func TestRegisterWithKeyOmitsBlankOwnerAndLAN(t *testing.T) { + t.Parallel() + srv := newFakeJSONServer(t, echoHandler()) + defer srv.close() + c, _ := Dial(srv.addr()) + defer c.Close() + + resp, err := c.RegisterWithKey("x:1", "PUB==", "", nil) + if err != nil { + t.Fatalf("register: %v", err) + } + echo, _ := resp["echo"].(map[string]interface{}) + if _, ok := echo["owner"]; ok { + t.Fatalf("owner should be omitted when blank") + } + if _, ok := echo["lan_addrs"]; ok { + t.Fatalf("lan_addrs should be omitted when empty") + } + if _, ok := echo["version"]; ok { + t.Fatalf("version should be omitted when not supplied") + } + if got, _ := echo["public_key"].(string); got != "PUB==" { + t.Fatalf("public_key: %q", got) + } +} + +func TestRegisterWithKeyIncludesAllFields(t *testing.T) { + t.Parallel() + srv := newFakeJSONServer(t, echoHandler()) + defer srv.close() + c, _ := Dial(srv.addr()) + defer c.Close() + + resp, err := c.RegisterWithKey("x:1", "PUB==", "bob", []string{"10.0.0.1:80"}, "v1.2.3") + if err != nil { + t.Fatalf("register: %v", err) + } + echo, _ := resp["echo"].(map[string]interface{}) + if got, _ := echo["owner"].(string); got != "bob" { + t.Fatalf("owner: %q", got) + } + if got, _ := echo["version"].(string); got != "v1.2.3" { + t.Fatalf("version: %q", got) + } + lan, _ := echo["lan_addrs"].([]interface{}) + if len(lan) != 1 || lan[0] != "10.0.0.1:80" { + t.Fatalf("lan_addrs: %v", lan) + } +} + +// --- Lookup / Resolve / ReportTrust / RevokeTrust / SetVisibility ---------- + +func TestLookupSendsNodeID(t *testing.T) { + t.Parallel() + srv := newFakeJSONServer(t, echoHandler()) + defer srv.close() + c, _ := Dial(srv.addr()) + defer c.Close() + + resp, err := c.Lookup(42) + if err != nil { + t.Fatalf("lookup: %v", err) + } + echo, _ := resp["echo"].(map[string]interface{}) + if got, _ := echo["type"].(string); got != "lookup" { + t.Fatalf("type: %q", got) + } + if got := uint32(echo["node_id"].(float64)); got != 42 { + t.Fatalf("node_id: %d", got) + } +} + +func TestReportTrustAndRevokeTrustFormat(t *testing.T) { + t.Parallel() + srv := newFakeJSONServer(t, echoHandler()) + defer srv.close() + c, _ := Dial(srv.addr()) + defer c.Close() + c.SetSigner(func(ch string) string { return "SIG:" + ch }) + + for name, fn := range map[string]func() (map[string]interface{}, error){ + "report_trust": func() (map[string]interface{}, error) { return c.ReportTrust(1, 2) }, + "revoke_trust": func() (map[string]interface{}, error) { return c.RevokeTrust(1, 2) }, + } { + resp, err := fn() + if err != nil { + t.Fatalf("%s: %v", name, err) + } + echo, _ := resp["echo"].(map[string]interface{}) + if got, _ := echo["type"].(string); got != name { + t.Fatalf("%s: type=%q", name, got) + } + if got := uint32(echo["node_id"].(float64)); got != 1 { + t.Fatalf("%s: node_id=%d", name, got) + } + if got := uint32(echo["peer_id"].(float64)); got != 2 { + t.Fatalf("%s: peer_id=%d", name, got) + } + } +} + +func TestSetVisibilityPublicFlagSerialized(t *testing.T) { + t.Parallel() + srv := newFakeJSONServer(t, echoHandler()) + defer srv.close() + c, _ := Dial(srv.addr()) + defer c.Close() + c.SetSigner(func(ch string) string { return "SIG:" + ch }) + + resp, err := c.SetVisibility(9, true) + if err != nil { + t.Fatalf("set_visibility: %v", err) + } + echo, _ := resp["echo"].(map[string]interface{}) + if got, _ := echo["public"].(bool); got != true { + t.Fatalf("public: %v", got) + } +} + +// --- CreateNetwork / CreateManagedNetwork ---------------------------------- + +func TestCreateNetworkBasicAndFull(t *testing.T) { + t.Parallel() + srv := newFakeJSONServer(t, echoHandler()) + defer srv.close() + c, _ := Dial(srv.addr()) + defer c.Close() + + // Basic: no adminToken, no enterprise, no networkAdminToken. + resp, err := c.CreateNetwork(1, "foo", "public", "tok", "", false) + if err != nil { + t.Fatalf("create: %v", err) + } + echo, _ := resp["echo"].(map[string]interface{}) + if _, ok := echo["admin_token"]; ok { + t.Fatalf("admin_token should be omitted when blank") + } + if _, ok := echo["enterprise"]; ok { + t.Fatalf("enterprise should be omitted when false") + } + if _, ok := echo["network_admin_token"]; ok { + t.Fatalf("network_admin_token should be omitted when not supplied") + } + + // Full: adminToken + enterprise + networkAdminToken. + resp, err = c.CreateNetwork(1, "foo", "public", "tok", "ADM", true, "NAT") + if err != nil { + t.Fatalf("create full: %v", err) + } + echo, _ = resp["echo"].(map[string]interface{}) + if got, _ := echo["admin_token"].(string); got != "ADM" { + t.Fatalf("admin_token: %q", got) + } + if got, _ := echo["enterprise"].(bool); !got { + t.Fatalf("enterprise: %v", got) + } + if got, _ := echo["network_admin_token"].(string); got != "NAT" { + t.Fatalf("network_admin_token: %q", got) + } +} + +func TestCreateManagedNetworkIncludesRules(t *testing.T) { + t.Parallel() + srv := newFakeJSONServer(t, echoHandler()) + defer srv.close() + c, _ := Dial(srv.addr()) + defer c.Close() + + resp, err := c.CreateManagedNetwork(2, "n", "invite", "tok", "", false, `{"a":1}`) + if err != nil { + t.Fatalf("managed: %v", err) + } + echo, _ := resp["echo"].(map[string]interface{}) + if got, _ := echo["rules"].(string); got != `{"a":1}` { + t.Fatalf("rules: %q", got) + } +} + +// --- RotateKey -------------------------------------------------------------- + +func TestRotateKeyOmitsBlankSignatureAndPubKey(t *testing.T) { + t.Parallel() + srv := newFakeJSONServer(t, echoHandler()) + defer srv.close() + c, _ := Dial(srv.addr()) + defer c.Close() + + resp, err := c.RotateKey(7, "", "") + if err != nil { + t.Fatalf("rotate: %v", err) + } + echo, _ := resp["echo"].(map[string]interface{}) + if _, ok := echo["signature"]; ok { + t.Fatalf("signature should be omitted when blank") + } + if _, ok := echo["new_public_key"]; ok { + t.Fatalf("new_public_key should be omitted when blank") + } + + resp, err = c.RotateKey(7, "SIG", "NPK") + if err != nil { + t.Fatalf("rotate full: %v", err) + } + echo, _ = resp["echo"].(map[string]interface{}) + if got, _ := echo["signature"].(string); got != "SIG" { + t.Fatalf("signature: %q", got) + } + if got, _ := echo["new_public_key"].(string); got != "NPK" { + t.Fatalf("new_public_key: %q", got) + } +} + +func minimalTLSConfig() *tls.Config { + return &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: true} //nolint:gosec // test-only +} diff --git a/registry/wire/blueprint.go b/registry/wire/blueprint.go new file mode 100644 index 0000000..dd5312c --- /dev/null +++ b/registry/wire/blueprint.go @@ -0,0 +1,178 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package wire + +import ( + "encoding/json" + "fmt" + "os" + + "github.com/pilot-protocol/common/urlvalidate" +) + +// NetworkBlueprint defines a declarative configuration for provisioning +// an enterprise network. Enterprises apply blueprints via the registry +// protocol or the pilotctl CLI to create and configure networks in one shot. +type NetworkBlueprint struct { + // Network settings + Name string `json:"name"` + JoinRule string `json:"join_rule,omitempty"` // "open", "token", "invite" (default: "open") + JoinToken string `json:"join_token,omitempty"` // required if join_rule = "token" + Enterprise bool `json:"enterprise,omitempty"` // enable enterprise features + + // Policy + Policy *BlueprintPolicy `json:"policy,omitempty"` + ExprPolicy json.RawMessage `json:"expr_policy,omitempty"` + + // RBAC pre-assignments (by external_id) + Roles []BlueprintRole `json:"roles,omitempty"` + + // Identity provider configuration + IdentityProvider *BlueprintIdentityProvider `json:"identity_provider,omitempty"` + + // Observability + Webhooks *BlueprintWebhooks `json:"webhooks,omitempty"` + + // Audit export + AuditExport *BlueprintAuditExport `json:"audit_export,omitempty"` + + // Per-network admin token (optional override) + NetworkAdminToken string `json:"network_admin_token,omitempty"` +} + +// BlueprintPolicy defines the network policy section of a blueprint. +type BlueprintPolicy struct { + MaxMembers int `json:"max_members,omitempty"` + AllowedPorts []uint16 `json:"allowed_ports,omitempty"` + Description string `json:"description,omitempty"` +} + +// BlueprintRole pre-assigns RBAC roles by external identity. +type BlueprintRole struct { + ExternalID string `json:"external_id"` + Role string `json:"role"` // "owner", "admin", "member" +} + +// BlueprintIdentityProvider configures external identity verification. +type BlueprintIdentityProvider struct { + Type string `json:"type"` // "oidc", "saml", "webhook", "entra_id", "ldap" + URL string `json:"url"` // verification endpoint + Issuer string `json:"issuer,omitempty"` // OIDC issuer URL + ClientID string `json:"client_id,omitempty"` // OIDC client ID + TenantID string `json:"tenant_id,omitempty"` // Azure AD / Entra ID tenant + Domain string `json:"domain,omitempty"` // LDAP domain +} + +// BlueprintWebhooks configures webhook endpoints for the network. +type BlueprintWebhooks struct { + AuditURL string `json:"audit_url,omitempty"` // audit event webhook + IdentityURL string `json:"identity_url,omitempty"` // identity verification webhook +} + +// BlueprintAuditExport configures external audit log export. +type BlueprintAuditExport struct { + Format string `json:"format"` // "json", "splunk_hec", "syslog_cef" + Endpoint string `json:"endpoint"` // destination URL or address + Token string `json:"token,omitempty"` // auth token (e.g., Splunk HEC token) + Index string `json:"index,omitempty"` // Splunk index + Source string `json:"source,omitempty"` // source identifier +} + +// LoadBlueprint reads a network blueprint from a JSON file. +func LoadBlueprint(path string) (*NetworkBlueprint, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read blueprint: %w", err) + } + var bp NetworkBlueprint + if err := json.Unmarshal(data, &bp); err != nil { + return nil, fmt.Errorf("parse blueprint: %w", err) + } + if bp.Name == "" { + return nil, fmt.Errorf("blueprint: name is required") + } + return &bp, nil +} + +// ValidateBlueprint checks a blueprint for configuration errors. +func ValidateBlueprint(bp *NetworkBlueprint) error { + if bp.Name == "" { + return fmt.Errorf("name is required") + } + switch bp.JoinRule { + case "", "open", "token", "invite": + default: + return fmt.Errorf("invalid join_rule %q (must be open, token, or invite)", bp.JoinRule) + } + if bp.JoinRule == "token" && bp.JoinToken == "" { + return fmt.Errorf("join_token is required when join_rule is token") + } + for _, r := range bp.Roles { + if r.ExternalID == "" { + return fmt.Errorf("role entry: external_id is required") + } + switch r.Role { + case "owner", "admin", "member": + default: + return fmt.Errorf("invalid role %q for %s", r.Role, r.ExternalID) + } + } + if bp.IdentityProvider != nil { + switch bp.IdentityProvider.Type { + case "oidc", "saml", "webhook", "entra_id", "ldap": + default: + return fmt.Errorf("invalid identity_provider type %q", bp.IdentityProvider.Type) + } + if bp.IdentityProvider.URL == "" { + return fmt.Errorf("identity_provider.url is required") + } + if err := urlvalidate.Validate(bp.IdentityProvider.URL); err != nil { + return fmt.Errorf("identity_provider.url: %w", err) + } + } + if bp.Webhooks != nil { + if bp.Webhooks.AuditURL != "" { + if err := urlvalidate.Validate(bp.Webhooks.AuditURL); err != nil { + return fmt.Errorf("webhooks.audit_url: %w", err) + } + } + if bp.Webhooks.IdentityURL != "" { + if err := urlvalidate.Validate(bp.Webhooks.IdentityURL); err != nil { + return fmt.Errorf("webhooks.identity_url: %w", err) + } + } + } + if bp.AuditExport != nil { + switch bp.AuditExport.Format { + case "json", "splunk_hec", "syslog_cef": + default: + return fmt.Errorf("invalid audit_export format %q", bp.AuditExport.Format) + } + if bp.AuditExport.Endpoint == "" { + return fmt.Errorf("audit_export.endpoint is required") + } + // syslog_cef sinks accept raw host:port targets; only the HTTP(S) + // formats need SSRF validation. + if bp.AuditExport.Format == "json" || bp.AuditExport.Format == "splunk_hec" { + if err := urlvalidate.Validate(bp.AuditExport.Endpoint); err != nil { + return fmt.Errorf("audit_export.endpoint: %w", err) + } + } + } + if len(bp.ExprPolicy) > 0 { + var check struct { + Version int `json:"version"` + Rules json.RawMessage `json:"rules"` + } + if err := json.Unmarshal(bp.ExprPolicy, &check); err != nil { + return fmt.Errorf("expr_policy: invalid JSON: %w", err) + } + if check.Version != 1 { + return fmt.Errorf("expr_policy: unsupported version %d (want 1)", check.Version) + } + if len(check.Rules) == 0 || string(check.Rules) == "null" { + return fmt.Errorf("expr_policy: at least one rule is required") + } + } + return nil +} diff --git a/registry/wire/rules.go b/registry/wire/rules.go new file mode 100644 index 0000000..f19e566 --- /dev/null +++ b/registry/wire/rules.go @@ -0,0 +1,204 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package wire + +import ( + "encoding/json" + "fmt" + "time" +) + +// NetworkRules defines the managed network ruleset. When set on a NetworkInfo, +// the network becomes "managed" — daemon-local link lifecycle is governed by +// these rules. The registry only stores and distributes the rules; all cycle +// logic runs inside each daemon. +type NetworkRules struct { + Links int `json:"links"` // max managed peers per node + Cycle string `json:"cycle"` // Go duration: "24h", "1h" + Prune int `json:"prune"` // how many to drop per cycle + PruneBy string `json:"prune_by"` // "score", "age", "activity" + Fill int `json:"fill"` // how many to add per cycle + FillHow string `json:"fill_how"` // "random" + Grace string `json:"grace,omitempty"` // grace period for new members +} + +// ValidateRules checks that a NetworkRules is well-formed. Returns nil if valid. +func ValidateRules(r *NetworkRules) error { + if r == nil { + return nil + } + if r.Links < 1 { + return fmt.Errorf("rules: links must be >= 1 (got %d)", r.Links) + } + if r.Cycle == "" { + return fmt.Errorf("rules: cycle is required") + } + d, err := time.ParseDuration(r.Cycle) + if err != nil { + return fmt.Errorf("rules: invalid cycle duration %q: %w", r.Cycle, err) + } + if d < 1*time.Minute { + return fmt.Errorf("rules: cycle must be >= 1m (got %s)", r.Cycle) + } + if r.Prune < 0 { + return fmt.Errorf("rules: prune must be >= 0 (got %d)", r.Prune) + } + if r.Fill < 0 { + return fmt.Errorf("rules: fill must be >= 0 (got %d)", r.Fill) + } + if r.Prune > r.Links { + return fmt.Errorf("rules: prune (%d) cannot exceed links (%d)", r.Prune, r.Links) + } + if r.Fill > r.Links { + return fmt.Errorf("rules: fill (%d) cannot exceed links (%d)", r.Fill, r.Links) + } + + switch r.PruneBy { + case "score", "age", "activity": + // valid + case "": + return fmt.Errorf("rules: prune_by is required") + default: + return fmt.Errorf("rules: unknown prune_by strategy %q (want score|age|activity)", r.PruneBy) + } + + switch r.FillHow { + case "random": + // valid + case "": + return fmt.Errorf("rules: fill_how is required") + default: + return fmt.Errorf("rules: unknown fill_how strategy %q (want random)", r.FillHow) + } + + if r.Grace != "" { + g, err := time.ParseDuration(r.Grace) + if err != nil { + return fmt.Errorf("rules: invalid grace duration %q: %w", r.Grace, err) + } + if g < 0 { + return fmt.Errorf("rules: grace must be >= 0") + } + } + + return nil +} + +// ParseRules unmarshals a JSON string into NetworkRules and validates it. +func ParseRules(raw string) (*NetworkRules, error) { + var r NetworkRules + if err := json.Unmarshal([]byte(raw), &r); err != nil { + return nil, fmt.Errorf("rules: invalid JSON: %w", err) + } + if err := ValidateRules(&r); err != nil { + return nil, err + } + return &r, nil +} + +// RulesToPolicy converts a NetworkRules struct into a PolicyDocument JSON +// (json.RawMessage). This provides backward compatibility: existing managed +// networks continue to work through the policy engine. +func RulesToPolicy(r *NetworkRules) (json.RawMessage, error) { + if r == nil { + return nil, nil + } + + type action struct { + Type string `json:"type"` + Params map[string]interface{} `json:"params,omitempty"` + } + type rule struct { + Name string `json:"name"` + On string `json:"on"` + Match string `json:"match"` + Actions []action `json:"actions"` + } + type policyDoc struct { + Version int `json:"version"` + Config map[string]interface{} `json:"config,omitempty"` + Rules []rule `json:"rules"` + } + + doc := policyDoc{ + Version: 1, + Config: map[string]interface{}{ + "max_peers": r.Links, + "cycle": r.Cycle, + }, + Rules: []rule{ + { + Name: "cycle-prune-fill", + On: "cycle", + Match: "true", + Actions: []action{ + {Type: "prune", Params: map[string]interface{}{"count": r.Prune, "by": r.PruneBy}}, + {Type: "fill", Params: map[string]interface{}{"count": r.Fill, "how": r.FillHow}}, + }, + }, + }, + } + + if r.Grace != "" { + doc.Config["grace"] = r.Grace + } + + data, err := json.Marshal(doc) + if err != nil { + return nil, fmt.Errorf("rules-to-policy: %w", err) + } + return json.RawMessage(data), nil +} + +// AllowedPortsToPolicy converts a port allowlist into a PolicyDocument JSON +// (json.RawMessage). This replaces the old AllowedPorts mechanism with +// equivalent policy rules. +func AllowedPortsToPolicy(ports []uint16) (json.RawMessage, error) { + if len(ports) == 0 { + return nil, nil + } + + // Build match expression: "port in [80, 443, 1001]" + var buf []byte + buf = append(buf, "port in ["...) + for i, p := range ports { + if i > 0 { + buf = append(buf, ", "...) + } + buf = fmt.Appendf(buf, "%d", p) + } + buf = append(buf, ']') + matchExpr := string(buf) + + type action struct { + Type string `json:"type"` + } + type rule struct { + Name string `json:"name"` + On string `json:"on"` + Match string `json:"match"` + Actions []action `json:"actions"` + } + type policyDoc struct { + Version int `json:"version"` + Rules []rule `json:"rules"` + } + + doc := policyDoc{ + Version: 1, + Rules: []rule{ + {Name: "allow-ports", On: "connect", Match: matchExpr, Actions: []action{{Type: "allow"}}}, + {Name: "allow-ports-dg", On: "datagram", Match: matchExpr, Actions: []action{{Type: "allow"}}}, + {Name: "allow-ports-dial", On: "dial", Match: matchExpr, Actions: []action{{Type: "allow"}}}, + {Name: "deny-rest", On: "connect", Match: "true", Actions: []action{{Type: "deny"}}}, + {Name: "deny-rest-dg", On: "datagram", Match: "true", Actions: []action{{Type: "deny"}}}, + {Name: "deny-rest-dial", On: "dial", Match: "true", Actions: []action{{Type: "deny"}}}, + }, + } + + data, err := json.Marshal(doc) + if err != nil { + return nil, fmt.Errorf("ports-to-policy: %w", err) + } + return json.RawMessage(data), nil +} diff --git a/registry/wire/wire.go b/registry/wire/wire.go new file mode 100644 index 0000000..4ea64e3 --- /dev/null +++ b/registry/wire/wire.go @@ -0,0 +1,595 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +// Package wire defines the binary wire format shared between the registry +// client and server. It contains protocol constants, framing, and the +// encode/decode helpers that both sides use to talk over the same TCP +// connection. Pure types and functions — no networking, no logging, no I/O +// beyond the io.Reader/io.Writer abstractions used by the framing layer. +package wire + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "io" + "math" + "net" + "time" +) + +// maxCount caps wire-controlled list lengths to prevent a malicious +// peer from triggering large allocations (e.g. netCount=65535 → +// 130 KB make()). All frames are bounded by MaxMessageSize (64 MiB) +// but per-field allocations without caps can cause memory pressure +// before the overall frame limit is reached. +const maxCount = 1024 + +// WriteMessageDeadline bounds how long a single JSON response write can +// take. If a client is slow to drain (overloaded host, kernel buffer +// pressure) we'd otherwise hold the request goroutine + response payload +// in memory indefinitely. After this deadline expires, w.Write returns +// an error and the caller can drop the connection cleanly. +const WriteMessageDeadline = 5 * time.Second + +// MaxMessageSize is the maximum allowed wire message size (64 MiB). +// Messages exceeding this limit cause the connection to be closed. +// Note: must stay well below the binary wire magic (0x50494C54 ≈ 1.3B) +// for protocol auto-detection to work. Sized for full registry snapshot +// in subscribe_replication: ~26 MiB at 100k+ nodes, with headroom. +const MaxMessageSize = 64 * 1024 * 1024 + +// Binary wire format for high-throughput operations. +// +// Protocol negotiation: binary clients send magic 0x50494C54 ("PILT") + 1 byte +// version as the first 5 bytes of the connection. The server detects this vs a +// JSON length prefix (which is always < 64KB) and switches mode per-connection. +// +// Binary frame: [4B total_length][1B msg_type][payload] +// +// Message types: +// 0x00 = JSON passthrough (payload is JSON bytes) +// 0x01 = heartbeat request +// 0x81 = heartbeat response +// 0x02 = lookup request +// 0x82 = lookup response +// 0x03 = resolve request +// 0x83 = resolve response +// 0xFF = error response + +// Magic is the 4-byte magic sent by binary clients at connection start. +var Magic = [4]byte{0x50, 0x49, 0x4C, 0x54} // "PILT" + +// Version is the current binary protocol version. +const Version byte = 1 + +// Binary message type constants. +const ( + MsgJSON byte = 0x00 + MsgHeartbeat byte = 0x01 + MsgHeartbeatOK byte = 0x81 + MsgLookup byte = 0x02 + MsgLookupOK byte = 0x82 + MsgResolve byte = 0x03 + MsgResolveOK byte = 0x83 + MsgError byte = 0xFF +) + +// ReadFrame reads a single binary frame: [4B length][1B type][payload]. +func ReadFrame(r io.Reader) (msgType byte, payload []byte, err error) { + var hdr [5]byte + if _, err = io.ReadFull(r, hdr[:]); err != nil { + return 0, nil, err + } + length := binary.BigEndian.Uint32(hdr[:4]) + if length < 1 { + return 0, nil, fmt.Errorf("binary frame too short") + } + if length > MaxMessageSize { + return 0, nil, fmt.Errorf("binary frame too large: %d bytes (max %d)", length, MaxMessageSize) + } + msgType = hdr[4] + payloadLen := length - 1 // length includes the type byte + if payloadLen > 0 { + payload = make([]byte, payloadLen) + if _, err = io.ReadFull(r, payload); err != nil { + return 0, nil, err + } + } + return msgType, payload, nil +} + +// WriteFrame writes a single binary frame. +func WriteFrame(w io.Writer, msgType byte, payload []byte) error { + length := uint32(1 + len(payload)) // type byte + payload + var hdr [5]byte + binary.BigEndian.PutUint32(hdr[:4], length) + hdr[4] = msgType + if _, err := w.Write(hdr[:]); err != nil { + return err + } + if len(payload) > 0 { + if _, err := w.Write(payload); err != nil { + return err + } + } + return nil +} + +// --- Heartbeat --- + +// HeartbeatReq holds a decoded binary heartbeat request: [4B node_id][64B signature]. +type HeartbeatReq struct { + NodeID uint32 + Signature [64]byte +} + +// EncodeHeartbeatReq encodes a heartbeat request payload. +func EncodeHeartbeatReq(nodeID uint32, sig []byte) []byte { + buf := make([]byte, 4+64) + binary.BigEndian.PutUint32(buf[:4], nodeID) + copy(buf[4:], sig) + return buf +} + +// DecodeHeartbeatReq decodes a heartbeat request payload. +func DecodeHeartbeatReq(payload []byte) (HeartbeatReq, error) { + if len(payload) < 68 { + return HeartbeatReq{}, fmt.Errorf("heartbeat request too short: %d bytes", len(payload)) + } + var req HeartbeatReq + req.NodeID = binary.BigEndian.Uint32(payload[:4]) + copy(req.Signature[:], payload[4:68]) + return req, nil +} + +// EncodeHeartbeatResp encodes the heartbeat response: [8B unix_time][1B flags]. +// flags bit0 = key_expiry_warning. +func EncodeHeartbeatResp(unixTime int64, keyExpiryWarning bool) []byte { + buf := make([]byte, 9) + binary.BigEndian.PutUint64(buf[:8], uint64(unixTime)) + if keyExpiryWarning { + buf[8] = 1 + } + return buf +} + +// DecodeHeartbeatResp decodes a heartbeat response. +func DecodeHeartbeatResp(payload []byte) (unixTime int64, keyExpiryWarning bool, err error) { + if len(payload) < 9 { + return 0, false, fmt.Errorf("heartbeat response too short: %d bytes", len(payload)) + } + unixTime = int64(binary.BigEndian.Uint64(payload[:8])) + keyExpiryWarning = payload[8]&1 != 0 + return unixTime, keyExpiryWarning, nil +} + +// --- Lookup --- + +// EncodeLookupReq encodes a lookup request: [4B node_id]. +func EncodeLookupReq(nodeID uint32) []byte { + buf := make([]byte, 4) + binary.BigEndian.PutUint32(buf, nodeID) + return buf +} + +// DecodeLookupReq decodes a lookup request. +func DecodeLookupReq(payload []byte) (uint32, error) { + if len(payload) < 4 { + return 0, fmt.Errorf("lookup request too short: %d bytes", len(payload)) + } + return binary.BigEndian.Uint32(payload[:4]), nil +} + +// EncodeLookupResp encodes a lookup response. +// Format: [4B node_id][1B flags][4B reserved][2B net_count][net_ids...] +// +// [1B pubkey_len][pubkey...][1B hostname_len][hostname...] +// [1B tags_count][for each: 1B len, bytes...][2B addr_len][addr...] +// [1B extid_len][extid...] +// +// The 4-byte reserved field was formerly polo_score; it is written as zero +// and ignored on decode to preserve wire-format compatibility. +func EncodeLookupResp(nodeID uint32, public, taskExec bool, + networks []uint16, pubKey []byte, hostname string, tags []string, + realAddr string, externalID string) []byte { + + // Calculate size + size := 4 + 1 + 4 + 2 + len(networks)*2 // node_id + flags + reserved + nets + size += 1 + len(pubKey) // pubkey + size += 1 + len(hostname) // hostname + size += 1 // tags count + for _, t := range tags { + size += 1 + len(t) // tag len + tag + } + size += 2 + len(realAddr) // real_addr (only if public) + size += 1 + len(externalID) // external_id + + buf := make([]byte, 0, size) + + // node_id + buf = binary.BigEndian.AppendUint32(buf, nodeID) + + // flags + var flags byte + if public { + flags |= 0x01 + } + if taskExec { + flags |= 0x02 + } + buf = append(buf, flags) + + // reserved (was polo_score) — always zero + buf = binary.BigEndian.AppendUint32(buf, 0) + + // networks + buf = binary.BigEndian.AppendUint16(buf, uint16(len(networks))) + for _, n := range networks { + buf = binary.BigEndian.AppendUint16(buf, n) + } + + // pubkey + if len(pubKey) > 255 { + pubKey = pubKey[:255] + } + buf = append(buf, byte(len(pubKey))) + buf = append(buf, pubKey...) + + // hostname + if len(hostname) > 255 { + hostname = hostname[:255] + } + buf = append(buf, byte(len(hostname))) + buf = append(buf, []byte(hostname)...) + + // tags + if len(tags) > 255 { + tags = tags[:255] + } + buf = append(buf, byte(len(tags))) + for _, t := range tags { + if len(t) > 255 { + t = t[:255] + } + buf = append(buf, byte(len(t))) + buf = append(buf, []byte(t)...) + } + + // real_addr + buf = binary.BigEndian.AppendUint16(buf, uint16(len(realAddr))) + buf = append(buf, []byte(realAddr)...) + + // external_id + if len(externalID) > 255 { + externalID = externalID[:255] + } + buf = append(buf, byte(len(externalID))) + buf = append(buf, []byte(externalID)...) + + return buf +} + +// --- Resolve --- + +// EncodeResolveReq encodes a resolve request: [4B node_id][4B requester_id][64B signature]. +func EncodeResolveReq(nodeID, requesterID uint32, sig []byte) []byte { + buf := make([]byte, 4+4+64) + binary.BigEndian.PutUint32(buf[:4], nodeID) + binary.BigEndian.PutUint32(buf[4:8], requesterID) + copy(buf[8:], sig) + return buf +} + +// DecodeResolveReq decodes a resolve request. +func DecodeResolveReq(payload []byte) (nodeID, requesterID uint32, sig []byte, err error) { + if len(payload) < 72 { + return 0, 0, nil, fmt.Errorf("resolve request too short: %d bytes", len(payload)) + } + nodeID = binary.BigEndian.Uint32(payload[:4]) + requesterID = binary.BigEndian.Uint32(payload[4:8]) + sig = payload[8:72] + return nodeID, requesterID, sig, nil +} + +// EncodeResolveResp encodes a resolve response. +// Format: [4B node_id][2B addr_len][addr...][2B lan_count][for each: 2B len, addr...] +// +// [4B key_age_days] (math.MaxUint32 if unknown) +func EncodeResolveResp(nodeID uint32, realAddr string, lanAddrs []string, keyAgeDays int) []byte { + size := 4 + 2 + len(realAddr) + 2 + 4 + for _, la := range lanAddrs { + size += 2 + len(la) + } + buf := make([]byte, 0, size) + + buf = binary.BigEndian.AppendUint32(buf, nodeID) + + buf = binary.BigEndian.AppendUint16(buf, uint16(len(realAddr))) + buf = append(buf, []byte(realAddr)...) + + buf = binary.BigEndian.AppendUint16(buf, uint16(len(lanAddrs))) + for _, la := range lanAddrs { + buf = binary.BigEndian.AppendUint16(buf, uint16(len(la))) + buf = append(buf, []byte(la)...) + } + + if keyAgeDays < 0 { + buf = binary.BigEndian.AppendUint32(buf, math.MaxUint32) + } else { + buf = binary.BigEndian.AppendUint32(buf, uint32(keyAgeDays)) + } + + return buf +} + +// --- Error --- + +// EncodeError encodes an error message frame payload. +func EncodeError(msg string) []byte { + if len(msg) > 65000 { + msg = msg[:65000] + } + buf := make([]byte, 2+len(msg)) + binary.BigEndian.PutUint16(buf[:2], uint16(len(msg))) + copy(buf[2:], msg) + return buf +} + +// DecodeError decodes an error message frame payload. +func DecodeError(payload []byte) string { + if len(payload) < 2 { + return "unknown error" + } + length := binary.BigEndian.Uint16(payload[:2]) + if int(length) > len(payload)-2 { + length = uint16(len(payload) - 2) + } + return string(payload[2 : 2+length]) +} + +// --- Lookup response decoder (client-side) --- + +// LookupResult holds the decoded fields from a binary lookup response. +type LookupResult struct { + NodeID uint32 + Public bool + TaskExec bool + Networks []uint16 + PubKey []byte + Hostname string + Tags []string + RealAddr string + ExternalID string +} + +// DecodeLookupResp decodes a binary lookup response. +func DecodeLookupResp(payload []byte) (LookupResult, error) { + var r LookupResult + if len(payload) < 11 { + return r, fmt.Errorf("lookup response too short: %d bytes", len(payload)) + } + + off := 0 + r.NodeID = binary.BigEndian.Uint32(payload[off : off+4]) + off += 4 + flags := payload[off] + off++ + r.Public = flags&0x01 != 0 + r.TaskExec = flags&0x02 != 0 + off += 4 // skip reserved field (was polo_score) + + if off+2 > len(payload) { + return r, fmt.Errorf("truncated network count") + } + netCount := int(binary.BigEndian.Uint16(payload[off : off+2])) + off += 2 + if netCount > maxCount { + return r, fmt.Errorf("network count %d exceeds cap %d", netCount, maxCount) + } + r.Networks = make([]uint16, netCount) + for i := 0; i < netCount; i++ { + if off+2 > len(payload) { + return r, fmt.Errorf("truncated networks at index %d", i) + } + r.Networks[i] = binary.BigEndian.Uint16(payload[off : off+2]) + off += 2 + } + + if off >= len(payload) { + return r, fmt.Errorf("truncated pubkey length") + } + pkLen := int(payload[off]) + off++ + if off+pkLen > len(payload) { + return r, fmt.Errorf("truncated pubkey data") + } + if pkLen > 0 { + r.PubKey = make([]byte, pkLen) + copy(r.PubKey, payload[off:off+pkLen]) + } + off += pkLen + + if off >= len(payload) { + return r, fmt.Errorf("truncated hostname length") + } + hnLen := int(payload[off]) + off++ + if off+hnLen > len(payload) { + return r, fmt.Errorf("truncated hostname data") + } + r.Hostname = string(payload[off : off+hnLen]) + off += hnLen + + if off >= len(payload) { + return r, fmt.Errorf("truncated tags count") + } + tagCount := int(payload[off]) + off++ + if tagCount > maxCount { + return r, fmt.Errorf("tag count %d exceeds cap %d", tagCount, maxCount) + } + r.Tags = make([]string, tagCount) + for i := 0; i < tagCount; i++ { + if off >= len(payload) { + return r, fmt.Errorf("truncated tag length at index %d", i) + } + tLen := int(payload[off]) + off++ + if off+tLen > len(payload) { + return r, fmt.Errorf("truncated tag data at index %d", i) + } + r.Tags[i] = string(payload[off : off+tLen]) + off += tLen + } + + if off+2 > len(payload) { + return r, fmt.Errorf("truncated real_addr length") + } + addrLen := int(binary.BigEndian.Uint16(payload[off : off+2])) + off += 2 + if off+addrLen > len(payload) { + return r, fmt.Errorf("truncated real_addr data") + } + r.RealAddr = string(payload[off : off+addrLen]) + off += addrLen + + if off >= len(payload) { + return r, fmt.Errorf("truncated external_id length") + } + eidLen := int(payload[off]) + off++ + if off+eidLen > len(payload) { + return r, fmt.Errorf("truncated external_id data") + } + r.ExternalID = string(payload[off : off+eidLen]) + + return r, nil +} + +// --- Resolve response decoder (client-side) --- + +// ResolveResult holds the decoded fields from a binary resolve response. +type ResolveResult struct { + NodeID uint32 + RealAddr string + LANAddrs []string + KeyAgeDays int // -1 if unknown +} + +// DecodeResolveResp decodes a binary resolve response. +func DecodeResolveResp(payload []byte) (ResolveResult, error) { + var r ResolveResult + if len(payload) < 12 { + return r, fmt.Errorf("resolve response too short: %d bytes", len(payload)) + } + + off := 0 + r.NodeID = binary.BigEndian.Uint32(payload[off : off+4]) + off += 4 + + if off+2 > len(payload) { + return r, fmt.Errorf("truncated addr length") + } + addrLen := int(binary.BigEndian.Uint16(payload[off : off+2])) + off += 2 + if off+addrLen > len(payload) { + return r, fmt.Errorf("truncated addr data") + } + r.RealAddr = string(payload[off : off+addrLen]) + off += addrLen + + if off+2 > len(payload) { + return r, fmt.Errorf("truncated lan_addrs count") + } + lanCount := int(binary.BigEndian.Uint16(payload[off : off+2])) + off += 2 + if lanCount > maxCount { + return r, fmt.Errorf("lan_addrs count %d exceeds cap %d", lanCount, maxCount) + } + r.LANAddrs = make([]string, lanCount) + for i := 0; i < lanCount; i++ { + if off+2 > len(payload) { + return r, fmt.Errorf("truncated lan addr length at index %d", i) + } + laLen := int(binary.BigEndian.Uint16(payload[off : off+2])) + off += 2 + if off+laLen > len(payload) { + return r, fmt.Errorf("truncated lan addr data at index %d", i) + } + r.LANAddrs[i] = string(payload[off : off+laLen]) + off += laLen + } + + if off+4 > len(payload) { + return r, fmt.Errorf("truncated key_age_days") + } + raw := binary.BigEndian.Uint32(payload[off : off+4]) + if raw == math.MaxUint32 { + r.KeyAgeDays = -1 + } else { + r.KeyAgeDays = int(raw) + } + + return r, nil +} + +// --- JSON message framing --- +// +// The non-binary JSON protocol uses a 4-byte big-endian length prefix +// followed by a JSON body. ReadMessage/WriteMessage are the helpers both +// the client and the server use over the same TCP connection. + +// ReadMessage reads a length-prefixed JSON message from r and decodes +// it into a map. +func ReadMessage(r io.Reader) (map[string]interface{}, error) { + var lenBuf [4]byte + if _, err := io.ReadFull(r, lenBuf[:]); err != nil { + return nil, err + } + length := binary.BigEndian.Uint32(lenBuf[:]) + if length > MaxMessageSize { + return nil, fmt.Errorf("message too large: %d bytes (max %d)", length, MaxMessageSize) + } + + body := make([]byte, length) + if _, err := io.ReadFull(r, body); err != nil { + return nil, err + } + + var msg map[string]interface{} + if err := json.Unmarshal(body, &msg); err != nil { + return nil, fmt.Errorf("json decode: %w", err) + } + return msg, nil +} + +// WriteMessage writes a length-prefixed JSON-encoded message to w. +// If w is a net.Conn, a write deadline is applied. +func WriteMessage(w io.Writer, msg map[string]interface{}) error { + body, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("json encode: %w", err) + } + return WriteRawMessage(w, body) +} + +// WriteRawMessage writes a length-prefixed raw JSON body to w. +// Callers that have already produced the JSON bytes (e.g., a list-nodes +// cache hit) can skip the json.Marshal step. +func WriteRawMessage(w io.Writer, body []byte) error { + var lenBuf [4]byte + binary.BigEndian.PutUint32(lenBuf[:], uint32(len(body))) + + if c, ok := w.(net.Conn); ok { + _ = c.SetWriteDeadline(time.Now().Add(WriteMessageDeadline)) + defer c.SetWriteDeadline(time.Time{}) + } + + if _, err := w.Write(lenBuf[:]); err != nil { + return err + } + if _, err := w.Write(body); err != nil { + return err + } + return nil +} diff --git a/registry/wire/zz_blueprint_test.go b/registry/wire/zz_blueprint_test.go new file mode 100644 index 0000000..9f3aac6 --- /dev/null +++ b/registry/wire/zz_blueprint_test.go @@ -0,0 +1,254 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package wire_test + +import ( + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/pilot-protocol/common/registry/wire" +) + +func TestLoadBlueprint_HappyPath(t *testing.T) { + t.Parallel() + bp := &wire.NetworkBlueprint{ + Name: "test-net", + JoinRule: "open", + } + data, err := json.Marshal(bp) + if err != nil { + t.Fatalf("marshal: %v", err) + } + dir := t.TempDir() + path := filepath.Join(dir, "bp.json") + if err := os.WriteFile(path, data, 0600); err != nil { + t.Fatalf("write: %v", err) + } + got, err := wire.LoadBlueprint(path) + if err != nil { + t.Fatalf("LoadBlueprint: %v", err) + } + if got.Name != bp.Name { + t.Errorf("Name: got %q, want %q", got.Name, bp.Name) + } + if got.JoinRule != bp.JoinRule { + t.Errorf("JoinRule: got %q, want %q", got.JoinRule, bp.JoinRule) + } +} + +func TestLoadBlueprint_FileNotFound(t *testing.T) { + t.Parallel() + _, err := wire.LoadBlueprint(filepath.Join(t.TempDir(), "does-not-exist.json")) + if err == nil || !strings.Contains(err.Error(), "read blueprint") { + t.Fatalf("want 'read blueprint' err, got %v", err) + } +} + +func TestLoadBlueprint_BadJSON(t *testing.T) { + t.Parallel() + dir := t.TempDir() + path := filepath.Join(dir, "bad.json") + if err := os.WriteFile(path, []byte(`{not json`), 0600); err != nil { + t.Fatalf("write: %v", err) + } + _, err := wire.LoadBlueprint(path) + if err == nil || !strings.Contains(err.Error(), "parse blueprint") { + t.Fatalf("want 'parse blueprint' err, got %v", err) + } +} + +func TestLoadBlueprint_MissingName(t *testing.T) { + t.Parallel() + dir := t.TempDir() + path := filepath.Join(dir, "noname.json") + if err := os.WriteFile(path, []byte(`{}`), 0600); err != nil { + t.Fatalf("write: %v", err) + } + _, err := wire.LoadBlueprint(path) + if err == nil || !strings.Contains(err.Error(), "name is required") { + t.Fatalf("want 'name is required' err, got %v", err) + } +} + +func TestValidateBlueprint_HappyPath(t *testing.T) { + t.Parallel() + bp := &wire.NetworkBlueprint{Name: "net", JoinRule: "open"} + if err := wire.ValidateBlueprint(bp); err != nil { + t.Fatalf("ValidateBlueprint: %v", err) + } +} + +func TestValidateBlueprint_NameRequired(t *testing.T) { + t.Parallel() + err := wire.ValidateBlueprint(&wire.NetworkBlueprint{}) + if err == nil || !strings.Contains(err.Error(), "name is required") { + t.Fatalf("want 'name is required', got %v", err) + } +} + +func TestValidateBlueprint_AllJoinRules(t *testing.T) { + t.Parallel() + for _, jr := range []string{"", "open", "token", "invite"} { + bp := &wire.NetworkBlueprint{Name: "n", JoinRule: jr} + if jr == "token" { + bp.JoinToken = "tok" + } + if err := wire.ValidateBlueprint(bp); err != nil { + t.Errorf("JoinRule=%q: %v", jr, err) + } + } +} + +func TestValidateBlueprint_InvalidJoinRule(t *testing.T) { + t.Parallel() + bp := &wire.NetworkBlueprint{Name: "n", JoinRule: "weird"} + err := wire.ValidateBlueprint(bp) + if err == nil || !strings.Contains(err.Error(), "invalid join_rule") { + t.Fatalf("got %v", err) + } +} + +func TestValidateBlueprint_TokenRuleNeedsToken(t *testing.T) { + t.Parallel() + bp := &wire.NetworkBlueprint{Name: "n", JoinRule: "token"} + err := wire.ValidateBlueprint(bp) + if err == nil || !strings.Contains(err.Error(), "join_token is required") { + t.Fatalf("got %v", err) + } +} + +func TestValidateBlueprint_RoleRequiresExternalID(t *testing.T) { + t.Parallel() + bp := &wire.NetworkBlueprint{ + Name: "n", + Roles: []wire.BlueprintRole{{Role: "admin"}}, + } + err := wire.ValidateBlueprint(bp) + if err == nil || !strings.Contains(err.Error(), "external_id is required") { + t.Fatalf("got %v", err) + } +} + +func TestValidateBlueprint_RoleValidAndInvalid(t *testing.T) { + t.Parallel() + for _, r := range []string{"owner", "admin", "member"} { + bp := &wire.NetworkBlueprint{ + Name: "n", + Roles: []wire.BlueprintRole{{ExternalID: "u", Role: r}}, + } + if err := wire.ValidateBlueprint(bp); err != nil { + t.Errorf("role=%q: %v", r, err) + } + } + bp := &wire.NetworkBlueprint{ + Name: "n", + Roles: []wire.BlueprintRole{{ExternalID: "u", Role: "superadmin"}}, + } + if err := wire.ValidateBlueprint(bp); err == nil || + !strings.Contains(err.Error(), "invalid role") { + t.Fatalf("got %v", err) + } +} + +func TestValidateBlueprint_IdentityProvider(t *testing.T) { + t.Parallel() + // missing URL + err := wire.ValidateBlueprint(&wire.NetworkBlueprint{ + Name: "n", + IdentityProvider: &wire.BlueprintIdentityProvider{Type: "oidc"}, + }) + if err == nil || !strings.Contains(err.Error(), "identity_provider.url is required") { + t.Fatalf("got %v", err) + } + // invalid type + err = wire.ValidateBlueprint(&wire.NetworkBlueprint{ + Name: "n", + IdentityProvider: &wire.BlueprintIdentityProvider{Type: "weird", URL: "https://a.b"}, + }) + if err == nil || !strings.Contains(err.Error(), "invalid identity_provider type") { + t.Fatalf("got %v", err) + } + // happy path for each valid type + for _, typ := range []string{"oidc", "saml", "webhook", "entra_id", "ldap"} { + err := wire.ValidateBlueprint(&wire.NetworkBlueprint{ + Name: "n", + IdentityProvider: &wire.BlueprintIdentityProvider{ + Type: typ, + URL: "https://example.com/auth", + }, + }) + if err != nil { + t.Errorf("type=%q: %v", typ, err) + } + } +} + +func TestValidateBlueprint_AuditExport(t *testing.T) { + t.Parallel() + // invalid format + err := wire.ValidateBlueprint(&wire.NetworkBlueprint{ + Name: "n", + AuditExport: &wire.BlueprintAuditExport{ + Format: "weird", + Endpoint: "https://a.b", + }, + }) + if err == nil || !strings.Contains(err.Error(), "invalid audit_export format") { + t.Fatalf("got %v", err) + } + // missing endpoint + err = wire.ValidateBlueprint(&wire.NetworkBlueprint{ + Name: "n", + AuditExport: &wire.BlueprintAuditExport{Format: "json"}, + }) + if err == nil || !strings.Contains(err.Error(), "audit_export.endpoint is required") { + t.Fatalf("got %v", err) + } + // happy path syslog_cef (no URL validation) + err = wire.ValidateBlueprint(&wire.NetworkBlueprint{ + Name: "n", + AuditExport: &wire.BlueprintAuditExport{Format: "syslog_cef", Endpoint: "1.2.3.4:514"}, + }) + if err != nil { + t.Errorf("syslog_cef: %v", err) + } +} + +func TestValidateBlueprint_ExprPolicy(t *testing.T) { + t.Parallel() + // invalid JSON + err := wire.ValidateBlueprint(&wire.NetworkBlueprint{ + Name: "n", + ExprPolicy: json.RawMessage(`{not json`), + }) + if err == nil || !strings.Contains(err.Error(), "expr_policy: invalid JSON") { + t.Fatalf("got %v", err) + } + // wrong version + err = wire.ValidateBlueprint(&wire.NetworkBlueprint{ + Name: "n", + ExprPolicy: json.RawMessage(`{"version":2,"rules":[1]}`), + }) + if err == nil || !strings.Contains(err.Error(), "unsupported version") { + t.Fatalf("got %v", err) + } + // no rules + err = wire.ValidateBlueprint(&wire.NetworkBlueprint{ + Name: "n", + ExprPolicy: json.RawMessage(`{"version":1,"rules":null}`), + }) + if err == nil || !strings.Contains(err.Error(), "at least one rule") { + t.Fatalf("got %v", err) + } + // happy path + err = wire.ValidateBlueprint(&wire.NetworkBlueprint{ + Name: "n", + ExprPolicy: json.RawMessage(`{"version":1,"rules":[{"on":"connect","match":"true"}]}`), + }) + if err != nil { + t.Errorf("happy: %v", err) + } +} diff --git a/registry/wire/zz_decode_edge_test.go b/registry/wire/zz_decode_edge_test.go new file mode 100644 index 0000000..1f9e95d --- /dev/null +++ b/registry/wire/zz_decode_edge_test.go @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package wire_test + +import ( + "strings" + "testing" + + "github.com/pilot-protocol/common/registry/wire" +) + +func TestDecodeLookupReq_Truncated(t *testing.T) { + t.Parallel() + for _, n := range []int{0, 1, 2, 3} { + _, err := wire.DecodeLookupReq(make([]byte, n)) + if err == nil || !strings.Contains(err.Error(), "too short") { + t.Errorf("len=%d: want 'too short' err, got %v", n, err) + } + } +} + +func TestDecodeLookupReq_HappyPath(t *testing.T) { + t.Parallel() + got, err := wire.DecodeLookupReq([]byte{0x00, 0x00, 0xCA, 0xFE}) + if err != nil { + t.Fatalf("DecodeLookupReq: %v", err) + } + if got != 0xCAFE { + t.Errorf("got %x, want CAFE", got) + } +} + +func TestEncodeLookupResp_RoundTripWithAllFields(t *testing.T) { + t.Parallel() + // Build a fully-populated lookup response, then decode it. + encoded := wire.EncodeLookupResp( + 0xABCD, // nodeID + true, // public + true, // taskExec + []uint16{1, 2, 3}, // networks + []byte("0123456789012345"), // pubkey (16 bytes) + "host.example", // hostname + []string{"tag1", "tag2"}, // tags + "1.2.3.4:4000", // realAddr (only if public) + "ext-id-xyz", // externalID + ) + if len(encoded) == 0 { + t.Fatal("EncodeLookupResp returned empty") + } + resp, err := wire.DecodeLookupResp(encoded) + if err != nil { + t.Fatalf("DecodeLookupResp: %v", err) + } + if resp.NodeID != 0xABCD { + t.Errorf("NodeID = %x, want ABCD", resp.NodeID) + } + if !resp.Public { + t.Errorf("Public = false, want true") + } + if resp.Hostname != "host.example" { + t.Errorf("Hostname = %q, want host.example", resp.Hostname) + } + if resp.ExternalID != "ext-id-xyz" { + t.Errorf("ExternalID = %q", resp.ExternalID) + } + if len(resp.Networks) != 3 { + t.Errorf("Networks len = %d, want 3", len(resp.Networks)) + } + if len(resp.Tags) != 2 { + t.Errorf("Tags len = %d, want 2", len(resp.Tags)) + } +} + +func TestEncodeLookupResp_PrivateNodeNoAddr(t *testing.T) { + t.Parallel() + // Private node: realAddr is encoded but should not be revealed by + // post-decode contract. + encoded := wire.EncodeLookupResp( + 1, false, false, []uint16{}, []byte{}, "host", []string{}, "", "", + ) + resp, err := wire.DecodeLookupResp(encoded) + if err != nil { + t.Fatalf("DecodeLookupResp: %v", err) + } + if resp.Public { + t.Errorf("Public = true, want false") + } +} + +func TestDecodeError_Truncated(t *testing.T) { + t.Parallel() + // DecodeError returns a string. Truncated → fallback string. + for _, n := range []int{0, 1} { + got := wire.DecodeError(make([]byte, n)) + if got != "unknown error" { + t.Errorf("len=%d: got %q, want 'unknown error'", n, got) + } + } +} + +func TestDecodeError_HappyPath(t *testing.T) { + t.Parallel() + // 2-byte length prefix + body + msg := "internal error" + buf := []byte{byte(len(msg) >> 8), byte(len(msg))} + buf = append(buf, msg...) + if got := wire.DecodeError(buf); got != msg { + t.Errorf("got %q, want %q", got, msg) + } +} diff --git a/registry/wire/zz_decode_truncation_test.go b/registry/wire/zz_decode_truncation_test.go new file mode 100644 index 0000000..6d90ca9 --- /dev/null +++ b/registry/wire/zz_decode_truncation_test.go @@ -0,0 +1,108 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package wire_test + +import ( + "strings" + "testing" + + "github.com/pilot-protocol/common/registry/wire" +) + +// TestDecodeLookupResp_TruncationCascade walks every truncation +// boundary in DecodeLookupResp by encoding a fully-populated response +// then progressively trimming the payload from the back. +func TestDecodeLookupResp_TruncationCascade(t *testing.T) { + t.Parallel() + full := wire.EncodeLookupResp( + 0x1234, + true, + false, + []uint16{1, 2}, + []byte("pubkey-32-bytes-AAAAAAAAAAAAAAAA"), + "hostname", + []string{"tagA", "tagB"}, + "10.0.0.1:4000", + "ext-id", + ) + // Truncate to every shorter length and ensure decode either succeeds + // (only happens at exact length boundaries) or returns a truncation + // error. This exercises every "if off >= len(payload)" branch. + for i := 0; i < len(full); i++ { + _, err := wire.DecodeLookupResp(full[:i]) + if err == nil { + continue // accidental valid prefix is fine + } + // Every error should contain "truncated" or "too short". + if !strings.Contains(err.Error(), "truncated") && + !strings.Contains(err.Error(), "too short") { + t.Errorf("len=%d: unexpected err %v", i, err) + } + } +} + +// TestDecodeResolveResp_TruncationCascade does the same for ResolveResp. +func TestDecodeResolveResp_TruncationCascade(t *testing.T) { + t.Parallel() + full := wire.EncodeResolveResp(0x1234, "10.0.0.5:4000", []string{"192.168.1.10:4000"}, 7) + for i := 0; i < len(full); i++ { + _, err := wire.DecodeResolveResp(full[:i]) + if err == nil { + continue + } + if !strings.Contains(err.Error(), "truncated") && + !strings.Contains(err.Error(), "too short") && + !strings.Contains(err.Error(), "decode") { + t.Errorf("len=%d: unexpected err %v", i, err) + } + } +} + +// TestDecodeResolveReq_Truncation drills the short-buffer branches. +func TestDecodeResolveReq_Truncation(t *testing.T) { + t.Parallel() + for _, n := range []int{0, 1, 4, 8, 16, 32, 64, 71} { + _, _, _, err := wire.DecodeResolveReq(make([]byte, n)) + if err == nil { + t.Errorf("len=%d: want error, got nil", n) + } + } +} + +// TestDecodeHeartbeatResp_Truncation exercises the small response decoder. +func TestDecodeHeartbeatResp_Truncation(t *testing.T) { + t.Parallel() + for _, n := range []int{0, 1, 2, 3, 4, 5, 8} { + _, _, err := wire.DecodeHeartbeatResp(make([]byte, n)) + if err == nil { + t.Errorf("len=%d: want error, got nil", n) + } + } +} + +// TestDecodeError_LengthExceedsBuffer covers the clamping branch where the +// length prefix lies about how much data follows. +func TestDecodeError_LengthExceedsBuffer(t *testing.T) { + t.Parallel() + // Length prefix says 100 bytes, but buffer only has 5 bytes of body. + buf := []byte{0x00, 0x64, 'h', 'e', 'l', 'l', 'o'} // 0x0064 = 100 + got := wire.DecodeError(buf) + if got != "hello" { + t.Errorf("got %q, want 'hello' (clamped)", got) + } +} + +// TestEncodeError_OverlongMessageTruncated covers EncodeError's 65000-byte cap. +func TestEncodeError_OverlongMessageTruncated(t *testing.T) { + t.Parallel() + long := strings.Repeat("x", 70000) + encoded := wire.EncodeError(long) + // Encoded payload = 2-byte length + body. Body should be 65000 chars. + if len(encoded) != 2+65000 { + t.Errorf("encoded length = %d, want %d", len(encoded), 2+65000) + } + // Decode and ensure round-trip is the truncated form. + if got := wire.DecodeError(encoded); len(got) != 65000 { + t.Errorf("decoded length = %d, want 65000", len(got)) + } +} diff --git a/registry/wire/zz_frame_test.go b/registry/wire/zz_frame_test.go new file mode 100644 index 0000000..4ed4f78 --- /dev/null +++ b/registry/wire/zz_frame_test.go @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package wire_test + +import ( + "bytes" + "encoding/binary" + "errors" + "io" + "strings" + "testing" + + "github.com/pilot-protocol/common/registry/wire" +) + +func TestReadFrameWriteFrame_RoundTrip(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + payload := []byte("hello-frame-body") + if err := wire.WriteFrame(&buf, 0x42, payload); err != nil { + t.Fatalf("WriteFrame: %v", err) + } + msgType, got, err := wire.ReadFrame(&buf) + if err != nil { + t.Fatalf("ReadFrame: %v", err) + } + if msgType != 0x42 { + t.Errorf("msgType = %x, want 0x42", msgType) + } + if !bytes.Equal(got, payload) { + t.Errorf("payload = %q, want %q", got, payload) + } +} + +func TestReadFrame_EmptyPayload(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + if err := wire.WriteFrame(&buf, 0x01, nil); err != nil { + t.Fatalf("WriteFrame: %v", err) + } + msgType, payload, err := wire.ReadFrame(&buf) + if err != nil { + t.Fatalf("ReadFrame: %v", err) + } + if msgType != 0x01 { + t.Errorf("msgType = %x, want 0x01", msgType) + } + if len(payload) != 0 { + t.Errorf("payload len = %d, want 0", len(payload)) + } +} + +func TestReadFrame_HeaderTruncated(t *testing.T) { + t.Parallel() + _, _, err := wire.ReadFrame(bytes.NewReader([]byte{0x00, 0x01})) // 2 bytes, need 5 + if !errors.Is(err, io.ErrUnexpectedEOF) && err != io.EOF { + t.Errorf("want EOF/ErrUnexpectedEOF, got %v", err) + } +} + +func TestReadFrame_LengthZero(t *testing.T) { + t.Parallel() + var hdr [5]byte + binary.BigEndian.PutUint32(hdr[:4], 0) // length = 0 → too short + hdr[4] = 0x01 + _, _, err := wire.ReadFrame(bytes.NewReader(hdr[:])) + if err == nil || !strings.Contains(err.Error(), "too short") { + t.Errorf("want 'too short', got %v", err) + } +} + +func TestReadFrame_LengthExceedsMax(t *testing.T) { + t.Parallel() + var hdr [5]byte + binary.BigEndian.PutUint32(hdr[:4], wire.MaxMessageSize+1) + _, _, err := wire.ReadFrame(bytes.NewReader(hdr[:])) + if err == nil || !strings.Contains(err.Error(), "too large") { + t.Errorf("want 'too large', got %v", err) + } +} + +func TestReadFrame_PayloadTruncated(t *testing.T) { + t.Parallel() + var hdr [5]byte + binary.BigEndian.PutUint32(hdr[:4], 100) // claims 99 bytes of payload + hdr[4] = 0x01 + _, _, err := wire.ReadFrame(bytes.NewReader(append(hdr[:], []byte("short")...))) + if !errors.Is(err, io.ErrUnexpectedEOF) { + t.Errorf("want ErrUnexpectedEOF, got %v", err) + } +} + +func TestWriteFrame_HeaderWriteError(t *testing.T) { + t.Parallel() + bang := errors.New("hdr-fail") + err := wire.WriteFrame(&failingWriter{err: bang}, 0x01, []byte("xx")) + if !errors.Is(err, bang) { + t.Errorf("want hdr-fail, got %v", err) + } +} + +func TestWriteFrame_PayloadWriteError(t *testing.T) { + t.Parallel() + bang := errors.New("payload-fail") + // allow 5-byte header, fail body + err := wire.WriteFrame(&shortWriter{allow: 5, err: bang}, 0x01, []byte("hello")) + if !errors.Is(err, bang) { + t.Errorf("want payload-fail, got %v", err) + } +} diff --git a/registry/wire/zz_fuzz_wire_test.go b/registry/wire/zz_fuzz_wire_test.go new file mode 100644 index 0000000..d45bb94 --- /dev/null +++ b/registry/wire/zz_fuzz_wire_test.go @@ -0,0 +1,216 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package wire_test + +import ( + "bytes" + "encoding/binary" + "testing" + + "github.com/pilot-protocol/common/registry/wire" +) + +// FuzzReadFrame exercises the binary frame reader. +// Wire format: [4B length][1B type][payload]. The length field is +// length-prefixed; a malicious or buggy peer could send a 4-byte header +// that claims gigabytes — the MaxMessageSize cap should keep it bounded, +// but fuzzing confirms no panic / OOM regression slips in. +func FuzzReadFrame(f *testing.F) { + // Seed: valid empty-payload JSON frame. + { + var buf bytes.Buffer + wire.WriteFrame(&buf, wire.MsgJSON, []byte("{}")) + f.Add(buf.Bytes()) + } + // Seed: valid heartbeat req. + { + var buf bytes.Buffer + wire.WriteFrame(&buf, wire.MsgHeartbeat, wire.EncodeHeartbeatReq(42, make([]byte, 64))) + f.Add(buf.Bytes()) + } + // Seed: lookup req. + { + var buf bytes.Buffer + wire.WriteFrame(&buf, wire.MsgLookup, wire.EncodeLookupReq(0xDEADBEEF)) + f.Add(buf.Bytes()) + } + // Adversarial: huge length field, no body. + { + var hdr [5]byte + binary.BigEndian.PutUint32(hdr[:4], 0xFFFFFFFF) + hdr[4] = wire.MsgJSON + f.Add(hdr[:]) + } + // Adversarial: length=0 (below the "must include type byte" minimum). + { + var hdr [5]byte + binary.BigEndian.PutUint32(hdr[:4], 0) + hdr[4] = wire.MsgJSON + f.Add(hdr[:]) + } + f.Add([]byte{}) + f.Add(make([]byte, 4)) + + f.Fuzz(func(t *testing.T, data []byte) { + defer func() { + if r := recover(); r != nil { + t.Errorf("panic on input %x: %v", data, r) + } + }() + r := bytes.NewReader(data) + _, _, _ = wire.ReadFrame(r) + }) +} + +// FuzzDecodeLookupResp targets the wire-controlled allocation path +// flagged in PILOT-131. The decoder pulls counts (network count, tag +// count, length-prefixed fields) directly from the input — a 16-bit +// network count or 8-bit tag count drives `make([]uint16, n)` / +// `make([]string, n)`. Truncated inputs must surface as errors, not +// panics, and not unbounded allocations. +func FuzzDecodeLookupResp(f *testing.F) { + f.Add(wire.EncodeLookupResp(1, false, false, nil, nil, "", nil, "", "")) + f.Add(wire.EncodeLookupResp(0xDEADBEEF, true, true, + []uint16{1, 2, 3}, []byte("pubkey"), "host", []string{"a", "b"}, + "1.2.3.4:5", "extid")) + f.Add(wire.EncodeLookupResp(7, true, false, + []uint16{42}, bytes.Repeat([]byte{0x55}, 255), "h", []string{"tag"}, + "", "")) + + // Adversarial: header claims many networks but no body follows. + { + buf := make([]byte, 11) + binary.BigEndian.PutUint32(buf[:4], 1) + buf[4] = 0 + // reserved (4) zero + binary.BigEndian.PutUint16(buf[9:11], 0xFFFF) // claim 65535 networks + f.Add(buf) + } + // Adversarial: pubkey_len > remaining bytes. + { + buf := make([]byte, 12) + binary.BigEndian.PutUint32(buf[:4], 1) + // reserved + netcount = 0 + buf[11] = 0xFF // pubkey_len = 255 + f.Add(buf) + } + // Minimum-size buffer. + f.Add(make([]byte, 11)) + f.Add([]byte{}) + + f.Fuzz(func(t *testing.T, data []byte) { + defer func() { + if r := recover(); r != nil { + t.Errorf("panic on input %x: %v", data, r) + } + }() + _, _ = wire.DecodeLookupResp(data) + }) +} + +// FuzzDecodeResolveResp covers the resolve response decoder which has +// the same wire-controlled allocation shape (count + length-prefixed +// LAN addrs). +func FuzzDecodeResolveResp(f *testing.F) { + f.Add(wire.EncodeResolveResp(1, "1.2.3.4:5", nil, 0)) + f.Add(wire.EncodeResolveResp(2, "10.0.0.1:9000", + []string{"192.168.1.1", "10.0.0.5"}, 30)) + f.Add(wire.EncodeResolveResp(3, "", nil, -1)) + f.Add(make([]byte, 12)) + f.Add([]byte{}) + + // Adversarial: LAN count overflow. + { + buf := make([]byte, 8) + binary.BigEndian.PutUint32(buf[:4], 1) + binary.BigEndian.PutUint16(buf[4:6], 0) // addr_len = 0 + binary.BigEndian.PutUint16(buf[6:8], 0xFFFF) // 65535 LAN addrs + f.Add(buf) + } + + f.Fuzz(func(t *testing.T, data []byte) { + defer func() { + if r := recover(); r != nil { + t.Errorf("panic on input %x: %v", data, r) + } + }() + _, _ = wire.DecodeResolveResp(data) + }) +} + +// FuzzDecodeHeartbeatReq / Resp / LookupReq / Error are simple +// fixed-shape decoders — fuzz them anyway since they're entry points. +func FuzzDecodeHeartbeatReq(f *testing.F) { + f.Add(wire.EncodeHeartbeatReq(1, make([]byte, 64))) + f.Add(wire.EncodeHeartbeatReq(0xFFFFFFFF, bytes.Repeat([]byte{0xAA}, 64))) + f.Add([]byte{}) + f.Add(make([]byte, 67)) // one byte short + + f.Fuzz(func(t *testing.T, data []byte) { + defer func() { + if r := recover(); r != nil { + t.Errorf("panic on input %x: %v", data, r) + } + }() + _, _ = wire.DecodeHeartbeatReq(data) + }) +} + +func FuzzDecodeError(f *testing.F) { + f.Add(wire.EncodeError("oh no")) + f.Add(wire.EncodeError("")) + f.Add([]byte{0x00}) + + f.Fuzz(func(t *testing.T, data []byte) { + defer func() { + if r := recover(); r != nil { + t.Errorf("panic on input %x: %v", data, r) + } + }() + _ = wire.DecodeError(data) + }) +} + +// FuzzReadMessage exercises the JSON length-prefixed message reader. +// The 4-byte length is wire-controlled; the MaxMessageSize check is the +// only guard against `make([]byte, hugeLength)`. Verify no panic and no +// OOM-by-allocation regression. +func FuzzReadMessage(f *testing.F) { + // Seed: valid 2-byte JSON `{}`. + { + var buf bytes.Buffer + _ = wire.WriteMessage(&buf, map[string]interface{}{}) + f.Add(buf.Bytes()) + } + { + var buf bytes.Buffer + _ = wire.WriteMessage(&buf, map[string]interface{}{ + "op": "lookup", "node_id": float64(42), + }) + f.Add(buf.Bytes()) + } + // Adversarial: header claims big payload. + { + hdr := make([]byte, 4) + binary.BigEndian.PutUint32(hdr, 0xFFFFFFFF) + f.Add(hdr) + } + // Length declares 4GB but no body follows. + { + hdr := make([]byte, 4) + binary.BigEndian.PutUint32(hdr, 0x7FFFFFFF) + f.Add(hdr) + } + f.Add([]byte{}) + f.Add(make([]byte, 3)) + + f.Fuzz(func(t *testing.T, data []byte) { + defer func() { + if r := recover(); r != nil { + t.Errorf("panic on input %x: %v", data, r) + } + }() + r := bytes.NewReader(data) + _, _ = wire.ReadMessage(r) + }) +} diff --git a/registry/wire/zz_message_framing_test.go b/registry/wire/zz_message_framing_test.go new file mode 100644 index 0000000..2a3c6df --- /dev/null +++ b/registry/wire/zz_message_framing_test.go @@ -0,0 +1,151 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package wire_test + +import ( + "bytes" + "encoding/binary" + "errors" + "io" + "strings" + "testing" + + "github.com/pilot-protocol/common/registry/wire" +) + +// failingWriter returns the supplied error on every Write call. Used to +// exercise the early-return branch of WriteRawMessage. +type failingWriter struct{ err error } + +func (f *failingWriter) Write(p []byte) (int, error) { return 0, f.err } + +// shortWriter accepts the first N bytes then errors. Used to fail the +// SECOND Write inside WriteRawMessage. +type shortWriter struct { + allow int + err error +} + +func (s *shortWriter) Write(p []byte) (int, error) { + if s.allow >= len(p) { + s.allow -= len(p) + return len(p), nil + } + return 0, s.err +} + +func TestWriteReadMessageRoundTrip(t *testing.T) { + t.Parallel() + msg := map[string]interface{}{ + "op": "register", + "email": "a@b.co", + "port": float64(4000), // json.Unmarshal turns numbers into float64 + } + var buf bytes.Buffer + if err := wire.WriteMessage(&buf, msg); err != nil { + t.Fatalf("WriteMessage: %v", err) + } + got, err := wire.ReadMessage(&buf) + if err != nil { + t.Fatalf("ReadMessage: %v", err) + } + for k, v := range msg { + if got[k] != v { + t.Errorf("key %q: got %v (%T), want %v (%T)", k, got[k], got[k], v, v) + } + } +} + +func TestWriteMessageJSONEncodeError(t *testing.T) { + t.Parallel() + // channels can't be JSON-encoded → json.Marshal fails. + bad := map[string]interface{}{"ch": make(chan int)} + err := wire.WriteMessage(&bytes.Buffer{}, bad) + if err == nil || !strings.Contains(err.Error(), "json encode") { + t.Fatalf("want 'json encode' err, got %v", err) + } +} + +func TestReadMessageTooLarge(t *testing.T) { + t.Parallel() + // Synthesise a length prefix > MaxMessageSize without writing the body. + var lenBuf [4]byte + binary.BigEndian.PutUint32(lenBuf[:], wire.MaxMessageSize+1) + r := bytes.NewReader(lenBuf[:]) + _, err := wire.ReadMessage(r) + if err == nil || !strings.Contains(err.Error(), "too large") { + t.Fatalf("want 'too large' err, got %v", err) + } +} + +func TestReadMessageEOFOnPrefix(t *testing.T) { + t.Parallel() + // Empty reader → io.EOF on the length prefix read. + _, err := wire.ReadMessage(bytes.NewReader(nil)) + if !errors.Is(err, io.EOF) { + t.Fatalf("want io.EOF, got %v", err) + } +} + +func TestReadMessageTruncatedBody(t *testing.T) { + t.Parallel() + // 100-byte prefix but only 5 bytes of body → ErrUnexpectedEOF. + var lenBuf [4]byte + binary.BigEndian.PutUint32(lenBuf[:], 100) + r := bytes.NewReader(append(lenBuf[:], []byte("short")...)) + _, err := wire.ReadMessage(r) + if !errors.Is(err, io.ErrUnexpectedEOF) { + t.Fatalf("want ErrUnexpectedEOF, got %v", err) + } +} + +func TestReadMessageBadJSON(t *testing.T) { + t.Parallel() + body := []byte("{not json") + var lenBuf [4]byte + binary.BigEndian.PutUint32(lenBuf[:], uint32(len(body))) + r := bytes.NewReader(append(lenBuf[:], body...)) + _, err := wire.ReadMessage(r) + if err == nil || !strings.Contains(err.Error(), "json decode") { + t.Fatalf("want 'json decode' err, got %v", err) + } +} + +func TestWriteRawMessageHappyPath(t *testing.T) { + t.Parallel() + body := []byte(`{"ok":true}`) + var buf bytes.Buffer + if err := wire.WriteRawMessage(&buf, body); err != nil { + t.Fatalf("WriteRawMessage: %v", err) + } + // Verify prefix + if buf.Len() != 4+len(body) { + t.Fatalf("buf len %d, want %d", buf.Len(), 4+len(body)) + } + gotLen := binary.BigEndian.Uint32(buf.Bytes()[:4]) + if int(gotLen) != len(body) { + t.Errorf("length prefix %d, want %d", gotLen, len(body)) + } + if !bytes.Equal(buf.Bytes()[4:], body) { + t.Errorf("body mismatch: got %q", buf.Bytes()[4:]) + } +} + +func TestWriteRawMessageErrorOnPrefix(t *testing.T) { + t.Parallel() + bang := errors.New("boom") + err := wire.WriteRawMessage(&failingWriter{err: bang}, []byte(`{}`)) + if !errors.Is(err, bang) { + t.Fatalf("want boom, got %v", err) + } +} + +func TestWriteRawMessageErrorOnBody(t *testing.T) { + t.Parallel() + bang := errors.New("boom") + // allow the 4-byte prefix, fail on the body write + err := wire.WriteRawMessage(&shortWriter{allow: 4, err: bang}, []byte(`{"x":1}`)) + if !errors.Is(err, bang) { + t.Fatalf("want boom on body write, got %v", err) + } +} diff --git a/registry/wire/zz_rules_test.go b/registry/wire/zz_rules_test.go new file mode 100644 index 0000000..512ecd5 --- /dev/null +++ b/registry/wire/zz_rules_test.go @@ -0,0 +1,334 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package wire_test + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/pilot-protocol/common/registry/wire" +) + +// --- ValidateRules error branches ---------------------------------------- + +func TestValidateRulesNilReturnsNil(t *testing.T) { + t.Parallel() + if err := wire.ValidateRules(nil); err != nil { + t.Fatalf("nil rules: %v", err) + } +} + +func TestValidateRulesLinksRequired(t *testing.T) { + t.Parallel() + cases := []int{0, -5} + for _, l := range cases { + r := &wire.NetworkRules{Links: l, Cycle: "1h", PruneBy: "score", FillHow: "random"} + err := wire.ValidateRules(r) + if err == nil || !strings.Contains(err.Error(), "links must be >= 1") { + t.Fatalf("links=%d: %v", l, err) + } + } +} + +func TestValidateRulesCycleRequired(t *testing.T) { + t.Parallel() + r := &wire.NetworkRules{Links: 5, Cycle: "", PruneBy: "score", FillHow: "random"} + err := wire.ValidateRules(r) + if err == nil || !strings.Contains(err.Error(), "cycle is required") { + t.Fatalf("expected cycle-required error, got %v", err) + } +} + +func TestValidateRulesCycleInvalidDuration(t *testing.T) { + t.Parallel() + r := &wire.NetworkRules{Links: 5, Cycle: "not-a-duration", PruneBy: "score", FillHow: "random"} + err := wire.ValidateRules(r) + if err == nil || !strings.Contains(err.Error(), "invalid cycle duration") { + t.Fatalf("%v", err) + } +} + +func TestValidateRulesCycleTooShort(t *testing.T) { + t.Parallel() + r := &wire.NetworkRules{Links: 5, Cycle: "30s", PruneBy: "score", FillHow: "random"} + err := wire.ValidateRules(r) + if err == nil || !strings.Contains(err.Error(), "cycle must be >= 1m") { + t.Fatalf("%v", err) + } +} + +func TestValidateRulesPruneFillNegativeOrOverflow(t *testing.T) { + t.Parallel() + base := wire.NetworkRules{Links: 5, Cycle: "1h", PruneBy: "score", FillHow: "random"} + // Prune < 0 + r := base + r.Prune = -1 + if err := wire.ValidateRules(&r); err == nil || !strings.Contains(err.Error(), "prune must be >= 0") { + t.Fatalf("prune<0: %v", err) + } + // Fill < 0 + r = base + r.Fill = -1 + if err := wire.ValidateRules(&r); err == nil || !strings.Contains(err.Error(), "fill must be >= 0") { + t.Fatalf("fill<0: %v", err) + } + // Prune > Links + r = base + r.Prune = 10 + if err := wire.ValidateRules(&r); err == nil || !strings.Contains(err.Error(), "cannot exceed links") { + t.Fatalf("prune>links: %v", err) + } + // Fill > Links + r = base + r.Fill = 10 + if err := wire.ValidateRules(&r); err == nil || !strings.Contains(err.Error(), "fill (10) cannot exceed links") { + t.Fatalf("fill>links: %v", err) + } +} + +func TestValidateRulesPruneByAllValidValues(t *testing.T) { + t.Parallel() + for _, pb := range []string{"score", "age", "activity"} { + r := &wire.NetworkRules{Links: 5, Cycle: "1h", PruneBy: pb, FillHow: "random"} + if err := wire.ValidateRules(r); err != nil { + t.Fatalf("prune_by=%q: %v", pb, err) + } + } +} + +func TestValidateRulesPruneByRequiredAndUnknown(t *testing.T) { + t.Parallel() + // Empty + r := &wire.NetworkRules{Links: 5, Cycle: "1h", PruneBy: "", FillHow: "random"} + if err := wire.ValidateRules(r); err == nil || !strings.Contains(err.Error(), "prune_by is required") { + t.Fatalf("empty prune_by: %v", err) + } + // Unknown + r = &wire.NetworkRules{Links: 5, Cycle: "1h", PruneBy: "lottery", FillHow: "random"} + if err := wire.ValidateRules(r); err == nil || !strings.Contains(err.Error(), "unknown prune_by strategy") { + t.Fatalf("unknown prune_by: %v", err) + } +} + +func TestValidateRulesFillHowRequiredAndUnknown(t *testing.T) { + t.Parallel() + r := &wire.NetworkRules{Links: 5, Cycle: "1h", PruneBy: "score", FillHow: ""} + if err := wire.ValidateRules(r); err == nil || !strings.Contains(err.Error(), "fill_how is required") { + t.Fatalf("empty fill_how: %v", err) + } + r = &wire.NetworkRules{Links: 5, Cycle: "1h", PruneBy: "score", FillHow: "roundrobin"} + if err := wire.ValidateRules(r); err == nil || !strings.Contains(err.Error(), "unknown fill_how strategy") { + t.Fatalf("unknown fill_how: %v", err) + } +} + +func TestValidateRulesGraceInvalid(t *testing.T) { + t.Parallel() + r := &wire.NetworkRules{Links: 5, Cycle: "1h", PruneBy: "score", FillHow: "random", Grace: "not-a-duration"} + if err := wire.ValidateRules(r); err == nil || !strings.Contains(err.Error(), "invalid grace duration") { + t.Fatalf("bad grace: %v", err) + } + // Note: time.ParseDuration rejects literal negatives like "-1m" for some inputs. + // We rely on the `g < 0` branch being effectively unreachable via parsing in practice, + // but verify parseable non-negative grace succeeds. +} + +func TestValidateRulesGraceEmptyOrValid(t *testing.T) { + t.Parallel() + r := &wire.NetworkRules{Links: 5, Cycle: "1h", PruneBy: "score", FillHow: "random", Grace: ""} + if err := wire.ValidateRules(r); err != nil { + t.Fatalf("empty grace: %v", err) + } + r.Grace = "10m" + if err := wire.ValidateRules(r); err != nil { + t.Fatalf("valid grace: %v", err) + } +} + +func TestValidateRulesHappyPath(t *testing.T) { + t.Parallel() + r := &wire.NetworkRules{Links: 10, Cycle: "1h", Prune: 2, PruneBy: "score", Fill: 2, FillHow: "random", Grace: "5m"} + if err := wire.ValidateRules(r); err != nil { + t.Fatalf("happy: %v", err) + } +} + +// --- ParseRules ----------------------------------------------------------- + +func TestParseRulesBadJSON(t *testing.T) { + t.Parallel() + _, err := wire.ParseRules(`{not json`) + if err == nil || !strings.Contains(err.Error(), "invalid JSON") { + t.Fatalf("%v", err) + } +} + +func TestParseRulesInvalidRules(t *testing.T) { + t.Parallel() + _, err := wire.ParseRules(`{"links":0,"cycle":"1h","prune_by":"score","fill_how":"random"}`) + if err == nil || !strings.Contains(err.Error(), "links must be >= 1") { + t.Fatalf("%v", err) + } +} + +func TestParseRulesHappyPath(t *testing.T) { + t.Parallel() + r, err := wire.ParseRules(`{"links":5,"cycle":"1h","prune":1,"prune_by":"age","fill":1,"fill_how":"random"}`) + if err != nil { + t.Fatalf("%v", err) + } + if r.Links != 5 || r.Cycle != "1h" || r.Prune != 1 || r.PruneBy != "age" || r.Fill != 1 || r.FillHow != "random" { + t.Fatalf("parsed: %+v", r) + } +} + +// --- RulesToPolicy -------------------------------------------------------- + +func TestRulesToPolicyNilReturnsNilNil(t *testing.T) { + t.Parallel() + raw, err := wire.RulesToPolicy(nil) + if err != nil { + t.Fatalf("%v", err) + } + if raw != nil { + t.Fatalf("expected nil json.RawMessage for nil rules, got %s", string(raw)) + } +} + +func TestRulesToPolicyShapeAndContentWithoutGrace(t *testing.T) { + t.Parallel() + r := &wire.NetworkRules{Links: 7, Cycle: "2h", Prune: 3, PruneBy: "age", Fill: 2, FillHow: "random"} + raw, err := wire.RulesToPolicy(r) + if err != nil { + t.Fatalf("%v", err) + } + var doc map[string]interface{} + if err := json.Unmarshal(raw, &doc); err != nil { + t.Fatalf("%v", err) + } + if doc["version"].(float64) != 1 { + t.Fatalf("version: %v", doc["version"]) + } + cfg := doc["config"].(map[string]interface{}) + if cfg["max_peers"].(float64) != 7 { + t.Fatalf("max_peers: %v", cfg["max_peers"]) + } + if cfg["cycle"].(string) != "2h" { + t.Fatalf("cycle: %v", cfg["cycle"]) + } + if _, hasGrace := cfg["grace"]; hasGrace { + t.Fatalf("grace should be absent when Grace=\"\"") + } + rules := doc["rules"].([]interface{}) + if len(rules) != 1 { + t.Fatalf("rules count: %d", len(rules)) + } + // rule[0] = cycle-prune-fill; prune action first, fill action second + r1 := rules[0].(map[string]interface{}) + if r1["name"].(string) != "cycle-prune-fill" || r1["on"].(string) != "cycle" { + t.Fatalf("rule 0: %+v", r1) + } + actions := r1["actions"].([]interface{}) + pruneA := actions[0].(map[string]interface{}) + if pruneA["type"].(string) != "prune" { + t.Fatalf("first action: %+v", pruneA) + } + params := pruneA["params"].(map[string]interface{}) + if params["count"].(float64) != 3 || params["by"].(string) != "age" { + t.Fatalf("prune params: %+v", params) + } + fillA := actions[1].(map[string]interface{}) + if fillA["type"].(string) != "fill" { + t.Fatalf("second action: %+v", fillA) + } + fillP := fillA["params"].(map[string]interface{}) + if fillP["count"].(float64) != 2 || fillP["how"].(string) != "random" { + t.Fatalf("fill params: %+v", fillP) + } +} + +func TestRulesToPolicyIncludesGraceWhenSet(t *testing.T) { + t.Parallel() + r := &wire.NetworkRules{Links: 5, Cycle: "1h", Prune: 1, PruneBy: "score", Fill: 1, FillHow: "random", Grace: "15m"} + raw, err := wire.RulesToPolicy(r) + if err != nil { + t.Fatalf("%v", err) + } + var doc map[string]interface{} + _ = json.Unmarshal(raw, &doc) + cfg := doc["config"].(map[string]interface{}) + if cfg["grace"].(string) != "15m" { + t.Fatalf("grace: %v", cfg["grace"]) + } +} + +// --- AllowedPortsToPolicy ------------------------------------------------- + +func TestAllowedPortsToPolicyEmptyReturnsNilNil(t *testing.T) { + t.Parallel() + raw, err := wire.AllowedPortsToPolicy(nil) + if err != nil || raw != nil { + t.Fatalf("nil ports: raw=%v err=%v", raw, err) + } + raw, err = wire.AllowedPortsToPolicy([]uint16{}) + if err != nil || raw != nil { + t.Fatalf("empty ports: raw=%v err=%v", raw, err) + } +} + +func TestAllowedPortsToPolicyMatchExpressionAndRules(t *testing.T) { + t.Parallel() + raw, err := wire.AllowedPortsToPolicy([]uint16{80, 443, 7001}) + if err != nil { + t.Fatalf("%v", err) + } + // Raw text contains the exact match expression. + s := string(raw) + if !strings.Contains(s, `"port in [80, 443, 7001]"`) { + t.Fatalf("match expr not formatted as expected:\n%s", s) + } + var doc map[string]interface{} + if err := json.Unmarshal(raw, &doc); err != nil { + t.Fatalf("%v", err) + } + if doc["version"].(float64) != 1 { + t.Fatalf("version: %v", doc["version"]) + } + rules := doc["rules"].([]interface{}) + if len(rules) != 6 { + t.Fatalf("rules count: %d, want 6 (3 allow + 3 deny)", len(rules)) + } + // Expected names in order. + wantNames := []string{"allow-ports", "allow-ports-dg", "allow-ports-dial", "deny-rest", "deny-rest-dg", "deny-rest-dial"} + for i, want := range wantNames { + r := rules[i].(map[string]interface{}) + if r["name"].(string) != want { + t.Fatalf("rule[%d].name = %q, want %q", i, r["name"], want) + } + } + // Allow rules use the built match expr; deny rules use "true". + for i := 0; i < 3; i++ { + r := rules[i].(map[string]interface{}) + if r["match"].(string) != "port in [80, 443, 7001]" { + t.Fatalf("allow rule[%d] match: %q", i, r["match"]) + } + } + for i := 3; i < 6; i++ { + r := rules[i].(map[string]interface{}) + if r["match"].(string) != "true" { + t.Fatalf("deny rule[%d] match: %q", i, r["match"]) + } + } +} + +func TestAllowedPortsToPolicySinglePort(t *testing.T) { + t.Parallel() + raw, err := wire.AllowedPortsToPolicy([]uint16{7}) + if err != nil { + t.Fatalf("%v", err) + } + if !strings.Contains(string(raw), `"port in [7]"`) { + t.Fatalf("single-port match expr:\n%s", string(raw)) + } +} diff --git a/registry/wire/zz_wire_test.go b/registry/wire/zz_wire_test.go new file mode 100644 index 0000000..1467eb6 --- /dev/null +++ b/registry/wire/zz_wire_test.go @@ -0,0 +1,415 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package wire_test + +import ( + "bytes" + "math" + "testing" + + "github.com/pilot-protocol/common/registry/wire" +) + +func TestWireFrameRoundTrip(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + msgType byte + payload []byte + }{ + {"empty payload", wire.MsgHeartbeat, nil}, + {"small payload", wire.MsgLookup, []byte{1, 2, 3, 4}}, + {"max type", wire.MsgError, []byte("test error")}, + {"json passthrough", wire.MsgJSON, []byte(`{"type":"heartbeat","node_id":42}`)}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + if err := wire.WriteFrame(&buf, tt.msgType, tt.payload); err != nil { + t.Fatalf("write frame: %v", err) + } + + gotType, gotPayload, err := wire.ReadFrame(&buf) + if err != nil { + t.Fatalf("read frame: %v", err) + } + if gotType != tt.msgType { + t.Fatalf("type: got 0x%02x, want 0x%02x", gotType, tt.msgType) + } + if !bytes.Equal(gotPayload, tt.payload) { + t.Fatalf("payload: got %v, want %v", gotPayload, tt.payload) + } + }) + } +} + +func TestWireFrameMultipleMessages(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + for i := 0; i < 10; i++ { + payload := []byte{byte(i), byte(i + 1)} + if err := wire.WriteFrame(&buf, byte(i), payload); err != nil { + t.Fatalf("write frame %d: %v", i, err) + } + } + + for i := 0; i < 10; i++ { + gotType, gotPayload, err := wire.ReadFrame(&buf) + if err != nil { + t.Fatalf("read frame %d: %v", i, err) + } + if gotType != byte(i) { + t.Fatalf("frame %d type: got 0x%02x, want 0x%02x", i, gotType, byte(i)) + } + if len(gotPayload) != 2 || gotPayload[0] != byte(i) || gotPayload[1] != byte(i+1) { + t.Fatalf("frame %d payload mismatch", i) + } + } +} + +func TestWireFrameTooLarge(t *testing.T) { + t.Parallel() + + // Write a frame claiming a payload larger than MaxMessageSize + var buf bytes.Buffer + wire.WriteFrame(&buf, wire.MsgJSON, make([]byte, wire.MaxMessageSize+1)) + + _, _, err := wire.ReadFrame(&buf) + if err == nil { + t.Fatal("expected error for oversized frame") + } +} + +func TestHeartbeatReqRoundTrip(t *testing.T) { + t.Parallel() + + var sig [64]byte + for i := range sig { + sig[i] = byte(i) + } + + payload := wire.EncodeHeartbeatReq(42, sig[:]) + req, err := wire.DecodeHeartbeatReq(payload) + if err != nil { + t.Fatalf("decode: %v", err) + } + if req.NodeID != 42 { + t.Fatalf("nodeID: got %d, want 42", req.NodeID) + } + if req.Signature != sig { + t.Fatal("signature mismatch") + } +} + +func TestHeartbeatReqTooShort(t *testing.T) { + t.Parallel() + _, err := wire.DecodeHeartbeatReq([]byte{1, 2, 3}) + if err == nil { + t.Fatal("expected error for short payload") + } +} + +func TestHeartbeatRespRoundTrip(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + unixTime int64 + keyExpiryWarning bool + }{ + {"no warning", 1700000000, false}, + {"with warning", 1700000000, true}, + {"zero time", 0, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + payload := wire.EncodeHeartbeatResp(tt.unixTime, tt.keyExpiryWarning) + gotTime, gotWarning, err := wire.DecodeHeartbeatResp(payload) + if err != nil { + t.Fatalf("decode: %v", err) + } + if gotTime != tt.unixTime { + t.Fatalf("time: got %d, want %d", gotTime, tt.unixTime) + } + if gotWarning != tt.keyExpiryWarning { + t.Fatalf("warning: got %v, want %v", gotWarning, tt.keyExpiryWarning) + } + }) + } +} + +func TestLookupReqRoundTrip(t *testing.T) { + t.Parallel() + + payload := wire.EncodeLookupReq(12345) + nodeID, err := wire.DecodeLookupReq(payload) + if err != nil { + t.Fatalf("decode: %v", err) + } + if nodeID != 12345 { + t.Fatalf("nodeID: got %d, want 12345", nodeID) + } +} + +func TestLookupRespRoundTrip(t *testing.T) { + t.Parallel() + + pubKey := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32} + + payload := wire.EncodeLookupResp( + 42, // nodeID + true, // public + true, // taskExec + []uint16{1, 2, 3}, // networks + pubKey, // pubKey + "test-host", // hostname + []string{"svc", "primary"}, // tags + "10.0.0.1:4000", // realAddr + "ext-123", // externalID + ) + + result, err := wire.DecodeLookupResp(payload) + if err != nil { + t.Fatalf("decode: %v", err) + } + + if result.NodeID != 42 { + t.Fatalf("NodeID: got %d, want 42", result.NodeID) + } + if !result.Public { + t.Fatal("expected Public=true") + } + if !result.TaskExec { + t.Fatal("expected TaskExec=true") + } + if len(result.Networks) != 3 || result.Networks[0] != 1 || result.Networks[2] != 3 { + t.Fatalf("Networks: got %v, want [1,2,3]", result.Networks) + } + if !bytes.Equal(result.PubKey, pubKey) { + t.Fatal("PubKey mismatch") + } + if result.Hostname != "test-host" { + t.Fatalf("Hostname: got %q, want %q", result.Hostname, "test-host") + } + if len(result.Tags) != 2 || result.Tags[0] != "svc" || result.Tags[1] != "primary" { + t.Fatalf("Tags: got %v", result.Tags) + } + if result.RealAddr != "10.0.0.1:4000" { + t.Fatalf("RealAddr: got %q", result.RealAddr) + } + if result.ExternalID != "ext-123" { + t.Fatalf("ExternalID: got %q", result.ExternalID) + } +} + +func TestLookupRespMinimal(t *testing.T) { + t.Parallel() + + payload := wire.EncodeLookupResp(1, false, false, nil, nil, "", nil, "", "") + result, err := wire.DecodeLookupResp(payload) + if err != nil { + t.Fatalf("decode: %v", err) + } + if result.NodeID != 1 { + t.Fatalf("NodeID: got %d, want 1", result.NodeID) + } + if result.Public || result.TaskExec { + t.Fatal("expected both flags false") + } + if len(result.Networks) != 0 { + t.Fatal("expected empty networks") + } + if len(result.PubKey) != 0 { + t.Fatal("expected empty pubkey") + } +} + +func TestResolveReqRoundTrip(t *testing.T) { + t.Parallel() + + sig := make([]byte, 64) + for i := range sig { + sig[i] = byte(i + 100) + } + + payload := wire.EncodeResolveReq(10, 20, sig) + nodeID, requesterID, gotSig, err := wire.DecodeResolveReq(payload) + if err != nil { + t.Fatalf("decode: %v", err) + } + if nodeID != 10 { + t.Fatalf("nodeID: got %d, want 10", nodeID) + } + if requesterID != 20 { + t.Fatalf("requesterID: got %d, want 20", requesterID) + } + if !bytes.Equal(gotSig, sig) { + t.Fatal("signature mismatch") + } +} + +func TestResolveRespRoundTrip(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + nodeID uint32 + realAddr string + lanAddrs []string + keyAgeDays int + }{ + {"basic", 42, "10.0.0.1:4000", nil, 30}, + {"with LANs", 42, "10.0.0.1:4000", []string{"192.168.1.1:4000", "192.168.2.1:4000"}, 30}, + {"unknown key age", 42, "10.0.0.1:4000", nil, -1}, + {"zero key age", 42, "10.0.0.1:4000", nil, 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + payload := wire.EncodeResolveResp(tt.nodeID, tt.realAddr, tt.lanAddrs, tt.keyAgeDays) + result, err := wire.DecodeResolveResp(payload) + if err != nil { + t.Fatalf("decode: %v", err) + } + if result.NodeID != tt.nodeID { + t.Fatalf("NodeID: got %d, want %d", result.NodeID, tt.nodeID) + } + if result.RealAddr != tt.realAddr { + t.Fatalf("RealAddr: got %q, want %q", result.RealAddr, tt.realAddr) + } + if len(result.LANAddrs) != len(tt.lanAddrs) { + t.Fatalf("LANAddrs length: got %d, want %d", len(result.LANAddrs), len(tt.lanAddrs)) + } + for i, la := range result.LANAddrs { + if la != tt.lanAddrs[i] { + t.Fatalf("LANAddrs[%d]: got %q, want %q", i, la, tt.lanAddrs[i]) + } + } + if result.KeyAgeDays != tt.keyAgeDays { + t.Fatalf("KeyAgeDays: got %d, want %d", result.KeyAgeDays, tt.keyAgeDays) + } + }) + } +} + +func TestResolveRespMaxKeyAge(t *testing.T) { + t.Parallel() + + // Verify math.MaxUint32 maps to -1 + payload := wire.EncodeResolveResp(1, "addr", nil, -1) + result, err := wire.DecodeResolveResp(payload) + if err != nil { + t.Fatalf("decode: %v", err) + } + if result.KeyAgeDays != -1 { + t.Fatalf("KeyAgeDays: got %d, want -1", result.KeyAgeDays) + } + + // Verify large positive value round-trips + payload = wire.EncodeResolveResp(1, "addr", nil, int(math.MaxUint32-1)) + result, err = wire.DecodeResolveResp(payload) + if err != nil { + t.Fatalf("decode: %v", err) + } + if result.KeyAgeDays != int(math.MaxUint32-1) { + t.Fatalf("KeyAgeDays: got %d, want %d", result.KeyAgeDays, math.MaxUint32-1) + } +} + +func TestWireErrorRoundTrip(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + msg string + }{ + {"simple", "not found"}, + {"empty", ""}, + {"long", string(make([]byte, 1000))}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + payload := wire.EncodeError(tt.msg) + got := wire.DecodeError(payload) + if got != tt.msg { + t.Fatalf("got %q, want %q", got, tt.msg) + } + }) + } +} + +func TestWireErrorTruncation(t *testing.T) { + t.Parallel() + + // Messages > 65000 are truncated + longMsg := string(make([]byte, 70000)) + payload := wire.EncodeError(longMsg) + got := wire.DecodeError(payload) + if len(got) != 65000 { + t.Fatalf("expected truncated to 65000, got %d", len(got)) + } +} + +func TestWireProtocolNegotiationMagic(t *testing.T) { + t.Parallel() + + // Verify the magic bytes are correct + if wire.Magic != [4]byte{0x50, 0x49, 0x4C, 0x54} { + t.Fatalf("magic: got %v, want PILT", wire.Magic) + } + // Verify magic != any valid JSON length prefix (which must be < MaxMessageSize) + magicAsLen := uint32(wire.Magic[0])<<24 | uint32(wire.Magic[1])<<16 | uint32(wire.Magic[2])<<8 | uint32(wire.Magic[3]) + if magicAsLen <= wire.MaxMessageSize { + t.Fatalf("magic as length (%d) must be > MaxMessageSize (%d) for protocol detection", magicAsLen, wire.MaxMessageSize) + } +} + +func BenchmarkEncodeHeartbeatReq(b *testing.B) { + sig := make([]byte, 64) + for i := 0; i < b.N; i++ { + wire.EncodeHeartbeatReq(42, sig) + } +} + +func BenchmarkDecodeHeartbeatReq(b *testing.B) { + sig := make([]byte, 64) + payload := wire.EncodeHeartbeatReq(42, sig) + for i := 0; i < b.N; i++ { + wire.DecodeHeartbeatReq(payload) + } +} + +func BenchmarkEncodeLookupResp(b *testing.B) { + pubKey := make([]byte, 32) + networks := []uint16{1, 2, 3} + tags := []string{"svc", "primary"} + for i := 0; i < b.N; i++ { + wire.EncodeLookupResp(42, true, true, networks, pubKey, "test-host", tags, "10.0.0.1:4000", "ext-123") + } +} + +func BenchmarkDecodeLookupResp(b *testing.B) { + pubKey := make([]byte, 32) + networks := []uint16{1, 2, 3} + tags := []string{"svc", "primary"} + payload := wire.EncodeLookupResp(42, true, true, networks, pubKey, "test-host", tags, "10.0.0.1:4000", "ext-123") + for i := 0; i < b.N; i++ { + wire.DecodeLookupResp(payload) + } +} + +func BenchmarkWireFrameRoundTrip(b *testing.B) { + payload := make([]byte, 68) // heartbeat size + var buf bytes.Buffer + for i := 0; i < b.N; i++ { + buf.Reset() + wire.WriteFrame(&buf, wire.MsgHeartbeat, payload) + wire.ReadFrame(&buf) + } +} diff --git a/secure/client.go b/secure/client.go new file mode 100644 index 0000000..1381016 --- /dev/null +++ b/secure/client.go @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package secure + +import ( + "github.com/pilot-protocol/common/driver" + "github.com/pilot-protocol/common/protocol" +) + +// Dial connects to a remote agent's secure port and performs the handshake. +// Returns an encrypted connection that implements net.Conn. +func Dial(d *driver.Driver, addr protocol.Addr, auth ...*HandshakeConfig) (*SecureConn, error) { + conn, err := d.DialAddr(addr, protocol.PortSecure) + if err != nil { + return nil, err + } + + sc, err := Handshake(conn, false, auth...) + if err != nil { + conn.Close() + return nil, err + } + return sc, nil +} diff --git a/secure/secure.go b/secure/secure.go new file mode 100644 index 0000000..f3d94dc --- /dev/null +++ b/secure/secure.go @@ -0,0 +1,773 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package secure + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/ecdh" + "crypto/ed25519" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/binary" + "fmt" + "io" + "net" + "sync" + "time" +) + +// MaxEncryptedMessageLen limits the maximum decrypted message size to prevent +// memory exhaustion from a malicious peer advertising a huge msgLen. +const MaxEncryptedMessageLen = 16 * 1024 * 1024 // 16 MB + +// HandshakeTimeout is the maximum time allowed for the ECDH handshake. +const HandshakeTimeout = 10 * time.Second + +// AuthFrameLen is the total size of an auth frame: +// nodeID(4) + timestamp(8) + nonce(16) + ed25519_signature(64) = 92 bytes. +const AuthFrameLen = 4 + 8 + 16 + 64 + +// authTimestampSkew is the maximum allowed time difference for auth timestamps. +const authTimestampSkew = 5 * time.Second + +// replayCacheExpiry is how long nonces are kept in the replay cache. +const replayCacheExpiry = 1 * time.Hour + +// HandshakeConfig holds identity authentication parameters for the secure +// channel handshake. If nil is passed to Handshake, authentication is skipped +// (backward compatibility for tests and unauthenticated channels). +type HandshakeConfig struct { + NodeID uint32 + Signer ed25519.PrivateKey + PeerPubKey ed25519.PublicKey +} + +// PeerPubKeyLookup returns the Ed25519 public key for a given node ID. +// Used by the server to look up a connecting client's identity for auth +// verification. Returns nil if the node is unknown. +// +// Definitive declaration of PeerPubKeyLookup; do not duplicate in this package. +type PeerPubKeyLookup func(nodeID uint32) ed25519.PublicKey + +// replayCache prevents reuse of auth nonces within a 1-hour window. +var replayCache = struct { + sync.Mutex + nonces map[[16]byte]time.Time +}{nonces: make(map[[16]byte]time.Time)} + +func init() { + go replayCacheCleaner() +} + +// replayCacheCleaner periodically removes expired nonce entries. +func replayCacheCleaner() { + ticker := time.NewTicker(5 * time.Minute) + for range ticker.C { + now := time.Now() + replayCache.Lock() + for k, t := range replayCache.nonces { + if now.Sub(t) > replayCacheExpiry { + delete(replayCache.nonces, k) + } + } + replayCache.Unlock() + } +} + +// maxReplayCacheEntries caps the replay cache to prevent memory exhaustion (M1 fix). +const maxReplayCacheEntries = 100000 + +// CheckAndRecordNonce returns an error if the nonce was already seen within +// the replay window, otherwise records it and returns nil. +func CheckAndRecordNonce(nonce [16]byte) error { + replayCache.Lock() + defer replayCache.Unlock() + if _, exists := replayCache.nonces[nonce]; exists { + return fmt.Errorf("auth nonce replay detected") + } + if len(replayCache.nonces) >= maxReplayCacheEntries { + return fmt.Errorf("auth replay cache full") + } + replayCache.nonces[nonce] = time.Now() + return nil +} + +// ResetReplayCache clears the replay cache. Exported for testing only. +func ResetReplayCache() { + replayCache.Lock() + defer replayCache.Unlock() + replayCache.nonces = make(map[[16]byte]time.Time) +} + +// InjectReplayNonce adds a nonce to the replay cache. Exported for testing only. +func InjectReplayNonce(nonce [16]byte) { + replayCache.Lock() + defer replayCache.Unlock() + replayCache.nonces[nonce] = time.Now() +} + +// CheckReplayNonce checks if a nonce is in the replay cache without recording it. +// Exported for testing only. +func CheckReplayNonce(nonce [16]byte) error { + replayCache.Lock() + defer replayCache.Unlock() + if _, exists := replayCache.nonces[nonce]; exists { + return fmt.Errorf("auth nonce replay detected") + } + return nil +} + +// HandshakeWithTimestampOffset performs an authenticated handshake but shifts +// the auth frame timestamp by the given offset. Exported for testing only. +func HandshakeWithTimestampOffset(conn net.Conn, isServer bool, cfg *HandshakeConfig, offset time.Duration) (*SecureConn, error) { + conn.SetDeadline(time.Now().Add(HandshakeTimeout)) + defer conn.SetDeadline(time.Time{}) + + curve := ecdh.X25519() + privKey, err := curve.GenerateKey(rand.Reader) + if err != nil { + return nil, fmt.Errorf("generate key: %w", err) + } + localPub := privKey.PublicKey().Bytes() + + var remotePub []byte + + if isServer { + remotePub, err = ReadExact(conn, 32) + if err != nil { + return nil, fmt.Errorf("read client key: %w", err) + } + if _, err := conn.Write(localPub); err != nil { + return nil, fmt.Errorf("send server key: %w", err) + } + } else { + if _, err := conn.Write(localPub); err != nil { + return nil, fmt.Errorf("send client key: %w", err) + } + remotePub, err = ReadExact(conn, 32) + if err != nil { + return nil, fmt.Errorf("read server key: %w", err) + } + } + + peerKey, err := curve.NewPublicKey(remotePub) + if err != nil { + return nil, fmt.Errorf("parse peer key: %w", err) + } + + shared, err := privKey.ECDH(peerKey) + if err != nil { + return nil, fmt.Errorf("ecdh: %w", err) + } + + // HKDF-SHA256 key derivation (H1 fix) + mac := hmac.New(sha256.New, nil) + mac.Write(shared) + prk := mac.Sum(nil) + mac = hmac.New(sha256.New, prk) + mac.Write([]byte("pilot-secure-v1")) + mac.Write([]byte{0x01}) + key := mac.Sum(nil) + + block, err := aes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("aes: %w", err) + } + aead, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("gcm: %w", err) + } + + // Zero intermediate key material (H4 fix) + for i := range shared { + shared[i] = 0 + } + for i := range key { + key[i] = 0 + } + for i := range prk { + prk[i] = 0 + } + + sc := &SecureConn{raw: conn, aead: aead} + if isServer { + sc.noncePrefix = [4]byte{0x00, 0x00, 0x00, 0x01} + } else { + sc.noncePrefix = [4]byte{0x00, 0x00, 0x00, 0x02} + } + + if cfg != nil && cfg.Signer != nil { + if err := performAuthWithOffset(sc, cfg, localPub, remotePub, isServer, offset); err != nil { + sc.Close() + return nil, fmt.Errorf("auth: %w", err) + } + } + + return sc, nil +} + +// performAuthWithOffset is like performAuth but applies a timestamp offset. +func performAuthWithOffset(sc *SecureConn, cfg *HandshakeConfig, localX25519Pub, remoteX25519Pub []byte, isServer bool, offset time.Duration) error { + // Use shifted timestamp + ts := uint64(time.Now().Add(offset).Unix()) + + var authNonce [16]byte + if _, err := rand.Read(authNonce[:]); err != nil { + return fmt.Errorf("generate auth nonce: %w", err) + } + + sigMsg := BuildAuthSignMessage(cfg.NodeID, localX25519Pub, ts, authNonce) + signature := ed25519.Sign(cfg.Signer, sigMsg) + + frame := make([]byte, AuthFrameLen) + binary.BigEndian.PutUint32(frame[0:4], cfg.NodeID) + binary.BigEndian.PutUint64(frame[4:12], ts) + copy(frame[12:28], authNonce[:]) + copy(frame[28:92], signature) + + now := time.Now() // verifier uses current time + + if isServer { + if _, err := sc.Write(frame); err != nil { + return fmt.Errorf("send auth frame: %w", err) + } + peerFrame, err := readAuthFrame(sc) + if err != nil { + return fmt.Errorf("read peer auth frame: %w", err) + } + peerNodeID, err := VerifyAuthFrame(peerFrame, cfg.PeerPubKey, remoteX25519Pub, now) + if err != nil { + return err + } + sc.PeerNodeID = peerNodeID + } else { + peerFrame, err := readAuthFrame(sc) + if err != nil { + return fmt.Errorf("read peer auth frame: %w", err) + } + peerNodeID, err := VerifyAuthFrame(peerFrame, cfg.PeerPubKey, remoteX25519Pub, now) + if err != nil { + return err + } + sc.PeerNodeID = peerNodeID + if _, err := sc.Write(frame); err != nil { + return fmt.Errorf("send auth frame: %w", err) + } + } + + return nil +} + +// SecureConn wraps a net.Conn with AES-256-GCM encryption. +// After a successful ECDH handshake, all reads and writes are encrypted. +type SecureConn struct { + raw net.Conn + aead cipher.AEAD + rmu sync.Mutex + wmu sync.Mutex + nonce uint64 // monotonic counter for nonces + noncePrefix [4]byte // role-based prefix for nonce domain separation + readBuf []byte // leftover plaintext from a previous Read + PeerNodeID uint32 // authenticated peer node ID (0 if unauthenticated) +} + +// Handshake performs an ECDH key exchange over the connection. +// isServer determines which side reads first. +// An optional HandshakeConfig enables mutual Ed25519 authentication inside the +// encrypted channel after the ECDH exchange. Pass nil or omit for unauthenticated +// mode (backward compatible). +// A deadline is set to prevent indefinite blocking (M14 fix). +func Handshake(conn net.Conn, isServer bool, auth ...*HandshakeConfig) (*SecureConn, error) { + // Set handshake deadline to prevent indefinite blocking (M14 fix) + conn.SetDeadline(time.Now().Add(HandshakeTimeout)) + defer conn.SetDeadline(time.Time{}) // clear deadline after handshake + + // Generate ephemeral X25519 key pair + curve := ecdh.X25519() + privKey, err := curve.GenerateKey(rand.Reader) + if err != nil { + return nil, fmt.Errorf("generate key: %w", err) + } + localPub := privKey.PublicKey().Bytes() // 32 bytes + + var remotePub []byte + + if isServer { + // Server: read client's public key first, then send ours + remotePub, err = ReadExact(conn, 32) + if err != nil { + return nil, fmt.Errorf("read client key: %w", err) + } + if _, err := conn.Write(localPub); err != nil { + return nil, fmt.Errorf("send server key: %w", err) + } + } else { + // Client: send our public key first, then read server's + if _, err := conn.Write(localPub); err != nil { + return nil, fmt.Errorf("send client key: %w", err) + } + remotePub, err = ReadExact(conn, 32) + if err != nil { + return nil, fmt.Errorf("read server key: %w", err) + } + } + + // Parse remote public key + peerKey, err := curve.NewPublicKey(remotePub) + if err != nil { + return nil, fmt.Errorf("parse peer key: %w", err) + } + + // Compute shared secret + shared, err := privKey.ECDH(peerKey) + if err != nil { + return nil, fmt.Errorf("ecdh: %w", err) + } + + // HKDF-SHA256 key derivation (H1 fix) + mac := hmac.New(sha256.New, nil) + mac.Write(shared) + prk := mac.Sum(nil) + mac = hmac.New(sha256.New, prk) + mac.Write([]byte("pilot-secure-v1")) + mac.Write([]byte{0x01}) + key := mac.Sum(nil) + + // Create AES-GCM cipher + block, err := aes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("aes: %w", err) + } + aead, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("gcm: %w", err) + } + + // Zero intermediate key material (H4 fix) + for i := range shared { + shared[i] = 0 + } + for i := range key { + key[i] = 0 + } + for i := range prk { + prk[i] = 0 + } + + sc := &SecureConn{raw: conn, aead: aead} + // Use role-based nonce prefix to prevent nonce collision (C3 fix). + // Both sides share the same AES-GCM key; using deterministic prefixes + // based on role ensures the nonce spaces never overlap. + if isServer { + sc.noncePrefix = [4]byte{0x00, 0x00, 0x00, 0x01} // server prefix + } else { + sc.noncePrefix = [4]byte{0x00, 0x00, 0x00, 0x02} // client prefix + } + + // Perform mutual Ed25519 authentication if config provided. + // This happens INSIDE the encrypted channel (after ECDH). + var cfg *HandshakeConfig + if len(auth) > 0 { + cfg = auth[0] + } + if cfg != nil && cfg.Signer != nil { + if err := performAuth(sc, cfg, localPub, remotePub, isServer); err != nil { + sc.Close() + return nil, fmt.Errorf("auth: %w", err) + } + } + + return sc, nil +} + +// HandshakeWithLookup is like Handshake with auth, but uses a lookup function +// to resolve the peer's Ed25519 public key by nodeID. This is used by servers +// that don't know the peer's identity until they read the auth frame. +func HandshakeWithLookup(conn net.Conn, isServer bool, cfg *HandshakeConfig, lookup PeerPubKeyLookup) (*SecureConn, error) { + // Set handshake deadline to prevent indefinite blocking (M14 fix) + conn.SetDeadline(time.Now().Add(HandshakeTimeout)) + defer conn.SetDeadline(time.Time{}) // clear deadline after handshake + + // Generate ephemeral X25519 key pair + curve := ecdh.X25519() + privKey, err := curve.GenerateKey(rand.Reader) + if err != nil { + return nil, fmt.Errorf("generate key: %w", err) + } + localPub := privKey.PublicKey().Bytes() // 32 bytes + + var remotePub []byte + + if isServer { + remotePub, err = ReadExact(conn, 32) + if err != nil { + return nil, fmt.Errorf("read client key: %w", err) + } + if _, err := conn.Write(localPub); err != nil { + return nil, fmt.Errorf("send server key: %w", err) + } + } else { + if _, err := conn.Write(localPub); err != nil { + return nil, fmt.Errorf("send client key: %w", err) + } + remotePub, err = ReadExact(conn, 32) + if err != nil { + return nil, fmt.Errorf("read server key: %w", err) + } + } + + peerKey, err := curve.NewPublicKey(remotePub) + if err != nil { + return nil, fmt.Errorf("parse peer key: %w", err) + } + + shared, err := privKey.ECDH(peerKey) + if err != nil { + return nil, fmt.Errorf("ecdh: %w", err) + } + + // HKDF-SHA256 key derivation (H1 fix) + mac := hmac.New(sha256.New, nil) + mac.Write(shared) + prk := mac.Sum(nil) + mac = hmac.New(sha256.New, prk) + mac.Write([]byte("pilot-secure-v1")) + mac.Write([]byte{0x01}) + key := mac.Sum(nil) + + block, err := aes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("aes: %w", err) + } + aead, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("gcm: %w", err) + } + + // Zero intermediate key material (H4 fix) + for i := range shared { + shared[i] = 0 + } + for i := range key { + key[i] = 0 + } + for i := range prk { + prk[i] = 0 + } + + sc := &SecureConn{raw: conn, aead: aead} + if isServer { + sc.noncePrefix = [4]byte{0x00, 0x00, 0x00, 0x01} + } else { + sc.noncePrefix = [4]byte{0x00, 0x00, 0x00, 0x02} + } + + if cfg != nil && cfg.Signer != nil { + if err := performAuthWithLookup(sc, cfg, localPub, remotePub, isServer, lookup); err != nil { + sc.Close() + return nil, fmt.Errorf("auth: %w", err) + } + } + + return sc, nil +} + +// performAuthWithLookup is like performAuth but resolves the peer's Ed25519 +// pubkey via a lookup function after reading the peer's auth frame. +func performAuthWithLookup(sc *SecureConn, cfg *HandshakeConfig, localX25519Pub, remoteX25519Pub []byte, isServer bool, lookup PeerPubKeyLookup) error { + now := time.Now() + ts := uint64(now.Unix()) + + var authNonce [16]byte + if _, err := rand.Read(authNonce[:]); err != nil { + return fmt.Errorf("generate auth nonce: %w", err) + } + + sigMsg := BuildAuthSignMessage(cfg.NodeID, localX25519Pub, ts, authNonce) + signature := ed25519.Sign(cfg.Signer, sigMsg) + + frame := make([]byte, AuthFrameLen) + binary.BigEndian.PutUint32(frame[0:4], cfg.NodeID) + binary.BigEndian.PutUint64(frame[4:12], ts) + copy(frame[12:28], authNonce[:]) + copy(frame[28:92], signature) + + if isServer { + if _, err := sc.Write(frame); err != nil { + return fmt.Errorf("send auth frame: %w", err) + } + peerFrame, err := readAuthFrame(sc) + if err != nil { + return fmt.Errorf("read peer auth frame: %w", err) + } + // Extract peer's nodeID to look up their pubkey + peerNodeID := binary.BigEndian.Uint32(peerFrame[0:4]) + peerPubKey := lookup(peerNodeID) + if peerPubKey == nil { + return fmt.Errorf("unknown peer node %d: no public key found", peerNodeID) + } + peerNodeID, err = VerifyAuthFrame(peerFrame, peerPubKey, remoteX25519Pub, now) + if err != nil { + return err + } + sc.PeerNodeID = peerNodeID + } else { + peerFrame, err := readAuthFrame(sc) + if err != nil { + return fmt.Errorf("read peer auth frame: %w", err) + } + peerNodeID := binary.BigEndian.Uint32(peerFrame[0:4]) + peerPubKey := lookup(peerNodeID) + if peerPubKey == nil { + return fmt.Errorf("unknown peer node %d: no public key found", peerNodeID) + } + peerNodeID, err = VerifyAuthFrame(peerFrame, peerPubKey, remoteX25519Pub, now) + if err != nil { + return err + } + sc.PeerNodeID = peerNodeID + if _, err := sc.Write(frame); err != nil { + return fmt.Errorf("send auth frame: %w", err) + } + } + + return nil +} + +// performAuth executes the mutual Ed25519 authentication protocol inside the +// already-encrypted SecureConn. Both sides send an auth frame and verify the +// peer's frame. +// +// Auth frame format (92 bytes): +// +// [nodeID(4)][timestamp(8)][nonce(16)][ed25519_signature(64)] +// +// Signature covers: +// +// "pilot-secure-auth:" + nodeID(4) + X25519_ephemeral_pubkey(32) + timestamp(8) + nonce(16) +// +// Each side signs its OWN X25519 ephemeral pubkey (localPub). The verifier +// reconstructs the signed message using the peer's X25519 pubkey (remotePub, +// which was received during the ECDH exchange). This binds the ephemeral ECDH +// key to the long-term Ed25519 identity: a MITM cannot substitute their own +// X25519 key because they cannot produce a valid Ed25519 signature for it. +func performAuth(sc *SecureConn, cfg *HandshakeConfig, localX25519Pub, remoteX25519Pub []byte, isServer bool) error { + // Build our auth frame + now := time.Now() + ts := uint64(now.Unix()) + + var authNonce [16]byte + if _, err := rand.Read(authNonce[:]); err != nil { + return fmt.Errorf("generate auth nonce: %w", err) + } + + // Sign over our own X25519 pubkey to bind our identity to this ECDH session + sigMsg := BuildAuthSignMessage(cfg.NodeID, localX25519Pub, ts, authNonce) + signature := ed25519.Sign(cfg.Signer, sigMsg) + + // Build the wire frame: nodeID(4) + timestamp(8) + nonce(16) + signature(64) + frame := make([]byte, AuthFrameLen) + binary.BigEndian.PutUint32(frame[0:4], cfg.NodeID) + binary.BigEndian.PutUint64(frame[4:12], ts) + copy(frame[12:28], authNonce[:]) + copy(frame[28:92], signature) + + // Exchange auth frames. Server sends first, then reads. + // Client reads first, then sends. This prevents deadlock on net.Pipe. + if isServer { + if _, err := sc.Write(frame); err != nil { + return fmt.Errorf("send auth frame: %w", err) + } + peerFrame, err := readAuthFrame(sc) + if err != nil { + return fmt.Errorf("read peer auth frame: %w", err) + } + peerNodeID, err := VerifyAuthFrame(peerFrame, cfg.PeerPubKey, remoteX25519Pub, now) + if err != nil { + return err + } + sc.PeerNodeID = peerNodeID + } else { + peerFrame, err := readAuthFrame(sc) + if err != nil { + return fmt.Errorf("read peer auth frame: %w", err) + } + peerNodeID, err := VerifyAuthFrame(peerFrame, cfg.PeerPubKey, remoteX25519Pub, now) + if err != nil { + return err + } + sc.PeerNodeID = peerNodeID + if _, err := sc.Write(frame); err != nil { + return fmt.Errorf("send auth frame: %w", err) + } + } + + return nil +} + +// readAuthFrame reads exactly AuthFrameLen bytes from the SecureConn. +// The data is already decrypted by SecureConn.Read. +func readAuthFrame(sc *SecureConn) ([]byte, error) { + frame := make([]byte, AuthFrameLen) + n := 0 + for n < AuthFrameLen { + nn, err := sc.Read(frame[n:]) + if err != nil { + return nil, err + } + n += nn + } + return frame, nil +} + +// VerifyAuthFrame validates a peer's auth frame. The peer signed over their own +// X25519 ephemeral pubkey (peerX25519Pub), which we received during the ECDH +// exchange. We reconstruct the signed message and verify against the peer's +// Ed25519 public key from the registry. +func VerifyAuthFrame(frame []byte, peerEdPubKey ed25519.PublicKey, peerX25519Pub []byte, now time.Time) (uint32, error) { + if len(frame) != AuthFrameLen { + return 0, fmt.Errorf("auth frame wrong size: %d", len(frame)) + } + + peerNodeID := binary.BigEndian.Uint32(frame[0:4]) + peerTS := binary.BigEndian.Uint64(frame[4:12]) + var peerNonce [16]byte + copy(peerNonce[:], frame[12:28]) + peerSig := frame[28:92] + + // Check timestamp within skew window + peerTime := time.Unix(int64(peerTS), 0) + diff := now.Sub(peerTime) + if diff < 0 { + diff = -diff + } + if diff > authTimestampSkew { + return 0, fmt.Errorf("auth timestamp expired: skew %v exceeds %v", diff, authTimestampSkew) + } + + // Check nonce replay + if err := CheckAndRecordNonce(peerNonce); err != nil { + return 0, err + } + + // Reconstruct the message the peer signed: domain + nodeID + peerX25519Pub + timestamp + nonce + sigMsg := BuildAuthSignMessage(peerNodeID, peerX25519Pub, peerTS, peerNonce) + + // Verify Ed25519 signature + if !ed25519.Verify(peerEdPubKey, sigMsg, peerSig) { + return 0, fmt.Errorf("auth signature verification failed") + } + + return peerNodeID, nil +} + +// BuildAuthSignMessage constructs the message that is signed in the auth frame. +// Format: "pilot-secure-auth:" + nodeID(4) + X25519_ephemeral_pubkey(32) + timestamp(8) + nonce(16) +func BuildAuthSignMessage(nodeID uint32, x25519Pub []byte, timestamp uint64, nonce [16]byte) []byte { + domain := []byte("pilot-secure-auth:") + msg := make([]byte, len(domain)+4+32+8+16) + copy(msg, domain) + off := len(domain) + binary.BigEndian.PutUint32(msg[off:off+4], nodeID) + off += 4 + copy(msg[off:off+32], x25519Pub) + off += 32 + binary.BigEndian.PutUint64(msg[off:off+8], timestamp) + off += 8 + copy(msg[off:off+16], nonce[:]) + return msg +} + +// Read decrypts and reads data from the connection. +// Leftover plaintext from a previous decryption is returned first (H14 fix). +func (sc *SecureConn) Read(b []byte) (int, error) { + sc.rmu.Lock() + defer sc.rmu.Unlock() + + // Return buffered leftover data first (H14 fix — prevents silent truncation) + if len(sc.readBuf) > 0 { + n := copy(b, sc.readBuf) + sc.readBuf = sc.readBuf[n:] + return n, nil + } + + // Read 4-byte length prefix + lenBuf, err := ReadExact(sc.raw, 4) + if err != nil { + return 0, err + } + msgLen := binary.BigEndian.Uint32(lenBuf) + if msgLen < uint32(sc.aead.NonceSize()) { + return 0, fmt.Errorf("encrypted message too short") + } + // Reject unreasonably large messages to prevent OOM (M13 fix) + if msgLen > uint32(MaxEncryptedMessageLen) { + return 0, fmt.Errorf("encrypted message too large: %d bytes", msgLen) + } + + // Read nonce + ciphertext + ciphertext, err := ReadExact(sc.raw, int(msgLen)) + if err != nil { + return 0, err + } + + nonce := ciphertext[:sc.aead.NonceSize()] + encrypted := ciphertext[sc.aead.NonceSize():] + + // Decrypt with sender's nonce prefix as AAD (H3 fix) + peerAAD := nonce[:4] + plaintext, err := sc.aead.Open(nil, nonce, encrypted, peerAAD) + if err != nil { + return 0, fmt.Errorf("decrypt: %w", err) + } + + n := copy(b, plaintext) + // Buffer any remaining plaintext for subsequent Read calls (H14 fix) + if n < len(plaintext) { + sc.readBuf = make([]byte, len(plaintext)-n) + copy(sc.readBuf, plaintext[n:]) + } + return n, nil +} + +// Write encrypts and writes data to the connection. +func (sc *SecureConn) Write(b []byte) (int, error) { + sc.wmu.Lock() + defer sc.wmu.Unlock() + + // Generate nonce from prefix + counter + nonce := make([]byte, sc.aead.NonceSize()) + copy(nonce[0:4], sc.noncePrefix[:]) + sc.nonce++ + binary.BigEndian.PutUint64(nonce[sc.aead.NonceSize()-8:], sc.nonce) + + // Encrypt with nonce prefix as AAD (H3 fix) + ciphertext := sc.aead.Seal(nil, nonce, b, sc.noncePrefix[:]) + + // Write: [4-byte length][nonce][ciphertext] + total := len(nonce) + len(ciphertext) + msg := make([]byte, 4+total) + binary.BigEndian.PutUint32(msg[0:4], uint32(total)) + copy(msg[4:], nonce) + copy(msg[4+len(nonce):], ciphertext) + + if _, err := sc.raw.Write(msg); err != nil { + return 0, err + } + return len(b), nil +} + +func (sc *SecureConn) Close() error { return sc.raw.Close() } +func (sc *SecureConn) LocalAddr() net.Addr { return sc.raw.LocalAddr() } +func (sc *SecureConn) RemoteAddr() net.Addr { return sc.raw.RemoteAddr() } +func (sc *SecureConn) SetDeadline(t time.Time) error { return sc.raw.SetDeadline(t) } +func (sc *SecureConn) SetReadDeadline(t time.Time) error { return sc.raw.SetReadDeadline(t) } +func (sc *SecureConn) SetWriteDeadline(t time.Time) error { return sc.raw.SetWriteDeadline(t) } + +func ReadExact(r io.Reader, n int) ([]byte, error) { + buf := make([]byte, n) + _, err := io.ReadFull(r, buf) + return buf, err +} diff --git a/secure/server.go b/secure/server.go new file mode 100644 index 0000000..cba34f1 --- /dev/null +++ b/secure/server.go @@ -0,0 +1,102 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package secure + +import ( + "crypto/ed25519" + "log/slog" + "net" + + "github.com/pilot-protocol/common/driver" + "github.com/pilot-protocol/common/protocol" +) + +// Handler is called for each new secure connection. +type Handler func(conn net.Conn) + +// Server listens on port 443 and upgrades connections to encrypted channels. +type Server struct { + driver *driver.Driver + handler Handler + authNodeID uint32 + authSigner ed25519.PrivateKey + peerLookup PeerPubKeyLookup +} + +// NewServer creates a secure channel server (unauthenticated ECDH). +func NewServer(d *driver.Driver, handler Handler) *Server { + return &Server{driver: d, handler: handler} +} + +// NewAuthServer creates a secure channel server with Ed25519 authentication. +// The server authenticates itself and verifies connecting clients using the +// lookup function to obtain each client's expected Ed25519 public key. +func NewAuthServer(d *driver.Driver, handler Handler, nodeID uint32, signer ed25519.PrivateKey, lookup PeerPubKeyLookup) *Server { + return &Server{ + driver: d, + handler: handler, + authNodeID: nodeID, + authSigner: signer, + peerLookup: lookup, + } +} + +// Driver returns the underlying packet driver. Exposed for tests. +func (s *Server) Driver() *driver.Driver { return s.driver } + +// Handler returns the per-connection handler callback. Exposed for tests. +func (s *Server) Handler() Handler { return s.handler } + +// AuthNodeID returns the authenticated node id (zero when unauth). +// Exposed for tests. +func (s *Server) AuthNodeID() uint32 { return s.authNodeID } + +// AuthSigner returns the server's Ed25519 signing key (nil when unauth). +// Exposed for tests. +func (s *Server) AuthSigner() ed25519.PrivateKey { return s.authSigner } + +// PeerLookup returns the per-peer pubkey lookup (nil when unauth). +// Exposed for tests. +func (s *Server) PeerLookup() PeerPubKeyLookup { return s.peerLookup } + +// ListenAndServe binds port 443 and starts accepting secure connections. +func (s *Server) ListenAndServe() error { + ln, err := s.driver.Listen(protocol.PortSecure) + if err != nil { + return err + } + + slog.Info("secure server listening", "port", protocol.PortSecure) + + for { + conn, err := ln.Accept() + if err != nil { + return err + } + go s.handleConn(conn) + } +} + +func (s *Server) handleConn(conn net.Conn) { + var sc *SecureConn + var err error + + if s.authSigner != nil { + // Use lookup-based handshake: the peer's nodeID is extracted from + // their auth frame, then the lookup function resolves their Ed25519 + // pubkey for signature verification. + sc, err = HandshakeWithLookup(conn, true, &HandshakeConfig{ + NodeID: s.authNodeID, + Signer: s.authSigner, + }, s.peerLookup) + } else { + sc, err = Handshake(conn, true) + } + + if err != nil { + slog.Warn("secure handshake failed", "err", err) + conn.Close() + return + } + s.handler(sc) +} diff --git a/secure/zz_extra_coverage_test.go b/secure/zz_extra_coverage_test.go new file mode 100644 index 0000000..3b49c24 --- /dev/null +++ b/secure/zz_extra_coverage_test.go @@ -0,0 +1,1496 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +// Package secure_test — extra coverage tests targeting error paths in +// Handshake / HandshakeWithLookup / HandshakeWithTimestampOffset, the +// SecureConn Read/Write framing edges, performAuth* error branches, +// and the Dial/ListenAndServe surfaces that require a minimal IPC +// daemon mock to reach. +// +// Goal: bring pkg/secure from ~80% to ≥95% statement coverage. These +// tests use only public APIs (or test-only exported helpers already in +// the package). +package secure_test + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/ed25519" + "crypto/rand" + "encoding/binary" + "encoding/hex" + "errors" + "io" + "net" + "os" + "path/filepath" + "strings" + "sync" + "testing" + "time" + + "github.com/pilot-protocol/common/ipcutil" + "github.com/pilot-protocol/common/driver" + "github.com/pilot-protocol/common/protocol" + "github.com/pilot-protocol/common/secure" +) + +// --------------------------------------------------------------------------- +// Fake daemon — minimal IPC peer that drives the public driver.Driver API +// so we can reach secure.Dial, Server.ListenAndServe, and Server.handleConn +// through their real call paths. +// +// Wire format: each IPC frame is a length-prefixed buffer (ipcutil.Read / +// Write), and the first byte is the command opcode (see pkg/driver/ipc.go). +// We hard-code the opcodes here because they are private to driver, but +// the wire format is stable. +// --------------------------------------------------------------------------- + +const ( + cmdBind byte = 0x01 + cmdBindOK byte = 0x02 + cmdDial byte = 0x03 + cmdDialOK byte = 0x04 + cmdAccept byte = 0x05 + cmdSend byte = 0x06 + cmdRecv byte = 0x07 + cmdClose byte = 0x08 + cmdCloseOK byte = 0x09 +) + +// shortSocketPath returns a /tmp path short enough for macOS unix socket +// length limit (~104 chars). +func shortSocketPath(t *testing.T) string { + t.Helper() + var b [6]byte + if _, err := rand.Read(b[:]); err != nil { + t.Fatal(err) + } + p := filepath.Join("/tmp", "ss-"+hex.EncodeToString(b[:])+".sock") + t.Cleanup(func() { _ = os.Remove(p) }) + return p +} + +// fakeDaemon implements just enough of the daemon IPC contract for one +// connection at a time. It runs handlers per opcode; the test sets these +// up before connecting via driver.Connect. +type fakeDaemon struct { + t *testing.T + ln net.Listener + path string + mu sync.Mutex + conn net.Conn + connSet chan struct{} + handlers map[byte]func(frame []byte) [][]byte + // per-connID send forwarders — used by Dial/Accept happy-paths + bridges map[uint32]chan<- []byte +} + +func newFakeDaemon(t *testing.T) *fakeDaemon { + t.Helper() + p := shortSocketPath(t) + ln, err := net.Listen("unix", p) + if err != nil { + t.Fatalf("listen: %v", err) + } + d := &fakeDaemon{ + t: t, + ln: ln, + path: p, + connSet: make(chan struct{}), + handlers: make(map[byte]func(frame []byte) [][]byte), + bridges: make(map[uint32]chan<- []byte), + } + go d.loop() + return d +} + +func (d *fakeDaemon) loop() { + conn, err := d.ln.Accept() + if err != nil { + return + } + d.mu.Lock() + d.conn = conn + d.mu.Unlock() + close(d.connSet) + for { + frame, err := ipcutil.Read(conn) + if err != nil { + return + } + if len(frame) == 0 { + continue + } + cmd := frame[0] + d.mu.Lock() + h := d.handlers[cmd] + var bridgeCh chan<- []byte + var payload []byte + if cmd == cmdSend && len(frame) >= 5 { + id := binary.BigEndian.Uint32(frame[1:5]) + bridgeCh = d.bridges[id] + payload = append([]byte(nil), frame[5:]...) + } + d.mu.Unlock() + if bridgeCh != nil { + bridgeCh <- payload + continue + } + if h == nil { + continue + } + for _, r := range h(frame) { + _ = ipcutil.Write(conn, r) + } + } +} + +func (d *fakeDaemon) push(frame []byte) { + d.mu.Lock() + c := d.conn + d.mu.Unlock() + if c == nil { + <-d.connSet + d.mu.Lock() + c = d.conn + d.mu.Unlock() + } + _ = ipcutil.Write(c, frame) +} + +func (d *fakeDaemon) onCmd(cmd byte, h func(frame []byte) [][]byte) { + d.mu.Lock() + defer d.mu.Unlock() + d.handlers[cmd] = h +} + +func (d *fakeDaemon) registerBridge(connID uint32, toDriver chan<- []byte) { + d.mu.Lock() + defer d.mu.Unlock() + d.bridges[connID] = toDriver +} + +func (d *fakeDaemon) close() { + _ = d.ln.Close() + select { + case <-d.connSet: + case <-time.After(100 * time.Millisecond): + } + d.mu.Lock() + c := d.conn + d.mu.Unlock() + if c != nil { + _ = c.Close() + } +} + +// pumpRecv emits cmdRecv frames carrying `data` for connID to the driver. +func (d *fakeDaemon) pumpRecv(connID uint32, data []byte) { + frame := make([]byte, 1+4+len(data)) + frame[0] = cmdRecv + binary.BigEndian.PutUint32(frame[1:5], connID) + copy(frame[5:], data) + d.push(frame) +} + +// pumpAccept emits a cmdAccept frame for port `port` with the given conn. +func (d *fakeDaemon) pumpAccept(port uint16, connID uint32) { + addrSize := protocol.AddrSize + frame := make([]byte, 1+2+4+addrSize+2) + frame[0] = cmdAccept + binary.BigEndian.PutUint16(frame[1:3], port) + binary.BigEndian.PutUint32(frame[3:7], connID) + binary.BigEndian.PutUint16(frame[3+4+addrSize:], 0) + d.push(frame) +} + +// bridgeDriverToPipe wires a fakeDaemon conn-side to one half of a net.Pipe. +// Bytes the driver writes via cmdSend(connID, ...) are forwarded to the pipe +// (writeable end) `peer`. Bytes that arrive on `peer` are pushed back to the +// driver as cmdRecv(connID, ...). This lets us run secure.Handshake on both +// the driver-Conn side and a raw secure.Handshake on the peer side against +// each other, exercising secure.Dial and ListenAndServe happy paths. +func (d *fakeDaemon) bridgeDriverToPipe(connID uint32, peer net.Conn) { + toPeer := make(chan []byte, 32) + d.registerBridge(connID, toPeer) + go func() { + for data := range toPeer { + if _, err := peer.Write(data); err != nil { + return + } + } + }() + go func() { + buf := make([]byte, 4096) + for { + n, err := peer.Read(buf) + if n > 0 { + d.pumpRecv(connID, buf[:n]) + } + if err != nil { + return + } + } + }() +} + +// failAfterNWrites wraps a net.Conn and returns errClosedPipe after the Nth +// write. Used to surgically trigger sc.Write failures at specific points in +// the authenticated handshake protocol. +type failAfterNWrites struct { + net.Conn + mu sync.Mutex + writes int + failAfter int +} + +func (f *failAfterNWrites) Write(b []byte) (int, error) { + f.mu.Lock() + f.writes++ + n := f.writes + f.mu.Unlock() + if n > f.failAfter { + return 0, io.ErrClosedPipe + } + return f.Conn.Write(b) +} + +// --------------------------------------------------------------------------- +// secure.Dial +// --------------------------------------------------------------------------- + +func TestDialDialAddrErrorPropagates(t *testing.T) { + d := newFakeDaemon(t) + drv, err := driver.Connect(d.path) + if err != nil { + t.Fatalf("connect: %v", err) + } + go func() { + time.Sleep(50 * time.Millisecond) + d.close() + drv.Close() + }() + _, err = secure.Dial(drv, protocol.Addr{Network: 1, Node: 1}) + if err == nil { + t.Fatal("expected dial error after daemon close") + } +} + +func TestDialHandshakeErrorClosesConn(t *testing.T) { + d := newFakeDaemon(t) + defer d.close() + + const connID uint32 = 0xCAFEBABE + d.onCmd(cmdDial, func(frame []byte) [][]byte { + resp := make([]byte, 1+4) + resp[0] = cmdDialOK + binary.BigEndian.PutUint32(resp[1:5], connID) + return [][]byte{resp} + }) + d.onCmd(cmdClose, func(frame []byte) [][]byte { + resp := make([]byte, 5) + resp[0] = cmdCloseOK + binary.BigEndian.PutUint32(resp[1:5], connID) + return [][]byte{resp} + }) + + drv, err := driver.Connect(d.path) + if err != nil { + t.Fatal(err) + } + defer drv.Close() + + go func() { + time.Sleep(20 * time.Millisecond) + d.pumpRecv(connID, []byte{0x01, 0x02, 0x03}) + time.Sleep(20 * time.Millisecond) + d.close() + }() + + _, err = secure.Dial(drv, protocol.Addr{Network: 1, Node: 1}) + if err == nil { + t.Fatal("expected handshake error after daemon close") + } +} + +func TestDialHappyPathWithBridge(t *testing.T) { + d := newFakeDaemon(t) + defer d.close() + + const connID uint32 = 0xAA00AA00 + d.onCmd(cmdDial, func(frame []byte) [][]byte { + resp := make([]byte, 5) + resp[0] = cmdDialOK + binary.BigEndian.PutUint32(resp[1:5], connID) + return [][]byte{resp} + }) + d.onCmd(cmdClose, func(frame []byte) [][]byte { + resp := make([]byte, 5) + resp[0] = cmdCloseOK + binary.BigEndian.PutUint32(resp[1:5], connID) + return [][]byte{resp} + }) + + drv, err := driver.Connect(d.path) + if err != nil { + t.Fatal(err) + } + defer drv.Close() + + pa, pb := net.Pipe() + defer pa.Close() + defer pb.Close() + d.bridgeDriverToPipe(connID, pa) + + peerDone := make(chan *secure.SecureConn, 1) + peerErr := make(chan error, 1) + go func() { + sc, err := secure.Handshake(pb, true) + if err != nil { + peerErr <- err + return + } + peerDone <- sc + }() + + sc, err := secure.Dial(drv, protocol.Addr{Network: 1, Node: 1}) + if err != nil { + t.Fatalf("Dial: %v", err) + } + select { + case <-peerDone: + case err := <-peerErr: + t.Fatalf("peer handshake: %v", err) + case <-time.After(3 * time.Second): + t.Fatal("peer handshake timed out") + } + if sc == nil { + t.Fatal("Dial returned nil conn") + } + sc.Close() +} + +// --------------------------------------------------------------------------- +// Server: ListenAndServe + handleConn paths +// --------------------------------------------------------------------------- + +func TestListenAndServeBindError(t *testing.T) { + d := newFakeDaemon(t) + defer d.close() + drv, err := driver.Connect(d.path) + if err != nil { + t.Fatal(err) + } + defer drv.Close() + + s := secure.NewServer(drv, func(_ net.Conn) {}) + errCh := make(chan error, 1) + go func() { errCh <- s.ListenAndServe() }() + time.Sleep(50 * time.Millisecond) + d.close() + drv.Close() + select { + case err := <-errCh: + if err == nil { + t.Fatal("expected error from ListenAndServe") + } + case <-time.After(5 * time.Second): + t.Fatal("ListenAndServe never returned") + } +} + +func TestListenAndServeAcceptErrorReturns(t *testing.T) { + d := newFakeDaemon(t) + defer d.close() + + d.onCmd(cmdBind, func(frame []byte) [][]byte { + resp := make([]byte, 3) + resp[0] = cmdBindOK + binary.BigEndian.PutUint16(resp[1:3], protocol.PortSecure) + return [][]byte{resp} + }) + + drv, err := driver.Connect(d.path) + if err != nil { + t.Fatal(err) + } + defer drv.Close() + + s := secure.NewServer(drv, func(_ net.Conn) {}) + errCh := make(chan error, 1) + go func() { errCh <- s.ListenAndServe() }() + time.Sleep(50 * time.Millisecond) + d.close() + drv.Close() + select { + case <-errCh: + case <-time.After(5 * time.Second): + t.Fatal("ListenAndServe never returned after daemon close") + } +} + +func TestListenAndServeHandshakeFailsUnauthBranch(t *testing.T) { + d := newFakeDaemon(t) + defer d.close() + + const connID uint32 = 0x77777777 + d.onCmd(cmdBind, func(frame []byte) [][]byte { + resp := make([]byte, 3) + resp[0] = cmdBindOK + binary.BigEndian.PutUint16(resp[1:3], protocol.PortSecure) + return [][]byte{resp} + }) + d.onCmd(cmdClose, func(frame []byte) [][]byte { + resp := make([]byte, 5) + resp[0] = cmdCloseOK + binary.BigEndian.PutUint32(resp[1:5], connID) + return [][]byte{resp} + }) + + drv, err := driver.Connect(d.path) + if err != nil { + t.Fatal(err) + } + defer drv.Close() + + handlerCalled := make(chan struct{}, 1) + s := secure.NewServer(drv, func(_ net.Conn) { handlerCalled <- struct{}{} }) + + go func() { _ = s.ListenAndServe() }() + time.Sleep(80 * time.Millisecond) + d.pumpAccept(protocol.PortSecure, connID) + time.Sleep(40 * time.Millisecond) + d.pumpRecv(connID, []byte{0x01}) + time.Sleep(40 * time.Millisecond) + d.close() + drv.Close() + + select { + case <-handlerCalled: + t.Fatal("handler should not be called on handshake failure") + case <-time.After(200 * time.Millisecond): + } +} + +func TestListenAndServeAuthBranchHandshakeFails(t *testing.T) { + d := newFakeDaemon(t) + defer d.close() + + const connID uint32 = 0x88888888 + d.onCmd(cmdBind, func(frame []byte) [][]byte { + resp := make([]byte, 3) + resp[0] = cmdBindOK + binary.BigEndian.PutUint16(resp[1:3], protocol.PortSecure) + return [][]byte{resp} + }) + d.onCmd(cmdClose, func(frame []byte) [][]byte { + resp := make([]byte, 5) + resp[0] = cmdCloseOK + binary.BigEndian.PutUint32(resp[1:5], connID) + return [][]byte{resp} + }) + + drv, err := driver.Connect(d.path) + if err != nil { + t.Fatal(err) + } + defer drv.Close() + + _, signer, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + lookup := func(_ uint32) ed25519.PublicKey { return nil } + + handlerCalled := make(chan struct{}, 1) + s := secure.NewAuthServer(drv, func(_ net.Conn) { handlerCalled <- struct{}{} }, 42, signer, lookup) + + go func() { _ = s.ListenAndServe() }() + time.Sleep(80 * time.Millisecond) + d.pumpAccept(protocol.PortSecure, connID) + time.Sleep(40 * time.Millisecond) + d.pumpRecv(connID, []byte{0x01}) + time.Sleep(40 * time.Millisecond) + d.close() + drv.Close() + + select { + case <-handlerCalled: + t.Fatal("handler must not be called on failed handshake") + case <-time.After(200 * time.Millisecond): + } +} + +func TestListenAndServeHandlerInvokedOnSuccess(t *testing.T) { + d := newFakeDaemon(t) + defer d.close() + + const connID uint32 = 0xBB00BB00 + d.onCmd(cmdBind, func(frame []byte) [][]byte { + resp := make([]byte, 3) + resp[0] = cmdBindOK + binary.BigEndian.PutUint16(resp[1:3], protocol.PortSecure) + return [][]byte{resp} + }) + d.onCmd(cmdClose, func(frame []byte) [][]byte { + resp := make([]byte, 5) + resp[0] = cmdCloseOK + binary.BigEndian.PutUint32(resp[1:5], connID) + return [][]byte{resp} + }) + + drv, err := driver.Connect(d.path) + if err != nil { + t.Fatal(err) + } + defer drv.Close() + + handlerCh := make(chan struct{}, 1) + s := secure.NewServer(drv, func(_ net.Conn) { handlerCh <- struct{}{} }) + + go func() { _ = s.ListenAndServe() }() + time.Sleep(80 * time.Millisecond) + + pa, pb := net.Pipe() + defer pa.Close() + defer pb.Close() + d.bridgeDriverToPipe(connID, pa) + + go func() { _, _ = secure.Handshake(pb, false) }() + d.pumpAccept(protocol.PortSecure, connID) + + select { + case <-handlerCh: + case <-time.After(5 * time.Second): + t.Fatal("handler never invoked") + } +} + +// --------------------------------------------------------------------------- +// Handshake error branches — server/client sides closing mid-flight +// --------------------------------------------------------------------------- + +func TestHandshakeServerReadClientKeyFails(t *testing.T) { + t.Parallel() + a, b := net.Pipe() + defer a.Close() + defer b.Close() + b.Close() + _, err := secure.Handshake(a, true) + if err == nil { + t.Fatal("expected read-client-key error") + } + if !strings.Contains(err.Error(), "read client key") { + t.Errorf("err = %v", err) + } +} + +func TestHandshakeClientSendKeyFails(t *testing.T) { + t.Parallel() + a, b := net.Pipe() + a.Close() + b.Close() + _, err := secure.Handshake(a, false) + if err == nil { + t.Fatal("expected send-client-key error") + } +} + +func TestHandshakeClientReadServerKeyFails(t *testing.T) { + t.Parallel() + a, b := net.Pipe() + defer a.Close() + defer b.Close() + go func() { + buf := make([]byte, 32) + _, _ = io.ReadFull(b, buf) + b.Close() + }() + _, err := secure.Handshake(a, false) + if err == nil { + t.Fatal("expected read-server-key error") + } + if !strings.Contains(err.Error(), "read server key") { + t.Errorf("err = %v", err) + } +} + +func TestHandshakeServerSendKeyFails(t *testing.T) { + t.Parallel() + a, b := net.Pipe() + defer a.Close() + defer b.Close() + go func() { + junk := make([]byte, 32) + _, _ = rand.Read(junk) + _, _ = b.Write(junk) + b.Close() + }() + _, err := secure.Handshake(a, true) + if err == nil { + t.Fatal("expected send-server-key error") + } +} + +func TestHandshakeWithLookupServerReadKeyFails(t *testing.T) { + t.Parallel() + a, b := net.Pipe() + a.Close() + b.Close() + _, err := secure.HandshakeWithLookup(a, true, nil, nil) + if err == nil { + t.Fatal("expected error") + } +} + +func TestHandshakeWithLookupClientWriteFails(t *testing.T) { + t.Parallel() + a, b := net.Pipe() + a.Close() + b.Close() + _, err := secure.HandshakeWithLookup(a, false, nil, nil) + if err == nil { + t.Fatal("expected error") + } +} + +func TestHandshakeWithLookupClientReadFails(t *testing.T) { + t.Parallel() + a, b := net.Pipe() + defer a.Close() + defer b.Close() + go func() { + buf := make([]byte, 32) + _, _ = io.ReadFull(b, buf) + b.Close() + }() + _, err := secure.HandshakeWithLookup(a, false, nil, nil) + if err == nil { + t.Fatal("expected error") + } +} + +func TestHandshakeWithLookupServerWriteKeyFails(t *testing.T) { + t.Parallel() + a, b := net.Pipe() + defer a.Close() + defer b.Close() + go func() { + junk := make([]byte, 32) + _, _ = rand.Read(junk) + _, _ = b.Write(junk) + b.Close() + }() + _, err := secure.HandshakeWithLookup(a, true, nil, nil) + if err == nil { + t.Fatal("expected error") + } +} + +// --------------------------------------------------------------------------- +// HandshakeWithTimestampOffset error branches +// --------------------------------------------------------------------------- + +func TestHandshakeWithTimestampOffsetServerReadKeyFails(t *testing.T) { + t.Parallel() + a, b := net.Pipe() + a.Close() + b.Close() + _, err := secure.HandshakeWithTimestampOffset(a, true, nil, 0) + if err == nil { + t.Fatal("expected error") + } +} + +func TestHandshakeWithTimestampOffsetClientWriteFails(t *testing.T) { + t.Parallel() + a, b := net.Pipe() + a.Close() + b.Close() + _, err := secure.HandshakeWithTimestampOffset(a, false, nil, 0) + if err == nil { + t.Fatal("expected error") + } +} + +func TestHandshakeWithTimestampOffsetServerWriteFails(t *testing.T) { + t.Parallel() + a, b := net.Pipe() + defer a.Close() + defer b.Close() + go func() { + junk := make([]byte, 32) + _, _ = rand.Read(junk) + _, _ = b.Write(junk) + b.Close() + }() + _, err := secure.HandshakeWithTimestampOffset(a, true, nil, 0) + if err == nil { + t.Fatal("expected error") + } +} + +func TestHandshakeWithTimestampOffsetClientReadFails(t *testing.T) { + t.Parallel() + a, b := net.Pipe() + defer a.Close() + defer b.Close() + go func() { + buf := make([]byte, 32) + _, _ = io.ReadFull(b, buf) + b.Close() + }() + _, err := secure.HandshakeWithTimestampOffset(a, false, nil, 0) + if err == nil { + t.Fatal("expected error") + } +} + +func TestHandshakeWithTimestampOffsetMutual(t *testing.T) { + secure.ResetReplayCache() + _, srvPriv, _ := ed25519.GenerateKey(rand.Reader) + _, cliPriv, _ := ed25519.GenerateKey(rand.Reader) + srvPub := srvPriv.Public().(ed25519.PublicKey) + cliPub := cliPriv.Public().(ed25519.PublicKey) + + cfgServer := &secure.HandshakeConfig{NodeID: 1, Signer: srvPriv, PeerPubKey: cliPub} + cfgClient := &secure.HandshakeConfig{NodeID: 2, Signer: cliPriv, PeerPubKey: srvPub} + + pa, pb := net.Pipe() + defer pa.Close() + defer pb.Close() + type res struct { + sc *secure.SecureConn + err error + } + chA := make(chan res, 1) + chB := make(chan res, 1) + go func() { sc, err := secure.HandshakeWithTimestampOffset(pa, true, cfgServer, 0); chA <- res{sc, err} }() + go func() { sc, err := secure.HandshakeWithTimestampOffset(pb, false, cfgClient, 0); chB <- res{sc, err} }() + rA := <-chA + rB := <-chB + if rA.err != nil { + t.Fatalf("server: %v", rA.err) + } + if rB.err != nil { + t.Fatalf("client: %v", rB.err) + } + if rA.sc.PeerNodeID != 2 || rB.sc.PeerNodeID != 1 { + t.Errorf("peer IDs wrong: %d, %d", rA.sc.PeerNodeID, rB.sc.PeerNodeID) + } +} + +func TestHandshakeWithTimestampOffsetUnauthSkipsAuth(t *testing.T) { + pa, pb := net.Pipe() + defer pa.Close() + defer pb.Close() + type res struct { + sc *secure.SecureConn + err error + } + chA := make(chan res, 1) + chB := make(chan res, 1) + go func() { sc, err := secure.HandshakeWithTimestampOffset(pa, true, nil, 0); chA <- res{sc, err} }() + go func() { sc, err := secure.HandshakeWithTimestampOffset(pb, false, nil, 0); chB <- res{sc, err} }() + rA := <-chA + rB := <-chB + if rA.err != nil || rB.err != nil { + t.Fatalf("handshake errors: %v / %v", rA.err, rB.err) + } + rA.sc.Close() + rB.sc.Close() +} + +// --------------------------------------------------------------------------- +// Handshake — ECDH low-order pubkey hits the "ecdh:" branch +// --------------------------------------------------------------------------- + +func TestHandshakeECDHFailsOnLowOrderPoint(t *testing.T) { + t.Parallel() + a, b := net.Pipe() + defer a.Close() + defer b.Close() + go func() { + zeros := make([]byte, 32) + _, _ = b.Write(zeros) + buf := make([]byte, 32) + _, _ = io.ReadFull(b, buf) + }() + _, err := secure.Handshake(a, true) + if err == nil { + t.Fatal("expected ecdh low-order error") + } + if !strings.Contains(err.Error(), "ecdh") { + t.Errorf("err = %v", err) + } +} + +func TestHandshakeWithLookupECDHFailsOnLowOrderPoint(t *testing.T) { + t.Parallel() + a, b := net.Pipe() + defer a.Close() + defer b.Close() + go func() { + zeros := make([]byte, 32) + _, _ = b.Write(zeros) + buf := make([]byte, 32) + _, _ = io.ReadFull(b, buf) + }() + _, err := secure.HandshakeWithLookup(a, true, nil, nil) + if err == nil { + t.Fatal("expected ecdh error") + } + if !strings.Contains(err.Error(), "ecdh") { + t.Errorf("err = %v", err) + } +} + +func TestHandshakeWithTimestampOffsetECDHFailsOnLowOrderPoint(t *testing.T) { + t.Parallel() + a, b := net.Pipe() + defer a.Close() + defer b.Close() + go func() { + zeros := make([]byte, 32) + _, _ = b.Write(zeros) + buf := make([]byte, 32) + _, _ = io.ReadFull(b, buf) + }() + _, err := secure.HandshakeWithTimestampOffset(a, true, nil, 0) + if err == nil { + t.Fatal("expected ecdh error") + } + if !strings.Contains(err.Error(), "ecdh") { + t.Errorf("err = %v", err) + } +} + +// --------------------------------------------------------------------------- +// SecureConn.Read framing error paths +// --------------------------------------------------------------------------- + +func TestSecureConnReadRejectsMessageTooShort(t *testing.T) { + t.Parallel() + pa, pb := net.Pipe() + defer pa.Close() + defer pb.Close() + + type res struct { + sc *secure.SecureConn + err error + } + chA := make(chan res, 1) + chB := make(chan res, 1) + go func() { sc, err := secure.Handshake(pa, true); chA <- res{sc, err} }() + go func() { sc, err := secure.Handshake(pb, false); chB <- res{sc, err} }() + rA := <-chA + rB := <-chB + if rA.err != nil || rB.err != nil { + t.Fatalf("handshake: %v %v", rA.err, rB.err) + } + defer rA.sc.Close() + defer rB.sc.Close() + + go func() { + var hdr [4]byte + binary.BigEndian.PutUint32(hdr[:], 4) // too short + _, _ = pb.Write(hdr[:]) + _, _ = pb.Write([]byte{0x00, 0x00, 0x00, 0x00}) + }() + + buf := make([]byte, 16) + _, err := rA.sc.Read(buf) + if err == nil { + t.Fatal("expected error on too-short message") + } + if !strings.Contains(err.Error(), "too short") { + t.Errorf("err = %v", err) + } +} + +func TestSecureConnReadRejectsMessageTooLarge(t *testing.T) { + t.Parallel() + pa, pb := net.Pipe() + defer pa.Close() + defer pb.Close() + type res struct { + sc *secure.SecureConn + err error + } + chA := make(chan res, 1) + chB := make(chan res, 1) + go func() { sc, err := secure.Handshake(pa, true); chA <- res{sc, err} }() + go func() { sc, err := secure.Handshake(pb, false); chB <- res{sc, err} }() + rA := <-chA + rB := <-chB + if rA.err != nil || rB.err != nil { + t.Fatalf("handshake: %v %v", rA.err, rB.err) + } + defer rA.sc.Close() + defer rB.sc.Close() + + go func() { + var hdr [4]byte + binary.BigEndian.PutUint32(hdr[:], secure.MaxEncryptedMessageLen+1) + _, _ = pb.Write(hdr[:]) + }() + + buf := make([]byte, 16) + _, err := rA.sc.Read(buf) + if err == nil { + t.Fatal("expected too-large error") + } + if !strings.Contains(err.Error(), "too large") { + t.Errorf("err = %v", err) + } +} + +func TestSecureConnReadDecryptFails(t *testing.T) { + t.Parallel() + pa, pb := net.Pipe() + defer pa.Close() + defer pb.Close() + type res struct { + sc *secure.SecureConn + err error + } + chA := make(chan res, 1) + chB := make(chan res, 1) + go func() { sc, err := secure.Handshake(pa, true); chA <- res{sc, err} }() + go func() { sc, err := secure.Handshake(pb, false); chB <- res{sc, err} }() + rA := <-chA + rB := <-chB + if rA.err != nil || rB.err != nil { + t.Fatalf("handshake: %v %v", rA.err, rB.err) + } + defer rA.sc.Close() + defer rB.sc.Close() + + go func() { + const total = 32 + var hdr [4]byte + binary.BigEndian.PutUint32(hdr[:], total) + _, _ = pb.Write(hdr[:]) + payload := make([]byte, total) + _, _ = rand.Read(payload) + _, _ = pb.Write(payload) + }() + + buf := make([]byte, 16) + _, err := rA.sc.Read(buf) + if err == nil { + t.Fatal("expected decrypt error") + } + if !strings.Contains(err.Error(), "decrypt") { + t.Errorf("err = %v", err) + } +} + +func TestSecureConnReadLengthPrefixError(t *testing.T) { + t.Parallel() + pa, pb := net.Pipe() + type res struct { + sc *secure.SecureConn + err error + } + chA := make(chan res, 1) + chB := make(chan res, 1) + go func() { sc, err := secure.Handshake(pa, true); chA <- res{sc, err} }() + go func() { sc, err := secure.Handshake(pb, false); chB <- res{sc, err} }() + rA := <-chA + rB := <-chB + if rA.err != nil || rB.err != nil { + t.Fatalf("handshake: %v %v", rA.err, rB.err) + } + pb.Close() + rB.sc.Close() + defer pa.Close() + defer rA.sc.Close() + _, err := rA.sc.Read(make([]byte, 16)) + if err == nil { + t.Fatal("expected error reading length") + } +} + +func TestSecureConnReadCiphertextReadError(t *testing.T) { + t.Parallel() + pa, pb := net.Pipe() + type res struct { + sc *secure.SecureConn + err error + } + chA := make(chan res, 1) + chB := make(chan res, 1) + go func() { sc, err := secure.Handshake(pa, true); chA <- res{sc, err} }() + go func() { sc, err := secure.Handshake(pb, false); chB <- res{sc, err} }() + rA := <-chA + rB := <-chB + if rA.err != nil || rB.err != nil { + t.Fatalf("handshake: %v %v", rA.err, rB.err) + } + + go func() { + var hdr [4]byte + binary.BigEndian.PutUint32(hdr[:], 32) + _, _ = pb.Write(hdr[:]) + pb.Close() + }() + defer pa.Close() + defer rA.sc.Close() + _, err := rA.sc.Read(make([]byte, 16)) + if err == nil { + t.Fatal("expected error reading ciphertext body") + } + rB.sc.Close() +} + +// --------------------------------------------------------------------------- +// SecureConn.Write error paths +// --------------------------------------------------------------------------- + +func TestSecureConnWriteErrorOnClosedConn(t *testing.T) { + t.Parallel() + pa, pb := net.Pipe() + defer pa.Close() + defer pb.Close() + type res struct { + sc *secure.SecureConn + err error + } + chA := make(chan res, 1) + chB := make(chan res, 1) + go func() { sc, err := secure.Handshake(pa, true); chA <- res{sc, err} }() + go func() { sc, err := secure.Handshake(pb, false); chB <- res{sc, err} }() + rA := <-chA + rB := <-chB + if rA.err != nil || rB.err != nil { + t.Fatalf("handshake: %v %v", rA.err, rB.err) + } + pa.Close() + _, err := rA.sc.Write([]byte("oops")) + if err == nil { + t.Fatal("expected write error after raw conn close") + } + rB.sc.Close() +} + +// --------------------------------------------------------------------------- +// performAuth* error paths — VerifyAuthFrame failures on each side +// --------------------------------------------------------------------------- + +func TestAuthServerVerifyFailsClientPasses(t *testing.T) { + secure.ResetReplayCache() + srvPub, srvPriv, _ := ed25519.GenerateKey(rand.Reader) + _, cliPriv, _ := ed25519.GenerateKey(rand.Reader) + wrongPub, _, _ := ed25519.GenerateKey(rand.Reader) + + pa, pb := net.Pipe() + defer pa.Close() + defer pb.Close() + + cfgServer := &secure.HandshakeConfig{NodeID: 1, Signer: srvPriv, PeerPubKey: wrongPub} + cfgClient := &secure.HandshakeConfig{NodeID: 2, Signer: cliPriv, PeerPubKey: srvPub} + + errA := make(chan error, 1) + errB := make(chan error, 1) + go func() { _, err := secure.Handshake(pa, true, cfgServer); errA <- err }() + go func() { _, err := secure.Handshake(pb, false, cfgClient); errB <- err }() + <-errA + <-errB +} + +func TestAuthClientVerifyFailsServerPasses(t *testing.T) { + secure.ResetReplayCache() + _, srvPriv, _ := ed25519.GenerateKey(rand.Reader) + cliPub, cliPriv, _ := ed25519.GenerateKey(rand.Reader) + wrongPub, _, _ := ed25519.GenerateKey(rand.Reader) + + pa, pb := net.Pipe() + defer pa.Close() + defer pb.Close() + + cfgServer := &secure.HandshakeConfig{NodeID: 1, Signer: srvPriv, PeerPubKey: cliPub} + cfgClient := &secure.HandshakeConfig{NodeID: 2, Signer: cliPriv, PeerPubKey: wrongPub} + + errA := make(chan error, 1) + errB := make(chan error, 1) + go func() { _, err := secure.Handshake(pa, true, cfgServer); errA <- err }() + go func() { _, err := secure.Handshake(pb, false, cfgClient); errB <- err }() + <-errA + <-errB +} + +func TestAuthOffsetClientVerifyFails(t *testing.T) { + secure.ResetReplayCache() + _, srvPriv, _ := ed25519.GenerateKey(rand.Reader) + cliPub, cliPriv, _ := ed25519.GenerateKey(rand.Reader) + wrongPub, _, _ := ed25519.GenerateKey(rand.Reader) + + pa, pb := net.Pipe() + defer pa.Close() + defer pb.Close() + + cfgServer := &secure.HandshakeConfig{NodeID: 1, Signer: srvPriv, PeerPubKey: cliPub} + cfgClient := &secure.HandshakeConfig{NodeID: 2, Signer: cliPriv, PeerPubKey: wrongPub} + + errA := make(chan error, 1) + errB := make(chan error, 1) + go func() { + _, err := secure.HandshakeWithTimestampOffset(pa, true, cfgServer, 0) + errA <- err + }() + go func() { + _, err := secure.HandshakeWithTimestampOffset(pb, false, cfgClient, 0) + errB <- err + }() + <-errA + <-errB +} + +func TestAuthOffsetServerVerifyFails(t *testing.T) { + secure.ResetReplayCache() + srvPub, srvPriv, _ := ed25519.GenerateKey(rand.Reader) + _, cliPriv, _ := ed25519.GenerateKey(rand.Reader) + wrongPub, _, _ := ed25519.GenerateKey(rand.Reader) + + pa, pb := net.Pipe() + defer pa.Close() + defer pb.Close() + + cfgServer := &secure.HandshakeConfig{NodeID: 1, Signer: srvPriv, PeerPubKey: wrongPub} + cfgClient := &secure.HandshakeConfig{NodeID: 2, Signer: cliPriv, PeerPubKey: srvPub} + + errA := make(chan error, 1) + errB := make(chan error, 1) + go func() { + _, err := secure.HandshakeWithTimestampOffset(pa, true, cfgServer, 0) + errA <- err + }() + go func() { + _, err := secure.HandshakeWithTimestampOffset(pb, false, cfgClient, 0) + errB <- err + }() + <-errA + <-errB +} + +func TestAuthLookupServerVerifyFails(t *testing.T) { + secure.ResetReplayCache() + serverPub, serverPriv, _ := ed25519.GenerateKey(rand.Reader) + _, clientPriv, _ := ed25519.GenerateKey(rand.Reader) + wrongPub, _, _ := ed25519.GenerateKey(rand.Reader) + const srvID, cliID = uint32(101), uint32(202) + + srvLookup := func(nodeID uint32) ed25519.PublicKey { + if nodeID == cliID { + return wrongPub + } + return nil + } + cliLookup := func(nodeID uint32) ed25519.PublicKey { + if nodeID == srvID { + return serverPub + } + return nil + } + + pa, pb := net.Pipe() + defer pa.Close() + defer pb.Close() + + errA := make(chan error, 1) + errB := make(chan error, 1) + go func() { + _, err := secure.HandshakeWithLookup(pa, true, &secure.HandshakeConfig{NodeID: srvID, Signer: serverPriv}, srvLookup) + errA <- err + }() + go func() { + _, err := secure.HandshakeWithLookup(pb, false, &secure.HandshakeConfig{NodeID: cliID, Signer: clientPriv}, cliLookup) + errB <- err + }() + <-errA + <-errB +} + +func TestAuthLookupClientVerifyFails(t *testing.T) { + secure.ResetReplayCache() + _, serverPriv, _ := ed25519.GenerateKey(rand.Reader) + clientPub, clientPriv, _ := ed25519.GenerateKey(rand.Reader) + wrongPub, _, _ := ed25519.GenerateKey(rand.Reader) + const srvID, cliID = uint32(901), uint32(902) + + srvLookup := func(nodeID uint32) ed25519.PublicKey { + if nodeID == cliID { + return clientPub + } + return nil + } + cliLookup := func(nodeID uint32) ed25519.PublicKey { + if nodeID == srvID { + return wrongPub + } + return nil + } + + pa, pb := net.Pipe() + defer pa.Close() + defer pb.Close() + + errA := make(chan error, 1) + errB := make(chan error, 1) + go func() { + _, err := secure.HandshakeWithLookup(pa, true, &secure.HandshakeConfig{NodeID: srvID, Signer: serverPriv}, srvLookup) + errA <- err + }() + go func() { + _, err := secure.HandshakeWithLookup(pb, false, &secure.HandshakeConfig{NodeID: cliID, Signer: clientPriv}, cliLookup) + errB <- err + }() + <-errA + cliErr := <-errB + if cliErr == nil { + t.Fatal("expected client verify error") + } +} + +// --------------------------------------------------------------------------- +// performAuth* — post-ECDH auth-frame write failures +// --------------------------------------------------------------------------- + +func TestAuthServerPostECDHWriteFails(t *testing.T) { + secure.ResetReplayCache() + srvPub, srvPriv, _ := ed25519.GenerateKey(rand.Reader) + cliPub, cliPriv, _ := ed25519.GenerateKey(rand.Reader) + + pa, pb := net.Pipe() + defer pa.Close() + defer pb.Close() + + cfgServer := &secure.HandshakeConfig{NodeID: 1, Signer: srvPriv, PeerPubKey: cliPub} + cfgClient := &secure.HandshakeConfig{NodeID: 2, Signer: cliPriv, PeerPubKey: srvPub} + + wrapped := &failAfterNWrites{Conn: pa, failAfter: 1} + + errA := make(chan error, 1) + errB := make(chan error, 1) + go func() { _, err := secure.Handshake(wrapped, true, cfgServer); errA <- err }() + go func() { _, err := secure.Handshake(pb, false, cfgClient); errB <- err }() + srvErr := <-errA + <-errB + if srvErr == nil { + t.Fatal("expected server auth-write error") + } +} + +func TestAuthClientPostVerifyWriteFails(t *testing.T) { + secure.ResetReplayCache() + srvPub, srvPriv, _ := ed25519.GenerateKey(rand.Reader) + cliPub, cliPriv, _ := ed25519.GenerateKey(rand.Reader) + + pa, pb := net.Pipe() + defer pa.Close() + defer pb.Close() + + cfgServer := &secure.HandshakeConfig{NodeID: 1, Signer: srvPriv, PeerPubKey: cliPub} + cfgClient := &secure.HandshakeConfig{NodeID: 2, Signer: cliPriv, PeerPubKey: srvPub} + + wrapped := &failAfterNWrites{Conn: pb, failAfter: 1} + + errA := make(chan error, 1) + errB := make(chan error, 1) + go func() { _, err := secure.Handshake(pa, true, cfgServer); errA <- err }() + go func() { _, err := secure.Handshake(wrapped, false, cfgClient); errB <- err }() + <-errA + cliErr := <-errB + if cliErr == nil { + t.Fatal("expected client post-verify write error") + } +} + +func TestAuthOffsetServerPostECDHWriteFails(t *testing.T) { + secure.ResetReplayCache() + srvPub, srvPriv, _ := ed25519.GenerateKey(rand.Reader) + cliPub, cliPriv, _ := ed25519.GenerateKey(rand.Reader) + + pa, pb := net.Pipe() + defer pa.Close() + defer pb.Close() + + cfgServer := &secure.HandshakeConfig{NodeID: 1, Signer: srvPriv, PeerPubKey: cliPub} + cfgClient := &secure.HandshakeConfig{NodeID: 2, Signer: cliPriv, PeerPubKey: srvPub} + + wrapped := &failAfterNWrites{Conn: pa, failAfter: 1} + + errA := make(chan error, 1) + errB := make(chan error, 1) + go func() { + _, err := secure.HandshakeWithTimestampOffset(wrapped, true, cfgServer, 0) + errA <- err + }() + go func() { + _, err := secure.HandshakeWithTimestampOffset(pb, false, cfgClient, 0) + errB <- err + }() + srvErr := <-errA + <-errB + if srvErr == nil { + t.Fatal("expected server auth-write error") + } +} + +func TestAuthOffsetClientPostVerifyWriteFails(t *testing.T) { + secure.ResetReplayCache() + srvPub, srvPriv, _ := ed25519.GenerateKey(rand.Reader) + cliPub, cliPriv, _ := ed25519.GenerateKey(rand.Reader) + + pa, pb := net.Pipe() + defer pa.Close() + defer pb.Close() + + cfgServer := &secure.HandshakeConfig{NodeID: 1, Signer: srvPriv, PeerPubKey: cliPub} + cfgClient := &secure.HandshakeConfig{NodeID: 2, Signer: cliPriv, PeerPubKey: srvPub} + + wrapped := &failAfterNWrites{Conn: pb, failAfter: 1} + + errA := make(chan error, 1) + errB := make(chan error, 1) + go func() { + _, err := secure.HandshakeWithTimestampOffset(pa, true, cfgServer, 0) + errA <- err + }() + go func() { + _, err := secure.HandshakeWithTimestampOffset(wrapped, false, cfgClient, 0) + errB <- err + }() + <-errA + cliErr := <-errB + if cliErr == nil { + t.Fatal("expected client post-verify write error") + } +} + +func TestAuthLookupServerPostECDHWriteFails(t *testing.T) { + secure.ResetReplayCache() + serverPub, serverPriv, _ := ed25519.GenerateKey(rand.Reader) + clientPub, clientPriv, _ := ed25519.GenerateKey(rand.Reader) + const srvID, cliID = uint32(701), uint32(702) + srvLookup := func(nodeID uint32) ed25519.PublicKey { + if nodeID == cliID { + return clientPub + } + return nil + } + cliLookup := func(nodeID uint32) ed25519.PublicKey { + if nodeID == srvID { + return serverPub + } + return nil + } + + pa, pb := net.Pipe() + defer pa.Close() + defer pb.Close() + + wrapped := &failAfterNWrites{Conn: pa, failAfter: 1} + + errA := make(chan error, 1) + errB := make(chan error, 1) + go func() { + _, err := secure.HandshakeWithLookup(wrapped, true, &secure.HandshakeConfig{NodeID: srvID, Signer: serverPriv}, srvLookup) + errA <- err + }() + go func() { + _, err := secure.HandshakeWithLookup(pb, false, &secure.HandshakeConfig{NodeID: cliID, Signer: clientPriv}, cliLookup) + errB <- err + }() + srvErr := <-errA + <-errB + if srvErr == nil { + t.Fatal("expected server auth-write error") + } +} + +func TestAuthLookupClientPostVerifyWriteFails(t *testing.T) { + secure.ResetReplayCache() + serverPub, serverPriv, _ := ed25519.GenerateKey(rand.Reader) + clientPub, clientPriv, _ := ed25519.GenerateKey(rand.Reader) + const srvID, cliID = uint32(801), uint32(802) + srvLookup := func(nodeID uint32) ed25519.PublicKey { + if nodeID == cliID { + return clientPub + } + return nil + } + cliLookup := func(nodeID uint32) ed25519.PublicKey { + if nodeID == srvID { + return serverPub + } + return nil + } + + pa, pb := net.Pipe() + defer pa.Close() + defer pb.Close() + + wrapped := &failAfterNWrites{Conn: pb, failAfter: 1} + + errA := make(chan error, 1) + errB := make(chan error, 1) + go func() { + _, err := secure.HandshakeWithLookup(pa, true, &secure.HandshakeConfig{NodeID: srvID, Signer: serverPriv}, srvLookup) + errA <- err + }() + go func() { + _, err := secure.HandshakeWithLookup(wrapped, false, &secure.HandshakeConfig{NodeID: cliID, Signer: clientPriv}, cliLookup) + errB <- err + }() + <-errA + cliErr := <-errB + if cliErr == nil { + t.Fatal("expected client post-verify write error") + } +} + +// --------------------------------------------------------------------------- +// CheckAndRecordNonce cap — fill the cache to maxReplayCacheEntries then +// confirm the next insert errors. +// --------------------------------------------------------------------------- + +func TestCheckAndRecordNonceCacheFull(t *testing.T) { + secure.ResetReplayCache() + for i := 0; i < 100000; i++ { + var n [16]byte + binary.BigEndian.PutUint64(n[:8], uint64(i)) + secure.InjectReplayNonce(n) + } + var fresh [16]byte + binary.BigEndian.PutUint64(fresh[:8], 0xFFFFFFFFFFFFFFFF) + err := secure.CheckAndRecordNonce(fresh) + if err == nil || !strings.Contains(err.Error(), "cache full") { + t.Fatalf("expected cache-full error, got %v", err) + } + secure.ResetReplayCache() +} + +// --------------------------------------------------------------------------- +// Sanity: AES-GCM nonce size assumption matches. +// --------------------------------------------------------------------------- + +func TestAesGcmNonceSizeIs12(t *testing.T) { + t.Parallel() + key := make([]byte, 32) + block, err := aes.NewCipher(key) + if err != nil { + t.Fatal(err) + } + g, err := cipher.NewGCM(block) + if err != nil { + t.Fatal(err) + } + if g.NonceSize() != 12 { + t.Errorf("nonce size = %d, want 12", g.NonceSize()) + } +} + +var _ = errors.New +var _ = secure.AuthFrameLen diff --git a/secure/zz_handshake_lookup_test.go b/secure/zz_handshake_lookup_test.go new file mode 100644 index 0000000..9d044a0 --- /dev/null +++ b/secure/zz_handshake_lookup_test.go @@ -0,0 +1,284 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package secure_test + +import ( + "crypto/ed25519" + "crypto/rand" + "errors" + "io" + "net" + "testing" + "time" + + "github.com/pilot-protocol/common/secure" +) + +// runLookupHandshake connects two net.Pipe ends, runs HandshakeWithLookup on +// each side concurrently, and returns both resulting SecureConns (or errors). +// NOTE: callers must NOT mark tests t.Parallel() since HandshakeWithLookup +// mutates the global replay cache (see iter 12 lesson). +func runLookupHandshake(t *testing.T, serverCfg, clientCfg *secure.HandshakeConfig, serverLookup, clientLookup secure.PeerPubKeyLookup) (*secure.SecureConn, error, *secure.SecureConn, error) { + t.Helper() + s, c := net.Pipe() + type result struct { + sc *secure.SecureConn + err error + } + srvCh := make(chan result, 1) + cliCh := make(chan result, 1) + go func() { + sc, err := secure.HandshakeWithLookup(s, true, serverCfg, serverLookup) + srvCh <- result{sc, err} + }() + go func() { + sc, err := secure.HandshakeWithLookup(c, false, clientCfg, clientLookup) + cliCh <- result{sc, err} + }() + select { + case <-time.After(5 * time.Second): + s.Close() + c.Close() + t.Fatal("handshake timed out") + default: + } + srv := <-srvCh + cli := <-cliCh + return srv.sc, srv.err, cli.sc, cli.err +} + +func newEd25519KeyPair(t *testing.T) (ed25519.PublicKey, ed25519.PrivateKey) { + t.Helper() + pub, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + return pub, priv +} + +func TestHandshakeWithLookupHappyPath(t *testing.T) { + secure.ResetReplayCache() + serverPub, serverPriv := newEd25519KeyPair(t) + clientPub, clientPriv := newEd25519KeyPair(t) + const srvID, cliID = uint32(0x10001), uint32(0x20002) + + srvLookup := func(nodeID uint32) ed25519.PublicKey { + if nodeID == cliID { + return clientPub + } + return nil + } + cliLookup := func(nodeID uint32) ed25519.PublicKey { + if nodeID == srvID { + return serverPub + } + return nil + } + + srvSC, srvErr, cliSC, cliErr := runLookupHandshake(t, + &secure.HandshakeConfig{NodeID: srvID, Signer: serverPriv}, + &secure.HandshakeConfig{NodeID: cliID, Signer: clientPriv}, + srvLookup, cliLookup) + + if srvErr != nil { + t.Fatalf("server handshake: %v", srvErr) + } + if cliErr != nil { + t.Fatalf("client handshake: %v", cliErr) + } + if srvSC.PeerNodeID != cliID { + t.Errorf("server saw peer=%d, want %d", srvSC.PeerNodeID, cliID) + } + if cliSC.PeerNodeID != srvID { + t.Errorf("client saw peer=%d, want %d", cliSC.PeerNodeID, srvID) + } + + // End-to-end data exchange proves the derived keys match on both sides. + done := make(chan error, 1) + go func() { + buf := make([]byte, 5) + if _, err := io.ReadFull(cliSC, buf); err != nil { + done <- err + return + } + if string(buf) != "ping!" { + done <- errors.New("bad payload: " + string(buf)) + return + } + done <- nil + }() + if _, err := srvSC.Write([]byte("ping!")); err != nil { + t.Fatalf("write: %v", err) + } + select { + case err := <-done: + if err != nil { + t.Fatalf("read: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("timeout reading from encrypted stream") + } + cliSC.Close() + srvSC.Close() +} + +func TestHandshakeWithLookupServerRejectsUnknownPeer(t *testing.T) { + secure.ResetReplayCache() + serverPub, serverPriv := newEd25519KeyPair(t) + _, clientPriv := newEd25519KeyPair(t) + const srvID, cliID = uint32(0x30003), uint32(0x40004) + + srvLookup := func(_ uint32) ed25519.PublicKey { return nil } // unknown + cliLookup := func(nodeID uint32) ed25519.PublicKey { + if nodeID == srvID { + return serverPub + } + return nil + } + + _, srvErr, _, _ := runLookupHandshake(t, + &secure.HandshakeConfig{NodeID: srvID, Signer: serverPriv}, + &secure.HandshakeConfig{NodeID: cliID, Signer: clientPriv}, + srvLookup, cliLookup) + + // Server reads client's auth frame AFTER writing its own, then looks up + // the client's pubkey and rejects on nil. Client has already completed + // its side (read server frame, verified, wrote its own) so it returns + // without error — only the server-side error surfaces. + if srvErr == nil { + t.Fatal("server should have rejected unknown peer") + } +} + +func TestHandshakeWithLookupClientRejectsUnknownServer(t *testing.T) { + secure.ResetReplayCache() + _, serverPriv := newEd25519KeyPair(t) + clientPub, clientPriv := newEd25519KeyPair(t) + const srvID, cliID = uint32(0x50005), uint32(0x60006) + + srvLookup := func(nodeID uint32) ed25519.PublicKey { + if nodeID == cliID { + return clientPub + } + return nil + } + cliLookup := func(_ uint32) ed25519.PublicKey { return nil } + + _, srvErr, _, cliErr := runLookupHandshake(t, + &secure.HandshakeConfig{NodeID: srvID, Signer: serverPriv}, + &secure.HandshakeConfig{NodeID: cliID, Signer: clientPriv}, + srvLookup, cliLookup) + + if cliErr == nil { + t.Fatal("client should have rejected unknown server") + } + if srvErr == nil { + t.Fatal("server should have failed after client closed") + } +} + +func TestHandshakeWithLookupBadSignatureRejected(t *testing.T) { + secure.ResetReplayCache() + serverPub, serverPriv := newEd25519KeyPair(t) + _, clientPriv := newEd25519KeyPair(t) + // Third unrelated pubkey — server will look up client by nodeID but + // get a key that doesn't match the client's actual signer. + wrongPub, _ := newEd25519KeyPair(t) + const srvID, cliID = uint32(0x70007), uint32(0x80008) + + srvLookup := func(nodeID uint32) ed25519.PublicKey { + if nodeID == cliID { + return wrongPub // signature will fail to verify + } + return nil + } + cliLookup := func(nodeID uint32) ed25519.PublicKey { + if nodeID == srvID { + return serverPub + } + return nil + } + + _, srvErr, _, _ := runLookupHandshake(t, + &secure.HandshakeConfig{NodeID: srvID, Signer: serverPriv}, + &secure.HandshakeConfig{NodeID: cliID, Signer: clientPriv}, + srvLookup, cliLookup) + + if srvErr == nil { + t.Fatal("server should have rejected bad signature") + } +} + +func TestHandshakeWithLookupNoAuthSkipsLookup(t *testing.T) { + secure.ResetReplayCache() + s, c := net.Pipe() + srvCh := make(chan error, 1) + cliCh := make(chan error, 1) + // No signer in cfg — auth is skipped, lookup is never called. + go func() { + _, err := secure.HandshakeWithLookup(s, true, nil, nil) + srvCh <- err + }() + go func() { + _, err := secure.HandshakeWithLookup(c, false, nil, nil) + cliCh <- err + }() + select { + case err := <-srvCh: + if err != nil { + t.Fatalf("server: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("server handshake timed out") + } + select { + case err := <-cliCh: + if err != nil { + t.Fatalf("client: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("client handshake timed out") + } + s.Close() + c.Close() +} + +// ---------- Server constructors ---------- + +func TestNewServerSetsFields(t *testing.T) { + t.Parallel() + called := false + h := func(_ net.Conn) { called = true } + s := secure.NewServer(nil, h) + if s.Driver() != nil { + t.Error("driver should be nil") + } + if s.Handler() == nil { + t.Fatal("handler nil") + } + if s.AuthSigner() != nil || s.AuthNodeID() != 0 || s.PeerLookup() != nil { + t.Error("unauth server should not populate auth fields") + } + // Sanity: handler invocable. + s.Handler()(nil) + if !called { + t.Error("handler not invoked") + } +} + +func TestNewAuthServerSetsFields(t *testing.T) { + t.Parallel() + _, priv := newEd25519KeyPair(t) + lookup := func(_ uint32) ed25519.PublicKey { return nil } + h := func(_ net.Conn) {} + s := secure.NewAuthServer(nil, h, 0xABCD1234, priv, lookup) + if s.AuthNodeID() != 0xABCD1234 { + t.Errorf("authNodeID = %#x", s.AuthNodeID()) + } + if s.AuthSigner() == nil { + t.Error("authSigner nil") + } + if s.PeerLookup() == nil { + t.Error("peerLookup nil") + } +} diff --git a/secure/zz_secure_test.go b/secure/zz_secure_test.go new file mode 100644 index 0000000..21d83e4 --- /dev/null +++ b/secure/zz_secure_test.go @@ -0,0 +1,586 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package secure_test + +import ( + "bytes" + "crypto/ed25519" + "crypto/rand" + "encoding/binary" + "errors" + "net" + "strings" + "sync" + "testing" + "time" + + "github.com/pilot-protocol/common/secure" +) + +// pipePair returns two connected net.Conn endpoints (in-process pipe). +func pipePair() (net.Conn, net.Conn) { + return net.Pipe() +} + +// genIdentity returns an Ed25519 keypair. +func genIdentity(t *testing.T) (ed25519.PublicKey, ed25519.PrivateKey) { + t.Helper() + pub, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + return pub, priv +} + +// handshakeBoth runs both ends of Handshake concurrently and returns errors. +func handshakeBoth(t *testing.T, a, b net.Conn, cfgA, cfgB *secure.HandshakeConfig) (*secure.SecureConn, *secure.SecureConn) { + t.Helper() + type result struct { + sc *secure.SecureConn + err error + } + chA := make(chan result, 1) + chB := make(chan result, 1) + go func() { + var sc *secure.SecureConn + var err error + if cfgA != nil { + sc, err = secure.Handshake(a, true, cfgA) + } else { + sc, err = secure.Handshake(a, true) + } + chA <- result{sc, err} + }() + go func() { + var sc *secure.SecureConn + var err error + if cfgB != nil { + sc, err = secure.Handshake(b, false, cfgB) + } else { + sc, err = secure.Handshake(b, false) + } + chB <- result{sc, err} + }() + rA := <-chA + rB := <-chB + if rA.err != nil { + t.Fatalf("server handshake: %v", rA.err) + } + if rB.err != nil { + t.Fatalf("client handshake: %v", rB.err) + } + return rA.sc, rB.sc +} + +// --------------------------------------------------------------------------- +// Unauthenticated handshake + Read/Write round-trip +// --------------------------------------------------------------------------- + +func TestUnauthenticatedHandshakeRoundTrip(t *testing.T) { + t.Parallel() + a, b := pipePair() + defer a.Close() + defer b.Close() + + server, client := handshakeBoth(t, a, b, nil, nil) + + msg := []byte("hello secure world") + go func() { _, _ = client.Write(msg) }() + + got := make([]byte, len(msg)) + n, err := server.Read(got) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(got[:n], msg) { + t.Errorf("got %q, want %q", got[:n], msg) + } +} + +func TestEncryptedReadBuffersLeftover(t *testing.T) { + t.Parallel() + a, b := pipePair() + defer a.Close() + defer b.Close() + server, client := handshakeBoth(t, a, b, nil, nil) + + msg := bytes.Repeat([]byte("X"), 1024) + go func() { _, _ = client.Write(msg) }() + + // Read with a small buffer to force leftover-buffering path + small := make([]byte, 100) + n, err := server.Read(small) + if err != nil { + t.Fatal(err) + } + if n != 100 { + t.Errorf("first read = %d, want 100", n) + } + // Drain remaining 924 bytes via subsequent reads (should hit readBuf) + rest := make([]byte, 1024) + off := 0 + for off < 924 { + k, err := server.Read(rest[off:]) + if err != nil { + t.Fatal(err) + } + off += k + } + if off != 924 { + t.Errorf("drained %d, want 924", off) + } +} + +func TestEncryptedHidesPlaintextOnWire(t *testing.T) { + // Use a tap to capture raw bytes, then verify plaintext is not present. + a, b := pipePair() + defer a.Close() + defer b.Close() + server, client := handshakeBoth(t, a, b, nil, nil) + + // Send a recognizable plaintext via client; ensure it doesn't appear + // directly on the wire by reading both server-side decrypted data and + // verifying decryption returns the same bytes (i.e., AEAD is exercised). + plain := []byte("PLAINTEXT-MARKER-12345") + go func() { _, _ = client.Write(plain) }() + got := make([]byte, len(plain)) + if _, err := server.Read(got); err != nil { + t.Fatal(err) + } + if !bytes.Equal(got, plain) { + t.Errorf("decryption mismatch: got %q want %q", got, plain) + } + // Indirect check: server has consumed all, no further bytes pending +} + +func TestNonceUniquenessAcrossWrites(t *testing.T) { + a, b := pipePair() + defer a.Close() + defer b.Close() + server, client := handshakeBoth(t, a, b, nil, nil) + + // Send N distinct messages from client; server reads them. Each Write + // increments the nonce counter so duplicates would be a SUT bug. + const N = 5 + go func() { + for i := 0; i < N; i++ { + _, _ = client.Write([]byte{byte(i)}) + } + }() + for i := 0; i < N; i++ { + buf := make([]byte, 1) + if _, err := server.Read(buf); err != nil { + t.Fatal(err) + } + if buf[0] != byte(i) { + t.Errorf("msg %d: got %d", i, buf[0]) + } + } +} + +// --------------------------------------------------------------------------- +// Authenticated handshake +// --------------------------------------------------------------------------- + +func TestAuthenticatedHandshakeMutual(t *testing.T) { + secure.ResetReplayCache() + + srvPub, srvPriv := genIdentity(t) + cliPub, cliPriv := genIdentity(t) + + a, b := pipePair() + defer a.Close() + defer b.Close() + + cfgServer := &secure.HandshakeConfig{NodeID: 100, Signer: srvPriv, PeerPubKey: cliPub} + cfgClient := &secure.HandshakeConfig{NodeID: 200, Signer: cliPriv, PeerPubKey: srvPub} + + server, client := handshakeBoth(t, a, b, cfgServer, cfgClient) + + if server.PeerNodeID != 200 { + t.Errorf("server PeerNodeID = %d, want 200", server.PeerNodeID) + } + if client.PeerNodeID != 100 { + t.Errorf("client PeerNodeID = %d, want 100", client.PeerNodeID) + } +} + +func TestAuthenticatedHandshakeWrongPeerKeyFails(t *testing.T) { + secure.ResetReplayCache() + + _, srvPriv := genIdentity(t) + _, cliPriv := genIdentity(t) + wrongPub, _ := genIdentity(t) // server expects this, client signs with different key + + a, b := pipePair() + defer a.Close() + defer b.Close() + + cfgServer := &secure.HandshakeConfig{NodeID: 1, Signer: srvPriv, PeerPubKey: wrongPub} + cfgClient := &secure.HandshakeConfig{NodeID: 2, Signer: cliPriv, PeerPubKey: wrongPub} + + type r struct{ err error } + chA := make(chan r, 1) + chB := make(chan r, 1) + go func() { _, err := secure.Handshake(a, true, cfgServer); chA <- r{err} }() + go func() { _, err := secure.Handshake(b, false, cfgClient); chB <- r{err} }() + + rA := <-chA + rB := <-chB + if rA.err == nil && rB.err == nil { + t.Fatal("expected at least one side to fail with wrong PeerPubKey") + } +} + +// --------------------------------------------------------------------------- +// Replay cache and timestamp skew +// --------------------------------------------------------------------------- + +func TestHandshakeWithTimestampOffsetExpiredFails(t *testing.T) { + secure.ResetReplayCache() + srvPub, srvPriv := genIdentity(t) + cliPub, cliPriv := genIdentity(t) + + a, b := pipePair() + defer a.Close() + defer b.Close() + + cfgServer := &secure.HandshakeConfig{NodeID: 1, Signer: srvPriv, PeerPubKey: cliPub} + cfgClient := &secure.HandshakeConfig{NodeID: 2, Signer: cliPriv, PeerPubKey: srvPub} + + // Server uses normal timestamp; client uses 10s offset → exceeds 5s skew. + chA := make(chan error, 1) + chB := make(chan error, 1) + go func() { + _, err := secure.HandshakeWithTimestampOffset(a, true, cfgServer, 0) + chA <- err + }() + go func() { + _, err := secure.HandshakeWithTimestampOffset(b, false, cfgClient, 10*time.Second) + chB <- err + }() + + errA := <-chA + errB := <-chB + // Server reads client's frame and finds it exceeds skew. + if errA == nil { + t.Fatal("expected server to reject expired auth") + } + if !strings.Contains(errA.Error(), "timestamp expired") && + !strings.Contains(errA.Error(), "skew") { + t.Errorf("unexpected err: %v", errA) + } + _ = errB // client side may also error or close +} + +func TestReplayCacheRejectsRepeat(t *testing.T) { + secure.ResetReplayCache() + + var nonce [16]byte + if _, err := rand.Read(nonce[:]); err != nil { + t.Fatal(err) + } + if err := secure.CheckAndRecordNonce(nonce); err != nil { + t.Fatalf("first record: %v", err) + } + // Same nonce again → replay + err := secure.CheckAndRecordNonce(nonce) + if err == nil || !strings.Contains(err.Error(), "replay") { + t.Fatalf("expected replay error, got %v", err) + } +} + +func TestCheckReplayNonceDoesNotRecord(t *testing.T) { + secure.ResetReplayCache() + + var nonce [16]byte + if _, err := rand.Read(nonce[:]); err != nil { + t.Fatal(err) + } + // CheckReplayNonce should report not-present (nil err) without inserting. + if err := secure.CheckReplayNonce(nonce); err != nil { + t.Fatalf("expected fresh nonce nil err, got %v", err) + } + // Then we can record it + if err := secure.CheckAndRecordNonce(nonce); err != nil { + t.Fatal(err) + } + // Now CheckReplayNonce should report replay + if err := secure.CheckReplayNonce(nonce); err == nil { + t.Fatal("expected replay error from CheckReplayNonce") + } +} + +func TestInjectReplayNonceTriggersReplay(t *testing.T) { + secure.ResetReplayCache() + + var nonce [16]byte + if _, err := rand.Read(nonce[:]); err != nil { + t.Fatal(err) + } + secure.InjectReplayNonce(nonce) + if err := secure.CheckAndRecordNonce(nonce); err == nil { + t.Fatal("expected replay after inject") + } +} + +// --------------------------------------------------------------------------- +// BuildAuthSignMessage +// --------------------------------------------------------------------------- + +func TestBuildAuthSignMessageStable(t *testing.T) { + x25519 := bytes.Repeat([]byte{0xAB}, 32) + var nonce [16]byte + for i := range nonce { + nonce[i] = byte(i) + } + got := secure.BuildAuthSignMessage(0xDEADBEEF, x25519, 0x1122334455667788, nonce) + // Layout: domain(18) + nodeID(4) + pub(32) + ts(8) + nonce(16) = 78 + if len(got) != 18+4+32+8+16 { + t.Errorf("len = %d, want 78", len(got)) + } + if !bytes.HasPrefix(got, []byte("pilot-secure-auth:")) { + t.Errorf("missing domain prefix: %q", got[:18]) + } + if id := binary.BigEndian.Uint32(got[18:22]); id != 0xDEADBEEF { + t.Errorf("nodeID encoding wrong: %x", id) + } + if !bytes.Equal(got[22:54], x25519) { + t.Errorf("pubkey not embedded correctly") + } + if ts := binary.BigEndian.Uint64(got[54:62]); ts != 0x1122334455667788 { + t.Errorf("timestamp encoding wrong: %x", ts) + } + if !bytes.Equal(got[62:78], nonce[:]) { + t.Errorf("nonce not embedded correctly") + } +} + +func TestBuildAuthSignMessageDifferentInputsDiffer(t *testing.T) { + x := bytes.Repeat([]byte{0x00}, 32) + var n1, n2 [16]byte + n2[0] = 1 + a := secure.BuildAuthSignMessage(1, x, 100, n1) + b := secure.BuildAuthSignMessage(1, x, 100, n2) + if bytes.Equal(a, b) { + t.Fatal("messages with different nonces should differ") + } +} + +// --------------------------------------------------------------------------- +// VerifyAuthFrame +// --------------------------------------------------------------------------- + +func TestVerifyAuthFrameWrongSize(t *testing.T) { + _, err := secure.VerifyAuthFrame(make([]byte, 10), nil, nil, time.Now()) + if err == nil || !strings.Contains(err.Error(), "wrong size") { + t.Fatalf("expected wrong-size err, got %v", err) + } +} + +func TestVerifyAuthFrameExpiredTimestamp(t *testing.T) { + secure.ResetReplayCache() + pub, priv := genIdentity(t) + x25519 := bytes.Repeat([]byte{0xAB}, 32) + expiredTS := uint64(time.Now().Add(-time.Hour).Unix()) + var nonce [16]byte + rand.Read(nonce[:]) + + frame := make([]byte, secure.AuthFrameLen) + binary.BigEndian.PutUint32(frame[0:4], 42) + binary.BigEndian.PutUint64(frame[4:12], expiredTS) + copy(frame[12:28], nonce[:]) + sig := ed25519.Sign(priv, secure.BuildAuthSignMessage(42, x25519, expiredTS, nonce)) + copy(frame[28:92], sig) + + _, err := secure.VerifyAuthFrame(frame, pub, x25519, time.Now()) + if err == nil || !strings.Contains(err.Error(), "expired") { + t.Fatalf("expected expired err, got %v", err) + } +} + +func TestVerifyAuthFrameReplayDetected(t *testing.T) { + secure.ResetReplayCache() + pub, priv := genIdentity(t) + x25519 := bytes.Repeat([]byte{0xAB}, 32) + now := time.Now() + ts := uint64(now.Unix()) + var nonce [16]byte + rand.Read(nonce[:]) + + build := func() []byte { + frame := make([]byte, secure.AuthFrameLen) + binary.BigEndian.PutUint32(frame[0:4], 42) + binary.BigEndian.PutUint64(frame[4:12], ts) + copy(frame[12:28], nonce[:]) + sig := ed25519.Sign(priv, secure.BuildAuthSignMessage(42, x25519, ts, nonce)) + copy(frame[28:92], sig) + return frame + } + + if _, err := secure.VerifyAuthFrame(build(), pub, x25519, now); err != nil { + t.Fatalf("first verify: %v", err) + } + // Second verify with the SAME nonce → replay + _, err := secure.VerifyAuthFrame(build(), pub, x25519, now) + if err == nil || !strings.Contains(err.Error(), "replay") { + t.Fatalf("expected replay error, got %v", err) + } +} + +func TestVerifyAuthFrameBadSignature(t *testing.T) { + secure.ResetReplayCache() + pub, _ := genIdentity(t) // verifier key + _, otherPriv := genIdentity(t) + x25519 := bytes.Repeat([]byte{0xAB}, 32) + now := time.Now() + ts := uint64(now.Unix()) + var nonce [16]byte + rand.Read(nonce[:]) + + frame := make([]byte, secure.AuthFrameLen) + binary.BigEndian.PutUint32(frame[0:4], 42) + binary.BigEndian.PutUint64(frame[4:12], ts) + copy(frame[12:28], nonce[:]) + // Sign with a DIFFERENT key → verification must fail + sig := ed25519.Sign(otherPriv, secure.BuildAuthSignMessage(42, x25519, ts, nonce)) + copy(frame[28:92], sig) + + _, err := secure.VerifyAuthFrame(frame, pub, x25519, now) + if err == nil || !strings.Contains(err.Error(), "verification failed") { + t.Fatalf("expected sig verify err, got %v", err) + } +} + +// --------------------------------------------------------------------------- +// ReadExact +// --------------------------------------------------------------------------- + +func TestReadExactSuccess(t *testing.T) { + t.Parallel() + got, err := secure.ReadExact(bytes.NewReader([]byte("hello world")), 5) + if err != nil { + t.Fatal(err) + } + if string(got) != "hello" { + t.Errorf("got %q", got) + } +} + +func TestReadExactShortFails(t *testing.T) { + t.Parallel() + _, err := secure.ReadExact(bytes.NewReader([]byte("hi")), 5) + if err == nil { + t.Fatal("expected error reading 5 from 2-byte source") + } + if !errors.Is(err, errors.New("")) && err.Error() == "" { + t.Fatalf("expected non-empty error, got %v", err) + } +} + +// --------------------------------------------------------------------------- +// secure.SecureConn passthrough methods +// --------------------------------------------------------------------------- + +func TestSecureConnAddrAndDeadlinePassthrough(t *testing.T) { + t.Parallel() + a, b := pipePair() + defer a.Close() + defer b.Close() + server, _ := handshakeBoth(t, a, b, nil, nil) + + if server.LocalAddr() == nil { + t.Error("LocalAddr nil") + } + if server.RemoteAddr() == nil { + t.Error("RemoteAddr nil") + } + + dl := time.Now().Add(time.Second) + if err := server.SetDeadline(dl); err != nil { + t.Errorf("SetDeadline: %v", err) + } + if err := server.SetReadDeadline(dl); err != nil { + t.Errorf("SetReadDeadline: %v", err) + } + if err := server.SetWriteDeadline(dl); err != nil { + t.Errorf("SetWriteDeadline: %v", err) + } +} + +func TestSecureConnCloseClosesUnderlying(t *testing.T) { + t.Parallel() + a, b := pipePair() + server, _ := handshakeBoth(t, a, b, nil, nil) + + if err := server.Close(); err != nil { + t.Errorf("Close: %v", err) + } + // After Close the underlying conn rejects further writes. + if _, err := a.Write([]byte("x")); err == nil { + t.Error("expected raw write to fail after Close") + } + b.Close() +} + +// --------------------------------------------------------------------------- +// Handshake error: unparseable peer key +// --------------------------------------------------------------------------- + +func TestHandshakeRejectsBadPeerKey(t *testing.T) { + t.Parallel() + a, b := pipePair() + defer a.Close() + defer b.Close() + + // Server expects 32-byte X25519 pub from client. Send 32 bytes of 0xFF + // which is an invalid (non-canonical) curve point. + go func() { + // Client side: write garbage instead of running Handshake + junk := bytes.Repeat([]byte{0xFF}, 32) + _, _ = b.Write(junk) + // Read server's pubkey to unblock its Write + buf := make([]byte, 32) + _, _ = b.Read(buf) + }() + + _, err := secure.Handshake(a, true) + // Either ECDH or NewPublicKey may reject — both valid. + if err == nil { + t.Skip("server accepted; some Go versions accept all-1s as pubkey — skip") + } +} + +// --------------------------------------------------------------------------- +// Concurrent writes serialise (no nonce reuse / corruption) +// --------------------------------------------------------------------------- + +func TestConcurrentWritesSerialise(t *testing.T) { + t.Parallel() + a, b := pipePair() + defer a.Close() + defer b.Close() + server, client := handshakeBoth(t, a, b, nil, nil) + + const N = 20 + var wg sync.WaitGroup + for i := 0; i < N; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + _, _ = client.Write([]byte{byte(i)}) + }(i) + } + + // Read all messages on server side; ensure decryption succeeds for each. + got := make(map[byte]bool) + for len(got) < N { + buf := make([]byte, 1) + _, err := server.Read(buf) + if err != nil { + t.Fatalf("decrypt err during concurrent writes: %v", err) + } + got[buf[0]] = true + } + wg.Wait() +} diff --git a/urlvalidate/validate.go b/urlvalidate/validate.go new file mode 100644 index 0000000..081a9c1 --- /dev/null +++ b/urlvalidate/validate.go @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +// Package urlvalidate provides SSRF-prevention checks shared across packages +// that accept operator-supplied URLs (webhook endpoints, audit export sinks, +// identity provider verification callbacks, etc.). +// +// The rules are intentionally conservative: +// - Only http and https schemes are allowed. +// - Link-local addresses (IPv4 169.254.0.0/16, IPv6 fe80::/10) are blocked +// because they include cloud metadata services and host-local adjacencies. +// - A small allowlist of cloud metadata hostnames is blocked outright. DNS +// is case-insensitive, so the comparison lowercases the hostname before +// matching — "Metadata.Google.Internal" must not bypass the blocklist. +// +// Placing this in a neutral package lets both pkg/daemon and pkg/registry +// (which cannot import pkg/daemon) share exactly one implementation. +package urlvalidate + +import ( + "fmt" + "net" + "net/url" + "strings" +) + +// Validate returns nil if rawURL is an acceptable http(s) endpoint that does +// not point at a link-local or well-known cloud-metadata target. Callers are +// responsible for deciding whether an empty URL (which returns an error here) +// should be interpreted as "disable" before calling. +func Validate(rawURL string) error { + parsed, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("invalid URL: %w", err) + } + if parsed.Scheme != "http" && parsed.Scheme != "https" { + return fmt.Errorf("URL must use http or https scheme, got %q", parsed.Scheme) + } + host := parsed.Hostname() + if host == "" { + return fmt.Errorf("URL must have a host") + } + // Strip IPv6 zone identifier (e.g. "fe80::1%eth0") before parsing. + // net.ParseIP does not handle zone suffixes, so without this a + // link-local address with a zone ID would pass the check unnoticed. + ipStr := host + if i := strings.IndexByte(ipStr, '%'); i != -1 { + ipStr = ipStr[:i] + } + if ip := net.ParseIP(ipStr); ip != nil { + if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + return fmt.Errorf("URL cannot target link-local address %s", host) + } + } + switch strings.ToLower(host) { + case + // GCP + "metadata.google.internal", + "metadata.google.com", + // AWS (DNS names that reach the EC2 instance metadata service + // without traversing the link-local IP path) + "ec2.internal", + "instance-data-service.ec2.internal", + // Azure (IMDS DNS endpoint) + "metadata.azure.com": + return fmt.Errorf("URL cannot target cloud metadata endpoint %s", host) + } + return nil +} diff --git a/urlvalidate/zz_cloud_metadata_test.go b/urlvalidate/zz_cloud_metadata_test.go new file mode 100644 index 0000000..6edaac0 --- /dev/null +++ b/urlvalidate/zz_cloud_metadata_test.go @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package urlvalidate_test + +// Regression for SSRF allowlist gaps: the original implementation +// blocked GCP metadata.google.{internal,com} + link-local IPs (which +// covers 169.254.169.254 reaching the EC2/Azure metadata services by +// IP). But the AWS DNS-name path (ec2.internal, +// instance-data-service.ec2.internal) and Azure DNS-name path +// (metadata.azure.com) reached the metadata service without hitting +// the link-local check, leaving an SSRF vector for a webhook +// destination targeting `http://ec2.internal/...`. + +import ( + "strings" + "testing" + + "github.com/pilot-protocol/common/urlvalidate" +) + +func TestValidate_BlocksAWSMetadataHostnames(t *testing.T) { + t.Parallel() + + cases := []string{ + "http://ec2.internal/latest/meta-data/iam/security-credentials/", + "http://instance-data-service.ec2.internal/", + "http://EC2.Internal/", // case-insensitive + } + for _, in := range cases { + err := urlvalidate.Validate(in) + if err == nil { + t.Errorf("Validate(%q) returned nil — AWS metadata hostname not blocked", in) + continue + } + if !strings.Contains(err.Error(), "metadata") { + t.Errorf("Validate(%q) error %q does not mention 'metadata'", in, err.Error()) + } + } +} + +func TestValidate_BlocksAzureMetadataHostname(t *testing.T) { + t.Parallel() + + err := urlvalidate.Validate("http://metadata.azure.com/metadata/instance?api-version=2021-02-01") + if err == nil { + t.Fatal("Azure metadata.azure.com not blocked — SSRF vector") + } + if !strings.Contains(err.Error(), "metadata") { + t.Errorf("expected error to mention 'metadata', got: %v", err) + } +} + +func TestValidate_StillAllowsLegitimateHosts(t *testing.T) { + t.Parallel() + + for _, in := range []string{ + "https://example.com/webhook", + "https://hooks.slack.com/services/T00/B00/abc", + "https://internal-api.example.com/", + } { + if err := urlvalidate.Validate(in); err != nil { + t.Errorf("Validate(%q) wrongly rejected: %v", in, err) + } + } +} diff --git a/urlvalidate/zz_validate_edge_test.go b/urlvalidate/zz_validate_edge_test.go new file mode 100644 index 0000000..6d560a8 --- /dev/null +++ b/urlvalidate/zz_validate_edge_test.go @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package urlvalidate_test + +import ( + "strings" + "testing" + + "github.com/pilot-protocol/common/urlvalidate" +) + +func TestValidate_ParseError(t *testing.T) { + t.Parallel() + // %ZZ is an invalid percent-encoding → url.Parse returns an error. + err := urlvalidate.Validate("http://example.com/%ZZ") + if err == nil || !strings.Contains(err.Error(), "invalid URL") { + t.Fatalf("want 'invalid URL', got %v", err) + } +} + +func TestValidate_NoHost(t *testing.T) { + t.Parallel() + // "http:" parses but Hostname() returns "". + err := urlvalidate.Validate("http:") + if err == nil || !strings.Contains(err.Error(), "URL must have a host") { + t.Fatalf("want 'URL must have a host', got %v", err) + } +} + +func TestValidate_LinkLocalIPv6WithZone(t *testing.T) { + t.Parallel() + // The code strips %zoneid before passing to net.ParseIP. Cover that branch. + err := urlvalidate.Validate("http://[fe80::1%25eth0]/") + if err == nil || !strings.Contains(err.Error(), "link-local") { + t.Fatalf("want 'link-local', got %v", err) + } +} + +func TestValidate_LinkLocalIPv4Multicast(t *testing.T) { + t.Parallel() + // 224.0.0.1 is in IPv4 link-local multicast block 224.0.0.0/24. + err := urlvalidate.Validate("http://224.0.0.1/") + if err == nil || !strings.Contains(err.Error(), "link-local") { + t.Fatalf("want 'link-local', got %v", err) + } +} + +func TestValidate_NormalPublicHostsAllowed(t *testing.T) { + t.Parallel() + // Spot-checks for non-error happy paths beyond what the table covers. + for _, u := range []string{ + "https://hooks.example.com/webhook", + "http://example.org:8080/path?x=1", + "https://api.example.io/audit", + } { + if err := urlvalidate.Validate(u); err != nil { + t.Errorf("%s: unexpected error: %v", u, err) + } + } +} diff --git a/urlvalidate/zz_validate_test.go b/urlvalidate/zz_validate_test.go new file mode 100644 index 0000000..05bb50e --- /dev/null +++ b/urlvalidate/zz_validate_test.go @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package urlvalidate_test + +import ( + "strings" + "testing" + + "github.com/pilot-protocol/common/urlvalidate" +) + +func TestValidate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + url string + wantErr bool + errMsg string + }{ + {"valid http", "http://example.com/hook", false, ""}, + {"valid https", "https://hooks.example.com/pilot", false, ""}, + {"valid with port", "https://example.com:8443/hook", false, ""}, + {"valid routable IPv4", "http://192.168.1.100:9000/hook", false, ""}, + + {"ftp scheme", "ftp://example.com/hook", true, "http or https"}, + {"file scheme", "file:///etc/passwd", true, "http or https"}, + {"no scheme", "example.com/hook", true, "http or https"}, + {"empty", "", true, "http or https"}, + + {"link-local ipv4", "http://169.254.169.254/metadata", true, "link-local"}, + {"link-local ipv6", "http://[fe80::1]/hook", true, "link-local"}, + + {"gcp metadata", "http://metadata.google.internal/", true, "cloud metadata"}, + {"gcp metadata alt", "http://metadata.google.com/", true, "cloud metadata"}, + {"gcp metadata mixed case", "http://Metadata.Google.Internal/", true, "cloud metadata"}, + {"gcp metadata upper case", "http://METADATA.GOOGLE.INTERNAL/", true, "cloud metadata"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := urlvalidate.Validate(tc.url) + if tc.wantErr { + if err == nil { + t.Fatalf("expected error for URL %q", tc.url) + } + if tc.errMsg != "" && !strings.Contains(err.Error(), tc.errMsg) { + t.Fatalf("expected error containing %q, got: %v", tc.errMsg, err) + } + } else if err != nil { + t.Fatalf("unexpected error for URL %q: %v", tc.url, err) + } + }) + } +}