Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion internal/client/tokensource/gcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,29 @@ package tokensource
import (
"context"
"fmt"
"net/url"
"strings"

"cloud.google.com/go/compute/metadata"
"google.golang.org/api/idtoken"
)

var (
newIDTokenSource = idtoken.NewTokenSource
metadataGet = metadata.GetWithContext
)

// FromGCP fetches an OIDC identity token using GCP Application Default
// Credentials. This works on GCE, Cloud Run, GKE, and anywhere a service
// account key or workload identity federation is configured.
func FromGCP(ctx context.Context, audience string) (string, error) {
ts, err := idtoken.NewTokenSource(ctx, audience)
ts, err := newIDTokenSource(ctx, audience)

if err != nil {
if isUnsupportedAuthorizedUser(err) {
return fromGCPMetadata(ctx, audience)
}

return "", fmt.Errorf("creating GCP token source: %w", err)
}

Expand All @@ -30,3 +42,26 @@ func FromGCP(ctx context.Context, audience string) (string, error) {

return token.AccessToken, nil
}

func isUnsupportedAuthorizedUser(err error) bool {
errString := err.Error()
return strings.Contains(errString, "unsupported credentials type") && strings.Contains(errString, "authorized_user")
}

func fromGCPMetadata(ctx context.Context, audience string) (string, error) {
v := url.Values{}
v.Set("audience", audience)
v.Set("format", "full")

token, err := metadataGet(ctx, "instance/service-accounts/default/identity?"+v.Encode())

if err != nil {
return "", fmt.Errorf("fetching GCP identity token from metadata server: %w", err)
}

if token == "" {
return "", fmt.Errorf("GCP metadata server returned an empty token")
}

return token, nil
}
82 changes: 82 additions & 0 deletions internal/client/tokensource/gcp_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package tokensource

import (
"context"
"errors"
"testing"

"golang.org/x/oauth2"
"google.golang.org/api/idtoken"
)

func TestFromGCPFallsBackToMetadataForAuthorizedUserCredentials(t *testing.T) {
tests := []struct {
name string
err string
}{
{"unquoted", "idtoken: unsupported credentials type: authorized_user"},
{"quoted", "idtoken: unsupported credentials type: \"authorized_user\""},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resetGCPTestHooks(t)

newIDTokenSource = func(_ context.Context, _ string, _ ...idtoken.ClientOption) (oauth2.TokenSource, error) {
return nil, errors.New(tt.err)
}

var gotSuffix string
metadataGet = func(_ context.Context, suffix string) (string, error) {
gotSuffix = suffix
return "metadata-token", nil
}

token, err := FromGCP(context.Background(), "https://sts.example/token")

if err != nil {
t.Fatalf("FromGCP returned error: %v", err)
}

if token != "metadata-token" {
t.Fatalf("FromGCP token = %q, want %q", token, "metadata-token")
}

wantSuffix := "instance/service-accounts/default/identity?audience=https%3A%2F%2Fsts.example%2Ftoken&format=full"
if gotSuffix != wantSuffix {
t.Fatalf("metadata suffix = %q, want %q", gotSuffix, wantSuffix)
}
})
}
}

func TestFromGCPDoesNotFallBackForOtherTokenSourceErrors(t *testing.T) {
resetGCPTestHooks(t)

newIDTokenSource = func(_ context.Context, _ string, _ ...idtoken.ClientOption) (oauth2.TokenSource, error) {
return nil, errors.New("boom")
}

metadataGet = func(_ context.Context, _ string) (string, error) {
t.Fatal("metadataGet should not be called")
return "", nil
}

_, err := FromGCP(context.Background(), "https://sts.example/token")

if err == nil {
t.Fatal("FromGCP returned nil error")
}
}

func resetGCPTestHooks(t *testing.T) {
t.Helper()

originalNewIDTokenSource := newIDTokenSource
originalMetadataGet := metadataGet

t.Cleanup(func() {
newIDTokenSource = originalNewIDTokenSource
metadataGet = originalMetadataGet
})
}