Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: [ '3.8', '3.9', '3.10', '3.11', '3.12' ]
python-version: [ '3.10', '3.11', '3.12', '3.13', '3.14' ]

steps:
- name: Checkout
Expand All @@ -30,20 +30,22 @@ jobs:
with:
python-version: ${{ matrix.python-version }}

- name: Upgrade pip
run: python3 -m pip install --upgrade pip

- name: Install ypywidgets in dev mode
run: pip install -e ".[dev]"
run: pip install -e . --group dev

- name: Lint
run: |
ruff format --check src tests
ruff check src tests

- name: Check types
run: mypy src
run: mypy src tests

- name: Run tests
run: pytest ./tests -v --color=yes

- name: Run code coverage
if: ${{ (matrix.python-version == '3.12') && (matrix.os == 'ubuntu-latest') }}
if: ${{ (matrix.python-version == '3.14') && (matrix.os == 'ubuntu-latest') }}
run: |
coverage run -m pytest tests
coverage report --fail-under=100
10 changes: 6 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ version = "0.9.7"
description = "Y-based Jupyter widgets for Python"
readme = "README.md"
license = "MIT"
requires-python = ">=3.8"
requires-python = ">=3.10"
authors = [
{ name = "David Brochart", email = "david.brochart@gmail.com" },
]
Expand All @@ -24,14 +24,16 @@ dependencies = [
]

[project.urls]
Homepage = "https://github.com/davidbrochart/ypywidgets"
Homepage = "https://github.com/QuantStack/ypywidgets"

[project.optional-dependencies]
[dependency-groups]
dev = [
"anyio",
"coverage >=7.0.0,<8.0.0",
"mypy",
"pytest",
"pytest-asyncio",
"ruff",
"trio",
]

[tool.hatch.build.targets.wheel]
Expand Down
29 changes: 15 additions & 14 deletions src/ypywidgets/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,13 @@ def __init__(

def _receive(self, msg):
message = bytes(msg["buffers"][0])
if message[0] == YMessageType.SYNC:
reply = handle_sync_message(message[1:], self._ydoc)
if reply is not None:
self._comm.send(buffers=[reply])
if message[1] == YSyncMessageType.SYNC_STEP2:
self._ydoc.observe(self._send)
match message[0]:
case YMessageType.SYNC:
reply = handle_sync_message(message[1:], self._ydoc)
if reply is not None:
self._comm.send(buffers=[reply])
if message[1] == YSyncMessageType.SYNC_STEP2:
self._ydoc.observe(self._send)

def _send(self, event: TransactionEvent):
update = event.update
Expand All @@ -69,12 +70,12 @@ def _send(self, event: TransactionEvent):

class CommWidget(Widget):
def __init__(
self,
ydoc: Doc | None = None,
comm_data: dict | None = None,
comm_metadata: dict | None = None,
comm_id: str | None = None,
):
self,
ydoc: Doc | None = None,
comm_data: dict | None = None,
comm_metadata: dict | None = None,
comm_id: str | None = None,
):
super().__init__(ydoc)
model_name = self.__class__.__name__
_model_name = self.ydoc["_model_name"] = Text()
Expand All @@ -90,13 +91,13 @@ def __init__(
def _repr_mimebundle_(self, *args, **kwargs): # pragma: nocover
plaintext = repr(self)
if len(plaintext) > 110:
plaintext = plaintext[:110] + '…'
plaintext = plaintext[:110] + "…"
data = {
"text/plain": plaintext,
"application/vnd.jupyter.ywidget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": self._comm.comm_id,
}
},
}
return data
1 change: 0 additions & 1 deletion src/ypywidgets/reactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@


class Reactive(_Reactive, Generic[ValueType]):

def __init__(
self,
default: ValueType,
Expand Down
213 changes: 147 additions & 66 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import asyncio
import time
from typing import Optional
import math
from contextlib import AsyncExitStack
from functools import partial
from typing import Any, cast

import comm
import pytest
from anyio import Event, create_memory_object_stream, create_task_group, fail_after
from anyio.abc import TaskGroup
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pycrdt import (
YMessageType,
YSyncMessageType,
Expand All @@ -15,86 +19,163 @@
from ypywidgets import Widget
from ypywidgets.comm import CommWidget

pytestmark = pytest.mark.anyio

class MockComm(comm.base_comm.BaseComm):

class MockComm(comm.base_comm.BaseComm):
def __init__(
self,
comm_id=None,
target_name=None,
data=None,
metadata=None,
self,
task_group: TaskGroup,
send_send_stream: MemoryObjectSendStream,
send_recv_stream: MemoryObjectReceiveStream,
recv_send_stream: MemoryObjectSendStream,
recv_recv_stream: MemoryObjectReceiveStream,
comm_id: str = "",
target_name: str = "",
data=None,
metadata=None,
) -> None:
self.send_send_stream = send_send_stream
self.send_recv_stream = send_recv_stream
self.recv_send_stream = recv_send_stream
self.recv_recv_stream = recv_recv_stream
super().__init__(
comm_id=comm_id, target_name=target_name, data=data, metadata=metadata
)
task_group.start_soon(self.receive)

def publish_msg(
self, msg_type, data, metadata, buffers, target_name=None, target_module=None
):
self.send_queue = asyncio.Queue()
self.recv_queue = asyncio.Queue()
super().__init__(comm_id=comm_id, target_name=target_name, data=data, metadata=metadata)
self.receive_task = asyncio.create_task(self.receive())

def publish_msg(self, msg_type, data, metadata, buffers, target_name=None, target_module=None):
self.send_queue.put_nowait((msg_type, data, metadata, buffers, target_name, target_module))
self.send_send_stream.send_nowait(
(msg_type, data, metadata, buffers, target_name, target_module)
)

def handle_msg(self, msg):
self._msg_callback(msg)

async def receive(self):
async def receive(self) -> None:
while True:
msg = await self.recv_queue.get()
msg = await self.recv_recv_stream.receive()
self.handle_msg(msg)


comm.create_comm = MockComm

class Context:
def __init__(self):
self.tasks = []

def add_task(self, task):
self.tasks.append(task)

async def __aenter__(self) -> "Context":
send_send_stream, send_recv_stream = create_memory_object_stream(
max_buffer_size=math.inf
)
recv_send_stream, recv_recv_stream = create_memory_object_stream(
max_buffer_size=math.inf
)
async with AsyncExitStack() as stack:
await stack.enter_async_context(send_send_stream)
await stack.enter_async_context(recv_send_stream)
await stack.enter_async_context(send_recv_stream)
await stack.enter_async_context(recv_recv_stream)
self.task_group = await stack.enter_async_context(create_task_group())
comm.create_comm = partial(
MockComm,
self.task_group,
send_send_stream,
send_recv_stream,
recv_send_stream,
recv_recv_stream,
)
for task in self.tasks:
self.task_group.start_soon(task)
self.stack = stack.pop_all()
return self

async def __aexit__(self, *exc) -> bool | None:
self.task_group.cancel_scope.cancel()
comm.create_comm = _create_comm
return await self.stack.__aexit__(*exc)


def _create_comm(*args: Any, **kwargs: Any) -> comm.BaseComm:
return comm.DummyComm(*args, **kwargs) # pragma: nocover


class SyncedWidgets:
def __init__(
self, widget_factories: tuple[type[CommWidget], type[Widget]], context: Context
) -> None:
self.local_widget_factory, self.remote_widget_factory = widget_factories
self.local_widget: CommWidget | None = None
self.remote_widget: Widget | None = None
self.local_widget_created = Event()
self.remote_widget_created = Event()
context.add_task(self.receive)

def send(self, event: TransactionEvent) -> None:
update = event.update
message = create_update_message(update)
self.comm.recv_send_stream.send_nowait({"buffers": [message]})

@pytest.fixture
def widget_factories():
return CommWidget, Widget
async def receive(self) -> None:
self.local_widget = self.local_widget_factory()
self.local_widget_created.set()
self.comm = cast(MockComm, self.local_widget._comm)
while True:
(
msg_type,
data,
metadata,
buffers,
target_name,
target_module,
) = await self.comm.send_recv_stream.receive()
match msg_type:
case "comm_open":
self.remote_widget = self.remote_widget_factory()
msg = create_sync_message(self.remote_widget.ydoc)
self.comm.handle_msg({"buffers": [msg]})
case "comm_msg":
assert self.remote_widget is not None
message = buffers[0]
match message[0]:
case YMessageType.SYNC:
reply = handle_sync_message(
message[1:], self.remote_widget.ydoc
)
if reply is not None:
self.comm.handle_msg({"buffers": [reply]})
if message[1] == YSyncMessageType.SYNC_STEP2:
self.sub = self.remote_widget.ydoc.observe(self.send)
self.remote_widget_created.set()

async def get_local_widget(self, timeout: float = 0.2) -> CommWidget:
with fail_after(timeout):
await self.local_widget_created.wait()
assert self.local_widget is not None
return self.local_widget

async def get_remote_widget(self, timeout: float = 0.2) -> Widget:
with fail_after(timeout):
await self.remote_widget_created.wait()
assert self.remote_widget is not None
return self.remote_widget


@pytest.fixture
async def synced_widgets(widget_factories):
local_widget = widget_factories[0]()
remote_widget_manager = RemoteWidgetManager(widget_factories[1], local_widget._comm)
remote_widget = await remote_widget_manager.get_widget()
return local_widget, remote_widget
def context() -> Context:
return Context()


class RemoteWidgetManager:

comm: Optional[MockComm]
widget: Optional[Widget]
@pytest.fixture
def widget_factories() -> tuple[type[CommWidget], type[Widget]]:
return CommWidget, Widget

def __init__(self, widget_factory, comm):
self.widget_factory = widget_factory
self.comm = comm
self.widget = None
self.receive_task = asyncio.create_task(self.receive())

def send(self, event: TransactionEvent):
update = event.update
message = create_update_message(update)
self.comm.recv_queue.put_nowait({"buffers": [message]})

async def receive(self):
while True:
msg_type, data, metadata, buffers, target_name, target_module = await self.comm.send_queue.get()
if msg_type == "comm_open":
self.widget = self.widget_factory()
msg = create_sync_message(self.widget.ydoc)
self.comm.handle_msg({"buffers": [msg]})
elif msg_type == "comm_msg":
message = buffers[0]
if message[0] == YMessageType.SYNC:
reply = handle_sync_message(message[1:], self.widget.ydoc)
if reply is not None:
self.comm.handle_msg({"buffers": [reply]})
if message[1] == YSyncMessageType.SYNC_STEP2:
self.widget.ydoc.observe(self.send)

async def get_widget(self, timeout=0.1):
t = time.monotonic()
while True:
if self.widget:
return self.widget
await asyncio.sleep(0)
if time.monotonic() - t > timeout: # pragma: nocover
raise TimeoutError("Timeout waiting for widget")
@pytest.fixture
def synced_widgets(
widget_factories: tuple[type[CommWidget], type[Widget]], context: Context
) -> SyncedWidgets:
return SyncedWidgets(widget_factories, context)
Loading