diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index 68637594f..034b9f6a9 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -120,23 +120,23 @@ def __repr__(self) -> str: return f"DeviceCache({len(self._cache)} devices)" -class DeviceRef(str): +class DeviceRef: + name: str model: DeviceModel _cache: DeviceCache - def __new__(cls, name: str, cache: DeviceCache, model: DeviceModel): - instance = super().__new__(cls, name) - instance.model = model - instance._cache = cache - return instance + def __init__(self, name: str, cache: DeviceCache, model: DeviceModel): + self.name = name + self.model = model + self._cache = cache def __getattr__(self, name) -> "DeviceRef": if name.startswith("_"): raise AttributeError(f"No child device named {name}") - return self._cache[f"{self}.{name}"] + return self._cache[f"{self.name}.{name}"] def __repr__(self): - return f"Device({self})" + return f"Device({self.name})" class Plan: diff --git a/src/blueapi/client/rest.py b/src/blueapi/client/rest.py index cacbc1656..739c62309 100644 --- a/src/blueapi/client/rest.py +++ b/src/blueapi/client/rest.py @@ -13,6 +13,7 @@ from pydantic import BaseModel, TypeAdapter, ValidationError from blueapi import __version__ +from blueapi.client import client from blueapi.config import RestConfig from blueapi.service.authentication import JWTAuth, SessionManager from blueapi.service.model import ( @@ -241,7 +242,7 @@ def create_task(self, task: TaskRequest) -> TaskResponse: TaskResponse, method="POST", get_exception=_create_task_exceptions, - data=task.model_dump(), + data=task.model_dump(fallback=_task_model_fallback), ) def clear_task(self, task_id: str) -> TaskResponse: @@ -363,3 +364,10 @@ def __getattr__(name: str): class ServiceUnavailableError(Exception): pass + + +def _task_model_fallback(obj: Any) -> Any: + """Fallback method for serializing TaskRequests""" + if isinstance(obj, client.DeviceRef): + return obj.name + raise ValueError() diff --git a/tests/unit_tests/client/test_client.py b/tests/unit_tests/client/test_client.py index 169d727f6..125fc165c 100644 --- a/tests/unit_tests/client/test_client.py +++ b/tests/unit_tests/client/test_client.py @@ -202,9 +202,9 @@ def test_get_child_device(mock_rest: Mock, client: BlueapiClient): else None ) foo = client.devices.foo - assert foo == "foo" + assert foo.name == "foo" x = client.devices.foo.x - assert x == "foo.x" + assert x.name == "foo.x" def test_state_property(client: BlueapiClient):