Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/blueapi/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,23 +120,23 @@ def __repr__(self) -> str:
return f"DeviceCache({len(self._cache)} devices)"


class DeviceRef(str):
class DeviceRef:
name: str
model: DeviceModel
_cache: DeviceCache

def __new__(cls, name: str, cache: DeviceCache, model: DeviceModel):
instance = super().__new__(cls, name)
instance.model = model
instance._cache = cache
return instance
def __init__(self, name: str, cache: DeviceCache, model: DeviceModel):
self.name = name
self.model = model
self._cache = cache

def __getattr__(self, name) -> "DeviceRef":
if name.startswith("_"):
raise AttributeError(f"No child device named {name}")
return self._cache[f"{self}.{name}"]
return self._cache[f"{self.name}.{name}"]

def __repr__(self):
return f"Device({self})"
return f"Device({self.name})"


class Plan:
Expand Down
10 changes: 9 additions & 1 deletion src/blueapi/client/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pydantic import BaseModel, TypeAdapter, ValidationError

from blueapi import __version__
from blueapi.client import client
from blueapi.config import RestConfig
from blueapi.service.authentication import JWTAuth, SessionManager
from blueapi.service.model import (
Expand Down Expand Up @@ -241,7 +242,7 @@ def create_task(self, task: TaskRequest) -> TaskResponse:
TaskResponse,
method="POST",
get_exception=_create_task_exceptions,
data=task.model_dump(),
data=task.model_dump(fallback=_task_model_fallback),
)

def clear_task(self, task_id: str) -> TaskResponse:
Expand Down Expand Up @@ -363,3 +364,10 @@ def __getattr__(name: str):

class ServiceUnavailableError(Exception):
pass


def _task_model_fallback(obj: Any) -> Any:
"""Fallback method for serializing TaskRequests"""
if isinstance(obj, client.DeviceRef):
return obj.name
raise ValueError()
4 changes: 2 additions & 2 deletions tests/unit_tests/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,9 @@ def test_get_child_device(mock_rest: Mock, client: BlueapiClient):
else None
)
foo = client.devices.foo
assert foo == "foo"
assert foo.name == "foo"
x = client.devices.foo.x
assert x == "foo.x"
assert x.name == "foo.x"


def test_state_property(client: BlueapiClient):
Expand Down
Loading