diff --git a/pkg/fetch/celestia_node.go b/pkg/fetch/celestia_node.go index 8275f67..f7035bd 100644 --- a/pkg/fetch/celestia_node.go +++ b/pkg/fetch/celestia_node.go @@ -41,25 +41,35 @@ type CelestiaNodeFetcher struct { blob blobAPI headerCloser jsonrpc.ClientCloser blobCloser jsonrpc.ClientCloser + addr string // original address for creating WS subscription clients + authHeader http.Header log zerolog.Logger mu sync.Mutex closed bool + subCloser jsonrpc.ClientCloser // WS client for subscriptions, if any } const ( defaultRPCTimeout = 8 * time.Second defaultRPCMaxRetries = 2 defaultRPCRetryDelay = 100 * time.Millisecond + wsDialTimeout = 10 * time.Second ) -// NewCelestiaNodeFetcher connects to a Celestia node at the given WebSocket address. +// NewCelestiaNodeFetcher connects to a Celestia node at the given address. +// Regular RPC calls use the provided URL scheme (typically HTTP). +// Subscriptions automatically upgrade to WebSocket when needed. func NewCelestiaNodeFetcher(ctx context.Context, addr, token string, log zerolog.Logger) (*CelestiaNodeFetcher, error) { headers := http.Header{} if token != "" { headers.Set("Authorization", "Bearer "+token) } - f := &CelestiaNodeFetcher{log: log} + f := &CelestiaNodeFetcher{ + addr: addr, + authHeader: headers, + log: log, + } var err error f.headerCloser, err = jsonrpc.NewClient(ctx, addr, "header", &f.header, headers) @@ -76,6 +86,19 @@ func NewCelestiaNodeFetcher(ctx context.Context, addr, token string, log zerolog return f, nil } +// httpToWS converts http:// to ws:// and https:// to wss://. +// Returns the address unchanged if it already uses a WS scheme. +func httpToWS(addr string) string { + switch { + case strings.HasPrefix(addr, "http://"): + return "ws://" + strings.TrimPrefix(addr, "http://") + case strings.HasPrefix(addr, "https://"): + return "wss://" + strings.TrimPrefix(addr, "https://") + default: + return addr + } +} + func (f *CelestiaNodeFetcher) GetHeader(ctx context.Context, height uint64) (*types.Header, error) { raw, err := f.callRawWithRetry(ctx, "header.GetByHeight", func(callCtx context.Context) (json.RawMessage, error) { return f.header.GetByHeight(callCtx, height) @@ -147,11 +170,82 @@ func (f *CelestiaNodeFetcher) callRawWithRetry(ctx context.Context, op string, f } func (f *CelestiaNodeFetcher) SubscribeHeaders(ctx context.Context) (<-chan *types.Header, error) { + // Try subscription on the existing client first (works if already on WS). rawCh, err := f.header.Subscribe(ctx) if err != nil { - return nil, fmt.Errorf("header.Subscribe: %w", err) + // The client is likely HTTP — upgrade to WS for subscriptions. + rawCh, err = f.subscribeViaWS(ctx) + } + if err != nil { + // Neither worked — fall back to polling. + f.log.Warn().Err(err).Msg("header.Subscribe not available, falling back to polling") + return f.pollHeaders(ctx), nil + } + + return f.forwardHeaders(ctx, rawCh), nil +} + +// wsSubscribeResult holds the outcome of a WS subscribe attempt. +type wsSubscribeResult struct { + ch <-chan json.RawMessage + closer jsonrpc.ClientCloser + err error +} + +// subscribeViaWS creates a separate WebSocket client for header subscriptions. +// This handles the case where the main client uses HTTP (no channel support). +// The connection attempt is bounded by wsDialTimeout; if the node doesn't +// support WebSocket the goroutine is abandoned (cleaned up when ctx ends). +func (f *CelestiaNodeFetcher) subscribeViaWS(ctx context.Context) (<-chan json.RawMessage, error) { + wsAddr := httpToWS(f.addr) + if wsAddr == f.addr { + return nil, fmt.Errorf("address %q is not HTTP; cannot upgrade to WebSocket", f.addr) + } + + f.log.Info().Str("ws_addr", wsAddr).Msg("upgrading to WebSocket for header subscription") + + // Run the WS dial + subscribe in a goroutine so we can timeout if the + // node doesn't accept WebSocket connections. The parent ctx is passed to + // NewClient because it controls the WS connection lifetime (not just dial). + done := make(chan wsSubscribeResult, 1) + go func() { + var subAPI headerAPI + closer, err := jsonrpc.NewClient(ctx, wsAddr, "header", &subAPI, f.authHeader) + if err != nil { + done <- wsSubscribeResult{err: fmt.Errorf("connect WS header client: %w", err)} + return + } + ch, err := subAPI.Subscribe(ctx) + if err != nil { + closer() + done <- wsSubscribeResult{err: fmt.Errorf("header.Subscribe via WS: %w", err)} + return + } + done <- wsSubscribeResult{ch: ch, closer: closer} + }() + + select { + case r := <-done: + if r.err != nil { + return nil, r.err + } + f.mu.Lock() + old := f.subCloser + f.subCloser = r.closer + f.mu.Unlock() + if old != nil { + old() + } + return r.ch, nil + case <-time.After(wsDialTimeout): + return nil, fmt.Errorf("WS connection to %s timed out after %s", wsAddr, wsDialTimeout) + case <-ctx.Done(): + return nil, ctx.Err() } +} +// forwardHeaders maps raw JSON headers from a subscription channel to typed headers. +func (f *CelestiaNodeFetcher) forwardHeaders(ctx context.Context, rawCh <-chan json.RawMessage) <-chan *types.Header { out := make(chan *types.Header, 64) go func() { defer close(out) @@ -176,8 +270,51 @@ func (f *CelestiaNodeFetcher) SubscribeHeaders(ctx context.Context) (<-chan *typ } } }() + return out +} + +// pollHeaders polls GetNetworkHead at 1s intervals, emitting new headers when +// the height advances. Used as a fallback when header.Subscribe is unavailable. +// NOTE: only the current chain tip is emitted; intermediate heights produced +// between ticks are skipped. The sync coordinator handles this via gap detection +// and re-backfill, so no data is lost — but this path is higher latency than +// a true subscription. +func (f *CelestiaNodeFetcher) pollHeaders(ctx context.Context) <-chan *types.Header { + out := make(chan *types.Header, 64) + go func() { + defer close(out) + + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + + var lastHeight uint64 - return out, nil + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + hdr, err := f.GetNetworkHead(ctx) + if err != nil { + if ctx.Err() != nil { + return + } + f.log.Warn().Err(err).Msg("poll network head failed") + continue + } + if hdr.Height <= lastHeight { + continue + } + lastHeight = hdr.Height + select { + case out <- hdr: + case <-ctx.Done(): + return + } + } + } + }() + return out } // GetProof forwards a blob proof request to the upstream Celestia node. @@ -210,6 +347,9 @@ func (f *CelestiaNodeFetcher) Close() error { return nil } f.closed = true + if f.subCloser != nil { + f.subCloser() + } f.headerCloser() f.blobCloser() return nil diff --git a/pkg/sync/subscription.go b/pkg/sync/subscription.go index 301e01d..ecb48ed 100644 --- a/pkg/sync/subscription.go +++ b/pkg/sync/subscription.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "time" "github.com/rs/zerolog" @@ -12,6 +13,8 @@ import ( "github.com/evstack/apex/pkg/types" ) +const streamingLogInterval = 30 * time.Second + // SubscriptionManager processes new headers from a live subscription. type SubscriptionManager struct { store store.Store @@ -45,10 +48,20 @@ func (sm *SubscriptionManager) Run(ctx context.Context) error { networkHeight = ss.NetworkHeight } + ticker := time.NewTicker(streamingLogInterval) + defer ticker.Stop() + var processed uint64 + for { select { case <-ctx.Done(): return nil + case <-ticker.C: + sm.log.Info(). + Uint64("height", lastHeight). + Uint64("blocks", processed). + Msg("streaming progress") + processed = 0 case hdr, ok := <-ch: if !ok { // Channel closed (disconnect or ctx cancelled). @@ -72,6 +85,7 @@ func (sm *SubscriptionManager) Run(ctx context.Context) error { } lastHeight = hdr.Height + processed++ } } }