Skip to content

Commit 58f33c8

Browse files
committed
bug: only require CI=true and AccessToken when making API requests (#1278)
* check when inCI and whether we required Access Token * fast fail in APIClient when CI=true and no Access Token * use apiClient constructor in search_jobs * fail fast in login / auth token when CI=true and AccessToken not set * add comment * review feedback - rename `RequireAccessToken` to `checkIfCIAccessTokenRequired` - provide more context in error on why it happened * use sentinel error * use lib errors (cherry picked from commit 0e5e006)
1 parent 3bd3cfb commit 58f33c8

File tree

8 files changed

+173
-60
lines changed

8 files changed

+173
-60
lines changed

cmd/src/auth_token.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ func init() {
5050
}
5151

5252
func resolveAuthToken(ctx context.Context, cfg *config) (string, error) {
53+
if err := cfg.requireCIAccessToken(); err != nil {
54+
return "", err
55+
}
56+
5357
if cfg.AccessToken != "" {
5458
return cfg.AccessToken, nil
5559
}

cmd/src/auth_token_test.go

Lines changed: 51 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@ func TestResolveAuthToken(t *testing.T) {
1717
newOAuthTokenRefresher = func(*oauth.Token) oauthTokenRefresher {
1818
newRefresherCalled = true
1919
return fakeOAuthTokenRefresher{}
20-
}
20+
}
2121

22-
token, err := resolveAuthToken(context.Background(), &config{
23-
AccessToken: "access-token",
24-
Endpoint: "https://example.com",
25-
})
22+
token, err := resolveAuthToken(context.Background(), &config{
23+
AccessToken: "access-token",
24+
Endpoint: "https://example.com",
25+
})
2626
if err != nil {
2727
t.Fatal(err)
2828
}
@@ -34,23 +34,45 @@ func TestResolveAuthToken(t *testing.T) {
3434
}
3535
})
3636

37+
t.Run("requires access token in CI", func(t *testing.T) {
38+
reset := stubAuthTokenDependencies(t)
39+
defer reset()
40+
41+
loadCalled := false
42+
loadOAuthToken = func(context.Context, string) (*oauth.Token, error) {
43+
loadCalled = true
44+
return nil, nil
45+
}
46+
47+
_, err := resolveAuthToken(context.Background(), &config{
48+
inCI: true,
49+
Endpoint: "https://example.com",
50+
})
51+
if err != errCIAccessTokenRequired {
52+
t.Fatalf("err = %v, want %v", err, errCIAccessTokenRequired)
53+
}
54+
if loadCalled {
55+
t.Fatal("expected OAuth token loader not to be called")
56+
}
57+
})
58+
3759
t.Run("uses stored oauth token", func(t *testing.T) {
3860
reset := stubAuthTokenDependencies(t)
3961
defer reset()
4062

41-
loadOAuthToken = func(context.Context, string) (*oauth.Token, error) {
42-
return &oauth.Token{
43-
AccessToken: "oauth-token",
44-
}, nil
63+
loadOAuthToken = func(context.Context, string) (*oauth.Token, error) {
64+
return &oauth.Token{
65+
AccessToken: "oauth-token",
66+
}, nil
4567
}
4668

4769
newOAuthTokenRefresher = func(*oauth.Token) oauthTokenRefresher {
4870
return fakeOAuthTokenRefresher{token: oauth.Token{AccessToken: "oauth-token"}}
49-
}
71+
}
5072

51-
token, err := resolveAuthToken(context.Background(), &config{
52-
Endpoint: "https://example.com",
53-
})
73+
token, err := resolveAuthToken(context.Background(), &config{
74+
Endpoint: "https://example.com",
75+
})
5476
if err != nil {
5577
t.Fatal(err)
5678
}
@@ -63,17 +85,17 @@ func TestResolveAuthToken(t *testing.T) {
6385
reset := stubAuthTokenDependencies(t)
6486
defer reset()
6587

66-
loadOAuthToken = func(context.Context, string) (*oauth.Token, error) {
67-
return &oauth.Token{AccessToken: "old-token"}, nil
68-
}
88+
loadOAuthToken = func(context.Context, string) (*oauth.Token, error) {
89+
return &oauth.Token{AccessToken: "old-token"}, nil
90+
}
6991

7092
newOAuthTokenRefresher = func(*oauth.Token) oauthTokenRefresher {
7193
return fakeOAuthTokenRefresher{token: oauth.Token{AccessToken: "new-token"}}
72-
}
94+
}
7395

74-
token, err := resolveAuthToken(context.Background(), &config{
75-
Endpoint: "https://example.com",
76-
})
96+
token, err := resolveAuthToken(context.Background(), &config{
97+
Endpoint: "https://example.com",
98+
})
7799
if err != nil {
78100
t.Fatal(err)
79101
}
@@ -86,20 +108,20 @@ func TestResolveAuthToken(t *testing.T) {
86108
reset := stubAuthTokenDependencies(t)
87109
defer reset()
88110

89-
loadOAuthToken = func(context.Context, string) (*oauth.Token, error) {
90-
return &oauth.Token{AccessToken: "old-token"}, nil
91-
}
111+
loadOAuthToken = func(context.Context, string) (*oauth.Token, error) {
112+
return &oauth.Token{AccessToken: "old-token"}, nil
113+
}
92114
newOAuthTokenRefresher = func(*oauth.Token) oauthTokenRefresher {
93115
return fakeOAuthTokenRefresher{err: fmt.Errorf("refresh failed")}
94116
}
95117

96-
_, err := resolveAuthToken(context.Background(), &config{
97-
Endpoint: "https://example.com",
98-
})
99-
if err == nil {
100-
t.Fatal("expected error")
101-
}
118+
_, err := resolveAuthToken(context.Background(), &config{
119+
Endpoint: "https://example.com",
102120
})
121+
if err == nil {
122+
t.Fatal("expected error")
123+
}
124+
})
103125
}
104126

105127
func stubAuthTokenDependencies(t *testing.T) func() {

cmd/src/login.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,10 @@ const (
9696
)
9797

9898
func loginCmd(ctx context.Context, p loginParams) error {
99+
if err := p.cfg.requireCIAccessToken(); err != nil {
100+
return err
101+
}
102+
99103
if p.cfg.ConfigFilePath != "" {
100104
fmt.Fprintln(p.out)
101105
fmt.Fprintf(p.out, "⚠️ Warning: Configuring src with a JSON file is deprecated. Please migrate to using the env vars SRC_ENDPOINT, SRC_ACCESS_TOKEN, and SRC_PROXY instead, and then remove %s. See https://github.com/sourcegraph/src-cli#readme for more information.\n", p.cfg.ConfigFilePath)

cmd/src/login_test.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,16 @@ func TestLogin(t *testing.T) {
5050
}
5151
})
5252

53+
t.Run("CI requires access token", func(t *testing.T) {
54+
out, err := check(t, &config{Endpoint: "https://example.com", inCI: true}, "https://example.com")
55+
if err != errCIAccessTokenRequired {
56+
t.Fatalf("err = %v, want %v", err, errCIAccessTokenRequired)
57+
}
58+
if out != "" {
59+
t.Fatalf("output = %q, want empty output", out)
60+
}
61+
})
62+
5363
t.Run("warning when using config file", func(t *testing.T) {
5464
out, err := check(t, &config{Endpoint: "https://example.com", ConfigFilePath: "f"}, "https://example.com")
5565
if err == nil {

cmd/src/main.go

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ var (
8383

8484
errConfigMerge = errors.New("when using a configuration file, zero or all environment variables must be set")
8585
errConfigAuthorizationConflict = errors.New("when passing an 'Authorization' additional headers, SRC_ACCESS_TOKEN must never be set")
86-
errCIAccessTokenRequired = errors.New("SRC_ACCESS_TOKEN must be set in CI")
86+
errCIAccessTokenRequired = errors.New("CI is true and SRC_ACCESS_TOKEN is not set or empty. When running in CI OAuth tokens cannot be used, only SRC_ACCESS_TOKEN. Either set CI=false or define a SRC_ACCESS_TOKEN")
8787
)
8888

8989
// commands contains all registered subcommands.
@@ -122,6 +122,7 @@ type config struct {
122122
ProxyURL *url.URL
123123
ProxyPath string
124124
ConfigFilePath string
125+
inCI bool
125126
}
126127

127128
type AuthMode int
@@ -138,16 +139,31 @@ func (c *config) AuthMode() AuthMode {
138139
return AuthModeOAuth
139140
}
140141

142+
func (c *config) InCI() bool {
143+
return c.inCI
144+
}
145+
146+
func (c *config) requireCIAccessToken() error {
147+
// In CI we typically do not have access to the keyring and the machine is also
148+
// typically headless, so OAuth tokens are not a reliable fallback.
149+
if c.InCI() && c.AuthMode() != AuthModeAccessToken {
150+
return errCIAccessTokenRequired
151+
}
152+
153+
return nil
154+
}
155+
141156
// apiClient returns an api.Client built from the configuration.
142157
func (c *config) apiClient(flags *api.Flags, out io.Writer) api.Client {
143158
opts := api.ClientOpts{
144-
Endpoint: c.Endpoint,
145-
AccessToken: c.AccessToken,
146-
AdditionalHeaders: c.AdditionalHeaders,
147-
Flags: flags,
148-
Out: out,
149-
ProxyURL: c.ProxyURL,
150-
ProxyPath: c.ProxyPath,
159+
Endpoint: c.Endpoint,
160+
AccessToken: c.AccessToken,
161+
AdditionalHeaders: c.AdditionalHeaders,
162+
Flags: flags,
163+
Out: out,
164+
ProxyURL: c.ProxyURL,
165+
ProxyPath: c.ProxyPath,
166+
RequireAccessTokenInCI: c.InCI(),
151167
}
152168

153169
// Only use OAuth if we do not have SRC_ACCESS_TOKEN set
@@ -179,6 +195,7 @@ func readConfig() (*config, error) {
179195
return nil, err
180196
}
181197
var cfg config
198+
cfg.inCI = isCI()
182199
if err == nil {
183200
cfg.ConfigFilePath = cfgPath
184201
if err := json.Unmarshal(data, &cfg); err != nil {
@@ -276,10 +293,6 @@ func readConfig() (*config, error) {
276293

277294
cfg.Endpoint = cleanEndpoint(cfg.Endpoint)
278295

279-
if isCI() && cfg.AccessToken == "" {
280-
return nil, errCIAccessTokenRequired
281-
}
282-
283296
return &cfg, nil
284297
}
285298

cmd/src/main_test.go

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
package main
22

33
import (
4+
"context"
45
"encoding/json"
6+
"errors"
7+
"io"
58
"net/url"
69
"os"
710
"path/filepath"
@@ -285,9 +288,13 @@ func TestReadConfig(t *testing.T) {
285288
wantErr: errConfigAuthorizationConflict.Error(),
286289
},
287290
{
288-
name: "CI requires access token",
289-
envCI: "1",
290-
wantErr: errCIAccessTokenRequired.Error(),
291+
name: "CI does not require access token during config read",
292+
envCI: "1",
293+
want: &config{
294+
Endpoint: "https://sourcegraph.com",
295+
AdditionalHeaders: map[string]string{},
296+
inCI: true,
297+
},
291298
},
292299
{
293300
name: "CI allows access token from config file",
@@ -300,6 +307,7 @@ func TestReadConfig(t *testing.T) {
300307
Endpoint: "https://example.com",
301308
AccessToken: "deadbeef",
302309
AdditionalHeaders: map[string]string{},
310+
inCI: true,
303311
},
304312
},
305313
}
@@ -351,8 +359,8 @@ func TestReadConfig(t *testing.T) {
351359
t.Fatal(err)
352360
}
353361

354-
config, err := readConfig()
355-
if diff := cmp.Diff(test.want, config); diff != "" {
362+
gotConfig, err := readConfig()
363+
if diff := cmp.Diff(test.want, gotConfig, cmp.AllowUnexported(config{})); diff != "" {
356364
t.Errorf("config: %v", diff)
357365
}
358366
var errMsg string
@@ -379,3 +387,36 @@ func TestConfigAuthMode(t *testing.T) {
379387
}
380388
})
381389
}
390+
391+
func TestConfigAPIClientCIAccessTokenGate(t *testing.T) {
392+
endpoint := "https://example.com"
393+
394+
t.Run("requires access token in CI", func(t *testing.T) {
395+
client := (&config{Endpoint: endpoint, inCI: true}).apiClient(nil, io.Discard)
396+
397+
_, err := client.NewHTTPRequest(context.Background(), "GET", ".api/src-cli/version", nil)
398+
if !errors.Is(err, api.ErrCIAccessTokenRequired) {
399+
t.Fatalf("NewHTTPRequest() error = %v, want %v", err, api.ErrCIAccessTokenRequired)
400+
}
401+
})
402+
403+
t.Run("allows access token in CI", func(t *testing.T) {
404+
client := (&config{Endpoint: endpoint, inCI: true, AccessToken: "abc"}).apiClient(nil, io.Discard)
405+
406+
req, err := client.NewHTTPRequest(context.Background(), "GET", ".api/src-cli/version", nil)
407+
if err != nil {
408+
t.Fatalf("NewHTTPRequest() unexpected error: %s", err)
409+
}
410+
if got := req.Header.Get("Authorization"); got != "token abc" {
411+
t.Fatalf("Authorization header = %q, want %q", got, "token abc")
412+
}
413+
})
414+
415+
t.Run("allows oauth mode outside CI", func(t *testing.T) {
416+
client := (&config{Endpoint: endpoint}).apiClient(nil, io.Discard)
417+
418+
if _, err := client.NewHTTPRequest(context.Background(), "GET", ".api/src-cli/version", nil); err != nil {
419+
t.Fatalf("NewHTTPRequest() unexpected error: %s", err)
420+
}
421+
})
422+
}

cmd/src/search_jobs.go

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,7 @@ func parseColumns(columnsFlag string) []string {
155155

156156
// createSearchJobsClient creates a reusable API client for search jobs commands
157157
func createSearchJobsClient(out *flag.FlagSet, apiFlags *api.Flags) api.Client {
158-
return api.NewClient(api.ClientOpts{
159-
Endpoint: cfg.Endpoint,
160-
AccessToken: cfg.AccessToken,
161-
Out: out.Output(),
162-
Flags: apiFlags,
163-
})
158+
return cfg.apiClient(apiFlags, out.Output())
164159
}
165160

166161
// parseSearchJobsArgs parses command arguments with the provided flag set

0 commit comments

Comments
 (0)