From db9086c65b1bd87d5cbaa732c4bdb227da3194cc Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Fri, 12 Jun 2026 15:10:15 +0100 Subject: [PATCH] refactor: Stop DeviceRef extending str This was originally to make serialization behave correctly when the task was passed to requests. As we are using pydantic to dump the model to JSON before passing it, we can add the custom serialization there to get the device-to-string conversion that we need. --- src/blueapi/client/client.py | 16 ++++++++-------- src/blueapi/client/rest.py | 10 +++++++++- tests/unit_tests/client/test_client.py | 4 ++-- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index 68637594f..034b9f6a9 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -120,23 +120,23 @@ def __repr__(self) -> str: return f"DeviceCache({len(self._cache)} devices)" -class DeviceRef(str): +class DeviceRef: + name: str model: DeviceModel _cache: DeviceCache - def __new__(cls, name: str, cache: DeviceCache, model: DeviceModel): - instance = super().__new__(cls, name) - instance.model = model - instance._cache = cache - return instance + def __init__(self, name: str, cache: DeviceCache, model: DeviceModel): + self.name = name + self.model = model + self._cache = cache def __getattr__(self, name) -> "DeviceRef": if name.startswith("_"): raise AttributeError(f"No child device named {name}") - return self._cache[f"{self}.{name}"] + return self._cache[f"{self.name}.{name}"] def __repr__(self): - return f"Device({self})" + return f"Device({self.name})" class Plan: diff --git a/src/blueapi/client/rest.py b/src/blueapi/client/rest.py index cacbc1656..739c62309 100644 --- a/src/blueapi/client/rest.py +++ b/src/blueapi/client/rest.py @@ -13,6 +13,7 @@ from pydantic import BaseModel, TypeAdapter, ValidationError from blueapi import __version__ +from blueapi.client import client from blueapi.config import RestConfig from blueapi.service.authentication import JWTAuth, SessionManager from blueapi.service.model import ( @@ -241,7 +242,7 @@ def create_task(self, task: TaskRequest) -> TaskResponse: TaskResponse, method="POST", get_exception=_create_task_exceptions, - data=task.model_dump(), + data=task.model_dump(fallback=_task_model_fallback), ) def clear_task(self, task_id: str) -> TaskResponse: @@ -363,3 +364,10 @@ def __getattr__(name: str): class ServiceUnavailableError(Exception): pass + + +def _task_model_fallback(obj: Any) -> Any: + """Fallback method for serializing TaskRequests""" + if isinstance(obj, client.DeviceRef): + return obj.name + raise ValueError() diff --git a/tests/unit_tests/client/test_client.py b/tests/unit_tests/client/test_client.py index 169d727f6..125fc165c 100644 --- a/tests/unit_tests/client/test_client.py +++ b/tests/unit_tests/client/test_client.py @@ -202,9 +202,9 @@ def test_get_child_device(mock_rest: Mock, client: BlueapiClient): else None ) foo = client.devices.foo - assert foo == "foo" + assert foo.name == "foo" x = client.devices.foo.x - assert x == "foo.x" + assert x.name == "foo.x" def test_state_property(client: BlueapiClient):