Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pkg/api/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,13 @@ type ResumeElicitationRequest struct {
Content map[string]any `json:"content"` // The submitted form data (only present when action is "accept")
}

// SteerSessionRequest represents a request to inject user messages into a
// running agent session. The messages are picked up by the agent loop between
// tool execution and the next LLM call.
type SteerSessionRequest struct {
Messages []Message `json:"messages"`
}

// UpdateSessionTitleRequest represents a request to update a session's title
type UpdateSessionTitleRequest struct {
Title string `json:"title"`
Expand Down
12 changes: 12 additions & 0 deletions pkg/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,18 @@ func (a *App) SubscribeWith(ctx context.Context, send func(tea.Msg)) {
}
}

// Steer enqueues a user message for mid-turn injection into the running
// agent loop. Works with both local and remote runtimes.
func (a *App) Steer(msg runtime.QueuedMessage) error {
return a.runtime.Steer(msg)
}

// FollowUp enqueues a message for end-of-turn processing. Each follow-up
// gets a full undivided agent turn.
func (a *App) FollowUp(msg runtime.QueuedMessage) error {
return a.runtime.FollowUp(msg)
}

// Resume resumes the runtime with the given confirmation request
func (a *App) Resume(req runtime.ResumeRequest) {
a.runtime.Resume(context.Background(), req)
Expand Down
2 changes: 2 additions & 0 deletions pkg/app/app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ func (m *mockRuntime) UpdateSessionTitle(_ context.Context, sess *session.Sessio
func (m *mockRuntime) TitleGenerator() *sessiontitle.Generator { return nil }
func (m *mockRuntime) Close() error { return nil }
func (m *mockRuntime) Stop() {}
func (m *mockRuntime) Steer(_ runtime.QueuedMessage) error { return nil }
func (m *mockRuntime) FollowUp(_ runtime.QueuedMessage) error { return nil }

// Verify mockRuntime implements runtime.Runtime
var _ runtime.Runtime = (*mockRuntime)(nil)
Expand Down
2 changes: 2 additions & 0 deletions pkg/cli/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ func (m *mockRuntime) ExecuteMCPPrompt(context.Context, string, map[string]strin
func (m *mockRuntime) UpdateSessionTitle(context.Context, *session.Session, string) error { return nil }
func (m *mockRuntime) TitleGenerator() *sessiontitle.Generator { return nil }
func (m *mockRuntime) Close() error { return nil }
func (m *mockRuntime) Steer(runtime.QueuedMessage) error { return nil }
func (m *mockRuntime) FollowUp(runtime.QueuedMessage) error { return nil }
func (m *mockRuntime) RegenerateTitle(context.Context, *session.Session, chan runtime.Event) {}

func (m *mockRuntime) Resume(_ context.Context, req runtime.ResumeRequest) {
Expand Down
12 changes: 12 additions & 0 deletions pkg/runtime/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,18 @@ func (c *Client) ResumeSession(ctx context.Context, id, confirmation, reason, to
return c.doRequest(ctx, http.MethodPost, "/api/sessions/"+id+"/resume", req, nil)
}

// SteerSession injects user messages into a running session mid-turn.
func (c *Client) SteerSession(ctx context.Context, sessionID string, messages []api.Message) error {
req := api.SteerSessionRequest{Messages: messages}
return c.doRequest(ctx, http.MethodPost, "/api/sessions/"+sessionID+"/steer", req, nil)
}

// FollowUpSession queues messages for end-of-turn processing.
func (c *Client) FollowUpSession(ctx context.Context, sessionID string, messages []api.Message) error {
req := api.SteerSessionRequest{Messages: messages}
return c.doRequest(ctx, http.MethodPost, "/api/sessions/"+sessionID+"/followup", req, nil)
}

// DeleteSession deletes a session by ID
func (c *Client) DeleteSession(ctx context.Context, id string) error {
return c.doRequest(ctx, "DELETE", "/api/sessions/"+id, nil, nil)
Expand Down
2 changes: 2 additions & 0 deletions pkg/runtime/commands_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ func (m *mockRuntime) UpdateSessionTitle(context.Context, *session.Session, stri
}
func (m *mockRuntime) TitleGenerator() *sessiontitle.Generator { return nil }
func (m *mockRuntime) Close() error { return nil }
func (m *mockRuntime) Steer(QueuedMessage) error { return nil }
func (m *mockRuntime) FollowUp(QueuedMessage) error { return nil }

func (m *mockRuntime) RegenerateTitle(context.Context, *session.Session, chan Event) {
}
Expand Down
46 changes: 44 additions & 2 deletions pkg/runtime/loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -386,13 +386,55 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
// Record per-toolset model override for the next LLM turn.
toolModelOverride = resolveToolCallModelOverride(res.Calls, agentTools)

// Only compact proactively when the model will continue (has
// tool calls to process on the next turn). If the model stopped
// and no steered messages override that, compaction is wasteful
// because no further LLM call follows.
if !res.Stopped {
r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events)
}

// --- STEERING: mid-turn injection ---
// Drain ALL pending steer messages. These are urgent course-
// corrections that the model should see on the very next
// iteration, wrapped in <system-reminder> tags.
if steered := r.steerQueue.Drain(ctx); len(steered) > 0 {
for _, sm := range steered {
wrapped := fmt.Sprintf(
"<system-reminder>\nThe user sent the following message while you were working:\n%s\n\nPlease address this in your next response while continuing with your current tasks.\n</system-reminder>",
sm.Content,
)
userMsg := session.UserMessage(wrapped, sm.MultiContent...)
sess.AddMessage(userMsg)
events <- UserMessage(sm.Content, sess.ID, sm.MultiContent, len(sess.Messages)-1)
}

// The model must respond to the injected messages — compact
// if needed and re-enter the loop for the next LLM call.
r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events)
continue
}

if res.Stopped {
slog.Debug("Conversation stopped", "agent", a.Name())
r.executeStopHooks(ctx, sess, a, res.Content, events)

// --- FOLLOW-UP: end-of-turn injection ---
// Pop exactly one follow-up message. Unlike steered
// messages, follow-ups are plain user messages that start
// a new turn — the model sees them as fresh input, not a
// mid-stream interruption. Each follow-up gets a full
// undivided agent turn.
if followUp, ok := r.followUpQueue.Dequeue(ctx); ok {
userMsg := session.UserMessage(followUp.Content, followUp.MultiContent...)
sess.AddMessage(userMsg)
events <- UserMessage(followUp.Content, sess.ID, followUp.MultiContent, len(sess.Messages)-1)
r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events)
continue // re-enter the loop for a new turn
}

break
}

r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events)
}
}()

Expand Down
85 changes: 85 additions & 0 deletions pkg/runtime/message_queue.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package runtime

import (
"context"

"github.com/docker/docker-agent/pkg/chat"
)

// QueuedMessage is a user message waiting to be injected into the agent loop,
// either mid-turn (via the steer queue) or at end-of-turn (via the follow-up
// queue).
type QueuedMessage struct {
Content string
MultiContent []chat.MessagePart
}

// MessageQueue is the interface for storing messages that are injected into
// the agent loop. Implementations must be safe for concurrent use: Enqueue
// is called from API handlers while Dequeue/Drain are called from the agent
// loop goroutine.
//
// The default implementation is NewInMemoryMessageQueue. Callers that need
// durable or distributed storage can provide their own implementation
// via the WithSteerQueue or WithFollowUpQueue options.
type MessageQueue interface {
// Enqueue adds a message to the queue. Returns false if the queue is
// full or the context is cancelled.
Enqueue(ctx context.Context, msg QueuedMessage) bool
// Dequeue removes and returns the next message from the queue.
// Returns the message and true, or a zero value and false if the
// queue is empty. Must not block.
Dequeue(ctx context.Context) (QueuedMessage, bool)
// Drain returns all pending messages and removes them from the queue.
// Must not block — if the queue is empty it returns nil.
Drain(ctx context.Context) []QueuedMessage
}

// inMemoryMessageQueue is the default MessageQueue backed by a buffered channel.
type inMemoryMessageQueue struct {
ch chan QueuedMessage
}

const (
// defaultSteerQueueCapacity is the buffer size for the default in-memory steer queue.
defaultSteerQueueCapacity = 5
// defaultFollowUpQueueCapacity is the buffer size for the default in-memory follow-up queue.
// Higher than steer because follow-ups accumulate while waiting for the turn to end.
defaultFollowUpQueueCapacity = 20
)

// NewInMemoryMessageQueue creates a MessageQueue backed by a buffered channel
// with the given capacity.
func NewInMemoryMessageQueue(capacity int) MessageQueue {
return &inMemoryMessageQueue{ch: make(chan QueuedMessage, capacity)}
}

func (q *inMemoryMessageQueue) Enqueue(_ context.Context, msg QueuedMessage) bool {
select {
case q.ch <- msg:
return true
default:
return false
}
}

func (q *inMemoryMessageQueue) Dequeue(_ context.Context) (QueuedMessage, bool) {
select {
case m := <-q.ch:
return m, true
default:
return QueuedMessage{}, false
}
}

func (q *inMemoryMessageQueue) Drain(_ context.Context) []QueuedMessage {
var msgs []QueuedMessage
for {
select {
case m := <-q.ch:
msgs = append(msgs, m)
default:
return msgs
}
}
}
6 changes: 6 additions & 0 deletions pkg/runtime/remote_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ type RemoteClient interface {
// RunAgentWithAgentName executes an agent with a specific agent name
RunAgentWithAgentName(ctx context.Context, sessionID, agent, agentName string, messages []api.Message) (<-chan Event, error)

// SteerSession injects user messages into a running session mid-turn
SteerSession(ctx context.Context, sessionID string, messages []api.Message) error

// FollowUpSession queues messages for end-of-turn processing
FollowUpSession(ctx context.Context, sessionID string, messages []api.Message) error

// UpdateSessionTitle updates the title of a session
UpdateSessionTitle(ctx context.Context, sessionID, title string) error

Expand Down
21 changes: 21 additions & 0 deletions pkg/runtime/remote_runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,27 @@ func (r *RemoteRuntime) Run(ctx context.Context, sess *session.Session) ([]sessi
return sess.GetAllMessages(), nil
}

// Steer enqueues a user message for mid-turn injection into the running
// agent loop on the remote server.
func (r *RemoteRuntime) Steer(msg QueuedMessage) error {
if r.sessionID == "" {
return errors.New("no active session")
}
return r.client.SteerSession(context.Background(), r.sessionID, []api.Message{
{Content: msg.Content, MultiContent: msg.MultiContent},
})
}

// FollowUp enqueues a message for end-of-turn processing on the remote server.
func (r *RemoteRuntime) FollowUp(msg QueuedMessage) error {
if r.sessionID == "" {
return errors.New("no active session")
}
return r.client.FollowUpSession(context.Background(), r.sessionID, []api.Message{
{Content: msg.Content, MultiContent: msg.MultiContent},
})
}

// Resume allows resuming execution after user confirmation
func (r *RemoteRuntime) Resume(ctx context.Context, req ResumeRequest) {
slog.Debug("Resuming remote runtime", "agent", r.currentAgent, "type", req.Type, "reason", req.Reason, "tool_name", req.ToolName, "session_id", r.sessionID)
Expand Down
54 changes: 54 additions & 0 deletions pkg/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,14 @@ type Runtime interface {
// if the runtime does not support local title generation (e.g. remote runtimes).
TitleGenerator() *sessiontitle.Generator

// Steer enqueues a user message for urgent mid-turn injection into the
// running agent loop. Returns an error if the queue is full or steering
// is not available.
Steer(msg QueuedMessage) error
// FollowUp enqueues a message for end-of-turn processing. Each follow-up
// gets a full undivided agent turn. Returns an error if the queue is full.
FollowUp(msg QueuedMessage) error

// Close releases resources held by the runtime (e.g., session store connections).
Close() error
}
Expand Down Expand Up @@ -201,6 +209,14 @@ type LocalRuntime struct {

currentAgentMu sync.RWMutex

// steerQueue stores urgent mid-turn messages. The agent loop drains
// ALL pending messages after tool execution, before the stop check.
steerQueue MessageQueue

// followUpQueue stores end-of-turn messages. The agent loop pops
// exactly ONE message after the model stops and stop-hooks have run.
followUpQueue MessageQueue

// onToolsChanged is called when an MCP toolset reports a tool list change.
onToolsChanged func(Event)

Expand Down Expand Up @@ -228,6 +244,22 @@ func WithTracer(t trace.Tracer) Opt {
}
}

// WithSteerQueue sets a custom MessageQueue for mid-turn message injection.
// If not provided, an in-memory buffered queue is used.
func WithSteerQueue(q MessageQueue) Opt {
return func(r *LocalRuntime) {
r.steerQueue = q
}
}

// WithFollowUpQueue sets a custom MessageQueue for end-of-turn follow-up
// messages. If not provided, an in-memory buffered queue is used.
func WithFollowUpQueue(q MessageQueue) Opt {
return func(r *LocalRuntime) {
r.followUpQueue = q
}
}

func WithSessionCompaction(sessionCompaction bool) Opt {
return func(r *LocalRuntime) {
r.sessionCompaction = sessionCompaction
Expand Down Expand Up @@ -291,6 +323,8 @@ func NewLocalRuntime(agents *team.Team, opts ...Opt) (*LocalRuntime, error) {
currentAgent: defaultAgent.Name(),
resumeChan: make(chan ResumeRequest),
elicitationRequestCh: make(chan ElicitationResult),
steerQueue: NewInMemoryMessageQueue(defaultSteerQueueCapacity),
followUpQueue: NewInMemoryMessageQueue(defaultFollowUpQueueCapacity),
sessionCompaction: true,
managedOAuth: true,
sessionStore: session.NewInMemorySessionStore(),
Expand Down Expand Up @@ -1015,6 +1049,26 @@ func (r *LocalRuntime) ResumeElicitation(ctx context.Context, action tools.Elici
}
}

// Steer enqueues a user message for urgent mid-turn injection into the
// running agent loop. The message will be picked up after the current batch
// of tool calls finishes but before the loop checks whether to stop.
func (r *LocalRuntime) Steer(msg QueuedMessage) error {
if !r.steerQueue.Enqueue(context.Background(), msg) {
return errors.New("steer queue full")
}
return nil
}

// FollowUp enqueues a message to be processed after the current agent turn
// finishes. Unlike Steer, follow-ups are popped one at a time and each gets
// a full undivided agent turn.
func (r *LocalRuntime) FollowUp(msg QueuedMessage) error {
if !r.followUpQueue.Enqueue(context.Background(), msg) {
return errors.New("follow-up queue full")
}
return nil
}

// Run starts the agent's interaction loop

func (r *LocalRuntime) startSpan(ctx context.Context, name string, opts ...trace.SpanStartOption) (context.Context, trace.Span) {
Expand Down
Loading
Loading