Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions agent-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,15 @@
"type": "boolean",
"description": "Whether to add a 'description' parameter to tool calls, allowing the LLM to provide context about why it is calling a tool"
},
"tool_choice": {
"type": "string",
"description": "Controls how the model selects tools. 'auto' (default) lets the model decide, 'required' forces the model to always call a tool, 'none' prevents tool use.",
"enum": [
"auto",
"required",
"none"
]
},
"hooks": {
"$ref": "#/definitions/HooksConfig",
"description": "Lifecycle hooks for executing shell commands at various points in the agent's execution"
Expand Down
11 changes: 11 additions & 0 deletions examples/tool_choice_required.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
agents:
root:
model: anthropic/claude-opus-4-6
instruction: >
You are a coding agent. Use your tools to read, write, and modify files,
and to run shell commands. Always use tools to accomplish tasks rather than
just describing what to do.
tool_choice: required
toolsets:
- type: filesystem
- type: shell
6 changes: 6 additions & 0 deletions pkg/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type Agent struct {
addDate bool
addEnvironmentInfo bool
addDescriptionParameter bool
toolChoice string
maxIterations int
maxConsecutiveToolCalls int
maxOldToolCallTokens int
Expand Down Expand Up @@ -206,6 +207,11 @@ func (a *Agent) Hooks() *latest.HooksConfig {
return a.hooks
}

// ToolChoice returns the tool choice mode for this agent (e.g., "auto", "required", "none").
func (a *Agent) ToolChoice() string {
return a.toolChoice
}

// Tools returns the tools available to this agent
func (a *Agent) Tools(ctx context.Context) ([]tools.Tool, error) {
a.ensureToolSetsAreStarted(ctx)
Expand Down
6 changes: 6 additions & 0 deletions pkg/agent/opts.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ func WithAddDescriptionParameter(addDescriptionParameter bool) Opt {
}
}

func WithToolChoice(toolChoice string) Opt {
return func(a *Agent) {
a.toolChoice = toolChoice
}
}

func WithAddPromptFiles(addPromptFiles []string) Opt {
return func(a *Agent) {
a.addPromptFiles = addPromptFiles
Expand Down
1 change: 1 addition & 0 deletions pkg/config/latest/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ type AgentConfig struct {
AddEnvironmentInfo bool `json:"add_environment_info,omitempty"`
CodeModeTools bool `json:"code_mode_tools,omitempty"`
AddDescriptionParameter bool `json:"add_description_parameter,omitempty"`
ToolChoice string `json:"tool_choice,omitempty"`
MaxIterations int `json:"max_iterations,omitempty"`
MaxConsecutiveToolCalls int `json:"max_consecutive_tool_calls,omitempty"`
MaxOldToolCallTokens int `json:"max_old_tool_call_tokens,omitempty"`
Expand Down
18 changes: 18 additions & 0 deletions pkg/config/latest/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ func (t *Config) validate() error {
return err
}

// Validate tool_choice
if err := agent.validateToolChoice(); err != nil {
return err
}

for j := range agent.Toolsets {
if err := agent.Toolsets[j].validate(); err != nil {
return err
Expand All @@ -38,6 +43,19 @@ func (t *Config) validate() error {
return nil
}

// validateToolChoice validates the tool_choice configuration for an agent
func (a *AgentConfig) validateToolChoice() error {
if a.ToolChoice == "" {
return nil
}
switch a.ToolChoice {
case "auto", "required", "none":
return nil
default:
return errors.New("tool_choice must be one of: auto, required, none")
}
}

// validateFallback validates the fallback configuration for an agent
func (a *AgentConfig) validateFallback() error {
if a.Fallback == nil {
Expand Down
76 changes: 76 additions & 0 deletions pkg/config/latest/validate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,79 @@ agents:
})
}
}

func TestAgentConfig_Validate_ToolChoice(t *testing.T) {
t.Parallel()

tests := []struct {
name string
config string
wantErr string
}{
{
name: "valid tool_choice auto",
config: `
agents:
root:
model: "openai/gpt-4"
tool_choice: auto
`,
wantErr: "",
},
{
name: "valid tool_choice required",
config: `
agents:
root:
model: "openai/gpt-4"
tool_choice: required
`,
wantErr: "",
},
{
name: "valid tool_choice none",
config: `
agents:
root:
model: "openai/gpt-4"
tool_choice: none
`,
wantErr: "",
},
{
name: "no tool_choice set",
config: `
agents:
root:
model: "openai/gpt-4"
`,
wantErr: "",
},
{
name: "invalid tool_choice value",
config: `
agents:
root:
model: "openai/gpt-4"
tool_choice: force
`,
wantErr: "tool_choice must be one of: auto, required, none",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

var cfg Config
err := yaml.Unmarshal([]byte(tt.config), &cfg)

if tt.wantErr != "" {
require.Error(t, err)
require.Contains(t, err.Error(), tt.wantErr)
} else {
require.NoError(t, err)
}
})
}
}
13 changes: 13 additions & 0 deletions pkg/model/provider/anthropic/beta_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,19 @@ func (c *Client) createBetaStream(

if len(requestTools) > 0 {
slog.Debug("Anthropic Beta API: Adding tools to request", "tool_count", len(requestTools))

// Apply tool_choice from agent config
if toolChoice := c.ModelOptions.ToolChoice(); toolChoice != "" {
switch toolChoice {
case "required":
params.ToolChoice = anthropic.BetaToolChoiceUnionParam{OfAny: &anthropic.BetaToolChoiceAnyParam{}}
case "none":
params.ToolChoice = anthropic.BetaToolChoiceUnionParam{OfNone: &anthropic.BetaToolChoiceNoneParam{}}
default: // "auto" or any other value
params.ToolChoice = anthropic.BetaToolChoiceUnionParam{OfAuto: &anthropic.BetaToolChoiceAutoParam{}}
}
slog.Debug("Anthropic Beta API request using tool_choice", "tool_choice", toolChoice)
}
}

slog.Debug("Anthropic Beta API chat completion stream request",
Expand Down
13 changes: 13 additions & 0 deletions pkg/model/provider/anthropic/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,19 @@ func (c *Client) CreateChatCompletionStream(

if len(requestTools) > 0 {
slog.Debug("Adding tools to Anthropic request", "tool_count", len(requestTools))

// Apply tool_choice from agent config
if toolChoice := c.ModelOptions.ToolChoice(); toolChoice != "" {
switch toolChoice {
case "required":
params.ToolChoice = anthropic.ToolChoiceUnionParam{OfAny: &anthropic.ToolChoiceAnyParam{}}
case "none":
params.ToolChoice = anthropic.ToolChoiceUnionParam{OfNone: &anthropic.ToolChoiceNoneParam{}}
default: // "auto" or any other value
params.ToolChoice = anthropic.ToolChoiceUnionParam{OfAuto: &anthropic.ToolChoiceAutoParam{}}
}
slog.Debug("Anthropic request using tool_choice", "tool_choice", toolChoice)
}
}

// Log the request details for debugging
Expand Down
2 changes: 1 addition & 1 deletion pkg/model/provider/bedrock/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ func (c *Client) buildConverseStreamInput(messages []chat.Message, requestTools

// Convert and set tools
if len(requestTools) > 0 {
input.ToolConfig = convertToolConfig(requestTools, enableCaching)
input.ToolConfig = convertToolConfig(requestTools, enableCaching, c.ModelOptions.ToolChoice())
}

return input
Expand Down
40 changes: 35 additions & 5 deletions pkg/model/provider/bedrock/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func TestConvertToolConfig(t *testing.T) {
},
}}

config := convertToolConfig(requestTools, false)
config := convertToolConfig(requestTools, false, "")

require.NotNil(t, config)
require.Len(t, config.Tools, 1)
Expand All @@ -142,10 +142,10 @@ func TestConvertToolConfig(t *testing.T) {
func TestConvertToolConfig_Empty(t *testing.T) {
t.Parallel()

config := convertToolConfig(nil, false)
config := convertToolConfig(nil, false, "")
assert.Nil(t, config)

config = convertToolConfig([]tools.Tool{}, false)
config = convertToolConfig([]tools.Tool{}, false, "")
assert.Nil(t, config)
}

Expand Down Expand Up @@ -1179,7 +1179,7 @@ func TestConvertToolConfig_WithCaching(t *testing.T) {
Description: "A test tool",
}}

config := convertToolConfig(requestTools, true)
config := convertToolConfig(requestTools, true, "")

require.NotNil(t, config)
require.Len(t, config.Tools, 2) // tool spec + cache point
Expand All @@ -1197,12 +1197,42 @@ func TestConvertToolConfig_WithoutCaching(t *testing.T) {
Description: "A test tool",
}}

config := convertToolConfig(requestTools, false)
config := convertToolConfig(requestTools, false, "")

require.NotNil(t, config)
require.Len(t, config.Tools, 1) // just tool spec, no cache point
}

func TestConvertToolConfig_ToolChoiceRequired(t *testing.T) {
t.Parallel()

requestTools := []tools.Tool{{
Name: "test_tool",
Description: "A test tool",
}}

config := convertToolConfig(requestTools, false, "required")

require.NotNil(t, config)
_, isAny := config.ToolChoice.(*types.ToolChoiceMemberAny)
assert.True(t, isAny, "expected ToolChoiceMemberAny for required")
}

func TestConvertToolConfig_ToolChoiceAuto(t *testing.T) {
t.Parallel()

requestTools := []tools.Tool{{
Name: "test_tool",
Description: "A test tool",
}}

config := convertToolConfig(requestTools, false, "auto")

require.NotNil(t, config)
_, isAuto := config.ToolChoice.(*types.ToolChoiceMemberAuto)
assert.True(t, isAuto, "expected ToolChoiceMemberAuto for auto")
}

func TestPromptCachingEnabled_TypeMismatch(t *testing.T) {
t.Parallel()

Expand Down
16 changes: 14 additions & 2 deletions pkg/model/provider/bedrock/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ func mapToDocument(m map[string]any) document.Interface {
return document.NewLazyDocument(m)
}

func convertToolConfig(requestTools []tools.Tool, enableCaching bool) *types.ToolConfiguration {
func convertToolConfig(requestTools []tools.Tool, enableCaching bool, toolChoice string) *types.ToolConfiguration {
if len(requestTools) == 0 {
return nil
}
Expand All @@ -273,9 +273,21 @@ func convertToolConfig(requestTools []tools.Tool, enableCaching bool) *types.Too
})
}

var choice types.ToolChoice
switch toolChoice {
case "required":
choice = &types.ToolChoiceMemberAny{Value: types.AnyToolChoice{}}
case "none":
// Bedrock does not have a "none" tool choice; omit tool config entirely
// would be ideal, but if tools are present we default to auto.
choice = &types.ToolChoiceMemberAuto{Value: types.AutoToolChoice{}}
default: // "auto" or empty
choice = &types.ToolChoiceMemberAuto{Value: types.AutoToolChoice{}}
}

return &types.ToolConfiguration{
Tools: toolSpecs,
ToolChoice: &types.ToolChoiceMemberAuto{Value: types.AutoToolChoice{}},
ToolChoice: choice,
}
}

Expand Down
8 changes: 8 additions & 0 deletions pkg/model/provider/dmr/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,14 @@ func (c *Client) CreateChatCompletionStream(ctx context.Context, messages []chat
if c.ModelConfig.ParallelToolCalls != nil {
params.ParallelToolCalls = openai.Bool(*c.ModelConfig.ParallelToolCalls)
}

// Apply tool_choice from agent config
if toolChoice := c.ModelOptions.ToolChoice(); toolChoice != "" {
params.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{
OfAuto: openai.Opt(toolChoice),
}
slog.Debug("DMR request using tool_choice", "tool_choice", toolChoice)
}
}

// Log the request in JSON format for debugging
Expand Down
15 changes: 14 additions & 1 deletion pkg/model/provider/gemini/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -569,9 +569,22 @@ func (c *Client) CreateChatCompletionStream(
config.Tools = append(config.Tools, allTools...)

// Enable function calling
mode := genai.FunctionCallingConfigModeAuto
if toolChoice := c.ModelOptions.ToolChoice(); toolChoice != "" {
switch toolChoice {
case "required":
mode = genai.FunctionCallingConfigModeAny
case "none":
mode = genai.FunctionCallingConfigModeNone
default: // "auto"
mode = genai.FunctionCallingConfigModeAuto
}
slog.Debug("Gemini request using tool_choice", "tool_choice", toolChoice, "mode", mode)
}

config.ToolConfig = &genai.ToolConfig{
FunctionCallingConfig: &genai.FunctionCallingConfig{
Mode: genai.FunctionCallingConfigModeAuto,
Mode: mode,
},
}

Expand Down
Loading
Loading