From 91e10dc3e79836fd1467db21284871baa901f268 Mon Sep 17 00:00:00 2001 From: bussyjd Date: Mon, 25 May 2026 13:02:28 +0400 Subject: [PATCH] fix(model): retry inference probe on transient network errors MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit obol model setup custom validates a candidate LLM endpoint by POSTing a 1-token chat completion. The probe was one-shot: any client.Do error (DNS flake, TCP reset, momentary route loss) failed the whole validation, surfacing in release-smoke flow-04 step 2 as: ✗ endpoint validation failed: inference probe failed — cannot reach http://silvermesh.v1337.lan:8081/v1/chat/completions: Post ...: cannot reach Reproduced 2026-05-25 — the exact same POST returned HTTP 200 minutes later from the same host. No code bug on either side, just a transient route flake. Add a bounded retry around client.Do (3 attempts, 250ms · 1s · 4s backoff). Retry ONLY on Go-level network errors. Non-2xx HTTP responses are real upstream signals (4xx = config bug, 5xx = upstream broken) and still fail fast — retry won't help. Tests inject a no-op sleep via package-level probeBackoffSleep var. Three new tests cover the retry table, non-2xx no-retry, and invalid response body no-retry. --- internal/model/model.go | 35 ++++++++- internal/model/model_test.go | 146 +++++++++++++++++++++++++++++++++++ 2 files changed, 178 insertions(+), 3 deletions(-) diff --git a/internal/model/model.go b/internal/model/model.go index 1b8042ea..2ef7a278 100644 --- a/internal/model/model.go +++ b/internal/model/model.go @@ -855,6 +855,10 @@ func AddCustomEndpoint(cfg *config.Config, u *ui.UI, endpoint, modelName, apiKey return nil } +// probeBackoffSleep is the sleep used between ValidateCustomEndpoint inference-probe +// retries. Overridable in tests to keep them fast. +var probeBackoffSleep = time.Sleep + // ValidateCustomEndpoint validates that a custom OpenAI-compatible endpoint works. // It runs a 2-step validation: reachability check, then inference probe. // The inference probe is the definitive test — some servers (e.g., mlx-lm) don't @@ -917,9 +921,34 @@ func ValidateCustomEndpoint(endpoint, modelName, apiKey string) error { probeReq.Header.Set("Authorization", authHeader) } - probeResp, err := client.Do(probeReq) - if err != nil { - return fmt.Errorf("inference probe failed — cannot reach %s: %w", completionsURL, err) + // Retry on transient network errors (DNS flake, TCP reset, route loss). + // Only client.Do errors are retried — non-200 HTTP responses are real + // upstream signals (4xx = config bug, 5xx = upstream broken) and fail fast. + const probeMaxAttempts = 3 + probeBackoffs := []time.Duration{ + 250 * time.Millisecond, + 1 * time.Second, + 4 * time.Second, + } + + var probeResp *http.Response + var probeErr error + for attempt := 0; attempt < probeMaxAttempts; attempt++ { + // Bodies are single-use; re-attach the payload for each attempt. + attemptReq := probeReq.Clone(probeReq.Context()) + attemptReq.Body = io.NopCloser(bytes.NewReader(probePayload)) + + probeResp, probeErr = client.Do(attemptReq) + if probeErr == nil { + break + } + if attempt < probeMaxAttempts-1 { + probeBackoffSleep(probeBackoffs[attempt]) + } + } + if probeErr != nil { + return fmt.Errorf("inference probe failed after %d attempts — cannot reach %s: %w", + probeMaxAttempts, completionsURL, probeErr) } defer probeResp.Body.Close() diff --git a/internal/model/model_test.go b/internal/model/model_test.go index 405989ca..4ce440d1 100644 --- a/internal/model/model_test.go +++ b/internal/model/model_test.go @@ -8,7 +8,9 @@ import ( "os" "path/filepath" "strings" + "sync/atomic" "testing" + "time" ) func TestBuildModelEntries(t *testing.T) { @@ -550,6 +552,150 @@ func TestValidateCustomEndpoint(t *testing.T) { }) } +// withNoSleep replaces probeBackoffSleep with a no-op while the test runs so +// retry tests don't spend real seconds on backoff. +func withNoSleep(t *testing.T) { + t.Helper() + orig := probeBackoffSleep + probeBackoffSleep = func(time.Duration) {} + t.Cleanup(func() { probeBackoffSleep = orig }) +} + +// abortAfterNHandler aborts the connection on the first n POST hits to +// /chat/completions (forcing client.Do to return a Go-level network error), +// then serves a normal 200 with valid choices. Reachability hits +// (GET /models, /health, /) are always served normally. +func abortAfterNHandler(n int) (http.Handler, *int32) { + var posts int32 + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/chat/completions") { + cur := atomic.AddInt32(&posts, 1) + if int(cur) <= n { + panic(http.ErrAbortHandler) + } + fmt.Fprint(w, `{"choices":[{"message":{"content":"pong"}}]}`) + return + } + // Reachability paths. + if strings.HasSuffix(r.URL.Path, "/models") { + fmt.Fprint(w, `{"data":[{"id":"test-model"}]}`) + return + } + w.WriteHeader(http.StatusOK) + }) + return h, &posts +} + +func TestValidateCustomEndpoint_RetriesOnNetworkError(t *testing.T) { + withNoSleep(t) + + cases := []struct { + name string + errorsBeforeOK int + wantErr bool + wantPostAttempts int32 + }{ + {name: "success first try", errorsBeforeOK: 0, wantErr: false, wantPostAttempts: 1}, + {name: "one transient error, then ok", errorsBeforeOK: 1, wantErr: false, wantPostAttempts: 2}, + {name: "two transient errors, then ok (at limit)", errorsBeforeOK: 2, wantErr: false, wantPostAttempts: 3}, + {name: "three transient errors — budget exhausted", errorsBeforeOK: 3, wantErr: true, wantPostAttempts: 3}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + handler, posts := abortAfterNHandler(tc.errorsBeforeOK) + srv := httptest.NewServer(handler) + defer srv.Close() + + // Silence ErrAbortHandler panics — they're expected. + srv.Config.ErrorLog = nil + + err := ValidateCustomEndpoint(srv.URL+"/v1", "test-model", "") + if tc.wantErr && err == nil { + t.Fatalf("expected error, got nil") + } + if !tc.wantErr && err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tc.wantErr && !strings.Contains(err.Error(), "inference probe failed after 3 attempts") { + t.Errorf("error message should reference attempt count, got: %v", err) + } + if got := atomic.LoadInt32(posts); got != tc.wantPostAttempts { + t.Errorf("POST attempts: got %d, want %d", got, tc.wantPostAttempts) + } + }) + } +} + +func TestValidateCustomEndpoint_NoRetryOnNon2xx(t *testing.T) { + withNoSleep(t) + + statuses := []int{http.StatusUnauthorized, http.StatusNotFound, http.StatusServiceUnavailable} + for _, code := range statuses { + t.Run(fmt.Sprintf("HTTP %d", code), func(t *testing.T) { + var posts int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/chat/completions") { + atomic.AddInt32(&posts, 1) + w.WriteHeader(code) + fmt.Fprint(w, `{"error":"nope"}`) + return + } + if strings.HasSuffix(r.URL.Path, "/models") { + fmt.Fprint(w, `{"data":[{"id":"test-model"}]}`) + return + } + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + err := ValidateCustomEndpoint(srv.URL+"/v1", "test-model", "") + if err == nil { + t.Fatalf("expected error for HTTP %d", code) + } + if !strings.Contains(err.Error(), fmt.Sprintf("returned %d", code)) { + t.Errorf("error should reference returned status %d, got: %v", code, err) + } + if got := atomic.LoadInt32(&posts); got != 1 { + t.Errorf("non-2xx must not retry: got %d POSTs, want 1", got) + } + }) + } +} + +func TestValidateCustomEndpoint_NoRetryOnInvalidResponseBody(t *testing.T) { + withNoSleep(t) + + var posts int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/chat/completions") { + atomic.AddInt32(&posts, 1) + fmt.Fprint(w, `not json {{{`) + return + } + if strings.HasSuffix(r.URL.Path, "/models") { + fmt.Fprint(w, `{"data":[{"id":"test-model"}]}`) + return + } + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + err := ValidateCustomEndpoint(srv.URL+"/v1", "test-model", "") + if err == nil { + t.Fatal("expected JSON decode error") + } + if !strings.Contains(err.Error(), "invalid response") { + t.Errorf("error should mention 'invalid response', got: %v", err) + } + if got := atomic.LoadInt32(&posts); got != 1 { + t.Errorf("malformed body must not retry: got %d POSTs, want 1", got) + } +} + func TestFormatBytes(t *testing.T) { tests := []struct { input int64