Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 30 additions & 20 deletions callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ var ErrCallbackTimeout = errors.New("browser authorization timed out")

// callbackResult holds the outcome of the local callback round-trip.
type callbackResult struct {
Storage *TokenStorage
Error string
Desc string
Storage *TokenStorage
Error string
Desc string // Detailed description (for terminal only)
SanitizedMsg string // User-friendly message (for browser only)
}

// startCallbackServer starts a local HTTP server on the given port and waits
Expand All @@ -50,36 +51,48 @@ func startCallbackServer(ctx context.Context, port int, expectedState string,

if oauthErr := q.Get("error"); oauthErr != "" {
desc := q.Get("error_description")
writeCallbackPage(w, false, oauthErr, desc)
sendResult(callbackResult{Error: oauthErr, Desc: desc})
sanitized := sanitizeOAuthError(oauthErr, desc)
writeCallbackPage(w, false, sanitized)
sendResult(callbackResult{Error: oauthErr, Desc: desc, SanitizedMsg: sanitized})
return
}

state := q.Get("state")
if state != expectedState {
writeCallbackPage(w, false, "state_mismatch",
"State parameter does not match. Possible CSRF attack.")
sanitized := "Authorization failed. Possible security issue detected."
writeCallbackPage(w, false, sanitized)
sendResult(callbackResult{
Error: "state_mismatch",
Desc: "state parameter mismatch",
Error: "state_mismatch",
Desc: "State parameter does not match. Possible CSRF attack.",
SanitizedMsg: sanitized,
})
return
}

code := q.Get("code")
if code == "" {
writeCallbackPage(w, false, "missing_code", "No authorization code in callback.")
sendResult(callbackResult{Error: "missing_code", Desc: "code parameter missing"})
sanitized := "Authorization failed. Missing authorization code."
writeCallbackPage(w, false, sanitized)
sendResult(callbackResult{
Error: "missing_code",
Desc: "code parameter missing",
SanitizedMsg: sanitized,
})
return
}

storage, exchangeErr := exchangeFn(r.Context(), code)
if exchangeErr != nil {
writeCallbackPage(w, false, "token_exchange_failed", exchangeErr.Error())
sendResult(callbackResult{Error: "token_exchange_failed", Desc: exchangeErr.Error()})
sanitized := sanitizeTokenExchangeError(exchangeErr)
writeCallbackPage(w, false, sanitized)
sendResult(callbackResult{
Error: "token_exchange_failed",
Desc: exchangeErr.Error(),
SanitizedMsg: sanitized,
})
return
}
writeCallbackPage(w, true, "", "")
writeCallbackPage(w, true, "")
sendResult(callbackResult{Storage: storage})
})

Expand Down Expand Up @@ -121,7 +134,8 @@ func startCallbackServer(ctx context.Context, port int, expectedState string,
}

// writeCallbackPage writes a minimal HTML response to the browser tab.
func writeCallbackPage(w http.ResponseWriter, success bool, errCode, errDesc string) {
// The message parameter should be pre-sanitized for security.
func writeCallbackPage(w http.ResponseWriter, success bool, message string) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")

if success {
Expand All @@ -137,10 +151,6 @@ func writeCallbackPage(w http.ResponseWriter, success bool, errCode, errDesc str
return
}

msg := errCode
if errDesc != "" {
msg = errDesc
}
fmt.Fprintf(w, `<!DOCTYPE html>
<html>
<head><title>Authorization Failed</title></head>
Expand All @@ -149,5 +159,5 @@ func writeCallbackPage(w http.ResponseWriter, success bool, errCode, errDesc str
<p>%s</p>
<p>You can close this tab and check your terminal for details.</p>
</body>
</html>`, html.EscapeString(msg))
</html>`, html.EscapeString(message))
}
66 changes: 57 additions & 9 deletions callback_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,30 @@ func TestCallbackServer_StateMismatch(t *testing.T) {
defer resp.Body.Close()

body, _ := io.ReadAll(resp.Body)
if !strings.Contains(string(body), "Authorization Failed") {
t.Errorf("expected failure page for state mismatch, got: %s", string(body))
bodyStr := string(body)

// Verify browser shows sanitized message
if !strings.Contains(bodyStr, "Authorization Failed") {
t.Errorf("expected failure page for state mismatch, got: %s", bodyStr)
}
if !strings.Contains(bodyStr, "security issue") {
t.Errorf("expected sanitized security message in browser, got: %s", bodyStr)
}

// Verify browser does NOT show CSRF attack details
if strings.Contains(bodyStr, "CSRF") {
t.Errorf("browser should not mention CSRF attack details, got: %s", bodyStr)
}

select {
case result := <-ch:
if result.err == nil {
t.Error("expected error for state mismatch, got nil")
}
// Terminal error should contain state_mismatch
if !strings.Contains(result.err.Error(), "state_mismatch") {
t.Errorf("expected terminal error to mention state_mismatch, got: %v", result.err)
}
case <-time.After(3 * time.Second):
t.Fatal("timed out waiting for callback result")
}
Expand All @@ -145,17 +160,32 @@ func TestCallbackServer_OAuthError(t *testing.T) {
defer resp.Body.Close()

body, _ := io.ReadAll(resp.Body)
if !strings.Contains(string(body), "Authorization Failed") {
t.Errorf("expected failure page for access_denied, got: %s", string(body))
bodyStr := string(body)

// Verify browser shows sanitized message
if !strings.Contains(bodyStr, "Authorization Failed") {
t.Errorf("expected failure page for access_denied, got: %s", bodyStr)
}
if !strings.Contains(bodyStr, "Authorization was denied") {
t.Errorf("expected sanitized message in browser, got: %s", bodyStr)
}

// Verify browser does NOT show detailed description
if strings.Contains(bodyStr, "User denied") {
t.Errorf("browser should not contain detailed error description, got: %s", bodyStr)
}

select {
case result := <-ch:
if result.err == nil {
t.Error("expected error for access_denied, got nil")
}
// Verify terminal error still contains full details
if !strings.Contains(result.err.Error(), "access_denied") {
t.Errorf("expected error to mention access_denied, got: %v", result.err)
t.Errorf("expected terminal error to mention access_denied, got: %v", result.err)
}
if !strings.Contains(result.err.Error(), "User denied") {
t.Errorf("expected terminal error to contain detailed description, got: %v", result.err)
}
case <-time.After(3 * time.Second):
t.Fatal("timed out waiting for callback result")
Expand All @@ -168,7 +198,7 @@ func TestCallbackServer_ExchangeFailure(t *testing.T) {

ch := startCallbackServerAsync(t, context.Background(), port, state,
func(_ context.Context, _ string) (*TokenStorage, error) {
return nil, errors.New("unauthorized_client: unauthorized_client")
return nil, errors.New("unauthorized_client: backend service authentication failed")
})

callbackURL := fmt.Sprintf(
Expand All @@ -182,17 +212,35 @@ func TestCallbackServer_ExchangeFailure(t *testing.T) {
defer resp.Body.Close()

body, _ := io.ReadAll(resp.Body)
if !strings.Contains(string(body), "Authorization Failed") {
t.Errorf("expected failure page for exchange error, got: %s", string(body))
bodyStr := string(body)

// Verify browser shows generic message
if !strings.Contains(bodyStr, "Authorization Failed") {
t.Errorf("expected failure page for exchange error, got: %s", bodyStr)
}
if !strings.Contains(bodyStr, "Token exchange failed") {
t.Errorf("expected sanitized message in browser, got: %s", bodyStr)
}

// Verify browser does NOT show backend error details
if strings.Contains(bodyStr, "unauthorized_client") {
t.Errorf("browser should not contain backend error code, got: %s", bodyStr)
}
if strings.Contains(bodyStr, "backend service") {
t.Errorf("browser should not contain backend error details, got: %s", bodyStr)
}

select {
case result := <-ch:
if result.err == nil {
t.Error("expected error for exchange failure, got nil")
}
// Verify terminal error still contains full backend error
if !strings.Contains(result.err.Error(), "unauthorized_client") {
t.Errorf("expected error to mention unauthorized_client, got: %v", result.err)
t.Errorf("expected terminal error to mention unauthorized_client, got: %v", result.err)
}
if !strings.Contains(result.err.Error(), "backend service") {
t.Errorf("expected terminal error to contain full details, got: %v", result.err)
}
case <-time.After(3 * time.Second):
t.Fatal("timed out waiting for callback result")
Expand Down
32 changes: 32 additions & 0 deletions error_sanitizer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package main

// sanitizeOAuthError maps standard OAuth error codes to user-friendly messages
// that are safe to display in the browser. This prevents information disclosure
// while maintaining a good user experience.
// The errorDescription parameter is intentionally ignored to prevent leaking details.
func sanitizeOAuthError(errorCode, _ string) string {
switch errorCode {
case "access_denied":
return "Authorization was denied. You may close this window."
case "invalid_request":
return "Invalid request. Please contact support."
case "unauthorized_client":
return "Client is not authorized."
case "server_error":
return "Server error. Please try again later."
case "temporarily_unavailable":
return "Service is temporarily unavailable. Please try again later."
default:
return "Authentication failed. Please check your terminal for details."
}
}

// sanitizeTokenExchangeError sanitizes backend token exchange errors to prevent
// leaking sensitive implementation details such as service names, internal error
// codes, or validation mechanisms.
// The err parameter is intentionally ignored to prevent leaking any details.
func sanitizeTokenExchangeError(_ error) string {
// Always return a generic message to prevent information disclosure.
// The full error is still logged to the terminal for debugging.
return "Token exchange failed. Please try again."
}
137 changes: 137 additions & 0 deletions error_sanitizer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package main

import (
"errors"
"strings"
"testing"
)

func TestSanitizeOAuthError(t *testing.T) {
tests := []struct {
name string
errorCode string
errorDescription string
wantContains string
wantNotContains string
}{
{
name: "access_denied",
errorCode: "access_denied",
errorDescription: "User denied the request",
wantContains: "Authorization was denied",
wantNotContains: "User denied",
},
{
name: "invalid_request",
errorCode: "invalid_request",
errorDescription: "Missing required parameter: redirect_uri",
wantContains: "Invalid request",
wantNotContains: "redirect_uri",
},
{
name: "unauthorized_client",
errorCode: "unauthorized_client",
errorDescription: "Client authentication failed",
wantContains: "Client is not authorized",
wantNotContains: "authentication failed",
},
{
name: "server_error",
errorCode: "server_error",
errorDescription: "Internal database connection failed",
wantContains: "Server error",
wantNotContains: "database",
},
{
name: "temporarily_unavailable",
errorCode: "temporarily_unavailable",
errorDescription: "Service overloaded",
wantContains: "temporarily unavailable",
wantNotContains: "overloaded",
},
{
name: "unknown_error",
errorCode: "custom_error_code",
errorDescription: "Some internal error details",
wantContains: "Authentication failed",
wantNotContains: "internal",
},
{
name: "empty_description",
errorCode: "access_denied",
errorDescription: "",
wantContains: "Authorization was denied",
wantNotContains: "",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := sanitizeOAuthError(tt.errorCode, tt.errorDescription)

if tt.wantContains != "" && !strings.Contains(got, tt.wantContains) {
t.Errorf("sanitizeOAuthError() = %q, want to contain %q", got, tt.wantContains)
}

if tt.wantNotContains != "" && strings.Contains(got, tt.wantNotContains) {
t.Errorf(
"sanitizeOAuthError() = %q, should not contain %q",
got,
tt.wantNotContains,
)
}
})
}
}

func TestSanitizeTokenExchangeError(t *testing.T) {
tests := []struct {
name string
err error
wantContains string
wantNotContains []string
}{
{
name: "generic_error",
err: errors.New("unauthorized_client: client authentication failed"),
wantContains: "Token exchange failed",
wantNotContains: []string{"unauthorized_client", "authentication"},
},
{
name: "backend_service_error",
err: errors.New("backend service error: database connection failed"),
wantContains: "Token exchange failed",
wantNotContains: []string{"backend", "database", "service"},
},
{
name: "internal_error",
err: errors.New("internal error: validation failed for user account"),
wantContains: "Token exchange failed",
wantNotContains: []string{"internal", "validation", "account"},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := sanitizeTokenExchangeError(tt.err)

if !strings.Contains(got, tt.wantContains) {
t.Errorf(
"sanitizeTokenExchangeError() = %q, want to contain %q",
got,
tt.wantContains,
)
}

for _, notWant := range tt.wantNotContains {
if strings.Contains(strings.ToLower(got), strings.ToLower(notWant)) {
t.Errorf(
"sanitizeTokenExchangeError() = %q, should not contain %q",
got,
notWant,
)
}
}
})
}
}
Loading