diff --git a/packages/reflex-base/src/reflex_base/event/__init__.py b/packages/reflex-base/src/reflex_base/event/__init__.py index ca0347745f9..3cd37b65e3e 100644 --- a/packages/reflex-base/src/reflex_base/event/__init__.py +++ b/packages/reflex-base/src/reflex_base/event/__init__.py @@ -23,7 +23,14 @@ overload, ) -from typing_extensions import Self, TypeAliasType, TypedDict, TypeVarTuple, Unpack +from typing_extensions import ( + Self, + TypeAliasType, + TypedDict, + TypeVarTuple, + Unpack, + is_typeddict, +) from reflex_base import constants from reflex_base.components.field import BaseField @@ -58,6 +65,10 @@ if TYPE_CHECKING: from reflex.state import BaseState + BASE_STATE = TypeVar("BASE_STATE", bound=BaseState) +else: + BASE_STATE = TypeVar("BASE_STATE") + @dataclasses.dataclass( init=True, @@ -1683,6 +1694,37 @@ def _values_returned_from_event(event_spec_annotations: list[Any]) -> list[Any]: ] +def _is_on_submit_mapping_event_arg_compatible_with_typed_dict( + provided_event_arg_type: Any, + callback_param_type: Any, + key: str, +) -> bool: + """Check whether an on_submit mapping payload can satisfy a TypedDict callback. + + This keeps the compatibility relaxation scoped to form submission payloads + rather than applying to unrelated mapping-based event triggers. + + Args: + provided_event_arg_type: The type produced by the event trigger. + callback_param_type: The callback parameter annotation. + key: The event trigger key being validated. + + Returns: + Whether the provided event payload should be treated as compatible. + """ + if key != constants.EventTriggers.ON_SUBMIT or not is_typeddict( + callback_param_type + ): + return False + + mapping_type = get_origin(provided_event_arg_type) or provided_event_arg_type + if not safe_issubclass(mapping_type, Mapping): + return False + + key_type = get_args(provided_event_arg_type)[:1] + return not key_type or typehint_issubclass(key_type[0], str) + + def _check_event_args_subclass_of_callback( callback_params_names: list[str], provided_event_types: list[Any], @@ -1724,15 +1766,18 @@ def _check_event_args_subclass_of_callback( continue type_match_found.setdefault(arg, False) + callback_param_type = callback_param_name_to_type[arg] try: compare_result = typehint_issubclass( - args_types_without_vars[i], callback_param_name_to_type[arg] + args_types_without_vars[i], callback_param_type + ) or _is_on_submit_mapping_event_arg_compatible_with_typed_dict( + args_types_without_vars[i], callback_param_type, key ) except TypeError as te: callback_name_context = f" of {callback_name}" if callback_name else "" key_context = f" for {key}" if key else "" - msg = f"Could not compare types {args_types_without_vars[i]} and {callback_param_name_to_type[arg]} for argument {arg}{callback_name_context}{key_context}." + msg = f"Could not compare types {args_types_without_vars[i]} and {callback_param_type} for argument {arg}{callback_name_context}{key_context}." raise TypeError(msg) from te if compare_result: @@ -1744,7 +1789,7 @@ def _check_event_args_subclass_of_callback( ) delayed_exceptions.append( EventHandlerArgTypeMismatchError( - f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {callback_param_name_to_type[arg]}{as_annotated_in} instead." + f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {callback_param_type}{as_annotated_in} instead." ) ) @@ -2557,10 +2602,6 @@ def __call__(self, *args: Var) -> Any: if TYPE_CHECKING: from reflex.state import BaseState - BASE_STATE = TypeVar("BASE_STATE", bound=BaseState) -else: - BASE_STATE = TypeVar("BASE_STATE") - class EventNamespace: """A namespace for event related classes.""" diff --git a/packages/reflex-components-core/src/reflex_components_core/el/elements/forms.py b/packages/reflex-components-core/src/reflex_components_core/el/elements/forms.py index 8a2422081a5..f460363e403 100644 --- a/packages/reflex-components-core/src/reflex_components_core/el/elements/forms.py +++ b/packages/reflex-components-core/src/reflex_components_core/el/elements/forms.py @@ -2,17 +2,19 @@ from __future__ import annotations -from collections.abc import Iterator +from collections.abc import Iterator, Mapping +from functools import partial from hashlib import md5 -from typing import Any, ClassVar, Literal +from typing import Any, ClassVar, Literal, TypeVar, get_origin, get_type_hints -from reflex_base.components.component import field +from reflex_base.components.component import BaseComponent, Component, field from reflex_base.components.tags.tag import Tag from reflex_base.constants import Dirs, EventTriggers from reflex_base.event import ( FORM_DATA, EventChain, EventHandler, + EventSpec, checked_input_event, float_input_event, input_event, @@ -21,16 +23,34 @@ on_submit_event, on_submit_string_event, prevent_default, + unwrap_var_annotation, ) +from reflex_base.utils.exceptions import EventHandlerValueError from reflex_base.utils.imports import ImportDict from reflex_base.vars import VarData from reflex_base.vars.base import LiteralVar, Var from reflex_base.vars.number import ternary_operation +from typing_extensions import is_typeddict from reflex_components_core.el.element import Element from .base import BaseHTML +_DYNAMIC_FORM_FIELD = object() + +_KNOWN_SUBMIT_CONTROL_TYPES = { + "reflex_components_radix.primitives.slider.SliderRoot", + "reflex_components_radix.themes.components.checkbox.Checkbox", + "reflex_components_radix.themes.components.checkbox_group.CheckboxGroupRoot", + "reflex_components_radix.themes.components.radio_cards.RadioCardsRoot", + "reflex_components_radix.themes.components.radio_group.RadioGroupRoot", + "reflex_components_radix.themes.components.select.SelectRoot", + "reflex_components_radix.themes.components.slider.Slider", + "reflex_components_radix.themes.components.switch.Switch", +} + +FORM_SUBMIT_MAPPING = TypeVar("FORM_SUBMIT_MAPPING", bound=Mapping[str, Any]) + def _handle_submit_js_template( handle_submit_unique_name: str, @@ -66,6 +86,206 @@ def _handle_submit_js_template( """ +def on_submit_mapping_event( + form_data: Var[FORM_SUBMIT_MAPPING], +) -> tuple[Var[FORM_SUBMIT_MAPPING]]: + """Provide a generic mapping-style submit event spec for type checkers. + + Args: + form_data: The form submission payload. + + Returns: + The form data payload. + """ + return (form_data,) + + +def _iter_form_components(component: BaseComponent) -> Iterator[BaseComponent]: + """Yield a component and all nested components that may contribute form data. + + Args: + component: The component to walk. + + Yields: + The component and its nested component descendants. + """ + yield component + for child in component.children: + if isinstance(child, BaseComponent): + yield from _iter_form_components(child) + if isinstance(component, Component): + for component_in_props in component._get_components_in_props(): + yield from _iter_form_components(component_in_props) + + +def _get_static_string_prop( + component: BaseComponent, + prop_name: str, +) -> str | object | None: + """Resolve a component prop when it is statically known to be a string. + + Args: + component: The component being inspected. + prop_name: The prop to resolve. + + Returns: + The resolved string, ``_DYNAMIC_FORM_FIELD`` for dynamic vars, or ``None``. + """ + value = getattr(component, prop_name, None) + if value is None: + return None + if isinstance(value, str): + return value + if isinstance(value, LiteralVar): + decoded = value._decode() + if isinstance(decoded, str): + return decoded + return None + if isinstance(value, Var): + return _DYNAMIC_FORM_FIELD + return None + + +def _is_submit_participating_control(component: BaseComponent) -> bool: + """Check whether a component can contribute a named field to form data. + + Args: + component: The component to inspect. + + Returns: + Whether the component is a submit-participating control. + """ + if isinstance(component, (BaseInput, Select, Textarea)): + return True + component_type_name = ( + f"{component.__class__.__module__}.{component.__class__.__name__}" + ) + return component_type_name in _KNOWN_SUBMIT_CONTROL_TYPES + + +def _is_form_data_payload_arg(value: Var) -> bool: + """Check whether an event arg value is the form submission payload. + + Args: + value: The event arg value. + + Returns: + Whether the arg is the ``form_data`` payload. + """ + return isinstance(value, Var) and value._js_expr == FORM_DATA._js_expr + + +def _get_handler_name(handler: EventHandler) -> str: + """Get a stable fully qualified handler name for errors. + + Args: + handler: The handler to name. + + Returns: + The fully qualified handler name. + """ + return handler.fn.__qualname__ + + +def _resolve_on_submit_typed_dict_contract( + event_spec: EventSpec, +) -> tuple[str, type[Any], frozenset[str]] | None: + """Resolve the TypedDict contract for an on_submit handler, if any. + + Args: + event_spec: The finalized event spec in the on_submit chain. + + Returns: + The handler name, TypedDict annotation, and required keys, or ``None``. + """ + form_data_param_name = next( + ( + param._js_expr + for param, value in event_spec.args + if _is_form_data_payload_arg(value) + ), + None, + ) + if form_data_param_name is None: + return None + + func = ( + event_spec.handler.fn.func + if isinstance(event_spec.handler.fn, partial) + else event_spec.handler.fn + ) + try: + type_hints = get_type_hints(func) + except Exception: + return None + + annotation = type_hints.get(form_data_param_name) + if annotation is None: + return None + + annotation = unwrap_var_annotation(annotation) + if not is_typeddict(annotation): + return None + + required_fields = _get_required_typed_dict_fields(annotation) + return _get_handler_name(event_spec.handler), annotation, required_fields + + +def _get_required_typed_dict_fields(typed_dict_type: type[Any]) -> frozenset[str]: + """Resolve required TypedDict keys in a cross-version-safe way. + + Args: + typed_dict_type: The TypedDict class to inspect. + + Returns: + The required field names for the TypedDict. + """ + try: + field_type_hints = get_type_hints(typed_dict_type, include_extras=True) + except Exception: + field_type_hints = getattr(typed_dict_type, "__annotations__", {}) + + total = getattr(typed_dict_type, "__total__", True) + required_fields = { + field_name + for field_name, annotation in field_type_hints.items() + if _is_required_typed_dict_field(annotation, total=total) + } + return frozenset(required_fields) + + +def _is_required_typed_dict_field(annotation: Any, *, total: bool) -> bool: + """Check whether a TypedDict field annotation is required. + + Args: + annotation: The field annotation to inspect. + total: Whether the TypedDict defaults to required fields. + + Returns: + Whether the field is required. + """ + marker_name = getattr(get_origin(annotation), "__name__", None) + if marker_name == "NotRequired": + return False + if marker_name == "Required": + return True + return total + + +def _format_field_list(fields: tuple[str, ...]) -> str: + """Format field names as a bullet list. + + Args: + fields: The fields to format. + + Returns: + A human-readable bullet list. + """ + if not fields: + return ' - "(none)"' + return "\n".join(f' - "{field}"' for field in fields) + + ButtonType = Literal["submit", "reset", "button"] @@ -172,9 +392,9 @@ class Form(BaseHTML): doc="The name used to make this form's submit handler function unique." ) - on_submit: EventHandler[on_submit_event, on_submit_string_event] = field( - doc="Fired when the form is submitted" - ) + on_submit: EventHandler[ + on_submit_event, on_submit_mapping_event, on_submit_string_event + ] = field(doc="Fired when the form is submitted") @classmethod def create(cls, *children, **props): @@ -196,6 +416,7 @@ def create(cls, *children, **props): # Render the form hooks and use the hash of the resulting code to create a unique name. props["handle_submit_unique_name"] = "" form = super().create(*children, **props) + form._validate_on_submit_typed_dict_fields() # pyright: ignore[reportAttributeAccessIssue] form.handle_submit_unique_name = md5( # pyright: ignore[reportAttributeAccessIssue] str(form._get_all_hooks()).encode("utf-8") ).hexdigest() @@ -263,6 +484,83 @@ def _get_form_refs(self) -> dict[str, Any]: ) return form_refs + def _get_static_form_field_keys(self) -> tuple[set[str], bool]: + """Collect statically known form-data keys and whether any are dynamic. + + Returns: + The known keys and whether any name/id identifiers are dynamic. + """ + form_keys = set(self._get_form_refs()) + has_dynamic_identifiers = False + + for component in _iter_form_components(self): + if component is self or not _is_submit_participating_control(component): + continue + + name = _get_static_string_prop(component, "name") + if name is _DYNAMIC_FORM_FIELD: + has_dynamic_identifiers = True + elif isinstance(name, str): + form_keys.add(name) + + if _get_static_string_prop(component, "id") is _DYNAMIC_FORM_FIELD: + has_dynamic_identifiers = True + + return form_keys, has_dynamic_identifiers + + def _validate_on_submit_typed_dict_fields(self) -> None: + """Validate statically knowable form fields against TypedDict submit handlers. + + Raises: + EventHandlerValueError: If a required TypedDict field is missing. + """ + on_submit = self.event_triggers.get(EventTriggers.ON_SUBMIT) + if not isinstance(on_submit, EventChain): + return + + if any(not isinstance(event, EventSpec) for event in on_submit.events): + return + + event_specs = tuple( + event for event in on_submit.events if isinstance(event, EventSpec) + ) + typed_dict_contracts = [ + contract + for event in event_specs + if (contract := _resolve_on_submit_typed_dict_contract(event)) is not None + ] + if not typed_dict_contracts: + return + + form_keys, has_dynamic_identifiers = self._get_static_form_field_keys() + for handler_name, typed_dict_type, required_fields in typed_dict_contracts: + required_field_names = tuple(sorted(required_fields)) + if not required_field_names: + continue + + missing_fields = tuple( + field for field in required_field_names if field not in form_keys + ) + if not missing_fields or has_dynamic_identifiers: + continue + + present_fields = tuple( + field for field in required_field_names if field in form_keys + ) + msg = ( + f"Form field mismatch for on_submit handler `{handler_name}`.\n\n" + f"The handler expects form data matching `{typed_dict_type.__name__}` " + "with required fields:\n" + f"{_format_field_list(required_field_names)}\n\n" + "Fields missing from the form:\n" + f"{_format_field_list(missing_fields)}\n\n" + "Matching fields present in the form:\n" + f"{_format_field_list(present_fields)}\n\n" + "Hint: Add controls with matching static `name` or `id` values, or " + "make the TypedDict fields optional." + ) + raise EventHandlerValueError(msg) + def _get_vars( self, include_children: bool = True, ignore_ids: set[int] | None = None ) -> Iterator[Var]: diff --git a/pyi_hashes.json b/pyi_hashes.json index f7900836abd..c3538f3de19 100644 --- a/pyi_hashes.json +++ b/pyi_hashes.json @@ -27,7 +27,7 @@ "packages/reflex-components-core/src/reflex_components_core/el/element.pyi": "1a8824cdd243efc876157b97f9f1b714", "packages/reflex-components-core/src/reflex_components_core/el/elements/__init__.pyi": "e6c845f2f29eb079697a2e31b0c2f23a", "packages/reflex-components-core/src/reflex_components_core/el/elements/base.pyi": "7c74980207dc1a5cac14083f2edd31ba", - "packages/reflex-components-core/src/reflex_components_core/el/elements/forms.pyi": "da7ef00fd67699eeeb55e33279c2eb8d", + "packages/reflex-components-core/src/reflex_components_core/el/elements/forms.pyi": "1ff521334b753d334dbfbe309a8e2327", "packages/reflex-components-core/src/reflex_components_core/el/elements/inline.pyi": "0ea0058ea7b6ae03138c7c85df963c32", "packages/reflex-components-core/src/reflex_components_core/el/elements/media.pyi": "97f7f6c66533bb3947a43ceefe160d49", "packages/reflex-components-core/src/reflex_components_core/el/elements/metadata.pyi": "7ea09671a42d75234a0464fc3601577c", diff --git a/tests/units/components/forms/test_form.py b/tests/units/components/forms/test_form.py index f3ea3c03228..6f2d61e5a64 100644 --- a/tests/units/components/forms/test_form.py +++ b/tests/units/components/forms/test_form.py @@ -1,6 +1,15 @@ +from typing import TypedDict + +import pytest from reflex_base.event import EventChain, prevent_default +from reflex_base.utils.exceptions import EventHandlerValueError from reflex_base.vars.base import Var +from reflex_components_core.el.elements.forms import Form as HTMLForm +from reflex_components_core.el.elements.forms import Input from reflex_components_radix.primitives.form import Form +from typing_extensions import NotRequired + +import reflex as rx def test_render_on_submit(): @@ -20,3 +29,151 @@ def test_render_no_on_submit(): assert isinstance(f.event_triggers["on_submit"], EventChain) assert len(f.event_triggers["on_submit"].events) == 1 assert f.event_triggers["on_submit"].events[0] == prevent_default + + +@pytest.mark.parametrize("form_factory", [HTMLForm.create, Form.create]) +def test_on_submit_accepts_typed_dict_form_data(form_factory): + """TypedDict-annotated submit handlers should be accepted.""" + + class SignupData(TypedDict): + name: str + email: str + + class SignupState(rx.State): + @rx.event + def on_submit(self, form_data: SignupData): + pass + + form = form_factory( + Input.create(name="name"), + Input.create(name="email"), + on_submit=SignupState.on_submit, + ) + + assert isinstance(form.event_triggers["on_submit"], EventChain) + + +def test_on_submit_accepts_id_backed_typed_dict_form_data(): + """Static ids that are mirrored into form_data should satisfy TypedDict keys.""" + + class SignupData(TypedDict): + email_input: str + + class SignupState(rx.State): + @rx.event + def on_submit(self, form_data: SignupData): + pass + + form = HTMLForm.create( + Input.create(id="email_input"), + on_submit=SignupState.on_submit, + ) + + assert isinstance(form.event_triggers["on_submit"], EventChain) + + +def test_on_submit_accepts_typed_dict_with_optional_fields(): + """Optional TypedDict keys should not be required in the form.""" + + class SignupData(TypedDict): + email: str + nickname: NotRequired[str] + + class SignupState(rx.State): + @rx.event + def on_submit(self, form_data: SignupData): + pass + + form = HTMLForm.create( + Input.create(name="email"), + on_submit=SignupState.on_submit, + ) + + assert isinstance(form.event_triggers["on_submit"], EventChain) + + +def test_on_submit_allows_extra_typed_dict_form_fields(): + """Forms may include more fields than the TypedDict requires.""" + + class SignupData(TypedDict): + email: str + + class SignupState(rx.State): + @rx.event + def on_submit(self, form_data: SignupData): + pass + + form = HTMLForm.create( + Input.create(name="email"), + Input.create(name="nickname"), + on_submit=SignupState.on_submit, + ) + + assert isinstance(form.event_triggers["on_submit"], EventChain) + + +def test_on_submit_resolves_typed_dict_after_bound_args(): + """The final submit payload parameter should still resolve after binding args.""" + + class SignupData(TypedDict): + email: str + + class SignupState(rx.State): + @rx.event + def on_submit(self, source: str, form_data: SignupData): + pass + + form = HTMLForm.create( + Input.create(name="email"), + on_submit=SignupState.on_submit("marketing"), # pyright: ignore [reportCallIssue] + ) + + assert isinstance(form.event_triggers["on_submit"], EventChain) + + +def test_on_submit_typed_dict_missing_fields_raises_helpful_error(): + """Missing required TypedDict keys should produce a focused compile-time error.""" + + class SignupData(TypedDict): + fname: str + lname: str + email: str + + class SignupState(rx.State): + @rx.event + def on_submit(self, form_data: SignupData): + pass + + with pytest.raises(EventHandlerValueError) as err: + HTMLForm.create( + Input.create(name="email"), + on_submit=SignupState.on_submit, + ) + + error = str(err.value) + assert "Form field mismatch for on_submit handler" in error + assert "SignupState.on_submit" in error + assert "SignupData" in error + assert '"fname"' in error + assert '"lname"' in error + assert '"email"' in error + assert "Matching fields present in the form" in error + + +def test_on_submit_typed_dict_skips_dynamic_field_identifiers(): + """Dynamic field names should skip strict validation instead of raising.""" + + class SignupData(TypedDict): + email: str + + class SignupState(rx.State): + @rx.event + def on_submit(self, form_data: SignupData): + pass + + form = HTMLForm.create( + Input.create(name=Var(_js_expr="dynamic_name", _var_type=str)), + on_submit=SignupState.on_submit, + ) + + assert isinstance(form.event_triggers["on_submit"], EventChain) diff --git a/tests/units/components/test_component.py b/tests/units/components/test_component.py index 5b3e601f176..0323bd6f7d2 100644 --- a/tests/units/components/test_component.py +++ b/tests/units/components/test_component.py @@ -1,6 +1,6 @@ from contextlib import nullcontext from dataclasses import dataclass -from typing import Any, ClassVar +from typing import Any, ClassVar, TypedDict import pytest from reflex_base.components.component import ( @@ -839,6 +839,34 @@ def get_event_triggers(cls) -> dict[str, Any]: C1.create(on_foo=C1State.mock_handler) +def test_non_submit_mapping_events_do_not_accept_typed_dict_handlers(): + """TypedDict relaxation should stay scoped to form submission handlers.""" + + class Payload(TypedDict): + email: str + + class C1State(BaseState): + def mock_handler(self, payload: Payload): + """Mock handler.""" + + def on_foo_spec(payload: Var[dict[str, int]]) -> tuple[Var[dict[str, int]]]: + return (payload,) + + class C1(Component): + library = "/local" + tag = "C1" + + @classmethod + def get_event_triggers(cls) -> dict[str, Any]: + return { + **super().get_event_triggers(), + "on_foo": on_foo_spec, + } + + with pytest.raises(EventHandlerArgTypeMismatchError): + C1.create(on_foo=C1State.mock_handler) + + def test_create_custom_component(my_component): """Test that we can create a custom component.