diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 977c85a3..1c095295 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -457,16 +457,35 @@ def {_field_name}( case _: encoder_parts.append((None, "x")) - # Build the ternary chain from encoder_parts + # Build the ternary chain from encoder_parts. + # + # Every entry that has a `type_check` (isinstance / `x is None`) gets + # its own guard, including the last one. Falling off the end means the + # input did not match any declared anyOf variant, which should not + # happen for a well-formed value; we emit a `cast(Any, x)` so mypy + # doesn't try to narrow the value through the chain. + # + # Previously the last entry was emitted unconditionally as the `else` + # branch. That works for simple unions (object | str | list), but + # breaks down when the last variant's encoder requires iteration + # (e.g. `[encode_X(y) for y in x]` when the variant is an array) and + # mypy fails to fully narrow `x` through the prior `isinstance` + # checks. The unguarded final branch then triggers `union-attr` + # errors like "Item 'float' has no attribute '__iter__'". typeddict_encoder = list[str]() - for i, (type_check, encoder_expr) in enumerate(encoder_parts): - is_last = i == len(encoder_parts) - 1 - if is_last or type_check is None: - # Last item or no type check - just the expression + has_unguarded_terminal = False + for type_check, encoder_expr in encoder_parts: + if type_check is None: + # No type check available — emit the bare expression and stop; + # nothing after it could be reached anyway. typeddict_encoder.append(encoder_expr) - else: - # Add expression with type check - typeddict_encoder.append(f"{encoder_expr} if {type_check} else") + has_unguarded_terminal = True + break + typeddict_encoder.append(f"{encoder_expr} if {type_check} else") + if not has_unguarded_terminal and encoder_parts: + # Unreachable in practice (every declared variant was guarded + # above), but mypy needs a concrete final expression. + typeddict_encoder.append("cast(Any, x)") if permit_unknown_members: union = _make_open_union_type_expr(any_of) else: diff --git a/tests/v1/codegen/snapshot/snapshots/test_anyof_array_in_union/__init__.py b/tests/v1/codegen/snapshot/snapshots/test_anyof_array_in_union/__init__.py new file mode 100644 index 00000000..e935fb7a --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/test_anyof_array_in_union/__init__.py @@ -0,0 +1,13 @@ +# Code generated by river.codegen. DO NOT EDIT. +from pydantic import BaseModel +from typing import Literal + +import replit_river as river + + +from .test_service import Test_ServiceService + + +class AnyOfArrayInUnionClient: + def __init__(self, client: river.Client[Literal[None]]): + self.test_service = Test_ServiceService(client) diff --git a/tests/v1/codegen/snapshot/snapshots/test_anyof_array_in_union/test_service/__init__.py b/tests/v1/codegen/snapshot/snapshots/test_anyof_array_in_union/test_service/__init__.py new file mode 100644 index 00000000..172ff4c2 --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/test_anyof_array_in_union/test_service/__init__.py @@ -0,0 +1,42 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +from typing import Any +import datetime + +from pydantic import TypeAdapter + +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter +import replit_river as river + + +from .exec_sql_method import ( + Exec_Sql_MethodInput, + Exec_Sql_MethodOutput, + Exec_Sql_MethodOutputTypeAdapter, + encode_Exec_Sql_MethodInput, + encode_Exec_Sql_MethodInputParams, +) + + +class Test_ServiceService: + def __init__(self, client: river.Client[Any]): + self.client = client + + async def exec_sql_method( + self, + input: Exec_Sql_MethodInput, + timeout: datetime.timedelta, + ) -> Exec_Sql_MethodOutput: + return await self.client.send_rpc( + "test_service", + "exec_sql_method", + input, + encode_Exec_Sql_MethodInput, + lambda x: Exec_Sql_MethodOutputTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + lambda x: RiverErrorTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + timeout, + ) diff --git a/tests/v1/codegen/snapshot/snapshots/test_anyof_array_in_union/test_service/exec_sql_method.py b/tests/v1/codegen/snapshot/snapshots/test_anyof_array_in_union/test_service/exec_sql_method.py new file mode 100644 index 00000000..ad2240b2 --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/test_anyof_array_in_union/test_service/exec_sql_method.py @@ -0,0 +1,82 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +import datetime +from typing import ( + Any, + Literal, + Mapping, + NotRequired, + TypedDict, + cast, +) +from typing_extensions import Annotated + +from pydantic import BaseModel, Field, TypeAdapter, WrapValidator +from replit_river.error_schema import RiverError +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) + +import replit_river as river + + +Exec_Sql_MethodInputParamsAnyOf_4 = str | float | bool | None + + +def encode_Exec_Sql_MethodInputParamsAnyOf_4( + x: "Exec_Sql_MethodInputParamsAnyOf_4", +) -> Any: + return x + + +Exec_Sql_MethodInputParams = ( + str | float | bool | list[Exec_Sql_MethodInputParamsAnyOf_4] | None +) + + +def encode_Exec_Sql_MethodInputParams(x: "Exec_Sql_MethodInputParams") -> Any: + return ( + x + if isinstance(x, str) + else x + if isinstance(x, (int, float)) + else x + if isinstance(x, bool) + else None + if x is None + else [encode_Exec_Sql_MethodInputParamsAnyOf_4(y) for y in x] + if isinstance(x, list) + else cast(Any, x) + ) + + +def encode_Exec_Sql_MethodInput( + x: "Exec_Sql_MethodInput", +) -> Any: + return { + k: v + for (k, v) in ( + { + "params": [encode_Exec_Sql_MethodInputParams(y) for y in x["params"]] + if "params" in x and x["params"] is not None + else None, + } + ).items() + if v is not None + } + + +class Exec_Sql_MethodInput(TypedDict): + params: NotRequired[list[Exec_Sql_MethodInputParams] | None] + + +class Exec_Sql_MethodOutput(BaseModel): + ok: bool + + +Exec_Sql_MethodOutputTypeAdapter: TypeAdapter[Exec_Sql_MethodOutput] = TypeAdapter( + Exec_Sql_MethodOutput +) diff --git a/tests/v1/codegen/snapshot/snapshots/test_anyof_mixed_types/test_service/anyof_mixed_method.py b/tests/v1/codegen/snapshot/snapshots/test_anyof_mixed_types/test_service/anyof_mixed_method.py index dce862c0..e0b41728 100644 --- a/tests/v1/codegen/snapshot/snapshots/test_anyof_mixed_types/test_service/anyof_mixed_method.py +++ b/tests/v1/codegen/snapshot/snapshots/test_anyof_mixed_types/test_service/anyof_mixed_method.py @@ -66,6 +66,8 @@ def encode_Anyof_Mixed_MethodInputRun_Command( else x if isinstance(x, str) else list(x) + if isinstance(x, list) + else cast(Any, x) ) diff --git a/tests/v1/codegen/snapshot/test_anyof_array_in_union.py b/tests/v1/codegen/snapshot/test_anyof_array_in_union.py new file mode 100644 index 00000000..fe4cf305 --- /dev/null +++ b/tests/v1/codegen/snapshot/test_anyof_array_in_union.py @@ -0,0 +1,31 @@ +from pytest_snapshot.plugin import Snapshot + +from tests.fixtures.codegen_snapshot_fixtures import validate_codegen + + +async def test_anyof_array_in_union(snapshot: Snapshot) -> None: + """Test codegen for an array field whose item type is a non-discriminated + anyOf union that itself contains an `array` variant. + + Concretely this mirrors the PostgreSQL `executeSqlCommand.params` schema: + `array>`. The inner union encoder ends in an + iteration over `x` (for the array variant), and historically that branch + was emitted as the unguarded `else` of a ternary chain. When mypy failed + to fully narrow `x` to `list[...]` through the preceding `isinstance` + checks, it complained that scalar items of the union have no + `__iter__` attribute (`union-attr`). + + The fix emits an explicit `isinstance(x, list)` guard for the array + branch and a `cast(Any, x)` fallback, so mypy never has to negative- + narrow into the iterating branch. + """ + validate_codegen( + snapshot=snapshot, + snapshot_dir="tests/v1/codegen/snapshot/snapshots", + read_schema=lambda: open( + "tests/v1/codegen/types/anyof_array_in_union_schema.json" + ), + target_path="test_anyof_array_in_union", + client_name="AnyOfArrayInUnionClient", + protocol_version="v1.1", + ) diff --git a/tests/v1/codegen/types/anyof_array_in_union_schema.json b/tests/v1/codegen/types/anyof_array_in_union_schema.json new file mode 100644 index 00000000..54e3c357 --- /dev/null +++ b/tests/v1/codegen/types/anyof_array_in_union_schema.json @@ -0,0 +1,49 @@ +{ + "services": { + "test_service": { + "procedures": { + "exec_sql_method": { + "input": { + "type": "object", + "properties": { + "params": { + "description": "Parameterized query values. Each entry is either a scalar or an array of scalars (for ANY($1::text[]) etc.).", + "type": "array", + "items": { + "anyOf": [ + { "type": "string" }, + { "type": "number" }, + { "type": "boolean" }, + { "type": "null" }, + { + "type": "array", + "items": { + "anyOf": [ + { "type": "string" }, + { "type": "number" }, + { "type": "boolean" }, + { "type": "null" } + ] + } + } + ] + } + } + } + }, + "output": { + "type": "object", + "properties": { + "ok": { "type": "boolean" } + }, + "required": ["ok"] + }, + "errors": { + "not": {} + }, + "type": "rpc" + } + } + } + } +}