From 8e0348bbe6a10738f21c1067154976bf4d07eac0 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Fri, 10 Apr 2026 13:59:26 +0200 Subject: [PATCH] fix: serialize concurrent RunSession calls to prevent tool_use/tool_result mismatch When two HTTP requests target the same session concurrently, the second can inject user messages while the first is mid-tool-call, producing a tool_use without a matching tool_result that causes Anthropic API errors. Add a per-session streaming mutex to activeRuntimes. RunSession uses TryLock to fail fast with ErrSessionBusy (HTTP 409) when the session is already streaming. Message addition is deferred until after the lock is acquired, so a rejected request never mutates the session. TryLock is called on the calling goroutine inside RunSession; Unlock is deferred in the background goroutine after RunStream completes. The lock is held continuously from before message addition through the entire stream including all tool-call processing. Assisted-By: docker-agent --- pkg/server/server.go | 4 + pkg/server/session_manager.go | 50 +++++-- pkg/server/session_manager_test.go | 224 +++++++++++++++++++++++++++++ 3 files changed, 263 insertions(+), 15 deletions(-) create mode 100644 pkg/server/session_manager_test.go diff --git a/pkg/server/server.go b/pkg/server/server.go index b9cf3b626..c81ce54bd 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -4,6 +4,7 @@ import ( "cmp" "context" "encoding/json" + "errors" "fmt" "log/slog" "net" @@ -285,6 +286,9 @@ func (s *Server) runAgent(c echo.Context) error { streamChan, err := s.sm.RunSession(c.Request().Context(), sessionID, agentFilename, currentAgent, messages) if err != nil { + if errors.Is(err, ErrSessionBusy) { + return echo.NewHTTPError(http.StatusConflict, err.Error()) + } return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to run session: %v", err)) } diff --git a/pkg/server/session_manager.go b/pkg/server/session_manager.go index 6c26d3b58..98eddfb8b 100644 --- a/pkg/server/session_manager.go +++ b/pkg/server/session_manager.go @@ -27,6 +27,8 @@ type activeRuntimes struct { cancel context.CancelFunc session *session.Session // The actual session object used by the runtime titleGen *sessiontitle.Generator // Title generator (includes fallback models) + + streaming sync.Mutex // Held while a RunStream is in progress; serialises concurrent requests } // SessionManager manages sessions for HTTP and Connect-RPC servers. @@ -134,6 +136,9 @@ func (sm *SessionManager) DeleteSession(ctx context.Context, sessionID string) e return nil } +// ErrSessionBusy is returned when a session is already processing a request. +var ErrSessionBusy = errors.New("session is already processing a request") + // RunSession runs a session with the given messages. func (sm *SessionManager) RunSession(ctx context.Context, sessionID, agentFilename, currentAgent string, messages []api.Message) (<-chan runtime.Event, error) { sm.mux.Lock() @@ -146,19 +151,6 @@ func (sm *SessionManager) RunSession(ctx context.Context, sessionID, agentFilena rc := sm.runConfig.Clone() rc.WorkingDir = sess.WorkingDir - // Collect user messages for potential title generation - var userMessages []string - for _, msg := range messages { - sess.AddMessage(session.UserMessage(msg.Content, msg.MultiContent...)) - if msg.Content != "" { - userMessages = append(userMessages, msg.Content) - } - } - - if err := sm.sessionStore.UpdateSession(ctx, sess); err != nil { - return nil, err - } - runtimeSession, exists := sm.runtimeSessions.Load(sessionID) streamCtx, cancel := context.WithCancel(ctx) var titleGen *sessiontitle.Generator @@ -177,17 +169,45 @@ func (sm *SessionManager) RunSession(ctx context.Context, sessionID, agentFilena } sm.runtimeSessions.Store(sessionID, runtimeSession) } else { - // Update the session pointer in case it was reloaded - runtimeSession.session = sess titleGen = runtimeSession.titleGen } + // Reject the request immediately if the session is already streaming. + // This prevents interleaving user messages while a tool call is in + // progress, which would produce a tool_use without a matching + // tool_result and cause provider errors. + if !runtimeSession.streaming.TryLock() { + cancel() + return nil, ErrSessionBusy + } + + // Now that we hold the streaming lock, it is safe to mutate the session. + // Collect user messages for potential title generation + var userMessages []string + for _, msg := range messages { + sess.AddMessage(session.UserMessage(msg.Content, msg.MultiContent...)) + if msg.Content != "" { + userMessages = append(userMessages, msg.Content) + } + } + + if err := sm.sessionStore.UpdateSession(ctx, sess); err != nil { + runtimeSession.streaming.Unlock() + cancel() + return nil, err + } + + // Update the session pointer so the runtime sees the latest messages. + runtimeSession.session = sess + streamChan := make(chan runtime.Event) // Check if we need to generate a title needsTitle := sess.Title == "" && len(userMessages) > 0 && titleGen != nil go func() { + defer runtimeSession.streaming.Unlock() + // Start title generation in parallel if needed if needsTitle { go sm.generateTitle(ctx, sess, titleGen, userMessages, streamChan) diff --git a/pkg/server/session_manager_test.go b/pkg/server/session_manager_test.go new file mode 100644 index 000000000..39b93fd68 --- /dev/null +++ b/pkg/server/session_manager_test.go @@ -0,0 +1,224 @@ +package server + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/api" + "github.com/docker/docker-agent/pkg/concurrent" + "github.com/docker/docker-agent/pkg/config" + "github.com/docker/docker-agent/pkg/runtime" + "github.com/docker/docker-agent/pkg/session" + "github.com/docker/docker-agent/pkg/sessiontitle" + "github.com/docker/docker-agent/pkg/tools" +) + +// fakeRuntime is a minimal Runtime that records concurrent RunStream calls. +type fakeRuntime struct { + runtime.Runtime + + concurrentStreams atomic.Int32 + maxConcurrent atomic.Int32 + streamDelay time.Duration +} + +func (f *fakeRuntime) RunStream(_ context.Context, _ *session.Session) <-chan runtime.Event { + cur := f.concurrentStreams.Add(1) + for { + old := f.maxConcurrent.Load() + if cur <= old || f.maxConcurrent.CompareAndSwap(old, cur) { + break + } + } + + ch := make(chan runtime.Event) + go func() { + time.Sleep(f.streamDelay) + f.concurrentStreams.Add(-1) + close(ch) + }() + return ch +} + +func (f *fakeRuntime) Resume(_ context.Context, _ runtime.ResumeRequest) {} + +func (f *fakeRuntime) ResumeElicitation(_ context.Context, _ tools.ElicitationAction, _ map[string]any) error { + return nil +} + +func newTestSessionManager(t *testing.T, sess *session.Session, fake *fakeRuntime) *SessionManager { + t.Helper() + + ctx := t.Context() + store := session.NewInMemorySessionStore() + require.NoError(t, store.AddSession(ctx, sess)) + + sm := &SessionManager{ + runtimeSessions: concurrent.NewMap[string, *activeRuntimes](), + sessionStore: store, + Sources: config.Sources{}, + runConfig: &config.RuntimeConfig{}, + } + + // Pre-register a runtime for this session so RunSession skips agent loading. + sm.runtimeSessions.Store(sess.ID, &activeRuntimes{ + runtime: fake, + session: sess, + titleGen: (*sessiontitle.Generator)(nil), + }) + + return sm +} + +// TestRunSession_ConcurrentRequestReturnsErrSessionBusy verifies that a +// second RunSession call on a session that is already streaming returns +// ErrSessionBusy instead of silently interleaving messages. +func TestRunSession_ConcurrentRequestReturnsErrSessionBusy(t *testing.T) { + t.Parallel() + + ctx := t.Context() + sess := session.New() + fake := &fakeRuntime{streamDelay: 500 * time.Millisecond} + sm := newTestSessionManager(t, sess, fake) + + // Start the first stream. + ch1, err := sm.RunSession(ctx, sess.ID, "agent", "root", []api.Message{ + {Content: "first"}, + }) + require.NoError(t, err) + + // Give the goroutine a moment to acquire the streaming lock. + time.Sleep(50 * time.Millisecond) + + // The second request should fail immediately with ErrSessionBusy. + _, err = sm.RunSession(ctx, sess.ID, "agent", "root", []api.Message{ + {Content: "second"}, + }) + require.ErrorIs(t, err, ErrSessionBusy) + + // Drain first stream to let it complete. + for range ch1 { + } + + // After the first stream finishes, a new request should succeed. + ch3, err := sm.RunSession(ctx, sess.ID, "agent", "root", []api.Message{ + {Content: "third"}, + }) + require.NoError(t, err) + for range ch3 { + } +} + +// TestRunSession_MessagesNotAddedWhenBusy verifies that when a session +// is busy, the rejected request does not mutate the session's messages. +func TestRunSession_MessagesNotAddedWhenBusy(t *testing.T) { + t.Parallel() + + ctx := t.Context() + sess := session.New() + fake := &fakeRuntime{streamDelay: 500 * time.Millisecond} + sm := newTestSessionManager(t, sess, fake) + + ch1, err := sm.RunSession(ctx, sess.ID, "agent", "root", []api.Message{ + {Content: "first"}, + }) + require.NoError(t, err) + + time.Sleep(50 * time.Millisecond) + + msgCountBefore := len(sess.GetAllMessages()) + + _, err = sm.RunSession(ctx, sess.ID, "agent", "root", []api.Message{ + {Content: "should not be added"}, + }) + require.ErrorIs(t, err, ErrSessionBusy) + + // Messages should not have been added. + assert.Len(t, sess.GetAllMessages(), msgCountBefore) + + for range ch1 { + } +} + +// TestRunSession_SequentialRequestsSucceed verifies that sequential +// (non-overlapping) requests on the same session work normally. +func TestRunSession_SequentialRequestsSucceed(t *testing.T) { + t.Parallel() + + ctx := t.Context() + sess := session.New() + fake := &fakeRuntime{streamDelay: 10 * time.Millisecond} + sm := newTestSessionManager(t, sess, fake) + + for range 3 { + ch, err := sm.RunSession(ctx, sess.ID, "agent", "root", []api.Message{ + {Content: "hello"}, + }) + require.NoError(t, err) + for range ch { + } + } + + assert.Equal(t, int32(1), fake.maxConcurrent.Load()) +} + +// TestRunSession_DifferentSessionsConcurrently verifies that concurrent +// requests on *different* sessions are not blocked by each other. +func TestRunSession_DifferentSessionsConcurrently(t *testing.T) { + t.Parallel() + + ctx := t.Context() + store := session.NewInMemorySessionStore() + fake1 := &fakeRuntime{streamDelay: 200 * time.Millisecond} + fake2 := &fakeRuntime{streamDelay: 200 * time.Millisecond} + + sess1 := session.New() + sess2 := session.New() + require.NoError(t, store.AddSession(ctx, sess1)) + require.NoError(t, store.AddSession(ctx, sess2)) + + sm := &SessionManager{ + runtimeSessions: concurrent.NewMap[string, *activeRuntimes](), + sessionStore: store, + Sources: config.Sources{}, + runConfig: &config.RuntimeConfig{}, + } + + sm.runtimeSessions.Store(sess1.ID, &activeRuntimes{ + runtime: fake1, session: sess1, titleGen: (*sessiontitle.Generator)(nil), + }) + sm.runtimeSessions.Store(sess2.ID, &activeRuntimes{ + runtime: fake2, session: sess2, titleGen: (*sessiontitle.Generator)(nil), + }) + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + ch, err := sm.RunSession(ctx, sess1.ID, "agent", "root", []api.Message{{Content: "a"}}) + assert.NoError(t, err) + for range ch { + } + }() + + go func() { + defer wg.Done() + ch, err := sm.RunSession(ctx, sess2.ID, "agent", "root", []api.Message{{Content: "b"}}) + assert.NoError(t, err) + for range ch { + } + }() + + wg.Wait() + + // Both sessions should have streamed (1 each). + assert.Equal(t, int32(1), fake1.maxConcurrent.Load()) + assert.Equal(t, int32(1), fake2.maxConcurrent.Load()) +}