From 3660022c7d1416e2739080dcfdf2dd42bd478114 Mon Sep 17 00:00:00 2001 From: bmendonca3 <208517100+bmendonca3@users.noreply.github.com> Date: Thu, 28 May 2026 18:08:21 -0700 Subject: [PATCH] Support enterprise GitHub Actions OIDC issuers Signed-off-by: bmendonca3 <208517100+bmendonca3@users.noreply.github.com> --- github/oidc.go | 82 ++++++++++++++++++++++++++++++++++++--- github/oidc_test.go | 94 ++++++++++++++++++++++++++++++++++++++++++++- github/oidctest.go | 13 +++++-- 3 files changed, 178 insertions(+), 11 deletions(-) diff --git a/github/oidc.go b/github/oidc.go index 330817dd75..125a33ebd9 100644 --- a/github/oidc.go +++ b/github/oidc.go @@ -17,6 +17,7 @@ package github import ( "bytes" "context" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -25,6 +26,7 @@ import ( "net/url" "os" "sort" + "strings" "time" "github.com/coreos/go-oidc/v3/oidc" @@ -86,9 +88,12 @@ type OIDCClient struct { // requestURL is the GitHub URL to request a OIDC token. requestURL *url.URL + // actionsProviderURL is the expected GitHub Actions OIDC issuer base URL. + actionsProviderURL string + // verifierFunc is a factory to generate an oidc.IDTokenVerifier for token verification. // This is used for tests. - verifierFunc func(context.Context) (*oidc.IDTokenVerifier, error) + verifierFunc func(context.Context, string) (*oidc.IDTokenVerifier, error) // bearerToken is used to request an ID token. bearerToken string @@ -107,11 +112,12 @@ func NewOIDCClient() (*OIDCClient, error) { } c := OIDCClient{ - requestURL: parsedURL, - bearerToken: os.Getenv(requestTokenEnvKey), + requestURL: parsedURL, + actionsProviderURL: defaultActionsProviderURL, + bearerToken: os.Getenv(requestTokenEnvKey), } - c.verifierFunc = func(ctx context.Context) (*oidc.IDTokenVerifier, error) { - provider, err := oidc.NewProvider(ctx, defaultActionsProviderURL) + c.verifierFunc = func(ctx context.Context, issuer string) (*oidc.IDTokenVerifier, error) { + provider, err := oidc.NewProvider(ctx, issuer) if err != nil { return nil, err } @@ -172,10 +178,74 @@ func (c *OIDCClient) decodePayload(b []byte) (string, error) { return payload.Value, nil } +func tokenIssuer(payload string) (string, error) { + parts := strings.Split(payload, ".") + if len(parts) < 2 { + return "", fmt.Errorf("jwt parts: %d", len(parts)) + } + + b, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return "", fmt.Errorf("jwt payload: %w", err) + } + + var claims struct { + Issuer string `json:"iss"` + } + if err := json.Unmarshal(b, &claims); err != nil { + return "", fmt.Errorf("jwt payload claims: %w", err) + } + if claims.Issuer == "" { + return "", errors.New("issuer is empty") + } + + return claims.Issuer, nil +} + +func validActionsIssuer(issuer, providerURL string) (string, error) { + parsedIssuer, err := url.Parse(issuer) + if err != nil { + return "", fmt.Errorf("parse issuer %q: %w", issuer, err) + } + parsedProvider, err := url.Parse(providerURL) + if err != nil { + return "", fmt.Errorf("parse provider %q: %w", providerURL, err) + } + + if parsedIssuer.Scheme != parsedProvider.Scheme || parsedIssuer.Host != parsedProvider.Host { + return "", fmt.Errorf("issuer %q does not match provider %q", issuer, providerURL) + } + if parsedIssuer.User != nil || parsedIssuer.RawQuery != "" || parsedIssuer.Fragment != "" { + return "", fmt.Errorf("issuer %q contains unsupported URL components", issuer) + } + + issuerPath := parsedIssuer.Path + providerPath := strings.TrimSuffix(parsedProvider.Path, "/") + if issuerPath == providerPath || issuerPath == "" && providerPath == "" { + return issuer, nil + } + + suffix, ok := strings.CutPrefix(issuerPath, providerPath+"/") + if !ok || suffix == "" || strings.Contains(suffix, "/") { + return "", fmt.Errorf("issuer %q is not a supported GitHub Actions issuer", issuer) + } + + return issuer, nil +} + // verifyToken verifies the token contents and signature. func (c *OIDCClient) verifyToken(ctx context.Context, audience []string, payload string) (*oidc.IDToken, error) { + issuer, err := tokenIssuer(payload) + if err != nil { + return nil, fmt.Errorf("%w: extracting issuer: %w", errVerify, err) + } + issuer, err = validActionsIssuer(issuer, c.actionsProviderURL) + if err != nil { + return nil, fmt.Errorf("%w: validating issuer: %w", errVerify, err) + } + // Verify the token. - verifier, err := c.verifierFunc(ctx) + verifier, err := c.verifierFunc(ctx, issuer) if err != nil { return nil, fmt.Errorf("%w: creating verifier: %w", errVerify, err) } diff --git a/github/oidc_test.go b/github/oidc_test.go index efc103b78e..48860566a5 100644 --- a/github/oidc_test.go +++ b/github/oidc_test.go @@ -22,6 +22,7 @@ import ( "net/http" "net/http/httptest" "os" + "strings" "testing" "time" @@ -130,6 +131,19 @@ func TestToken(t *testing.T) { ActorID: "4567", }, }, + { + name: "enterprise issuer slug", + audience: []string{"hoge"}, + token: &OIDCToken{ + Issuer: "/octocat-inc", + Audience: []string{"hoge"}, + Expiry: now.Add(1 * time.Hour), + JobWorkflowRef: "pico", + RepositoryID: "1234", + RepositoryOwnerID: "4321", + ActorID: "4567", + }, + }, { name: "no repository id claim", audience: []string{"hoge"}, @@ -285,7 +299,11 @@ func TestToken(t *testing.T) { tc.err(err) } else { // Successful response, as expected. Check token. - if want, got := tc.token, token; !tokenEqual(s.URL, want, got) { + expectedIssuer := s.URL + if strings.HasPrefix(tc.token.Issuer, "/") { + expectedIssuer += tc.token.Issuer + } + if want, got := tc.token, token; !tokenEqual(expectedIssuer, want, got) { t.Errorf("unexpected workflow ref\nwant: %#v\ngot: %#v\ndiff:\n%v", want, got, cmp.Diff(want, got)) } } @@ -294,6 +312,80 @@ func TestToken(t *testing.T) { } } +func Test_validActionsIssuer(t *testing.T) { + testCases := []struct { + name string + issuer string + providerURL string + wantErr bool + }{ + { + name: "default issuer", + issuer: "https://token.actions.githubusercontent.com", + providerURL: defaultActionsProviderURL, + }, + { + name: "enterprise slug issuer", + issuer: "https://token.actions.githubusercontent.com/octocat-inc", + providerURL: defaultActionsProviderURL, + }, + { + name: "wrong host", + issuer: "https://token.actions.githubusercontent.com.evil/octocat-inc", + providerURL: defaultActionsProviderURL, + wantErr: true, + }, + { + name: "userinfo", + issuer: "https://user:pass@token.actions.githubusercontent.com/octocat-inc", + providerURL: defaultActionsProviderURL, + wantErr: true, + }, + { + name: "query string", + issuer: "https://token.actions.githubusercontent.com/octocat-inc?foo=bar", + providerURL: defaultActionsProviderURL, + wantErr: true, + }, + { + name: "multiple path segments", + issuer: "https://token.actions.githubusercontent.com/octocat-inc/extra", + providerURL: defaultActionsProviderURL, + wantErr: true, + }, + { + name: "encoded slash", + issuer: "https://token.actions.githubusercontent.com/octocat-inc%2Fextra", + providerURL: defaultActionsProviderURL, + wantErr: true, + }, + { + name: "empty path segment", + issuer: "https://token.actions.githubusercontent.com/", + providerURL: defaultActionsProviderURL, + wantErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := validActionsIssuer(tc.issuer, tc.providerURL) + if err != nil { + if !tc.wantErr { + t.Fatalf("unexpected error: %v", err) + } + return + } + if tc.wantErr { + t.Fatal("expected error") + } + if got != tc.issuer { + t.Fatalf("unexpected issuer, got: %q, want: %q", got, tc.issuer) + } + }) + } +} + func Test_compareStringSlice(t *testing.T) { testCases := []struct { name string diff --git a/github/oidctest.go b/github/oidctest.go index fc6b40e1fd..7db3db857a 100644 --- a/github/oidctest.go +++ b/github/oidctest.go @@ -80,7 +80,11 @@ func NewTestOIDCServer(t *testing.T, now time.Time, token *OIDCToken) (*httptest // Allow the token to override the issuer for verification testing. issuer := issuerURL if token.Issuer != "" { - issuer = token.Issuer + if strings.HasPrefix(token.Issuer, "/") { + issuer += token.Issuer + } else { + issuer = token.Issuer + } } b, err := json.Marshal(jsonToken{ @@ -143,9 +147,10 @@ func newTestOIDCServer(t *testing.T, now time.Time, f http.HandlerFunc) (*httpte t.Fatalf("unexpected error: %v", err) } c := OIDCClient{ - requestURL: requestURL, - verifierFunc: func(_ context.Context) (*oidc.IDTokenVerifier, error) { - return oidc.NewVerifier(s.URL, &testKeySet{}, &oidc.Config{ + requestURL: requestURL, + actionsProviderURL: issuerURL, + verifierFunc: func(_ context.Context, issuer string) (*oidc.IDTokenVerifier, error) { + return oidc.NewVerifier(issuer, &testKeySet{}, &oidc.Config{ Now: func() time.Time { return now }, SkipClientIDCheck: true, }), nil