diff --git a/go.mod b/go.mod index 4ac7b16..0f5f271 100644 --- a/go.mod +++ b/go.mod @@ -5,10 +5,10 @@ go 1.25.6 require ( github.com/git-pkgs/archives v0.2.0 github.com/git-pkgs/enrichment v0.2.1 - github.com/git-pkgs/purl v0.1.9 - github.com/git-pkgs/registries v0.3.0 - github.com/git-pkgs/spdx v0.1.1 - github.com/git-pkgs/vers v0.2.3 + github.com/git-pkgs/purl v0.1.10 + github.com/git-pkgs/registries v0.4.0 + github.com/git-pkgs/spdx v0.1.2 + github.com/git-pkgs/vers v0.2.4 github.com/git-pkgs/vulns v0.1.3 github.com/go-chi/chi/v5 v5.2.5 github.com/jmoiron/sqlx v1.4.0 @@ -276,7 +276,7 @@ require ( golang.org/x/exp/typeparams v0.0.0-20260209203927-2842357ff358 // indirect golang.org/x/mod v0.33.0 // indirect golang.org/x/net v0.51.0 // indirect - golang.org/x/sync v0.19.0 // indirect + golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.42.0 // indirect golang.org/x/text v0.34.0 // indirect golang.org/x/tools v0.42.0 // indirect diff --git a/go.sum b/go.sum index 5ee370b..63f37e1 100644 --- a/go.sum +++ b/go.sum @@ -230,14 +230,14 @@ github.com/git-pkgs/enrichment v0.2.1 h1:mJJt4YQBzl9aOfu4226ylnC9H6YO9YZDjGpbSPV github.com/git-pkgs/enrichment v0.2.1/go.mod h1:q9eDZpRrUbYwzD4Mtg/T6LRdBMlt2DYRIvVRDULFnKg= github.com/git-pkgs/packageurl-go v0.3.1 h1:WM3RBABQZLaRBxgKyYughc3cVBE8KyQxbSC6Jt5ak7M= github.com/git-pkgs/packageurl-go v0.3.1/go.mod h1:rcIxiG37BlQLB6FZfgdj9Fm7yjhRQd3l+5o7J0QPAk4= -github.com/git-pkgs/purl v0.1.9 h1:zSHKBVwRTJiMGwiYIiHgoIUfJTdtC7kVQ0+0RHckwxc= -github.com/git-pkgs/purl v0.1.9/go.mod h1:6YX25yhztts1Byktw4pOlykru57GOJaanA+WmOBFtdU= -github.com/git-pkgs/registries v0.3.0 h1:eIM78ry7l1CfwbPMXQ/vCsN9xJNWN1uDmkl76MS+OT8= -github.com/git-pkgs/registries v0.3.0/go.mod h1:RAqG9XyGLV56F8tBXXyzmEaHTBkub7MWFD9KGjt4WtQ= -github.com/git-pkgs/spdx v0.1.1 h1:jjchxLhvTnTR7fLcdXdNVDh/tLq6B2S6LnaKEzBjhRQ= -github.com/git-pkgs/spdx v0.1.1/go.mod h1:nbZdJ09OuZg9/bgRnnyEM5F5uR8K7Iwf5oDHQvK3WcE= -github.com/git-pkgs/vers v0.2.3 h1:elyuJZ2mBRIncRUF6SjpnwIwSuRRnPdAEJBZcVgU450= -github.com/git-pkgs/vers v0.2.3/go.mod h1:biTbSQK1qdbrsxDEKnqe3Jzclxz8vW6uDcwKjfUGcOo= +github.com/git-pkgs/purl v0.1.10 h1:NMjeF10nzFn3tdQlz6rbmHB+i+YkyrFQxho3e33ePTQ= +github.com/git-pkgs/purl v0.1.10/go.mod h1:C5Vp/kyZ/wGckCLexx4wPVfUxEiToRkdsOPh5Z7ig/I= +github.com/git-pkgs/registries v0.4.0 h1:GO7fQ8/jot0ulSQHBdxLSNSX/p8eB3gEXWO+98fmoEo= +github.com/git-pkgs/registries v0.4.0/go.mod h1:49UCPFWQmwNV7rBEr9TrTDWKR7vYxFcxp3VfdkeFbdE= +github.com/git-pkgs/spdx v0.1.2 h1:wHSK+CqFsO5N7yDTPvxDmer5LgNEa7vAsiZhi5Aci0A= +github.com/git-pkgs/spdx v0.1.2/go.mod h1:V98MgZapNgYw54/pdGR82d7RU93qzJoybahbpZqTfw8= +github.com/git-pkgs/vers v0.2.4 h1:Zr3jR/Xf1i/6cvBaJKPxhCwjzqz7uvYHE0Fhid/GPBk= +github.com/git-pkgs/vers v0.2.4/go.mod h1:biTbSQK1qdbrsxDEKnqe3Jzclxz8vW6uDcwKjfUGcOo= github.com/git-pkgs/vulns v0.1.3 h1:Q9GixxhAYpP5vVDetKNMACHxGnWwB8aE5c9kbE8xxqU= github.com/git-pkgs/vulns v0.1.3/go.mod h1:/PVy7S1oZNVF9X8yVOZ9SX5MFpyVWCtLnIX0kAfPjY0= github.com/github/go-spdx/v2 v2.4.0 h1:+4IwVwJJbm3rzvrQ6P1nI9BDMcy3la4RchRy5uehV/M= @@ -738,8 +738,8 @@ golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= -golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= -golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/internal/handler/container.go b/internal/handler/container.go index fc5f98c..8aa82eb 100644 --- a/internal/handler/container.go +++ b/internal/handler/container.go @@ -103,20 +103,22 @@ func (h *ContainerHandler) handleBlobDownload(w http.ResponseWriter, r *http.Req return } - // Try to get from cache first + // Try to get from cache, or fetch from upstream with auth filename := digest - result, err := h.proxy.GetOrFetchArtifactFromURL( + headers := http.Header{"Authorization": {"Bearer " + token}} + result, err := h.proxy.GetOrFetchArtifactFromURLWithHeaders( r.Context(), "oci", name, digest, // use digest as version filename, fmt.Sprintf("%s/v2/%s/blobs/%s", h.registryURL, name, digest), + headers, ) if err != nil { - // Fetch directly with auth - h.proxyBlobWithAuth(w, r, name, digest, token) + h.proxy.Logger.Error("failed to fetch blob", "error", err) + h.containerError(w, http.StatusBadGateway, "BLOB_UNKNOWN", "failed to fetch blob") return } @@ -304,34 +306,6 @@ func (h *ContainerHandler) proxyBlobHead(w http.ResponseWriter, r *http.Request, w.WriteHeader(resp.StatusCode) } -// proxyBlobWithAuth proxies a blob download with authentication. -func (h *ContainerHandler) proxyBlobWithAuth(w http.ResponseWriter, r *http.Request, name, digest, token string) { - upstreamURL := fmt.Sprintf("%s/v2/%s/blobs/%s", h.registryURL, name, digest) - - req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, upstreamURL, nil) - if err != nil { - h.containerError(w, http.StatusInternalServerError, "INTERNAL_ERROR", "failed to create request") - return - } - - req.Header.Set("Authorization", "Bearer "+token) - - resp, err := h.proxy.HTTPClient.Do(req) - if err != nil { - h.containerError(w, http.StatusBadGateway, "INTERNAL_ERROR", "failed to fetch from upstream") - return - } - defer func() { _ = resp.Body.Close() }() - - for _, header := range []string{"Content-Type", "Content-Length", "Docker-Content-Digest"} { - if v := resp.Header.Get(header); v != "" { - w.Header().Set(header, v) - } - } - - w.WriteHeader(resp.StatusCode) - _, _ = io.Copy(w, resp.Body) -} // containerError writes an OCI-compliant error response. func (h *ContainerHandler) containerError(w http.ResponseWriter, status int, code, message string) { diff --git a/internal/handler/container_test.go b/internal/handler/container_test.go index b84adfd..b34a250 100644 --- a/internal/handler/container_test.go +++ b/internal/handler/container_test.go @@ -1,9 +1,17 @@ package handler import ( + "bytes" + "context" + "encoding/json" + "io" + "log/slog" "net/http" "net/http/httptest" "testing" + + "github.com/git-pkgs/proxy/internal/database" + "github.com/git-pkgs/registries/fetch" ) func TestContainerHandler_parseBlobPath(t *testing.T) { @@ -127,6 +135,92 @@ func TestContainerHandler_parseTagsListPath(t *testing.T) { } } +func TestContainerHandler_BlobDownload_CachesWithAuth(t *testing.T) { + // Set up a mock auth server that returns a token + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{"token": "test-token-123"}) + })) + defer authServer.Close() + + // Set up mock fetcher that captures headers + var capturedHeaders http.Header + mf := &mockFetcherWithHeaders{ + fetchFn: func(_ context.Context, _ string, headers http.Header) (*fetch.Artifact, error) { + capturedHeaders = headers + return &fetch.Artifact{ + Body: io.NopCloser(bytes.NewReader([]byte("blob-content"))), + Size: 12, + ContentType: "application/octet-stream", + }, nil + }, + } + + dir := t.TempDir() + db, err := database.Create(dir + "/test.db") + if err != nil { + t.Fatalf("failed to create test database: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + + store := newMockStorage() + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + proxy := &Proxy{ + DB: db, + Storage: store, + Fetcher: mf, + Logger: logger, + HTTPClient: &http.Client{}, + } + + h := &ContainerHandler{ + proxy: proxy, + registryURL: "https://registry-1.docker.io", + authURL: authServer.URL, + proxyURL: "http://localhost:8080", + } + + handler := h.Routes() + req := httptest.NewRequest(http.MethodGet, "/library/nginx/blobs/sha256:abc123def456abc123def456abc123def456abc123def456abc123def456abcd", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("got status %d, want %d; body: %s", w.Code, http.StatusOK, w.Body.String()) + } + + // Verify auth header was passed to the fetcher + if capturedHeaders == nil { + t.Fatal("expected headers to be passed to fetcher, got nil") + } + auth := capturedHeaders.Get("Authorization") + if auth != "Bearer test-token-123" { + t.Errorf("Authorization = %q, want %q", auth, "Bearer test-token-123") + } + + // Verify response headers + if got := w.Header().Get("Docker-Content-Digest"); got != "sha256:abc123def456abc123def456abc123def456abc123def456abc123def456abcd" { + t.Errorf("Docker-Content-Digest = %q, want digest", got) + } +} + +// mockFetcherWithHeaders captures headers passed to FetchWithHeaders. +type mockFetcherWithHeaders struct { + fetchFn func(ctx context.Context, url string, headers http.Header) (*fetch.Artifact, error) +} + +func (f *mockFetcherWithHeaders) Fetch(ctx context.Context, url string) (*fetch.Artifact, error) { + return f.FetchWithHeaders(ctx, url, nil) +} + +func (f *mockFetcherWithHeaders) FetchWithHeaders(ctx context.Context, url string, headers http.Header) (*fetch.Artifact, error) { + return f.fetchFn(ctx, url, headers) +} + +func (f *mockFetcherWithHeaders) Head(_ context.Context, _ string) (int64, string, error) { + return 0, "", nil +} + func TestContainerHandler_Routes_VersionCheck(t *testing.T) { h := NewContainerHandler(nil, "http://localhost:8080") diff --git a/internal/handler/handler.go b/internal/handler/handler.go index 91d8960..109eacd 100644 --- a/internal/handler/handler.go +++ b/internal/handler/handler.go @@ -380,6 +380,13 @@ func JSONError(w http.ResponseWriter, status int, message string) { // GetOrFetchArtifactFromURL retrieves an artifact from cache or fetches from a specific URL. // This is useful for registries where download URLs are determined from metadata. func (p *Proxy) GetOrFetchArtifactFromURL(ctx context.Context, ecosystem, name, version, filename, downloadURL string) (*CacheResult, error) { + return p.GetOrFetchArtifactFromURLWithHeaders(ctx, ecosystem, name, version, filename, downloadURL, nil) +} + +// GetOrFetchArtifactFromURLWithHeaders retrieves an artifact from cache or fetches from a URL +// with additional HTTP headers. This is needed for registries that require authentication +// (e.g. Docker Hub requires a Bearer token even for public images). +func (p *Proxy) GetOrFetchArtifactFromURLWithHeaders(ctx context.Context, ecosystem, name, version, filename, downloadURL string, headers http.Header) (*CacheResult, error) { pkgPURL := purl.MakePURLString(ecosystem, name, "") versionPURL := purl.MakePURLString(ecosystem, name, version) @@ -389,14 +396,14 @@ func (p *Proxy) GetOrFetchArtifactFromURL(ctx context.Context, ecosystem, name, return cached, nil } - return p.fetchAndCacheFromURL(ctx, ecosystem, name, version, filename, pkgPURL, versionPURL, downloadURL) + return p.fetchAndCacheFromURL(ctx, ecosystem, name, version, filename, pkgPURL, versionPURL, downloadURL, headers) } -func (p *Proxy) fetchAndCacheFromURL(ctx context.Context, ecosystem, name, version, filename, pkgPURL, versionPURL, downloadURL string) (*CacheResult, error) { +func (p *Proxy) fetchAndCacheFromURL(ctx context.Context, ecosystem, name, version, filename, pkgPURL, versionPURL, downloadURL string, headers http.Header) (*CacheResult, error) { p.Logger.Info("fetching from upstream", "ecosystem", ecosystem, "name", name, "version", version, "url", downloadURL) - artifact, err := p.Fetcher.Fetch(ctx, downloadURL) + artifact, err := p.Fetcher.FetchWithHeaders(ctx, downloadURL, headers) if err != nil { return nil, fmt.Errorf("fetching from upstream: %w", err) } diff --git a/internal/handler/handler_test.go b/internal/handler/handler_test.go index dd85a17..5c433d6 100644 --- a/internal/handler/handler_test.go +++ b/internal/handler/handler_test.go @@ -87,7 +87,11 @@ type mockFetcher struct { fetchedURL string } -func (f *mockFetcher) Fetch(_ context.Context, url string) (*fetch.Artifact, error) { +func (f *mockFetcher) Fetch(ctx context.Context, url string) (*fetch.Artifact, error) { + return f.FetchWithHeaders(ctx, url, nil) +} + +func (f *mockFetcher) FetchWithHeaders(_ context.Context, url string, _ http.Header) (*fetch.Artifact, error) { f.fetchCalled = true f.fetchedURL = url if f.fetchErr != nil {