diff --git a/agent-schema.json b/agent-schema.json index ed7fba093..091764c0c 100644 --- a/agent-schema.json +++ b/agent-schema.json @@ -365,6 +365,15 @@ "type": "boolean", "description": "Whether to add a 'description' parameter to tool calls, allowing the LLM to provide context about why it is calling a tool" }, + "tool_choice": { + "type": "string", + "description": "Controls how the model selects tools. 'auto' (default) lets the model decide, 'required' forces the model to always call a tool, 'none' prevents tool use.", + "enum": [ + "auto", + "required", + "none" + ] + }, "hooks": { "$ref": "#/definitions/HooksConfig", "description": "Lifecycle hooks for executing shell commands at various points in the agent's execution" diff --git a/examples/tool_choice_required.yaml b/examples/tool_choice_required.yaml new file mode 100644 index 000000000..8aec65190 --- /dev/null +++ b/examples/tool_choice_required.yaml @@ -0,0 +1,11 @@ +agents: + root: + model: anthropic/claude-opus-4-6 + instruction: > + You are a coding agent. Use your tools to read, write, and modify files, + and to run shell commands. Always use tools to accomplish tasks rather than + just describing what to do. + tool_choice: required + toolsets: + - type: filesystem + - type: shell diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index 8dfd60619..8e04a45cb 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -32,6 +32,7 @@ type Agent struct { addDate bool addEnvironmentInfo bool addDescriptionParameter bool + toolChoice string maxIterations int maxConsecutiveToolCalls int maxOldToolCallTokens int @@ -206,6 +207,11 @@ func (a *Agent) Hooks() *latest.HooksConfig { return a.hooks } +// ToolChoice returns the tool choice mode for this agent (e.g., "auto", "required", "none"). +func (a *Agent) ToolChoice() string { + return a.toolChoice +} + // Tools returns the tools available to this agent func (a *Agent) Tools(ctx context.Context) ([]tools.Tool, error) { a.ensureToolSetsAreStarted(ctx) diff --git a/pkg/agent/opts.go b/pkg/agent/opts.go index e66a598c6..8c697ed2a 100644 --- a/pkg/agent/opts.go +++ b/pkg/agent/opts.go @@ -115,6 +115,12 @@ func WithAddDescriptionParameter(addDescriptionParameter bool) Opt { } } +func WithToolChoice(toolChoice string) Opt { + return func(a *Agent) { + a.toolChoice = toolChoice + } +} + func WithAddPromptFiles(addPromptFiles []string) Opt { return func(a *Agent) { a.addPromptFiles = addPromptFiles diff --git a/pkg/config/latest/types.go b/pkg/config/latest/types.go index 11fa3d6bc..44827314d 100644 --- a/pkg/config/latest/types.go +++ b/pkg/config/latest/types.go @@ -366,6 +366,7 @@ type AgentConfig struct { AddEnvironmentInfo bool `json:"add_environment_info,omitempty"` CodeModeTools bool `json:"code_mode_tools,omitempty"` AddDescriptionParameter bool `json:"add_description_parameter,omitempty"` + ToolChoice string `json:"tool_choice,omitempty"` MaxIterations int `json:"max_iterations,omitempty"` MaxConsecutiveToolCalls int `json:"max_consecutive_tool_calls,omitempty"` MaxOldToolCallTokens int `json:"max_old_tool_call_tokens,omitempty"` diff --git a/pkg/config/latest/validate.go b/pkg/config/latest/validate.go index 57329b67e..a0a1502e8 100644 --- a/pkg/config/latest/validate.go +++ b/pkg/config/latest/validate.go @@ -23,6 +23,11 @@ func (t *Config) validate() error { return err } + // Validate tool_choice + if err := agent.validateToolChoice(); err != nil { + return err + } + for j := range agent.Toolsets { if err := agent.Toolsets[j].validate(); err != nil { return err @@ -38,6 +43,19 @@ func (t *Config) validate() error { return nil } +// validateToolChoice validates the tool_choice configuration for an agent +func (a *AgentConfig) validateToolChoice() error { + if a.ToolChoice == "" { + return nil + } + switch a.ToolChoice { + case "auto", "required", "none": + return nil + default: + return errors.New("tool_choice must be one of: auto, required, none") + } +} + // validateFallback validates the fallback configuration for an agent func (a *AgentConfig) validateFallback() error { if a.Fallback == nil { diff --git a/pkg/config/latest/validate_test.go b/pkg/config/latest/validate_test.go index 6c90b9be7..e13939fb7 100644 --- a/pkg/config/latest/validate_test.go +++ b/pkg/config/latest/validate_test.go @@ -115,3 +115,79 @@ agents: }) } } + +func TestAgentConfig_Validate_ToolChoice(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + config string + wantErr string + }{ + { + name: "valid tool_choice auto", + config: ` +agents: + root: + model: "openai/gpt-4" + tool_choice: auto +`, + wantErr: "", + }, + { + name: "valid tool_choice required", + config: ` +agents: + root: + model: "openai/gpt-4" + tool_choice: required +`, + wantErr: "", + }, + { + name: "valid tool_choice none", + config: ` +agents: + root: + model: "openai/gpt-4" + tool_choice: none +`, + wantErr: "", + }, + { + name: "no tool_choice set", + config: ` +agents: + root: + model: "openai/gpt-4" +`, + wantErr: "", + }, + { + name: "invalid tool_choice value", + config: ` +agents: + root: + model: "openai/gpt-4" + tool_choice: force +`, + wantErr: "tool_choice must be one of: auto, required, none", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var cfg Config + err := yaml.Unmarshal([]byte(tt.config), &cfg) + + if tt.wantErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErr) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/pkg/model/provider/anthropic/beta_client.go b/pkg/model/provider/anthropic/beta_client.go index 249b7ab8a..2d36ddf9c 100644 --- a/pkg/model/provider/anthropic/beta_client.go +++ b/pkg/model/provider/anthropic/beta_client.go @@ -109,6 +109,19 @@ func (c *Client) createBetaStream( if len(requestTools) > 0 { slog.Debug("Anthropic Beta API: Adding tools to request", "tool_count", len(requestTools)) + + // Apply tool_choice from agent config + if toolChoice := c.ModelOptions.ToolChoice(); toolChoice != "" { + switch toolChoice { + case "required": + params.ToolChoice = anthropic.BetaToolChoiceUnionParam{OfAny: &anthropic.BetaToolChoiceAnyParam{}} + case "none": + params.ToolChoice = anthropic.BetaToolChoiceUnionParam{OfNone: &anthropic.BetaToolChoiceNoneParam{}} + default: // "auto" or any other value + params.ToolChoice = anthropic.BetaToolChoiceUnionParam{OfAuto: &anthropic.BetaToolChoiceAutoParam{}} + } + slog.Debug("Anthropic Beta API request using tool_choice", "tool_choice", toolChoice) + } } slog.Debug("Anthropic Beta API chat completion stream request", diff --git a/pkg/model/provider/anthropic/client.go b/pkg/model/provider/anthropic/client.go index 4194a4c19..d1b52732c 100644 --- a/pkg/model/provider/anthropic/client.go +++ b/pkg/model/provider/anthropic/client.go @@ -330,6 +330,19 @@ func (c *Client) CreateChatCompletionStream( if len(requestTools) > 0 { slog.Debug("Adding tools to Anthropic request", "tool_count", len(requestTools)) + + // Apply tool_choice from agent config + if toolChoice := c.ModelOptions.ToolChoice(); toolChoice != "" { + switch toolChoice { + case "required": + params.ToolChoice = anthropic.ToolChoiceUnionParam{OfAny: &anthropic.ToolChoiceAnyParam{}} + case "none": + params.ToolChoice = anthropic.ToolChoiceUnionParam{OfNone: &anthropic.ToolChoiceNoneParam{}} + default: // "auto" or any other value + params.ToolChoice = anthropic.ToolChoiceUnionParam{OfAuto: &anthropic.ToolChoiceAutoParam{}} + } + slog.Debug("Anthropic request using tool_choice", "tool_choice", toolChoice) + } } // Log the request details for debugging diff --git a/pkg/model/provider/bedrock/client.go b/pkg/model/provider/bedrock/client.go index ffed161b6..1585b054d 100644 --- a/pkg/model/provider/bedrock/client.go +++ b/pkg/model/provider/bedrock/client.go @@ -249,7 +249,7 @@ func (c *Client) buildConverseStreamInput(messages []chat.Message, requestTools // Convert and set tools if len(requestTools) > 0 { - input.ToolConfig = convertToolConfig(requestTools, enableCaching) + input.ToolConfig = convertToolConfig(requestTools, enableCaching, c.ModelOptions.ToolChoice()) } return input diff --git a/pkg/model/provider/bedrock/client_test.go b/pkg/model/provider/bedrock/client_test.go index 7f91a19b0..66dad863c 100644 --- a/pkg/model/provider/bedrock/client_test.go +++ b/pkg/model/provider/bedrock/client_test.go @@ -128,7 +128,7 @@ func TestConvertToolConfig(t *testing.T) { }, }} - config := convertToolConfig(requestTools, false) + config := convertToolConfig(requestTools, false, "") require.NotNil(t, config) require.Len(t, config.Tools, 1) @@ -142,10 +142,10 @@ func TestConvertToolConfig(t *testing.T) { func TestConvertToolConfig_Empty(t *testing.T) { t.Parallel() - config := convertToolConfig(nil, false) + config := convertToolConfig(nil, false, "") assert.Nil(t, config) - config = convertToolConfig([]tools.Tool{}, false) + config = convertToolConfig([]tools.Tool{}, false, "") assert.Nil(t, config) } @@ -1179,7 +1179,7 @@ func TestConvertToolConfig_WithCaching(t *testing.T) { Description: "A test tool", }} - config := convertToolConfig(requestTools, true) + config := convertToolConfig(requestTools, true, "") require.NotNil(t, config) require.Len(t, config.Tools, 2) // tool spec + cache point @@ -1197,12 +1197,42 @@ func TestConvertToolConfig_WithoutCaching(t *testing.T) { Description: "A test tool", }} - config := convertToolConfig(requestTools, false) + config := convertToolConfig(requestTools, false, "") require.NotNil(t, config) require.Len(t, config.Tools, 1) // just tool spec, no cache point } +func TestConvertToolConfig_ToolChoiceRequired(t *testing.T) { + t.Parallel() + + requestTools := []tools.Tool{{ + Name: "test_tool", + Description: "A test tool", + }} + + config := convertToolConfig(requestTools, false, "required") + + require.NotNil(t, config) + _, isAny := config.ToolChoice.(*types.ToolChoiceMemberAny) + assert.True(t, isAny, "expected ToolChoiceMemberAny for required") +} + +func TestConvertToolConfig_ToolChoiceAuto(t *testing.T) { + t.Parallel() + + requestTools := []tools.Tool{{ + Name: "test_tool", + Description: "A test tool", + }} + + config := convertToolConfig(requestTools, false, "auto") + + require.NotNil(t, config) + _, isAuto := config.ToolChoice.(*types.ToolChoiceMemberAuto) + assert.True(t, isAuto, "expected ToolChoiceMemberAuto for auto") +} + func TestPromptCachingEnabled_TypeMismatch(t *testing.T) { t.Parallel() diff --git a/pkg/model/provider/bedrock/convert.go b/pkg/model/provider/bedrock/convert.go index 060b7e137..03512680c 100644 --- a/pkg/model/provider/bedrock/convert.go +++ b/pkg/model/provider/bedrock/convert.go @@ -247,7 +247,7 @@ func mapToDocument(m map[string]any) document.Interface { return document.NewLazyDocument(m) } -func convertToolConfig(requestTools []tools.Tool, enableCaching bool) *types.ToolConfiguration { +func convertToolConfig(requestTools []tools.Tool, enableCaching bool, toolChoice string) *types.ToolConfiguration { if len(requestTools) == 0 { return nil } @@ -273,9 +273,21 @@ func convertToolConfig(requestTools []tools.Tool, enableCaching bool) *types.Too }) } + var choice types.ToolChoice + switch toolChoice { + case "required": + choice = &types.ToolChoiceMemberAny{Value: types.AnyToolChoice{}} + case "none": + // Bedrock does not have a "none" tool choice; omit tool config entirely + // would be ideal, but if tools are present we default to auto. + choice = &types.ToolChoiceMemberAuto{Value: types.AutoToolChoice{}} + default: // "auto" or empty + choice = &types.ToolChoiceMemberAuto{Value: types.AutoToolChoice{}} + } + return &types.ToolConfiguration{ Tools: toolSpecs, - ToolChoice: &types.ToolChoiceMemberAuto{Value: types.AutoToolChoice{}}, + ToolChoice: choice, } } diff --git a/pkg/model/provider/dmr/client.go b/pkg/model/provider/dmr/client.go index 439e31918..653d718c9 100644 --- a/pkg/model/provider/dmr/client.go +++ b/pkg/model/provider/dmr/client.go @@ -212,6 +212,14 @@ func (c *Client) CreateChatCompletionStream(ctx context.Context, messages []chat if c.ModelConfig.ParallelToolCalls != nil { params.ParallelToolCalls = openai.Bool(*c.ModelConfig.ParallelToolCalls) } + + // Apply tool_choice from agent config + if toolChoice := c.ModelOptions.ToolChoice(); toolChoice != "" { + params.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{ + OfAuto: openai.Opt(toolChoice), + } + slog.Debug("DMR request using tool_choice", "tool_choice", toolChoice) + } } // Log the request in JSON format for debugging diff --git a/pkg/model/provider/gemini/client.go b/pkg/model/provider/gemini/client.go index e48e6f6e4..2b3fdc1b6 100644 --- a/pkg/model/provider/gemini/client.go +++ b/pkg/model/provider/gemini/client.go @@ -569,9 +569,22 @@ func (c *Client) CreateChatCompletionStream( config.Tools = append(config.Tools, allTools...) // Enable function calling + mode := genai.FunctionCallingConfigModeAuto + if toolChoice := c.ModelOptions.ToolChoice(); toolChoice != "" { + switch toolChoice { + case "required": + mode = genai.FunctionCallingConfigModeAny + case "none": + mode = genai.FunctionCallingConfigModeNone + default: // "auto" + mode = genai.FunctionCallingConfigModeAuto + } + slog.Debug("Gemini request using tool_choice", "tool_choice", toolChoice, "mode", mode) + } + config.ToolConfig = &genai.ToolConfig{ FunctionCallingConfig: &genai.FunctionCallingConfig{ - Mode: genai.FunctionCallingConfigModeAuto, + Mode: mode, }, } diff --git a/pkg/model/provider/openai/client.go b/pkg/model/provider/openai/client.go index c5176daa6..6c0f03ab3 100644 --- a/pkg/model/provider/openai/client.go +++ b/pkg/model/provider/openai/client.go @@ -272,6 +272,14 @@ func (c *Client) CreateChatCompletionStream( if c.ModelConfig.ParallelToolCalls != nil { params.ParallelToolCalls = openai.Bool(*c.ModelConfig.ParallelToolCalls) } + + // Apply tool_choice from agent config + if toolChoice := c.ModelOptions.ToolChoice(); toolChoice != "" { + params.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{ + OfAuto: openai.Opt(toolChoice), + } + slog.Debug("OpenAI request using tool_choice", "tool_choice", toolChoice) + } } // Apply thinking budget: set reasoning_effort for reasoning models (o-series, gpt-5) @@ -381,6 +389,14 @@ func (c *Client) CreateResponseStream( if c.ModelConfig.ParallelToolCalls != nil { params.ParallelToolCalls = param.NewOpt(*c.ModelConfig.ParallelToolCalls) } + + // Apply tool_choice from agent config + if toolChoice := c.ModelOptions.ToolChoice(); toolChoice != "" { + params.ToolChoice = responses.ResponseNewParamsToolChoiceUnion{ + OfToolChoiceMode: param.NewOpt(responses.ToolChoiceOptions(toolChoice)), + } + slog.Debug("OpenAI responses request using tool_choice", "tool_choice", toolChoice) + } } // Configure reasoning for models that support it (o-series, gpt-5). diff --git a/pkg/model/provider/options/options.go b/pkg/model/provider/options/options.go index 8ac2866bc..e476ebe6c 100644 --- a/pkg/model/provider/options/options.go +++ b/pkg/model/provider/options/options.go @@ -11,6 +11,7 @@ type ModelOptions struct { noThinking bool maxTokens int64 providers map[string]latest.ProviderConfig + toolChoice string } func (c *ModelOptions) Gateway() string { @@ -37,6 +38,10 @@ func (c *ModelOptions) Providers() map[string]latest.ProviderConfig { return c.providers } +func (c *ModelOptions) ToolChoice() string { + return c.toolChoice +} + type Opt func(*ModelOptions) func WithGateway(gateway string) Opt { @@ -75,6 +80,12 @@ func WithProviders(providers map[string]latest.ProviderConfig) Opt { } } +func WithToolChoice(toolChoice string) Opt { + return func(cfg *ModelOptions) { + cfg.toolChoice = toolChoice + } +} + // FromModelOptions converts a concrete ModelOptions value into a slice of // Opt configuration functions. Later Opts override earlier ones when applied. func FromModelOptions(m ModelOptions) []Opt { @@ -97,5 +108,8 @@ func FromModelOptions(m ModelOptions) []Opt { if len(m.providers) > 0 { out = append(out, WithProviders(m.providers)) } + if m.toolChoice != "" { + out = append(out, WithToolChoice(m.toolChoice)) + } return out } diff --git a/pkg/teamloader/teamloader.go b/pkg/teamloader/teamloader.go index 028c43513..eb8a4fb71 100644 --- a/pkg/teamloader/teamloader.go +++ b/pkg/teamloader/teamloader.go @@ -160,6 +160,7 @@ func LoadWithConfig(ctx context.Context, agentSource config.Source, runConfig *c agent.WithAddDate(agentConfig.AddDate), agent.WithAddEnvironmentInfo(agentConfig.AddEnvironmentInfo), agent.WithAddDescriptionParameter(agentConfig.AddDescriptionParameter), + agent.WithToolChoice(agentConfig.ToolChoice), agent.WithAddPromptFiles(promptFiles), agent.WithMaxIterations(agentConfig.MaxIterations), agent.WithMaxConsecutiveToolCalls(agentConfig.MaxConsecutiveToolCalls), @@ -302,6 +303,7 @@ func getModelsForAgent(ctx context.Context, cfg *latest.Config, a *latest.AgentC opts := []options.Opt{ options.WithGateway(runConfig.ModelsGateway), options.WithStructuredOutput(a.StructuredOutput), + options.WithToolChoice(a.ToolChoice), options.WithProviders(cfg.Providers), } if maxTokens != nil { @@ -362,6 +364,7 @@ func getFallbackModelsForAgent(ctx context.Context, cfg *latest.Config, a *lates opts := []options.Opt{ options.WithGateway(runConfig.ModelsGateway), options.WithStructuredOutput(a.StructuredOutput), + options.WithToolChoice(a.ToolChoice), options.WithProviders(cfg.Providers), } if maxTokens != nil {