diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index efb225a..72abb3d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -20,7 +20,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, macos-latest, windows-latest] - python-version: [ '3.8', '3.9', '3.10', '3.11', '3.12' ] + python-version: [ '3.10', '3.11', '3.12', '3.13', '3.14' ] steps: - name: Checkout @@ -30,20 +30,22 @@ jobs: with: python-version: ${{ matrix.python-version }} - - name: Upgrade pip - run: python3 -m pip install --upgrade pip - - name: Install ypywidgets in dev mode - run: pip install -e ".[dev]" + run: pip install -e . --group dev + + - name: Lint + run: | + ruff format --check src tests + ruff check src tests - name: Check types - run: mypy src + run: mypy src tests - name: Run tests run: pytest ./tests -v --color=yes - name: Run code coverage - if: ${{ (matrix.python-version == '3.12') && (matrix.os == 'ubuntu-latest') }} + if: ${{ (matrix.python-version == '3.14') && (matrix.os == 'ubuntu-latest') }} run: | coverage run -m pytest tests coverage report --fail-under=100 diff --git a/pyproject.toml b/pyproject.toml index 428c59e..d773051 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ version = "0.9.7" description = "Y-based Jupyter widgets for Python" readme = "README.md" license = "MIT" -requires-python = ">=3.8" +requires-python = ">=3.10" authors = [ { name = "David Brochart", email = "david.brochart@gmail.com" }, ] @@ -24,14 +24,16 @@ dependencies = [ ] [project.urls] -Homepage = "https://github.com/davidbrochart/ypywidgets" +Homepage = "https://github.com/QuantStack/ypywidgets" -[project.optional-dependencies] +[dependency-groups] dev = [ + "anyio", "coverage >=7.0.0,<8.0.0", "mypy", "pytest", - "pytest-asyncio", + "ruff", + "trio", ] [tool.hatch.build.targets.wheel] diff --git a/src/ypywidgets/comm.py b/src/ypywidgets/comm.py index 89a5a89..10f004c 100644 --- a/src/ypywidgets/comm.py +++ b/src/ypywidgets/comm.py @@ -54,12 +54,13 @@ def __init__( def _receive(self, msg): message = bytes(msg["buffers"][0]) - if message[0] == YMessageType.SYNC: - reply = handle_sync_message(message[1:], self._ydoc) - if reply is not None: - self._comm.send(buffers=[reply]) - if message[1] == YSyncMessageType.SYNC_STEP2: - self._ydoc.observe(self._send) + match message[0]: + case YMessageType.SYNC: + reply = handle_sync_message(message[1:], self._ydoc) + if reply is not None: + self._comm.send(buffers=[reply]) + if message[1] == YSyncMessageType.SYNC_STEP2: + self._ydoc.observe(self._send) def _send(self, event: TransactionEvent): update = event.update @@ -69,12 +70,12 @@ def _send(self, event: TransactionEvent): class CommWidget(Widget): def __init__( - self, - ydoc: Doc | None = None, - comm_data: dict | None = None, - comm_metadata: dict | None = None, - comm_id: str | None = None, - ): + self, + ydoc: Doc | None = None, + comm_data: dict | None = None, + comm_metadata: dict | None = None, + comm_id: str | None = None, + ): super().__init__(ydoc) model_name = self.__class__.__name__ _model_name = self.ydoc["_model_name"] = Text() @@ -90,13 +91,13 @@ def __init__( def _repr_mimebundle_(self, *args, **kwargs): # pragma: nocover plaintext = repr(self) if len(plaintext) > 110: - plaintext = plaintext[:110] + '…' + plaintext = plaintext[:110] + "…" data = { "text/plain": plaintext, "application/vnd.jupyter.ywidget-view+json": { "version_major": 2, "version_minor": 0, "model_id": self._comm.comm_id, - } + }, } return data diff --git a/src/ypywidgets/reactive.py b/src/ypywidgets/reactive.py index e337321..47a12e4 100644 --- a/src/ypywidgets/reactive.py +++ b/src/ypywidgets/reactive.py @@ -10,7 +10,6 @@ class Reactive(_Reactive, Generic[ValueType]): - def __init__( self, default: ValueType, diff --git a/tests/conftest.py b/tests/conftest.py index b7a0ea5..ede26f5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,13 @@ -import asyncio -import time -from typing import Optional +import math +from contextlib import AsyncExitStack +from functools import partial +from typing import Any, cast import comm import pytest +from anyio import Event, create_memory_object_stream, create_task_group, fail_after +from anyio.abc import TaskGroup +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pycrdt import ( YMessageType, YSyncMessageType, @@ -15,86 +19,163 @@ from ypywidgets import Widget from ypywidgets.comm import CommWidget +pytestmark = pytest.mark.anyio -class MockComm(comm.base_comm.BaseComm): +class MockComm(comm.base_comm.BaseComm): def __init__( - self, - comm_id=None, - target_name=None, - data=None, - metadata=None, + self, + task_group: TaskGroup, + send_send_stream: MemoryObjectSendStream, + send_recv_stream: MemoryObjectReceiveStream, + recv_send_stream: MemoryObjectSendStream, + recv_recv_stream: MemoryObjectReceiveStream, + comm_id: str = "", + target_name: str = "", + data=None, + metadata=None, + ) -> None: + self.send_send_stream = send_send_stream + self.send_recv_stream = send_recv_stream + self.recv_send_stream = recv_send_stream + self.recv_recv_stream = recv_recv_stream + super().__init__( + comm_id=comm_id, target_name=target_name, data=data, metadata=metadata + ) + task_group.start_soon(self.receive) + + def publish_msg( + self, msg_type, data, metadata, buffers, target_name=None, target_module=None ): - self.send_queue = asyncio.Queue() - self.recv_queue = asyncio.Queue() - super().__init__(comm_id=comm_id, target_name=target_name, data=data, metadata=metadata) - self.receive_task = asyncio.create_task(self.receive()) - - def publish_msg(self, msg_type, data, metadata, buffers, target_name=None, target_module=None): - self.send_queue.put_nowait((msg_type, data, metadata, buffers, target_name, target_module)) + self.send_send_stream.send_nowait( + (msg_type, data, metadata, buffers, target_name, target_module) + ) def handle_msg(self, msg): self._msg_callback(msg) - async def receive(self): + async def receive(self) -> None: while True: - msg = await self.recv_queue.get() + msg = await self.recv_recv_stream.receive() self.handle_msg(msg) -comm.create_comm = MockComm - +class Context: + def __init__(self): + self.tasks = [] + + def add_task(self, task): + self.tasks.append(task) + + async def __aenter__(self) -> "Context": + send_send_stream, send_recv_stream = create_memory_object_stream( + max_buffer_size=math.inf + ) + recv_send_stream, recv_recv_stream = create_memory_object_stream( + max_buffer_size=math.inf + ) + async with AsyncExitStack() as stack: + await stack.enter_async_context(send_send_stream) + await stack.enter_async_context(recv_send_stream) + await stack.enter_async_context(send_recv_stream) + await stack.enter_async_context(recv_recv_stream) + self.task_group = await stack.enter_async_context(create_task_group()) + comm.create_comm = partial( + MockComm, + self.task_group, + send_send_stream, + send_recv_stream, + recv_send_stream, + recv_recv_stream, + ) + for task in self.tasks: + self.task_group.start_soon(task) + self.stack = stack.pop_all() + return self + + async def __aexit__(self, *exc) -> bool | None: + self.task_group.cancel_scope.cancel() + comm.create_comm = _create_comm + return await self.stack.__aexit__(*exc) + + +def _create_comm(*args: Any, **kwargs: Any) -> comm.BaseComm: + return comm.DummyComm(*args, **kwargs) # pragma: nocover + + +class SyncedWidgets: + def __init__( + self, widget_factories: tuple[type[CommWidget], type[Widget]], context: Context + ) -> None: + self.local_widget_factory, self.remote_widget_factory = widget_factories + self.local_widget: CommWidget | None = None + self.remote_widget: Widget | None = None + self.local_widget_created = Event() + self.remote_widget_created = Event() + context.add_task(self.receive) + + def send(self, event: TransactionEvent) -> None: + update = event.update + message = create_update_message(update) + self.comm.recv_send_stream.send_nowait({"buffers": [message]}) -@pytest.fixture -def widget_factories(): - return CommWidget, Widget + async def receive(self) -> None: + self.local_widget = self.local_widget_factory() + self.local_widget_created.set() + self.comm = cast(MockComm, self.local_widget._comm) + while True: + ( + msg_type, + data, + metadata, + buffers, + target_name, + target_module, + ) = await self.comm.send_recv_stream.receive() + match msg_type: + case "comm_open": + self.remote_widget = self.remote_widget_factory() + msg = create_sync_message(self.remote_widget.ydoc) + self.comm.handle_msg({"buffers": [msg]}) + case "comm_msg": + assert self.remote_widget is not None + message = buffers[0] + match message[0]: + case YMessageType.SYNC: + reply = handle_sync_message( + message[1:], self.remote_widget.ydoc + ) + if reply is not None: + self.comm.handle_msg({"buffers": [reply]}) + if message[1] == YSyncMessageType.SYNC_STEP2: + self.sub = self.remote_widget.ydoc.observe(self.send) + self.remote_widget_created.set() + + async def get_local_widget(self, timeout: float = 0.2) -> CommWidget: + with fail_after(timeout): + await self.local_widget_created.wait() + assert self.local_widget is not None + return self.local_widget + + async def get_remote_widget(self, timeout: float = 0.2) -> Widget: + with fail_after(timeout): + await self.remote_widget_created.wait() + assert self.remote_widget is not None + return self.remote_widget @pytest.fixture -async def synced_widgets(widget_factories): - local_widget = widget_factories[0]() - remote_widget_manager = RemoteWidgetManager(widget_factories[1], local_widget._comm) - remote_widget = await remote_widget_manager.get_widget() - return local_widget, remote_widget +def context() -> Context: + return Context() -class RemoteWidgetManager: - - comm: Optional[MockComm] - widget: Optional[Widget] +@pytest.fixture +def widget_factories() -> tuple[type[CommWidget], type[Widget]]: + return CommWidget, Widget - def __init__(self, widget_factory, comm): - self.widget_factory = widget_factory - self.comm = comm - self.widget = None - self.receive_task = asyncio.create_task(self.receive()) - def send(self, event: TransactionEvent): - update = event.update - message = create_update_message(update) - self.comm.recv_queue.put_nowait({"buffers": [message]}) - - async def receive(self): - while True: - msg_type, data, metadata, buffers, target_name, target_module = await self.comm.send_queue.get() - if msg_type == "comm_open": - self.widget = self.widget_factory() - msg = create_sync_message(self.widget.ydoc) - self.comm.handle_msg({"buffers": [msg]}) - elif msg_type == "comm_msg": - message = buffers[0] - if message[0] == YMessageType.SYNC: - reply = handle_sync_message(message[1:], self.widget.ydoc) - if reply is not None: - self.comm.handle_msg({"buffers": [reply]}) - if message[1] == YSyncMessageType.SYNC_STEP2: - self.widget.ydoc.observe(self.send) - - async def get_widget(self, timeout=0.1): - t = time.monotonic() - while True: - if self.widget: - return self.widget - await asyncio.sleep(0) - if time.monotonic() - t > timeout: # pragma: nocover - raise TimeoutError("Timeout waiting for widget") +@pytest.fixture +def synced_widgets( + widget_factories: tuple[type[CommWidget], type[Widget]], context: Context +) -> SyncedWidgets: + return SyncedWidgets(widget_factories, context) diff --git a/tests/test_attributes.py b/tests/test_attributes.py index 4bcfe85..7e762d5 100644 --- a/tests/test_attributes.py +++ b/tests/test_attributes.py @@ -1,71 +1,78 @@ -import asyncio -from typing import Optional +from __future__ import annotations import pytest +from anyio import sleep from pycrdt import Text -from ypywidgets import Reactive +from ypywidgets import Reactive, Widget from ypywidgets.comm import CommWidget +pytestmark = pytest.mark.anyio + class Widget1(CommWidget): - foo = Reactive[str]("foo1") - bar = Reactive[str]("bar1") - baz = Reactive[Optional[str]](None) + foo = Reactive[str | None]("foo1") + bar = Reactive[str | None]("bar1") + baz = Reactive[str | None](None) -class Widget2(CommWidget): - foo = Reactive[str]("") +class Widget2(Widget): + foo = Reactive[str | None](None) + bar = Reactive[str | None](None) + baz = Reactive[str | None](None) @foo.watch def _watch_foo(self, old, new): print(f"foo changed: '{old}'->'{new}'") -@pytest.mark.asyncio -async def test_create_ydoc(synced_widgets): - local_widget, remote_widget = await synced_widgets +async def test_create_ydoc(synced_widgets, context): + async with context: + local_widget = await synced_widgets.get_local_widget() + remote_widget = await synced_widgets.get_remote_widget() - local_text = Text() - local_widget.ydoc["text"] = local_text - text = "hello world!" - local_text += text + local_text = Text() + local_widget.ydoc["text"] = local_text + text = "hello world!" + local_text += text - remote_text = Text() - remote_widget.ydoc["text"] = remote_text - await asyncio.sleep(0.01) - assert str(remote_text) == text + remote_text = Text() + remote_widget.ydoc["text"] = remote_text + await sleep(0.01) + assert str(remote_text) == text -@pytest.mark.asyncio -@pytest.mark.parametrize("widget_factories", ((Widget1, Widget1),)) -async def test_sync_attribute(widget_factories, synced_widgets): - local_widget, remote_widget = await synced_widgets +@pytest.mark.parametrize("widget_factories", ((Widget1, Widget2),)) +async def test_sync_attribute(widget_factories, synced_widgets, context): + async with context: + local_widget = await synced_widgets.get_local_widget() + remote_widget = await synced_widgets.get_remote_widget() - with pytest.raises(AttributeError): - assert local_widget.wrong_attr1 + with pytest.raises(AttributeError): + assert local_widget.wrong_attr1 - with pytest.raises(AttributeError): - assert remote_widget.wrong_attr2 + with pytest.raises(AttributeError): + assert remote_widget.wrong_attr2 - local_widget.foo = "foo2" - assert remote_widget.foo == "foo1" # not synced yet - await asyncio.sleep(0.01) # wait for sync - assert remote_widget.foo == "foo2" + local_widget.foo = "foo2" + assert remote_widget.foo is None # not synced yet + await sleep(0.01) # wait for sync + assert remote_widget.foo == "foo2" - remote_widget.baz = "baz2" - assert local_widget.baz is None # not synced yet - await asyncio.sleep(0.01) # wait for sync - assert local_widget.baz == "baz2" + remote_widget.baz = "baz2" + assert local_widget.baz is None # not synced yet + await sleep(0.01) # wait for sync + assert local_widget.baz == "baz2" -@pytest.mark.asyncio @pytest.mark.parametrize("widget_factories", ((Widget1, Widget2),)) -async def test_watch_attribute(widget_factories, synced_widgets, capfd): - local_widget, remote_widget = await synced_widgets +async def test_watch_attribute(widget_factories, synced_widgets, capfd, context): + async with context: + local_widget = await synced_widgets.get_local_widget() + await synced_widgets.get_remote_widget() - local_widget.foo = "foo" + local_widget.foo = "foo" - # we're seeing the remote widget watch callback - await asyncio.sleep(0.01) - out, err = capfd.readouterr() - assert out == "foo changed: ''->'foo'\n" + # we're seeing the remote widget watch callback + await sleep(0.01) + out, err = capfd.readouterr() + assert out == "foo changed: 'None'->'foo'\n"