Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 55 additions & 3 deletions server/cmd/api/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ import (
"fmt"
"os"
"os/exec"
"path/filepath"
"sync"
"time"

"github.com/hashicorp/go-multierror"
"github.com/onkernel/kernel-images/server/lib/devtoolsproxy"
"github.com/onkernel/kernel-images/server/lib/logger"
"github.com/onkernel/kernel-images/server/lib/nekoclient"
Expand All @@ -29,8 +31,9 @@ type ApiService struct {
watches map[string]*fsWatch

// Process management
procMu sync.RWMutex
procs map[string]*processHandle
procMu sync.RWMutex
procs map[string]*processHandle
shuttingDown bool

// Neko authenticated client
nekoAuthClient *nekoclient.AuthClient
Expand Down Expand Up @@ -297,6 +300,55 @@ func (s *ApiService) ListRecorders(ctx context.Context, _ oapi.ListRecordersRequ
return oapi.ListRecorders200JSONResponse(infos), nil
}

// killAllProcesses sends SIGKILL to every tracked process that is still running.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is new behavior — we'll need to reach out to heavy browser pool users and make sure they don't depend on process execs carrying over between session re-use.

// It acquires the write lock and sets shuttingDown so that ProcessSpawn rejects
// new processes once the kill pass begins.
func (s *ApiService) killAllProcesses(ctx context.Context) error {
log := logger.FromContext(ctx)
s.procMu.Lock()
defer s.procMu.Unlock()
s.shuttingDown = true

var result *multierror.Error
for id, h := range s.procs {
if h.state() != "running" {
continue
}
if h.cmd.Process == nil {
continue
}
// supervisorctl handles the lifecycle of long running processes so we don't want to kill
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure we should make an exception for supervisorctl here. if someone sent a process exec for a supervisorctl command it shouldn't matter — we're going to hard reset supervisor services anyway, right? or are we doing our own supervisorctl hard-reset of things like chromium during server shutdown? that feels a little weird but i could live with it — just want to make sure the reasoning is clear.

// any active supervisorctl processes. For example it is used to restart kernel-images-api
// and killing that process would break the restart process.
if filepath.Base(h.cmd.Path) == "supervisorctl" {
continue
}
if err := h.cmd.Process.Kill(); err != nil {
// A process may already have exited between the state check and the
// kill call; treat that as a benign race rather than a fatal error.
if !errors.Is(err, os.ErrProcessDone) {
result = multierror.Append(result, fmt.Errorf("process %s: %w", id, err))
log.Error("failed to kill process", "process_id", id, "err", err)
}
}
}
return result.ErrorOrNil()
}

func (s *ApiService) Shutdown(ctx context.Context) error {
return s.recordManager.StopAll(ctx)
var wg sync.WaitGroup
var killErr, stopErr error

wg.Add(2)
go func() {
defer wg.Done()
killErr = s.killAllProcesses(ctx)
}()
go func() {
defer wg.Done()
stopErr = s.recordManager.StopAll(ctx)
}()
wg.Wait()

return multierror.Append(killErr, stopErr).ErrorOrNil()
}
9 changes: 7 additions & 2 deletions server/cmd/api/api/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,14 @@ func (s *ApiService) ProcessSpawn(ctx context.Context, request oapi.ProcessSpawn
doneCh: make(chan struct{}),
}

// Store handle
// Store handle; reject if the server is shutting down.
s.procMu.Lock()
if s.shuttingDown {
s.procMu.Unlock()
// The process was already started; kill it immediately.
_ = cmd.Process.Kill()
return oapi.ProcessSpawn500JSONResponse{InternalErrorJSONResponse: oapi.InternalErrorJSONResponse{Message: "server is shutting down"}}, nil
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shutdown spawn path leaks process lifecycle

Medium Severity

When ProcessSpawn hits s.shuttingDown after cmd.Start(), it returns immediately after cmd.Process.Kill() without waiting for exit or running normal cleanup. This skips the waiter path that closes pipes and balances s.stz.Enable, and it also ignores kill failure, leaving a potentially untracked process lifecycle.

Fix in Cursor Fix in Web

if s.procs == nil {
s.procs = make(map[string]*processHandle)
}
Expand Down Expand Up @@ -624,7 +630,6 @@ func (s *ApiService) ProcessResize(ctx context.Context, request oapi.ProcessResi
return oapi.ProcessResize200JSONResponse(oapi.OkResponse{Ok: true}), nil
}


// writeJSON writes a JSON response with the given status code.
// Unlike http.Error, this sets the correct Content-Type for JSON.
func writeJSON(w http.ResponseWriter, status int, body string) {
Expand Down
3 changes: 3 additions & 0 deletions server/cmd/api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,9 @@ func main() {
defer shutdownCancel()
g, _ := errgroup.WithContext(shutdownCtx)

g.Go(func() error {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stz.Drain runs concurrently with srv.Shutdown here. this works but makes the state transitions harder to reason about — Drain races with in-flight requests' deferred Enable() calls. consider sequencing: drain HTTP servers first (letting all in-flight Enables run normally), then call stz.Drain after. at that point the controller is already at rest and Drain is just a safety net freeze. same outcome, easier to verify correctness.

return stz.Drain(shutdownCtx)
})
g.Go(func() error {
return srv.Shutdown(shutdownCtx)
})
Expand Down
2 changes: 2 additions & 0 deletions server/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ require (
github.com/glebarez/sqlite v1.11.0
github.com/go-chi/chi/v5 v5.2.1
github.com/google/uuid v1.6.0
github.com/hashicorp/go-multierror v1.1.1
github.com/kelseyhightower/envconfig v1.4.0
github.com/klauspost/compress v1.18.3
github.com/m1k1o/neko/server v0.0.0-20251008185748-46e2fc7d3866
Expand Down Expand Up @@ -51,6 +52,7 @@ require (
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/go-openapi/jsonpointer v0.21.0 // indirect
github.com/go-openapi/swag v0.23.0 // indirect
github.com/hashicorp/errwrap v1.0.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/josharian/intern v1.0.0 // indirect
Expand Down
4 changes: 4 additions & 0 deletions server/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2/go.mod h1:pkJQ2tZHJ0aFOVEEot6oZmaVEZcRme73eIFmhiVuRWs=
github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
Expand Down
22 changes: 22 additions & 0 deletions server/lib/scaletozero/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,24 @@ package scaletozero

import (
"context"
"net"
"net/http"

"github.com/onkernel/kernel-images/server/lib/logger"
)

// Middleware returns a standard net/http middleware that disables scale-to-zero
// at the start of each request and re-enables it after the handler completes.
// Connections from loopback addresses are ignored and do not affect the
// scale-to-zero state.
func Middleware(ctrl Controller) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if isLoopbackAddr(r.RemoteAddr) {
next.ServeHTTP(w, r)
return
}

if err := ctrl.Disable(r.Context()); err != nil {
logger.FromContext(r.Context()).Error("failed to disable scale-to-zero", "error", err)
http.Error(w, "failed to disable scale-to-zero", http.StatusInternalServerError)
Expand All @@ -23,3 +31,17 @@ func Middleware(ctrl Controller) func(http.Handler) http.Handler {
})
}
}

// isLoopbackAddr reports whether addr is a loopback address.
// addr may be an "ip:port" pair or a bare IP.
func isLoopbackAddr(addr string) bool {
host, _, err := net.SplitHostPort(addr)
if err != nil {
host = addr
}
ip := net.ParseIP(host)
if ip == nil {
return false
}
return ip.IsLoopback()
}
114 changes: 114 additions & 0 deletions server/lib/scaletozero/middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package scaletozero

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestMiddlewareDisablesAndEnablesForExternalAddr(t *testing.T) {
t.Parallel()
mock := &mockScaleToZeroer{}
handler := Middleware(mock)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))

req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "203.0.113.50:12345"
rec := httptest.NewRecorder()

handler.ServeHTTP(rec, req)

assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, 1, mock.disableCalls)
assert.Equal(t, 1, mock.enableCalls)
}

func TestMiddlewareSkipsLoopbackAddrs(t *testing.T) {
t.Parallel()

loopbackAddrs := []struct {
name string
addr string
}{
{"loopback-v4", "127.0.0.1:8080"},
{"loopback-v6", "[::1]:8080"},
}

for _, tc := range loopbackAddrs {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
mock := &mockScaleToZeroer{}
var called bool
handler := Middleware(mock)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
}))

req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = tc.addr
rec := httptest.NewRecorder()

handler.ServeHTTP(rec, req)

assert.True(t, called, "handler should still be called")
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, 0, mock.disableCalls, "should not disable for loopback addr")
assert.Equal(t, 0, mock.enableCalls, "should not enable for loopback addr")
})
}
}

func TestMiddlewareDisableError(t *testing.T) {
t.Parallel()
mock := &mockScaleToZeroer{disableErr: assert.AnError}
var called bool
handler := Middleware(mock)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
}))

req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "203.0.113.50:12345"
rec := httptest.NewRecorder()

handler.ServeHTTP(rec, req)

assert.False(t, called, "handler should not be called on disable error")
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Equal(t, 0, mock.enableCalls)
}

func TestIsLoopbackAddr(t *testing.T) {
t.Parallel()

tests := []struct {
addr string
loopback bool
}{
// Loopback
{"127.0.0.1:80", true},
{"[::1]:80", true},
{"127.0.0.1", true},
{"::1", true},
// Non-loopback
{"10.0.0.1:80", false},
{"172.16.0.1:80", false},
{"192.168.1.1:80", false},
{"203.0.113.50:80", false},
{"8.8.8.8:53", false},
{"[2001:db8::1]:80", false},
// Unparseable
{"not-an-ip:80", false},
{"", false},
}

for _, tc := range tests {
t.Run(tc.addr, func(t *testing.T) {
t.Parallel()
require.Equal(t, tc.loopback, isLoopbackAddr(tc.addr))
})
}
}
Loading
Loading