From 108245d6a72c35f8a4e1a2a8e4492879fb14b8a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A6rkeren?= <164513459+Faerkeren@users.noreply.github.com> Date: Wed, 27 May 2026 03:02:58 -0100 Subject: [PATCH] refactor: introduce typed SceneAccessor and TimerAccessor (#76) - Add `accessor_cls` field to `DomainSpec` so domains can declare a typed `DomainAccessor` subclass instead of relying on dynamic `setattr`-based operation binding. - Expose `DomainAccessor.factory` as a public property so subclasses can access services/state without reaching into `_factory` private attribute. - `HAClient` now instantiates `spec.accessor_cls` when set, falling back to the base `DomainAccessor` for domains without one. - Replace scene domain's dynamic `_create`/`_apply` functions with a `SceneAccessor` subclass that exposes properly typed `create()` and `apply()` methods. - Replace timer domain's dynamic `_create` function with a `TimerAccessor` subclass that exposes a properly typed `create()` method. - Remove all `# type: ignore[assignment]` comments from domain code; mypy now passes cleanly with no suppressions. - Add six new tests covering the `factory` property, typed accessor instantiation, and the `accessor_cls` dispatch path. --- src/haclient/api.py | 3 +- src/haclient/core/plugins.py | 37 +++++++-- src/haclient/domains/scene.py | 150 +++++++++++++++++----------------- src/haclient/domains/timer.py | 127 +++++++++++++++------------- tests/test_plugins.py | 71 ++++++++++++++++ 5 files changed, 247 insertions(+), 141 deletions(-) diff --git a/src/haclient/api.py b/src/haclient/api.py index 66dc896..64ff100 100644 --- a/src/haclient/api.py +++ b/src/haclient/api.py @@ -128,7 +128,8 @@ def __init__( active = self._select_active_domains(domains) self._accessors: dict[str, DomainAccessor[Any]] = {} for spec in active: - accessor: DomainAccessor[Any] = DomainAccessor(spec, self._factory) + cls = spec.accessor_cls if spec.accessor_cls is not None else DomainAccessor + accessor: DomainAccessor[Any] = cls(spec, self._factory) self._accessors[spec.accessor_name()] = accessor self._accessors[spec.name] = accessor for event_type in spec.event_subscriptions: diff --git a/src/haclient/core/plugins.py b/src/haclient/core/plugins.py index fc5f7f0..250c91b 100644 --- a/src/haclient/core/plugins.py +++ b/src/haclient/core/plugins.py @@ -8,7 +8,13 @@ ``ha.light`` or ``ha.scene``). It provides: * ``__call__(name)`` and ``__getitem__(name)`` for entity lookup. -* Domain-level operations registered by the spec via ``operations``. +* Domain-level operations registered by the spec via ``operations`` (legacy + third-party path) **or** via typed subclass methods (preferred path). + +Domains with collection-level operations should subclass `DomainAccessor` +and register the subclass via ``DomainSpec.accessor_cls``. This keeps the +public API statically typed without requiring ``# type: ignore`` workarounds +or private ``_factory`` access from outside the accessor. Third-party plugins can ship additional domains by exposing an entry point under the ``haclient.domains`` group; see @@ -63,10 +69,14 @@ class DomainSpec(Generic[E]): on_event : callable or None Per-domain event handler (see `DomainEventHandler`). operations : dict - Domain-level async operations registered on the - `DomainAccessor`. Each value is an async callable; it will be - bound to the accessor so the first positional argument *is* the - accessor instance. + Legacy dynamic operation dict kept for third-party plugin + compatibility. Built-in domains with collection-level operations + should prefer ``accessor_cls`` instead. + accessor_cls : type[DomainAccessor] or None + Optional typed `DomainAccessor` subclass to instantiate for this + domain. When provided, the ``HAClient`` uses this class rather than + the base `DomainAccessor`, exposing properly typed collection-level + methods (e.g. ``SceneAccessor.create``). """ name: str @@ -75,6 +85,7 @@ class DomainSpec(Generic[E]): event_subscriptions: tuple[str, ...] = () on_event: DomainEventHandler | None = None operations: dict[str, Callable[..., Any]] = field(default_factory=dict) + accessor_cls: type[DomainAccessor[Any]] | None = None def accessor_name(self) -> str: """Return the accessor attribute name (defaults to ``name``).""" @@ -87,14 +98,14 @@ class DomainAccessor(Generic[E]): Returned by ``HAClient.``. Exposes: * Lookup by short name: ``ha.light("kitchen")`` or ``ha.light["kitchen"]``. - * Domain-level operations registered on the spec, bound to this accessor: - ``await ha.scene.create("romantic", ...)``. + * Domain-level operations either via typed subclass methods (preferred) or + via legacy dynamic binding of ``spec.operations`` entries. Parameters ---------- spec : DomainSpec The spec describing this domain. - factory : EntityFactory + factory : EntityFactoryProtocol Factory used to create entity instances on demand. """ @@ -103,6 +114,7 @@ def __init__(self, spec: DomainSpec[E], factory: EntityFactoryProtocol) -> None: self._factory = factory for op_name, op in spec.operations.items(): # Bind each operation as an attribute on the instance. + # This path is kept for backward-compatible third-party plugins. setattr(self, op_name, self._bind(op)) @property @@ -110,6 +122,15 @@ def spec(self) -> DomainSpec[E]: """Return the underlying `DomainSpec`.""" return self._spec + @property + def factory(self) -> EntityFactoryProtocol: + """Return the `EntityFactoryProtocol` used to create entities. + + Subclasses use this to access ``factory.services`` and + ``factory.state`` without reaching into private internals. + """ + return self._factory + def _bind(self, op: Callable[..., Any]) -> Callable[..., Any]: """Bind a domain operation to this accessor. diff --git a/src/haclient/domains/scene.py b/src/haclient/domains/scene.py index 0d29085..7af30dc 100644 --- a/src/haclient/domains/scene.py +++ b/src/haclient/domains/scene.py @@ -6,28 +6,24 @@ Domain-level operations ----------------------- Beyond per-entity actions, the scene domain exposes two collection-level -operations on the `DomainAccessor`: +operations on the `SceneAccessor` (returned by ``ha.scene``): -* ``create(scene_id, entities, *, snapshot_entities=None) -> Scene`` — - create (or update) a runtime scene helper. -* ``apply(entities, *, transition=None) -> None`` — apply a state - combination without persisting it. +* ``await ha.scene.create(scene_id, entities, *, snapshot_entities=None)`` + — create (or update) a runtime scene helper, returning a `Scene`. +* ``await ha.scene.apply(entities, *, transition=None)`` + — apply a state combination without persisting it. -These are invoked as ``await ha.scene.create(...)`` and -``await ha.scene.apply(...)``. Per-entity access still works through the -usual ``ha.scene("name")`` / ``ha.scene["name"]`` syntax. +Per-entity access still works through the usual +``ha.scene("name")`` / ``ha.scene["name"]`` syntax. """ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import Any from haclient.core.plugins import DomainAccessor, DomainSpec, register_domain from haclient.entity.base import Entity, ValueChangeHandler -if TYPE_CHECKING: - from haclient.core.factory import EntityFactory - class Scene(Entity): """A Home Assistant scene entity. @@ -108,74 +104,82 @@ def on_activate(self, func: ValueChangeHandler) -> ValueChangeHandler: return self._register_state_value_listener(func) -# -- Domain-level operations -------------------------------------------- - - -async def _create( - accessor: DomainAccessor[Scene], - scene_id: str, - entities: dict[str, dict[str, Any]], - *, - snapshot_entities: list[str] | None = None, -) -> Scene: - """Create a runtime scene and return the corresponding `Scene`. - - Parameters - ---------- - accessor : DomainAccessor - The scene accessor (provided automatically by the binding). - scene_id : str - Object-id for the new scene (e.g. ``"romantic"`` → - ``scene.romantic``). - entities : dict - Mapping of entity ids to state/attribute dicts. - snapshot_entities : list of str or None, optional - Entity ids whose current state should be captured. - - Returns - ------- - Scene - The newly created scene. - """ - factory: EntityFactory = accessor._factory # type: ignore[assignment] - services = factory.services - payload: dict[str, Any] = {"scene_id": scene_id, "entities": entities} - if snapshot_entities is not None: - payload["snapshot_entities"] = snapshot_entities - await services.call("scene", "create", payload) - return accessor[scene_id] - - -async def _apply( - accessor: DomainAccessor[Scene], - entities: dict[str, dict[str, Any]], - *, - transition: float | None = None, -) -> None: - """Apply a scene-like state combination without persisting it. - - Parameters - ---------- - accessor : DomainAccessor - The scene accessor. - entities : dict - Mapping of entity ids to desired state/attribute dicts. - transition : float or None, optional - Transition seconds for entities that support it. +# -- Typed domain accessor ---------------------------------------------- + + +class SceneAccessor(DomainAccessor[Scene]): + """Typed domain accessor for the ``scene`` domain. + + Returned by ``ha.scene``. Provides statically-typed collection-level + operations in addition to the standard entity lookup methods inherited + from `DomainAccessor`. """ - factory: EntityFactory = accessor._factory # type: ignore[assignment] - services = factory.services - payload: dict[str, Any] = {"entities": entities} - if transition is not None: - payload["transition"] = transition - await services.call("scene", "apply", payload) + + async def create( + self, + scene_id: str, + entities: dict[str, dict[str, Any]], + *, + snapshot_entities: list[str] | None = None, + ) -> Scene: + """Create (or update) a runtime scene helper. + + Parameters + ---------- + scene_id : str + Object-id for the new scene (e.g. ``"romantic"`` → + ``scene.romantic``). + entities : dict + Mapping of entity ids to target state/attribute dicts. + snapshot_entities : list of str or None, optional + Entity ids whose current state should be captured instead of + providing an explicit state dict. + + Returns + ------- + Scene + The newly created (or updated) scene entity. + """ + from haclient.core.factory import EntityFactory + + factory = self.factory + assert isinstance(factory, EntityFactory) + payload: dict[str, Any] = {"scene_id": scene_id, "entities": entities} + if snapshot_entities is not None: + payload["snapshot_entities"] = snapshot_entities + await factory.services.call("scene", "create", payload) + return self[scene_id] + + async def apply( + self, + entities: dict[str, dict[str, Any]], + *, + transition: float | None = None, + ) -> None: + """Apply a scene-like state combination without persisting it. + + Parameters + ---------- + entities : dict + Mapping of entity ids to desired state/attribute dicts. + transition : float or None, optional + Transition seconds for entities that support it. + """ + from haclient.core.factory import EntityFactory + + factory = self.factory + assert isinstance(factory, EntityFactory) + payload: dict[str, Any] = {"entities": entities} + if transition is not None: + payload["transition"] = transition + await factory.services.call("scene", "apply", payload) SPEC: DomainSpec[Scene] = register_domain( DomainSpec( name="scene", entity_cls=Scene, - operations={"create": _create, "apply": _apply}, + accessor_cls=SceneAccessor, ) ) """The `DomainSpec` registered with the shared `DomainRegistry`.""" diff --git a/src/haclient/domains/timer.py b/src/haclient/domains/timer.py index 8f9a60c..dd46fa4 100644 --- a/src/haclient/domains/timer.py +++ b/src/haclient/domains/timer.py @@ -21,7 +21,6 @@ from haclient.entity.base import Entity, ValueChangeHandler if TYPE_CHECKING: - from haclient.core.factory import EntityFactory from haclient.core.services import ServiceCaller from haclient.core.state import StateStore from haclient.ports import Clock @@ -326,69 +325,79 @@ def _handle_timer_event(self, event_type: str, data: dict[str, Any]) -> None: self._schedule_value(listener, self.entity_id, data) -# -- Domain-level operations & event handler -------------------------- +# -- Typed domain accessor & event handler ---------------------------- -async def _create( - accessor: DomainAccessor[Timer], - *, - name: str | None = None, - duration: str = "00:01:00", - persistent: bool = False, -) -> Timer: - """Create a library-managed timer helper in Home Assistant. +class TimerAccessor(DomainAccessor[Timer]): + """Typed domain accessor for the ``timer`` domain. - Sends a ``timer/create`` WebSocket command and returns a `Timer`. + Returned by ``ha.timer``. Provides a statically-typed + :meth:`create` method in addition to the standard entity lookup + methods inherited from `DomainAccessor`. + """ - Parameters - ---------- - accessor : DomainAccessor - The timer accessor (provided automatically by the binding). - name : str or None, optional - Short object-id; auto-generated when omitted (only allowed for - ephemeral timers). - duration : str, optional - Initial duration for the helper. - persistent : bool, optional - If ``True``, the HA helper is **not** deleted on idle. - Requires an explicit *name*. + async def create( + self, + *, + name: str | None = None, + duration: str = "00:01:00", + persistent: bool = False, + ) -> Timer: + """Create a library-managed timer helper in Home Assistant. - Returns - ------- - Timer - The newly created timer entity. + Sends a ``timer/create`` WebSocket command and returns a `Timer`. - Raises - ------ - ValueError - If ``persistent=True`` and *name* is ``None``. - """ - if name is None: - if persistent: - raise ValueError("Persistent timers require an explicit name") - name = _generate_timer_id() - - factory: EntityFactory = accessor._factory # type: ignore[assignment] - services = factory.services - state = factory.state - entity_id = state.registry.resolve("timer", name) - existing = state.registry.get(entity_id) - timer: Timer - if existing is not None and isinstance(existing, Timer): - timer = existing - if timer._ensured: - return timer - else: - timer = accessor[name] - - timer._persistent = persistent - object_id = entity_id.split(".", 1)[1] - await services.ws.send_command( - {"type": "timer/create", "name": object_id, "duration": duration} - ) - timer._ensured = True - timer._created_by_us = True - return timer + Parameters + ---------- + name : str or None, optional + Short object-id; auto-generated when omitted (only allowed for + ephemeral timers). + duration : str, optional + Initial duration for the helper in HA format (e.g. + ``"00:01:00"``). + persistent : bool, optional + If ``True``, the HA helper is **not** deleted on idle. + Requires an explicit *name*. + + Returns + ------- + Timer + The newly created (or already-ensured) timer entity. + + Raises + ------ + ValueError + If ``persistent=True`` and *name* is ``None``. + """ + from haclient.core.factory import EntityFactory + + if name is None: + if persistent: + raise ValueError("Persistent timers require an explicit name") + name = _generate_timer_id() + + factory = self.factory + assert isinstance(factory, EntityFactory) + services = factory.services + state = factory.state + entity_id = state.registry.resolve("timer", name) + existing = state.registry.get(entity_id) + timer: Timer + if existing is not None and isinstance(existing, Timer): + timer = existing + if timer._ensured: + return timer + else: + timer = self[name] + + timer._persistent = persistent + object_id = entity_id.split(".", 1)[1] + await services.ws.send_command( + {"type": "timer/create", "name": object_id, "duration": duration} + ) + timer._ensured = True + timer._created_by_us = True + return timer def _on_timer_event(entity: Entity, event_type: str, data: dict[str, Any]) -> None: @@ -407,7 +416,7 @@ def _on_timer_event(entity: Entity, event_type: str, data: dict[str, Any]) -> No entity_cls=Timer, event_subscriptions=("timer.finished", "timer.cancelled"), on_event=_on_timer_event, - operations={"create": _create}, + accessor_cls=TimerAccessor, ) ) """The `DomainSpec` registered with the shared `DomainRegistry`.""" diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 477435c..9009c50 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -199,3 +199,74 @@ async def test_domain_accessor_spec_property() -> None: assert accessor.spec.name == "light" finally: await ha.close() + + +async def test_domain_accessor_factory_property() -> None: + """``DomainAccessor.factory`` returns the EntityFactoryProtocol instance.""" + from haclient.core.factory import EntityFactory + + ha = HAClient.from_url("http://x", token="t", load_plugins=False) + try: + accessor = ha.domain("light") + assert isinstance(accessor.factory, EntityFactory) + finally: + await ha.close() + + +async def test_scene_accessor_is_typed_subclass() -> None: + """``ha.scene`` returns a ``SceneAccessor`` with typed methods.""" + from haclient.domains.scene import SceneAccessor + + ha = HAClient.from_url("http://x", token="t", load_plugins=False) + try: + accessor = ha.domain("scene") + assert isinstance(accessor, SceneAccessor) + assert callable(accessor.create) + assert callable(accessor.apply) + finally: + await ha.close() + + +async def test_timer_accessor_is_typed_subclass() -> None: + """``ha.timer`` returns a ``TimerAccessor`` with a typed ``create`` method.""" + from haclient.domains.timer import TimerAccessor + + ha = HAClient.from_url("http://x", token="t", load_plugins=False) + try: + accessor = ha.domain("timer") + assert isinstance(accessor, TimerAccessor) + assert callable(accessor.create) + finally: + await ha.close() + + +async def test_accessor_cls_in_spec_is_used() -> None: + """When ``DomainSpec.accessor_cls`` is set, it is instantiated by the client.""" + + class _CustomAccessor(DomainAccessor[_Custom]): + def special(self) -> str: + return "typed" + + reg = DomainRegistry() + spec = DomainSpec(name="cls_test", entity_cls=_Custom, accessor_cls=_CustomAccessor) + reg.register(spec) + ha = HAClient.from_url("http://x", token="t", load_plugins=False, registry=reg) + try: + accessor = ha.domain("cls_test") + assert isinstance(accessor, _CustomAccessor) + assert accessor.special() == "typed" + finally: + await ha.close() + + +async def test_accessor_cls_none_falls_back_to_base() -> None: + """When ``DomainSpec.accessor_cls`` is ``None``, the base ``DomainAccessor`` is used.""" + reg = DomainRegistry() + spec = DomainSpec(name="base_test", entity_cls=_Custom) + reg.register(spec) + ha = HAClient.from_url("http://x", token="t", load_plugins=False, registry=reg) + try: + accessor = ha.domain("base_test") + assert type(accessor) is DomainAccessor + finally: + await ha.close()