diff --git a/src/ypywidgets/comm.py b/src/ypywidgets/comm.py index 10f004c..4982d2a 100644 --- a/src/ypywidgets/comm.py +++ b/src/ypywidgets/comm.py @@ -2,6 +2,7 @@ import comm from pycrdt import ( + Awareness, Doc, Text, TransactionEvent, @@ -10,6 +11,7 @@ create_sync_message, create_update_message, handle_sync_message, + read_message, ) from .widget import Widget @@ -48,10 +50,15 @@ def __init__( ) -> None: self._ydoc = ydoc self._comm = comm + self._awareness = Awareness(ydoc) msg = create_sync_message(ydoc) self._comm.send(buffers=[msg]) self._comm.on_msg(self._receive) + @property + def awareness(self) -> Awareness: + return self._awareness + def _receive(self, msg): message = bytes(msg["buffers"][0]) match message[0]: @@ -61,6 +68,9 @@ def _receive(self, msg): self._comm.send(buffers=[reply]) if message[1] == YSyncMessageType.SYNC_STEP2: self._ydoc.observe(self._send) + case YMessageType.AWARENESS: + payload = read_message(message[1:]) + self._awareness.apply_awareness_update(payload, None) def _send(self, event: TransactionEvent): update = event.update @@ -86,7 +96,11 @@ def __init__( create_ydoc=not ydoc, ) self._comm = create_widget_comm(comm_data, comm_metadata, comm_id) - CommProvider(self.ydoc, self._comm) + self._comm_provider = CommProvider(self.ydoc, self._comm) + + @property + def awareness(self) -> Awareness: + return self._comm_provider.awareness def _repr_mimebundle_(self, *args, **kwargs): # pragma: nocover plaintext = repr(self) diff --git a/tests/conftest.py b/tests/conftest.py index ede26f5..b1e117c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,12 +9,14 @@ from anyio.abc import TaskGroup from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pycrdt import ( + Awareness, YMessageType, YSyncMessageType, TransactionEvent, create_sync_message, create_update_message, handle_sync_message, + read_message, ) from ypywidgets import Widget from ypywidgets.comm import CommWidget @@ -112,6 +114,7 @@ def __init__( self.remote_widget: Widget | None = None self.local_widget_created = Event() self.remote_widget_created = Event() + self._remote_awareness: Awareness | None = None context.add_task(self.receive) def send(self, event: TransactionEvent) -> None: @@ -150,6 +153,13 @@ async def receive(self) -> None: if message[1] == YSyncMessageType.SYNC_STEP2: self.sub = self.remote_widget.ydoc.observe(self.send) self.remote_widget_created.set() + case YMessageType.AWARENESS: # pragma: nocover + if self._remote_awareness is None: + self._remote_awareness = Awareness( + self.remote_widget.ydoc # pragma: no + ) + payload = read_message(bytes(message[1:])) + self._remote_awareness.apply_awareness_update(payload, None) async def get_local_widget(self, timeout: float = 0.2) -> CommWidget: with fail_after(timeout): diff --git a/tests/test_comm_awareness.py b/tests/test_comm_awareness.py new file mode 100644 index 0000000..4211eac --- /dev/null +++ b/tests/test_comm_awareness.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import pytest +from pycrdt import Awareness, Doc, YMessageType, create_awareness_message + +pytestmark = pytest.mark.anyio + + +async def test_comm_provider_applies_awareness_frame(synced_widgets, context): + async with context: + local_widget = await synced_widgets.get_local_widget() + remote_awareness = Awareness(Doc()) + remote_awareness.set_local_state({"role": "remote"}) + payload = remote_awareness.encode_awareness_update([remote_awareness.client_id]) + frame = create_awareness_message(payload) + + assert frame[0] == YMessageType.AWARENESS + + local_widget._comm_provider._receive({"buffers": [frame]}) + + remote_state = local_widget.awareness.states.get(remote_awareness.client_id) + assert remote_state is not None + assert remote_state.get("role") == "remote" + + +async def test_comm_widget_exposes_provider_awareness(synced_widgets, context): + async with context: + widget = await synced_widgets.get_local_widget() + assert widget.awareness is widget._comm_provider.awareness + + +async def test_comm_widget_awareness_observe_and_unobserve(synced_widgets, context): + async with context: + widget = await synced_widgets.get_local_widget() + + events: list[str] = [] + sub_id = widget.awareness.observe(lambda topic, _: events.append(topic)) + + widget.awareness.set_local_state({"ping": 1}) + assert events + + widget.awareness.unobserve(sub_id) + events.clear() + widget.awareness.set_local_state({"ping": 2}) + assert events == []