diff --git a/README.md b/README.md index f1e128f..21829f5 100644 --- a/README.md +++ b/README.md @@ -351,9 +351,18 @@ io.Copy(dst, artifact.Body) The fetcher uses DNS caching (5-minute refresh), connection pooling, and a 5-minute timeout suited for large artifacts. It retries on rate limits and server errors with exponential backoff and jitter. +### Per-request headers + +Use `FetchWithHeaders` to pass HTTP headers for a single request. This is useful when the auth token varies per request or is obtained dynamically (e.g. Docker Hub token exchange): + +```go +headers := http.Header{"Authorization": {"Bearer " + token}} +artifact, err := f.FetchWithHeaders(ctx, url, headers) +``` + ### Authentication -Pass a function that returns auth headers per URL: +For static credentials that apply to all requests matching a URL pattern, pass a function at construction time: ```go f := fetch.NewFetcher( @@ -366,6 +375,8 @@ f := fetch.NewFetcher( ) ``` +When both `WithAuthFunc` and `FetchWithHeaders` set the same header, `WithAuthFunc` takes precedence. + ### Circuit breaker Wrap a fetcher with per-host circuit breakers to avoid hammering a failing registry. The breaker trips after 5 consecutive failures and resets with exponential backoff (30s initial, 5min max). diff --git a/fetch/circuit_breaker.go b/fetch/circuit_breaker.go index 0a9a387..dfe3c34 100644 --- a/fetch/circuit_breaker.go +++ b/fetch/circuit_breaker.go @@ -3,6 +3,7 @@ package fetch import ( "context" "fmt" + "net/http" "net/url" "sync" "time" @@ -71,6 +72,11 @@ func (cbf *CircuitBreakerFetcher) getBreaker(registry string) *circuit.Breaker { // Fetch wraps the underlying fetcher's Fetch with circuit breaker logic. func (cbf *CircuitBreakerFetcher) Fetch(ctx context.Context, fetchURL string) (*Artifact, error) { + return cbf.FetchWithHeaders(ctx, fetchURL, nil) +} + +// FetchWithHeaders wraps the underlying fetcher's FetchWithHeaders with circuit breaker logic. +func (cbf *CircuitBreakerFetcher) FetchWithHeaders(ctx context.Context, fetchURL string, headers http.Header) (*Artifact, error) { // Extract registry from URL for circuit breaker selection registry := extractRegistry(fetchURL) breaker := cbf.getBreaker(registry) @@ -84,7 +90,7 @@ func (cbf *CircuitBreakerFetcher) Fetch(ctx context.Context, fetchURL string) (* var artifact *Artifact err := breaker.Call(func() error { var fetchErr error - artifact, fetchErr = cbf.fetcher.Fetch(ctx, fetchURL) + artifact, fetchErr = cbf.fetcher.FetchWithHeaders(ctx, fetchURL, headers) return fetchErr }, 0) diff --git a/fetch/fetcher.go b/fetch/fetcher.go index 805db60..f57c2b5 100644 --- a/fetch/fetcher.go +++ b/fetch/fetcher.go @@ -51,6 +51,7 @@ type Artifact struct { // FetcherInterface defines the interface for artifact fetchers. type FetcherInterface interface { Fetch(ctx context.Context, url string) (*Artifact, error) + FetchWithHeaders(ctx context.Context, url string, headers http.Header) (*Artifact, error) Head(ctx context.Context, url string) (size int64, contentType string, err error) } @@ -162,6 +163,12 @@ func NewFetcher(opts ...Option) *Fetcher { // Fetch downloads an artifact from the given URL. // The caller must close the returned Artifact.Body when done. func (f *Fetcher) Fetch(ctx context.Context, url string) (*Artifact, error) { + return f.FetchWithHeaders(ctx, url, nil) +} + +// FetchWithHeaders downloads an artifact from the given URL with additional HTTP headers. +// The caller must close the returned Artifact.Body when done. +func (f *Fetcher) FetchWithHeaders(ctx context.Context, url string, headers http.Header) (*Artifact, error) { var lastErr error for attempt := 0; attempt <= f.maxRetries; attempt++ { @@ -178,7 +185,7 @@ func (f *Fetcher) Fetch(ctx context.Context, url string) (*Artifact, error) { } } - artifact, err := f.doFetch(ctx, url) + artifact, err := f.doFetch(ctx, url, headers) if err == nil { return artifact, nil } @@ -202,7 +209,7 @@ func (f *Fetcher) Fetch(ctx context.Context, url string) (*Artifact, error) { return nil, lastErr } -func (f *Fetcher) doFetch(ctx context.Context, url string) (*Artifact, error) { +func (f *Fetcher) doFetch(ctx context.Context, url string, headers http.Header) (*Artifact, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, fmt.Errorf("creating request: %w", err) @@ -211,7 +218,14 @@ func (f *Fetcher) doFetch(ctx context.Context, url string) (*Artifact, error) { req.Header.Set("User-Agent", f.userAgent) req.Header.Set("Accept", "*/*") - // Add authentication header if configured + // Add caller-provided headers + for key, values := range headers { + for _, v := range values { + req.Header.Set(key, v) + } + } + + // Add authentication header if configured (overrides caller headers) if f.authFn != nil { if name, value := f.authFn(url); name != "" && value != "" { req.Header.Set(name, value) diff --git a/fetch/fetcher_test.go b/fetch/fetcher_test.go index 34b92e3..0d6c29f 100644 --- a/fetch/fetcher_test.go +++ b/fetch/fetcher_test.go @@ -286,6 +286,74 @@ func TestFetchRetryWithJitter(t *testing.T) { } } +func TestFetchWithHeaders(t *testing.T) { + var receivedAuth string + var receivedCustom string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedAuth = r.Header.Get("Authorization") + receivedCustom = r.Header.Get("X-Custom") + _, _ = w.Write([]byte("ok")) + })) + defer server.Close() + + f := NewFetcher() + headers := http.Header{ + "Authorization": {"Bearer test-token"}, + "X-Custom": {"custom-value"}, + } + artifact, err := f.FetchWithHeaders(context.Background(), server.URL+"/test.tgz", headers) + if err != nil { + t.Fatalf("FetchWithHeaders failed: %v", err) + } + defer func() { _ = artifact.Body.Close() }() + + if receivedAuth != "Bearer test-token" { + t.Errorf("Authorization = %q, want %q", receivedAuth, "Bearer test-token") + } + if receivedCustom != "custom-value" { + t.Errorf("X-Custom = %q, want %q", receivedCustom, "custom-value") + } +} + +func TestFetchWithHeadersNil(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("ok")) + })) + defer server.Close() + + f := NewFetcher() + artifact, err := f.FetchWithHeaders(context.Background(), server.URL+"/test.tgz", nil) + if err != nil { + t.Fatalf("FetchWithHeaders with nil headers failed: %v", err) + } + _ = artifact.Body.Close() +} + +func TestFetchWithHeadersAuthFnOverrides(t *testing.T) { + var receivedAuth string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedAuth = r.Header.Get("Authorization") + _, _ = w.Write([]byte("ok")) + })) + defer server.Close() + + f := NewFetcher(WithAuthFunc(func(url string) (string, string) { + return "Authorization", "Bearer from-authfn" + })) + headers := http.Header{ + "Authorization": {"Bearer from-headers"}, + } + artifact, err := f.FetchWithHeaders(context.Background(), server.URL+"/test.tgz", headers) + if err != nil { + t.Fatalf("FetchWithHeaders failed: %v", err) + } + defer func() { _ = artifact.Body.Close() }() + + if receivedAuth != "Bearer from-authfn" { + t.Errorf("Authorization = %q, want %q (authFn should override)", receivedAuth, "Bearer from-authfn") + } +} + func TestFetchDNSCaching(t *testing.T) { requestCount := 0 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {