From d99dbec0f195b3974e5c1455b27f2b7aa6204c2e Mon Sep 17 00:00:00 2001 From: appleboy Date: Fri, 27 Feb 2026 16:59:07 +0800 Subject: [PATCH 1/2] feat: sanitize authentication errors for safer browser messages - Introduce error sanitization for OAuth and token exchange errors to prevent sensitive information disclosure in browser messages - Show user-friendly, sanitized error messages in the browser, while retaining full error details in the terminal output - Add tests verifying that browser messages do not leak backend details and remain generic - Add utility functions for error sanitization and testing their correctness Signed-off-by: appleboy --- callback.go | 50 ++++++----- callback_test.go | 66 ++++++++++++-- error_sanitizer.go | 44 ++++++++++ error_sanitizer_test.go | 186 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 317 insertions(+), 29 deletions(-) create mode 100644 error_sanitizer.go create mode 100644 error_sanitizer_test.go diff --git a/callback.go b/callback.go index 162bad6..a887bd6 100644 --- a/callback.go +++ b/callback.go @@ -23,9 +23,10 @@ var ErrCallbackTimeout = errors.New("browser authorization timed out") // callbackResult holds the outcome of the local callback round-trip. type callbackResult struct { - Storage *TokenStorage - Error string - Desc string + Storage *TokenStorage + Error string + Desc string // Detailed description (for terminal only) + SanitizedMsg string // User-friendly message (for browser only) } // startCallbackServer starts a local HTTP server on the given port and waits @@ -50,36 +51,48 @@ func startCallbackServer(ctx context.Context, port int, expectedState string, if oauthErr := q.Get("error"); oauthErr != "" { desc := q.Get("error_description") - writeCallbackPage(w, false, oauthErr, desc) - sendResult(callbackResult{Error: oauthErr, Desc: desc}) + sanitized := sanitizeOAuthError(oauthErr, desc) + writeCallbackPage(w, false, sanitized) + sendResult(callbackResult{Error: oauthErr, Desc: desc, SanitizedMsg: sanitized}) return } state := q.Get("state") if state != expectedState { - writeCallbackPage(w, false, "state_mismatch", - "State parameter does not match. Possible CSRF attack.") + sanitized := "Authorization failed. Possible security issue detected." + writeCallbackPage(w, false, sanitized) sendResult(callbackResult{ - Error: "state_mismatch", - Desc: "state parameter mismatch", + Error: "state_mismatch", + Desc: "State parameter does not match. Possible CSRF attack.", + SanitizedMsg: sanitized, }) return } code := q.Get("code") if code == "" { - writeCallbackPage(w, false, "missing_code", "No authorization code in callback.") - sendResult(callbackResult{Error: "missing_code", Desc: "code parameter missing"}) + sanitized := "Authorization failed. Missing authorization code." + writeCallbackPage(w, false, sanitized) + sendResult(callbackResult{ + Error: "missing_code", + Desc: "code parameter missing", + SanitizedMsg: sanitized, + }) return } storage, exchangeErr := exchangeFn(r.Context(), code) if exchangeErr != nil { - writeCallbackPage(w, false, "token_exchange_failed", exchangeErr.Error()) - sendResult(callbackResult{Error: "token_exchange_failed", Desc: exchangeErr.Error()}) + sanitized := sanitizeTokenExchangeError(exchangeErr) + writeCallbackPage(w, false, sanitized) + sendResult(callbackResult{ + Error: "token_exchange_failed", + Desc: exchangeErr.Error(), + SanitizedMsg: sanitized, + }) return } - writeCallbackPage(w, true, "", "") + writeCallbackPage(w, true, "") sendResult(callbackResult{Storage: storage}) }) @@ -121,7 +134,8 @@ func startCallbackServer(ctx context.Context, port int, expectedState string, } // writeCallbackPage writes a minimal HTML response to the browser tab. -func writeCallbackPage(w http.ResponseWriter, success bool, errCode, errDesc string) { +// The message parameter should be pre-sanitized for security. +func writeCallbackPage(w http.ResponseWriter, success bool, message string) { w.Header().Set("Content-Type", "text/html; charset=utf-8") if success { @@ -137,10 +151,6 @@ func writeCallbackPage(w http.ResponseWriter, success bool, errCode, errDesc str return } - msg := errCode - if errDesc != "" { - msg = errDesc - } fmt.Fprintf(w, ` Authorization Failed @@ -149,5 +159,5 @@ func writeCallbackPage(w http.ResponseWriter, success bool, errCode, errDesc str

%s

You can close this tab and check your terminal for details.

-`, html.EscapeString(msg)) +`, html.EscapeString(message)) } diff --git a/callback_test.go b/callback_test.go index 755d9f2..0d4d56f 100644 --- a/callback_test.go +++ b/callback_test.go @@ -114,8 +114,19 @@ func TestCallbackServer_StateMismatch(t *testing.T) { defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) - if !strings.Contains(string(body), "Authorization Failed") { - t.Errorf("expected failure page for state mismatch, got: %s", string(body)) + bodyStr := string(body) + + // Verify browser shows sanitized message + if !strings.Contains(bodyStr, "Authorization Failed") { + t.Errorf("expected failure page for state mismatch, got: %s", bodyStr) + } + if !strings.Contains(bodyStr, "security issue") { + t.Errorf("expected sanitized security message in browser, got: %s", bodyStr) + } + + // Verify browser does NOT show CSRF attack details + if strings.Contains(bodyStr, "CSRF") { + t.Errorf("browser should not mention CSRF attack details, got: %s", bodyStr) } select { @@ -123,6 +134,10 @@ func TestCallbackServer_StateMismatch(t *testing.T) { if result.err == nil { t.Error("expected error for state mismatch, got nil") } + // Terminal error should contain state_mismatch + if !strings.Contains(result.err.Error(), "state_mismatch") { + t.Errorf("expected terminal error to mention state_mismatch, got: %v", result.err) + } case <-time.After(3 * time.Second): t.Fatal("timed out waiting for callback result") } @@ -145,8 +160,19 @@ func TestCallbackServer_OAuthError(t *testing.T) { defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) - if !strings.Contains(string(body), "Authorization Failed") { - t.Errorf("expected failure page for access_denied, got: %s", string(body)) + bodyStr := string(body) + + // Verify browser shows sanitized message + if !strings.Contains(bodyStr, "Authorization Failed") { + t.Errorf("expected failure page for access_denied, got: %s", bodyStr) + } + if !strings.Contains(bodyStr, "Authorization was denied") { + t.Errorf("expected sanitized message in browser, got: %s", bodyStr) + } + + // Verify browser does NOT show detailed description + if strings.Contains(bodyStr, "User denied") { + t.Errorf("browser should not contain detailed error description, got: %s", bodyStr) } select { @@ -154,8 +180,12 @@ func TestCallbackServer_OAuthError(t *testing.T) { if result.err == nil { t.Error("expected error for access_denied, got nil") } + // Verify terminal error still contains full details if !strings.Contains(result.err.Error(), "access_denied") { - t.Errorf("expected error to mention access_denied, got: %v", result.err) + t.Errorf("expected terminal error to mention access_denied, got: %v", result.err) + } + if !strings.Contains(result.err.Error(), "User denied") { + t.Errorf("expected terminal error to contain detailed description, got: %v", result.err) } case <-time.After(3 * time.Second): t.Fatal("timed out waiting for callback result") @@ -168,7 +198,7 @@ func TestCallbackServer_ExchangeFailure(t *testing.T) { ch := startCallbackServerAsync(t, context.Background(), port, state, func(_ context.Context, _ string) (*TokenStorage, error) { - return nil, errors.New("unauthorized_client: unauthorized_client") + return nil, errors.New("unauthorized_client: backend service authentication failed") }) callbackURL := fmt.Sprintf( @@ -182,8 +212,22 @@ func TestCallbackServer_ExchangeFailure(t *testing.T) { defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) - if !strings.Contains(string(body), "Authorization Failed") { - t.Errorf("expected failure page for exchange error, got: %s", string(body)) + bodyStr := string(body) + + // Verify browser shows generic message + if !strings.Contains(bodyStr, "Authorization Failed") { + t.Errorf("expected failure page for exchange error, got: %s", bodyStr) + } + if !strings.Contains(bodyStr, "Token exchange failed") { + t.Errorf("expected sanitized message in browser, got: %s", bodyStr) + } + + // Verify browser does NOT show backend error details + if strings.Contains(bodyStr, "unauthorized_client") { + t.Errorf("browser should not contain backend error code, got: %s", bodyStr) + } + if strings.Contains(bodyStr, "backend service") { + t.Errorf("browser should not contain backend error details, got: %s", bodyStr) } select { @@ -191,8 +235,12 @@ func TestCallbackServer_ExchangeFailure(t *testing.T) { if result.err == nil { t.Error("expected error for exchange failure, got nil") } + // Verify terminal error still contains full backend error if !strings.Contains(result.err.Error(), "unauthorized_client") { - t.Errorf("expected error to mention unauthorized_client, got: %v", result.err) + t.Errorf("expected terminal error to mention unauthorized_client, got: %v", result.err) + } + if !strings.Contains(result.err.Error(), "backend service") { + t.Errorf("expected terminal error to contain full details, got: %v", result.err) } case <-time.After(3 * time.Second): t.Fatal("timed out waiting for callback result") diff --git a/error_sanitizer.go b/error_sanitizer.go new file mode 100644 index 0000000..47dde7c --- /dev/null +++ b/error_sanitizer.go @@ -0,0 +1,44 @@ +package main + +import "strings" + +// sanitizeOAuthError maps standard OAuth error codes to user-friendly messages +// that are safe to display in the browser. This prevents information disclosure +// while maintaining a good user experience. +// The errorDescription parameter is intentionally ignored to prevent leaking details. +func sanitizeOAuthError(errorCode, _ string) string { + switch errorCode { + case "access_denied": + return "Authorization was denied. You may close this window." + case "invalid_request": + return "Invalid request. Please contact support." + case "unauthorized_client": + return "Client is not authorized." + case "server_error": + return "Server error. Please try again later." + case "temporarily_unavailable": + return "Service is temporarily unavailable. Please try again later." + default: + return "Authentication failed. Please check your terminal for details." + } +} + +// sanitizeTokenExchangeError sanitizes backend token exchange errors to prevent +// leaking sensitive implementation details such as service names, internal error +// codes, or validation mechanisms. +func sanitizeTokenExchangeError(err error) string { + // Always return a generic message to prevent information disclosure. + // The full error is still logged to the terminal for debugging. + return "Token exchange failed. Please try again." +} + +// containsAny checks if string s contains any of the specified substrings. +func containsAny(s string, substrs []string) bool { + s = strings.ToLower(s) + for _, substr := range substrs { + if strings.Contains(s, strings.ToLower(substr)) { + return true + } + } + return false +} diff --git a/error_sanitizer_test.go b/error_sanitizer_test.go new file mode 100644 index 0000000..3d71e23 --- /dev/null +++ b/error_sanitizer_test.go @@ -0,0 +1,186 @@ +package main + +import ( + "errors" + "strings" + "testing" +) + +func TestSanitizeOAuthError(t *testing.T) { + tests := []struct { + name string + errorCode string + errorDescription string + wantContains string + wantNotContains string + }{ + { + name: "access_denied", + errorCode: "access_denied", + errorDescription: "User denied the request", + wantContains: "Authorization was denied", + wantNotContains: "User denied", + }, + { + name: "invalid_request", + errorCode: "invalid_request", + errorDescription: "Missing required parameter: redirect_uri", + wantContains: "Invalid request", + wantNotContains: "redirect_uri", + }, + { + name: "unauthorized_client", + errorCode: "unauthorized_client", + errorDescription: "Client authentication failed", + wantContains: "Client is not authorized", + wantNotContains: "authentication failed", + }, + { + name: "server_error", + errorCode: "server_error", + errorDescription: "Internal database connection failed", + wantContains: "Server error", + wantNotContains: "database", + }, + { + name: "temporarily_unavailable", + errorCode: "temporarily_unavailable", + errorDescription: "Service overloaded", + wantContains: "temporarily unavailable", + wantNotContains: "overloaded", + }, + { + name: "unknown_error", + errorCode: "custom_error_code", + errorDescription: "Some internal error details", + wantContains: "Authentication failed", + wantNotContains: "internal", + }, + { + name: "empty_description", + errorCode: "access_denied", + errorDescription: "", + wantContains: "Authorization was denied", + wantNotContains: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := sanitizeOAuthError(tt.errorCode, tt.errorDescription) + + if tt.wantContains != "" && !strings.Contains(got, tt.wantContains) { + t.Errorf("sanitizeOAuthError() = %q, want to contain %q", got, tt.wantContains) + } + + if tt.wantNotContains != "" && strings.Contains(got, tt.wantNotContains) { + t.Errorf( + "sanitizeOAuthError() = %q, should not contain %q", + got, + tt.wantNotContains, + ) + } + }) + } +} + +func TestSanitizeTokenExchangeError(t *testing.T) { + tests := []struct { + name string + err error + wantContains string + wantNotContains []string + }{ + { + name: "generic_error", + err: errors.New("unauthorized_client: client authentication failed"), + wantContains: "Token exchange failed", + wantNotContains: []string{"unauthorized_client", "authentication"}, + }, + { + name: "backend_service_error", + err: errors.New("backend service error: database connection failed"), + wantContains: "Token exchange failed", + wantNotContains: []string{"backend", "database", "service"}, + }, + { + name: "internal_error", + err: errors.New("internal error: validation failed for user account"), + wantContains: "Token exchange failed", + wantNotContains: []string{"internal", "validation", "account"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := sanitizeTokenExchangeError(tt.err) + + if !strings.Contains(got, tt.wantContains) { + t.Errorf( + "sanitizeTokenExchangeError() = %q, want to contain %q", + got, + tt.wantContains, + ) + } + + for _, notWant := range tt.wantNotContains { + if strings.Contains(strings.ToLower(got), strings.ToLower(notWant)) { + t.Errorf( + "sanitizeTokenExchangeError() = %q, should not contain %q", + got, + notWant, + ) + } + } + }) + } +} + +func TestContainsAny(t *testing.T) { + tests := []struct { + name string + s string + substrs []string + want bool + }{ + { + name: "contains_one", + s: "This is a test string with database keyword", + substrs: []string{"database", "internal", "service"}, + want: true, + }, + { + name: "contains_multiple", + s: "Internal backend service error", + substrs: []string{"database", "internal", "service"}, + want: true, + }, + { + name: "contains_none", + s: "Simple error message", + substrs: []string{"database", "internal", "service"}, + want: false, + }, + { + name: "case_insensitive", + s: "DATABASE connection failed", + substrs: []string{"database", "internal"}, + want: true, + }, + { + name: "empty_substrs", + s: "Any string", + substrs: []string{}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := containsAny(tt.s, tt.substrs) + if got != tt.want { + t.Errorf("containsAny() = %v, want %v", got, tt.want) + } + }) + } +} From 82c8140b9db418c17fd47a36fe76f5d9b8c562f7 Mon Sep 17 00:00:00 2001 From: appleboy Date: Fri, 27 Feb 2026 17:06:07 +0800 Subject: [PATCH 2/2] refactor: improve error handling and remove unused helpers - Remove the containsAny helper function and its associated tests - Change sanitizeTokenExchangeError to ignore the error parameter entirely for improved security Signed-off-by: appleboy --- error_sanitizer.go | 16 ++------------ error_sanitizer_test.go | 49 ----------------------------------------- 2 files changed, 2 insertions(+), 63 deletions(-) diff --git a/error_sanitizer.go b/error_sanitizer.go index 47dde7c..1384de4 100644 --- a/error_sanitizer.go +++ b/error_sanitizer.go @@ -1,7 +1,5 @@ package main -import "strings" - // sanitizeOAuthError maps standard OAuth error codes to user-friendly messages // that are safe to display in the browser. This prevents information disclosure // while maintaining a good user experience. @@ -26,19 +24,9 @@ func sanitizeOAuthError(errorCode, _ string) string { // sanitizeTokenExchangeError sanitizes backend token exchange errors to prevent // leaking sensitive implementation details such as service names, internal error // codes, or validation mechanisms. -func sanitizeTokenExchangeError(err error) string { +// The err parameter is intentionally ignored to prevent leaking any details. +func sanitizeTokenExchangeError(_ error) string { // Always return a generic message to prevent information disclosure. // The full error is still logged to the terminal for debugging. return "Token exchange failed. Please try again." } - -// containsAny checks if string s contains any of the specified substrings. -func containsAny(s string, substrs []string) bool { - s = strings.ToLower(s) - for _, substr := range substrs { - if strings.Contains(s, strings.ToLower(substr)) { - return true - } - } - return false -} diff --git a/error_sanitizer_test.go b/error_sanitizer_test.go index 3d71e23..7317666 100644 --- a/error_sanitizer_test.go +++ b/error_sanitizer_test.go @@ -135,52 +135,3 @@ func TestSanitizeTokenExchangeError(t *testing.T) { }) } } - -func TestContainsAny(t *testing.T) { - tests := []struct { - name string - s string - substrs []string - want bool - }{ - { - name: "contains_one", - s: "This is a test string with database keyword", - substrs: []string{"database", "internal", "service"}, - want: true, - }, - { - name: "contains_multiple", - s: "Internal backend service error", - substrs: []string{"database", "internal", "service"}, - want: true, - }, - { - name: "contains_none", - s: "Simple error message", - substrs: []string{"database", "internal", "service"}, - want: false, - }, - { - name: "case_insensitive", - s: "DATABASE connection failed", - substrs: []string{"database", "internal"}, - want: true, - }, - { - name: "empty_substrs", - s: "Any string", - substrs: []string{}, - want: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := containsAny(tt.s, tt.substrs) - if got != tt.want { - t.Errorf("containsAny() = %v, want %v", got, tt.want) - } - }) - } -}