Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion internal/http/refresh_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ func (t *AuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {

type credentialStore interface {
Set(org, token string) error
Get(org string) (string, error)
GetRefreshToken(org string) (string, error)
SetRefreshToken(org, token string) error
DeleteRefreshToken(org string) error
Expand Down Expand Up @@ -186,7 +187,13 @@ func (t *RefreshTransport) doRefresh(ctx context.Context, failedToken string) (s
// (invalid/expired/revoked). Transient failures (network, 5xx)
// should not destroy the user's session.
if isTerminalRefreshError(err) {
_ = t.Keyring.DeleteRefreshToken(t.Org)
token, recovered, rotated := t.recoverRotatedCredentials(refreshToken, failedToken)
if recovered {
return token, nil
}
if !rotated {
_ = t.Keyring.DeleteRefreshToken(t.Org)
}
}
return "", err
}
Expand All @@ -208,6 +215,20 @@ func (t *RefreshTransport) doRefresh(ctx context.Context, failedToken string) (s
return tokenResp.AccessToken, nil
}

func (t *RefreshTransport) recoverRotatedCredentials(failedRefreshToken, failedToken string) (string, bool, bool) {
currentRefreshToken, err := t.Keyring.GetRefreshToken(t.Org)
if err != nil || currentRefreshToken == "" || currentRefreshToken == failedRefreshToken {
return "", false, false
}

accessToken, err := t.Keyring.Get(t.Org)
if err != nil || accessToken == "" || accessToken == failedToken {
return "", false, true
}
t.TokenSource.SetToken(accessToken)
return accessToken, true, true
}

// isTerminalRefreshError returns true for OAuth errors that indicate the
// refresh token is permanently invalid and should be cleared.
func isTerminalRefreshError(err error) bool {
Expand Down
142 changes: 142 additions & 0 deletions internal/http/refresh_transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,50 @@ import (
)

type stubCredentialStore struct {
mu sync.Mutex

accessToken string
refreshToken string
refreshTokenReads []string
setAccessErr error
setRefreshTokenErr error
deleteRefreshCalls int
}

func (s *stubCredentialStore) Set(_ string, token string) error {
s.mu.Lock()
defer s.mu.Unlock()

if s.setAccessErr != nil {
return s.setAccessErr
}
s.accessToken = token
return nil
}

func (s *stubCredentialStore) Get(string) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()

return s.accessToken, nil
}

func (s *stubCredentialStore) GetRefreshToken(string) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()

if len(s.refreshTokenReads) > 0 {
token := s.refreshTokenReads[0]
s.refreshTokenReads = s.refreshTokenReads[1:]
return token, nil
}
return s.refreshToken, nil
}

func (s *stubCredentialStore) SetRefreshToken(_ string, token string) error {
s.mu.Lock()
defer s.mu.Unlock()

if s.setRefreshTokenErr != nil {
return s.setRefreshTokenErr
}
Expand All @@ -41,6 +66,10 @@ func (s *stubCredentialStore) SetRefreshToken(_ string, token string) error {
}

func (s *stubCredentialStore) DeleteRefreshToken(string) error {
s.mu.Lock()
defer s.mu.Unlock()

s.deleteRefreshCalls++
s.refreshToken = ""
return nil
}
Expand Down Expand Up @@ -224,6 +253,119 @@ func TestRefreshTransport_RotatedRefreshTokenStoreFailureReturnsError(t *testing
}
}

func TestRefreshTransport_InvalidGrantRotationHandling(t *testing.T) {
tests := []struct {
name string
storedAccessToken string
storedRefreshToken string
refreshTokenReads []string
wantStatus int
wantTokenSource string
wantRefreshToken string
wantDeleteCalls int
}{
{
name: "recovers with access token rotated by another process",
storedAccessToken: "new-token",
storedRefreshToken: "new-refresh-token",
refreshTokenReads: []string{"old-refresh-token", "old-refresh-token", "new-refresh-token"},
wantStatus: http.StatusOK,
wantTokenSource: "new-token",
wantRefreshToken: "new-refresh-token",
},
{
name: "preserves rotated refresh token when access token is not stored yet",
storedAccessToken: "old-token",
storedRefreshToken: "new-refresh-token",
refreshTokenReads: []string{"old-refresh-token", "old-refresh-token", "new-refresh-token"},
wantStatus: http.StatusUnauthorized,
wantTokenSource: "old-token",
wantRefreshToken: "new-refresh-token",
},
{
name: "deletes unchanged invalid refresh token",
storedAccessToken: "old-token",
storedRefreshToken: "old-refresh-token",
refreshTokenReads: []string{"old-refresh-token", "old-refresh-token", "old-refresh-token"},
wantStatus: http.StatusUnauthorized,
wantTokenSource: "old-token",
wantDeleteCalls: 1,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var refreshRequests atomic.Int32
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/oauth/token":
refreshRequests.Add(1)
if got := r.FormValue("refresh_token"); got != "old-refresh-token" {
t.Errorf("expected refresh request to use old refresh token, got %q", got)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(`{"error":"invalid_grant","error_description":"Invalid refresh token"}`))
default:
switch r.Header.Get("Authorization") {
case "Bearer old-token":
w.WriteHeader(http.StatusUnauthorized)
case "Bearer new-token":
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"ok":true}`))
default:
w.WriteHeader(http.StatusUnauthorized)
}
}
}))
defer server.Close()

origTransport := http.DefaultTransport
http.DefaultTransport = server.Client().Transport
defer func() { http.DefaultTransport = origTransport }()

store := &stubCredentialStore{
accessToken: tt.storedAccessToken,
refreshToken: tt.storedRefreshToken,
refreshTokenReads: append([]string(nil), tt.refreshTokenReads...),
}
transport := &RefreshTransport{
Base: server.Client().Transport,
Org: "test-org",
Keyring: store,
TokenSource: NewTokenSource("old-token"),
}

t.Setenv("BUILDKITE_HOST", strings.TrimPrefix(server.URL, "https://"))

req, _ := http.NewRequest(http.MethodGet, server.URL+"/test", nil)
req.Header.Set("Authorization", "Bearer old-token")

resp, err := transport.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Body.Close()

if resp.StatusCode != tt.wantStatus {
t.Fatalf("expected status %d, got %d", tt.wantStatus, resp.StatusCode)
}
if got := transport.TokenSource.Token(); got != tt.wantTokenSource {
t.Fatalf("expected token source %q, got %q", tt.wantTokenSource, got)
}
if got := store.refreshToken; got != tt.wantRefreshToken {
t.Fatalf("expected stored refresh token %q, got %q", tt.wantRefreshToken, got)
}
if got := store.deleteRefreshCalls; got != tt.wantDeleteCalls {
t.Fatalf("expected %d refresh token deletes, got %d", tt.wantDeleteCalls, got)
}
if got := refreshRequests.Load(); got != 1 {
t.Fatalf("expected 1 refresh request, got %d", got)
}
})
}
}

func TestRefreshTransport_DoesNotDeleteRefreshTokenOnTransientError(t *testing.T) {
keyring.MockForTesting()
defer keyring.ResetForTesting()
Expand Down