From 34d4851cfec45d8fdbd9f44f302128dc12f4b69a Mon Sep 17 00:00:00 2001 From: Lachlan Donald Date: Sat, 6 Jun 2026 07:24:31 +1000 Subject: [PATCH] fix: preserve rotated refresh tokens --- internal/http/refresh_transport.go | 23 +++- internal/http/refresh_transport_test.go | 142 ++++++++++++++++++++++++ 2 files changed, 164 insertions(+), 1 deletion(-) diff --git a/internal/http/refresh_transport.go b/internal/http/refresh_transport.go index 3844d866..511c5432 100644 --- a/internal/http/refresh_transport.go +++ b/internal/http/refresh_transport.go @@ -71,6 +71,7 @@ func (t *AuthTransport) RoundTrip(req *http.Request) (*http.Response, error) { type credentialStore interface { Set(org, token string) error + Get(org string) (string, error) GetRefreshToken(org string) (string, error) SetRefreshToken(org, token string) error DeleteRefreshToken(org string) error @@ -186,7 +187,13 @@ func (t *RefreshTransport) doRefresh(ctx context.Context, failedToken string) (s // (invalid/expired/revoked). Transient failures (network, 5xx) // should not destroy the user's session. if isTerminalRefreshError(err) { - _ = t.Keyring.DeleteRefreshToken(t.Org) + token, recovered, rotated := t.recoverRotatedCredentials(refreshToken, failedToken) + if recovered { + return token, nil + } + if !rotated { + _ = t.Keyring.DeleteRefreshToken(t.Org) + } } return "", err } @@ -208,6 +215,20 @@ func (t *RefreshTransport) doRefresh(ctx context.Context, failedToken string) (s return tokenResp.AccessToken, nil } +func (t *RefreshTransport) recoverRotatedCredentials(failedRefreshToken, failedToken string) (string, bool, bool) { + currentRefreshToken, err := t.Keyring.GetRefreshToken(t.Org) + if err != nil || currentRefreshToken == "" || currentRefreshToken == failedRefreshToken { + return "", false, false + } + + accessToken, err := t.Keyring.Get(t.Org) + if err != nil || accessToken == "" || accessToken == failedToken { + return "", false, true + } + t.TokenSource.SetToken(accessToken) + return accessToken, true, true +} + // isTerminalRefreshError returns true for OAuth errors that indicate the // refresh token is permanently invalid and should be cleared. func isTerminalRefreshError(err error) bool { diff --git a/internal/http/refresh_transport_test.go b/internal/http/refresh_transport_test.go index 8d11e6bd..d1a86f4e 100644 --- a/internal/http/refresh_transport_test.go +++ b/internal/http/refresh_transport_test.go @@ -14,13 +14,20 @@ import ( ) type stubCredentialStore struct { + mu sync.Mutex + accessToken string refreshToken string + refreshTokenReads []string setAccessErr error setRefreshTokenErr error + deleteRefreshCalls int } func (s *stubCredentialStore) Set(_ string, token string) error { + s.mu.Lock() + defer s.mu.Unlock() + if s.setAccessErr != nil { return s.setAccessErr } @@ -28,11 +35,29 @@ func (s *stubCredentialStore) Set(_ string, token string) error { return nil } +func (s *stubCredentialStore) Get(string) (string, error) { + s.mu.Lock() + defer s.mu.Unlock() + + return s.accessToken, nil +} + func (s *stubCredentialStore) GetRefreshToken(string) (string, error) { + s.mu.Lock() + defer s.mu.Unlock() + + if len(s.refreshTokenReads) > 0 { + token := s.refreshTokenReads[0] + s.refreshTokenReads = s.refreshTokenReads[1:] + return token, nil + } return s.refreshToken, nil } func (s *stubCredentialStore) SetRefreshToken(_ string, token string) error { + s.mu.Lock() + defer s.mu.Unlock() + if s.setRefreshTokenErr != nil { return s.setRefreshTokenErr } @@ -41,6 +66,10 @@ func (s *stubCredentialStore) SetRefreshToken(_ string, token string) error { } func (s *stubCredentialStore) DeleteRefreshToken(string) error { + s.mu.Lock() + defer s.mu.Unlock() + + s.deleteRefreshCalls++ s.refreshToken = "" return nil } @@ -224,6 +253,119 @@ func TestRefreshTransport_RotatedRefreshTokenStoreFailureReturnsError(t *testing } } +func TestRefreshTransport_InvalidGrantRotationHandling(t *testing.T) { + tests := []struct { + name string + storedAccessToken string + storedRefreshToken string + refreshTokenReads []string + wantStatus int + wantTokenSource string + wantRefreshToken string + wantDeleteCalls int + }{ + { + name: "recovers with access token rotated by another process", + storedAccessToken: "new-token", + storedRefreshToken: "new-refresh-token", + refreshTokenReads: []string{"old-refresh-token", "old-refresh-token", "new-refresh-token"}, + wantStatus: http.StatusOK, + wantTokenSource: "new-token", + wantRefreshToken: "new-refresh-token", + }, + { + name: "preserves rotated refresh token when access token is not stored yet", + storedAccessToken: "old-token", + storedRefreshToken: "new-refresh-token", + refreshTokenReads: []string{"old-refresh-token", "old-refresh-token", "new-refresh-token"}, + wantStatus: http.StatusUnauthorized, + wantTokenSource: "old-token", + wantRefreshToken: "new-refresh-token", + }, + { + name: "deletes unchanged invalid refresh token", + storedAccessToken: "old-token", + storedRefreshToken: "old-refresh-token", + refreshTokenReads: []string{"old-refresh-token", "old-refresh-token", "old-refresh-token"}, + wantStatus: http.StatusUnauthorized, + wantTokenSource: "old-token", + wantDeleteCalls: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var refreshRequests atomic.Int32 + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/oauth/token": + refreshRequests.Add(1) + if got := r.FormValue("refresh_token"); got != "old-refresh-token" { + t.Errorf("expected refresh request to use old refresh token, got %q", got) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":"invalid_grant","error_description":"Invalid refresh token"}`)) + default: + switch r.Header.Get("Authorization") { + case "Bearer old-token": + w.WriteHeader(http.StatusUnauthorized) + case "Bearer new-token": + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"ok":true}`)) + default: + w.WriteHeader(http.StatusUnauthorized) + } + } + })) + defer server.Close() + + origTransport := http.DefaultTransport + http.DefaultTransport = server.Client().Transport + defer func() { http.DefaultTransport = origTransport }() + + store := &stubCredentialStore{ + accessToken: tt.storedAccessToken, + refreshToken: tt.storedRefreshToken, + refreshTokenReads: append([]string(nil), tt.refreshTokenReads...), + } + transport := &RefreshTransport{ + Base: server.Client().Transport, + Org: "test-org", + Keyring: store, + TokenSource: NewTokenSource("old-token"), + } + + t.Setenv("BUILDKITE_HOST", strings.TrimPrefix(server.URL, "https://")) + + req, _ := http.NewRequest(http.MethodGet, server.URL+"/test", nil) + req.Header.Set("Authorization", "Bearer old-token") + + resp, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != tt.wantStatus { + t.Fatalf("expected status %d, got %d", tt.wantStatus, resp.StatusCode) + } + if got := transport.TokenSource.Token(); got != tt.wantTokenSource { + t.Fatalf("expected token source %q, got %q", tt.wantTokenSource, got) + } + if got := store.refreshToken; got != tt.wantRefreshToken { + t.Fatalf("expected stored refresh token %q, got %q", tt.wantRefreshToken, got) + } + if got := store.deleteRefreshCalls; got != tt.wantDeleteCalls { + t.Fatalf("expected %d refresh token deletes, got %d", tt.wantDeleteCalls, got) + } + if got := refreshRequests.Load(); got != 1 { + t.Fatalf("expected 1 refresh request, got %d", got) + } + }) + } +} + func TestRefreshTransport_DoesNotDeleteRefreshTokenOnTransientError(t *testing.T) { keyring.MockForTesting() defer keyring.ResetForTesting()