From 50dcf119f7853aa33e60bf0cfac6e26059014075 Mon Sep 17 00:00:00 2001 From: Luke Date: Sun, 4 Jan 2026 22:06:41 -0500 Subject: [PATCH] feat: add protocol updates --- .gitignore | 4 ++ roborock/devices/device.py | 81 ++++++++++++++++++++++++++-- roborock/devices/rpc/v1_channel.py | 14 ++--- roborock/devices/traits/v1/common.py | 21 ++++++++ 4 files changed, 111 insertions(+), 9 deletions(-) diff --git a/.gitignore b/.gitignore index de3ccf9c..f724ac37 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,7 @@ docs/_build/ # GitHub App credentials gha-creds-*.json + +# pickle files +*.p +*.pickle diff --git a/roborock/devices/device.py b/roborock/devices/device.py index 9026c4a7..40f581e6 100644 --- a/roborock/devices/device.py +++ b/roborock/devices/device.py @@ -6,16 +6,22 @@ import asyncio import datetime +import json import logging from abc import ABC from collections.abc import Callable from typing import Any from roborock.callbacks import CallbackList -from roborock.data import HomeDataDevice, HomeDataProduct +from roborock.data import HomeDataDevice, HomeDataProduct, RoborockErrorCode, RoborockStateCode from roborock.diagnostics import redact_device_data from roborock.exceptions import RoborockException -from roborock.roborock_message import RoborockMessage +from roborock.roborock_message import ( + ROBOROCK_DATA_STATUS_PROTOCOL, + RoborockDataProtocol, + RoborockMessage, + RoborockMessageProtocol, +) from roborock.util import RoborockLoggerAdapter from .traits import Trait @@ -219,8 +225,77 @@ async def close(self) -> None: self._unsub = None def _on_message(self, message: RoborockMessage) -> None: - """Handle incoming messages from the device.""" + """Handle incoming messages from the device. + + Note: Protocol updates (data points) are only sent via cloud/MQTT, not local connection. + """ self._logger.debug("Received message from device: %s", message) + if self.v1_properties is None: + # Ensure we are only doing below logic for set-up V1 devices. + return + + # Only process messages that can contain protocol updates + # RPC_RESPONSE (102), GENERAL_REQUEST (4), and GENERAL_RESPONSE (5) + if message.protocol not in { + RoborockMessageProtocol.RPC_RESPONSE, + RoborockMessageProtocol.GENERAL_RESPONSE, + }: + return + + if not message.payload: + return + + try: + payload = json.loads(message.payload.decode()) + dps = payload.get("dps", {}) + + if not dps: + return + + # Process each data point in the message + for data_point_number, data_point in dps.items(): + # Skip RPC responses (102) as they're handled by the RPC channel + if data_point_number == "102": + continue + + try: + data_protocol = RoborockDataProtocol(int(data_point_number)) + self._logger.debug(f"Got device update for {data_protocol.name}: {data_point}") + self._handle_protocol_update(data_protocol, data_point) + except ValueError: + # Unknown protocol number + self._logger.debug( + f"Got unknown data protocol {data_point_number}, data: {data_point}. " + f"This may allow for faster updates in the future." + ) + except (json.JSONDecodeError, UnicodeDecodeError, KeyError) as ex: + self._logger.debug(f"Failed to parse protocol message: {ex}") + + def _handle_protocol_update(self, protocol: RoborockDataProtocol, data_point: Any) -> None: + """Handle a protocol update for a specific data protocol. + + Args: + protocol: The data protocol number. + data_point: The data value for this protocol. + """ + # Handle status protocol updates + if protocol in ROBOROCK_DATA_STATUS_PROTOCOL and self.v1_properties and self.v1_properties.status: + # Update the specific field in the status trait + match protocol: + case RoborockDataProtocol.ERROR_CODE: + self.v1_properties.status.error_code = RoborockErrorCode(data_point) + case RoborockDataProtocol.STATE: + self.v1_properties.status.state = RoborockStateCode(data_point) + case RoborockDataProtocol.BATTERY: + self.v1_properties.status.battery = data_point + case RoborockDataProtocol.CHARGE_STATUS: + self.v1_properties.status.charge_status = data_point + case _: + # There is also fan power and water box mode, but for now those are skipped + return + + self._logger.debug("Updated status.%s to %s", protocol.name.lower(), data_point) + self.v1_properties.status.notify_update() def diagnostic_data(self) -> dict[str, Any]: """Return diagnostics information about the device.""" diff --git a/roborock/devices/rpc/v1_channel.py b/roborock/devices/rpc/v1_channel.py index d1b4ee24..608c9ecc 100644 --- a/roborock/devices/rpc/v1_channel.py +++ b/roborock/devices/rpc/v1_channel.py @@ -305,12 +305,14 @@ async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callab loop = asyncio.get_running_loop() self._reconnect_task = loop.create_task(self._background_reconnect()) - if not self.is_local_connected: - # We were not able to connect locally, so fallback to MQTT and at least - # establish that connection explicitly. If this fails then raise an - # error and let the caller know we failed to subscribe. - self._mqtt_unsub = await self._mqtt_channel.subscribe(self._on_mqtt_message) - self._logger.debug("V1Channel connected to device via MQTT") + # Always subscribe to MQTT to receive protocol updates (data points) + # even if we have a local connection. Protocol updates only come via cloud/MQTT. + # Local connection is used for RPC commands, but push notifications come via MQTT. + self._mqtt_unsub = await self._mqtt_channel.subscribe(self._on_mqtt_message) + if self.is_local_connected: + self._logger.debug("V1Channel connected via local and MQTT (for protocol updates)") + else: + self._logger.debug("V1Channel connected via MQTT only") def unsub() -> None: """Unsubscribe from all messages.""" diff --git a/roborock/devices/traits/v1/common.py b/roborock/devices/traits/v1/common.py index 63ae2e20..7313e245 100644 --- a/roborock/devices/traits/v1/common.py +++ b/roborock/devices/traits/v1/common.py @@ -3,11 +3,15 @@ This is an internal library and should not be used directly by consumers. """ +from __future__ import annotations + import logging from abc import ABC, abstractmethod +from collections.abc import Callable from dataclasses import dataclass, fields from typing import ClassVar, Self +from roborock.callbacks import CallbackList from roborock.data import RoborockBase from roborock.protocols.v1_protocol import V1RpcChannel from roborock.roborock_typing import RoborockCommand @@ -15,6 +19,7 @@ _LOGGER = logging.getLogger(__name__) V1ResponseData = dict | list | int | str +V1TraitUpdateCallback = Callable[["V1TraitMixin"], None] @dataclass @@ -74,6 +79,7 @@ def __post_init__(self) -> None: device setup code. """ self._rpc_channel = None + self._update_callbacks: CallbackList[V1TraitMixin] = CallbackList() @property def rpc_channel(self) -> V1RpcChannel: @@ -97,6 +103,21 @@ def _update_trait_values(self, new_data: RoborockBase) -> None: new_value = getattr(new_data, field.name, None) setattr(self, field.name, new_value) + def add_update_callback(self, callback: V1TraitUpdateCallback) -> Callable[[], None]: + """Add a callback to be notified when the trait is updated. + + The callback will be called with the updated trait instance whenever + a protocol message updates the trait. + + Returns: + A callable that can be used to remove the callback. + """ + return self._update_callbacks.add_callback(callback) + + def notify_update(self) -> None: + """Notify all registered callbacks that the trait has been updated.""" + self._update_callbacks(self) + def _get_value_field(clazz: type[V1TraitMixin]) -> str: """Get the name of the field marked as the main value of the RoborockValueBase."""