diff --git a/aisdk/ai/provider/openai/internal/codec/jsonschema.go b/aisdk/ai/provider/openai/internal/codec/jsonschema.go index 72cc1e2c..692782e0 100644 --- a/aisdk/ai/provider/openai/internal/codec/jsonschema.go +++ b/aisdk/ai/provider/openai/internal/codec/jsonschema.go @@ -16,6 +16,17 @@ func encodeSchema(schema *jsonschema.Schema) (map[string]any, error) { return nil, nil } + // Enforce OpenAI restrictions + // https://platform.openai.com/docs/guides/structured-outputs#root-objects-must-not-be-anyof-and-must-be-an-object + // NOTE: we could simply encode the input schema, pass it through to OpenAI and let it return an error, but there are + // other encoding rules we want to enforce later, and limiting the scope here allows us to limit the scope later. + if schema.Type != "object" { + return nil, fmt.Errorf("schema root must be of type object, got: %s", schema.Type) + } + if schema.AnyOf != nil { + return nil, fmt.Errorf("schema root cannot use AnyOf") + } + // Marshal to JSON and unmarshal back to interface{} to convert the types data, err := json.Marshal(schema) if err != nil { @@ -32,6 +43,12 @@ func encodeSchema(schema *jsonschema.Schema) (map[string]any, error) { return nil, fmt.Errorf("failed to unmarshal properties: %w\n\n%s", err, data) } + // Ensure properties field is set, even if it's empty. It's unclear whether OpenAI requires + // this to be set for nested schema objects too. For now we only set it at the top-level. + if _, ok := result["properties"]; !ok { + result["properties"] = map[string]any{} + } + // Convert {"not": {}} patterns to false throughout the schema normalizeSchemaMap(result) diff --git a/aisdk/ai/provider/openai/internal/codec/jsonschema_test.go b/aisdk/ai/provider/openai/internal/codec/jsonschema_test.go index 166e38c8..88c192a2 100644 --- a/aisdk/ai/provider/openai/internal/codec/jsonschema_test.go +++ b/aisdk/ai/provider/openai/internal/codec/jsonschema_test.go @@ -143,20 +143,42 @@ func TestEncodeSchema(t *testing.T) { }`, }, { - name: "schema with allOf containing additionalProperties", + name: "schema with nested AnyOf", input: &jsonschema.Schema{ - AllOf: []*jsonschema.Schema{ - { - Type: "object", - AdditionalProperties: api.FalseSchema(), + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "numeric": { + AnyOf: []*jsonschema.Schema{ + { + Type: "string", + }, + { + Type: "number", + }, + }, }, }, }, want: `{ - "allOf": [{ - "type": "object", - "additionalProperties": false - }] + "type": "object", + "properties": { + "numeric": { + "anyOf": [ + { "type": "string" }, + { "type": "number" } + ] + } + } + }`, + }, + { + name: "schema without properties gets empty properties map", + input: &jsonschema.Schema{ + Type: "object", + }, + want: `{ + "type": "object", + "properties": {} }`, }, { @@ -210,6 +232,44 @@ func TestEncodeSchema(t *testing.T) { "required": ["id"] }`, }, + + // Edge/error cases + { + name: "schema with non-object root", + input: &jsonschema.Schema{ + Properties: map[string]*jsonschema.Schema{ + "name": { + Type: "string", + Description: "The name", + }, + }, + }, + wantErr: true, + }, + { + name: "empty schema", + input: &jsonschema.Schema{}, + wantErr: true, + }, + { + name: "schema with only additional properties", + input: &jsonschema.Schema{ + AdditionalProperties: api.FalseSchema(), + }, + wantErr: true, + }, + { + name: "schema with AnyOf at rool level", + input: &jsonschema.Schema{ + AnyOf: []*jsonschema.Schema{ + { + Type: "object", + AdditionalProperties: api.FalseSchema(), + }, + }, + }, + wantErr: true, + }, } for _, tt := range tests { @@ -451,35 +511,3 @@ func TestNormalizeSchemaMap(t *testing.T) { }) } } - -func TestEncodeSchema_EdgeCases(t *testing.T) { - t.Run("schema with only additionalProperties", func(t *testing.T) { - schema := &jsonschema.Schema{ - AdditionalProperties: api.FalseSchema(), - } - - got, err := encodeSchema(schema) - require.NoError(t, err) - - gotJSON, err := json.Marshal(got) - require.NoError(t, err) - - expectedJSON := `{ - "additionalProperties": false - }` - - assert.JSONEq(t, expectedJSON, string(gotJSON)) - }) - - t.Run("empty schema", func(t *testing.T) { - schema := &jsonschema.Schema{} - - got, err := encodeSchema(schema) - require.NoError(t, err) - - gotJSON, err := json.Marshal(got) - require.NoError(t, err) - - assert.JSONEq(t, "{}", string(gotJSON)) - }) -}