diff --git a/cmd/src/login.go b/cmd/src/login.go index ab5a097c71..1c5c3092bc 100644 --- a/cmd/src/login.go +++ b/cmd/src/login.go @@ -65,7 +65,7 @@ Examples: } func loginCmd(ctx context.Context, cfg *config, client api.Client, endpointArg string, out io.Writer) error { - endpointArg = cleanEndpoint(endpointArg) + endpointArg = strings.TrimSuffix(endpointArg, "/") printProblem := func(problem string) { fmt.Fprintf(out, "❌ Problem: %s\n", problem) diff --git a/cmd/src/login_test.go b/cmd/src/login_test.go index 37fbf7a703..3069b1cdb9 100644 --- a/cmd/src/login_test.go +++ b/cmd/src/login_test.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "net/http/httptest" + "net/url" "strings" "testing" @@ -63,7 +64,8 @@ func TestLogin(t *testing.T) { defer s.Close() endpoint := s.URL - out, err := check(t, &config{Endpoint: endpoint, AccessToken: "x"}, endpoint) + u, _ := url.ParseRequestURI(endpoint) + out, err := check(t, &config{Endpoint: endpoint, EndpointURL: u, AccessToken: "x"}, endpoint) if err != cmderrors.ExitCode1 { t.Fatal(err) } @@ -82,7 +84,8 @@ func TestLogin(t *testing.T) { defer s.Close() endpoint := s.URL - out, err := check(t, &config{Endpoint: endpoint, AccessToken: "x"}, endpoint) + u, _ := url.ParseRequestURI(endpoint) + out, err := check(t, &config{Endpoint: endpoint, EndpointURL: u, AccessToken: "x"}, endpoint) if err != nil { t.Fatal(err) } diff --git a/cmd/src/main.go b/cmd/src/main.go index edfb1073d7..42c5c0349c 100644 --- a/cmd/src/main.go +++ b/cmd/src/main.go @@ -3,9 +3,11 @@ package main import ( "encoding/json" "flag" + "fmt" "io" "log" "net" + "net/http" "net/url" "os" "path/filepath" @@ -107,6 +109,20 @@ func normalizeDashHelp(args []string) []string { return args } +func parseEndpoint(endpoint string) (*url.URL, error) { + u, err := url.ParseRequestURI(strings.TrimSuffix(endpoint, "/")) + if err != nil { + return nil, err + } + if !(u.Scheme == "http" || u.Scheme == "https") { + return nil, errors.Newf("Invalid scheme %s. Require http or https", u.Scheme) + } + if u.Host == "" { + return nil, errors.Newf("Empty host") + } + return u, nil +} + var cfg *config // config represents the config format. @@ -118,12 +134,13 @@ type config struct { ProxyURL *url.URL ProxyPath string ConfigFilePath string + EndpointURL *url.URL } // apiClient returns an api.Client built from the configuration. func (c *config) apiClient(flags *api.Flags, out io.Writer) api.Client { return api.NewClient(api.ClientOpts{ - Endpoint: c.Endpoint, + EndpointURL: c.EndpointURL, AccessToken: c.AccessToken, AdditionalHeaders: c.AdditionalHeaders, Flags: flags, @@ -133,7 +150,8 @@ func (c *config) apiClient(flags *api.Flags, out io.Writer) api.Client { }) } -// readConfig reads the config file from the given path. +// readConfig reads the config from the standard config file, the (deprecated) user-specified config file, +// the environment variables, and the (deprecated) command-line flags. func readConfig() (*config, error) { cfgFile := *configPath userSpecified := *configPath != "" @@ -189,9 +207,21 @@ func readConfig() (*config, error) { cfg.Proxy = envProxy } + // Lastly, apply endpoint flag if set + if endpoint != nil && *endpoint != "" { + cfg.Endpoint = *endpoint + } + + if endpointURL, err := parseEndpoint(cfg.Endpoint); err != nil { + return nil, errors.Newf("invalid endpoint: %s", cfg.Endpoint) + } else { + cfg.EndpointURL = endpointURL + cfg.Endpoint = endpointURL.String() + } + if cfg.Proxy != "" { - parseEndpoint := func(endpoint string) (scheme string, address string) { + parseProxyEndpoint := func(endpoint string) (scheme string, address string) { parts := strings.SplitN(endpoint, "://", 2) if len(parts) == 2 { return parts[0], parts[1] @@ -205,7 +235,7 @@ func readConfig() (*config, error) { return slices.Contains(urlSchemes, scheme) } - scheme, address := parseEndpoint(cfg.Proxy) + scheme, address := parseProxyEndpoint(cfg.Proxy) if isURLScheme(scheme) { endpoint := cfg.Proxy @@ -227,11 +257,19 @@ func readConfig() (*config, error) { return nil, errors.Newf("Invalid proxy configuration: %w", err) } if !isValidUDS { - return nil, errors.Newf("invalid proxy socket: %s", path) + return nil, errors.Newf("Invalid proxy socket: %s", path) } cfg.ProxyPath = path } else { - return nil, errors.Newf("invalid proxy endpoint: %s", cfg.Proxy) + return nil, errors.Newf("Invalid proxy endpoint: %s", cfg.Proxy) + } + } else { + // no SRC_PROXY; check for the standard proxy env variables HTTP_PROXY, HTTPS_PROXY, and NO_PROXY + if u, err := http.ProxyFromEnvironment(&http.Request{URL: cfg.EndpointURL}); err != nil { + // when there's an error, the value for the env variable is not a legit URL + return nil, fmt.Errorf("Invalid HTTP_PROXY or HTTPS_PROXY value: %w", err) + } else { + cfg.ProxyURL = u } } @@ -242,20 +280,9 @@ func readConfig() (*config, error) { return nil, errConfigAuthorizationConflict } - // Lastly, apply endpoint flag if set - if endpoint != nil && *endpoint != "" { - cfg.Endpoint = *endpoint - } - - cfg.Endpoint = cleanEndpoint(cfg.Endpoint) - return &cfg, nil } -func cleanEndpoint(urlStr string) string { - return strings.TrimSuffix(urlStr, "/") -} - // isValidUnixSocket checks if the given path is a valid Unix socket. // // Parameters: diff --git a/cmd/src/main_test.go b/cmd/src/main_test.go index c37c36792a..686500ca01 100644 --- a/cmd/src/main_test.go +++ b/cmd/src/main_test.go @@ -42,7 +42,11 @@ func TestReadConfig(t *testing.T) { { name: "defaults", want: &config{ - Endpoint: "https://sourcegraph.com", + Endpoint: "https://sourcegraph.com", + EndpointURL: &url.URL{ + Scheme: "https", + Host: "sourcegraph.com", + }, AdditionalHeaders: map[string]string{}, }, }, @@ -54,7 +58,11 @@ func TestReadConfig(t *testing.T) { Proxy: "https://proxy.com:8080", }, want: &config{ - Endpoint: "https://example.com", + Endpoint: "https://example.com", + EndpointURL: &url.URL{ + Scheme: "https", + Host: "example.com", + }, AccessToken: "deadbeef", AdditionalHeaders: map[string]string{}, Proxy: "https://proxy.com:8080", @@ -95,6 +103,10 @@ func TestReadConfig(t *testing.T) { envProxy: "socks5://other.proxy.com:9999", want: &config{ Endpoint: "https://example.com", + EndpointURL: &url.URL{ + Scheme: "https", + Host: "example.com", + }, AccessToken: "deadbeef", Proxy: "socks5://other.proxy.com:9999", ProxyPath: "", @@ -117,6 +129,10 @@ func TestReadConfig(t *testing.T) { envProxy: "socks5://other.proxy.com:9999", want: &config{ Endpoint: "https://override.com", + EndpointURL: &url.URL{ + Scheme: "https", + Host: "override.com", + }, AccessToken: "abc", Proxy: "socks5://other.proxy.com:9999", ProxyPath: "", @@ -132,6 +148,10 @@ func TestReadConfig(t *testing.T) { envToken: "abc", want: &config{ Endpoint: "https://sourcegraph.com", + EndpointURL: &url.URL{ + Scheme: "https", + Host: "sourcegraph.com", + }, AccessToken: "abc", AdditionalHeaders: map[string]string{}, }, @@ -141,6 +161,10 @@ func TestReadConfig(t *testing.T) { envEndpoint: "https://example.com", want: &config{ Endpoint: "https://example.com", + EndpointURL: &url.URL{ + Scheme: "https", + Host: "example.com", + }, AccessToken: "", AdditionalHeaders: map[string]string{}, }, @@ -150,6 +174,10 @@ func TestReadConfig(t *testing.T) { envProxy: "https://proxy.com:8080", want: &config{ Endpoint: "https://sourcegraph.com", + EndpointURL: &url.URL{ + Scheme: "https", + Host: "sourcegraph.com", + }, AccessToken: "", Proxy: "https://proxy.com:8080", ProxyPath: "", @@ -167,6 +195,10 @@ func TestReadConfig(t *testing.T) { envProxy: "https://proxy.com:8080", want: &config{ Endpoint: "https://example.com", + EndpointURL: &url.URL{ + Scheme: "https", + Host: "example.com", + }, AccessToken: "abc", Proxy: "https://proxy.com:8080", ProxyPath: "", @@ -182,6 +214,10 @@ func TestReadConfig(t *testing.T) { envProxy: "unix://" + socketPath, want: &config{ Endpoint: "https://sourcegraph.com", + EndpointURL: &url.URL{ + Scheme: "https", + Host: "sourcegraph.com", + }, Proxy: "unix://" + socketPath, ProxyPath: socketPath, ProxyURL: nil, @@ -193,6 +229,10 @@ func TestReadConfig(t *testing.T) { envProxy: socketPath, want: &config{ Endpoint: "https://sourcegraph.com", + EndpointURL: &url.URL{ + Scheme: "https", + Host: "sourcegraph.com", + }, Proxy: socketPath, ProxyPath: socketPath, ProxyURL: nil, @@ -204,6 +244,10 @@ func TestReadConfig(t *testing.T) { envProxy: "socks://localhost:1080", want: &config{ Endpoint: "https://sourcegraph.com", + EndpointURL: &url.URL{ + Scheme: "https", + Host: "sourcegraph.com", + }, Proxy: "socks://localhost:1080", ProxyPath: "", ProxyURL: &url.URL{ @@ -218,6 +262,10 @@ func TestReadConfig(t *testing.T) { envProxy: "socks5h://localhost:1080", want: &config{ Endpoint: "https://sourcegraph.com", + EndpointURL: &url.URL{ + Scheme: "https", + Host: "sourcegraph.com", + }, Proxy: "socks5h://localhost:1080", ProxyPath: "", ProxyURL: &url.URL{ @@ -237,6 +285,10 @@ func TestReadConfig(t *testing.T) { }, want: &config{ Endpoint: "https://override.com", + EndpointURL: &url.URL{ + Scheme: "https", + Host: "override.com", + }, AccessToken: "deadbeef", AdditionalHeaders: map[string]string{}, }, @@ -248,6 +300,10 @@ func TestReadConfig(t *testing.T) { envToken: "abc", want: &config{ Endpoint: "https://override.com", + EndpointURL: &url.URL{ + Scheme: "https", + Host: "override.com", + }, AccessToken: "abc", AdditionalHeaders: map[string]string{}, }, @@ -260,6 +316,10 @@ func TestReadConfig(t *testing.T) { envFooHeader: "bar", want: &config{ Endpoint: "https://override.com", + EndpointURL: &url.URL{ + Scheme: "https", + Host: "override.com", + }, AccessToken: "abc", AdditionalHeaders: map[string]string{"foo": "bar"}, }, @@ -272,6 +332,10 @@ func TestReadConfig(t *testing.T) { envHeaders: "foo:bar\nfoo-bar:bar-baz", want: &config{ Endpoint: "https://override.com", + EndpointURL: &url.URL{ + Scheme: "https", + Host: "override.com", + }, AccessToken: "abc", AdditionalHeaders: map[string]string{"foo-bar": "bar-baz", "foo": "bar"}, }, diff --git a/cmd/src/search_jobs.go b/cmd/src/search_jobs.go index 96c5d9070a..5c3ea72de1 100644 --- a/cmd/src/search_jobs.go +++ b/cmd/src/search_jobs.go @@ -156,7 +156,7 @@ func parseColumns(columnsFlag string) []string { // createSearchJobsClient creates a reusable API client for search jobs commands func createSearchJobsClient(out *flag.FlagSet, apiFlags *api.Flags) api.Client { return api.NewClient(api.ClientOpts{ - Endpoint: cfg.Endpoint, + EndpointURL: cfg.EndpointURL, AccessToken: cfg.AccessToken, Out: out.Output(), Flags: apiFlags, @@ -259,11 +259,11 @@ func init() { usage := `'src search-jobs' is a tool that manages search jobs on a Sourcegraph instance. Usage: - + src search-jobs command [command options] - + The commands are: - + cancel cancels a search job by ID create creates a search job delete deletes a search job by ID @@ -272,11 +272,11 @@ func init() { logs fetches logs for a search job by ID restart restarts a search job by ID results fetches results for a search job by ID - + Common options for all commands: -c Select columns to display (e.g., -c id,query,state,username) -json Output results in JSON format - + Use "src search-jobs [command] -h" for more information about a command. ` diff --git a/cmd/src/search_jobs_logs.go b/cmd/src/search_jobs_logs.go index 2fe9b6bfee..ba90072b51 100644 --- a/cmd/src/search_jobs_logs.go +++ b/cmd/src/search_jobs_logs.go @@ -12,7 +12,7 @@ import ( ) // fetchJobLogs retrieves logs for a search job from its log URL -func fetchJobLogs(jobID string, logURL string) (io.ReadCloser, error) { +func fetchJobLogs(client api.Client, jobID string, logURL string) (io.ReadCloser, error) { if logURL == "" { return nil, fmt.Errorf("no logs URL found for search job %s", jobID) } @@ -24,7 +24,7 @@ func fetchJobLogs(jobID string, logURL string) (io.ReadCloser, error) { req.Header.Add("Authorization", "token "+cfg.AccessToken) - resp, err := http.DefaultClient.Do(req) + resp, err := client.Do(req) if err != nil { return nil, err } @@ -88,7 +88,7 @@ func init() { return fmt.Errorf("no job found with ID %s", jobID) } - logsData, err := fetchJobLogs(jobID, job.LogURL) + logsData, err := fetchJobLogs(client, jobID, job.LogURL) if err != nil { return err } diff --git a/cmd/src/search_jobs_results.go b/cmd/src/search_jobs_results.go index 2e45f8219f..22800c6831 100644 --- a/cmd/src/search_jobs_results.go +++ b/cmd/src/search_jobs_results.go @@ -12,7 +12,7 @@ import ( ) // fetchJobResults retrieves results for a search job from its results URL -func fetchJobResults(jobID string, resultsURL string) (io.ReadCloser, error) { +func fetchJobResults(client api.Client, jobID string, resultsURL string) (io.ReadCloser, error) { if resultsURL == "" { return nil, fmt.Errorf("no results URL found for search job %s", jobID) } @@ -24,7 +24,7 @@ func fetchJobResults(jobID string, resultsURL string) (io.ReadCloser, error) { req.Header.Add("Authorization", "token "+cfg.AccessToken) - resp, err := http.DefaultClient.Do(req) + resp, err := client.Do(req) if err != nil { return nil, err } @@ -90,7 +90,7 @@ func init() { return fmt.Errorf("no job found with ID %s", jobID) } - resultsData, err := fetchJobResults(jobID, job.URL) + resultsData, err := fetchJobResults(client, jobID, job.URL) if err != nil { return err } diff --git a/cmd/src/search_stream_test.go b/cmd/src/search_stream_test.go index 1653b273ac..f815fcbe7e 100644 --- a/cmd/src/search_stream_test.go +++ b/cmd/src/search_stream_test.go @@ -6,6 +6,7 @@ import ( "net" "net/http" "net/http/httptest" + "net/url" "os" "testing" @@ -126,8 +127,10 @@ func TestSearchStream(t *testing.T) { s := testServer(t, http.HandlerFunc(mockStreamHandler)) defer s.Close() + u, _ := url.ParseRequestURI(s.URL) cfg = &config{ - Endpoint: s.URL, + Endpoint: s.URL, + EndpointURL: u, } defer func() { cfg = nil }() diff --git a/internal/api/api.go b/internal/api/api.go index 5f750c1d4a..23af64c0b1 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -71,7 +71,7 @@ type request struct { // ClientOpts encapsulates the options given to NewClient. type ClientOpts struct { - Endpoint string + EndpointURL *url.URL AccessToken string AdditionalHeaders map[string]string @@ -98,10 +98,17 @@ func buildTransport(opts ClientOpts, flags *Flags) *http.Transport { transport.TLSClientConfig = &tls.Config{} } - if opts.ProxyURL != nil || opts.ProxyPath != "" { + if opts.ProxyPath != "" || (opts.ProxyURL != nil && opts.ProxyURL.Scheme == "https") { + // Use our custom dialer for: + // - unix socket proxies + // - TLS=enabled proxies, to force HTTP/1.1 for the CONNECT tunnel. + // Many TLS-enabled proxy servers don't support HTTP/2 CONNECT, + // which Go may negotiate via ALPN, resulting in connection errors. transport = withProxyTransport(transport, opts.ProxyURL, opts.ProxyPath) } + // For http:// and socks5:// proxies, the cloned + // transport's default Proxy handles them correctly without intervention. return transport } @@ -124,7 +131,7 @@ func NewClient(opts ClientOpts) Client { return &client{ opts: ClientOpts{ - Endpoint: opts.Endpoint, + EndpointURL: opts.EndpointURL, AccessToken: opts.AccessToken, AdditionalHeaders: opts.AdditionalHeaders, Flags: flags, @@ -159,7 +166,7 @@ func (c *client) NewHTTPRequest(ctx context.Context, method, p string, body io.R } func (c *client) createHTTPRequest(ctx context.Context, method, p string, body io.Reader) (*http.Request, error) { - req, err := http.NewRequestWithContext(ctx, method, strings.TrimRight(c.opts.Endpoint, "/")+"/"+p, body) + req, err := http.NewRequestWithContext(ctx, method, strings.TrimRight(c.opts.EndpointURL.String(), "/")+"/"+p, body) if err != nil { return nil, err } @@ -334,6 +341,6 @@ func (r *request) curlCmd() (string, error) { s += fmt.Sprintf(" %s \\\n", shellquote.Join("-H", k+": "+v)) } s += fmt.Sprintf(" %s \\\n", shellquote.Join("-d", string(data))) - s += fmt.Sprintf(" %s", shellquote.Join(r.client.opts.Endpoint+"/.api/graphql")) + s += fmt.Sprintf(" %s", shellquote.Join(r.client.opts.EndpointURL.String()+"/.api/graphql")) return s, nil } diff --git a/internal/api/proxy.go b/internal/api/proxy.go index 9589b9beb5..3f53328f29 100644 --- a/internal/api/proxy.go +++ b/internal/api/proxy.go @@ -6,17 +6,40 @@ import ( "crypto/tls" "encoding/base64" "fmt" + "io" "net" "net/http" "net/url" ) +type connWithBufferedReader struct { + net.Conn + r *bufio.Reader +} + +func (c *connWithBufferedReader) Read(p []byte) (int, error) { + return c.r.Read(p) +} + +// proxyDialAddr returns proxyURL.Host with a default port appended if one is +// not already present (443 for https, 80 for http). +func proxyDialAddr(proxyURL *url.URL) string { + // net.SplitHostPort returns an error when the input doesn't contain a port + if _, _, err := net.SplitHostPort(proxyURL.Host); err == nil { + return proxyURL.Host + } + if proxyURL.Scheme == "https" { + return net.JoinHostPort(proxyURL.Hostname(), "443") + } + return net.JoinHostPort(proxyURL.Hostname(), "80") +} + // withProxyTransport modifies the given transport to handle proxying of unix, socks5 and http connections. // // Note: baseTransport is considered to be a clone created with transport.Clone() // -// - If a the proxyPath is not empty, a unix socket proxy is created. -// - Otherwise, the proxyURL is used to determine if we should proxy socks5 / http connections +// - If proxyPath is not empty, a unix socket proxy is created. +// - Otherwise, proxyURL is used to determine if we should proxy socks5 / http connections func withProxyTransport(baseTransport *http.Transport, proxyURL *url.URL, proxyPath string) *http.Transport { handshakeTLS := func(ctx context.Context, conn net.Conn, addr string) (net.Conn, error) { // Extract the hostname (without the port) for TLS SNI @@ -24,12 +47,17 @@ func withProxyTransport(baseTransport *http.Transport, proxyURL *url.URL, proxyP if err != nil { return nil, err } - tlsConn := tls.Client(conn, &tls.Config{ - ServerName: host, - // Pull InsecureSkipVerify from the target host transport - // so that insecure-skip-verify flag settings are honored for the proxy server - InsecureSkipVerify: baseTransport.TLSClientConfig.InsecureSkipVerify, - }) + cfg := baseTransport.TLSClientConfig.Clone() + if cfg.ServerName == "" { + cfg.ServerName = host + } + // Preserve HTTP/2 negotiation to the origin when ForceAttemptHTTP2 + // is enabled. Without this, the manual TLS handshake would not + // advertise h2 via ALPN, silently forcing HTTP/1.1. + if baseTransport.ForceAttemptHTTP2 && len(cfg.NextProtos) == 0 { + cfg.NextProtos = []string{"h2", "http/1.1"} + } + tlsConn := tls.Client(conn, cfg) if err := tlsConn.HandshakeContext(ctx); err != nil { return nil, err } @@ -59,62 +87,74 @@ func withProxyTransport(baseTransport *http.Transport, proxyURL *url.URL, proxyP baseTransport.Proxy = http.ProxyURL(proxyURL) case "http", "https": dial := func(ctx context.Context, network, addr string) (net.Conn, error) { - // Dial the proxy - d := net.Dialer{} - conn, err := d.DialContext(ctx, "tcp", proxyURL.Host) + // Dial the proxy. For https:// proxies, we TLS-connect to the + // proxy itself and force ALPN to HTTP/1.1 to prevent Go from + // negotiating HTTP/2 for the CONNECT tunnel. Many proxy servers + // don't support HTTP/2 CONNECT, and Go's default Transport.Proxy + // would negotiate h2 via ALPN when TLS-connecting to an https:// + // proxy, causing "bogus greeting" errors. For http:// proxies, + // CONNECT is always HTTP/1.1 over plain TCP so this isn't needed. + // The target connection (e.g. to sourcegraph.com) still negotiates + // HTTP/2 normally through the established tunnel. + proxyAddr := proxyDialAddr(proxyURL) + + var conn net.Conn + var err error + if proxyURL.Scheme == "https" { + raw, dialErr := (&net.Dialer{}).DialContext(ctx, "tcp", proxyAddr) + if dialErr != nil { + return nil, dialErr + } + cfg := baseTransport.TLSClientConfig.Clone() + cfg.NextProtos = []string{"http/1.1"} + if cfg.ServerName == "" { + cfg.ServerName = proxyURL.Hostname() + } + tlsConn := tls.Client(raw, cfg) + if err := tlsConn.HandshakeContext(ctx); err != nil { + raw.Close() + return nil, err + } + conn = tlsConn + } else { + conn, err = (&net.Dialer{}).DialContext(ctx, "tcp", proxyAddr) + } if err != nil { return nil, err } - // this is the whole point of manually dialing the HTTP(S) proxy: - // being able to force HTTP/1. - // When relying on Transport.Proxy, the protocol is always HTTP/2, - // but many proxy servers don't support HTTP/2. - // We don't want to disable HTTP/2 in general because we want to use it when - // connecting to the Sourcegraph API, using HTTP/1 for the proxy connection only. - protocol := "HTTP/1.1" - - // CONNECT is the HTTP method used to set up a tunneling connection with a proxy - method := "CONNECT" - - // Manually writing out the HTTP commands because it's not complicated, - // and http.Request has some janky behavior: - // - ignores the Proto field and hard-codes the protocol to HTTP/1.1 - // - ignores the Host Header (Header.Set("Host", host)) and uses URL.Host instead. - // - When the Host field is set, overrides the URL field - connectReq := fmt.Sprintf("%s %s %s\r\n", method, addr, protocol) - - // A Host header is required per RFC 2616, section 14.23 - connectReq += fmt.Sprintf("Host: %s\r\n", addr) - - // use authentication if proxy credentials are present + connectReq := &http.Request{ + Method: "CONNECT", + URL: &url.URL{Opaque: addr}, + Host: addr, + Header: make(http.Header), + } if proxyURL.User != nil { password, _ := proxyURL.User.Password() auth := base64.StdEncoding.EncodeToString([]byte(proxyURL.User.Username() + ":" + password)) - connectReq += fmt.Sprintf("Proxy-Authorization: Basic %s\r\n", auth) + connectReq.Header.Set("Proxy-Authorization", "Basic "+auth) } - - // finish up with an extra carriage return + newline, as per RFC 7230, section 3 - connectReq += "\r\n" - - // Send the CONNECT request to the proxy to establish the tunnel - if _, err := conn.Write([]byte(connectReq)); err != nil { + if err := connectReq.Write(conn); err != nil { conn.Close() return nil, err } - // Read and check the response from the proxy - resp, err := http.ReadResponse(bufio.NewReader(conn), nil) + br := bufio.NewReader(conn) + resp, err := http.ReadResponse(br, nil) if err != nil { conn.Close() return nil, err } if resp.StatusCode != http.StatusOK { + // For non-200, it's safe/appropriate to close the body (it’s a real response body here). + // Try to read a bit (4k bytes) to include in the error message. + b, _ := io.ReadAll(io.LimitReader(resp.Body, 4<<10)) + resp.Body.Close() conn.Close() - return nil, fmt.Errorf("failed to connect to proxy %v: %v", proxyURL, resp.Status) + return nil, fmt.Errorf("failed to connect to proxy %s: %s: %q", proxyURL.Redacted(), resp.Status, b) } - resp.Body.Close() - return conn, nil + // 200 CONNECT: do NOT resp.Body.Close(); it would interfere with the tunnel. + return &connWithBufferedReader{Conn: conn, r: br}, nil } dialTLS := func(ctx context.Context, network, addr string) (net.Conn, error) { // Dial the underlying connection through the proxy diff --git a/internal/api/proxy_test.go b/internal/api/proxy_test.go new file mode 100644 index 0000000000..d1d9ee6348 --- /dev/null +++ b/internal/api/proxy_test.go @@ -0,0 +1,441 @@ +package api + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "fmt" + "io" + "math/big" + "net" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" +) + +// waitForServerReady polls the server until it's ready to accept connections +func waitForServerReady(t *testing.T, addr string, useTLS bool, timeout time.Duration) { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + for { + select { + case <-ctx.Done(): + t.Fatalf("server at %s did not become ready within %v", addr, timeout) + default: + } + + var conn net.Conn + var err error + + if useTLS { + conn, err = tls.Dial("tcp", addr, &tls.Config{InsecureSkipVerify: true}) + } else { + conn, err = net.Dial("tcp", addr) + } + + if err == nil { + conn.Close() + return // Server is ready + } + + time.Sleep(1 * time.Millisecond) + } +} + +// startCONNECTProxy starts an HTTP or HTTPS CONNECT proxy on a random port. +// It returns the proxy URL and a channel that receives the protocol observed by +// the proxy handler for each CONNECT request. +func startCONNECTProxy(t *testing.T, useTLS bool) (proxyURL *url.URL, obsCh <-chan string) { + t.Helper() + + ch := make(chan string, 10) + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + select { + case ch <- r.Proto: + default: + } + + if r.Method != http.MethodConnect { + http.Error(w, "expected CONNECT", http.StatusMethodNotAllowed) + return + } + + destConn, err := net.DialTimeout("tcp", r.Host, 10*time.Second) + if err != nil { + http.Error(w, err.Error(), http.StatusBadGateway) + return + } + defer destConn.Close() + + hijacker, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "hijacking not supported", http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) + clientConn, _, err := hijacker.Hijack() + if err != nil { + return + } + defer clientConn.Close() + + done := make(chan struct{}, 2) + go func() { io.Copy(destConn, clientConn); done <- struct{}{} }() + go func() { io.Copy(clientConn, destConn); done <- struct{}{} }() + <-done + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("proxy listen: %v", err) + } + + srv := &http.Server{Handler: handler} + + if useTLS { + cert := generateTestCert(t, "127.0.0.1") + srv.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}} + go srv.ServeTLS(ln, "", "") + } else { + go srv.Serve(ln) + } + t.Cleanup(func() { srv.Close() }) + + // Wait for the server to be ready + waitForServerReady(t, ln.Addr().String(), useTLS, 5*time.Second) + + scheme := "http" + if useTLS { + scheme = "https" + } + pURL, _ := url.Parse(fmt.Sprintf("%s://%s", scheme, ln.Addr().String())) + return pURL, ch +} + +// startCONNECTProxyWithAuth is like startCONNECTProxy but requires +// Proxy-Authorization with the given username and password. +func startCONNECTProxyWithAuth(t *testing.T, useTLS bool, wantUser, wantPass string) (proxyURL *url.URL) { + t.Helper() + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodConnect { + http.Error(w, "expected CONNECT", http.StatusMethodNotAllowed) + return + } + + authHeader := r.Header.Get("Proxy-Authorization") + wantAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(wantUser+":"+wantPass)) + if authHeader != wantAuth { + http.Error(w, "proxy auth required", http.StatusProxyAuthRequired) + return + } + + destConn, err := net.DialTimeout("tcp", r.Host, 10*time.Second) + if err != nil { + http.Error(w, err.Error(), http.StatusBadGateway) + return + } + defer destConn.Close() + + hijacker, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "hijacking not supported", http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) + clientConn, _, err := hijacker.Hijack() + if err != nil { + return + } + defer clientConn.Close() + + done := make(chan struct{}, 2) + go func() { io.Copy(destConn, clientConn); done <- struct{}{} }() + go func() { io.Copy(clientConn, destConn); done <- struct{}{} }() + <-done + }) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("proxy listen: %v", err) + } + + srv := &http.Server{Handler: handler} + + if useTLS { + cert := generateTestCert(t, "127.0.0.1") + srv.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}} + go srv.ServeTLS(ln, "", "") + } else { + go srv.Serve(ln) + } + t.Cleanup(func() { srv.Close() }) + + // Wait for the server to be ready + waitForServerReady(t, ln.Addr().String(), useTLS, 5*time.Second) + + scheme := "http" + if useTLS { + scheme = "https" + } + pURL, _ := url.Parse(fmt.Sprintf("%s://%s@%s", scheme, url.UserPassword(wantUser, wantPass).String(), ln.Addr().String())) + return pURL +} + +func generateTestCert(t *testing.T, host string) tls.Certificate { + t.Helper() + + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("generate key: %v", err) + } + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: host}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(1 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + IPAddresses: []net.IP{net.ParseIP(host)}, + } + certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + if err != nil { + t.Fatalf("create cert: %v", err) + } + return tls.Certificate{ + Certificate: [][]byte{certDER}, + PrivateKey: key, + } +} + +// newTestTransport creates a base transport suitable for proxy tests. +func newTestTransport() *http.Transport { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + return transport +} + +// startTargetServer starts an HTTPS server (with HTTP/2 enabled) that +// responds with "ok" to GET /. +func startTargetServer(t *testing.T) *httptest.Server { + t.Helper() + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "ok") + })) + srv.EnableHTTP2 = true + srv.StartTLS() + t.Cleanup(srv.Close) + return srv +} + +func TestWithProxyTransport_HTTPProxy(t *testing.T) { + target := startTargetServer(t) + proxyURL, obsCh := startCONNECTProxy(t, false) + + transport := withProxyTransport(newTestTransport(), proxyURL, "") + client := &http.Client{Transport: transport, Timeout: 10 * time.Second} + + resp, err := client.Get(target.URL) + if err != nil { + t.Fatalf("GET through http proxy: %v", err) + } + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected 200, got %d", resp.StatusCode) + } + if got := strings.TrimSpace(string(body)); got != "ok" { + t.Errorf("expected body 'ok', got %q", got) + } + + select { + case proto := <-obsCh: + if proto != "HTTP/1.1" { + t.Errorf("expected proxy to see HTTP/1.1 CONNECT, got %s", proto) + } + case <-time.After(2 * time.Second): + t.Fatal("proxy handler was never invoked") + } +} + +func TestWithProxyTransport_HTTPSProxy(t *testing.T) { + target := startTargetServer(t) + proxyURL, obsCh := startCONNECTProxy(t, true) + + transport := withProxyTransport(newTestTransport(), proxyURL, "") + client := &http.Client{Transport: transport, Timeout: 10 * time.Second} + + resp, err := client.Get(target.URL) + if err != nil { + t.Fatalf("GET through https proxy: %v", err) + } + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected 200, got %d", resp.StatusCode) + } + if got := strings.TrimSpace(string(body)); got != "ok" { + t.Errorf("expected body 'ok', got %q", got) + } + + select { + case proto := <-obsCh: + if proto != "HTTP/1.1" { + t.Errorf("expected proxy to see HTTP/1.1 CONNECT, got %s", proto) + } + case <-time.After(2 * time.Second): + t.Fatal("proxy handler was never invoked") + } +} + +func TestWithProxyTransport_ProxyAuth(t *testing.T) { + target := startTargetServer(t) + + t.Run("http proxy with auth", func(t *testing.T) { + proxyURL := startCONNECTProxyWithAuth(t, false, "user", "pass") + transport := withProxyTransport(newTestTransport(), proxyURL, "") + client := &http.Client{Transport: transport, Timeout: 10 * time.Second} + + resp, err := client.Get(target.URL) + if err != nil { + t.Fatalf("GET through authenticated http proxy: %v", err) + } + defer resp.Body.Close() + io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected 200, got %d", resp.StatusCode) + } + }) + + t.Run("https proxy with auth", func(t *testing.T) { + proxyURL := startCONNECTProxyWithAuth(t, true, "user", "s3cret") + transport := withProxyTransport(newTestTransport(), proxyURL, "") + client := &http.Client{Transport: transport, Timeout: 10 * time.Second} + + resp, err := client.Get(target.URL) + if err != nil { + t.Fatalf("GET through authenticated https proxy: %v", err) + } + defer resp.Body.Close() + io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected 200, got %d", resp.StatusCode) + } + }) +} + +func TestWithProxyTransport_HTTPSProxy_HTTP2ToOrigin(t *testing.T) { + // Verify that when tunneling through an HTTPS proxy, the connection to + // the origin target still negotiates HTTP/2 (not downgraded to HTTP/1.1). + target := startTargetServer(t) + proxyURL, _ := startCONNECTProxy(t, true) + + transport := withProxyTransport(newTestTransport(), proxyURL, "") + client := &http.Client{Transport: transport, Timeout: 10 * time.Second} + + resp, err := client.Get(target.URL) + if err != nil { + t.Fatalf("GET through https proxy: %v", err) + } + defer resp.Body.Close() + io.ReadAll(resp.Body) + + if resp.Proto != "HTTP/2.0" { + t.Errorf("expected HTTP/2.0 to origin, got %s", resp.Proto) + } +} + +func TestWithProxyTransport_ProxyRejectsConnect(t *testing.T) { + tests := []struct { + name string + statusCode int + body string + wantStatus string + }{ + {"407 proxy auth required", http.StatusProxyAuthRequired, "proxy auth required", "407 Proxy Authentication Required"}, + {"403 forbidden", http.StatusForbidden, "access denied by policy", "403 Forbidden"}, + {"502 bad gateway", http.StatusBadGateway, "upstream unreachable", "502 Bad Gateway"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Start a proxy that always rejects CONNECT with the given status. + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + srv := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, tt.body, tt.statusCode) + })} + go srv.Serve(ln) + t.Cleanup(func() { srv.Close() }) + + proxyURL, _ := url.Parse(fmt.Sprintf("http://%s", ln.Addr().String())) + transport := withProxyTransport(newTestTransport(), proxyURL, "") + client := &http.Client{Transport: transport, Timeout: 10 * time.Second} + + _, err = client.Get("https://example.com") + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantStatus) { + t.Errorf("error should contain status %q, got: %v", tt.wantStatus, err) + } + if !strings.Contains(err.Error(), tt.body) { + t.Errorf("error should contain body %q, got: %v", tt.body, err) + } + }) + } +} + +func TestProxyDialAddr(t *testing.T) { + tests := []struct { + name string + url string + want string + }{ + {"https with port", "https://proxy.example.com:8443", "proxy.example.com:8443"}, + {"https without port", "https://proxy.example.com", "proxy.example.com:443"}, + {"http with port", "http://proxy.example.com:8080", "proxy.example.com:8080"}, + {"http without port", "http://proxy.example.com", "proxy.example.com:80"}, + {"ipv4 with port", "http://192.168.1.100:3128", "192.168.1.100:3128"}, + {"ipv4 without port https", "https://10.0.0.1", "10.0.0.1:443"}, + {"ipv4 without port http", "http://172.16.0.5", "172.16.0.5:80"}, + {"ipv6 with port", "http://[::1]:8080", "[::1]:8080"}, + {"ipv6 without port https", "https://[2001:db8::1]", "[2001:db8::1]:443"}, + {"ipv6 without port http", "http://[fe80::1]", "[fe80::1]:80"}, + {"localhost with port", "http://localhost:9090", "localhost:9090"}, + {"localhost without port https", "https://localhost", "localhost:443"}, + {"localhost without port http", "http://localhost", "localhost:80"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u, err := url.ParseRequestURI(tt.url) + if err != nil { + t.Fatalf("parse URL: %v", err) + } + got := proxyDialAddr(u) + if got != tt.want { + t.Errorf("proxyHostPort(%s) = %q, want %q", tt.url, got, tt.want) + } + }) + } +} diff --git a/internal/batches/executor/executor_test.go b/internal/batches/executor/executor_test.go index 9fc96d927d..03f25e08a8 100644 --- a/internal/batches/executor/executor_test.go +++ b/internal/batches/executor/executor_test.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "net/url" "os" "path/filepath" "runtime" @@ -406,7 +407,8 @@ func TestExecutor_Integration(t *testing.T) { // Setup an api.Client that points to this test server var clientBuffer bytes.Buffer - client := api.NewClient(api.ClientOpts{Endpoint: ts.URL, Out: &clientBuffer}) + u, _ := url.ParseRequestURI(ts.URL) + client := api.NewClient(api.ClientOpts{EndpointURL: u, Out: &clientBuffer}) // Temp dir for log files and downloaded archives testTempDir := t.TempDir() @@ -827,7 +829,8 @@ func testExecuteTasks(t *testing.T, tasks []*Task, archives ...mock.RepoArchive) t.Cleanup(ts.Close) var clientBuffer bytes.Buffer - client := api.NewClient(api.ClientOpts{Endpoint: ts.URL, Out: &clientBuffer}) + u, _ := url.ParseRequestURI(ts.URL) + client := api.NewClient(api.ClientOpts{EndpointURL: u, Out: &clientBuffer}) // Prepare images // diff --git a/internal/batches/repozip/fetcher_test.go b/internal/batches/repozip/fetcher_test.go index f871237e45..56d03d85a6 100644 --- a/internal/batches/repozip/fetcher_test.go +++ b/internal/batches/repozip/fetcher_test.go @@ -5,6 +5,7 @@ import ( "context" "net/http" "net/http/httptest" + "net/url" "os" "path" "path/filepath" @@ -44,7 +45,8 @@ func TestArchive_Ensure(t *testing.T) { defer ts.Close() var clientBuffer bytes.Buffer - client := api.NewClient(api.ClientOpts{Endpoint: ts.URL, Out: &clientBuffer}) + u, _ := url.ParseRequestURI(ts.URL) + client := api.NewClient(api.ClientOpts{EndpointURL: u, Out: &clientBuffer}) rf := &archiveRegistry{ client: client, @@ -89,7 +91,8 @@ func TestArchive_Ensure(t *testing.T) { defer ts.Close() var clientBuffer bytes.Buffer - client := api.NewClient(api.ClientOpts{Endpoint: ts.URL, Out: &clientBuffer}) + u, _ := url.ParseRequestURI(ts.URL) + client := api.NewClient(api.ClientOpts{EndpointURL: u, Out: &clientBuffer}) rf := &archiveRegistry{ client: client, @@ -153,7 +156,8 @@ func TestArchive_Ensure(t *testing.T) { defer ts.Close() var clientBuffer bytes.Buffer - client := api.NewClient(api.ClientOpts{Endpoint: ts.URL, Out: &clientBuffer}) + u, _ := url.ParseRequestURI(ts.URL) + client := api.NewClient(api.ClientOpts{EndpointURL: u, Out: &clientBuffer}) rf := &archiveRegistry{ client: client, @@ -193,7 +197,8 @@ func TestArchive_Ensure(t *testing.T) { defer ts.Close() var clientBuffer bytes.Buffer - client := api.NewClient(api.ClientOpts{Endpoint: ts.URL, Out: &clientBuffer}) + u, _ := url.ParseRequestURI(ts.URL) + client := api.NewClient(api.ClientOpts{EndpointURL: u, Out: &clientBuffer}) rf := &archiveRegistry{ client: client, @@ -262,7 +267,8 @@ func TestArchive_Ensure(t *testing.T) { defer ts.Close() var clientBuffer bytes.Buffer - client := api.NewClient(api.ClientOpts{Endpoint: ts.URL, Out: &clientBuffer}) + u, _ := url.ParseRequestURI(ts.URL) + client := api.NewClient(api.ClientOpts{EndpointURL: u, Out: &clientBuffer}) rf := &archiveRegistry{ client: client,