Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"cmp"
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net"
Expand Down Expand Up @@ -285,6 +286,9 @@ func (s *Server) runAgent(c echo.Context) error {

streamChan, err := s.sm.RunSession(c.Request().Context(), sessionID, agentFilename, currentAgent, messages)
if err != nil {
if errors.Is(err, ErrSessionBusy) {
return echo.NewHTTPError(http.StatusConflict, err.Error())
}
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to run session: %v", err))
}

Expand Down
50 changes: 35 additions & 15 deletions pkg/server/session_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ type activeRuntimes struct {
cancel context.CancelFunc
session *session.Session // The actual session object used by the runtime
titleGen *sessiontitle.Generator // Title generator (includes fallback models)

streaming sync.Mutex // Held while a RunStream is in progress; serialises concurrent requests
}

// SessionManager manages sessions for HTTP and Connect-RPC servers.
Expand Down Expand Up @@ -134,6 +136,9 @@ func (sm *SessionManager) DeleteSession(ctx context.Context, sessionID string) e
return nil
}

// ErrSessionBusy is returned when a session is already processing a request.
var ErrSessionBusy = errors.New("session is already processing a request")

// RunSession runs a session with the given messages.
func (sm *SessionManager) RunSession(ctx context.Context, sessionID, agentFilename, currentAgent string, messages []api.Message) (<-chan runtime.Event, error) {
sm.mux.Lock()
Expand All @@ -146,19 +151,6 @@ func (sm *SessionManager) RunSession(ctx context.Context, sessionID, agentFilena
rc := sm.runConfig.Clone()
rc.WorkingDir = sess.WorkingDir

// Collect user messages for potential title generation
var userMessages []string
for _, msg := range messages {
sess.AddMessage(session.UserMessage(msg.Content, msg.MultiContent...))
if msg.Content != "" {
userMessages = append(userMessages, msg.Content)
}
}

if err := sm.sessionStore.UpdateSession(ctx, sess); err != nil {
return nil, err
}

runtimeSession, exists := sm.runtimeSessions.Load(sessionID)
streamCtx, cancel := context.WithCancel(ctx)
var titleGen *sessiontitle.Generator
Expand All @@ -177,17 +169,45 @@ func (sm *SessionManager) RunSession(ctx context.Context, sessionID, agentFilena
}
sm.runtimeSessions.Store(sessionID, runtimeSession)
} else {
// Update the session pointer in case it was reloaded
runtimeSession.session = sess
titleGen = runtimeSession.titleGen
}

// Reject the request immediately if the session is already streaming.
// This prevents interleaving user messages while a tool call is in
// progress, which would produce a tool_use without a matching
// tool_result and cause provider errors.
if !runtimeSession.streaming.TryLock() {
cancel()
return nil, ErrSessionBusy
}

// Now that we hold the streaming lock, it is safe to mutate the session.
// Collect user messages for potential title generation
var userMessages []string
for _, msg := range messages {
sess.AddMessage(session.UserMessage(msg.Content, msg.MultiContent...))
if msg.Content != "" {
userMessages = append(userMessages, msg.Content)
}
}

if err := sm.sessionStore.UpdateSession(ctx, sess); err != nil {
runtimeSession.streaming.Unlock()
cancel()
return nil, err
}

// Update the session pointer so the runtime sees the latest messages.
runtimeSession.session = sess

streamChan := make(chan runtime.Event)

// Check if we need to generate a title
needsTitle := sess.Title == "" && len(userMessages) > 0 && titleGen != nil

go func() {
defer runtimeSession.streaming.Unlock()

// Start title generation in parallel if needed
if needsTitle {
go sm.generateTitle(ctx, sess, titleGen, userMessages, streamChan)
Expand Down
224 changes: 224 additions & 0 deletions pkg/server/session_manager_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
package server

import (
"context"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/docker/docker-agent/pkg/api"
"github.com/docker/docker-agent/pkg/concurrent"
"github.com/docker/docker-agent/pkg/config"
"github.com/docker/docker-agent/pkg/runtime"
"github.com/docker/docker-agent/pkg/session"
"github.com/docker/docker-agent/pkg/sessiontitle"
"github.com/docker/docker-agent/pkg/tools"
)

// fakeRuntime is a minimal Runtime that records concurrent RunStream calls.
type fakeRuntime struct {
runtime.Runtime

concurrentStreams atomic.Int32
maxConcurrent atomic.Int32
streamDelay time.Duration
}

func (f *fakeRuntime) RunStream(_ context.Context, _ *session.Session) <-chan runtime.Event {
cur := f.concurrentStreams.Add(1)
for {
old := f.maxConcurrent.Load()
if cur <= old || f.maxConcurrent.CompareAndSwap(old, cur) {
break
}
}

ch := make(chan runtime.Event)
go func() {
time.Sleep(f.streamDelay)
f.concurrentStreams.Add(-1)
close(ch)
}()
return ch
}

func (f *fakeRuntime) Resume(_ context.Context, _ runtime.ResumeRequest) {}

func (f *fakeRuntime) ResumeElicitation(_ context.Context, _ tools.ElicitationAction, _ map[string]any) error {
return nil
}

func newTestSessionManager(t *testing.T, sess *session.Session, fake *fakeRuntime) *SessionManager {
t.Helper()

ctx := t.Context()
store := session.NewInMemorySessionStore()
require.NoError(t, store.AddSession(ctx, sess))

sm := &SessionManager{
runtimeSessions: concurrent.NewMap[string, *activeRuntimes](),
sessionStore: store,
Sources: config.Sources{},
runConfig: &config.RuntimeConfig{},
}

// Pre-register a runtime for this session so RunSession skips agent loading.
sm.runtimeSessions.Store(sess.ID, &activeRuntimes{
runtime: fake,
session: sess,
titleGen: (*sessiontitle.Generator)(nil),
})

return sm
}

// TestRunSession_ConcurrentRequestReturnsErrSessionBusy verifies that a
// second RunSession call on a session that is already streaming returns
// ErrSessionBusy instead of silently interleaving messages.
func TestRunSession_ConcurrentRequestReturnsErrSessionBusy(t *testing.T) {
t.Parallel()

ctx := t.Context()
sess := session.New()
fake := &fakeRuntime{streamDelay: 500 * time.Millisecond}
sm := newTestSessionManager(t, sess, fake)

// Start the first stream.
ch1, err := sm.RunSession(ctx, sess.ID, "agent", "root", []api.Message{
{Content: "first"},
})
require.NoError(t, err)

// Give the goroutine a moment to acquire the streaming lock.
time.Sleep(50 * time.Millisecond)

// The second request should fail immediately with ErrSessionBusy.
_, err = sm.RunSession(ctx, sess.ID, "agent", "root", []api.Message{
{Content: "second"},
})
require.ErrorIs(t, err, ErrSessionBusy)

// Drain first stream to let it complete.
for range ch1 {
}

// After the first stream finishes, a new request should succeed.
ch3, err := sm.RunSession(ctx, sess.ID, "agent", "root", []api.Message{
{Content: "third"},
})
require.NoError(t, err)
for range ch3 {
}
}

// TestRunSession_MessagesNotAddedWhenBusy verifies that when a session
// is busy, the rejected request does not mutate the session's messages.
func TestRunSession_MessagesNotAddedWhenBusy(t *testing.T) {
t.Parallel()

ctx := t.Context()
sess := session.New()
fake := &fakeRuntime{streamDelay: 500 * time.Millisecond}
sm := newTestSessionManager(t, sess, fake)

ch1, err := sm.RunSession(ctx, sess.ID, "agent", "root", []api.Message{
{Content: "first"},
})
require.NoError(t, err)

time.Sleep(50 * time.Millisecond)

msgCountBefore := len(sess.GetAllMessages())

_, err = sm.RunSession(ctx, sess.ID, "agent", "root", []api.Message{
{Content: "should not be added"},
})
require.ErrorIs(t, err, ErrSessionBusy)

// Messages should not have been added.
assert.Len(t, sess.GetAllMessages(), msgCountBefore)

for range ch1 {
}
}

// TestRunSession_SequentialRequestsSucceed verifies that sequential
// (non-overlapping) requests on the same session work normally.
func TestRunSession_SequentialRequestsSucceed(t *testing.T) {
t.Parallel()

ctx := t.Context()
sess := session.New()
fake := &fakeRuntime{streamDelay: 10 * time.Millisecond}
sm := newTestSessionManager(t, sess, fake)

for range 3 {
ch, err := sm.RunSession(ctx, sess.ID, "agent", "root", []api.Message{
{Content: "hello"},
})
require.NoError(t, err)
for range ch {
}
}

assert.Equal(t, int32(1), fake.maxConcurrent.Load())
}

// TestRunSession_DifferentSessionsConcurrently verifies that concurrent
// requests on *different* sessions are not blocked by each other.
func TestRunSession_DifferentSessionsConcurrently(t *testing.T) {
t.Parallel()

ctx := t.Context()
store := session.NewInMemorySessionStore()
fake1 := &fakeRuntime{streamDelay: 200 * time.Millisecond}
fake2 := &fakeRuntime{streamDelay: 200 * time.Millisecond}

sess1 := session.New()
sess2 := session.New()
require.NoError(t, store.AddSession(ctx, sess1))
require.NoError(t, store.AddSession(ctx, sess2))

sm := &SessionManager{
runtimeSessions: concurrent.NewMap[string, *activeRuntimes](),
sessionStore: store,
Sources: config.Sources{},
runConfig: &config.RuntimeConfig{},
}

sm.runtimeSessions.Store(sess1.ID, &activeRuntimes{
runtime: fake1, session: sess1, titleGen: (*sessiontitle.Generator)(nil),
})
sm.runtimeSessions.Store(sess2.ID, &activeRuntimes{
runtime: fake2, session: sess2, titleGen: (*sessiontitle.Generator)(nil),
})

var wg sync.WaitGroup
wg.Add(2)

go func() {
defer wg.Done()
ch, err := sm.RunSession(ctx, sess1.ID, "agent", "root", []api.Message{{Content: "a"}})
assert.NoError(t, err)
for range ch {
}
}()

go func() {
defer wg.Done()
ch, err := sm.RunSession(ctx, sess2.ID, "agent", "root", []api.Message{{Content: "b"}})
assert.NoError(t, err)
for range ch {
}
}()

wg.Wait()

// Both sessions should have streamed (1 each).
assert.Equal(t, int32(1), fake1.maxConcurrent.Load())
assert.Equal(t, int32(1), fake2.maxConcurrent.Load())
}
Loading