Skip to content

Commit db8d490

Browse files
authored
refactor: Move numtracker into the BlueskyContext (#1200)
Moving the numtracker into the bluesky context means we don't need to do the 'create context -> create numtracker -> configure context -> check path provider` shenanigans in the top level setup function and the path provider can be available to pass to device factories if we move to a dodal-less approach to devices.
1 parent 0447ea1 commit db8d490

File tree

3 files changed

+86
-105
lines changed

3 files changed

+86
-105
lines changed

src/blueapi/core/context.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import logging
22
from collections.abc import Callable
3-
from dataclasses import dataclass, field
3+
from dataclasses import InitVar, dataclass, field
44
from importlib import import_module
55
from inspect import Parameter, signature
66
from types import ModuleType, NoneType, UnionType
77
from typing import Any, Generic, TypeVar, Union, get_args, get_origin, get_type_hints
88

99
from bluesky.protocols import HasName
1010
from bluesky.run_engine import RunEngine
11+
from dodal.common.beamlines.beamline_utils import get_path_provider, set_path_provider
1112
from dodal.utils import make_all_devices
1213
from ophyd_async.core import NotConnected
1314
from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler, create_model
@@ -16,12 +17,15 @@
1617
from pydantic_core import CoreSchema, core_schema
1718

1819
from blueapi import utils
19-
from blueapi.config import EnvironmentConfig, SourceKind
20+
from blueapi.client.numtracker import NumtrackerClient
21+
from blueapi.config import ApplicationConfig, EnvironmentConfig, SourceKind
2022
from blueapi.utils import (
2123
BlueapiPlanModelConfig,
2224
is_function_sourced_from_module,
2325
load_module_all,
2426
)
27+
from blueapi.utils.invalid_config_error import InvalidConfigError
28+
from blueapi.utils.path_provider import StartDocumentPathProvider
2529

2630
from .bluesky_types import (
2731
BLUESKY_PROTOCOLS,
@@ -86,15 +90,57 @@ class BlueskyContext:
8690
The context holds the RunEngine and any plans/devices that you may want to use.
8791
"""
8892

93+
configuration: InitVar[ApplicationConfig | None] = None
94+
8995
run_engine: RunEngine = field(
9096
default_factory=lambda: RunEngine(context_managers=[])
9197
)
98+
numtracker: NumtrackerClient | None = field(default=None, init=False, repr=False)
9299
plans: dict[str, Plan] = field(default_factory=dict)
93100
devices: dict[str, Device] = field(default_factory=dict)
94101
plan_functions: dict[str, PlanGenerator] = field(default_factory=dict)
95102

96103
_reference_cache: dict[type, type] = field(default_factory=dict)
97104

105+
def __post_init__(self, configuration: ApplicationConfig | None):
106+
if not configuration:
107+
return
108+
109+
if configuration.numtracker is not None:
110+
if configuration.env.metadata is not None:
111+
self.numtracker = NumtrackerClient(url=configuration.numtracker.url)
112+
else:
113+
raise InvalidConfigError(
114+
"Numtracker url has been configured, but there is no instrument or"
115+
" instrument_session in the environment metadata"
116+
)
117+
118+
if self.numtracker is not None:
119+
numtracker = self.numtracker
120+
121+
path_provider = StartDocumentPathProvider()
122+
set_path_provider(path_provider)
123+
self.run_engine.subscribe(path_provider.update_run, "start")
124+
125+
def _update_scan_num(md: dict[str, Any]) -> int:
126+
scan = numtracker.create_scan(
127+
md["instrument_session"], md["instrument"]
128+
)
129+
md["data_session_directory"] = str(scan.scan.directory.path)
130+
md["scan_file"] = scan.scan.scan_file
131+
return scan.scan.scan_number
132+
133+
self.run_engine.scan_id_source = _update_scan_num
134+
135+
self.with_config(configuration.env)
136+
if self.numtracker and not isinstance(
137+
get_path_provider(), StartDocumentPathProvider
138+
):
139+
raise InvalidConfigError(
140+
"Numtracker has been configured but a path provider was imported with "
141+
"the devices. Remove this path provider to use numtracker."
142+
)
143+
98144
def find_device(self, addr: str | list[str]) -> Device | None:
99145
"""
100146
Find a device in this context, allows for recursive search.

src/blueapi/service/interface.py

Lines changed: 3 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,8 @@
44

55
from bluesky_stomp.messaging import StompClient
66
from bluesky_stomp.models import Broker, DestinationBase, MessageTopic
7-
from dodal.common.beamlines.beamline_utils import (
8-
get_path_provider,
9-
set_path_provider,
10-
)
117

128
from blueapi.cli.scratch import get_python_environment
13-
from blueapi.client.numtracker import NumtrackerClient
149
from blueapi.config import ApplicationConfig, OIDCConfig, StompConfig
1510
from blueapi.core.context import BlueskyContext
1611
from blueapi.core.event import EventStream
@@ -23,8 +18,6 @@
2318
TaskRequest,
2419
WorkerTask,
2520
)
26-
from blueapi.utils.invalid_config_error import InvalidConfigError
27-
from blueapi.utils.path_provider import StartDocumentPathProvider
2821
from blueapi.worker.event import TaskStatusEnum, WorkerState
2922
from blueapi.worker.task import Task
3023
from blueapi.worker.task_worker import TaskWorker, TrackableTask
@@ -48,14 +41,10 @@ def set_config(new_config: ApplicationConfig):
4841

4942
@cache
5043
def context() -> BlueskyContext:
51-
ctx = BlueskyContext()
44+
ctx = BlueskyContext(config())
5245
return ctx
5346

5447

55-
def configure_context() -> None:
56-
context().with_config(config().env)
57-
58-
5948
@cache
6049
def worker() -> TaskWorker:
6150
worker = TaskWorker(
@@ -96,76 +85,23 @@ def stomp_client() -> StompClient | None:
9685
return None
9786

9887

99-
@cache
100-
def numtracker_client() -> NumtrackerClient | None:
101-
conf = config()
102-
if conf.numtracker is not None:
103-
if conf.env.metadata is not None:
104-
return NumtrackerClient(url=conf.numtracker.url)
105-
else:
106-
raise InvalidConfigError(
107-
"Numtracker url has been configured, but there is no instrument or"
108-
" instrument_session in the environment metadata"
109-
)
110-
else:
111-
return None
112-
113-
114-
def _update_scan_num(md: dict[str, Any]) -> int:
115-
numtracker = numtracker_client()
116-
if numtracker is not None:
117-
scan = numtracker.create_scan(md["instrument_session"], md["instrument"])
118-
md["data_session_directory"] = str(scan.scan.directory.path)
119-
md["scan_file"] = scan.scan.scan_file
120-
return scan.scan.scan_number
121-
else:
122-
raise InvalidConfigError(
123-
"Blueapi was configured to talk to numtracker but numtracker is not"
124-
"configured, this should not happen, please contact the DAQ team"
125-
)
126-
127-
12888
def setup(config: ApplicationConfig) -> None:
12989
"""Creates and starts a worker with supplied config"""
13090
set_config(config)
13191
set_up_logging(config.logging)
13292

13393
# Eagerly initialize worker and messaging connection
13494
worker()
135-
136-
# if numtracker is configured, use a StartDocumentPathProvider
137-
if numtracker_client() is not None:
138-
context().run_engine.scan_id_source = _update_scan_num
139-
_hook_run_engine_and_path_provider()
140-
141-
configure_context()
142-
143-
if numtracker_client() is not None and not isinstance(
144-
get_path_provider(), StartDocumentPathProvider
145-
):
146-
raise InvalidConfigError(
147-
"Numtracker has been configured but a path provider was imported"
148-
" with the devices. Remove this path provider to use numtracker."
149-
)
150-
15195
stomp_client()
15296

15397

154-
def _hook_run_engine_and_path_provider() -> None:
155-
path_provider = StartDocumentPathProvider()
156-
set_path_provider(path_provider)
157-
run_engine = context().run_engine
158-
run_engine.subscribe(path_provider.update_run, "start")
159-
160-
16198
def teardown() -> None:
16299
worker().stop()
163100
if (stomp_client_ref := stomp_client()) is not None:
164101
stomp_client_ref.disconnect()
165102
context.cache_clear()
166103
worker.cache_clear()
167104
stomp_client.cache_clear()
168-
numtracker_client.cache_clear()
169105

170106

171107
def _publish_event_streams(
@@ -224,19 +160,13 @@ def begin_task(
224160
task: WorkerTask, pass_through_headers: Mapping[str, str] | None = None
225161
) -> WorkerTask:
226162
"""Trigger a task. Will fail if the worker is busy"""
227-
_try_configure_numtracker(pass_through_headers or {})
228-
163+
if nt := context().numtracker:
164+
nt.set_headers(pass_through_headers or {})
229165
if task.task_id is not None:
230166
worker().begin_task(task.task_id)
231167
return task
232168

233169

234-
def _try_configure_numtracker(pass_through_headers: Mapping[str, str]) -> None:
235-
numtracker = numtracker_client()
236-
if numtracker is not None:
237-
numtracker.set_headers(pass_through_headers)
238-
239-
240170
def get_tasks_by_status(status: TaskStatusEnum) -> list[TrackableTask]:
241171
"""Retrieve a list of tasks based on their status."""
242172
return worker().get_tasks_by_status(status)

tests/unit_tests/service/test_interface.py

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -256,15 +256,15 @@ def mock_tasks_by_status(status: TaskStatusEnum) -> list[TrackableTask]:
256256
assert interface.get_tasks_by_status(TaskStatusEnum.COMPLETE) == []
257257

258258

259-
@patch("blueapi.service.interface._try_configure_numtracker")
259+
@patch("blueapi.service.interface.BlueskyContext.numtracker")
260260
@patch("blueapi.service.interface.TaskWorker.begin_task")
261-
def test_begin_task_with_headers(worker_mock: MagicMock, mock_configure: MagicMock):
261+
def test_begin_task_with_headers(worker_mock: MagicMock, mock_numtracker: MagicMock):
262262
uuid_value = "350043fd-597e-41a7-9a92-5d5478232cf7"
263263
task = WorkerTask(task_id=uuid_value)
264264
headers = {"a": "b"}
265265

266266
returned_task = interface.begin_task(task, headers)
267-
mock_configure.assert_called_once_with(headers)
267+
mock_numtracker.set_headers.assert_called_once_with(headers)
268268

269269
assert task == returned_task
270270
worker_mock.assert_called_once_with(uuid_value)
@@ -406,10 +406,10 @@ def test_configure_numtracker():
406406
)
407407
interface.set_config(conf)
408408
headers = {"a": "b"}
409-
interface._try_configure_numtracker(headers)
410-
nt = interface.numtracker_client()
409+
nt = interface.context().numtracker
411410

412411
assert isinstance(nt, NumtrackerClient)
412+
nt.set_headers(headers)
413413
assert nt._headers == {"a": "b"}
414414
assert nt._url.unicode_string() == "https://numtracker-example.com/graphql"
415415

@@ -443,37 +443,36 @@ def test_headers_are_cleared(mock_post):
443443
headers = {"foo": "bar"}
444444

445445
interface.begin_task(task=WorkerTask(task_id=None), pass_through_headers=headers)
446-
interface._update_scan_num({"instrument_session": "cm12345-1", "instrument": "p46"})
446+
ctx = interface.context()
447+
assert ctx.run_engine.scan_id_source is not None
448+
ctx.run_engine.scan_id_source(
449+
{"instrument_session": "cm12345-1", "instrument": "p46"}
450+
)
447451
mock_post.assert_called_once()
448452
assert mock_post.call_args.kwargs["headers"] == headers
449453

450454
interface.begin_task(task=WorkerTask(task_id=None))
451-
interface._update_scan_num({"instrument_session": "cm12345-1", "instrument": "p46"})
455+
ctx.run_engine.scan_id_source(
456+
{"instrument_session": "cm12345-1", "instrument": "p46"}
457+
)
452458
assert mock_post.call_count == 2
453459
assert mock_post.call_args.kwargs["headers"] == {}
454460

455461

456-
def test_configure_numtracker_with_no_numtracker_config_fails():
462+
def test_numtracker_requires_instrument_metadata():
457463
conf = ApplicationConfig(
458-
env=EnvironmentConfig(metadata=MetadataConfig(instrument="p46")),
464+
numtracker=NumtrackerConfig(
465+
url=HttpUrl("https://numtracker-example.com/graphql"),
466+
)
459467
)
460468
interface.set_config(conf)
461-
headers = {"a": "b"}
462-
interface._try_configure_numtracker(headers)
463-
nt = interface.numtracker_client()
464-
465-
assert nt is None
466-
467-
468-
def test_configure_numtracker_with_no_metadata_fails():
469-
conf = ApplicationConfig(numtracker=NumtrackerConfig())
470-
interface.set_config(conf)
471-
headers = {"a": "b"}
472-
473-
assert conf.env.metadata is None
474-
469+
print("Post config")
475470
with pytest.raises(InvalidConfigError):
476-
interface._try_configure_numtracker(headers)
471+
interface.context()
472+
473+
# Clearing the config here prevents the same exception as above being
474+
# raised in the ensure_worker_stopped fixture
475+
interface.set_config(ApplicationConfig())
477476

478477

479478
def test_setup_without_numtracker_with_existing_provider_does_not_overwrite_provider():
@@ -506,7 +505,6 @@ def test_setup_with_numtracker_makes_start_document_provider():
506505
path_provider = get_path_provider()
507506

508507
assert isinstance(path_provider, StartDocumentPathProvider)
509-
assert interface.context().run_engine.scan_id_source == interface._update_scan_num
510508

511509
clear_path_provider()
512510

@@ -545,12 +543,15 @@ def test_numtracker_create_scan_called_with_arguments_from_metadata(mock_create_
545543
)
546544
interface.set_config(conf)
547545
ctx = interface.context()
548-
interface.configure_context()
549546

550547
headers = {"a": "b"}
551-
interface._try_configure_numtracker(headers)
548+
549+
assert ctx.numtracker is not None
550+
assert ctx.run_engine.scan_id_source is not None
551+
552+
ctx.numtracker.set_headers(headers)
552553
ctx.run_engine.md["instrument_session"] = "ab123"
553-
interface._update_scan_num(ctx.run_engine.md)
554+
ctx.run_engine.scan_id_source(ctx.run_engine.md)
554555

555556
mock_create_scan.assert_called_once_with("ab123", "p46")
556557

@@ -567,8 +568,10 @@ def test_update_scan_num_side_effect_sets_data_session_directory_in_re_md(
567568
interface.setup(conf)
568569
ctx = interface.context()
569570

571+
assert ctx.run_engine.scan_id_source is not None
572+
570573
ctx.run_engine.md["instrument_session"] = "ab123"
571-
interface._update_scan_num(ctx.run_engine.md)
574+
ctx.run_engine.scan_id_source(ctx.run_engine.md)
572575

573576
assert (
574577
ctx.run_engine.md["data_session_directory"] == "/exports/mybeamline/data/2025"
@@ -587,7 +590,9 @@ def test_update_scan_num_side_effect_sets_scan_file_in_re_md(
587590
interface.setup(conf)
588591
ctx = interface.context()
589592

593+
assert ctx.run_engine.scan_id_source is not None
594+
590595
ctx.run_engine.md["instrument_session"] = "ab123"
591-
interface._update_scan_num(ctx.run_engine.md)
596+
ctx.run_engine.scan_id_source(ctx.run_engine.md)
592597

593598
assert ctx.run_engine.md["scan_file"] == "p46-11"

0 commit comments

Comments
 (0)