diff --git a/packages/reflex-base/src/reflex_base/compiler/templates.py b/packages/reflex-base/src/reflex_base/compiler/templates.py index be3e3f6eee4..9091b8edfc5 100644 --- a/packages/reflex-base/src/reflex_base/compiler/templates.py +++ b/packages/reflex-base/src/reflex_base/compiler/templates.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: from reflex.compiler.utils import _ImportDict - from reflex_base.components.component import Component, StatefulComponent + from reflex_base.components.component import Component def _sort_hooks( @@ -417,7 +417,7 @@ def context_template( }}""" -def component_template(component: Component | StatefulComponent): +def component_template(component: Component): """Template to render a component tag. Args: @@ -618,24 +618,23 @@ def vite_config_template( }}));""" -def stateful_component_template( - tag_name: str, memo_trigger_hooks: list[str], component: Component, export: bool -): - """Template for stateful component. +def dynamic_component_template( + tag_name: str, component: Component, export: bool +) -> str: + """Template for a dynamic SSR component function declaration. Args: tag_name: The tag name for the component. - memo_trigger_hooks: The memo trigger hooks for the component. component: The component to render. export: Whether to export the component. Returns: - Rendered stateful component code as string. + Rendered dynamic component code as string. """ all_hooks = component._get_all_hooks() return f""" {"export " if export else ""}function {tag_name} () {{ - {_render_hooks(all_hooks, memo_trigger_hooks)} + {_render_hooks(all_hooks)} return ( {_RenderUtils.render(component.render())} ) @@ -643,15 +642,17 @@ def stateful_component_template( """ -def stateful_components_template(imports: list[_ImportDict], memoized_code: str) -> str: - """Template for stateful components. +def dynamic_components_module_template( + imports: list[_ImportDict], memoized_code: str +) -> str: + """Template for a dynamic-SSR components module. Args: imports: List of import statements. - memoized_code: Memoized code for stateful components. + memoized_code: Code for the module body. Returns: - Rendered stateful components code as string. + Rendered module code as string. """ imports_str = "\n".join([_RenderUtils.get_import(imp) for imp in imports]) return f"{imports_str}\n{memoized_code}" diff --git a/packages/reflex-base/src/reflex_base/components/component.py b/packages/reflex-base/src/reflex_base/components/component.py index 8f47f447c7d..b69d99497f2 100644 --- a/packages/reflex-base/src/reflex_base/components/component.py +++ b/packages/reflex-base/src/reflex_base/components/component.py @@ -3,7 +3,6 @@ from __future__ import annotations import contextlib -import copy import dataclasses import enum import functools @@ -22,7 +21,6 @@ from reflex_base import constants from reflex_base.breakpoints import Breakpoints -from reflex_base.compiler.templates import stateful_component_template from reflex_base.components.dynamic import load_dynamic_serializer from reflex_base.components.field import BaseField, FieldBasedMeta from reflex_base.components.tags import Tag @@ -33,7 +31,6 @@ Imports, MemoizationDisposition, MemoizationMode, - PageNames, ) from reflex_base.constants.compiler import SpecialAttributes from reflex_base.constants.state import CAMEL_CASE_MEMO_MARKER @@ -1300,7 +1297,7 @@ def _add_style_recursive( # Recursively add style to the children. for child in self.children: - # Skip BaseComponent and StatefulComponent children. + # Skip non-Component children. if not isinstance(child, Component): continue child._add_style_recursive(style, theme) @@ -1327,6 +1324,10 @@ def render(self) -> dict: Returns: The dictionary for template of component. """ + try: + return self._cached_render_result + except AttributeError: + pass tag = self._render() rendered_dict = dict( tag.set( @@ -1334,6 +1335,7 @@ def render(self) -> dict: ) ) self._replace_prop_names(rendered_dict) + self._cached_render_result = rendered_dict return rendered_dict def _replace_prop_names(self, rendered_dict: dict) -> None: @@ -1457,11 +1459,16 @@ def _get_vars( Yields: Each var referenced by the component (props, styles, event handlers). """ + # Default-args fast path is cached per instance. Invalidated by the + # auto-memoize plugin when fix_event_triggers_for_memo mutates event_triggers. + if not include_children and ignore_ids is None: + cached = self.__dict__.get("_vars_cache") + if cached is not None: + yield from cached + return + ignore_ids = ignore_ids or set() - vars: list[Var] | None = getattr(self, "__vars", None) - if vars is not None: - yield from vars - vars = self.__vars = [] + vars: list[Var] = [] # Get Vars associated with event trigger arguments. for _, event_vars in self._get_vars_from_event_triggers(self.event_triggers): vars.extend(event_vars) @@ -1500,7 +1507,6 @@ def _get_vars( if var._get_all_var_data() is not None: vars.append(var) - # Get Vars associated with children. if include_children: for child in self.children: if not isinstance(child, Component) or id(child) in ignore_ids: @@ -1510,7 +1516,11 @@ def _get_vars( include_children=include_children, ignore_ids=ignore_ids ) vars.extend(child_vars) + yield from vars + return + # Freeze and cache the default-args result. + self._vars_cache = tuple(vars) yield from vars def _event_trigger_values_use_state(self) -> bool: @@ -1555,6 +1565,7 @@ def _iter_parent_classes_names(cls) -> Iterator[str]: yield clz.__name__ @classmethod + @functools.cache def _iter_parent_classes_with_method(cls, method: str) -> Sequence[type[Component]]: """Iterate through parent classes that define a given method. @@ -1581,7 +1592,7 @@ def _iter_parent_classes_with_method(cls, method: str) -> Sequence[type[Componen continue seen_methods.add(method_func) clzs.append(clz) - return clzs + return tuple(clzs) def _get_custom_code(self) -> str | None: """Get custom code for the component. @@ -1704,6 +1715,10 @@ def _get_imports(self) -> ParsedImportDict: Returns: The imports needed by the component. """ + cached = self.__dict__.get("_imports_cache") + if cached is not None: + return cached + imports_ = ( {self.library: [self.import_var]} if self.library is not None and self.tag is not None @@ -1731,7 +1746,7 @@ def _get_imports(self) -> ParsedImportDict: imports.parse_imports(item) for item in list_of_import_dict ]) - return imports.merge_parsed_imports( + result = imports.merge_parsed_imports( self._get_dependencies_imports(), self._get_hooks_imports(), imports_, @@ -1739,6 +1754,8 @@ def _get_imports(self) -> ParsedImportDict: *var_imports, *added_import_dicts, ) + self._imports_cache = result + return result def _get_all_imports(self, collapse: bool = False) -> ParsedImportDict: """Get all the libraries and fields that are used by the component and its children. @@ -1840,7 +1857,11 @@ def _get_hooks_internal(self) -> dict[str, VarData | None]: Returns: The internally managed hooks. """ - return { + cached = self.__dict__.get("_hooks_internal_cache") + if cached is not None: + return cached + + result = { **{ str(hook): VarData(position=Hooks.HookPosition.INTERNAL) for hook in [self._get_ref_hook(), self._get_mount_lifecycle_hook()] @@ -1849,6 +1870,8 @@ def _get_hooks_internal(self) -> dict[str, VarData | None]: **self._get_vars_hooks(), **self._get_events_hooks(), } + self._hooks_internal_cache = result + return result def _get_added_hooks(self) -> dict[str, VarData | None]: """Get the hooks added via `add_hooks` method. @@ -2005,7 +2028,7 @@ def _get_all_app_wrap_components( # Add the app wrap components for the children. for child in self.children: child_id = id(child) - # Skip BaseComponent and StatefulComponent children. + # Skip non-Component children. if not isinstance(child, Component) or child_id in ignore_ids: continue ignore_ids.add(child_id) @@ -2382,440 +2405,6 @@ def _get_dynamic_imports(self) -> str: ) -class StatefulComponent(BaseComponent): - """A component that depends on state and is rendered outside of the page component. - - If a StatefulComponent is used in multiple pages, it will be rendered to a common file and - imported into each page that uses it. - - A stateful component has a tag name that includes a hash of the code that it renders - to. This tag name refers to the specific component with the specific props that it - was created with. - """ - - # Reference to the original component that was memoized into this component. - component: Component = field( - default_factory=Component, is_javascript_property=False - ) - - references: int = field( - doc="How many times this component is referenced in the app.", - default=0, - is_javascript_property=False, - ) - - rendered_as_shared: bool = field( - doc="Whether the component has already been rendered to a shared file.", - default=False, - is_javascript_property=False, - ) - - memo_trigger_hooks: list[str] = field( - default_factory=list, is_javascript_property=False - ) - - @classmethod - def create(cls, component: Component) -> StatefulComponent | None: - """Create a stateful component from a component. - - Args: - component: The component to memoize. - - Returns: - The stateful component or None if the component should not be memoized. - """ - from reflex_components_core.core.foreach import Foreach - - from reflex_base.registry import RegistrationContext - - if component._memoization_mode.disposition == MemoizationDisposition.NEVER: - # Never memoize this component. - return None - - if component.tag is None: - # Only memoize components with a tag. - return None - - # If _var_data is found in this component, it is a candidate for auto-memoization. - should_memoize = False - - # If the component requests to be memoized, then ignore other checks. - if component._memoization_mode.disposition == MemoizationDisposition.ALWAYS: - should_memoize = True - - if not should_memoize: - # Determine if any Vars have associated data. - for prop_var in component._get_vars(include_children=True): - if prop_var._get_all_var_data(): - should_memoize = True - break - - if not should_memoize: - # Check for special-cases in child components. - for child in component.children: - # Skip BaseComponent and StatefulComponent children. - if not isinstance(child, Component): - continue - # Always consider Foreach something that must be memoized by the parent. - if isinstance(child, Foreach): - should_memoize = True - break - child = cls._child_var(child) - if isinstance(child, Var) and child._get_all_var_data(): - should_memoize = True - break - - if should_memoize or component.event_triggers: - # Render the component to determine tag+hash based on component code. - tag_name = cls._get_tag_name(component) - if tag_name is None: - return None - - # Look up the tag in the cache - ctx = RegistrationContext.get() - stateful_component = ctx.tag_to_stateful_component.get(tag_name) - if stateful_component is None: - memo_trigger_hooks = cls._fix_event_triggers(component) - # Set the stateful component in the cache for the given tag. - stateful_component = ctx.tag_to_stateful_component.setdefault( - tag_name, - cls( - children=component.children, - component=component, - tag=tag_name, - memo_trigger_hooks=memo_trigger_hooks, - ), - ) - # Bump the reference count -- multiple pages referencing the same component - # will result in writing it to a common file. - stateful_component.references += 1 - return stateful_component - - # Return None to indicate this component should not be memoized. - return None - - @staticmethod - def _child_var(child: Component) -> Var | Component: - """Get the Var from a child component. - - This method is used for special cases when the StatefulComponent should actually - wrap the parent component of the child instead of recursing into the children - and memoizing them independently. - - Args: - child: The child component. - - Returns: - The Var from the child component or the child itself (for regular cases). - """ - from reflex_components_core.base.bare import Bare - from reflex_components_core.core.cond import Cond - from reflex_components_core.core.foreach import Foreach - from reflex_components_core.core.match import Match - - if isinstance(child, Bare): - return child.contents - if isinstance(child, Cond): - return child.cond - if isinstance(child, Foreach): - return child.iterable - if isinstance(child, Match): - return child.cond - return child - - @classmethod - def _get_tag_name(cls, component: Component) -> str | None: - """Get the tag based on rendering the given component. - - Args: - component: The component to render. - - Returns: - The tag for the stateful component. - """ - # Get the render dict for the component. - rendered_code = component.render() - if not rendered_code: - # Never memoize non-visual components. - return None - - # Compute the hash based on the rendered code. - code_hash = _hash_str(_deterministic_hash(rendered_code)) - - # Format the tag name including the hash. - return format.format_state_name( - f"{component.tag or 'Comp'}_{code_hash}" - ).capitalize() - - def _render_stateful_code( - self, - export: bool = False, - ) -> str: - if not self.tag: - return "" - # Render the code for this component and hooks. - return stateful_component_template( - tag_name=self.tag, - memo_trigger_hooks=self.memo_trigger_hooks, - component=self.component, - export=export, - ) - - @classmethod - def _fix_event_triggers( - cls, - component: Component, - ) -> list[str]: - """Render the code for a stateful component. - - Args: - component: The component to render. - - Returns: - The memoized event trigger hooks for the component. - """ - # Memoize event triggers useCallback to avoid unnecessary re-renders. - memo_event_triggers = tuple(cls._get_memoized_event_triggers(component).items()) - - # Trigger hooks stored separately to write after the normal hooks (see stateful_component.js.jinja2) - memo_trigger_hooks: list[str] = [] - - if memo_event_triggers: - # Copy the component to avoid mutating the original. - component = copy.copy(component) - - for event_trigger, ( - memo_trigger, - memo_trigger_hook, - ) in memo_event_triggers: - # Replace the event trigger with the memoized version. - memo_trigger_hooks.append(memo_trigger_hook) - component.event_triggers[event_trigger] = memo_trigger - - return memo_trigger_hooks - - @staticmethod - def _get_hook_deps(hook: str) -> list[str]: - """Extract var deps from a hook. - - Args: - hook: The hook line to extract deps from. - - Returns: - A list of var names created by the hook declaration. - """ - # Ensure that the hook is a var declaration. - var_decl = hook.partition("=")[0].strip() - if not any(var_decl.startswith(kw) for kw in ["const ", "let ", "var "]): - return [] - - # Extract the var name from the declaration. - _, _, var_name = var_decl.partition(" ") - var_name = var_name.strip() - - # Break up array and object destructuring if used. - if var_name.startswith(("[", "{")): - return [ - v.strip().replace("...", "") for v in var_name.strip("[]{}").split(",") - ] - return [var_name] - - @staticmethod - def _get_deps_from_event_trigger( - event: EventChain | EventSpec | Var, - ) -> dict[str, None]: - """Get the dependencies accessed by event triggers. - - Args: - event: The event trigger to extract deps from. - - Returns: - The dependencies accessed by the event triggers. - """ - events: list = [event] - deps = {} - - if isinstance(event, EventChain): - events.extend(event.events) - - for ev in events: - if isinstance(ev, EventSpec): - for arg in ev.args: - for a in arg: - var_datas = VarData.merge(a._get_all_var_data()) - if var_datas and var_datas.deps is not None: - deps |= {str(dep): None for dep in var_datas.deps} - return deps - - @classmethod - def _get_memoized_event_triggers( - cls, - component: Component, - ) -> dict[str, tuple[Var, str]]: - """Memoize event handler functions with useCallback to avoid unnecessary re-renders. - - Args: - component: The component with events to memoize. - - Returns: - A dict of event trigger name to a tuple of the memoized event trigger Var and - the hook code that memoizes the event handler. - """ - trigger_memo = {} - for event_trigger, event_args in component._get_vars_from_event_triggers( - component.event_triggers - ): - if event_trigger in { - EventTriggers.ON_MOUNT, - EventTriggers.ON_UNMOUNT, - EventTriggers.ON_SUBMIT, - }: - # Do not memoize lifecycle or submit events. - continue - - # Get the actual EventSpec and render it. - event = component.event_triggers[event_trigger] - rendered_chain = str(LiteralVar.create(event)) - - # Hash the rendered EventChain to get a deterministic function name. - chain_hash = md5(str(rendered_chain).encode("utf-8")).hexdigest() - memo_name = f"{event_trigger}_{chain_hash}" - - # Calculate Var dependencies accessed by the handler for useCallback dep array. - var_deps = ["addEvents", "ReflexEvent"] - - # Get deps from event trigger var data. - var_deps.extend(cls._get_deps_from_event_trigger(event)) - - # Get deps from hooks. - for arg in event_args: - var_data = arg._get_all_var_data() - if var_data is None: - continue - for hook in var_data.hooks: - var_deps.extend(cls._get_hook_deps(hook)) - memo_var_data = VarData.merge( - *[var._get_all_var_data() for var in event_args], - VarData( - imports={"react": [ImportVar(tag="useCallback")]}, - ), - ) - - # Store the memoized function name and hook code for this event trigger. - trigger_memo[event_trigger] = ( - Var(_js_expr=memo_name)._replace( - _var_type=EventChain, merge_var_data=memo_var_data - ), - f"const {memo_name} = useCallback({rendered_chain}, [{', '.join(var_deps)}])", - ) - return trigger_memo - - def _get_all_hooks_internal(self) -> dict[str, VarData | None]: - """Get the reflex internal hooks for the component and its children. - - Returns: - The code that should appear just before user-defined hooks. - """ - return {} - - def _get_all_hooks(self) -> dict[str, VarData | None]: - """Get the React hooks for this component. - - Returns: - The code that should appear just before returning the rendered component. - """ - return {} - - def _get_all_imports(self) -> ParsedImportDict: - """Get all the libraries and fields that are used by the component. - - Returns: - The import dict with the required imports. - """ - if self.rendered_as_shared: - return { - f"$/{Dirs.UTILS}/{PageNames.STATEFUL_COMPONENTS}": [ - ImportVar(tag=self.tag) - ] - } - return self.component._get_all_imports() - - def _get_all_dynamic_imports(self) -> set[str]: - """Get dynamic imports for the component. - - Returns: - The dynamic imports. - """ - if self.rendered_as_shared: - return set() - return self.component._get_all_dynamic_imports() - - def _get_all_custom_code(self, export: bool = False) -> dict[str, None]: - """Get custom code for the component. - - Args: - export: Whether to export the component. - - Returns: - The custom code. - """ - if self.rendered_as_shared: - return {} - return self.component._get_all_custom_code() | ({ - self._render_stateful_code(export=export): None - }) - - def _get_all_refs(self) -> dict[str, None]: - """Get the refs for the children of the component. - - Returns: - The refs for the children. - """ - if self.rendered_as_shared: - return {} - return self.component._get_all_refs() - - def render(self) -> dict: - """Define how to render the component in React. - - Returns: - The tag to render. - """ - return dict(Tag(name=self.tag or "")) - - def __str__(self) -> str: - """Represent the component in React. - - Returns: - The code to render the component. - """ - from reflex.compiler.compiler import _compile_component - - return _compile_component(self) - - @classmethod - def compile_from(cls, component: BaseComponent) -> BaseComponent: - """Walk through the component tree and memoize all stateful components. - - Args: - component: The component to memoize. - - Returns: - The memoized component tree. - """ - if isinstance(component, Component): - if component._memoization_mode.recursive: - # Recursively memoize stateful children (default). - component.children = [ - cls.compile_from(child) for child in component.children - ] - # Memoize this component if it depends on state. - stateful_component = cls.create(component) - if stateful_component is not None: - return stateful_component - return component - - class MemoizationLeaf(Component): """A component that does not separately memoize its children. diff --git a/packages/reflex-base/src/reflex_base/components/dynamic.py b/packages/reflex-base/src/reflex_base/components/dynamic.py index 6c2100a40e8..0386167198d 100644 --- a/packages/reflex-base/src/reflex_base/components/dynamic.py +++ b/packages/reflex-base/src/reflex_base/components/dynamic.py @@ -85,9 +85,8 @@ def make_component(component: Component) -> str: rendered_components.update(component._get_all_custom_code()) rendered_components[ - templates.stateful_component_template( + templates.dynamic_component_template( tag_name="MySSRComponent", - memo_trigger_hooks=[], component=component, export=True, ) @@ -110,7 +109,7 @@ def make_component(component: Component) -> str: else: imports[lib] = names - module_code_lines = templates.stateful_components_template( + module_code_lines = templates.dynamic_components_module_template( imports=utils.compile_imports(imports), memoized_code="\n".join(rendered_components), ).splitlines() diff --git a/packages/reflex-base/src/reflex_base/components/memoize_helpers.py b/packages/reflex-base/src/reflex_base/components/memoize_helpers.py new file mode 100644 index 00000000000..c7494ba6b8c --- /dev/null +++ b/packages/reflex-base/src/reflex_base/components/memoize_helpers.py @@ -0,0 +1,175 @@ +"""Event-trigger memoization helpers for auto-memoized and pseudo-stateful components. + +These helpers wrap a component's non-lifecycle event triggers in ``useCallback`` +so that React can skip re-renders of subtrees whose event handlers have stable +identities. They are used by both the compiler auto-memoization plugin (see +``reflex.compiler.plugins.memoize``) and by component-creation-time consumers +in ``reflex-components-core`` (e.g. ``WindowEventListener``, ``upload``). +""" + +from __future__ import annotations + +import contextlib +from hashlib import md5 + +from reflex_base.components.component import Component +from reflex_base.constants import EventTriggers +from reflex_base.event import EventChain, EventSpec +from reflex_base.utils.imports import ImportVar +from reflex_base.vars import VarData +from reflex_base.vars.base import LiteralVar, Var + + +def _get_hook_deps(hook: str) -> list[str]: + """Extract Var deps from a hook declaration line. + + Args: + hook: The hook line (e.g. ``"const foo = useState(...)"``). + + Returns: + The names of variables created by the declaration. + """ + var_decl = hook.partition("=")[0].strip() + if not any(var_decl.startswith(kw) for kw in ["const ", "let ", "var "]): + return [] + _, _, var_name = var_decl.partition(" ") + var_name = var_name.strip() + if var_name.startswith(("[", "{")): + return [v.strip().replace("...", "") for v in var_name.strip("[]{}").split(",")] + return [var_name] + + +def _get_deps_from_event_trigger( + event: EventChain | EventSpec | Var, +) -> dict[str, None]: + """Get the dependencies accessed by an event trigger value. + + Args: + event: The event trigger value. + + Returns: + Dependency names, insertion-ordered. + """ + events: list = [event] + deps: dict[str, None] = {} + + if isinstance(event, EventChain): + events.extend(event.events) + + for ev in events: + if isinstance(ev, EventSpec): + for arg in ev.args: + for a in arg: + var_datas = VarData.merge(a._get_all_var_data()) + if var_datas and var_datas.deps is not None: + deps |= {str(dep): None for dep in var_datas.deps} + return deps + + +def get_memoized_event_triggers( + component: Component, +) -> dict[str, tuple[Var, str]]: + """Generate ``useCallback`` wrappers for the component's event triggers. + + Args: + component: The component whose event triggers should be memoized. + + Returns: + A dict mapping event trigger name to + ``(memoized_var, useCallback_hook_line)``. + """ + trigger_memo: dict[str, tuple[Var, str]] = {} + for event_trigger, event_args in component._get_vars_from_event_triggers( + component.event_triggers + ): + if event_trigger in { + EventTriggers.ON_MOUNT, + EventTriggers.ON_UNMOUNT, + EventTriggers.ON_SUBMIT, + }: + # Do not memoize lifecycle or submit events. + continue + + event = component.event_triggers[event_trigger] + rendered_chain = str(LiteralVar.create(event)) + + chain_hash = md5( + str(rendered_chain).encode("utf-8"), usedforsecurity=False + ).hexdigest() + memo_name = f"{event_trigger}_{chain_hash}" + + var_deps = ["addEvents", "ReflexEvent"] + var_deps.extend(_get_deps_from_event_trigger(event)) + + for arg in event_args: + var_data = arg._get_all_var_data() + if var_data is None: + continue + for hook in var_data.hooks: + var_deps.extend(_get_hook_deps(hook)) + + memo_var_data = VarData.merge( + *[var._get_all_var_data() for var in event_args], + VarData(imports={"react": [ImportVar(tag="useCallback")]}), + ) + + trigger_memo[event_trigger] = ( + Var(_js_expr=memo_name)._replace( + _var_type=EventChain, merge_var_data=memo_var_data + ), + f"const {memo_name} = useCallback({rendered_chain}, [{', '.join(var_deps)}])", + ) + return trigger_memo + + +def fix_event_triggers_for_memo(component: Component) -> list[str]: + """Memoize ``component.event_triggers`` in place and return hook code. + + Replaces each (non-lifecycle) event-trigger value on ``component`` with a + ``Var`` naming a memoized ``useCallback`` wrapper, and returns the + ``useCallback`` hook lines in trigger order. + + Args: + component: The component whose event triggers to memoize. + + Returns: + The ``useCallback`` hook lines to emit at the top of the page body. + """ + memo_event_triggers = tuple(get_memoized_event_triggers(component).items()) + memo_trigger_hooks: list[str] = [] + + if memo_event_triggers: + component.event_triggers = dict( + component.event_triggers + ) # isolate so original dict is not mutated + for event_trigger, (memo_trigger, memo_trigger_hook) in memo_event_triggers: + memo_trigger_hooks.append(memo_trigger_hook) + component.event_triggers[event_trigger] = memo_trigger + + return memo_trigger_hooks + + +def invalidate_event_trigger_caches(component: Component) -> None: + """Drop caches that depend on ``component.event_triggers``. + + After :func:`fix_event_triggers_for_memo` mutates the shared event-triggers + dict, cached derivatives become stale. + + Args: + component: The original (pre-mutation) component. + """ + for attr in ( + "_cached_render_result", + "_vars_cache", + "_imports_cache", + "_hooks_internal_cache", + ): + with contextlib.suppress(AttributeError): + delattr(component, attr) + + +__all__ = [ + "fix_event_triggers_for_memo", + "get_memoized_event_triggers", + "invalidate_event_trigger_caches", +] diff --git a/packages/reflex-base/src/reflex_base/environment.py b/packages/reflex-base/src/reflex_base/environment.py index 31ebe795998..f2d4ce44cb9 100644 --- a/packages/reflex-base/src/reflex_base/environment.py +++ b/packages/reflex-base/src/reflex_base/environment.py @@ -2,14 +2,11 @@ from __future__ import annotations -import concurrent.futures import dataclasses import enum import importlib -import multiprocessing import os -import platform -from collections.abc import Callable, Sequence +from collections.abc import Sequence from functools import lru_cache from pathlib import Path from typing import ( @@ -529,97 +526,6 @@ class PerformanceMode(enum.Enum): OFF = "off" -class ExecutorType(enum.Enum): - """Executor for compiling the frontend.""" - - THREAD = "thread" - PROCESS = "process" - MAIN_THREAD = "main_thread" - - @classmethod - def get_executor_from_environment(cls): - """Get the executor based on the environment variables. - - Returns: - The executor. - """ - from reflex_base.utils import console - - executor_type = environment.REFLEX_COMPILE_EXECUTOR.get() - - reflex_compile_processes = environment.REFLEX_COMPILE_PROCESSES.get() - reflex_compile_threads = environment.REFLEX_COMPILE_THREADS.get() - # By default, use the main thread. Unless the user has specified a different executor. - # Using a process pool is much faster, but not supported on all platforms. It's gated behind a flag. - if executor_type is None: - if ( - platform.system() not in ("Linux", "Darwin") - and reflex_compile_processes is not None - ): - console.warn("Multiprocessing is only supported on Linux and MacOS.") - - if ( - platform.system() in ("Linux", "Darwin") - and reflex_compile_processes is not None - ): - if reflex_compile_processes == 0: - console.warn( - "Number of processes must be greater than 0. If you want to use the default number of processes, set REFLEX_COMPILE_EXECUTOR to 'process'. Defaulting to None." - ) - reflex_compile_processes = None - elif reflex_compile_processes < 0: - console.warn( - "Number of processes must be greater than 0. Defaulting to None." - ) - reflex_compile_processes = None - executor_type = ExecutorType.PROCESS - elif reflex_compile_threads is not None: - if reflex_compile_threads == 0: - console.warn( - "Number of threads must be greater than 0. If you want to use the default number of threads, set REFLEX_COMPILE_EXECUTOR to 'thread'. Defaulting to None." - ) - reflex_compile_threads = None - elif reflex_compile_threads < 0: - console.warn( - "Number of threads must be greater than 0. Defaulting to None." - ) - reflex_compile_threads = None - executor_type = ExecutorType.THREAD - else: - executor_type = ExecutorType.MAIN_THREAD - - match executor_type: - case ExecutorType.PROCESS: - executor = concurrent.futures.ProcessPoolExecutor( - max_workers=reflex_compile_processes, - mp_context=multiprocessing.get_context("fork"), - ) - case ExecutorType.THREAD: - executor = concurrent.futures.ThreadPoolExecutor( - max_workers=reflex_compile_threads - ) - case ExecutorType.MAIN_THREAD: - FUTURE_RESULT_TYPE = TypeVar("FUTURE_RESULT_TYPE") - - class MainThreadExecutor: - def __enter__(self): - return self - - def __exit__(self, *args): - pass - - def submit( - self, fn: Callable[..., FUTURE_RESULT_TYPE], *args, **kwargs - ) -> concurrent.futures.Future[FUTURE_RESULT_TYPE]: - future_job = concurrent.futures.Future() - future_job.set_result(fn(*args, **kwargs)) - return future_job - - executor = MainThreadExecutor() - - return executor - - class EnvironmentVariables: """Environment variables class to instantiate environment variables.""" @@ -660,14 +566,6 @@ class EnvironmentVariables: Path(constants.Dirs.UPLOADED_FILES) ) - REFLEX_COMPILE_EXECUTOR: EnvVar[ExecutorType | None] = env_var(None) - - # Whether to use separate processes to compile the frontend and how many. If not set, defaults to thread executor. - REFLEX_COMPILE_PROCESSES: EnvVar[int | None] = env_var(None) - - # Whether to use separate threads to compile the frontend and how many. Defaults to `min(32, os.cpu_count() + 4)`. - REFLEX_COMPILE_THREADS: EnvVar[int | None] = env_var(None) - # The directory to store reflex dependencies. REFLEX_DIR: EnvVar[Path] = env_var(constants.Reflex.DIR) diff --git a/packages/reflex-base/src/reflex_base/event/processor/event_processor.py b/packages/reflex-base/src/reflex_base/event/processor/event_processor.py index 864d3f40d79..3150af45ef7 100644 --- a/packages/reflex-base/src/reflex_base/event/processor/event_processor.py +++ b/packages/reflex-base/src/reflex_base/event/processor/event_processor.py @@ -380,6 +380,7 @@ async def enqueue_stream_delta( self, token: str, event: Event, + on_task_future: Callable[[EventFuture], None] | None = None, ) -> AsyncGenerator[Mapping[str, Any]]: """Enqueue an event to be processed and yield deltas emitted by the event handler. @@ -397,6 +398,8 @@ async def enqueue_stream_delta( Args: token: The client token associated with the event. event: The event to be enqueued. + on_task_future: Optional callback invoked with the EventFuture for the + enqueued handler as soon as it is created. Yields: Deltas emitted by the event handler for the specified token. @@ -429,6 +432,8 @@ async def _emit_delta_impl( emit_delta_impl=_emit_delta_impl, ), ) + if on_task_future is not None: + on_task_future(task_future) all_task_futures = asyncio.create_task(task_future.wait_all()) waiting_for = {all_task_futures, asyncio.create_task(deltas.get())} try: diff --git a/packages/reflex-base/src/reflex_base/plugins/__init__.py b/packages/reflex-base/src/reflex_base/plugins/__init__.py index b4320489b08..f3ef5aa971c 100644 --- a/packages/reflex-base/src/reflex_base/plugins/__init__.py +++ b/packages/reflex-base/src/reflex_base/plugins/__init__.py @@ -3,12 +3,26 @@ from . import sitemap, tailwind_v3, tailwind_v4 from ._screenshot import ScreenshotPlugin as _ScreenshotPlugin from .base import CommonContext, Plugin, PreCompileContext +from .compiler import ( + BaseContext, + CompileContext, + CompilerHooks, + ComponentAndChildren, + PageContext, + PageDefinition, +) from .sitemap import SitemapPlugin from .tailwind_v3 import TailwindV3Plugin from .tailwind_v4 import TailwindV4Plugin __all__ = [ + "BaseContext", "CommonContext", + "CompileContext", + "CompilerHooks", + "ComponentAndChildren", + "PageContext", + "PageDefinition", "Plugin", "PreCompileContext", "SitemapPlugin", diff --git a/packages/reflex-base/src/reflex_base/plugins/base.py b/packages/reflex-base/src/reflex_base/plugins/base.py index 52dfa8d7805..fdd8911a7f5 100644 --- a/packages/reflex-base/src/reflex_base/plugins/base.py +++ b/packages/reflex-base/src/reflex_base/plugins/base.py @@ -2,12 +2,14 @@ from collections.abc import Callable, Sequence from pathlib import Path -from typing import TYPE_CHECKING, ParamSpec, Protocol, TypedDict +from typing import TYPE_CHECKING, Any, ParamSpec, Protocol, TypedDict from typing_extensions import Unpack if TYPE_CHECKING: from reflex.app import App, UnevaluatedPage + from reflex_base.components.component import BaseComponent + from reflex_base.plugins.compiler import ComponentAndChildren, PageContext class CommonContext(TypedDict): @@ -117,6 +119,80 @@ def post_compile(self, **context: Unpack[PostCompileContext]) -> None: context: The context for the plugin. """ + def eval_page( + self, + page_fn: Any, + /, + **kwargs: Any, + ) -> "PageContext | None": + """Evaluate a page-like object into a page context. + + Args: + page_fn: The page-like object to evaluate. + kwargs: Additional compiler-specific context. + + Returns: + A page context when the plugin can evaluate the page, otherwise ``None``. + """ + del page_fn, kwargs + return None + + def compile_page( + self, + page_ctx: "PageContext", + /, + **kwargs: Any, + ) -> None: + """Finalize a page context after its component tree has been traversed.""" + del page_ctx, kwargs + return + + def enter_component( + self, + comp: "BaseComponent", + /, + *, + page_context: "PageContext", + compile_context: Any, + in_prop_tree: bool = False, + ) -> "BaseComponent | ComponentAndChildren | None": + """Inspect or transform a component before visiting its descendants. + + Args: + comp: The component being compiled. + page_context: The active page compilation state. + compile_context: The active compile-run state. + in_prop_tree: Whether the component is being visited through a prop subtree. + + Returns: + An optional replacement component and/or structural children. + """ + return None + + def leave_component( + self, + comp: "BaseComponent", + children: tuple["BaseComponent", ...], + /, + *, + page_context: "PageContext", + compile_context: Any, + in_prop_tree: bool = False, + ) -> "BaseComponent | ComponentAndChildren | None": + """Inspect or transform a component after visiting its descendants. + + Args: + comp: The component being compiled. + children: The compiled structural children for the component. + page_context: The active page compilation state. + compile_context: The active compile-run state. + in_prop_tree: Whether the component is being visited through a prop subtree. + + Returns: + An optional replacement component and/or structural children. + """ + return None + def __repr__(self): """Return a string representation of the plugin. diff --git a/packages/reflex-base/src/reflex_base/plugins/compiler.py b/packages/reflex-base/src/reflex_base/plugins/compiler.py new file mode 100644 index 00000000000..548fc9516a1 --- /dev/null +++ b/packages/reflex-base/src/reflex_base/plugins/compiler.py @@ -0,0 +1,808 @@ +"""Compiler plugin infrastructure: protocols, contexts, and dispatch.""" + +from __future__ import annotations + +import dataclasses +import inspect +from collections.abc import Callable, Sequence +from contextvars import ContextVar, Token +from types import TracebackType +from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeAlias, cast + +from typing_extensions import Self + +from reflex_base.components.component import BaseComponent, Component +from reflex_base.utils.imports import ParsedImportDict, collapse_imports, merge_imports +from reflex_base.vars import VarData + +from .base import Plugin + +if TYPE_CHECKING: + from reflex.app import App, ComponentCallable + + PageComponent: TypeAlias = Component | ComponentCallable +else: + PageComponent: TypeAlias = ( + Component + | Callable[ + [], + Component | tuple[Component, ...] | str, + ] + ) + + +class PageDefinition(Protocol): + """Protocol for page-like objects compiled by :class:`CompileContext`.""" + + @property + def route(self) -> str: + """Return the route for this page definition.""" + ... + + @property + def component(self) -> PageComponent: + """Return the component or callable for this page definition.""" + ... + + +ComponentAndChildren: TypeAlias = tuple[BaseComponent, tuple[BaseComponent, ...]] +ComponentReplacement: TypeAlias = BaseComponent | ComponentAndChildren | None +CompiledEnterHook: TypeAlias = Callable[ + [BaseComponent, bool], + ComponentReplacement, +] +CompiledLeaveHook: TypeAlias = Callable[ + [BaseComponent, tuple[BaseComponent, ...], bool], + ComponentReplacement, +] +EnterHookBinder: TypeAlias = Callable[ + ["PageContext", "CompileContext"], + CompiledEnterHook, +] +LeaveHookBinder: TypeAlias = Callable[ + ["PageContext", "CompileContext"], + CompiledLeaveHook, +] + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class CompilerHooks: + """Dispatch compiler hooks across an ordered plugin chain.""" + + plugins: tuple[Plugin, ...] = () + _eval_page_hooks: tuple[Callable[..., Any], ...] = dataclasses.field( + init=False, + repr=False, + ) + _compile_page_hooks: tuple[Callable[..., Any], ...] = dataclasses.field( + init=False, + repr=False, + ) + _enter_component_hook_binders: tuple[EnterHookBinder, ...] = dataclasses.field( + init=False, + repr=False, + ) + _leave_component_hook_binders: tuple[LeaveHookBinder, ...] = dataclasses.field( + init=False, + repr=False, + ) + _component_hooks_can_replace: bool = dataclasses.field( + init=False, + repr=False, + ) + + def __post_init__(self) -> None: + """Resolve the active compiler hook callables once.""" + object.__setattr__(self, "_eval_page_hooks", self._resolve_hooks("eval_page")) + object.__setattr__( + self, + "_compile_page_hooks", + self._resolve_hooks("compile_page"), + ) + enter_hook_binders: list[EnterHookBinder] = [] + leave_hook_binders: list[LeaveHookBinder] = [] + component_hooks_can_replace = False + + for plugin in self.plugins: + if ( + hook_impl := self._get_hook_impl(plugin, "enter_component") + ) is not None: + enter_hook_binders.append( + self._get_enter_hook_binder(plugin, hook_impl) + ) + component_hooks_can_replace = component_hooks_can_replace or bool( + getattr( + type(plugin), + "_compiler_can_replace_enter_component", + True, + ) + ) + + if ( + hook_impl := self._get_hook_impl(plugin, "leave_component") + ) is not None: + leave_hook_binders.append( + self._get_leave_hook_binder(plugin, hook_impl) + ) + component_hooks_can_replace = component_hooks_can_replace or bool( + getattr( + type(plugin), + "_compiler_can_replace_leave_component", + True, + ) + ) + + object.__setattr__( + self, + "_enter_component_hook_binders", + tuple(enter_hook_binders), + ) + object.__setattr__( + self, + "_leave_component_hook_binders", + tuple(reversed(tuple(leave_hook_binders))), + ) + object.__setattr__( + self, + "_component_hooks_can_replace", + component_hooks_can_replace, + ) + + @staticmethod + def _get_hook_impl( + plugin: Plugin, + hook_name: str, + ) -> Callable[..., Any] | None: + """Return the concrete hook implementation for a plugin, if any. + + Args: + plugin: The plugin to inspect. + hook_name: The hook attribute name. + + Returns: + The bound hook implementation, or ``None`` when the hook is inherited + unchanged from the default base implementation. + """ + plugin_impl = inspect.getattr_static(type(plugin), hook_name, None) + if plugin_impl is None: + return None + + if plugin_impl is inspect.getattr_static(Plugin, hook_name, None): + return None + + return cast(Callable[..., Any], getattr(plugin, hook_name, None)) + + def _resolve_hooks(self, hook_name: str) -> tuple[Callable[..., Any], ...]: + """Resolve concrete hook implementations for the plugin chain. + + Args: + hook_name: The hook attribute name. + + Returns: + The ordered concrete hook implementations for the hook. + """ + return tuple( + hook_impl + for plugin in self.plugins + if (hook_impl := self._get_hook_impl(plugin, hook_name)) is not None + ) + + @staticmethod + def _get_enter_hook_binder( + plugin: Plugin, + hook_impl: Callable[..., Any], + ) -> EnterHookBinder: + """Return a binder that produces a compiled enter-component hook.""" + if ( + binder := getattr(plugin, "_compiler_bind_enter_component", None) + ) is not None: + return cast(EnterHookBinder, binder) + + def bind( + page_context: PageContext, compile_context: CompileContext + ) -> CompiledEnterHook: + def enter_component( + comp: BaseComponent, + in_prop_tree: bool, + ) -> ComponentReplacement: + return cast( + ComponentReplacement, + hook_impl( + comp, + page_context=page_context, + compile_context=compile_context, + in_prop_tree=in_prop_tree, + ), + ) + + return enter_component + + return bind + + @staticmethod + def _get_leave_hook_binder( + plugin: Plugin, + hook_impl: Callable[..., Any], + ) -> LeaveHookBinder: + """Return a binder that produces a compiled leave-component hook.""" + if ( + binder := getattr(plugin, "_compiler_bind_leave_component", None) + ) is not None: + return cast(LeaveHookBinder, binder) + + def bind( + page_context: PageContext, compile_context: CompileContext + ) -> CompiledLeaveHook: + def leave_component( + comp: BaseComponent, + children: tuple[BaseComponent, ...], + in_prop_tree: bool, + ) -> ComponentReplacement: + return cast( + ComponentReplacement, + hook_impl( + comp, + children, + page_context=page_context, + compile_context=compile_context, + in_prop_tree=in_prop_tree, + ), + ) + + return leave_component + + return bind + + def eval_page( + self, + page_fn: PageComponent, + /, + *, + page: PageDefinition, + **kwargs: Any, + ) -> PageContext | None: + """Return the first page context produced by the plugin chain.""" + for hook_impl in self._eval_page_hooks: + result = hook_impl(page_fn, page=page, **kwargs) + if result is not None: + return cast(PageContext, result) + return None + + def compile_page( + self, + page_ctx: PageContext, + /, + **kwargs: Any, + ) -> None: + """Run all ``compile_page`` hooks in plugin order.""" + for hook_impl in self._compile_page_hooks: + hook_impl(page_ctx, **kwargs) + + def compile_component( + self, + comp: BaseComponent, + /, + *, + page_context: PageContext, + compile_context: CompileContext, + in_prop_tree: bool = False, + ) -> BaseComponent: + """Walk a component tree once while dispatching cached enter/leave hooks. + + Returns: + The compiled component root for this subtree. + """ + enter_hooks = tuple( + hook_binder(page_context, compile_context) + for hook_binder in self._enter_component_hook_binders + ) + + if not self._component_hooks_can_replace: + leave_hooks = tuple( + hook_binder(page_context, compile_context) + for hook_binder in self._leave_component_hook_binders + ) + + if len(enter_hooks) == 1 and not leave_hooks: + return self._compile_component_single_enter_fast_path( + comp, + enter_hook=enter_hooks[0], + in_prop_tree=in_prop_tree, + ) + + return self._compile_component_without_replacements( + comp, + enter_hooks=enter_hooks, + leave_hooks=leave_hooks, + in_prop_tree=in_prop_tree, + ) + + return self._compile_component_with_replacements( + comp, + enter_hooks=enter_hooks, + leave_hooks=tuple( + hook_binder(page_context, compile_context) + for hook_binder in self._leave_component_hook_binders + ), + in_prop_tree=in_prop_tree, + ) + + def _compile_component_without_replacements( + self, + comp: BaseComponent, + /, + *, + enter_hooks: tuple[CompiledEnterHook, ...], + leave_hooks: tuple[CompiledLeaveHook, ...], + in_prop_tree: bool = False, + ) -> BaseComponent: + """Walk a component tree when hook plans only observe state. + + Returns: + The compiled component root for this subtree. + """ + + def visit( + current_comp: BaseComponent, + current_in_prop_tree: bool, + ) -> BaseComponent: + for hook_impl in enter_hooks: + hook_impl( + current_comp, + current_in_prop_tree, + ) + + updated_children: list[BaseComponent] | None = None + children = current_comp.children + for index, child in enumerate(children): + compiled_child = visit( + child, + current_in_prop_tree, + ) + if updated_children is None: + if compiled_child is child: + continue + updated_children = list(children[:index]) + updated_children.append(compiled_child) + if updated_children is not None: + current_comp.children = updated_children + + if isinstance(current_comp, Component): + for prop_component in current_comp._get_components_in_props(): + visit( + prop_component, + True, + ) + + if leave_hooks: + compiled_children = tuple(current_comp.children) + for hook_impl in leave_hooks: + hook_impl( + current_comp, + compiled_children, + current_in_prop_tree, + ) + + return current_comp + + return visit( + comp, + in_prop_tree, + ) + + def _compile_component_single_enter_fast_path( + self, + comp: BaseComponent, + /, + *, + enter_hook: CompiledEnterHook, + in_prop_tree: bool = False, + ) -> BaseComponent: + """Walk a component tree for the common one-enter-hook fast path. + + Returns: + The compiled component root for this subtree. + """ + + def visit( + current_comp: BaseComponent, + current_in_prop_tree: bool, + ) -> BaseComponent: + enter_hook( + current_comp, + current_in_prop_tree, + ) + + updated_children: list[BaseComponent] | None = None + children = current_comp.children + for index, child in enumerate(children): + compiled_child = visit( + child, + current_in_prop_tree, + ) + if updated_children is None: + if compiled_child is child: + continue + updated_children = list(children[:index]) + updated_children.append(compiled_child) + if updated_children is not None: + current_comp.children = updated_children + + if isinstance(current_comp, Component): + for prop_component in current_comp._get_components_in_props(): + visit( + prop_component, + True, + ) + + return current_comp + + return visit( + comp, + in_prop_tree, + ) + + def _compile_component_with_replacements( + self, + comp: BaseComponent, + /, + *, + enter_hooks: tuple[CompiledEnterHook, ...], + leave_hooks: tuple[CompiledLeaveHook, ...], + in_prop_tree: bool = False, + ) -> BaseComponent: + """Walk a component tree while honoring hook replacements. + + Returns: + The compiled component root for this subtree. + """ + apply_replacement = self._apply_replacement + + def visit_children( + children: Sequence[BaseComponent], + current_in_prop_tree: bool, + ) -> tuple[BaseComponent, ...]: + if not children: + return () + + updated_children: list[BaseComponent] | None = None + for index, child in enumerate(children): + compiled_child = visit( + child, + current_in_prop_tree, + ) + if updated_children is None: + if compiled_child is child: + continue + updated_children = list(children[:index]) + updated_children.append(compiled_child) + if updated_children is None: + return children if isinstance(children, tuple) else tuple(children) + return tuple(updated_children) + + def visit( + current_comp: BaseComponent, + current_in_prop_tree: bool, + ) -> BaseComponent: + compiled_component = current_comp + structural_children: tuple[BaseComponent, ...] | None = None + + for hook_impl in enter_hooks: + compiled_component, structural_children = apply_replacement( + compiled_component, + structural_children, + hook_impl( + compiled_component, + current_in_prop_tree, + ), + ) + + if structural_children is None: + structural_children = tuple(compiled_component.children) + compiled_children = visit_children( + structural_children, + current_in_prop_tree, + ) + if isinstance(compiled_component, Component): + for prop_component in compiled_component._get_components_in_props(): + visit( + prop_component, + True, + ) + + for hook_impl in leave_hooks: + compiled_component, replacement_children = apply_replacement( + compiled_component, + compiled_children, + hook_impl( + compiled_component, + compiled_children, + current_in_prop_tree, + ), + ) + if replacement_children is not compiled_children: + assert replacement_children is not None + compiled_children = visit_children( + replacement_children, + current_in_prop_tree, + ) + + compiled_component.children = list(compiled_children) + return compiled_component + + return visit( + comp, + in_prop_tree, + ) + + @staticmethod + def _apply_replacement( + comp: BaseComponent, + children: tuple[BaseComponent, ...] | None, + replacement: ComponentReplacement, + ) -> tuple[BaseComponent, tuple[BaseComponent, ...] | None]: + """Apply a plugin replacement to the current component state. + + Args: + comp: The current component. + children: The current structural children. + replacement: The plugin-supplied replacement. + + Returns: + The updated component and structural children pair. + """ + if replacement is None: + return comp, children + if isinstance(replacement, tuple): + return replacement + return replacement, children + + +@dataclasses.dataclass(kw_only=True) +class BaseContext: + """Context manager that exposes itself through a class-local context var.""" + + __context_var__: ClassVar[ContextVar[Self | None]] + + _attached_context_token: Token[Self | None] | None = dataclasses.field( + default=None, + init=False, + repr=False, + ) + + @classmethod + def __init_subclass__(cls, **kwargs: Any) -> None: + """Initialize a dedicated context variable for each subclass.""" + super().__init_subclass__(**kwargs) + cls.__context_var__ = ContextVar(cls.__name__, default=None) + + @classmethod + def get(cls) -> Self: + """Return the active context instance for the current task. + + Returns: + The active context instance for the current task. + """ + context = cls.__context_var__.get() + if context is None: + msg = f"No active {cls.__name__} is attached to the current context." + raise RuntimeError(msg) + return context + + def __enter__(self) -> Self: + """Attach this context to the current task. + + Returns: + The attached context instance. + """ + if self._attached_context_token is not None: + msg = "Context is already attached and cannot be entered twice." + raise RuntimeError(msg) + self._attached_context_token = type(self).__context_var__.set(self) + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Detach this context from the current task.""" + del exc_type, exc_val, exc_tb + if self._attached_context_token is None: + return + try: + type(self).__context_var__.reset(self._attached_context_token) + finally: + self._attached_context_token = None + + async def __aenter__(self) -> Self: + """Attach this context to the current task asynchronously. + + Returns: + The attached context instance. + """ + return self.__enter__() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Detach this context from the current task asynchronously.""" + self.__exit__(exc_type, exc_val, exc_tb) + + def ensure_context_attached(self) -> None: + """Ensure this instance is the active context for the current task.""" + try: + current = type(self).get() + except RuntimeError as err: + msg = ( + f"{type(self).__name__} must be entered with 'with' or 'async with' " + "before calling this method." + ) + raise RuntimeError(msg) from err + if current is not self: + msg = f"{type(self).__name__} is not attached to the current task context." + raise RuntimeError(msg) + + +@dataclasses.dataclass(slots=True, kw_only=True) +class PageContext(BaseContext): + """Mutable compilation state for a single page.""" + + name: str + route: str + root_component: BaseComponent + imports: list[ParsedImportDict] = dataclasses.field(default_factory=list) + module_code: dict[str, None] = dataclasses.field(default_factory=dict) + hooks: dict[str, VarData | None] = dataclasses.field(default_factory=dict) + dynamic_imports: set[str] = dataclasses.field(default_factory=set) + refs: dict[str, None] = dataclasses.field(default_factory=dict) + app_wrap_components: dict[tuple[int, str], Component] = dataclasses.field( + default_factory=dict + ) + frontend_imports: ParsedImportDict = dataclasses.field(default_factory=dict) + output_path: str | None = None + output_code: str | None = None + + def merged_imports(self, *, collapse: bool = False) -> ParsedImportDict: + """Return the imports accumulated for this page. + + Args: + collapse: Whether to collapse duplicate imports. + + Returns: + The merged page imports. + """ + imports = merge_imports(*self.imports) if self.imports else {} + return collapse_imports(imports) if collapse else imports + + def custom_code_dict(self) -> dict[str, None]: + """Return custom-code snippets keyed like legacy collectors. + + Returns: + The page custom code keyed by snippet. + """ + return dict(self.module_code) + + +@dataclasses.dataclass(slots=True, kw_only=True) +class CompileContext(BaseContext): + """Mutable compilation state for an entire compile run.""" + + app: App | None = None + pages: Sequence[PageDefinition] + hooks: CompilerHooks = dataclasses.field(default_factory=CompilerHooks) + compiled_pages: dict[str, PageContext] = dataclasses.field(default_factory=dict) + all_imports: ParsedImportDict = dataclasses.field(default_factory=dict) + app_wrap_components: dict[tuple[int, str], Component] = dataclasses.field( + default_factory=dict + ) + stateful_routes: dict[str, None] = dataclasses.field(default_factory=dict) + # Auto-memoize wrapper tags seen during the tree walk (populated by + # ``MemoizeStatefulPlugin``). + memoize_wrappers: dict[str, None] = dataclasses.field(default_factory=dict) + # Compiler-generated experimental memo definitions for auto-memoized + # stateful wrappers. Stored as ``Any`` to keep ``reflex_base`` decoupled + # from ``reflex.experimental.memo``. + auto_memo_components: dict[str, Any] = dataclasses.field(default_factory=dict) + + def compile( + self, + *, + evaluate_progress: Callable[[], None] | None = None, + render_progress: Callable[[], None] | None = None, + **kwargs: Any, + ) -> dict[str, PageContext]: + """Compile all configured pages through the plugin pipeline. + + Args: + evaluate_progress: Callback invoked after each page evaluation. + render_progress: Callback invoked after each page render. + kwargs: Additional compiler-specific context. + + Returns: + The compiled page contexts keyed by route. + """ + from reflex.compiler import compiler + from reflex.state import all_base_state_classes + + self.ensure_context_attached() + self.compiled_pages.clear() + self.all_imports.clear() + self.app_wrap_components.clear() + self.stateful_routes.clear() + self.memoize_wrappers.clear() + self.auto_memo_components.clear() + + for page in self.pages: + page_fn = page.component + n_states_before = len(all_base_state_classes) + page_ctx = self.hooks.eval_page( + page_fn, + page=page, + compile_context=self, + **kwargs, + ) + if page_ctx is None: + page_name = getattr(page_fn, "__name__", repr(page_fn)) + msg = ( + f"No compiler plugin was able to evaluate page {page.route!r} " + f"({page_name})." + ) + raise RuntimeError(msg) + if page_ctx.route in self.compiled_pages: + msg = f"Duplicate compiled page route {page_ctx.route!r}." + raise RuntimeError(msg) + + if len(all_base_state_classes) > n_states_before: + self.stateful_routes[page.route] = None + + self.compiled_pages[page_ctx.route] = page_ctx + + if evaluate_progress is not None: + evaluate_progress() + + for page, page_ctx in zip( + self.pages, + self.compiled_pages.values(), + strict=True, + ): + with page_ctx: + page_ctx.root_component = self.hooks.compile_component( + page_ctx.root_component, + page_context=page_ctx, + compile_context=self, + ) + self.hooks.compile_page( + page_ctx, + page=page, + compile_context=self, + **kwargs, + ) + + page_ctx.frontend_imports = page_ctx.merged_imports(collapse=True) + self.all_imports = merge_imports( + self.all_imports, page_ctx.frontend_imports + ) + self.app_wrap_components.update(page_ctx.app_wrap_components) + page_ctx.output_path, page_ctx.output_code = ( + compiler.compile_page_from_context(page_ctx) + ) + + if render_progress is not None: + render_progress() + + return self.compiled_pages + + +__all__ = [ + "BaseContext", + "CompileContext", + "CompilerHooks", + "ComponentAndChildren", + "PageContext", + "PageDefinition", +] diff --git a/packages/reflex-base/src/reflex_base/registry.py b/packages/reflex-base/src/reflex_base/registry.py index 8caa1d2b2c3..71b4d723e5e 100644 --- a/packages/reflex-base/src/reflex_base/registry.py +++ b/packages/reflex-base/src/reflex_base/registry.py @@ -12,7 +12,6 @@ if TYPE_CHECKING: from reflex.state import BaseState - from reflex_base.components.component import StatefulComponent from reflex_base.event import EventHandler @@ -40,10 +39,6 @@ class RegistrationContext(BaseContext): default_factory=dict, repr=False, ) - tag_to_stateful_component: dict[str, StatefulComponent] = dataclasses.field( - default_factory=dict, - repr=False, - ) @classmethod def ensure_context(cls) -> Self: diff --git a/packages/reflex-base/src/reflex_base/utils/console.py b/packages/reflex-base/src/reflex_base/utils/console.py index b7a9539f77b..962008e42e0 100644 --- a/packages/reflex-base/src/reflex_base/utils/console.py +++ b/packages/reflex-base/src/reflex_base/utils/console.py @@ -479,6 +479,18 @@ def advance(self, task: TaskID, advance: int = 1): self.progress += advance _console.print(f"Progress: {self.progress}/{self.total}") + def update(self, task: TaskID, total: int | None = None): + """Update properties of a task. + + Args: + task: The task ID. + total: New total for the task. + """ + if total is not None and task in self.tasks: + previous_total = self.tasks[task]["total"] + self.tasks[task]["total"] = total + self.total += total - previous_total + def start(self): """Start the progress bar.""" diff --git a/packages/reflex-base/src/reflex_base/utils/streaming_response.py b/packages/reflex-base/src/reflex_base/utils/streaming_response.py index d9907379cef..66c6ab7fc62 100644 --- a/packages/reflex-base/src/reflex_base/utils/streaming_response.py +++ b/packages/reflex-base/src/reflex_base/utils/streaming_response.py @@ -2,7 +2,6 @@ from __future__ import annotations -import asyncio import builtins import contextlib import sys @@ -60,14 +59,16 @@ def _collapse_excgroups() -> Generator[None, None, None]: class DisconnectAwareStreamingResponse(StreamingResponse): - """Streaming response that cancels its body task on disconnect.""" + """Streaming response with a guaranteed finish callback.""" _on_finish: Callable[[], Awaitable[None]] + _on_disconnect: Callable[[], None] | None def __init__( self, *args: Any, on_finish: Callable[[], Awaitable[None]], + on_disconnect: Callable[[], None] | None = None, **kwargs: Any, ) -> None: """Initialize the response. @@ -75,17 +76,17 @@ def __init__( Args: args: Positional args forwarded to ``StreamingResponse``. on_finish: Cleanup callback to run exactly once when the response ends. + on_disconnect: Sync callback invoked when the client disconnects. kwargs: Keyword args forwarded to ``StreamingResponse``. """ super().__init__(*args, **kwargs) self._on_finish = on_finish + self._on_disconnect = on_disconnect - async def _watch_disconnect(self, receive: Receive) -> None: - """Wait for the client connection to close.""" - while True: - message = await receive() - if message["type"] == "http.disconnect": - return + def _notify_disconnect(self) -> None: + """Invoke the on_disconnect callback if one was provided.""" + if self._on_disconnect is not None: + self._on_disconnect() async def _close_body_iterator(self) -> None: """Close the body iterator if it supports ``aclose``.""" @@ -94,7 +95,7 @@ async def _close_body_iterator(self) -> None: await aclose() async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - """Serve the response and cancel the body task on disconnect.""" + """Serve the response and always run the finish callback.""" spec_version = _parse_asgi_spec_version(scope) try: @@ -107,47 +108,24 @@ async def wrap(func: Callable[[], Awaitable[None]]) -> None: task_group.cancel_scope.cancel() task_group.start_soon(wrap, partial(self.stream_response, send)) - await wrap(partial(self.listen_for_disconnect, receive)) - else: - # Verified against Starlette 0.52.1: the ASGI >= 2.4 path in - # StreamingResponse.__call__ delegates straight to - # stream_response(send) and does not read from receive(). - # Keep calling stream_response(send) directly here so the - # disconnect watcher remains the only receive() consumer; if - # Starlette changes that contract, re-check this logic. - stream_task = asyncio.create_task(self.stream_response(send)) - disconnect_task = asyncio.create_task(self._watch_disconnect(receive)) - should_close_body_iterator = False + if self._on_disconnect is not None: + + async def _disconnect_then_notify() -> None: + await self.listen_for_disconnect(receive) + self._notify_disconnect() + + await wrap(_disconnect_then_notify) + else: + await wrap(partial(self.listen_for_disconnect, receive)) + else: try: - done, _ = await asyncio.wait( - {stream_task, disconnect_task}, - return_when=asyncio.FIRST_COMPLETED, - ) - if disconnect_task in done and not stream_task.done(): - should_close_body_iterator = True - stream_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await stream_task - else: - try: - await stream_task - except OSError as err: - should_close_body_iterator = True - raise ClientDisconnect from err - finally: - if not disconnect_task.done(): - disconnect_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await disconnect_task - if not stream_task.done(): - should_close_body_iterator = True - stream_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await stream_task - if should_close_body_iterator: - await self._close_body_iterator() + await self.stream_response(send) + except OSError as err: + self._notify_disconnect() + raise ClientDisconnect from err finally: + await self._close_body_iterator() await self._on_finish() if self.background is not None: diff --git a/packages/reflex-components-core/src/reflex_components_core/core/_upload.py b/packages/reflex-components-core/src/reflex_components_core/core/_upload.py index 0674e071fbd..9179bb01f84 100644 --- a/packages/reflex-components-core/src/reflex_components_core/core/_upload.py +++ b/packages/reflex-components-core/src/reflex_components_core/core/_upload.py @@ -23,6 +23,7 @@ from typing_extensions import Self if TYPE_CHECKING: + from reflex_base.event.processor import EventFuture from reflex_base.utils.types import ASGIApp, Receive, Scope, Send from reflex.app import App @@ -495,20 +496,50 @@ def _create_upload_event() -> Event: msg = "Upload event was not created." raise RuntimeError(msg) + task_future: EventFuture | None = None + disconnect_seen = False + + def _try_cancel() -> None: + """Cancel the task future if it exists and is still running.""" + if task_future is not None and not task_future.done(): + task_future.cancel() + + def _remember_task_future(future: EventFuture) -> None: + """Keep a handle to the upload task for disconnect cancellation.""" + nonlocal task_future + task_future = future + if disconnect_seen: + _try_cancel() + + def _cancel_upload_task() -> None: + """Cancel the queued upload handler when the client disconnects.""" + nonlocal disconnect_seen + disconnect_seen = True + _try_cancel() + async def _ndjson_updates(): """Process the upload event, generating ndjson updates. Yields: Each state update as newline-delimited JSON. """ + # Let the disconnect watcher run before we enqueue the upload handler. + await asyncio.sleep(0) + if disconnect_seen: + return # Enqueue the task on the main event loop, but emit deltas to the local queue. - async for delta in app.event_processor.enqueue_stream_delta(token, event): + async for delta in app.event_processor.enqueue_stream_delta( + token, + event, + on_task_future=_remember_task_future, + ): yield json_dumps(StateUpdate(delta=delta)) + "\n" return DisconnectAwareStreamingResponse( _ndjson_updates(), media_type="application/x-ndjson", on_finish=_close_form_data, + on_disconnect=_cancel_upload_task, ) diff --git a/packages/reflex-components-core/src/reflex_components_core/core/upload.py b/packages/reflex-components-core/src/reflex_components_core/core/upload.py index 84b3ef06f06..ab020eb9884 100644 --- a/packages/reflex-components-core/src/reflex_components_core/core/upload.py +++ b/packages/reflex-components-core/src/reflex_components_core/core/upload.py @@ -10,9 +10,9 @@ Component, ComponentNamespace, MemoizationLeaf, - StatefulComponent, field, ) +from reflex_base.components.memoize_helpers import get_memoized_event_triggers from reflex_base.constants import Dirs from reflex_base.constants.compiler import Hooks, Imports from reflex_base.environment import environment @@ -357,7 +357,7 @@ def create(cls, *children, **props) -> Component: ), ) - event_triggers = StatefulComponent._get_memoized_event_triggers( + event_triggers = get_memoized_event_triggers( GhostUpload.create( on_drop=upload_props["on_drop"], on_drop_rejected=upload_props["on_drop_rejected"], diff --git a/packages/reflex-components-core/src/reflex_components_core/core/window_events.py b/packages/reflex-components-core/src/reflex_components_core/core/window_events.py index debb4c3dc37..10a7362188d 100644 --- a/packages/reflex-components-core/src/reflex_components_core/core/window_events.py +++ b/packages/reflex-components-core/src/reflex_components_core/core/window_events.py @@ -4,7 +4,7 @@ from typing import Any, cast -from reflex_base.components.component import StatefulComponent, field +from reflex_base.components.component import field from reflex_base.constants.compiler import Hooks from reflex_base.event import EventHandler, key_event, no_args_event_spec from reflex_base.vars.base import Var, VarData @@ -95,8 +95,10 @@ def create(cls, **props) -> WindowEventListener: Returns: The created component. """ + from reflex_base.components.memoize_helpers import fix_event_triggers_for_memo + real_component = cast("WindowEventListener", super().create(**props)) - hooks = StatefulComponent._fix_event_triggers(real_component) + hooks = fix_event_triggers_for_memo(real_component) real_component.hooks = hooks return real_component diff --git a/pyi_hashes.json b/pyi_hashes.json index 6f93a2554e3..88eb9573581 100644 --- a/pyi_hashes.json +++ b/pyi_hashes.json @@ -20,8 +20,8 @@ "packages/reflex-components-core/src/reflex_components_core/core/helmet.pyi": "7fd81a99bde5b0ff94bb52523597fd5c", "packages/reflex-components-core/src/reflex_components_core/core/html.pyi": "753d6ae315369530dad450ed643f5be6", "packages/reflex-components-core/src/reflex_components_core/core/sticky.pyi": "ba60a7d9cba75b27a1133bd63a9fbd59", - "packages/reflex-components-core/src/reflex_components_core/core/upload.pyi": "17775edb94cc804686ae4cd873584810", - "packages/reflex-components-core/src/reflex_components_core/core/window_events.pyi": "cab827931770be082cd1598a9908abbc", + "packages/reflex-components-core/src/reflex_components_core/core/upload.pyi": "ac589d6237fe51414d536b9d70de5dec", + "packages/reflex-components-core/src/reflex_components_core/core/window_events.pyi": "5e1dcb1130bc8af282783fae329ae6a6", "packages/reflex-components-core/src/reflex_components_core/datadisplay/__init__.pyi": "c96fed4da42a13576d64f84e3c7cb25c", "packages/reflex-components-core/src/reflex_components_core/el/__init__.pyi": "f09129ddefb57ab4c7769c86dc9a3153", "packages/reflex-components-core/src/reflex_components_core/el/element.pyi": "ff68d843c5987d3f0d773a6367eb9c63", @@ -120,5 +120,5 @@ "packages/reflex-components-sonner/src/reflex_components_sonner/toast.pyi": "2c5fadcc014056f041cd4d916137d9e7", "reflex/__init__.pyi": "3a9bb8544cbc338ffaf0a5927d9156df", "reflex/components/__init__.pyi": "f39a2af77f438fa243c58c965f19d42e", - "reflex/experimental/memo.pyi": "2c119a0dfea362dcd8193786363cbc02" + "reflex/experimental/memo.pyi": "792f2ffe75f3acce94af31bd8458a061" } diff --git a/reflex/app.py b/reflex/app.py index 389d18f246a..cc6bef7b99d 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -3,7 +3,6 @@ from __future__ import annotations import asyncio -import concurrent.futures import contextlib import copy import dataclasses @@ -15,23 +14,21 @@ import time import traceback import urllib.parse -from collections.abc import AsyncIterator, Callable, Coroutine, Mapping, Sequence -from datetime import datetime -from itertools import chain -from pathlib import Path -from timeit import default_timer as timer +from collections.abc import ( + AsyncIterator, + Callable, + Collection, + Coroutine, + Mapping, + Sequence, +) from types import SimpleNamespace -from typing import TYPE_CHECKING, Any, ParamSpec, overload +from typing import TYPE_CHECKING, Any, overload from reflex_base import constants -from reflex_base.components.component import ( - CUSTOM_COMPONENTS, - Component, - ComponentStyle, - evaluate_style_namespaces, -) +from reflex_base.components.component import Component, ComponentStyle from reflex_base.config import get_config -from reflex_base.environment import ExecutorType, environment +from reflex_base.environment import environment from reflex_base.event import ( _EVENT_FIELDS, Event, @@ -45,10 +42,8 @@ from reflex_base.utils import console from reflex_base.utils.imports import ImportVar from reflex_base.utils.types import ASGIApp, Message, Receive, Scope, Send -from reflex_components_core.base.app_wrap import AppWrap from reflex_components_core.base.error_boundary import ErrorBoundary from reflex_components_core.base.fragment import Fragment -from reflex_components_core.base.strict_mode import StrictMode from reflex_components_core.core.banner import ( backend_disabled, connection_pulser, @@ -58,7 +53,6 @@ from reflex_components_core.core.sticky import sticky from reflex_components_radix import themes from reflex_components_sonner.toast import toast -from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn from socketio import ASGIApp as EngineIOApp from socketio import AsyncNamespace, AsyncServer from starlette.applications import Starlette @@ -73,13 +67,7 @@ from reflex.admin import AdminDash from reflex.app_mixins import AppMixin, LifespanMixin, MiddlewareMixin from reflex.compiler import compiler -from reflex.compiler import utils as compiler_utils -from reflex.compiler.compiler import ( - ExecutorSafeFunctions, - compile_theme, - readable_name_from_component, -) -from reflex.experimental.memo import EXPERIMENTAL_MEMOS +from reflex.compiler.compiler import readable_name_from_component from reflex.istate.manager import StateManager, StateModificationContext from reflex.istate.manager.token import BaseStateToken from reflex.page import DECORATED_PAGES @@ -94,17 +82,8 @@ State, StateUpdate, all_base_state_classes, - code_uses_state_contexts, -) -from reflex.utils import ( - codespaces, - exceptions, - format, - frontend_skeleton, - js_runtimes, - path_ops, - prerequisites, ) +from reflex.utils import codespaces, exceptions, format, js_runtimes, prerequisites from reflex.utils.exec import ( get_compile_context, is_prod_mode, @@ -204,10 +183,10 @@ def extra_overlay_function() -> Component | None: def default_overlay_component() -> Component: - """Default overlay_component attribute for App. + """Default overlay component included in the app wraps. Returns: - The default overlay_component, which is a connection_modal. + The default overlay component, which is a connection banner/toaster set. """ from reflex_base.components.component import memo @@ -251,12 +230,12 @@ class UnevaluatedPage: component: Component | ComponentCallable route: str - title: Var | str | None - description: Var | str | None - image: str - on_load: EventType[()] | None - meta: Sequence[Mapping[str, Any] | Component] - context: Mapping[str, Any] + title: Var | str | None = None + description: Var | str | None = None + image: str = "" + on_load: EventType[()] | None = None + meta: Sequence[Mapping[str, Any] | Component] = () + context: Mapping[str, Any] = dataclasses.field(default_factory=dict) def merged_with(self, other: UnevaluatedPage) -> UnevaluatedPage: """Merge the other page into this one. @@ -278,9 +257,6 @@ def merged_with(self, other: UnevaluatedPage) -> UnevaluatedPage: ) -P = ParamSpec("P") - - @dataclasses.dataclass() class App(MiddlewareMixin, LifespanMixin): """The main Reflex app that encapsulates the backend and frontend. @@ -307,7 +283,6 @@ class App(MiddlewareMixin, LifespanMixin): style: The [global style](https://reflex.dev/docs/styling/overview/#global-styles}) for the app. stylesheets: A list of URLs to [stylesheets](https://reflex.dev/docs/styling/custom-stylesheets/) to include in the app. reset_style: Whether to include CSS reset for margin and padding. Defaults to True. - overlay_component: A component that is present on every page. Defaults to the Connection Error banner. app_wraps: App wraps to be applied to the whole app. Expected to be a dictionary of (order, name) to a function that takes whether the state is enabled and optionally returns a component. extra_app_wraps: Extra app wraps to be applied to the whole app. head_components: Components to add to the head of every page. @@ -885,7 +860,10 @@ def _compile_page(self, route: str, save_page: bool = True): """ n_states_before = len(all_base_state_classes) component = compiler.compile_unevaluated_page( - route, self._unevaluated_pages[route], self.style, self.theme + route, + self._unevaluated_pages[route], + self.style, + self.theme, ) # Indicate that evaluating this page creates one or more state classes. @@ -997,7 +975,10 @@ def _setup_admin_dash(self): admin.mount_to(self._api) - def _get_frontend_packages(self, imports: dict[str, set[ImportVar]]): + def _get_frontend_packages( + self, + imports: Mapping[str, Collection[ImportVar]], + ) -> None: """Gets the frontend packages to be installed and filters out the unnecessary ones. Args: @@ -1071,16 +1052,6 @@ def _should_compile(self) -> bool: # By default, compile the app. return True - def _add_overlay_to_component( - self, component: Component, overlay_component: Component - ) -> Component: - children = component.children - - if children[0] == overlay_component: - return component - - return Fragment.create(overlay_component, *children) - def _setup_sticky_badge(self): """Add the sticky badge to the app.""" from reflex_base.components.component import memo @@ -1149,390 +1120,13 @@ def _compile( ReflexRuntimeError: When any page uses state, but no rx.State subclass is defined. FileNotFoundError: When a plugin requires a file that does not exist. """ - from reflex_base.utils.exceptions import ReflexRuntimeError - - self._apply_decorated_pages() - - self._pages = {} - - def get_compilation_time() -> str: - return str(datetime.now().time()).split(".")[0] - - should_compile = self._should_compile() - backend_dir = prerequisites.get_backend_dir() - if not dry_run and not should_compile and backend_dir.exists(): - stateful_pages_marker = backend_dir / constants.Dirs.STATEFUL_PAGES - if stateful_pages_marker.exists(): - with stateful_pages_marker.open("r") as f: - stateful_pages = json.load(f) - for route in stateful_pages: - console.debug(f"BE Evaluating stateful page: {route}") - self._compile_page(route, save_page=False) - self._add_optional_endpoints() - return - - # Render a default 404 page if the user didn't supply one - if constants.Page404.SLUG not in self._unevaluated_pages: - self.add_page(route=constants.Page404.SLUG) - - # Fix up the style. - self.style = evaluate_style_namespaces(self.style) - - # Add the app wrappers. - app_wrappers: dict[tuple[int, str], Component] = { - # Default app wrap component renders {children} - (0, "AppWrap"): AppWrap.create() - } - - if self.theme is not None: - # If a theme component was provided, wrap the app with it - app_wrappers[20, "Theme"] = self.theme - - # Get the env mode. - config = get_config() - - if config.react_strict_mode: - app_wrappers[200, "StrictMode"] = StrictMode.create() - - if not should_compile and not dry_run: - with console.timing("Evaluate Pages (Backend)"): - for route in self._unevaluated_pages: - console.debug(f"Evaluating page: {route}") - self._compile_page(route, save_page=should_compile) - - # Save the pages which created new states at eval time. - self._write_stateful_pages_marker() - - # Add the optional endpoints (_upload) - self._add_optional_endpoints() - - return - - # Create a progress bar. - progress = ( - Progress( - *Progress.get_default_columns()[:-1], - MofNCompleteColumn(), - TimeElapsedColumn(), - ) - if use_rich - else console.PoorProgress() - ) - - # try to be somewhat accurate - but still not 100% - adhoc_steps_without_executor = 7 - fixed_pages_within_executor = 4 - plugin_count = len(config.plugins) - progress.start() - task = progress.add_task( - f"[{get_compilation_time()}] Compiling:", - total=len(self._unevaluated_pages) - + ((len(self._unevaluated_pages) + len(self._pages)) * 3) - + fixed_pages_within_executor - + adhoc_steps_without_executor - + plugin_count, - ) - - with console.timing("Evaluate Pages (Frontend)"): - performance_metrics: list[tuple[str, float]] = [] - for route in self._unevaluated_pages: - console.debug(f"Evaluating page: {route}") - start = timer() - self._compile_page(route, save_page=should_compile) - end = timer() - performance_metrics.append((route, end - start)) - progress.advance(task) - console.debug( - "Slowest pages:\n" - + "\n".join( - f"{route}: {time * 1000:.1f}ms" - for route, time in sorted( - performance_metrics, key=operator.itemgetter(1), reverse=True - )[:10] - ) - ) - # Save the pages which created new states at eval time. - self._write_stateful_pages_marker() - - # Add the optional endpoints (_upload) - self._add_optional_endpoints() - - self._validate_var_dependencies() - - if config.show_built_with_reflex is None: - if ( - get_compile_context() == constants.CompileContext.DEPLOY - and prerequisites.get_user_tier() in ["pro", "team", "enterprise"] - ): - config.show_built_with_reflex = False - else: - config.show_built_with_reflex = True - - if is_prod_mode() and config.show_built_with_reflex: - self._setup_sticky_badge() - - progress.advance(task) - - # Store the compile results. - compile_results: list[tuple[str, str]] = [] - - progress.advance(task) - - # Track imports found. - all_imports = {} - - if (toaster := self.toaster) is not None: - from reflex_base.components.component import memo - - @memo - def memoized_toast_provider(): - return toaster - - toast_provider = Fragment.create(memoized_toast_provider()) - - app_wrappers[44, "ToasterProvider"] = toast_provider - - # Add the app wraps to the app. - for key, app_wrap in chain( - self.app_wraps.items(), self.extra_app_wraps.items() - ): - # If the app wrap is a callable, generate the component - component = app_wrap(self._state is not None) - if component is not None: - app_wrappers[key] = component - - # Compile custom components. - ( - memo_components_output, - memo_components_result, - memo_components_imports, - ) = compiler.compile_memo_components( - dict.fromkeys(CUSTOM_COMPONENTS.values()), - tuple(EXPERIMENTAL_MEMOS.values()), - ) - compile_results.append((memo_components_output, memo_components_result)) - all_imports.update(memo_components_imports) - progress.advance(task) - - with console.timing("Collect all imports and app wraps"): - # This has to happen before compiling stateful components as that - # prevents recursive functions from reaching all components. - for component in self._pages.values(): - # Add component._get_all_imports() to all_imports. - all_imports.update(component._get_all_imports()) - - # Add the app wrappers from this component. - app_wrappers.update(component._get_all_app_wrap_components()) - - progress.advance(task) - - # Perform auto-memoization of stateful components. - with console.timing("Auto-memoize StatefulComponents"): - ( - stateful_components_path, - stateful_components_code, - page_components, - ) = compiler.compile_stateful_components( - self._pages.values(), - progress_function=lambda task=task: progress.advance(task), - ) - progress.advance(task) - - # Catch "static" apps (that do not define a rx.State subclass) which are trying to access rx.State. - if code_uses_state_contexts(stateful_components_code) and self._state is None: - msg = ( - "To access rx.State in frontend components, at least one " - "subclass of rx.State must be defined in the app." - ) - raise ReflexRuntimeError(msg) - compile_results.append((stateful_components_path, stateful_components_code)) - - progress.advance(task) - - # Compile the root document before fork. - compile_results.append( - compiler.compile_document_root( - self.head_components, - html_lang=self.html_lang, - html_custom_attrs=( - {"suppressHydrationWarning": True, **self.html_custom_attrs} - if self.html_custom_attrs - else {"suppressHydrationWarning": True} - ), - ) - ) - - progress.advance(task) - - # Copy the assets. - assets_src = Path.cwd() / constants.Dirs.APP_ASSETS - if assets_src.is_dir() and not dry_run: - with console.timing("Copy assets"): - path_ops.update_directory_tree( - src=assets_src, - dest=( - Path.cwd() / prerequisites.get_web_dir() / constants.Dirs.PUBLIC - ), - ) - - executor = ExecutorType.get_executor_from_environment() - - for route, component in zip(self._pages, page_components, strict=True): - ExecutorSafeFunctions.COMPONENTS[route] = component - - modify_files_tasks: list[tuple[str, str, Callable[[str], str]]] = [] - - with console.timing("Compile to Javascript"), executor as executor: - result_futures: list[ - concurrent.futures.Future[ - list[tuple[str, str]] | tuple[str, str] | None - ] - ] = [] - - def _submit_work( - fn: Callable[P, list[tuple[str, str]] | tuple[str, str] | None], - *args: P.args, - **kwargs: P.kwargs, - ): - f = executor.submit(fn, *args, **kwargs) - f.add_done_callback(lambda _: progress.advance(task)) - result_futures.append(f) - - # Compile the pre-compiled pages. - for route in self._pages: - _submit_work( - ExecutorSafeFunctions.compile_page, - route, - ) - - # Compile the root stylesheet with base styles. - _submit_work( - compiler.compile_root_stylesheet, self.stylesheets, self.reset_style - ) - - # Compile the theme. - _submit_work(compile_theme, self.style) - - def _submit_work_without_advancing( - fn: Callable[P, list[tuple[str, str]] | tuple[str, str] | None], - *args: P.args, - **kwargs: P.kwargs, - ): - f = executor.submit(fn, *args, **kwargs) - result_futures.append(f) - - for plugin in config.plugins: - plugin.pre_compile( - add_save_task=_submit_work_without_advancing, - add_modify_task=( - lambda *args, plugin=plugin: modify_files_tasks.append(( - plugin.__class__.__module__ + plugin.__class__.__name__, - *args, - )) - ), - unevaluated_pages=list(self._unevaluated_pages.values()), - ) - - # Wait for all compilation tasks to complete. - for future in concurrent.futures.as_completed(result_futures): - if (result := future.result()) is not None: - if isinstance(result, list): - compile_results.extend(result) - else: - compile_results.append(result) - - progress.advance(task, advance=len(config.plugins)) - - app_root = self._app_root(app_wrappers=app_wrappers) - - # Get imports from AppWrap components. - all_imports.update(app_root._get_all_imports()) - - progress.advance(task) - - # Compile the contexts. - compile_results.append( - compiler.compile_contexts(self._state, self.theme), - ) - if self.theme is not None: - # Fix #2992 by removing the top-level appearance prop - self.theme.appearance = None # pyright: ignore[reportAttributeAccessIssue] - progress.advance(task) - - # Compile the app root. - compile_results.append( - compiler.compile_app(app_root), - ) - progress.advance(task) - - progress.stop() - - if dry_run: - return - - # Install frontend packages. - with console.timing("Install Frontend Packages"): - self._get_frontend_packages(all_imports) - - # Setup the react-router.config.js - frontend_skeleton.update_react_router_config( + compiler.compile_app( + self, prerender_routes=prerender_routes, + dry_run=dry_run, + use_rich=use_rich, ) - if is_prod_mode(): - # Empty the .web pages directory. - compiler.purge_web_pages_dir() - else: - # In dev mode, delete removed pages and update existing pages. - keep_files = [Path(output_path) for output_path, _ in compile_results] - for p in Path( - prerequisites.get_web_dir() - / constants.Dirs.PAGES - / constants.Dirs.ROUTES - ).rglob("*"): - if p.is_file() and p not in keep_files: - # Remove pages that are no longer in the app. - p.unlink() - - output_mapping: dict[Path, str] = {} - for output_path, code in compile_results: - path = compiler_utils.resolve_path_of_web_dir(output_path) - if path in output_mapping: - console.warn( - f"Path {path} has two different outputs. The first one will be used." - ) - else: - output_mapping[path] = code - - for plugin in config.plugins: - for static_file_path, content in plugin.get_static_assets(): - path = compiler_utils.resolve_path_of_web_dir(static_file_path) - if path in output_mapping: - console.warn( - f"Plugin {plugin.__class__.__name__} is trying to write to {path} but it already exists. The plugin file will be ignored." - ) - else: - output_mapping[path] = ( - content.decode("utf-8") - if isinstance(content, bytes) - else content - ) - - for plugin_name, file_path, modify_fn in modify_files_tasks: - path = compiler_utils.resolve_path_of_web_dir(file_path) - file_content = output_mapping.get(path) - if file_content is None: - if path.exists(): - file_content = path.read_text() - else: - msg = f"Plugin {plugin_name} is trying to modify {path} but it does not exist." - raise FileNotFoundError(msg) - output_mapping[path] = modify_fn(file_content) - - with console.timing("Write to Disk"): - for output_path, code in output_mapping.items(): - compiler_utils.write_file(output_path, code) - def _write_stateful_pages_marker(self): """Write list of routes that create dynamic states for the backend to use later.""" if self._state is not None: diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index feec49ee69a..de941cd4a0e 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -2,6 +2,7 @@ from __future__ import annotations +import json import sys from collections.abc import Callable, Iterable, Sequence from inspect import getmodule @@ -10,35 +11,50 @@ from reflex_base import constants from reflex_base.components.component import ( + CUSTOM_COMPONENTS, BaseComponent, Component, ComponentStyle, CustomComponent, - StatefulComponent, + evaluate_style_namespaces, ) from reflex_base.config import get_config from reflex_base.constants.compiler import PageNames, ResetStylesheet from reflex_base.constants.state import FIELD_MARKER from reflex_base.environment import environment +from reflex_base.plugins import CompileContext, CompilerHooks, PageContext from reflex_base.style import SYSTEM_COLOR_MODE from reflex_base.utils.exceptions import ReflexError from reflex_base.utils.format import to_title_case -from reflex_base.utils.imports import ImportVar, ParsedImportDict +from reflex_base.utils.imports import ImportVar from reflex_base.vars.base import LiteralVar, Var +from reflex_components_core.base.app_wrap import AppWrap from reflex_components_core.base.fragment import Fragment +from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn from reflex.compiler import templates, utils +from reflex.compiler.plugins import default_page_plugins from reflex.experimental.memo import ( + EXPERIMENTAL_MEMOS, ExperimentalMemoComponentDefinition, ExperimentalMemoDefinition, ExperimentalMemoFunctionDefinition, ) -from reflex.state import BaseState -from reflex.utils import console, path_ops -from reflex.utils.exec import is_prod_mode +from reflex.state import BaseState, code_uses_state_contexts +from reflex.utils import console, frontend_skeleton, path_ops, prerequisites +from reflex.utils.exec import get_compile_context, is_prod_mode from reflex.utils.prerequisites import get_web_dir +def _set_progress_total( + progress: Progress | console.PoorProgress, + task: Any, + total: int, +) -> None: + """Update a task total for either rich or fallback progress bars.""" + progress.update(task, total=total) + + def _apply_common_imports( imports: dict[str, list[ImportVar]], ): @@ -331,7 +347,7 @@ def _compile_root_stylesheet(stylesheets: list[str], reset_style: bool = True) - return templates.styles_template(stylesheets=sheets) -def _compile_component(component: Component | StatefulComponent) -> str: +def _compile_component(component: Component) -> str: """Compile a single component. Args: @@ -412,85 +428,6 @@ def _compile_memo_components( ) -def _get_shared_components_recursive( - component: BaseComponent, - rendered_components: dict[str, None], - all_import_dicts: list[ParsedImportDict], -): - """Get the shared components for a component and its children. - - A shared component is a StatefulComponent that appears in 2 or more - pages and is a candidate for writing to a common file and importing - into each page where it is used. - - Args: - component: The component to collect shared StatefulComponents for. - rendered_components: A dict to store the rendered shared components in. - all_import_dicts: A list to store the imports of all shared components in. - """ - for child in component.children: - # Depth-first traversal. - _get_shared_components_recursive(child, rendered_components, all_import_dicts) - - # When the component is referenced by more than one page, render it - # to be included in the STATEFUL_COMPONENTS module. - # Skip this step in dev mode, thereby avoiding potential hot reload errors for larger apps - if isinstance(component, StatefulComponent) and component.references > 1: - # Reset this flag to render the actual component. - component.rendered_as_shared = False - - # Include dynamic imports in the shared component. - if dynamic_imports := component._get_all_dynamic_imports(): - rendered_components.update(dict.fromkeys(dynamic_imports)) - - # Include custom code in the shared component. - rendered_components.update(component._get_all_custom_code(export=True)) - - # Include all imports in the shared component. - all_import_dicts.append(component._get_all_imports()) - - # Indicate that this component now imports from the shared file. - component.rendered_as_shared = True - - -def _compile_stateful_components( - page_components: list[BaseComponent], -) -> str: - """Walk the page components and extract shared stateful components. - - Any StatefulComponent that is shared by more than one page will be rendered - to a separate module and marked rendered_as_shared so subsequent - renderings will import the component from the shared module instead of - directly including the code for it. - - Args: - page_components: The Components or StatefulComponents to compile. - - Returns: - The rendered stateful components code. - """ - all_import_dicts = [] - rendered_components = {} - - for page_component in page_components: - _get_shared_components_recursive( - page_component, rendered_components, all_import_dicts - ) - - # Don't import from the file that we're about to create. - all_imports = utils.merge_imports(*all_import_dicts) - all_imports.pop( - f"$/{constants.Dirs.UTILS}/{constants.PageNames.STATEFUL_COMPONENTS}", None - ) - if rendered_components: - _apply_common_imports(all_imports) - - return templates.stateful_components_template( - imports=utils.compile_imports(all_imports), - memoized_code="\n".join(rendered_components), - ) - - def compile_document_root( head_components: list[Component], html_lang: str | None = None, @@ -521,7 +458,7 @@ def compile_document_root( return output_path, code -def compile_app(app_root: Component) -> tuple[str, str]: +def compile_app_root(app_root: Component) -> tuple[str, str]: """Compile the app root. Args: @@ -596,6 +533,34 @@ def compile_page(path: str, component: BaseComponent) -> tuple[str, str]: return output_path, code +def compile_page_from_context(page_ctx: PageContext) -> tuple[str, str]: + """Compile a single page from a collected page context. + + Args: + page_ctx: The collected page context to render. + + Returns: + The path and code of the compiled page. + """ + output_path = utils.get_page_path(page_ctx.route) + imports = { + lib: list(fields) + for lib, fields in ( + page_ctx.frontend_imports or page_ctx.merged_imports(collapse=True) + ).items() + } + _apply_common_imports(imports) + + code = templates.page_template( + imports=utils.compile_imports(imports), + dynamic_imports=sorted(page_ctx.dynamic_imports), + custom_codes=page_ctx.custom_code_dict(), + hooks=page_ctx.hooks, + render=page_ctx.root_component.render(), + ) + return output_path, code + + def compile_memo_components( components: Iterable[CustomComponent], experimental_memos: Iterable[ExperimentalMemoDefinition] = (), @@ -617,36 +582,6 @@ def compile_memo_components( return output_path, code, imports -def compile_stateful_components( - pages: Iterable[Component], - progress_function: Callable[[], None], -) -> tuple[str, str, list[BaseComponent]]: - """Separately compile components that depend on State vars. - - StatefulComponents are compiled as their own component functions with their own - useContext declarations, which allows page components to be stateless and avoid - re-rendering along with parts of the page that actually depend on state. - - Args: - pages: The pages to extract stateful components from. - progress_function: A function to call to indicate progress, called once per page. - - Returns: - The path and code of the compiled stateful components. - """ - output_path = utils.get_stateful_components_path() - - page_components = [] - for page in pages: - # Compile the stateful components - page_component = StatefulComponent.compile_from(page) or page - progress_function() - page_components.append(page_component) - - code = _compile_stateful_components(page_components) if is_prod_mode() else "" - return output_path, code, page_components - - def purge_web_pages_dir(): """Empty out .web/pages directory.""" if not is_prod_mode() and environment.REFLEX_PERSIST_WEB_DIR.get(): @@ -661,7 +596,7 @@ def purge_web_pages_dir(): if TYPE_CHECKING: - from reflex.app import ComponentCallable, UnevaluatedPage + from reflex.app import App, ComponentCallable, UnevaluatedPage def _into_component_once( @@ -871,82 +806,321 @@ def compile_unevaluated_page( return component -class ExecutorSafeFunctions: - """Helper class to allow parallelisation of parts of the compilation process. +def _resolve_app_wrap_components( + app: App, + page_app_wrap_components: dict[tuple[int, str], Component], +) -> dict[tuple[int, str], Component]: + """Build the full app-wrap registry for compilation. - This class (and its class attributes) are available at global scope. + Args: + app: The app being compiled. + page_app_wrap_components: App-wrap components collected from pages. - In a multiprocessing context (like when using a ProcessPoolExecutor), the content of this - global class is logically replicated to any FORKED process. + Returns: + The merged app-wrap component registry. + """ + config = get_config() - How it works: - * Before the child process is forked, ensure that we stash any input data required by any future - function call in the child process. - * After the child process is forked, the child process will have a copy of the global class, which - includes the previously stashed input data. - * Any task submitted to the child process simply needs a way to communicate which input data the - requested function call requires. + app_wrappers: dict[tuple[int, str], Component] = { + (0, "AppWrap"): AppWrap.create(), + } + app_wrappers.update(page_app_wrap_components) + + if app.theme is not None: + app_wrappers[20, "Theme"] = app.theme + + if config.react_strict_mode: + from reflex_components_core.base.strict_mode import StrictMode + + app_wrappers[200, "StrictMode"] = StrictMode.create() + + if (toaster := app.toaster) is not None: + from reflex_base.components.component import memo + + @memo + def memoized_toast_provider(): + return toaster + + app_wrappers[44, "ToasterProvider"] = Fragment.create(memoized_toast_provider()) + + for wrap_mapping in (app.app_wraps, app.extra_app_wraps): + for key, app_wrap in wrap_mapping.items(): + component = app_wrap(app._state is not None) + if component is not None: + app_wrappers[key] = component + + return app_wrappers + + +def compile_app( + app: App, + *, + prerender_routes: bool = False, + dry_run: bool = False, + use_rich: bool = True, +) -> None: + """Compile an app using the compiler plugin pipeline.""" + from reflex_base.utils.exceptions import ReflexRuntimeError + + app._apply_decorated_pages() + app._pages = {} + + should_compile = app._should_compile() + backend_dir = prerequisites.get_backend_dir() + if not dry_run and not should_compile and backend_dir.exists(): + stateful_pages_marker = backend_dir / constants.Dirs.STATEFUL_PAGES + if stateful_pages_marker.exists(): + with stateful_pages_marker.open("r") as file: + stateful_pages = json.load(file) + for route in stateful_pages: + console.debug(f"BE Evaluating stateful page: {route}") + app._compile_page(route, save_page=False) + app._add_optional_endpoints() + return - Why do we need this? Passing input data directly to child process often not possible because the input data is not picklable. - The mechanic described here removes the need to pickle the input data at all. + if constants.Page404.SLUG not in app._unevaluated_pages: + app.add_page(route=constants.Page404.SLUG) - Limitations: - * This can never support returning unpicklable OUTPUT data. - * Any object mutations done by the child process will not propagate back to the parent process (fork goes one way!). + app.style = evaluate_style_namespaces(app.style) + config = get_config() - """ + if not should_compile and not dry_run: + with console.timing("Evaluate Pages (Backend)"): + for route in app._unevaluated_pages: + console.debug(f"Evaluating page: {route}") + app._compile_page(route, save_page=False) + + app._write_stateful_pages_marker() + app._add_optional_endpoints() + return + + progress = ( + Progress( + *Progress.get_default_columns()[:-1], + MofNCompleteColumn(), + TimeElapsedColumn(), + ) + if use_rich + else console.PoorProgress() + ) + fixed_steps = 7 + base_total = (len(app._unevaluated_pages) * 2) + fixed_steps + len(config.plugins) + progress.start() + task = progress.add_task("Compiling:", total=base_total) + + compile_ctx = CompileContext( + app=app, + pages=list(app._unevaluated_pages.values()), + hooks=CompilerHooks( + plugins=default_page_plugins(style=app.style, theme=app.theme) + ), + ) + + with console.timing("Compile pages"), compile_ctx: + compile_ctx.compile( + evaluate_progress=lambda: progress.advance(task), + render_progress=lambda: progress.advance(task), + ) + + for route, page_ctx in compile_ctx.compiled_pages.items(): + app._check_routes_conflict(route) + if not isinstance(page_ctx.root_component, Component): + msg = ( + f"Compiled page {route!r} root must be a Component before it can " + "be registered on the app." + ) + raise TypeError(msg) + app._pages[route] = page_ctx.root_component + + app._stateful_pages.update(compile_ctx.stateful_routes) + app._write_stateful_pages_marker() + app._add_optional_endpoints() + app._validate_var_dependencies() + + if config.show_built_with_reflex is None: + if ( + get_compile_context() == constants.CompileContext.DEPLOY + and prerequisites.get_user_tier() in ["pro", "team", "enterprise"] + ): + config.show_built_with_reflex = False + else: + config.show_built_with_reflex = True + + if is_prod_mode() and config.show_built_with_reflex: + app._setup_sticky_badge() + + progress.advance(task) + + compile_results = [ + (page_ctx.output_path, page_ctx.output_code) + for page_ctx in compile_ctx.compiled_pages.values() + if page_ctx.output_path is not None and page_ctx.output_code is not None + ] + all_imports = compile_ctx.all_imports + + if app._state is None and any( + code_uses_state_contexts(page_ctx.output_code or "") + for page_ctx in compile_ctx.compiled_pages.values() + ): + msg = ( + "To access rx.State in frontend components, at least one " + "subclass of rx.State must be defined in the app." + ) + raise ReflexRuntimeError(msg) + progress.advance(task) + + app_wrappers = _resolve_app_wrap_components(app, compile_ctx.app_wrap_components) + app_root = app._app_root(app_wrappers) + all_imports = utils.merge_imports(all_imports, app_root._get_all_imports()) + + ( + memo_components_output, + memo_components_result, + memo_components_imports, + ) = compile_memo_components( + dict.fromkeys(CUSTOM_COMPONENTS.values()), + ( + *tuple(EXPERIMENTAL_MEMOS.values()), + *tuple(compile_ctx.auto_memo_components.values()), + ), + ) + compile_results.append((memo_components_output, memo_components_result)) + all_imports = utils.merge_imports(all_imports, memo_components_imports) + progress.advance(task) + + compile_results.append( + compile_document_root( + app.head_components, + html_lang=app.html_lang, + html_custom_attrs=( + {"suppressHydrationWarning": True, **app.html_custom_attrs} + if app.html_custom_attrs + else {"suppressHydrationWarning": True} + ), + ) + ) + progress.advance(task) + + assets_src = Path.cwd() / constants.Dirs.APP_ASSETS + if assets_src.is_dir() and not dry_run: + with console.timing("Copy assets"): + path_ops.update_directory_tree( + src=assets_src, + dest=Path.cwd() / prerequisites.get_web_dir() / constants.Dirs.PUBLIC, + ) - COMPONENTS: dict[str, BaseComponent] = {} - UNCOMPILED_PAGES: dict[str, UnevaluatedPage] = {} - - @classmethod - def compile_page(cls, route: str) -> tuple[str, str]: - """Compile a page. - - Args: - route: The route of the page to compile. - - Returns: - The path and code of the compiled page. - """ - return compile_page(route, cls.COMPONENTS[route]) - - @classmethod - def compile_unevaluated_page( - cls, - route: str, - style: ComponentStyle, - theme: Component | None, - ) -> tuple[str, Component, tuple[str, str]]: - """Compile an unevaluated page. - - Args: - route: The route of the page to compile. - style: The style of the page. - theme: The theme of the page. - - Returns: - The route, compiled component, and compiled page. - """ - component = compile_unevaluated_page( - route, cls.UNCOMPILED_PAGES[route], style, theme + save_tasks: list[ + tuple[ + Callable[..., list[tuple[str, str]] | tuple[str, str] | None], + tuple[Any, ...], + dict[str, Any], + ] + ] = [] + modify_files_tasks: list[tuple[str, str, Callable[[str], str]]] = [] + + def add_save_task( + task_fn: Callable[..., list[tuple[str, str]] | tuple[str, str] | None], + /, + *args: Any, + **kwargs: Any, + ) -> None: + save_tasks.append((task_fn, args, kwargs)) + + for plugin in config.plugins: + plugin.pre_compile( + add_save_task=add_save_task, + add_modify_task=lambda *args, plugin=plugin: modify_files_tasks.append(( + plugin.__class__.__module__ + plugin.__class__.__name__, + *args, + )), + unevaluated_pages=list(app._unevaluated_pages.values()), ) - return route, component, compile_page(route, component) - @classmethod - def compile_theme(cls, style: ComponentStyle | None) -> tuple[str, str]: - """Compile the theme. + if save_tasks: + _set_progress_total(progress, task, base_total + len(save_tasks)) + + progress.advance(task, advance=len(config.plugins)) + + compile_results.append(compile_root_stylesheet(app.stylesheets, app.reset_style)) + progress.advance(task) + + compile_results.append(compile_theme(app.style)) + progress.advance(task) + + for task_fn, args, kwargs in save_tasks: + result = task_fn(*args, **kwargs) + if result is None: + progress.advance(task) + continue + if isinstance(result, list): + compile_results.extend(result) + else: + compile_results.append(result) + progress.advance(task) + + compile_results.append(compile_contexts(app._state, app.theme)) + if app.theme is not None: + app.theme.appearance = None # pyright: ignore[reportAttributeAccessIssue] + progress.advance(task) + + compile_results.append(compile_app_root(app_root)) + progress.advance(task) + + progress.stop() - Args: - style: The style to compile. + if dry_run: + return + + with console.timing("Install Frontend Packages"): + app._get_frontend_packages(all_imports) + + frontend_skeleton.update_react_router_config( + prerender_routes=prerender_routes, + ) - Returns: - The path and code of the compiled theme. + if is_prod_mode(): + purge_web_pages_dir() + else: + keep_files = [Path(output_path) for output_path, _ in compile_results] + for page_file in Path( + prerequisites.get_web_dir() / constants.Dirs.PAGES / constants.Dirs.ROUTES + ).rglob("*"): + if page_file.is_file() and page_file not in keep_files: + page_file.unlink() + + output_mapping: dict[Path, str] = {} + for output_path, code in compile_results: + path = utils.resolve_path_of_web_dir(output_path) + if path in output_mapping: + console.warn( + f"Path {path} has two different outputs. The first one will be used." + ) + else: + output_mapping[path] = code + + for plugin in config.plugins: + for static_file_path, content in plugin.get_static_assets(): + path = utils.resolve_path_of_web_dir(static_file_path) + if path in output_mapping: + console.warn( + f"Plugin {plugin.__class__.__name__} is trying to write to {path} but it already exists. The plugin file will be ignored." + ) + else: + output_mapping[path] = ( + content.decode("utf-8") if isinstance(content, bytes) else content + ) + + for plugin_name, file_path, modify_fn in modify_files_tasks: + path = utils.resolve_path_of_web_dir(file_path) + file_content = output_mapping.get(path) + if file_content is None: + if path.exists(): + file_content = path.read_text() + else: + msg = f"Plugin {plugin_name} is trying to modify {path} but it does not exist." + raise FileNotFoundError(msg) + output_mapping[path] = modify_fn(file_content) - Raises: - ValueError: If the style is not set. - """ - if style is None: - msg = "STYLE should be set" - raise ValueError(msg) - return compile_theme(style) + with console.timing("Write to Disk"): + for output_path, code in output_mapping.items(): + utils.write_file(output_path, code) diff --git a/reflex/compiler/plugins/__init__.py b/reflex/compiler/plugins/__init__.py new file mode 100644 index 00000000000..92e34115e3e --- /dev/null +++ b/reflex/compiler/plugins/__init__.py @@ -0,0 +1,30 @@ +"""Built-in compiler plugins for single-pass page compilation.""" + +from reflex_base.plugins import ( + BaseContext, + CompileContext, + CompilerHooks, + ComponentAndChildren, + PageContext, +) + +from .builtin import ( + ApplyStylePlugin, + DefaultCollectorPlugin, + DefaultPagePlugin, + default_page_plugins, +) +from .memoize import MemoizeStatefulPlugin + +__all__ = [ + "ApplyStylePlugin", + "BaseContext", + "CompileContext", + "CompilerHooks", + "ComponentAndChildren", + "DefaultCollectorPlugin", + "DefaultPagePlugin", + "MemoizeStatefulPlugin", + "PageContext", + "default_page_plugins", +] diff --git a/reflex/compiler/plugins/builtin.py b/reflex/compiler/plugins/builtin.py new file mode 100644 index 00000000000..0fcfdcef647 --- /dev/null +++ b/reflex/compiler/plugins/builtin.py @@ -0,0 +1,430 @@ +"""Built-in compiler plugins and the default plugin pipeline.""" + +from __future__ import annotations + +import dataclasses +from collections.abc import Callable +from typing import Any + +from reflex_base.components.component import BaseComponent, Component, ComponentStyle +from reflex_base.config import get_config +from reflex_base.plugins import CompileContext, PageContext, PageDefinition, Plugin +from reflex_base.utils.format import make_default_page_title +from reflex_base.utils.imports import collapse_imports, merge_imports +from reflex_base.vars import VarData +from reflex_components_core.base.fragment import Fragment + +from reflex.compiler import utils + + +@dataclasses.dataclass(frozen=True, slots=True) +class DefaultPagePlugin(Plugin): + """Evaluate an unevaluated page into a mutable page context.""" + + def eval_page( + self, + page_fn: Any, + /, + *, + page: PageDefinition, + **kwargs: Any, + ) -> PageContext: + """Evaluate the page function and attach legacy page metadata. + + Returns: + The evaluated page context. + """ + from reflex.compiler import compiler + + del kwargs + + try: + component = compiler.into_component(page_fn) + component = Fragment.create(component) + + title = getattr(page, "title", None) + meta_args = { + "title": ( + title + if title is not None + else make_default_page_title(get_config().app_name, page.route) + ), + "image": getattr(page, "image", ""), + "meta": getattr(page, "meta", ()), + } + if (description := getattr(page, "description", None)) is not None: + meta_args["description"] = description + + utils.add_meta(component, **meta_args) + except Exception as err: + if hasattr(err, "add_note"): + err.add_note(f"Happened while evaluating page {page.route!r}") + raise + + return PageContext( + name=getattr(page_fn, "__name__", page.route), + route=page.route, + root_component=component, + ) + + +@dataclasses.dataclass(frozen=True, slots=True) +class ApplyStylePlugin(Plugin): + """Apply app-level styles in the descending phase of the walk.""" + + _compiler_can_replace_enter_component = False + style: ComponentStyle | None = None + theme: Component | None = None + + @staticmethod + def _apply_style(comp: Component, style: ComponentStyle) -> None: + """Apply app-level styles to a single component. + + Args: + comp: The component to style. + style: The app-level component style map. + """ + if type(comp)._add_style != Component._add_style: + msg = "Do not override _add_style directly. Use add_style instead." + raise UserWarning(msg) + + new_style = comp._add_style() + style_vars = [new_style._var_data] + + component_style = comp._get_component_style(style) + if component_style: + new_style.update(component_style) + style_vars.append(component_style._var_data) + + new_style.update(comp.style) + style_vars.append(comp.style._var_data) + new_style._var_data = VarData.merge(*style_vars) + comp.style = new_style + + def enter_component( + self, + comp: BaseComponent, + /, + *, + page_context: PageContext, + compile_context: Any, + in_prop_tree: bool = False, + ) -> None: + """Apply the non-recursive portion of ``_add_style_recursive``.""" + del page_context, compile_context + + if self.style is not None and isinstance(comp, Component) and not in_prop_tree: + self._apply_style(comp, self.style) + + def _compiler_bind_enter_component( + self, + page_context: PageContext, + compile_context: CompileContext, + ) -> Callable[[BaseComponent, bool], None]: + """Bind a positional fast-path enter hook for style application. + + Returns: + A compiled enter hook that only takes hot-loop positional state. + """ + del page_context, compile_context + + style = self.style + if style is None: + + def enter_component( + comp: BaseComponent, + in_prop_tree: bool, + ) -> None: + del comp, in_prop_tree + + return enter_component + + apply_style = self._apply_style + + def enter_component( + comp: BaseComponent, + in_prop_tree: bool, + ) -> None: + if not isinstance(comp, Component) or in_prop_tree: + return + + apply_style(comp, style) + + return enter_component + + +@dataclasses.dataclass(frozen=True, slots=True) +class DefaultCollectorPlugin(Plugin): + """Collect page artifacts in one fused enter/leave hook pair.""" + + _compiler_can_replace_enter_component = False + _compiler_can_replace_leave_component = False + + def enter_component( + self, + comp: BaseComponent, + /, + *, + page_context: PageContext, + compile_context: Any, + in_prop_tree: bool = False, + ) -> None: + """Collect imports and page artifacts for the active component node.""" + del compile_context + + if not isinstance(comp, Component): + return + + imports = comp._get_imports() + if imports: + self._extend_imports(page_context.frontend_imports, imports) + + if not in_prop_tree: + self._collect_component_custom_code(page_context.module_code, comp) + + self._collect_component_hooks(page_context.hooks, comp) + + if ( + type(comp)._get_app_wrap_components + is not Component._get_app_wrap_components + ): + self._collect_app_wrap_components( + page_context.app_wrap_components, + comp, + ) + + if (dynamic_import := comp._get_dynamic_imports()) is not None: + page_context.dynamic_imports.add(dynamic_import) + + if (ref := comp.get_ref()) is not None: + page_context.refs[ref] = None + + def compile_page( + self, + page_ctx: PageContext, + /, + **kwargs: Any, + ) -> None: + """Collapse collected imports into a single legacy-shaped entry.""" + del kwargs + if page_ctx.frontend_imports: + collapsed_imports = collapse_imports( + merge_imports(page_ctx.frontend_imports, *page_ctx.imports) + if page_ctx.imports + else page_ctx.frontend_imports + ) + page_ctx.frontend_imports = collapsed_imports + page_ctx.imports = [collapsed_imports] + return + + page_ctx.imports = ( + [collapse_imports(merge_imports(*page_ctx.imports))] + if page_ctx.imports + else [] + ) + + def _compiler_bind_enter_component( + self, + page_context: PageContext, + compile_context: CompileContext, + ) -> Callable[[BaseComponent, bool], None]: + """Bind a positional fast-path enter hook for artifact collection. + + Returns: + A compiled enter hook that only takes hot-loop positional state. + """ + del compile_context + + frontend_imports = page_context.frontend_imports + module_code = page_context.module_code + hooks = page_context.hooks + dynamic_imports = page_context.dynamic_imports + refs = page_context.refs + app_wrap_components = page_context.app_wrap_components + extend_imports = self._extend_imports + collect_component_hooks = self._collect_component_hooks + collect_component_custom_code = self._collect_component_custom_code + collect_app_wrap_components = self._collect_app_wrap_components + base_get_app_wrap_components = Component._get_app_wrap_components + seen_app_wrap_methods: set[object] = set() + + def enter_component( + comp: BaseComponent, + in_prop_tree: bool, + ) -> None: + if not isinstance(comp, Component): + return + + imports_for_component = comp._get_imports() + if imports_for_component: + extend_imports(frontend_imports, imports_for_component) + + if not in_prop_tree: + collect_component_custom_code(module_code, comp) + + collect_component_hooks(hooks, comp) + + app_wrap_method = type(comp)._get_app_wrap_components + if ( + app_wrap_method is not base_get_app_wrap_components + and app_wrap_method not in seen_app_wrap_methods + ): + seen_app_wrap_methods.add(app_wrap_method) + collect_app_wrap_components(app_wrap_components, comp) + + dynamic_import = comp._get_dynamic_imports() + if dynamic_import is not None: + dynamic_imports.add(dynamic_import) + + ref = comp.get_ref() + if ref is not None: + refs[ref] = None + + return enter_component + + @staticmethod + def _collect_component_hooks( + page_hooks: dict[str, VarData | None], + component: Component, + ) -> None: + """Collect hooks for one structural-tree component in legacy order.""" + page_hooks.update(component._get_hooks_internal()) + if (user_hooks := component._get_hooks()) is not None: + page_hooks[user_hooks] = None + page_hooks.update(component._get_added_hooks()) + + @staticmethod + def _extend_imports( + target: dict[str, list[Any]], + source: dict[str, list[Any]], + ) -> None: + """Extend a parsed import mapping in place.""" + for lib, fields in source.items(): + target.setdefault(lib, []).extend(fields) + + @staticmethod + def _collect_component_custom_code( + module_code: dict[str, None], + component: Component, + ) -> None: + """Collect custom code for one structural-tree component in legacy order.""" + if (custom_code := component._get_custom_code()) is not None: + module_code[custom_code] = None + + for prop_component in component._get_components_in_props(): + DefaultCollectorPlugin._collect_prop_custom_code_into( + prop_component, + module_code, + ) + + for clz in component._iter_parent_classes_with_method("add_custom_code"): + for item in clz.add_custom_code(component): + module_code[item] = None + + @staticmethod + def _collect_prop_custom_code_into( + component: BaseComponent, + module_code: dict[str, None], + ) -> None: + """Recursively collect prop-tree custom code directly into ``module_code``.""" + if not isinstance(component, Component): + module_code.update(component._get_all_custom_code()) + return + + if (custom_code := component._get_custom_code()) is not None: + module_code[custom_code] = None + + for prop_component in component._get_components_in_props(): + DefaultCollectorPlugin._collect_prop_custom_code_into( + prop_component, + module_code, + ) + + for clz in component._iter_parent_classes_with_method("add_custom_code"): + for item in clz.add_custom_code(component): + module_code[item] = None + + for child in component.children: + DefaultCollectorPlugin._collect_prop_custom_code_into( + child, + module_code, + ) + + def _collect_app_wrap_components( + self, + page_app_wrap_components: dict[tuple[int, str], Component], + component: Component, + ) -> None: + """Collect app-wrap components for a structural-tree component.""" + direct_wrappers = component._get_app_wrap_components() + if not direct_wrappers: + return + + ignore_ids = {id(wrapper) for wrapper in page_app_wrap_components.values()} + page_app_wrap_components.update(direct_wrappers) + for wrapper in direct_wrappers.values(): + wrapper_id = id(wrapper) + if wrapper_id in ignore_ids: + continue + ignore_ids.add(wrapper_id) + self._collect_wrapper_subtree_into( + wrapper, + ignore_ids, + page_app_wrap_components, + ) + + @staticmethod + def _collect_wrapper_subtree_into( + component: Component, + ignore_ids: set[int], + components: dict[tuple[int, str], Component], + ) -> None: + """Collect nested app-wrap components into ``components``.""" + direct_wrappers = component._get_app_wrap_components() + for key, wrapper in direct_wrappers.items(): + wrapper_id = id(wrapper) + if wrapper_id in ignore_ids: + continue + ignore_ids.add(wrapper_id) + components[key] = wrapper + DefaultCollectorPlugin._collect_wrapper_subtree_into( + wrapper, + ignore_ids, + components, + ) + + for child in component.children: + if not isinstance(child, Component): + continue + child_id = id(child) + if child_id in ignore_ids: + continue + ignore_ids.add(child_id) + DefaultCollectorPlugin._collect_wrapper_subtree_into( + child, + ignore_ids, + components, + ) + + +def default_page_plugins( + *, + style: ComponentStyle | None = None, + theme: Component | None = None, +) -> tuple[Plugin, ...]: + """Return the default compiler plugin ordering for page compilation.""" + from reflex.compiler.plugins.memoize import MemoizeStatefulPlugin + + plugins: list[Plugin] = [DefaultPagePlugin()] + if style is not None: + plugins.append(ApplyStylePlugin(style=style, theme=theme)) + plugins.extend((MemoizeStatefulPlugin(), DefaultCollectorPlugin())) + return tuple(plugins) + + +__all__ = [ + "ApplyStylePlugin", + "DefaultCollectorPlugin", + "DefaultPagePlugin", + "default_page_plugins", +] diff --git a/reflex/compiler/plugins/memoize.py b/reflex/compiler/plugins/memoize.py new file mode 100644 index 00000000000..c2ee5337afa --- /dev/null +++ b/reflex/compiler/plugins/memoize.py @@ -0,0 +1,289 @@ +"""MemoizeStatefulPlugin — auto-memoize stateful components with ``rx._x.memo``. + +This plugin replaces the legacy ``StatefulComponent`` wrapping pass. It +participates in the normal single-pass walk via ``enter_component`` and inserts +per-subtree ``{children}``-pass-through wrappers built on the experimental +memo infrastructure. The wrapped subtree remains in the tree for the normal +walker descent, so downstream plugins (e.g. ``DefaultCollectorPlugin``) still +see the original components and collect their imports/hooks as usual. + +Each unique subtree shape contributes: + +- One generated experimental memo component definition, compiled into the + shared ``$/utils/components`` module. +- ``useCallback`` hook lines for each non-lifecycle event trigger, emitted into + ``page_context.hooks`` so the declarations live at the top of the page body. + +No shared ``stateful_components`` file is produced. +""" + +from __future__ import annotations + +import contextlib +import dataclasses +from functools import cache +from typing import Any + +from reflex_base.components.component import ( + BaseComponent, + Component, + _deterministic_hash, + _hash_str, +) +from reflex_base.components.memoize_helpers import ( + fix_event_triggers_for_memo, + invalidate_event_trigger_caches, +) +from reflex_base.constants.compiler import MemoizationDisposition +from reflex_base.plugins import ComponentAndChildren, PageContext +from reflex_base.plugins.base import Plugin +from reflex_base.utils import format +from reflex_base.vars.base import Var + +from reflex.experimental.memo import create_passthrough_component_memo + +# --------------------------------------------------------------------------- # +# Tag naming + memoize-eligibility # +# --------------------------------------------------------------------------- # + + +def _child_var(child: Component) -> Var | Component: + """Return the core Var of a structural child, for memoize-eligibility checks. + + For special wrappers (``Bare``/``Cond``/``Foreach``/``Match``) we peek at + the contained Var instead of recursing into the wrapper component itself. + + Args: + child: The child component to inspect. + + Returns: + The contained Var if ``child`` is a special wrapper, else ``child``. + """ + from reflex_components_core.base.bare import Bare + from reflex_components_core.core.cond import Cond + from reflex_components_core.core.foreach import Foreach + from reflex_components_core.core.match import Match + + if isinstance(child, Bare): + return child.contents + if isinstance(child, Cond): + return child.cond + if isinstance(child, Foreach): + return child.iterable + if isinstance(child, Match): + return child.cond + return child + + +def _compute_memo_tag(component: Component) -> str | None: + """Compute a stable tag name for a memoizable component. + + Returns ``None`` for components that render empty (non-visual components + are never memoized). + + Args: + component: The component to name. + + Returns: + The stable tag name, or ``None`` if the component renders empty. + """ + rendered_code = component.render() + if not rendered_code: + return None + code_hash = _hash_str(_deterministic_hash(rendered_code)) + return format.format_state_name( + f"{component.tag or 'Comp'}_{code_hash}" + ).capitalize() + + +def _should_memoize(component: Component) -> bool: + """Decide whether ``component`` is a candidate for auto-memoization. + + Checks for DIRECT triggers only (not walking into descendants): the + component's own Vars with var_data, event_triggers, or special child + types (Bare/Cond/Foreach/Match) whose probe Var carries var_data. + + Args: + component: The candidate component. + + Returns: + True if the component should be wrapped in a memo definition. + """ + from reflex_components_core.core.foreach import Foreach + + if component._memoization_mode.disposition == MemoizationDisposition.NEVER: + return False + if component.tag is None: + return False + if component._memoization_mode.disposition == MemoizationDisposition.ALWAYS: + return True + + # Direct Vars only (component's own props, style, class_name, id, etc.). + for prop_var in component._get_vars(include_children=False): + if prop_var._get_all_var_data(): + return True + + # Special-case structural children that are Var wrappers (Bare/Cond/ + # Foreach/Match). Foreach is always memoized because it produces dynamic + # child trees that React must reconcile by key. + for child in component.children: + if not isinstance(child, Component): + continue + if isinstance(child, Foreach): + return True + probe = _child_var(child) + if isinstance(probe, Var) and probe._get_all_var_data(): + return True + + # Components with event triggers are always memoized (to wrap callbacks). + return bool(component.event_triggers) + + +@cache +def _get_passthrough_memo_component(tag: str) -> tuple[Any, Any]: + """Return the generated experimental memo wrapper callable and definition. + + Args: + tag: The wrapper's exported component name. + + Returns: + The memo wrapper callable and its definition. + """ + return create_passthrough_component_memo(tag) + + +# --------------------------------------------------------------------------- # +# The plugin # +# --------------------------------------------------------------------------- # + + +@dataclasses.dataclass(frozen=True, slots=True) +class MemoizeStatefulPlugin(Plugin): + """Auto-memoize stateful components with ``{children}``-pass-through memos. + + Registered in ``default_page_plugins`` between ``ApplyStylePlugin`` and + ``DefaultCollectorPlugin``. On ``enter_component`` it decides whether a + component should be memoized, and if so wraps it in a generated + experimental memo component whose single child is the original. The walker + then descends into the original component normally so + ``DefaultCollectorPlugin`` still sees its subtree. + + A ``_memoize_wrapped`` attribute marks the original component so the + recursive visit doesn't re-wrap it. + """ + + _compiler_can_replace_enter_component = True + _compiler_can_replace_leave_component = False + + def enter_component( + self, + comp: BaseComponent, + /, + *, + page_context: PageContext, + compile_context: Any, + in_prop_tree: bool = False, + ) -> BaseComponent | ComponentAndChildren | None: + """Wrap eligible stateful components in an experimental memo component. + + Args: + comp: The component being visited. + page_context: The active page context. + compile_context: The active compile context. + in_prop_tree: Whether the component is in a prop subtree. + + Returns: + A ``(wrapper, (comp,))`` tuple replacement when ``comp`` is + memoizable, else ``None``. + """ + if in_prop_tree: + return None + if not isinstance(comp, Component): + return None + + # Re-entry guard: when the walker descends into our wrapped child, it + # calls enter_component on the original comp again. Clear the marker + # and pass through. + if getattr(comp, "_memoize_wrapped", False): + with contextlib.suppress(AttributeError): + del comp._memoize_wrapped # pyright: ignore[reportAttributeAccessIssue] + return None + + # Inside a MemoizationLeaf subtree, do not independently wrap + # descendants (the leaf owns the wrapping decision for its subtree). + if getattr(page_context, "_memoize_suppress_depth", 0) > 0: + return None + + is_memoization_leaf = not comp._memoization_mode.recursive + + if not _should_memoize(comp): + if is_memoization_leaf: + # Leaf that wasn't memoized still suppresses descendants. + page_context._memoize_suppress_depth = ( # type: ignore[attr-defined] + getattr(page_context, "_memoize_suppress_depth", 0) + 1 + ) + comp._memoize_pushed_suppression = True # type: ignore[attr-defined] + return None + + tag = _compute_memo_tag(comp) + if tag is None: + return None + + # Memoize event triggers, collect useCallback hooks for the page body. + memo_trigger_hooks = fix_event_triggers_for_memo(comp) + if memo_trigger_hooks: + invalidate_event_trigger_caches(comp) + for hook in memo_trigger_hooks: + page_context.hooks[hook] = None + + compile_context.memoize_wrappers[tag] = None + wrapper_factory, definition = _get_passthrough_memo_component(tag) + compile_context.auto_memo_components[tag] = definition + + # If comp is a MemoizationLeaf that IS being wrapped, suppress + # descendant wrapping for its subtree. + if is_memoization_leaf: + page_context._memoize_suppress_depth = ( # type: ignore[attr-defined] + getattr(page_context, "_memoize_suppress_depth", 0) + 1 + ) + comp._memoize_pushed_suppression = True # type: ignore[attr-defined] + + # Mark the original so the recursive re-enter skips wrapping. + comp._memoize_wrapped = True # type: ignore[attr-defined] + + wrapper = wrapper_factory(comp) + return (wrapper, (comp,)) + + def leave_component( + self, + comp: BaseComponent, + children: tuple[BaseComponent, ...], + /, + *, + page_context: PageContext, + compile_context: Any, + in_prop_tree: bool = False, + ) -> BaseComponent | ComponentAndChildren | None: + """Pop the ``MemoizationLeaf`` suppression counter if we pushed one. + + Args: + comp: The component being visited. + children: Its compiled children (unused). + page_context: The active page context. + compile_context: The active compile context (unused). + in_prop_tree: Whether the component is in a prop subtree (unused). + + Returns: + Always ``None``. + """ + del children, compile_context, in_prop_tree + if getattr(comp, "_memoize_pushed_suppression", False): + page_context._memoize_suppress_depth = ( # type: ignore[attr-defined] + getattr(page_context, "_memoize_suppress_depth", 1) - 1 + ) + with contextlib.suppress(AttributeError): + del comp._memoize_pushed_suppression # pyright: ignore[reportAttributeAccessIssue] + return None + + +__all__ = ["MemoizeStatefulPlugin"] diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index bb812c67c5f..3dad8c3cd1f 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -656,19 +656,6 @@ def get_components_path() -> str: ) -def get_stateful_components_path() -> str: - """Get the path of the compiled stateful components. - - Returns: - The path of the compiled stateful components. - """ - return str( - get_web_dir() - / constants.Dirs.UTILS - / (constants.PageNames.STATEFUL_COMPONENTS + constants.Ext.JSX) - ) - - def add_meta( page: Component, title: str, diff --git a/reflex/experimental/memo.py b/reflex/experimental/memo.py index 7dee0c72eea..b6904e6955c 100644 --- a/reflex/experimental/memo.py +++ b/reflex/experimental/memo.py @@ -76,6 +76,19 @@ class ExperimentalMemoComponent(Component): library = f"$/{constants.Dirs.COMPONENTS_PATH}" + def _validate_component_children(self, children: list[Component]) -> None: + """Skip direct parent/child validation for memo wrapper instances. + + Experimental memos wrap an underlying compiled component definition. + The runtime wrapper should not interpose on `_valid_parents` checks for + the authored subtree because the wrapper itself is not the semantic + parent in the user-authored component tree. + + Args: + children: The children of the component (ignored). + """ + del children + def _post_init(self, **kwargs): """Initialize the experimental memo component. @@ -950,6 +963,40 @@ def _create_component_wrapper( return _ExperimentalMemoComponentWrapper(definition) +@cache +def create_passthrough_component_memo( + export_name: str, +) -> tuple[ + Callable[..., ExperimentalMemoComponent], + ExperimentalMemoComponentDefinition, +]: + """Create an unregistered ``@rx._x.memo``-style passthrough component memo. + + This is used by compiler auto-memoization so generated wrappers compile + through the experimental memo pipeline instead of emitting ad-hoc page-local + ``React.memo`` declarations. + + Args: + export_name: The exported memo component name. + + Returns: + The callable memo wrapper and its component definition. + """ + + def passthrough(children: Var[Component]) -> Component: + return Bare.create(children) + + passthrough.__name__ = format.to_snake_case(export_name) + passthrough.__qualname__ = passthrough.__name__ + passthrough.__module__ = __name__ + + definition = _create_component_definition(passthrough, Component) + if definition.export_name != export_name: + definition = dataclasses.replace(definition, export_name=export_name) + + return _create_component_wrapper(definition), definition + + def memo(fn: Callable[..., Any]) -> Callable[..., Any]: """Create an experimental memo from a function. @@ -986,3 +1033,14 @@ def memo(fn: Callable[..., Any]) -> Callable[..., Any]: f"got `{return_annotation}`." ) raise TypeError(msg) + + +__all__ = [ + "EXPERIMENTAL_MEMOS", + "ExperimentalMemoComponent", + "ExperimentalMemoComponentDefinition", + "ExperimentalMemoDefinition", + "ExperimentalMemoFunctionDefinition", + "create_passthrough_component_memo", + "memo", +] diff --git a/reflex/plugins/__init__.py b/reflex/plugins/__init__.py index ac50297d040..114646fd0b1 100644 --- a/reflex/plugins/__init__.py +++ b/reflex/plugins/__init__.py @@ -1,7 +1,12 @@ """Re-export from reflex_base.plugins.""" from reflex_base.plugins import ( + BaseContext, CommonContext, + CompileContext, + CompilerHooks, + ComponentAndChildren, + PageContext, Plugin, PreCompileContext, SitemapPlugin, @@ -14,7 +19,12 @@ ) __all__ = [ + "BaseContext", "CommonContext", + "CompileContext", + "CompilerHooks", + "ComponentAndChildren", + "PageContext", "Plugin", "PreCompileContext", "SitemapPlugin", diff --git a/tests/benchmarks/fixtures.py b/tests/benchmarks/fixtures.py index c20ac177660..63469330109 100644 --- a/tests/benchmarks/fixtures.py +++ b/tests/benchmarks/fixtures.py @@ -1,10 +1,14 @@ +from collections.abc import Callable from dataclasses import dataclass -from typing import cast +from typing import Any, cast import pytest from pydantic import BaseModel +from reflex_base.components.component import BaseComponent, Component +from reflex_base.plugins import CompileContext, PageContext import reflex as rx +from reflex.compiler.plugins import DefaultCollectorPlugin class SideBarState(rx.State): @@ -221,6 +225,53 @@ class NestedElement(BaseModel): value: list[int] +@dataclass(frozen=True, slots=True) +class ImportOnlyCollectorPlugin(DefaultCollectorPlugin): + """Collect only imports — same scope as Component._get_all_imports. + + Inherits import collection from DefaultCollectorPlugin but disables + hooks, custom code, app_wrap, and stateful code rendering. + """ + + _compiler_stateful_only_leave_component = False + + def leave_component(self, *_args: Any, **_kwargs: Any) -> None: + """No-op: skip stateful code rendering.""" + + def _compiler_bind_leave_component( + self, *_args: Any, **_kwargs: Any + ) -> Callable[..., None]: + """Return a no-op leave hook.""" + + def _noop(*_a: Any, **_kw: Any) -> None: + pass + + return _noop + + def _compiler_bind_enter_component( + self, + page_context: PageContext, + compile_context: CompileContext, + ) -> Callable[[BaseComponent, bool], None]: + del compile_context + + frontend_imports = page_context.frontend_imports + extend_imports = self._extend_imports + + def enter_component( + comp: BaseComponent, + in_prop_tree: bool, + ) -> None: + if not isinstance(comp, Component) or in_prop_tree: + return + + imports = comp._get_imports() + if imports: + extend_imports(frontend_imports, imports) + + return enter_component + + class BenchmarkState(rx.State): """State for the benchmark.""" diff --git a/tests/benchmarks/test_compilation.py b/tests/benchmarks/test_compilation.py index 7ad0f666f76..f9b1f134e5b 100644 --- a/tests/benchmarks/test_compilation.py +++ b/tests/benchmarks/test_compilation.py @@ -1,7 +1,16 @@ +import copy + from pytest_codspeed import BenchmarkFixture from reflex_base.components.component import Component +from reflex_base.plugins import CompileContext, CompilerHooks, PageContext + +from reflex.app import UnevaluatedPage +from reflex.compiler import compiler +from reflex.compiler.compiler import _compile_page +from reflex.compiler.plugins import DefaultCollectorPlugin, default_page_plugins +from reflex.compiler.plugins.memoize import MemoizeStatefulPlugin -from reflex.compiler.compiler import _compile_page, _compile_stateful_components +from .fixtures import ImportOnlyCollectorPlugin def import_templates(): @@ -9,17 +18,126 @@ def import_templates(): import reflex.compiler.templates # noqa: F401 +def _compile_single_pass_page_ctx(component: Component) -> PageContext: + # The single-pass compiler mutates the tree in place when it inserts memo + # wrappers, so benchmark iterations need an isolated copy of the input. + component = copy.deepcopy(component) + page_ctx = PageContext( + name="benchmark", + route="/benchmark", + root_component=component, + ) + hooks = CompilerHooks( + plugins=(MemoizeStatefulPlugin(), DefaultCollectorPlugin()), + ) + compile_ctx = CompileContext(pages=[], hooks=hooks) + + with compile_ctx, page_ctx: + page_ctx.root_component = hooks.compile_component( + page_ctx.root_component, + page_context=page_ctx, + compile_context=compile_ctx, + ) + hooks.compile_page(page_ctx, compile_context=compile_ctx) + + return page_ctx + + +def _get_imports_single_pass(component: Component) -> dict: + """Collect only imports via a single-pass walk — comparable to _get_all_imports. + + Returns: + The collapsed import dict for the page. + """ + page_ctx = PageContext( + name="benchmark", + route="/benchmark", + root_component=component, + ) + hooks = CompilerHooks(plugins=(ImportOnlyCollectorPlugin(),)) + compile_ctx = CompileContext(pages=[], hooks=hooks) + + with compile_ctx, page_ctx: + hooks.compile_component( + component, + page_context=page_ctx, + compile_context=compile_ctx, + ) + hooks.compile_page(page_ctx, compile_context=compile_ctx) + + return page_ctx.frontend_imports + + +def _compile_page_single_pass(component: Component) -> str: + page_ctx = _compile_single_pass_page_ctx(component) + page_ctx.frontend_imports = page_ctx.merged_imports(collapse=True) + return compiler.compile_page_from_context(page_ctx)[1] + + +def _compile_page_full_context(unevaluated_page) -> str: + page = UnevaluatedPage(route="/benchmark", component=unevaluated_page) + compile_ctx = CompileContext( + pages=[page], + hooks=CompilerHooks(plugins=default_page_plugins()), + ) + + with compile_ctx: + compiled_pages = compile_ctx.compile() + + output_code = compiled_pages["/benchmark"].output_code + if output_code is None: + msg = "CompileContext did not produce output code for the benchmark page." + raise RuntimeError(msg) + return output_code + + def test_compile_page(evaluated_page: Component, benchmark: BenchmarkFixture): import_templates() benchmark(lambda: _compile_page(evaluated_page)) -def test_compile_stateful(evaluated_page: Component, benchmark: BenchmarkFixture): +def test_compile_page_single_pass( + evaluated_page: Component, + benchmark: BenchmarkFixture, +): import_templates() - benchmark(lambda: _compile_stateful_components([evaluated_page])) + benchmark(lambda: _compile_page_single_pass(evaluated_page)) + + +def test_compile_page_full_context( + unevaluated_page, + benchmark: BenchmarkFixture, +): + import_templates() + + benchmark(lambda: _compile_page_full_context(unevaluated_page)) def test_get_all_imports(evaluated_page: Component, benchmark: BenchmarkFixture): benchmark(lambda: evaluated_page._get_all_imports()) + + +def test_get_all_imports_single_pass( + evaluated_page: Component, + benchmark: BenchmarkFixture, +): + benchmark(lambda: _get_imports_single_pass(evaluated_page)) + + +def test_compile_single_pass_all_artifacts( + evaluated_page: Component, + benchmark: BenchmarkFixture, +): + """Full single-pass collecting all artifacts (imports, hooks, code, app_wrap). + + This is the fair comparison for the total work the old multi-pass approach + did across _get_all_imports + _get_all_hooks + _get_all_custom_code + + _get_all_app_wrap_components. + """ + benchmark( + lambda: _compile_single_pass_page_ctx(evaluated_page).merged_imports( + collapse=True + ) + ) diff --git a/tests/benchmarks/test_evaluate.py b/tests/benchmarks/test_evaluate.py index 7af08a3592e..b533b34c415 100644 --- a/tests/benchmarks/test_evaluate.py +++ b/tests/benchmarks/test_evaluate.py @@ -2,9 +2,22 @@ from pytest_codspeed import BenchmarkFixture from reflex_base.components.component import Component +from reflex_base.plugins import CompilerHooks + +from reflex.app import UnevaluatedPage +from reflex.compiler.plugins import DefaultPagePlugin def test_evaluate_page( unevaluated_page: Callable[[], Component], benchmark: BenchmarkFixture ): benchmark(unevaluated_page) + + +def test_evaluate_page_single_pass( + unevaluated_page: Callable[[], Component], + benchmark: BenchmarkFixture, +): + hooks = CompilerHooks(plugins=(DefaultPagePlugin(),)) + page = UnevaluatedPage(route="/benchmark", component=unevaluated_page) + benchmark(lambda: hooks.eval_page(page.component, page=page)) diff --git a/tests/integration/test_auto_memo.py b/tests/integration/test_auto_memo.py new file mode 100644 index 00000000000..121184f493e --- /dev/null +++ b/tests/integration/test_auto_memo.py @@ -0,0 +1,73 @@ +"""Integration tests for compiler-generated experimental memos.""" + +from collections.abc import Generator + +import pytest +from selenium.webdriver.common.by import By + +from reflex.testing import AppHarness + +from .utils import poll_for_navigation + + +def AutoMemoAcrossPagesApp(): + """Reflex app that shares one stateful subtree across two pages.""" + import reflex as rx + + def shared_counter() -> rx.Component: + return rx.text(rx.State.router.page.raw_path, id="shared-value") + + def index() -> rx.Component: + return rx.vstack( + shared_counter(), + rx.link("Other", href="/other", id="to-other"), + ) + + def other() -> rx.Component: + return rx.vstack( + shared_counter(), + rx.link("Home", href="/", id="to-home"), + ) + + app = rx.App() + app.add_page(index) + app.add_page(other, route="/other") + + +@pytest.fixture +def auto_memo_app(tmp_path) -> Generator[AppHarness, None, None]: + """Start AutoMemoAcrossPagesApp app at tmp_path via AppHarness. + + Yields: + A running AppHarness instance. + """ + with AppHarness.create( + root=tmp_path, + app_source=AutoMemoAcrossPagesApp, + ) as harness: + yield harness + + +def test_auto_memo_shared_across_pages(auto_memo_app: AppHarness): + """Shared stateful subtrees compile once and render correctly on both pages.""" + assert auto_memo_app.app_instance is not None, "app is not running" + + web_sources = "\n".join( + path.read_text() for path in (auto_memo_app.app_path / ".web").rglob("*.jsx") + ) + assert "$/utils/components" in web_sources + assert "$/utils/stateful_components" not in web_sources + + driver = auto_memo_app.frontend() + shared_value = AppHarness.poll_for_or_raise_timeout( + lambda: driver.find_element(By.ID, "shared-value") + ) + assert auto_memo_app.poll_for_content(shared_value, exp_not_equal="") == "/" + + with poll_for_navigation(driver): + driver.find_element(By.ID, "to-other").click() + + shared_value = AppHarness.poll_for_or_raise_timeout( + lambda: driver.find_element(By.ID, "shared-value") + ) + assert "other" in auto_memo_app.poll_for_content(shared_value, exp_not_equal="") diff --git a/tests/units/compiler/test_memoize_plugin.py b/tests/units/compiler/test_memoize_plugin.py new file mode 100644 index 00000000000..50d0a5d50e1 --- /dev/null +++ b/tests/units/compiler/test_memoize_plugin.py @@ -0,0 +1,216 @@ +# ruff: noqa: D101 + +import dataclasses +from collections.abc import Callable +from typing import Any + +from reflex_base.components.component import Component, field +from reflex_base.constants.compiler import MemoizationDisposition, MemoizationMode +from reflex_base.plugins import CompileContext, CompilerHooks, PageContext +from reflex_base.vars import VarData +from reflex_base.vars.base import LiteralVar, Var +from reflex_components_core.base.fragment import Fragment + +from reflex.compiler.plugins import DefaultCollectorPlugin, default_page_plugins +from reflex.compiler.plugins.memoize import MemoizeStatefulPlugin, _should_memoize +from reflex.experimental.memo import ( + ExperimentalMemoComponent, + create_passthrough_component_memo, +) + +STATE_VAR = LiteralVar.create("value")._replace( + merge_var_data=VarData(hooks={"useTestState": None}, state="TestState") +) + + +class Plain(Component): + tag = "Plain" + library = "plain-lib" + + +class WithProp(Component): + tag = "WithProp" + library = "with-prop-lib" + + label: Var[str] = field(default=LiteralVar.create("")) + + +class LeafComponent(Component): + tag = "LeafComponent" + library = "leaf-lib" + _memoization_mode = MemoizationMode(recursive=False) + + +@dataclasses.dataclass(slots=True) +class FakePage: + route: str + component: Callable[[], Component] + title: Any = None + description: Any = None + image: str = "" + meta: tuple[dict[str, Any], ...] = () + + +def _compile_single_page( + component_factory: Callable[[], Component], +) -> tuple[CompileContext, PageContext]: + ctx = CompileContext( + pages=[FakePage(route="/p", component=component_factory)], + hooks=CompilerHooks(plugins=default_page_plugins()), + ) + with ctx: + ctx.compile() + return ctx, ctx.compiled_pages["/p"] + + +def test_should_memoize_catches_direct_state_var_in_prop() -> None: + """A component whose own prop carries state VarData should memoize.""" + comp = WithProp.create(label=STATE_VAR) + assert _should_memoize(comp) + + +def test_should_memoize_catches_state_var_in_child_bare() -> None: + """A component whose Bare child contains state VarData should memoize.""" + comp = Plain.create(STATE_VAR) + assert _should_memoize(comp) + + +def test_should_not_memoize_plain_component() -> None: + """A component with no state vars and no event triggers is not memoized.""" + comp = Plain.create(LiteralVar.create("static-content")) + assert not _should_memoize(comp) + + +def test_should_not_memoize_when_disposition_never() -> None: + """``MemoizationDisposition.NEVER`` overrides heuristic eligibility.""" + comp = Plain.create(STATE_VAR) + object.__setattr__( + comp, + "_memoization_mode", + dataclasses.replace( + comp._memoization_mode, disposition=MemoizationDisposition.NEVER + ), + ) + assert not _should_memoize(comp) + + +def test_memoize_wrapper_uses_experimental_memo_component_and_call_site() -> None: + """Memoizable component imports a generated ``rx._x.memo`` wrapper.""" + ctx, page_ctx = _compile_single_page(lambda: Plain.create(STATE_VAR)) + + assert len(ctx.memoize_wrappers) == 1 + wrapper_tag = next(iter(ctx.memoize_wrappers)) + assert wrapper_tag in ctx.auto_memo_components + output = page_ctx.output_code or "" + assert f'import {{{wrapper_tag}}} from "$/utils/components"' in output + assert f"jsx({wrapper_tag}," in (page_ctx.output_code or "") + assert f"const {wrapper_tag} = memo" not in output + + +def test_memoize_wrapper_deduped_across_repeated_subtrees() -> None: + """Two identical memoizable call-sites collapse to one memo definition.""" + ctx, page_ctx = _compile_single_page( + lambda: Fragment.create( + Plain.create(STATE_VAR), + Plain.create(STATE_VAR), + ) + ) + assert len(ctx.memoize_wrappers) == 1 + wrapper_tag = next(iter(ctx.memoize_wrappers)) + assert list(ctx.auto_memo_components) == [wrapper_tag] + assert (page_ctx.output_code or "").count( + f'import {{{wrapper_tag}}} from "$/utils/components"' + ) == 1 + + +def test_memoization_leaf_suppresses_descendant_wrapping() -> None: + """A MemoizationLeaf suppresses independent wrappers for its descendants. + + Even when a descendant (``Plain(STATE_VAR)``) would otherwise be wrapped, + being inside a leaf's subtree suppresses that wrapping. Whether or not the + leaf itself gets wrapped, descendants do not produce their own wrappers. + """ + ctx, _page_ctx = _compile_single_page( + lambda: LeafComponent.create( + Plain.create(STATE_VAR), # would otherwise be independently memoized + ) + ) + # The inner Plain(STATE_VAR) is suppressed because it's inside the leaf's + # subtree. The leaf itself has no direct state dependency so no wrapper + # is emitted for it either. + assert len(ctx.memoize_wrappers) == 0 + + +def test_generated_memo_component_is_not_itself_memoized() -> None: + """The generated memo component instance itself is skipped by the heuristic.""" + wrapper_factory, _definition = create_passthrough_component_memo("MyTag") + wrapper = wrapper_factory(Plain.create()) + assert isinstance(wrapper, ExperimentalMemoComponent) + assert not _should_memoize(wrapper) + + +def test_event_trigger_memoization_emits_usecallback_in_page_hooks() -> None: + """Components with event triggers get useCallback wrappers at the page level.""" + from reflex_base.event import EventChain + + # Construct an event chain referencing state so _get_memoized_event_triggers + # emits a useCallback. + event_var = Var(_js_expr="test_event")._replace( + _var_type=EventChain, + merge_var_data=VarData(state="TestState"), + ) + comp = Plain.create() + comp.event_triggers["on_click"] = event_var + + _ctx, page_ctx = _compile_single_page(lambda: comp) + + # Check that a useCallback hook line was added to the page hooks dict. + hook_lines = list(page_ctx.hooks.keys()) + assert any( + "useCallback" in hook_line and "on_click_" in hook_line + for hook_line in hook_lines + ), f"Expected on_click useCallback hook in {hook_lines!r}" + + +def test_generated_memo_component_renders_as_its_exported_tag() -> None: + """The generated experimental memo component renders as its exported tag.""" + wrapper_factory, definition = create_passthrough_component_memo("MyWrapper_abc") + wrapper = wrapper_factory(Plain.create()) + assert isinstance(wrapper, ExperimentalMemoComponent) + assert wrapper.tag == "MyWrapper_abc" + assert definition.export_name == "MyWrapper_abc" + assert wrapper.render()["name"] == "MyWrapper_abc" + + +def test_shared_subtree_across_pages_uses_same_tag() -> None: + """The same memoizable subtree on multiple pages gets one shared tag.""" + ctx = CompileContext( + pages=[ + FakePage(route="/a", component=lambda: Plain.create(STATE_VAR)), + FakePage(route="/b", component=lambda: Plain.create(STATE_VAR)), + ], + hooks=CompilerHooks(plugins=default_page_plugins()), + ) + with ctx: + ctx.compile() + + assert len(ctx.memoize_wrappers) == 1 + tag = next(iter(ctx.memoize_wrappers)) + assert list(ctx.auto_memo_components) == [tag] + for route in ("/a", "/b"): + output = ctx.compiled_pages[route].output_code or "" + assert f'import {{{tag}}} from "$/utils/components"' in output + assert f"jsx({tag}," in output + + +def test_plugin_only_registered_once_in_default_page_plugins() -> None: + """MemoizeStatefulPlugin appears exactly once in the default plugin pipeline.""" + plugins = default_page_plugins() + memoize_plugins = [p for p in plugins if isinstance(p, MemoizeStatefulPlugin)] + assert len(memoize_plugins) == 1 + # And it is registered before the DefaultCollectorPlugin. + collector_index = next( + i for i, p in enumerate(plugins) if isinstance(p, DefaultCollectorPlugin) + ) + memoize_index = plugins.index(memoize_plugins[0]) + assert memoize_index < collector_index diff --git a/tests/units/compiler/test_plugins.py b/tests/units/compiler/test_plugins.py new file mode 100644 index 00000000000..86587d39be2 --- /dev/null +++ b/tests/units/compiler/test_plugins.py @@ -0,0 +1,988 @@ +# ruff: noqa: D101, D102 + +import dataclasses +from collections.abc import Callable +from typing import Any + +import pytest +from reflex_base.components.component import ( + BaseComponent, + Component, + ComponentStyle, + field, +) +from reflex_base.plugins import ( + BaseContext, + CompileContext, + CompilerHooks, + ComponentAndChildren, + PageContext, + PageDefinition, + Plugin, +) +from reflex_base.utils import format as format_utils +from reflex_base.utils.imports import ImportVar, collapse_imports, merge_imports +from reflex_base.vars import VarData +from reflex_base.vars.base import LiteralVar, Var +from reflex_components_core.base.fragment import Fragment + +from reflex.app import UnevaluatedPage +from reflex.compiler import compiler +from reflex.compiler.plugins import ( + ApplyStylePlugin, + DefaultCollectorPlugin, + DefaultPagePlugin, + default_page_plugins, +) + + +@dataclasses.dataclass(slots=True) +class FakePage: + route: str + component: Callable[[], Component] + title: Var | str | None = None + description: Var | str | None = None + image: str = "" + meta: tuple[dict[str, Any], ...] = () + + +class WrapperComponent(Component): + tag = "WrapperComponent" + library = "wrapper-lib" + + @staticmethod + def _get_app_wrap_components() -> dict[tuple[int, str], Component]: + return {(20, "NestedWrap"): Fragment.create()} + + +class RootComponent(Component): + tag = "RootComponent" + library = "root-lib" + + slot: Component | None = field(default=None) + + def add_style(self) -> dict[str, Any] | None: + return {"display": "flex"} + + def add_custom_code(self) -> list[str]: + return ["const rootAddedCode = 1;"] + + @staticmethod + def _get_app_wrap_components() -> dict[tuple[int, str], Component]: + return {(10, "Wrap"): WrapperComponent.create()} + + +class ChildComponent(Component): + tag = "ChildComponent" + library = "child-lib" + + def add_style(self) -> dict[str, Any] | None: + return {"align_items": "center"} + + def add_custom_code(self) -> list[str]: + return ["const childAddedCode = 1;"] + + def _get_custom_code(self) -> str | None: + return "const childCustomCode = 1;" + + def _get_hooks(self) -> str | None: + return "const childHook = useChildHook();" + + +class PropComponent(Component): + tag = "PropComponent" + library = "prop-lib" + + def add_custom_code(self) -> list[str]: + return ["const propAddedCode = 1;"] + + def _get_custom_code(self) -> str | None: + return "const propCustomCode = 1;" + + def _get_dynamic_imports(self) -> str | None: + return "dynamic(() => import('prop-lib'))" + + def _get_hooks(self) -> str | None: + return "const propHook = usePropHook();" + + @staticmethod + def _get_app_wrap_components() -> dict[tuple[int, str], Component]: + return {(15, "PropWrap"): Fragment.create()} + + +class SharedLibraryComponent(Component): + tag = "SharedLibraryComponent" + library = "react-moment" + + @staticmethod + def _get_app_wrap_components() -> dict[tuple[int, str], Component]: + return {(25, "SharedLibraryWrap"): Fragment.create()} + + +class InlineStatefulComponent(Component): + tag = "InlineStatefulComponent" + library = "inline-lib" + + +class StubPlugin(Plugin): + pass + + +SHARED_STATEFUL_VAR = LiteralVar.create("shared")._replace( + merge_var_data=VarData( + hooks={"useSharedStatefulValue": None}, + state="SharedState", + ) +) + +INLINE_STATEFUL_VAR = LiteralVar.create("inline")._replace( + merge_var_data=VarData( + hooks={"useInlineStatefulValue": None}, + state="InlineState", + ) +) + + +def create_component_tree() -> RootComponent: + return RootComponent.create( + ChildComponent.create(id="child-id", style={"color": "red"}), + slot=PropComponent.create(id="prop-id", style={"opacity": "0.5"}), + style={"margin": "0"}, + ) + + +def create_shared_stateful_component() -> SharedLibraryComponent: + return SharedLibraryComponent.create(SHARED_STATEFUL_VAR) + + +def create_inline_stateful_component() -> InlineStatefulComponent: + return InlineStatefulComponent.create(INLINE_STATEFUL_VAR) + + +def page_style() -> ComponentStyle: + return { + RootComponent: {"padding": "1rem"}, + ChildComponent: {"font_size": "12px"}, + PropComponent: {"border": "1px solid green"}, + } + + +def normalize_style(component: BaseComponent) -> dict[str, str]: + assert isinstance(component, Component) + return {key: str(value) for key, value in component.style.items()} + + +def create_compile_context(hooks: CompilerHooks | None = None) -> CompileContext: + return CompileContext(pages=[], hooks=hooks or CompilerHooks()) + + +def collect_page_context( + component: BaseComponent, + *, + plugins: tuple[Any, ...], +) -> PageContext: + page_ctx = PageContext( + name="page", + route="/page", + root_component=component, + ) + hooks = CompilerHooks(plugins=plugins) + compile_ctx = create_compile_context(hooks) + + with compile_ctx, page_ctx: + page_ctx.root_component = hooks.compile_component( + page_ctx.root_component, + page_context=page_ctx, + compile_context=compile_ctx, + ) + hooks.compile_page(page_ctx, compile_context=compile_ctx) + + return page_ctx + + +def test_eval_page_uses_first_non_none_result() -> None: + calls: list[str] = [] + page = FakePage(route="/demo", component=lambda: Fragment.create()) + + class NoMatchPlugin(StubPlugin): + def eval_page( + self, + page_fn: Any, + /, + *, + page: PageDefinition, + **kwargs: Any, + ) -> None: + del page_fn, page, kwargs + calls.append("no-match") + + class MatchPlugin(StubPlugin): + def eval_page( + self, + page_fn: Any, + /, + *, + page: PageDefinition, + **kwargs: Any, + ) -> PageContext: + del kwargs + calls.append("match") + return PageContext( + name="page", + route=page.route, + root_component=page_fn(), + ) + + class UnreachablePlugin(StubPlugin): + def eval_page( + self, + page_fn: Any, + /, + *, + page: PageDefinition, + **kwargs: Any, + ) -> PageContext: + del page_fn, page, kwargs + calls.append("unreachable") + msg = "eval_page should stop at the first page context" + raise AssertionError(msg) + + hooks = CompilerHooks(plugins=(NoMatchPlugin(), MatchPlugin(), UnreachablePlugin())) + + page_ctx = hooks.eval_page(page.component, page=page, compile_context=None) + + assert page_ctx is not None + assert page_ctx.route == "/demo" + assert calls == ["no-match", "match"] + + +def test_compile_page_runs_plugins_in_registration_order() -> None: + calls: list[str] = [] + page_ctx = PageContext( + name="page", + route="/ordered", + root_component=Fragment.create(), + ) + + class FirstPlugin(StubPlugin): + def compile_page( + self, + page_ctx: PageContext, + /, + **kwargs: Any, + ) -> None: + del page_ctx, kwargs + calls.append("first") + + class SecondPlugin(StubPlugin): + def compile_page( + self, + page_ctx: PageContext, + /, + **kwargs: Any, + ) -> None: + del page_ctx, kwargs + calls.append("second") + + hooks = CompilerHooks(plugins=(FirstPlugin(), SecondPlugin())) + hooks.compile_page(page_ctx, compile_context=None) + + assert calls == ["first", "second"] + + +def test_component_hook_resolution_caches_only_real_overrides() -> None: + class EnterPlugin(StubPlugin): + def enter_component( + self, + comp: BaseComponent, + /, + *, + page_context: PageContext, + compile_context: CompileContext, + in_prop_tree: bool = False, + ) -> None: + del comp, page_context, compile_context, in_prop_tree + + class LeavePlugin(StubPlugin): + def leave_component( + self, + comp: BaseComponent, + children: tuple[BaseComponent, ...], + /, + *, + page_context: PageContext, + compile_context: CompileContext, + in_prop_tree: bool = False, + ) -> None: + del ( + comp, + children, + page_context, + compile_context, + in_prop_tree, + ) + + hooks = CompilerHooks(plugins=(Plugin(), EnterPlugin(), LeavePlugin())) + + assert len(hooks._enter_component_hook_binders) == 1 + assert len(hooks._leave_component_hook_binders) == 1 + + +def test_enter_component_skips_inherited_base_plugin_hook( + monkeypatch: pytest.MonkeyPatch, +) -> None: + visited: list[str] = [] + root = RootComponent.create() + + def fail_enter_component( + self, + comp: BaseComponent, + /, + *, + page_context: PageContext, + compile_context: CompileContext, + in_prop_tree: bool = False, + ) -> None: + del self, comp, page_context, compile_context, in_prop_tree + msg = "Inherited Plugin.enter_component hook should be skipped." + raise AssertionError(msg) + + monkeypatch.setattr(Plugin, "enter_component", fail_enter_component) + + class RealPlugin(StubPlugin): + def enter_component( + self, + comp: BaseComponent, + /, + *, + page_context: PageContext, + compile_context: CompileContext, + in_prop_tree: bool = False, + ) -> None: + del page_context, compile_context, in_prop_tree + visited.append(type(comp).__name__) + + hooks = CompilerHooks(plugins=(Plugin(), RealPlugin())) + page_ctx = PageContext(name="page", route="/page", root_component=root) + compile_ctx = create_compile_context(hooks) + + with compile_ctx, page_ctx: + hooks.compile_component( + root, + page_context=page_ctx, + compile_context=compile_ctx, + ) + + assert visited == ["RootComponent"] + + +def test_enter_component_skips_inherited_protocol_hook( + monkeypatch: pytest.MonkeyPatch, +) -> None: + visited: list[str] = [] + root = RootComponent.create() + + def fail_enter_component( + self, + comp: BaseComponent, + /, + *, + page_context: PageContext, + compile_context: CompileContext, + in_prop_tree: bool = False, + ) -> None: + del self, comp, page_context, compile_context, in_prop_tree + msg = "Inherited Plugin.enter_component hook should be skipped." + raise AssertionError(msg) + + monkeypatch.setattr(Plugin, "enter_component", fail_enter_component) + + class ProtocolOnlyPlugin(Plugin): + pass + + class RealPlugin(StubPlugin): + def enter_component( + self, + comp: BaseComponent, + /, + *, + page_context: PageContext, + compile_context: CompileContext, + in_prop_tree: bool = False, + ) -> None: + del page_context, compile_context, in_prop_tree + visited.append(type(comp).__name__) + + hooks = CompilerHooks(plugins=(ProtocolOnlyPlugin(), RealPlugin())) + page_ctx = PageContext(name="page", route="/page", root_component=root) + compile_ctx = create_compile_context(hooks) + + with compile_ctx, page_ctx: + hooks.compile_component( + root, + page_context=page_ctx, + compile_context=compile_ctx, + ) + + assert visited == ["RootComponent"] + + +def test_compile_component_orders_enter_and_leave_by_plugin() -> None: + events: list[str] = [] + root = RootComponent.create() + + class FirstPlugin(StubPlugin): + def enter_component( + self, + comp: BaseComponent, + /, + *, + page_context: PageContext, + compile_context: CompileContext, + in_prop_tree: bool = False, + ) -> None: + del comp, page_context, compile_context, in_prop_tree + events.append("first:enter") + + def leave_component( + self, + comp: BaseComponent, + children: tuple[BaseComponent, ...], + /, + *, + page_context: PageContext, + compile_context: CompileContext, + in_prop_tree: bool = False, + ) -> None: + del ( + comp, + children, + page_context, + compile_context, + in_prop_tree, + ) + events.append("first:leave") + + class SecondPlugin(StubPlugin): + def enter_component( + self, + comp: BaseComponent, + /, + *, + page_context: PageContext, + compile_context: CompileContext, + in_prop_tree: bool = False, + ) -> None: + del comp, page_context, compile_context, in_prop_tree + events.append("second:enter") + + def leave_component( + self, + comp: BaseComponent, + children: tuple[BaseComponent, ...], + /, + *, + page_context: PageContext, + compile_context: CompileContext, + in_prop_tree: bool = False, + ) -> None: + del ( + comp, + children, + page_context, + compile_context, + in_prop_tree, + ) + events.append("second:leave") + + hooks = CompilerHooks(plugins=(FirstPlugin(), SecondPlugin())) + page_ctx = PageContext(name="page", route="/page", root_component=root) + compile_ctx = create_compile_context(hooks) + + with compile_ctx, page_ctx: + compiled_root = hooks.compile_component( + root, + page_context=page_ctx, + compile_context=compile_ctx, + ) + + assert compiled_root is root + assert events == [ + "first:enter", + "second:enter", + "second:leave", + "first:leave", + ] + + +def test_compile_component_traverses_children_before_prop_components() -> None: + visited: list[str] = [] + root = RootComponent.create( + ChildComponent.create(), + slot=PropComponent.create(), + ) + + class VisitPlugin(StubPlugin): + def enter_component( + self, + comp: BaseComponent, + /, + *, + page_context: PageContext, + compile_context: CompileContext, + in_prop_tree: bool = False, + ) -> None: + del page_context, compile_context, in_prop_tree + if isinstance(comp, Component): + visited.append(comp.tag or type(comp).__name__) + + hooks = CompilerHooks(plugins=(VisitPlugin(),)) + page_ctx = PageContext(name="page", route="/page", root_component=root) + compile_ctx = create_compile_context(hooks) + + with compile_ctx, page_ctx: + hooks.compile_component( + root, + page_context=page_ctx, + compile_context=compile_ctx, + ) + + assert visited == ["RootComponent", "ChildComponent", "PropComponent"] + + +def test_enter_and_leave_replacements_match_generator_style_behavior() -> None: + child = ChildComponent.create(id="original") + root = RootComponent.create(child) + + class ReplacePlugin(StubPlugin): + def enter_component( + self, + comp: BaseComponent, + /, + *, + page_context: PageContext, + compile_context: CompileContext, + in_prop_tree: bool = False, + ) -> BaseComponent | ComponentAndChildren | None: + del page_context, compile_context + if isinstance(comp, RootComponent) and not in_prop_tree: + replacement_child = ChildComponent.create(id="replacement") + return comp, (replacement_child,) + return None + + def leave_component( + self, + comp: BaseComponent, + children: tuple[BaseComponent, ...], + /, + *, + page_context: PageContext, + compile_context: CompileContext, + in_prop_tree: bool = False, + ) -> BaseComponent | ComponentAndChildren | None: + del page_context, compile_context, in_prop_tree + if isinstance(comp, RootComponent): + return Fragment.create(comp), children + return None + + hooks = CompilerHooks(plugins=(ReplacePlugin(),)) + page_ctx = PageContext(name="page", route="/page", root_component=root) + compile_ctx = create_compile_context(hooks) + + with compile_ctx, page_ctx: + compiled_root = hooks.compile_component( + root, + page_context=page_ctx, + compile_context=compile_ctx, + ) + + assert isinstance(compiled_root, Fragment) + assert len(compiled_root.children) == 1 + replacement_child = compiled_root.children[0] + assert isinstance(replacement_child, ChildComponent) + assert str(replacement_child.id) == "replacement" + + +def test_context_lifecycle_and_cleanup() -> None: + compile_ctx = CompileContext(pages=[], hooks=CompilerHooks()) + page_ctx = PageContext( + name="page", + route="/ctx", + root_component=Fragment.create(), + ) + + with pytest.raises(RuntimeError, match="No active CompileContext"): + CompileContext.get() + with pytest.raises( + RuntimeError, match="must be entered with 'with' or 'async with'" + ): + compile_ctx.ensure_context_attached() + + with compile_ctx: + assert CompileContext.get() is compile_ctx + with pytest.raises(RuntimeError, match="No active PageContext"): + PageContext.get() + with page_ctx: + assert CompileContext.get() is compile_ctx + assert PageContext.get() is page_ctx + page_ctx.ensure_context_attached() + with pytest.raises(RuntimeError, match="No active PageContext"): + PageContext.get() + assert CompileContext.get() is compile_ctx + + with pytest.raises(RuntimeError, match="No active CompileContext"): + CompileContext.get() + + with pytest.raises(ValueError, match="boom"), compile_ctx: + msg = "boom" + raise ValueError(msg) + + with pytest.raises(RuntimeError, match="No active CompileContext"): + CompileContext.get() + + +def test_page_context_default_factories_are_isolated() -> None: + page_ctx_a = PageContext( + name="a", + route="/a", + root_component=Fragment.create(), + ) + page_ctx_b = PageContext( + name="b", + route="/b", + root_component=Fragment.create(), + ) + + page_ctx_a.imports.append({"lib-a": [ImportVar(tag="ThingA")]}) + page_ctx_a.module_code["const a = 1;"] = None + page_ctx_a.hooks["hookA"] = None + page_ctx_a.dynamic_imports.add("dynamic-a") + page_ctx_a.refs["refA"] = None + page_ctx_a.app_wrap_components[1, "WrapA"] = Fragment.create() + + assert page_ctx_b.imports == [] + assert page_ctx_b.module_code == {} + assert page_ctx_b.hooks == {} + assert page_ctx_b.dynamic_imports == set() + assert page_ctx_b.refs == {} + assert page_ctx_b.app_wrap_components == {} + + +def test_page_context_helpers_preserve_accumulated_values() -> None: + page_ctx = PageContext( + name="page", + route="/page", + root_component=Fragment.create(), + ) + page_ctx.imports.extend([ + {"lib-a": [ImportVar(tag="ThingA")]}, + {"lib-a": [ImportVar(tag="ThingB")], "lib-b": [ImportVar(tag="ThingC")]}, + ]) + page_ctx.module_code["const first = 1;"] = None + page_ctx.module_code["const second = 2;"] = None + + assert page_ctx.merged_imports() == merge_imports(*page_ctx.imports) + assert page_ctx.merged_imports(collapse=True) == collapse_imports( + merge_imports(*page_ctx.imports) + ) + assert list(page_ctx.custom_code_dict()) == [ + "const first = 1;", + "const second = 2;", + ] + + +def test_base_context_subclasses_initialize_distinct_context_vars() -> None: + class DynamicContext(BaseContext): + pass + + class AnotherDynamicContext(BaseContext): + pass + + assert DynamicContext.__context_var__ is not AnotherDynamicContext.__context_var__ + + +def test_apply_style_plugin_matches_legacy_style_behavior() -> None: + component = create_component_tree() + legacy_component = create_component_tree() + + legacy_component._add_style_recursive(page_style()) + + hooks = CompilerHooks(plugins=(ApplyStylePlugin(style=page_style()),)) + page_ctx = PageContext(name="page", route="/page", root_component=component) + compile_ctx = create_compile_context(hooks) + + with compile_ctx, page_ctx: + hooks.compile_component( + component, + page_context=page_ctx, + compile_context=compile_ctx, + ) + + assert normalize_style(component) == normalize_style(legacy_component) + assert normalize_style(component.children[0]) == normalize_style( + legacy_component.children[0] + ) + assert component.slot is not None + assert legacy_component.slot is not None + assert normalize_style(component.slot) == normalize_style(legacy_component.slot) + + +def test_default_collector_matches_legacy_collectors() -> None: + component = create_component_tree() + assert "prop-lib" in component._get_all_imports(collapse=True) + + page_ctx = collect_page_context( + component, + plugins=(DefaultCollectorPlugin(),), + ) + + assert page_ctx.imports == [component._get_all_imports(collapse=True)] + assert "prop-lib" in page_ctx.frontend_imports + assert page_ctx.hooks == component._get_all_hooks() + assert "usePropHook" not in "".join(page_ctx.hooks) + assert page_ctx.module_code == component._get_all_custom_code() + assert page_ctx.dynamic_imports == component._get_all_dynamic_imports() + assert page_ctx.refs == component._get_all_refs() + assert page_ctx.refs == { + format_utils.format_ref("child-id"): None, + format_utils.format_ref("prop-id"): None, + } + assert ( + page_ctx.app_wrap_components.keys() + == component._get_all_app_wrap_components().keys() + ) + + +def test_default_page_plugins_are_minimal_and_ordered() -> None: + from reflex.compiler.plugins.memoize import MemoizeStatefulPlugin + + plugins = default_page_plugins(style=page_style()) + + assert len(plugins) == 4 + assert isinstance(plugins[0], DefaultPagePlugin) + assert isinstance(plugins[1], ApplyStylePlugin) + assert isinstance(plugins[2], MemoizeStatefulPlugin) + assert isinstance(plugins[3], DefaultCollectorPlugin) + + +def test_compile_context_compiles_pages_and_matches_legacy_output() -> None: + page = FakePage(route="/demo", component=create_component_tree) + compile_ctx = CompileContext( + pages=[page], + hooks=CompilerHooks(plugins=default_page_plugins(style=page_style())), + ) + + with compile_ctx: + compiled_pages = compile_ctx.compile() + + assert compiled_pages is compile_ctx.compiled_pages + assert list(compiled_pages) == ["/demo"] + + page_ctx = compiled_pages["/demo"] + assert isinstance(page_ctx.root_component, Component) + assert page_ctx.name == "create_component_tree" + assert page_ctx.route == "/demo" + assert "prop-lib" in page_ctx.root_component._get_all_imports(collapse=True) + assert page_ctx.frontend_imports == page_ctx.merged_imports(collapse=True) + assert "prop-lib" in page_ctx.frontend_imports + compile_ctx_imports = collapse_imports(compile_ctx.all_imports) + for lib, fields in page_ctx.frontend_imports.items(): + assert lib in compile_ctx_imports + assert set(compile_ctx_imports[lib]) >= set(fields) + assert page_ctx.output_path is not None + assert page_ctx.output_code is not None + assert page_ctx.imports == [page_ctx.root_component._get_all_imports(collapse=True)] + assert page_ctx.hooks == page_ctx.root_component._get_all_hooks() + assert page_ctx.module_code == page_ctx.root_component._get_all_custom_code() + assert ( + page_ctx.dynamic_imports == page_ctx.root_component._get_all_dynamic_imports() + ) + assert page_ctx.refs == page_ctx.root_component._get_all_refs() + assert ( + page_ctx.app_wrap_components.keys() + == page_ctx.root_component._get_all_app_wrap_components().keys() + ) + + legacy_component = compiler.compile_unevaluated_page( + page.route, + UnevaluatedPage( + component=page.component, + route=page.route, + title=page.title, + description=page.description, + image=page.image, + on_load=None, + meta=page.meta, + context={}, + ), + page_style(), + None, + ) + expected_output = compiler.compile_page(page.route, legacy_component)[1] + assert page_ctx.output_code == expected_output + + +def test_default_page_plugin_handles_var_backed_title_like_legacy_compiler() -> None: + page = UnevaluatedPage( + component=lambda: Fragment.create(), + route="/var-title", + title=Var(_js_expr="pageTitle", _var_type=str), + description=None, + image="", + on_load=None, + meta=(), + context={}, + ) + hooks = CompilerHooks(plugins=(DefaultPagePlugin(),)) + compile_ctx = create_compile_context(hooks) + + with compile_ctx: + page_ctx = hooks.eval_page( + page.component, + page=page, + compile_context=compile_ctx, + ) + + assert page_ctx is not None + + legacy_component = compiler.compile_unevaluated_page( + page.route, + page, + None, + None, + ) + assert page_ctx.root_component.render() == legacy_component.render() + + +def test_compile_context_rejects_duplicate_routes() -> None: + pages = [ + FakePage(route="/duplicate", component=lambda: Fragment.create()), + FakePage(route="/duplicate", component=lambda: Fragment.create()), + ] + compile_ctx = CompileContext( + pages=pages, + hooks=CompilerHooks(plugins=(DefaultPagePlugin(),)), + ) + + with ( + compile_ctx, + pytest.raises( + RuntimeError, + match="Duplicate compiled page route", + ), + ): + compile_ctx.compile() + + +def test_compile_context_requires_attached_context() -> None: + compile_ctx = CompileContext( + pages=[], + hooks=CompilerHooks(), + ) + + with pytest.raises( + RuntimeError, match="must be entered with 'with' or 'async with'" + ): + compile_ctx.compile() + + +def test_compile_context_memoize_wrappers_registers_shared_subtree_tag() -> None: + """Shared memoizable subtree across pages registers a single wrapper tag.""" + pages = [ + FakePage(route="/a", component=create_shared_stateful_component), + FakePage(route="/b", component=create_shared_stateful_component), + ] + compile_ctx = CompileContext( + pages=pages, + hooks=CompilerHooks(plugins=default_page_plugins()), + ) + + with compile_ctx: + compile_ctx.compile() + + # The wrapped library import still reaches the compile-context level. + assert "react-moment" in compile_ctx.all_imports + assert (25, "SharedLibraryWrap") in compile_ctx.app_wrap_components + # Both pages share the same subtree hash, so exactly one wrapper tag is registered. + assert len(compile_ctx.memoize_wrappers) == 1 + wrapper_tag = next(iter(compile_ctx.memoize_wrappers)) + assert list(compile_ctx.auto_memo_components) == [wrapper_tag] + # Each page imports the generated experimental memo component. + page_a_code = compile_ctx.compiled_pages["/a"].output_code or "" + assert f'import {{{wrapper_tag}}} from "$/utils/components"' in page_a_code + assert f"jsx({wrapper_tag}," in page_a_code + assert f"const {wrapper_tag} = memo" not in page_a_code + # The removed shared-stateful-components path should not appear anywhere. + assert "$/utils/stateful_components" not in page_a_code + + +def test_compile_context_resets_memoize_wrappers_between_runs() -> None: + """``CompileContext.memoize_wrappers`` is cleared on each compile run.""" + ctx = CompileContext( + pages=[FakePage(route="/a", component=create_shared_stateful_component)], + hooks=CompilerHooks(plugins=default_page_plugins()), + ) + with ctx: + ctx.compile() + first_tags = set(ctx.memoize_wrappers) + first_defs = set(ctx.auto_memo_components) + assert first_tags # memoize wrapper was registered + assert first_defs == first_tags + + # Re-compile with a different page set → wrappers reset, not accumulated. + ctx2 = CompileContext( + pages=[FakePage(route="/c", component=create_shared_stateful_component)], + hooks=CompilerHooks(plugins=default_page_plugins()), + ) + with ctx2: + ctx2.compile() + + # Same shared component → same tag, not a union across runs. + assert set(ctx2.memoize_wrappers) == first_tags + assert set(ctx2.auto_memo_components) == first_tags + page_ctx = ctx2.compiled_pages["/c"] + assert "react-moment" in page_ctx.frontend_imports + assert "$/utils/stateful_components" not in (page_ctx.output_code or "") + + +def test_compile_context_applies_style_before_inline_stateful_render() -> None: + compile_ctx = CompileContext( + pages=[ + FakePage( + route="/styled", + component=create_inline_stateful_component, + ) + ], + hooks=CompilerHooks( + plugins=default_page_plugins( + style={InlineStatefulComponent: {"color": "red"}} + ) + ), + ) + + with compile_ctx: + compile_ctx.compile() + + assert '["color"] : "red"' in ( + compile_ctx.compiled_pages["/styled"].output_code or "" + ) + + +def test_compile_context_applies_style_before_shared_stateful_render() -> None: + compile_ctx = CompileContext( + pages=[ + FakePage(route="/a", component=create_shared_stateful_component), + FakePage(route="/b", component=create_shared_stateful_component), + ], + hooks=CompilerHooks( + plugins=default_page_plugins( + style={SharedLibraryComponent: {"color": "red"}} + ) + ), + ) + + with compile_ctx: + compile_ctx.compile() + + assert '["color"] : "red"' in (compile_ctx.compiled_pages["/a"].output_code or "") + assert '["color"] : "red"' in (compile_ctx.compiled_pages["/b"].output_code or "") diff --git a/tests/units/components/test_component.py b/tests/units/components/test_component.py index de397442311..e891bd5431f 100644 --- a/tests/units/components/test_component.py +++ b/tests/units/components/test_component.py @@ -7,8 +7,8 @@ CUSTOM_COMPONENTS, Component, CustomComponent, - StatefulComponent, custom_component, + field, ) from reflex_base.constants import EventTriggers from reflex_base.constants.state import FIELD_MARKER @@ -522,30 +522,25 @@ def test_get_imports(component1, component2): } -def test_get_all_imports_includes_components_in_props(): - """Test that _get_all_imports collects imports from components in props.""" +def test_get_imports_includes_components_in_props(): + """Test that component-valued props contribute their imports.""" - class InnerComponent(Component): - """A component that requires a specific import.""" + class PropComponent(Component): + tag = "PropComponent" + library = "prop-lib" - def _get_imports(self) -> ParsedImportDict: - return {"some-library": [ImportVar(tag="SomeTag")]} - - class OuterComponent(Component): - """A component with a component-typed prop.""" + class ParentComponent(Component): + tag = "ParentComponent" + library = "parent-lib" - fallback: Component | None = None + slot: Component | None = field(default=None) - def _get_imports(self) -> ParsedImportDict: - return {"outer-library": [ImportVar(tag="OuterTag")]} + imports_ = ParentComponent.create(slot=PropComponent.create())._get_all_imports() - inner = InnerComponent.create() - outer = OuterComponent.create(fallback=inner) - all_imports = outer._get_all_imports() - assert "some-library" in all_imports, ( - "_get_all_imports() should collect imports from components in props" - ) - assert "outer-library" in all_imports + assert imports_ == parse_imports({ + "parent-lib": ["ParentComponent"], + "prop-lib": ["PropComponent"], + }) def test_get_custom_code(component1: Component, component2: Component): @@ -1191,47 +1186,6 @@ def test_format_component(component, rendered): assert str(component) == rendered -def test_stateful_component(test_state: type[TestState]): - """Test that a stateful component is created correctly. - - Args: - test_state: A test state. - """ - text_component = rx.text(test_state.num) - stateful_component = StatefulComponent.compile_from(text_component) - assert isinstance(stateful_component, StatefulComponent) - assert stateful_component.tag is not None - assert stateful_component.tag.startswith("Text_") - assert stateful_component.references == 1 - sc2 = StatefulComponent.compile_from(rx.text(test_state.num)) - assert isinstance(sc2, StatefulComponent) - assert stateful_component.references == 2 - assert sc2.references == 2 - - -def test_stateful_component_memoize_event_trigger(test_state: type[TestState]): - """Test that a stateful component is created correctly with events. - - Args: - test_state: A test state. - """ - button_component = rx.button("Click me", on_blur=test_state.do_something) - stateful_component = StatefulComponent.compile_from(button_component) - assert isinstance(stateful_component, StatefulComponent) - - # No event trigger? No StatefulComponent - assert not isinstance( - StatefulComponent.compile_from(rx.button("Click me")), StatefulComponent - ) - - -def test_stateful_banner(): - """Test that a stateful component is created correctly with events.""" - connection_modal_component = rx.connection_modal() - stateful_component = StatefulComponent.compile_from(connection_modal_component) - assert isinstance(stateful_component, StatefulComponent) - - TEST_VAR = LiteralVar.create("p")._replace( merge_var_data=VarData( hooks={"useTest": None}, diff --git a/tests/units/experimental/test_memo.py b/tests/units/experimental/test_memo.py index f202ecf05d8..236b6b115ad 100644 --- a/tests/units/experimental/test_memo.py +++ b/tests/units/experimental/test_memo.py @@ -415,6 +415,29 @@ def wrapper() -> rx.Component: assert definition.component.style == Style() +def test_component_returning_memo_is_transparent_for_child_validation(): + """Experimental memo wrappers should not break `_valid_parents` checks.""" + + class ValidParent(Component): + tag = "ValidParent" + library = "valid-parent" + + class RestrictedChild(Component): + tag = "RestrictedChild" + library = "restricted-child" + _valid_parents = ["ValidParent"] + + @rx._x.memo + def transparent(children: rx.Var[rx.Component]) -> rx.Component: + return children # type: ignore[return-value] + + wrapped_child = transparent(RestrictedChild.create()) + parent = ValidParent.create(wrapped_child) + + assert isinstance(wrapped_child, ExperimentalMemoComponent) + assert parent.children == [wrapped_child] + + def test_compile_memo_components_includes_experimental_custom_code(): """Experimental component memos should include custom code in compiled output.""" diff --git a/tests/units/reflex_base/event/processor/test_event_processor.py b/tests/units/reflex_base/event/processor/test_event_processor.py index de1ea4dcb23..ee7d0c9b5c8 100644 --- a/tests/units/reflex_base/event/processor/test_event_processor.py +++ b/tests/units/reflex_base/event/processor/test_event_processor.py @@ -518,6 +518,30 @@ async def test_stream_delta_not_configured_raises(): pass +async def test_stream_delta_calls_on_task_future(token: str): + """enqueue_stream_delta exposes the tracked EventFuture immediately. + + Args: + token: The client token. + """ + ep = EventProcessor(graceful_shutdown_timeout=2) + ep.configure() + captured = [] + async with ep: + event = Event.from_event_type(noop_event())[0] + collected = [ + d + async for d in ep.enqueue_stream_delta( + token, + event, + on_task_future=captured.append, + ) + ] + assert collected == [] + assert len(captured) == 1 + assert captured[0].done() + + async def test_sequential_chained_events_run_in_order(token: str): """Chained events enqueued by a handler run in the order they were enqueued. diff --git a/tests/units/test_app.py b/tests/units/test_app.py index 2a3eb5556f4..138f9aad047 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -1307,72 +1307,6 @@ async def send(_message): assert bio.closed -@pytest.mark.asyncio -async def test_upload_file_cancels_buffered_handler_on_disconnect(token: str): - """Buffered uploads cancel the streaming handler on client disconnect. - - Args: - token: A token. - """ - request_mock = unittest.mock.Mock() - request_mock.headers = { - "reflex-client-token": token, - "reflex-event-handler": f"{FileUploadState.get_full_name()}.multi_handle_upload", - } - - bio = io.BytesIO(b"contents of image one") - file1 = UploadFile(filename="image1.jpg", file=bio) - form_data = FormData([("files", file1)]) - original_close = form_data.close - form_close = AsyncMock(side_effect=original_close) - form_data.close = form_close - - async def form(): # noqa: RUF029 - return form_data - - request_mock.form = form - - stream_started = asyncio.Event() - stream_closed = asyncio.Event() - - async def enqueue_stream_delta(_token, _event): - try: - stream_started.set() - yield {"state": {"ok": True}} - await asyncio.Event().wait() - finally: - stream_closed.set() - - app = Mock( - event_processor=Mock(enqueue_stream_delta=enqueue_stream_delta), - ) - - upload_fn = upload(app) - streaming_response = await upload_fn(request_mock) - - assert isinstance(streaming_response, StreamingResponse) - - async def receive(): - await stream_started.wait() - return {"type": "http.disconnect"} - - async def send(_message): # noqa: RUF029 - return None - - await asyncio.wait_for( - streaming_response( - {"type": "http", "asgi": {"spec_version": "2.4"}}, - receive, - send, - ), - timeout=1, - ) - - await asyncio.wait_for(stream_closed.wait(), timeout=1) - assert form_close.await_count == 1 - assert bio.closed - - @pytest.mark.asyncio async def test_upload_file_raises_client_disconnect_when_stream_send_fails( token: str, @@ -1402,7 +1336,7 @@ async def form(): # noqa: RUF029 stream_closed = asyncio.Event() - async def enqueue_stream_delta(_token, _event): + async def enqueue_stream_delta(_token, _event, on_task_future=None): try: yield {"state": {"ok": True}} await asyncio.Event().wait() @@ -2162,6 +2096,54 @@ def test_app_wrap_compile_theme( assert expected.split(",") == function_app_definition.split(",") +def test_compile_writes_app_wrap_memo_components( + compilable_app: tuple[App, Path], + mocker, +) -> None: + """App-wrap memo components are emitted to the shared components module.""" + conf = rx.Config(app_name="testing") + mocker.patch("reflex_base.config._get_config", return_value=conf) + app, web_dir = compilable_app + + app.add_page(rx.box("Index"), route="/") + app._compile() + + components_js = ( + web_dir + / constants.Dirs.UTILS + / f"{constants.PageNames.COMPONENTS}{constants.Ext.JSX}" + ).read_text() + + assert "export const DefaultOverlayComponents" in components_js + assert "export const MemoizedToastProvider" in components_js + + +def test_compile_writes_upload_files_provider_app_wrap( + compilable_app: tuple[App, Path], + mocker, +) -> None: + """Upload pages emit the UploadFilesProvider app wrap into the app root.""" + conf = rx.Config(app_name="testing") + mocker.patch("reflex_base.config._get_config", return_value=conf) + app, web_dir = compilable_app + + app.add_page( + lambda: rx.upload.root( + rx.vstack( + rx.button("Select File"), + rx.text("Drag and drop files here or click to select files"), + ), + ), + route="/", + ) + app._compile() + + root_js = web_dir / constants.Dirs.PAGES / constants.PageNames.APP_ROOT + root_contents = root_js.read_text() + + assert "UploadFilesProvider" in root_contents + + @pytest.mark.parametrize( "react_strict_mode", [True, False], diff --git a/tests/units/test_environment.py b/tests/units/test_environment.py index 59e3c5ed438..8e7796b3ebe 100644 --- a/tests/units/test_environment.py +++ b/tests/units/test_environment.py @@ -12,7 +12,6 @@ from reflex_base.environment import ( EnvironmentVariables, EnvVar, - ExecutorType, ExistingPath, PerformanceMode, SequenceOptions, @@ -408,47 +407,6 @@ class TestEnv: assert env_var_instance.default == "default" -class TestExecutorType: - """Test the ExecutorType enum and related functionality.""" - - def test_executor_type_values(self): - """Test ExecutorType enum values.""" - assert ExecutorType.THREAD.value == "thread" - assert ExecutorType.PROCESS.value == "process" - assert ExecutorType.MAIN_THREAD.value == "main_thread" - - def test_get_executor_main_thread_mode(self): - """Test executor selection in main thread mode.""" - with ( - patch.object( - environment.REFLEX_COMPILE_EXECUTOR, - "get", - return_value=ExecutorType.MAIN_THREAD, - ), - patch.object( - environment.REFLEX_COMPILE_PROCESSES, "get", return_value=None - ), - patch.object(environment.REFLEX_COMPILE_THREADS, "get", return_value=None), - ): - executor = ExecutorType.get_executor_from_environment() - - # Test the main thread executor functionality - with executor: - future = executor.submit(lambda x: x * 2, 5) - assert future.result() == 10 - - def test_get_executor_returns_executor(self): - """Test that get_executor_from_environment returns an executor.""" - # Test with default values - should return some kind of executor - executor = ExecutorType.get_executor_from_environment() - assert executor is not None - - # Test that we can use it as a context manager - with executor: - future = executor.submit(lambda: "test") - assert future.result() == "test" - - class TestUtilityFunctions: """Test utility functions.""" diff --git a/tests/units/utils/test_streaming_response.py b/tests/units/utils/test_streaming_response.py index 122af46a9b7..9ee09a7801c 100644 --- a/tests/units/utils/test_streaming_response.py +++ b/tests/units/utils/test_streaming_response.py @@ -9,47 +9,11 @@ from starlette.requests import ClientDisconnect -@pytest.mark.asyncio -async def test_disconnect_cancels_stream_task_and_runs_finish(): - """A receive-side disconnect cancels the body stream and cleanup runs once.""" - body_closed = asyncio.Event() - body_started = asyncio.Event() - on_finish = AsyncMock() - - async def body(): - try: - body_started.set() - yield b"payload" - await asyncio.Event().wait() - finally: - body_closed.set() - - async def receive(): - await body_started.wait() - return {"type": "http.disconnect"} - - async def send(_message): - await asyncio.sleep(0) - - response = DisconnectAwareStreamingResponse( - body(), - media_type="application/x-ndjson", - on_finish=on_finish, - ) - - await asyncio.wait_for( - response({"type": "http", "asgi": {"spec_version": "2.4"}}, receive, send), - timeout=1, - ) - - await asyncio.wait_for(body_closed.wait(), timeout=1) - on_finish.assert_awaited_once() - - @pytest.mark.asyncio async def test_send_oserror_raises_client_disconnect_and_closes_body(): """A send-side disconnect still raises ClientDisconnect and closes the stream.""" body_closed = asyncio.Event() + disconnect_notified = asyncio.Event() on_finish = AsyncMock() async def body(): @@ -74,12 +38,14 @@ async def send(message): body(), media_type="application/x-ndjson", on_finish=on_finish, + on_disconnect=disconnect_notified.set, ) with pytest.raises(ClientDisconnect): await response({"type": "http", "asgi": {"spec_version": "2.4"}}, receive, send) await asyncio.wait_for(body_closed.wait(), timeout=1) + assert disconnect_notified.is_set() on_finish.assert_awaited_once()