Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 76 additions & 6 deletions github/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package github
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
Expand All @@ -25,6 +26,7 @@ import (
"net/url"
"os"
"sort"
"strings"
"time"

"github.com/coreos/go-oidc/v3/oidc"
Expand Down Expand Up @@ -86,9 +88,12 @@ type OIDCClient struct {
// requestURL is the GitHub URL to request a OIDC token.
requestURL *url.URL

// actionsProviderURL is the expected GitHub Actions OIDC issuer base URL.
actionsProviderURL string

// verifierFunc is a factory to generate an oidc.IDTokenVerifier for token verification.
// This is used for tests.
verifierFunc func(context.Context) (*oidc.IDTokenVerifier, error)
verifierFunc func(context.Context, string) (*oidc.IDTokenVerifier, error)

// bearerToken is used to request an ID token.
bearerToken string
Expand All @@ -107,11 +112,12 @@ func NewOIDCClient() (*OIDCClient, error) {
}

c := OIDCClient{
requestURL: parsedURL,
bearerToken: os.Getenv(requestTokenEnvKey),
requestURL: parsedURL,
actionsProviderURL: defaultActionsProviderURL,
bearerToken: os.Getenv(requestTokenEnvKey),
}
c.verifierFunc = func(ctx context.Context) (*oidc.IDTokenVerifier, error) {
provider, err := oidc.NewProvider(ctx, defaultActionsProviderURL)
c.verifierFunc = func(ctx context.Context, issuer string) (*oidc.IDTokenVerifier, error) {
provider, err := oidc.NewProvider(ctx, issuer)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -172,10 +178,74 @@ func (c *OIDCClient) decodePayload(b []byte) (string, error) {
return payload.Value, nil
}

func tokenIssuer(payload string) (string, error) {
parts := strings.Split(payload, ".")
if len(parts) < 2 {
return "", fmt.Errorf("jwt parts: %d", len(parts))
}

b, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return "", fmt.Errorf("jwt payload: %w", err)
}

var claims struct {
Issuer string `json:"iss"`
}
if err := json.Unmarshal(b, &claims); err != nil {
return "", fmt.Errorf("jwt payload claims: %w", err)
}
if claims.Issuer == "" {
return "", errors.New("issuer is empty")
}

return claims.Issuer, nil
}

func validActionsIssuer(issuer, providerURL string) (string, error) {
parsedIssuer, err := url.Parse(issuer)
if err != nil {
return "", fmt.Errorf("parse issuer %q: %w", issuer, err)
}
parsedProvider, err := url.Parse(providerURL)
if err != nil {
return "", fmt.Errorf("parse provider %q: %w", providerURL, err)
}

if parsedIssuer.Scheme != parsedProvider.Scheme || parsedIssuer.Host != parsedProvider.Host {
return "", fmt.Errorf("issuer %q does not match provider %q", issuer, providerURL)
}
if parsedIssuer.User != nil || parsedIssuer.RawQuery != "" || parsedIssuer.Fragment != "" {
return "", fmt.Errorf("issuer %q contains unsupported URL components", issuer)
}

issuerPath := parsedIssuer.Path
providerPath := strings.TrimSuffix(parsedProvider.Path, "/")
if issuerPath == providerPath || issuerPath == "" && providerPath == "" {
return issuer, nil
}

suffix, ok := strings.CutPrefix(issuerPath, providerPath+"/")
if !ok || suffix == "" || strings.Contains(suffix, "/") {
return "", fmt.Errorf("issuer %q is not a supported GitHub Actions issuer", issuer)
}

return issuer, nil
}

// verifyToken verifies the token contents and signature.
func (c *OIDCClient) verifyToken(ctx context.Context, audience []string, payload string) (*oidc.IDToken, error) {
issuer, err := tokenIssuer(payload)
if err != nil {
return nil, fmt.Errorf("%w: extracting issuer: %w", errVerify, err)
}
issuer, err = validActionsIssuer(issuer, c.actionsProviderURL)
if err != nil {
return nil, fmt.Errorf("%w: validating issuer: %w", errVerify, err)
}

// Verify the token.
verifier, err := c.verifierFunc(ctx)
verifier, err := c.verifierFunc(ctx, issuer)
if err != nil {
return nil, fmt.Errorf("%w: creating verifier: %w", errVerify, err)
}
Expand Down
94 changes: 93 additions & 1 deletion github/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -130,6 +131,19 @@ func TestToken(t *testing.T) {
ActorID: "4567",
},
},
{
name: "enterprise issuer slug",
audience: []string{"hoge"},
token: &OIDCToken{
Issuer: "/octocat-inc",
Audience: []string{"hoge"},
Expiry: now.Add(1 * time.Hour),
JobWorkflowRef: "pico",
RepositoryID: "1234",
RepositoryOwnerID: "4321",
ActorID: "4567",
},
},
{
name: "no repository id claim",
audience: []string{"hoge"},
Expand Down Expand Up @@ -285,7 +299,11 @@ func TestToken(t *testing.T) {
tc.err(err)
} else {
// Successful response, as expected. Check token.
if want, got := tc.token, token; !tokenEqual(s.URL, want, got) {
expectedIssuer := s.URL
if strings.HasPrefix(tc.token.Issuer, "/") {
expectedIssuer += tc.token.Issuer
}
if want, got := tc.token, token; !tokenEqual(expectedIssuer, want, got) {
t.Errorf("unexpected workflow ref\nwant: %#v\ngot: %#v\ndiff:\n%v", want, got, cmp.Diff(want, got))
}
}
Expand All @@ -294,6 +312,80 @@ func TestToken(t *testing.T) {
}
}

func Test_validActionsIssuer(t *testing.T) {
testCases := []struct {
name string
issuer string
providerURL string
wantErr bool
}{
{
name: "default issuer",
issuer: "https://token.actions.githubusercontent.com",
providerURL: defaultActionsProviderURL,
},
{
name: "enterprise slug issuer",
issuer: "https://token.actions.githubusercontent.com/octocat-inc",
providerURL: defaultActionsProviderURL,
},
{
name: "wrong host",
issuer: "https://token.actions.githubusercontent.com.evil/octocat-inc",
providerURL: defaultActionsProviderURL,
wantErr: true,
},
{
name: "userinfo",
issuer: "https://user:pass@token.actions.githubusercontent.com/octocat-inc",
providerURL: defaultActionsProviderURL,
wantErr: true,
},
{
name: "query string",
issuer: "https://token.actions.githubusercontent.com/octocat-inc?foo=bar",
providerURL: defaultActionsProviderURL,
wantErr: true,
},
{
name: "multiple path segments",
issuer: "https://token.actions.githubusercontent.com/octocat-inc/extra",
providerURL: defaultActionsProviderURL,
wantErr: true,
},
{
name: "encoded slash",
issuer: "https://token.actions.githubusercontent.com/octocat-inc%2Fextra",
providerURL: defaultActionsProviderURL,
wantErr: true,
},
{
name: "empty path segment",
issuer: "https://token.actions.githubusercontent.com/",
providerURL: defaultActionsProviderURL,
wantErr: true,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got, err := validActionsIssuer(tc.issuer, tc.providerURL)
if err != nil {
if !tc.wantErr {
t.Fatalf("unexpected error: %v", err)
}
return
}
if tc.wantErr {
t.Fatal("expected error")
}
if got != tc.issuer {
t.Fatalf("unexpected issuer, got: %q, want: %q", got, tc.issuer)
}
})
}
}

func Test_compareStringSlice(t *testing.T) {
testCases := []struct {
name string
Expand Down
13 changes: 9 additions & 4 deletions github/oidctest.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,11 @@ func NewTestOIDCServer(t *testing.T, now time.Time, token *OIDCToken) (*httptest
// Allow the token to override the issuer for verification testing.
issuer := issuerURL
if token.Issuer != "" {
issuer = token.Issuer
if strings.HasPrefix(token.Issuer, "/") {
issuer += token.Issuer
} else {
issuer = token.Issuer
}
}

b, err := json.Marshal(jsonToken{
Expand Down Expand Up @@ -143,9 +147,10 @@ func newTestOIDCServer(t *testing.T, now time.Time, f http.HandlerFunc) (*httpte
t.Fatalf("unexpected error: %v", err)
}
c := OIDCClient{
requestURL: requestURL,
verifierFunc: func(_ context.Context) (*oidc.IDTokenVerifier, error) {
return oidc.NewVerifier(s.URL, &testKeySet{}, &oidc.Config{
requestURL: requestURL,
actionsProviderURL: issuerURL,
verifierFunc: func(_ context.Context, issuer string) (*oidc.IDTokenVerifier, error) {
return oidc.NewVerifier(issuer, &testKeySet{}, &oidc.Config{
Now: func() time.Time { return now },
SkipClientIDCheck: true,
}), nil
Expand Down