From fe2e1e7ca946fd76c56d8a05e55a46c8851276eb Mon Sep 17 00:00:00 2001 From: Tanmay Sardesai Date: Wed, 11 Feb 2026 10:16:07 -0800 Subject: [PATCH 1/2] fix: scale-to-zero and shutdown improvements - middleware should skip disabling stz for internal requests - add logging for stz file writes - add support for stz.Drain() which is meant to be called upon shutdown - shutdown now kills all running processes and calls stz.Drain() --- server/cmd/api/api/api.go | 46 +++++++- server/cmd/api/main.go | 3 + server/go.mod | 2 + server/go.sum | 4 + server/lib/scaletozero/middleware.go | 22 ++++ server/lib/scaletozero/middleware_test.go | 114 ++++++++++++++++++++ server/lib/scaletozero/scaletozero.go | 76 ++++++++++++- server/lib/scaletozero/scaletozero_test.go | 118 +++++++++++++++++++++ 8 files changed, 382 insertions(+), 3 deletions(-) create mode 100644 server/lib/scaletozero/middleware_test.go diff --git a/server/cmd/api/api/api.go b/server/cmd/api/api/api.go index 910410b9..78cf5aba 100644 --- a/server/cmd/api/api/api.go +++ b/server/cmd/api/api/api.go @@ -6,9 +6,11 @@ import ( "fmt" "os" "os/exec" + "path/filepath" "sync" "time" + "github.com/hashicorp/go-multierror" "github.com/onkernel/kernel-images/server/lib/devtoolsproxy" "github.com/onkernel/kernel-images/server/lib/logger" "github.com/onkernel/kernel-images/server/lib/nekoclient" @@ -297,6 +299,48 @@ func (s *ApiService) ListRecorders(ctx context.Context, _ oapi.ListRecordersRequ return oapi.ListRecorders200JSONResponse(infos), nil } +// killAllProcesses sends SIGKILL to every tracked process that is still running. +func (s *ApiService) killAllProcesses(ctx context.Context) error { + log := logger.FromContext(ctx) + s.procMu.RLock() + defer s.procMu.RUnlock() + + var result *multierror.Error + for id, h := range s.procs { + if h.state() != "running" { + continue + } + if h.cmd.Process == nil { + continue + } + // supervisorctl handles the lifecycle of long running processes so we don't want to kill + // any active supervisorctl processes. For example it is used to restart kernel-images-api + // and killing that process would break the restart process. + if filepath.Base(h.cmd.Path) == "supervisorctl" { + continue + } + if err := h.cmd.Process.Kill(); err != nil { + result = multierror.Append(result, fmt.Errorf("process %s: %w", id, err)) + log.Error("failed to kill process", "process_id", id, "err", err) + } + } + return result.ErrorOrNil() +} + func (s *ApiService) Shutdown(ctx context.Context) error { - return s.recordManager.StopAll(ctx) + var wg sync.WaitGroup + var killErr, stopErr error + + wg.Add(2) + go func() { + defer wg.Done() + killErr = s.killAllProcesses(ctx) + }() + go func() { + defer wg.Done() + stopErr = s.recordManager.StopAll(ctx) + }() + wg.Wait() + + return multierror.Append(killErr, stopErr).ErrorOrNil() } diff --git a/server/cmd/api/main.go b/server/cmd/api/main.go index 7e47a0b6..45e14362 100644 --- a/server/cmd/api/main.go +++ b/server/cmd/api/main.go @@ -274,6 +274,9 @@ func main() { defer shutdownCancel() g, _ := errgroup.WithContext(shutdownCtx) + g.Go(func() error { + return stz.Drain(shutdownCtx) + }) g.Go(func() error { return srv.Shutdown(shutdownCtx) }) diff --git a/server/go.mod b/server/go.mod index f0270958..e5764cb6 100644 --- a/server/go.mod +++ b/server/go.mod @@ -14,6 +14,7 @@ require ( github.com/glebarez/sqlite v1.11.0 github.com/go-chi/chi/v5 v5.2.1 github.com/google/uuid v1.6.0 + github.com/hashicorp/go-multierror v1.1.1 github.com/kelseyhightower/envconfig v1.4.0 github.com/klauspost/compress v1.18.3 github.com/m1k1o/neko/server v0.0.0-20251008185748-46e2fc7d3866 @@ -51,6 +52,7 @@ require ( github.com/go-ole/go-ole v1.2.6 // indirect github.com/go-openapi/jsonpointer v0.21.0 // indirect github.com/go-openapi/swag v0.23.0 // indirect + github.com/hashicorp/errwrap v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/josharian/intern v1.0.0 // indirect diff --git a/server/go.sum b/server/go.sum index 559fc057..ff71a071 100644 --- a/server/go.sum +++ b/server/go.sum @@ -82,6 +82,10 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2/go.mod h1:pkJQ2tZHJ0aFOVEEot6oZmaVEZcRme73eIFmhiVuRWs= +github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= diff --git a/server/lib/scaletozero/middleware.go b/server/lib/scaletozero/middleware.go index f67c06e6..181e43d0 100644 --- a/server/lib/scaletozero/middleware.go +++ b/server/lib/scaletozero/middleware.go @@ -2,6 +2,7 @@ package scaletozero import ( "context" + "net" "net/http" "github.com/onkernel/kernel-images/server/lib/logger" @@ -9,9 +10,16 @@ import ( // Middleware returns a standard net/http middleware that disables scale-to-zero // at the start of each request and re-enables it after the handler completes. +// Connections from loopback addresses are ignored and do not affect the +// scale-to-zero state. func Middleware(ctrl Controller) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if isLoopbackAddr(r.RemoteAddr) { + next.ServeHTTP(w, r) + return + } + if err := ctrl.Disable(r.Context()); err != nil { logger.FromContext(r.Context()).Error("failed to disable scale-to-zero", "error", err) http.Error(w, "failed to disable scale-to-zero", http.StatusInternalServerError) @@ -23,3 +31,17 @@ func Middleware(ctrl Controller) func(http.Handler) http.Handler { }) } } + +// isLoopbackAddr reports whether addr is a loopback address. +// addr may be an "ip:port" pair or a bare IP. +func isLoopbackAddr(addr string) bool { + host, _, err := net.SplitHostPort(addr) + if err != nil { + host = addr + } + ip := net.ParseIP(host) + if ip == nil { + return false + } + return ip.IsLoopback() +} diff --git a/server/lib/scaletozero/middleware_test.go b/server/lib/scaletozero/middleware_test.go new file mode 100644 index 00000000..c48b6122 --- /dev/null +++ b/server/lib/scaletozero/middleware_test.go @@ -0,0 +1,114 @@ +package scaletozero + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMiddlewareDisablesAndEnablesForExternalAddr(t *testing.T) { + t.Parallel() + mock := &mockScaleToZeroer{} + handler := Middleware(mock)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = "203.0.113.50:12345" + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, 1, mock.disableCalls) + assert.Equal(t, 1, mock.enableCalls) +} + +func TestMiddlewareSkipsLoopbackAddrs(t *testing.T) { + t.Parallel() + + loopbackAddrs := []struct { + name string + addr string + }{ + {"loopback-v4", "127.0.0.1:8080"}, + {"loopback-v6", "[::1]:8080"}, + } + + for _, tc := range loopbackAddrs { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + mock := &mockScaleToZeroer{} + var called bool + handler := Middleware(mock)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = tc.addr + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + assert.True(t, called, "handler should still be called") + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, 0, mock.disableCalls, "should not disable for loopback addr") + assert.Equal(t, 0, mock.enableCalls, "should not enable for loopback addr") + }) + } +} + +func TestMiddlewareDisableError(t *testing.T) { + t.Parallel() + mock := &mockScaleToZeroer{disableErr: assert.AnError} + var called bool + handler := Middleware(mock)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = "203.0.113.50:12345" + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + assert.False(t, called, "handler should not be called on disable error") + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Equal(t, 0, mock.enableCalls) +} + +func TestIsLoopbackAddr(t *testing.T) { + t.Parallel() + + tests := []struct { + addr string + loopback bool + }{ + // Loopback + {"127.0.0.1:80", true}, + {"[::1]:80", true}, + {"127.0.0.1", true}, + {"::1", true}, + // Non-loopback + {"10.0.0.1:80", false}, + {"172.16.0.1:80", false}, + {"192.168.1.1:80", false}, + {"203.0.113.50:80", false}, + {"8.8.8.8:53", false}, + {"[2001:db8::1]:80", false}, + // Unparseable + {"not-an-ip:80", false}, + {"", false}, + } + + for _, tc := range tests { + t.Run(tc.addr, func(t *testing.T) { + t.Parallel() + require.Equal(t, tc.loopback, isLoopbackAddr(tc.addr)) + }) + } +} diff --git a/server/lib/scaletozero/scaletozero.go b/server/lib/scaletozero/scaletozero.go index 96b67281..446d8385 100644 --- a/server/lib/scaletozero/scaletozero.go +++ b/server/lib/scaletozero/scaletozero.go @@ -17,27 +17,62 @@ type Controller interface { Disable(ctx context.Context) error // Enable re-enables scale-to-zero after it has previously been disabled. Enable(ctx context.Context) error + // Drain resets the active count to zero and re-enables scale-to-zero. + // After Drain is called the controller is frozen: subsequent Disable and + // Enable calls become no-ops. The frozen state is in-memory only and is + // cleared on process restart. This is intended for graceful shutdown / + // restart scenarios where we want to guarantee scale-to-zero stays enabled. + Drain(ctx context.Context) error } type unikraftCloudController struct { - path string + path string + mu sync.Mutex + drained bool } func NewUnikraftCloudController() Controller { - return &unikraftCloudController{path: unikraftScaleToZeroFile} + return &unikraftCloudController{path: unikraftScaleToZeroFile, drained: false} } func (c *unikraftCloudController) Disable(ctx context.Context) error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.drained { + return nil + } return c.write(ctx, "+") } func (c *unikraftCloudController) Enable(ctx context.Context) error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.drained { + return nil + } return c.write(ctx, "-") } +func (c *unikraftCloudController) Drain(ctx context.Context) error { + c.mu.Lock() + defer c.mu.Unlock() + + if !c.drained { + err := c.write(ctx, "=0") + if err != nil { + return err + } + } + c.drained = true + return nil +} + func (c *unikraftCloudController) write(ctx context.Context, char string) error { if _, err := os.Stat(c.path); err != nil { if os.IsNotExist(err) { + logger.FromContext(ctx).Info("scale-to-zero control file not found, skipping write", "path", c.path, "value", char) return nil } logger.FromContext(ctx).Error("failed to stat scale-to-zero control file", "path", c.path, "err", err) @@ -54,6 +89,7 @@ func (c *unikraftCloudController) write(ctx context.Context, char string) error logger.FromContext(ctx).Error("failed to write scale-to-zero control file", "path", c.path, "err", err) return err } + logger.FromContext(ctx).Info("scale-to-zero control file written", "path", c.path, "value", char) return nil } @@ -63,14 +99,17 @@ func NewNoopController() *NoopController { return &NoopController{} } func (NoopController) Disable(context.Context) error { return nil } func (NoopController) Enable(context.Context) error { return nil } +func (NoopController) Drain(context.Context) error { return nil } // Oncer wraps a Controller and ensures that Disable and Enable are called at most once. type Oncer struct { ctrl Controller disableOnce sync.Once enableOnce sync.Once + drainedOnce sync.Once disableErr error enableErr error + drainedErr error } func NewOncer(c Controller) *Oncer { return &Oncer{ctrl: c} } @@ -85,10 +124,16 @@ func (o *Oncer) Enable(ctx context.Context) error { return o.enableErr } +func (o *Oncer) Drain(ctx context.Context) error { + o.drainedOnce.Do(func() { o.drainedErr = o.ctrl.Drain(ctx) }) + return o.drainedErr +} + type DebouncedController struct { ctrl Controller mu sync.Mutex disabled bool + drained bool activeCount int } @@ -100,6 +145,10 @@ func (c *DebouncedController) Disable(ctx context.Context) error { c.mu.Lock() defer c.mu.Unlock() + if c.drained { + return nil + } + c.activeCount++ if c.disabled { return nil @@ -118,6 +167,10 @@ func (c *DebouncedController) Enable(ctx context.Context) error { c.mu.Lock() defer c.mu.Unlock() + if c.drained { + return nil + } + if c.activeCount > 0 { c.activeCount-- } @@ -134,3 +187,22 @@ func (c *DebouncedController) Enable(ctx context.Context) error { c.disabled = false return nil } + +func (c *DebouncedController) Drain(ctx context.Context) error { + c.mu.Lock() + defer c.mu.Unlock() + + c.drained = true + c.activeCount = 0 + + if !c.disabled { + return nil + } + + if err := c.ctrl.Drain(ctx); err != nil { + return err + } + + c.disabled = false + return nil +} diff --git a/server/lib/scaletozero/scaletozero_test.go b/server/lib/scaletozero/scaletozero_test.go index 368b5c9b..c1b532ff 100644 --- a/server/lib/scaletozero/scaletozero_test.go +++ b/server/lib/scaletozero/scaletozero_test.go @@ -110,8 +110,10 @@ type mockScaleToZeroer struct { mu sync.Mutex disableCalls int enableCalls int + drainCalls int disableErr error enableErr error + drainErr error } func (m *mockScaleToZeroer) Disable(ctx context.Context) error { @@ -127,6 +129,71 @@ func (m *mockScaleToZeroer) Enable(ctx context.Context) error { m.enableCalls++ return m.enableErr } + +func (m *mockScaleToZeroer) Drain(ctx context.Context) error { + m.mu.Lock() + defer m.mu.Unlock() + m.drainCalls++ + return m.drainErr +} +func TestDebouncedControllerDrainResetsCount(t *testing.T) { + t.Parallel() + mock := &mockScaleToZeroer{} + c := NewDebouncedController(mock) + + // Build up multiple holders. + require.NoError(t, c.Disable(t.Context())) + require.NoError(t, c.Disable(t.Context())) + require.NoError(t, c.Disable(t.Context())) + assert.Equal(t, 1, mock.disableCalls) // debounced + + // Drain should call Drain on the underlying controller regardless of count. + require.NoError(t, c.Drain(t.Context())) + assert.Equal(t, 1, mock.drainCalls) + assert.Equal(t, 0, mock.enableCalls) +} + +func TestDebouncedControllerDrainFreezesController(t *testing.T) { + t.Parallel() + mock := &mockScaleToZeroer{} + c := NewDebouncedController(mock) + + require.NoError(t, c.Disable(t.Context())) + require.NoError(t, c.Drain(t.Context())) + assert.Equal(t, 1, mock.drainCalls) + + // After drain, Disable and Enable should be no-ops. + require.NoError(t, c.Disable(t.Context())) + require.NoError(t, c.Enable(t.Context())) + assert.Equal(t, 1, mock.disableCalls, "Disable should not reach underlying controller after Drain") + assert.Equal(t, 0, mock.enableCalls, "Enable should not reach underlying controller after Drain") +} + +func TestDebouncedControllerDrainWhenAlreadyEnabled(t *testing.T) { + t.Parallel() + mock := &mockScaleToZeroer{} + c := NewDebouncedController(mock) + + // Drain without any prior Disable — already enabled, so no Enable write needed. + require.NoError(t, c.Drain(t.Context())) + assert.Equal(t, 0, mock.enableCalls) + + // Controller should still be frozen. + require.NoError(t, c.Disable(t.Context())) + assert.Equal(t, 0, mock.disableCalls) +} + +func TestDebouncedControllerDrainError(t *testing.T) { + t.Parallel() + mock := &mockScaleToZeroer{drainErr: assert.AnError} + c := NewDebouncedController(mock) + + require.NoError(t, c.Disable(t.Context())) + err := c.Drain(t.Context()) + require.Error(t, err) + assert.Equal(t, 1, mock.drainCalls) +} + func TestUnikraftCloudControllerNoFileNoError(t *testing.T) { t.Parallel() p := filepath.Join(t.TempDir(), "scale_to_zero_disable") @@ -157,6 +224,57 @@ func TestUnikraftCloudControllerWritesPlusAndMinus(t *testing.T) { assert.Equal(t, []byte("-"), b) } +func TestUnikraftCloudControllerDrainWritesEqualsZero(t *testing.T) { + t.Parallel() + dir := t.TempDir() + p := filepath.Join(dir, "scale_to_zero_disable") + require.NoError(t, os.WriteFile(p, []byte{}, 0o600)) + c := &unikraftCloudController{path: p} + + require.NoError(t, c.Disable(t.Context())) + require.NoError(t, c.Drain(t.Context())) + + b, err := os.ReadFile(p) + require.NoError(t, err) + assert.Equal(t, []byte("=0"), b) +} + +func TestUnikraftCloudControllerDrainFreezesDisableEnable(t *testing.T) { + t.Parallel() + dir := t.TempDir() + p := filepath.Join(dir, "scale_to_zero_disable") + require.NoError(t, os.WriteFile(p, []byte{}, 0o600)) + c := &unikraftCloudController{path: p} + + require.NoError(t, c.Drain(t.Context())) + + // Disable and Enable should be no-ops after drain. + require.NoError(t, c.Disable(t.Context())) + b, err := os.ReadFile(p) + require.NoError(t, err) + assert.Equal(t, []byte("=0"), b, "Disable should not overwrite after Drain") + + require.NoError(t, c.Enable(t.Context())) + b, err = os.ReadFile(p) + require.NoError(t, err) + assert.Equal(t, []byte("=0"), b, "Enable should not overwrite after Drain") +} + +func TestUnikraftCloudControllerDrainIdempotent(t *testing.T) { + t.Parallel() + dir := t.TempDir() + p := filepath.Join(dir, "scale_to_zero_disable") + require.NoError(t, os.WriteFile(p, []byte{}, 0o600)) + c := &unikraftCloudController{path: p} + + require.NoError(t, c.Drain(t.Context())) + require.NoError(t, c.Drain(t.Context())) + + b, err := os.ReadFile(p) + require.NoError(t, err) + assert.Equal(t, []byte("=0"), b) +} + func TestUnikraftCloudControllerTruncatesExistingContent(t *testing.T) { t.Parallel() dir := t.TempDir() From 03e7c296698b8639b5ed7cba55ed634656e12f7f Mon Sep 17 00:00:00 2001 From: Tanmay Sardesai Date: Wed, 11 Feb 2026 13:23:29 -0800 Subject: [PATCH 2/2] address bugbot comments --- server/cmd/api/api/api.go | 20 ++++++--- server/cmd/api/api/process.go | 9 +++- server/lib/scaletozero/scaletozero.go | 41 +++++++++++++----- server/lib/scaletozero/scaletozero_test.go | 48 ++++++++++++++++++++++ 4 files changed, 99 insertions(+), 19 deletions(-) diff --git a/server/cmd/api/api/api.go b/server/cmd/api/api/api.go index 78cf5aba..4c3d3663 100644 --- a/server/cmd/api/api/api.go +++ b/server/cmd/api/api/api.go @@ -31,8 +31,9 @@ type ApiService struct { watches map[string]*fsWatch // Process management - procMu sync.RWMutex - procs map[string]*processHandle + procMu sync.RWMutex + procs map[string]*processHandle + shuttingDown bool // Neko authenticated client nekoAuthClient *nekoclient.AuthClient @@ -300,10 +301,13 @@ func (s *ApiService) ListRecorders(ctx context.Context, _ oapi.ListRecordersRequ } // killAllProcesses sends SIGKILL to every tracked process that is still running. +// It acquires the write lock and sets shuttingDown so that ProcessSpawn rejects +// new processes once the kill pass begins. func (s *ApiService) killAllProcesses(ctx context.Context) error { log := logger.FromContext(ctx) - s.procMu.RLock() - defer s.procMu.RUnlock() + s.procMu.Lock() + defer s.procMu.Unlock() + s.shuttingDown = true var result *multierror.Error for id, h := range s.procs { @@ -320,8 +324,12 @@ func (s *ApiService) killAllProcesses(ctx context.Context) error { continue } if err := h.cmd.Process.Kill(); err != nil { - result = multierror.Append(result, fmt.Errorf("process %s: %w", id, err)) - log.Error("failed to kill process", "process_id", id, "err", err) + // A process may already have exited between the state check and the + // kill call; treat that as a benign race rather than a fatal error. + if !errors.Is(err, os.ErrProcessDone) { + result = multierror.Append(result, fmt.Errorf("process %s: %w", id, err)) + log.Error("failed to kill process", "process_id", id, "err", err) + } } } return result.ErrorOrNil() diff --git a/server/cmd/api/api/process.go b/server/cmd/api/api/process.go index 3dd6aeb4..5d5045db 100644 --- a/server/cmd/api/api/process.go +++ b/server/cmd/api/api/process.go @@ -323,8 +323,14 @@ func (s *ApiService) ProcessSpawn(ctx context.Context, request oapi.ProcessSpawn doneCh: make(chan struct{}), } - // Store handle + // Store handle; reject if the server is shutting down. s.procMu.Lock() + if s.shuttingDown { + s.procMu.Unlock() + // The process was already started; kill it immediately. + _ = cmd.Process.Kill() + return oapi.ProcessSpawn500JSONResponse{InternalErrorJSONResponse: oapi.InternalErrorJSONResponse{Message: "server is shutting down"}}, nil + } if s.procs == nil { s.procs = make(map[string]*processHandle) } @@ -624,7 +630,6 @@ func (s *ApiService) ProcessResize(ctx context.Context, request oapi.ProcessResi return oapi.ProcessResize200JSONResponse(oapi.OkResponse{Ok: true}), nil } - // writeJSON writes a JSON response with the given status code. // Unlike http.Error, this sets the correct Content-Type for JSON. func writeJSON(w http.ResponseWriter, status int, body string) { diff --git a/server/lib/scaletozero/scaletozero.go b/server/lib/scaletozero/scaletozero.go index 446d8385..ba3b0fcb 100644 --- a/server/lib/scaletozero/scaletozero.go +++ b/server/lib/scaletozero/scaletozero.go @@ -102,8 +102,11 @@ func (NoopController) Enable(context.Context) error { return nil } func (NoopController) Drain(context.Context) error { return nil } // Oncer wraps a Controller and ensures that Disable and Enable are called at most once. +// After a successful Drain, Disable and Enable become no-ops per the Controller contract. type Oncer struct { ctrl Controller + mu sync.Mutex + drained bool disableOnce sync.Once enableOnce sync.Once drainedOnce sync.Once @@ -115,17 +118,36 @@ type Oncer struct { func NewOncer(c Controller) *Oncer { return &Oncer{ctrl: c} } func (o *Oncer) Disable(ctx context.Context) error { + o.mu.Lock() + if o.drained { + o.mu.Unlock() + return nil + } + o.mu.Unlock() o.disableOnce.Do(func() { o.disableErr = o.ctrl.Disable(ctx) }) return o.disableErr } func (o *Oncer) Enable(ctx context.Context) error { + o.mu.Lock() + if o.drained { + o.mu.Unlock() + return nil + } + o.mu.Unlock() o.enableOnce.Do(func() { o.enableErr = o.ctrl.Enable(ctx) }) return o.enableErr } func (o *Oncer) Drain(ctx context.Context) error { - o.drainedOnce.Do(func() { o.drainedErr = o.ctrl.Drain(ctx) }) + o.drainedOnce.Do(func() { + o.drainedErr = o.ctrl.Drain(ctx) + if o.drainedErr == nil { + o.mu.Lock() + o.drained = true + o.mu.Unlock() + } + }) return o.drainedErr } @@ -192,17 +214,14 @@ func (c *DebouncedController) Drain(ctx context.Context) error { c.mu.Lock() defer c.mu.Unlock() - c.drained = true - c.activeCount = 0 - - if !c.disabled { - return nil - } - - if err := c.ctrl.Drain(ctx); err != nil { - return err + if c.disabled { + if err := c.ctrl.Drain(ctx); err != nil { + return err + } + c.disabled = false } - c.disabled = false + c.drained = true + c.activeCount = 0 return nil } diff --git a/server/lib/scaletozero/scaletozero_test.go b/server/lib/scaletozero/scaletozero_test.go index c1b532ff..66aeddf4 100644 --- a/server/lib/scaletozero/scaletozero_test.go +++ b/server/lib/scaletozero/scaletozero_test.go @@ -192,6 +192,24 @@ func TestDebouncedControllerDrainError(t *testing.T) { err := c.Drain(t.Context()) require.Error(t, err) assert.Equal(t, 1, mock.drainCalls) + + // Controller must NOT be frozen after a failed drain. + // The prior Disable is still active (activeCount=1, disabled=true). + // Enable should reach the underlying controller once activeCount drops to 0. + require.NoError(t, c.Enable(t.Context())) + assert.Equal(t, 1, mock.enableCalls, "Enable should work after failed Drain") + + // After re-enabling, a new Disable should reach the underlying controller. + require.NoError(t, c.Disable(t.Context())) + assert.Equal(t, 2, mock.disableCalls, "Disable should work after failed Drain") + + // A subsequent successful drain should freeze the controller. + mock.drainErr = nil + require.NoError(t, c.Drain(t.Context())) + assert.Equal(t, 2, mock.drainCalls) + + require.NoError(t, c.Disable(t.Context())) + assert.Equal(t, 2, mock.disableCalls, "Disable should be a no-op after successful Drain") } func TestUnikraftCloudControllerNoFileNoError(t *testing.T) { @@ -275,6 +293,36 @@ func TestUnikraftCloudControllerDrainIdempotent(t *testing.T) { assert.Equal(t, []byte("=0"), b) } +func TestOncerDrainFreezesDisableEnable(t *testing.T) { + t.Parallel() + mock := &mockScaleToZeroer{} + o := NewOncer(mock) + + require.NoError(t, o.Drain(t.Context())) + assert.Equal(t, 1, mock.drainCalls) + + // After a successful Drain, Disable and Enable must be no-ops. + require.NoError(t, o.Disable(t.Context())) + assert.Equal(t, 0, mock.disableCalls, "Disable should be a no-op after Drain") + + require.NoError(t, o.Enable(t.Context())) + assert.Equal(t, 0, mock.enableCalls, "Enable should be a no-op after Drain") +} + +func TestOncerDrainErrorDoesNotFreeze(t *testing.T) { + t.Parallel() + mock := &mockScaleToZeroer{drainErr: assert.AnError} + o := NewOncer(mock) + + err := o.Drain(t.Context()) + require.Error(t, err) + assert.Equal(t, 1, mock.drainCalls) + + // Failed Drain should not freeze the controller; Disable should still work. + require.NoError(t, o.Disable(t.Context())) + assert.Equal(t, 1, mock.disableCalls, "Disable should work after failed Drain") +} + func TestUnikraftCloudControllerTruncatesExistingContent(t *testing.T) { t.Parallel() dir := t.TempDir()