diff --git a/internal/model/model.go b/internal/model/model.go index 873d8e6..4614c4c 100644 --- a/internal/model/model.go +++ b/internal/model/model.go @@ -886,6 +886,14 @@ func AddCustomEndpointWithOptions(cfg *config.Config, u *ui.UI, endpoint, modelN return nil } +// probeBackoffSleep is the sleep used between ValidateCustomEndpoint inference-probe +// retries. Overridable in tests to keep them fast. +var probeBackoffSleep = time.Sleep + +// ValidateCustomEndpoint validates that a custom OpenAI-compatible endpoint works. +// It runs a 2-step validation: reachability check, then inference probe. +// The inference probe is the definitive test — some servers (e.g., mlx-lm) don't +// list the loaded model in /models but accept it for inference. func ValidateCustomEndpoint(endpoint, modelName, apiKey string) error { return ValidateCustomEndpointWithOptions(endpoint, modelName, apiKey, CustomEndpointOptions{}) } @@ -957,9 +965,34 @@ func ValidateCustomEndpointWithOptions(endpoint, modelName, apiKey string, optio probeReq.Header.Set("Authorization", authHeader) } - probeResp, err := client.Do(probeReq) - if err != nil { - return fmt.Errorf("inference probe failed — cannot reach %s: %w", completionsURL, err) + // Retry on transient network errors (DNS flake, TCP reset, route loss). + // Only client.Do errors are retried — non-200 HTTP responses are real + // upstream signals (4xx = config bug, 5xx = upstream broken) and fail fast. + const probeMaxAttempts = 3 + probeBackoffs := []time.Duration{ + 250 * time.Millisecond, + 1 * time.Second, + 4 * time.Second, + } + + var probeResp *http.Response + var probeErr error + for attempt := 0; attempt < probeMaxAttempts; attempt++ { + // Bodies are single-use; re-attach the payload for each attempt. + attemptReq := probeReq.Clone(probeReq.Context()) + attemptReq.Body = io.NopCloser(bytes.NewReader(probePayload)) + + probeResp, probeErr = client.Do(attemptReq) + if probeErr == nil { + break + } + if attempt < probeMaxAttempts-1 { + probeBackoffSleep(probeBackoffs[attempt]) + } + } + if probeErr != nil { + return fmt.Errorf("inference probe failed after %d attempts — cannot reach %s: %w", + probeMaxAttempts, completionsURL, probeErr) } defer probeResp.Body.Close() diff --git a/internal/model/model_test.go b/internal/model/model_test.go index ef0556e..210ef67 100644 --- a/internal/model/model_test.go +++ b/internal/model/model_test.go @@ -8,7 +8,9 @@ import ( "os" "path/filepath" "strings" + "sync/atomic" "testing" + "time" ) func TestBuildModelEntries(t *testing.T) { @@ -597,6 +599,150 @@ func TestValidateCustomEndpoint(t *testing.T) { }) } +// withNoSleep replaces probeBackoffSleep with a no-op while the test runs so +// retry tests don't spend real seconds on backoff. +func withNoSleep(t *testing.T) { + t.Helper() + orig := probeBackoffSleep + probeBackoffSleep = func(time.Duration) {} + t.Cleanup(func() { probeBackoffSleep = orig }) +} + +// abortAfterNHandler aborts the connection on the first n POST hits to +// /chat/completions (forcing client.Do to return a Go-level network error), +// then serves a normal 200 with valid choices. Reachability hits +// (GET /models, /health, /) are always served normally. +func abortAfterNHandler(n int) (http.Handler, *int32) { + var posts int32 + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/chat/completions") { + cur := atomic.AddInt32(&posts, 1) + if int(cur) <= n { + panic(http.ErrAbortHandler) + } + fmt.Fprint(w, `{"choices":[{"message":{"content":"pong"}}]}`) + return + } + // Reachability paths. + if strings.HasSuffix(r.URL.Path, "/models") { + fmt.Fprint(w, `{"data":[{"id":"test-model"}]}`) + return + } + w.WriteHeader(http.StatusOK) + }) + return h, &posts +} + +func TestValidateCustomEndpoint_RetriesOnNetworkError(t *testing.T) { + withNoSleep(t) + + cases := []struct { + name string + errorsBeforeOK int + wantErr bool + wantPostAttempts int32 + }{ + {name: "success first try", errorsBeforeOK: 0, wantErr: false, wantPostAttempts: 1}, + {name: "one transient error, then ok", errorsBeforeOK: 1, wantErr: false, wantPostAttempts: 2}, + {name: "two transient errors, then ok (at limit)", errorsBeforeOK: 2, wantErr: false, wantPostAttempts: 3}, + {name: "three transient errors — budget exhausted", errorsBeforeOK: 3, wantErr: true, wantPostAttempts: 3}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + handler, posts := abortAfterNHandler(tc.errorsBeforeOK) + srv := httptest.NewServer(handler) + defer srv.Close() + + // Silence ErrAbortHandler panics — they're expected. + srv.Config.ErrorLog = nil + + err := ValidateCustomEndpoint(srv.URL+"/v1", "test-model", "") + if tc.wantErr && err == nil { + t.Fatalf("expected error, got nil") + } + if !tc.wantErr && err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tc.wantErr && !strings.Contains(err.Error(), "inference probe failed after 3 attempts") { + t.Errorf("error message should reference attempt count, got: %v", err) + } + if got := atomic.LoadInt32(posts); got != tc.wantPostAttempts { + t.Errorf("POST attempts: got %d, want %d", got, tc.wantPostAttempts) + } + }) + } +} + +func TestValidateCustomEndpoint_NoRetryOnNon2xx(t *testing.T) { + withNoSleep(t) + + statuses := []int{http.StatusUnauthorized, http.StatusNotFound, http.StatusServiceUnavailable} + for _, code := range statuses { + t.Run(fmt.Sprintf("HTTP %d", code), func(t *testing.T) { + var posts int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/chat/completions") { + atomic.AddInt32(&posts, 1) + w.WriteHeader(code) + fmt.Fprint(w, `{"error":"nope"}`) + return + } + if strings.HasSuffix(r.URL.Path, "/models") { + fmt.Fprint(w, `{"data":[{"id":"test-model"}]}`) + return + } + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + err := ValidateCustomEndpoint(srv.URL+"/v1", "test-model", "") + if err == nil { + t.Fatalf("expected error for HTTP %d", code) + } + if !strings.Contains(err.Error(), fmt.Sprintf("returned %d", code)) { + t.Errorf("error should reference returned status %d, got: %v", code, err) + } + if got := atomic.LoadInt32(&posts); got != 1 { + t.Errorf("non-2xx must not retry: got %d POSTs, want 1", got) + } + }) + } +} + +func TestValidateCustomEndpoint_NoRetryOnInvalidResponseBody(t *testing.T) { + withNoSleep(t) + + var posts int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/chat/completions") { + atomic.AddInt32(&posts, 1) + fmt.Fprint(w, `not json {{{`) + return + } + if strings.HasSuffix(r.URL.Path, "/models") { + fmt.Fprint(w, `{"data":[{"id":"test-model"}]}`) + return + } + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + err := ValidateCustomEndpoint(srv.URL+"/v1", "test-model", "") + if err == nil { + t.Fatal("expected JSON decode error") + } + if !strings.Contains(err.Error(), "invalid response") { + t.Errorf("error should mention 'invalid response', got: %v", err) + } + if got := atomic.LoadInt32(&posts); got != 1 { + t.Errorf("malformed body must not retry: got %d POSTs, want 1", got) + } +} + func TestFormatBytes(t *testing.T) { tests := []struct { input int64