diff --git a/pkg/api/types.go b/pkg/api/types.go index 90f943421..bde812a8c 100644 --- a/pkg/api/types.go +++ b/pkg/api/types.go @@ -160,6 +160,13 @@ type ResumeElicitationRequest struct { Content map[string]any `json:"content"` // The submitted form data (only present when action is "accept") } +// SteerSessionRequest represents a request to inject user messages into a +// running agent session. The messages are picked up by the agent loop between +// tool execution and the next LLM call. +type SteerSessionRequest struct { + Messages []Message `json:"messages"` +} + // UpdateSessionTitleRequest represents a request to update a session's title type UpdateSessionTitleRequest struct { Title string `json:"title"` diff --git a/pkg/app/app.go b/pkg/app/app.go index bd0637fec..1a6ba0d3e 100644 --- a/pkg/app/app.go +++ b/pkg/app/app.go @@ -532,6 +532,18 @@ func (a *App) SubscribeWith(ctx context.Context, send func(tea.Msg)) { } } +// Steer enqueues a user message for mid-turn injection into the running +// agent loop. Works with both local and remote runtimes. +func (a *App) Steer(msg runtime.QueuedMessage) error { + return a.runtime.Steer(msg) +} + +// FollowUp enqueues a message for end-of-turn processing. Each follow-up +// gets a full undivided agent turn. +func (a *App) FollowUp(msg runtime.QueuedMessage) error { + return a.runtime.FollowUp(msg) +} + // Resume resumes the runtime with the given confirmation request func (a *App) Resume(req runtime.ResumeRequest) { a.runtime.Resume(context.Background(), req) diff --git a/pkg/app/app_test.go b/pkg/app/app_test.go index eee69fca1..2f32cff20 100644 --- a/pkg/app/app_test.go +++ b/pkg/app/app_test.go @@ -67,6 +67,8 @@ func (m *mockRuntime) UpdateSessionTitle(_ context.Context, sess *session.Sessio func (m *mockRuntime) TitleGenerator() *sessiontitle.Generator { return nil } func (m *mockRuntime) Close() error { return nil } func (m *mockRuntime) Stop() {} +func (m *mockRuntime) Steer(_ runtime.QueuedMessage) error { return nil } +func (m *mockRuntime) FollowUp(_ runtime.QueuedMessage) error { return nil } // Verify mockRuntime implements runtime.Runtime var _ runtime.Runtime = (*mockRuntime)(nil) diff --git a/pkg/cli/runner_test.go b/pkg/cli/runner_test.go index 6cd27a8bb..4f39c1d04 100644 --- a/pkg/cli/runner_test.go +++ b/pkg/cli/runner_test.go @@ -60,6 +60,8 @@ func (m *mockRuntime) ExecuteMCPPrompt(context.Context, string, map[string]strin func (m *mockRuntime) UpdateSessionTitle(context.Context, *session.Session, string) error { return nil } func (m *mockRuntime) TitleGenerator() *sessiontitle.Generator { return nil } func (m *mockRuntime) Close() error { return nil } +func (m *mockRuntime) Steer(runtime.QueuedMessage) error { return nil } +func (m *mockRuntime) FollowUp(runtime.QueuedMessage) error { return nil } func (m *mockRuntime) RegenerateTitle(context.Context, *session.Session, chan runtime.Event) {} func (m *mockRuntime) Resume(_ context.Context, req runtime.ResumeRequest) { diff --git a/pkg/runtime/client.go b/pkg/runtime/client.go index 8218e4eec..8ade2cc45 100644 --- a/pkg/runtime/client.go +++ b/pkg/runtime/client.go @@ -266,6 +266,18 @@ func (c *Client) ResumeSession(ctx context.Context, id, confirmation, reason, to return c.doRequest(ctx, http.MethodPost, "/api/sessions/"+id+"/resume", req, nil) } +// SteerSession injects user messages into a running session mid-turn. +func (c *Client) SteerSession(ctx context.Context, sessionID string, messages []api.Message) error { + req := api.SteerSessionRequest{Messages: messages} + return c.doRequest(ctx, http.MethodPost, "/api/sessions/"+sessionID+"/steer", req, nil) +} + +// FollowUpSession queues messages for end-of-turn processing. +func (c *Client) FollowUpSession(ctx context.Context, sessionID string, messages []api.Message) error { + req := api.SteerSessionRequest{Messages: messages} + return c.doRequest(ctx, http.MethodPost, "/api/sessions/"+sessionID+"/followup", req, nil) +} + // DeleteSession deletes a session by ID func (c *Client) DeleteSession(ctx context.Context, id string) error { return c.doRequest(ctx, "DELETE", "/api/sessions/"+id, nil, nil) diff --git a/pkg/runtime/commands_test.go b/pkg/runtime/commands_test.go index fa6999233..00e4a2195 100644 --- a/pkg/runtime/commands_test.go +++ b/pkg/runtime/commands_test.go @@ -69,6 +69,8 @@ func (m *mockRuntime) UpdateSessionTitle(context.Context, *session.Session, stri } func (m *mockRuntime) TitleGenerator() *sessiontitle.Generator { return nil } func (m *mockRuntime) Close() error { return nil } +func (m *mockRuntime) Steer(QueuedMessage) error { return nil } +func (m *mockRuntime) FollowUp(QueuedMessage) error { return nil } func (m *mockRuntime) RegenerateTitle(context.Context, *session.Session, chan Event) { } diff --git a/pkg/runtime/loop.go b/pkg/runtime/loop.go index 44fbe00b9..0b653f6a5 100644 --- a/pkg/runtime/loop.go +++ b/pkg/runtime/loop.go @@ -386,13 +386,55 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c // Record per-toolset model override for the next LLM turn. toolModelOverride = resolveToolCallModelOverride(res.Calls, agentTools) + // Only compact proactively when the model will continue (has + // tool calls to process on the next turn). If the model stopped + // and no steered messages override that, compaction is wasteful + // because no further LLM call follows. + if !res.Stopped { + r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events) + } + + // --- STEERING: mid-turn injection --- + // Drain ALL pending steer messages. These are urgent course- + // corrections that the model should see on the very next + // iteration, wrapped in tags. + if steered := r.steerQueue.Drain(ctx); len(steered) > 0 { + for _, sm := range steered { + wrapped := fmt.Sprintf( + "\nThe user sent the following message while you were working:\n%s\n\nPlease address this in your next response while continuing with your current tasks.\n", + sm.Content, + ) + userMsg := session.UserMessage(wrapped, sm.MultiContent...) + sess.AddMessage(userMsg) + events <- UserMessage(sm.Content, sess.ID, sm.MultiContent, len(sess.Messages)-1) + } + + // The model must respond to the injected messages — compact + // if needed and re-enter the loop for the next LLM call. + r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events) + continue + } + if res.Stopped { slog.Debug("Conversation stopped", "agent", a.Name()) r.executeStopHooks(ctx, sess, a, res.Content, events) + + // --- FOLLOW-UP: end-of-turn injection --- + // Pop exactly one follow-up message. Unlike steered + // messages, follow-ups are plain user messages that start + // a new turn — the model sees them as fresh input, not a + // mid-stream interruption. Each follow-up gets a full + // undivided agent turn. + if followUp, ok := r.followUpQueue.Dequeue(ctx); ok { + userMsg := session.UserMessage(followUp.Content, followUp.MultiContent...) + sess.AddMessage(userMsg) + events <- UserMessage(followUp.Content, sess.ID, followUp.MultiContent, len(sess.Messages)-1) + r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events) + continue // re-enter the loop for a new turn + } + break } - - r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events) } }() diff --git a/pkg/runtime/message_queue.go b/pkg/runtime/message_queue.go new file mode 100644 index 000000000..99feecb5f --- /dev/null +++ b/pkg/runtime/message_queue.go @@ -0,0 +1,85 @@ +package runtime + +import ( + "context" + + "github.com/docker/docker-agent/pkg/chat" +) + +// QueuedMessage is a user message waiting to be injected into the agent loop, +// either mid-turn (via the steer queue) or at end-of-turn (via the follow-up +// queue). +type QueuedMessage struct { + Content string + MultiContent []chat.MessagePart +} + +// MessageQueue is the interface for storing messages that are injected into +// the agent loop. Implementations must be safe for concurrent use: Enqueue +// is called from API handlers while Dequeue/Drain are called from the agent +// loop goroutine. +// +// The default implementation is NewInMemoryMessageQueue. Callers that need +// durable or distributed storage can provide their own implementation +// via the WithSteerQueue or WithFollowUpQueue options. +type MessageQueue interface { + // Enqueue adds a message to the queue. Returns false if the queue is + // full or the context is cancelled. + Enqueue(ctx context.Context, msg QueuedMessage) bool + // Dequeue removes and returns the next message from the queue. + // Returns the message and true, or a zero value and false if the + // queue is empty. Must not block. + Dequeue(ctx context.Context) (QueuedMessage, bool) + // Drain returns all pending messages and removes them from the queue. + // Must not block — if the queue is empty it returns nil. + Drain(ctx context.Context) []QueuedMessage +} + +// inMemoryMessageQueue is the default MessageQueue backed by a buffered channel. +type inMemoryMessageQueue struct { + ch chan QueuedMessage +} + +const ( + // defaultSteerQueueCapacity is the buffer size for the default in-memory steer queue. + defaultSteerQueueCapacity = 5 + // defaultFollowUpQueueCapacity is the buffer size for the default in-memory follow-up queue. + // Higher than steer because follow-ups accumulate while waiting for the turn to end. + defaultFollowUpQueueCapacity = 20 +) + +// NewInMemoryMessageQueue creates a MessageQueue backed by a buffered channel +// with the given capacity. +func NewInMemoryMessageQueue(capacity int) MessageQueue { + return &inMemoryMessageQueue{ch: make(chan QueuedMessage, capacity)} +} + +func (q *inMemoryMessageQueue) Enqueue(_ context.Context, msg QueuedMessage) bool { + select { + case q.ch <- msg: + return true + default: + return false + } +} + +func (q *inMemoryMessageQueue) Dequeue(_ context.Context) (QueuedMessage, bool) { + select { + case m := <-q.ch: + return m, true + default: + return QueuedMessage{}, false + } +} + +func (q *inMemoryMessageQueue) Drain(_ context.Context) []QueuedMessage { + var msgs []QueuedMessage + for { + select { + case m := <-q.ch: + msgs = append(msgs, m) + default: + return msgs + } + } +} diff --git a/pkg/runtime/remote_client.go b/pkg/runtime/remote_client.go index c1398afaf..993be1468 100644 --- a/pkg/runtime/remote_client.go +++ b/pkg/runtime/remote_client.go @@ -30,6 +30,12 @@ type RemoteClient interface { // RunAgentWithAgentName executes an agent with a specific agent name RunAgentWithAgentName(ctx context.Context, sessionID, agent, agentName string, messages []api.Message) (<-chan Event, error) + // SteerSession injects user messages into a running session mid-turn + SteerSession(ctx context.Context, sessionID string, messages []api.Message) error + + // FollowUpSession queues messages for end-of-turn processing + FollowUpSession(ctx context.Context, sessionID string, messages []api.Message) error + // UpdateSessionTitle updates the title of a session UpdateSessionTitle(ctx context.Context, sessionID, title string) error diff --git a/pkg/runtime/remote_runtime.go b/pkg/runtime/remote_runtime.go index 5f03297cd..fd220c9a3 100644 --- a/pkg/runtime/remote_runtime.go +++ b/pkg/runtime/remote_runtime.go @@ -211,6 +211,27 @@ func (r *RemoteRuntime) Run(ctx context.Context, sess *session.Session) ([]sessi return sess.GetAllMessages(), nil } +// Steer enqueues a user message for mid-turn injection into the running +// agent loop on the remote server. +func (r *RemoteRuntime) Steer(msg QueuedMessage) error { + if r.sessionID == "" { + return errors.New("no active session") + } + return r.client.SteerSession(context.Background(), r.sessionID, []api.Message{ + {Content: msg.Content, MultiContent: msg.MultiContent}, + }) +} + +// FollowUp enqueues a message for end-of-turn processing on the remote server. +func (r *RemoteRuntime) FollowUp(msg QueuedMessage) error { + if r.sessionID == "" { + return errors.New("no active session") + } + return r.client.FollowUpSession(context.Background(), r.sessionID, []api.Message{ + {Content: msg.Content, MultiContent: msg.MultiContent}, + }) +} + // Resume allows resuming execution after user confirmation func (r *RemoteRuntime) Resume(ctx context.Context, req ResumeRequest) { slog.Debug("Resuming remote runtime", "agent", r.currentAgent, "type", req.Type, "reason", req.Reason, "tool_name", req.ToolName, "session_id", r.sessionID) diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index f6a4de8bc..711607615 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -139,6 +139,14 @@ type Runtime interface { // if the runtime does not support local title generation (e.g. remote runtimes). TitleGenerator() *sessiontitle.Generator + // Steer enqueues a user message for urgent mid-turn injection into the + // running agent loop. Returns an error if the queue is full or steering + // is not available. + Steer(msg QueuedMessage) error + // FollowUp enqueues a message for end-of-turn processing. Each follow-up + // gets a full undivided agent turn. Returns an error if the queue is full. + FollowUp(msg QueuedMessage) error + // Close releases resources held by the runtime (e.g., session store connections). Close() error } @@ -201,6 +209,14 @@ type LocalRuntime struct { currentAgentMu sync.RWMutex + // steerQueue stores urgent mid-turn messages. The agent loop drains + // ALL pending messages after tool execution, before the stop check. + steerQueue MessageQueue + + // followUpQueue stores end-of-turn messages. The agent loop pops + // exactly ONE message after the model stops and stop-hooks have run. + followUpQueue MessageQueue + // onToolsChanged is called when an MCP toolset reports a tool list change. onToolsChanged func(Event) @@ -228,6 +244,22 @@ func WithTracer(t trace.Tracer) Opt { } } +// WithSteerQueue sets a custom MessageQueue for mid-turn message injection. +// If not provided, an in-memory buffered queue is used. +func WithSteerQueue(q MessageQueue) Opt { + return func(r *LocalRuntime) { + r.steerQueue = q + } +} + +// WithFollowUpQueue sets a custom MessageQueue for end-of-turn follow-up +// messages. If not provided, an in-memory buffered queue is used. +func WithFollowUpQueue(q MessageQueue) Opt { + return func(r *LocalRuntime) { + r.followUpQueue = q + } +} + func WithSessionCompaction(sessionCompaction bool) Opt { return func(r *LocalRuntime) { r.sessionCompaction = sessionCompaction @@ -291,6 +323,8 @@ func NewLocalRuntime(agents *team.Team, opts ...Opt) (*LocalRuntime, error) { currentAgent: defaultAgent.Name(), resumeChan: make(chan ResumeRequest), elicitationRequestCh: make(chan ElicitationResult), + steerQueue: NewInMemoryMessageQueue(defaultSteerQueueCapacity), + followUpQueue: NewInMemoryMessageQueue(defaultFollowUpQueueCapacity), sessionCompaction: true, managedOAuth: true, sessionStore: session.NewInMemorySessionStore(), @@ -1015,6 +1049,26 @@ func (r *LocalRuntime) ResumeElicitation(ctx context.Context, action tools.Elici } } +// Steer enqueues a user message for urgent mid-turn injection into the +// running agent loop. The message will be picked up after the current batch +// of tool calls finishes but before the loop checks whether to stop. +func (r *LocalRuntime) Steer(msg QueuedMessage) error { + if !r.steerQueue.Enqueue(context.Background(), msg) { + return errors.New("steer queue full") + } + return nil +} + +// FollowUp enqueues a message to be processed after the current agent turn +// finishes. Unlike Steer, follow-ups are popped one at a time and each gets +// a full undivided agent turn. +func (r *LocalRuntime) FollowUp(msg QueuedMessage) error { + if !r.followUpQueue.Enqueue(context.Background(), msg) { + return errors.New("follow-up queue full") + } + return nil +} + // Run starts the agent's interaction loop func (r *LocalRuntime) startSpan(ctx context.Context, name string, opts ...trace.SpanStartOption) (context.Context, trace.Span) { diff --git a/pkg/server/server.go b/pkg/server/server.go index b9cf3b626..ede83122e 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -62,6 +62,10 @@ func New(ctx context.Context, sessionStore session.Store, runConfig *config.Runt group.POST("/sessions/:id/agent/:agent", s.runAgent) group.POST("/sessions/:id/agent/:agent/:agent_name", s.runAgent) group.POST("/sessions/:id/elicitation", s.elicitation) + // Steer: inject user messages into a running agent session mid-turn + group.POST("/sessions/:id/steer", s.steerSession) + // Follow-up: queue messages for end-of-turn processing + group.POST("/sessions/:id/followup", s.followUpSession) // Agent tool count group.GET("/agents/:id/:agent_name/tools/count", s.getAgentToolCount) @@ -317,3 +321,39 @@ func (s *Server) elicitation(c echo.Context) error { return c.JSON(http.StatusOK, nil) } + +func (s *Server) steerSession(c echo.Context) error { + sessionID := c.Param("id") + var req api.SteerSessionRequest + if err := c.Bind(&req); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("invalid request body: %v", err)) + } + + if len(req.Messages) == 0 { + return echo.NewHTTPError(http.StatusBadRequest, "at least one message is required") + } + + if err := s.sm.SteerSession(c.Request().Context(), sessionID, req.Messages); err != nil { + return echo.NewHTTPError(http.StatusConflict, fmt.Sprintf("failed to steer session: %v", err)) + } + + return c.JSON(http.StatusAccepted, map[string]string{"status": "queued"}) +} + +func (s *Server) followUpSession(c echo.Context) error { + sessionID := c.Param("id") + var req api.SteerSessionRequest + if err := c.Bind(&req); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("invalid request body: %v", err)) + } + + if len(req.Messages) == 0 { + return echo.NewHTTPError(http.StatusBadRequest, "at least one message is required") + } + + if err := s.sm.FollowUpSession(c.Request().Context(), sessionID, req.Messages); err != nil { + return echo.NewHTTPError(http.StatusConflict, fmt.Sprintf("failed to enqueue follow-up: %v", err)) + } + + return c.JSON(http.StatusAccepted, map[string]string{"status": "queued"}) +} diff --git a/pkg/server/session_manager.go b/pkg/server/session_manager.go index 6c26d3b58..f8685fa17 100644 --- a/pkg/server/session_manager.go +++ b/pkg/server/session_manager.go @@ -23,10 +23,11 @@ import ( ) type activeRuntimes struct { - runtime runtime.Runtime - cancel context.CancelFunc - session *session.Session // The actual session object used by the runtime - titleGen *sessiontitle.Generator // Title generator (includes fallback models) + runtime runtime.Runtime + cancel context.CancelFunc + session *session.Session // The actual session object used by the runtime + titleGen *sessiontitle.Generator // Title generator (includes fallback models) + streaming bool // True while RunStream is active; prevents concurrent runs } // SessionManager manages sessions for HTTP and Connect-RPC servers. @@ -160,6 +161,14 @@ func (sm *SessionManager) RunSession(ctx context.Context, sessionID, agentFilena } runtimeSession, exists := sm.runtimeSessions.Load(sessionID) + + // Reject if a stream is already active for this session. The caller + // should use POST /sessions/:id/steer to inject follow-up messages + // into a running session instead of starting a second concurrent stream. + if exists && runtimeSession.streaming { + return nil, errors.New("session is already streaming; use /steer to send follow-up messages") + } + streamCtx, cancel := context.WithCancel(ctx) var titleGen *sessiontitle.Generator if !exists { @@ -182,6 +191,8 @@ func (sm *SessionManager) RunSession(ctx context.Context, sessionID, agentFilena titleGen = runtimeSession.titleGen } + runtimeSession.streaming = true + streamChan := make(chan runtime.Event) // Check if we need to generate a title @@ -194,8 +205,17 @@ func (sm *SessionManager) RunSession(ctx context.Context, sessionID, agentFilena } stream := runtimeSession.runtime.RunStream(streamCtx, sess) - defer cancel() - defer close(streamChan) + // Single defer to control ordering: clear the streaming flag + // BEFORE closing streamChan. When the client sees the channel + // close it may immediately call RunSession for the next queued + // message; streaming must already be false by then. + defer func() { + sm.mux.Lock() + runtimeSession.streaming = false + sm.mux.Unlock() + close(streamChan) + cancel() + }() for event := range stream { if streamCtx.Err() != nil { return @@ -230,6 +250,49 @@ func (sm *SessionManager) ResumeSession(ctx context.Context, sessionID, confirma return nil } +// SteerSession enqueues user messages for mid-turn injection into a running +// session. The messages are picked up by the agent loop after the current tool +// calls finish but before the next LLM call. Returns an error if the session +// is not actively running or if the steer buffer is full. +func (sm *SessionManager) SteerSession(_ context.Context, sessionID string, messages []api.Message) error { + rt, exists := sm.runtimeSessions.Load(sessionID) + if !exists { + return errors.New("session not found or not running") + } + + for _, msg := range messages { + if err := rt.runtime.Steer(runtime.QueuedMessage{ + Content: msg.Content, + MultiContent: msg.MultiContent, + }); err != nil { + return err + } + } + + return nil +} + +// FollowUpSession enqueues user messages for end-of-turn processing in a +// running session. Each message is popped one at a time after the current +// turn finishes, giving each follow-up a full undivided agent turn. +func (sm *SessionManager) FollowUpSession(_ context.Context, sessionID string, messages []api.Message) error { + rt, exists := sm.runtimeSessions.Load(sessionID) + if !exists { + return errors.New("session not found or not running") + } + + for _, msg := range messages { + if err := rt.runtime.FollowUp(runtime.QueuedMessage{ + Content: msg.Content, + MultiContent: msg.MultiContent, + }); err != nil { + return err + } + } + + return nil +} + // ResumeElicitation resumes an elicitation request. func (sm *SessionManager) ResumeElicitation(ctx context.Context, sessionID, action string, content map[string]any) error { sm.mux.Lock()