diff --git a/ipykernel/kernelbase.py b/ipykernel/kernelbase.py index aceaf3eb9..c24ffd55c 100644 --- a/ipykernel/kernelbase.py +++ b/ipykernel/kernelbase.py @@ -418,7 +418,7 @@ async def dispatch_shell(self, msg, /, subshell_id: str | None = None): assert msg["header"].get("subshell_id") == subshell_id if self._supports_kernel_subshells: - stream = self.shell_channel_thread.manager.get_subshell_to_shell_channel_socket( + stream = self.shell_channel_thread.manager.get_subshell_to_shell_channel_stream( subshell_id ) else: diff --git a/ipykernel/socket_pair.py b/ipykernel/socket_pair.py index e2669b8c0..33043e8fc 100644 --- a/ipykernel/socket_pair.py +++ b/ipykernel/socket_pair.py @@ -19,12 +19,13 @@ class SocketPair: """ from_socket: zmq.Socket[Any] + from_stream: ZMQStream | None = None to_socket: zmq.Socket[Any] to_stream: ZMQStream | None = None on_recv_callback: Any on_recv_copy: bool - def __init__(self, context: zmq.Context[Any], name: str): + def __init__(self, context: zmq.Context[Any], name: str, from_io_loop: IOLoop | None = None): """Initialize the inproc socker pair.""" self.from_socket = context.socket(zmq.PAIR) self.to_socket = context.socket(zmq.PAIR) @@ -32,8 +33,14 @@ def __init__(self, context: zmq.Context[Any], name: str): self.from_socket.bind(address) self.to_socket.connect(address) # Or do I need to do this in another thread? + # Optional from_stream, only created if from_io_loop is specified. + if from_io_loop is not None: + self.from_stream = ZMQStream(self.from_socket, from_io_loop) + def close(self): """Close the inproc socker pair.""" + if self.from_stream is not None: + self.from_stream.close() self.from_socket.close() if self.to_stream is not None: diff --git a/ipykernel/subshell.py b/ipykernel/subshell.py index 911a9521c..32acd5917 100644 --- a/ipykernel/subshell.py +++ b/ipykernel/subshell.py @@ -24,7 +24,7 @@ def __init__( super().__init__(name=f"subshell-{subshell_id}", **kwargs) self.shell_channel_to_subshell = SocketPair(context, subshell_id) - self.subshell_to_shell_channel = SocketPair(context, subshell_id + "-reverse") + self.subshell_to_shell_channel = SocketPair(context, subshell_id + "-reverse", self.io_loop) # When aborting flag is set, execute_request messages to this subshell will be aborted. self.aborting = False diff --git a/ipykernel/subshell_manager.py b/ipykernel/subshell_manager.py index 24d683523..61ffe4ee7 100644 --- a/ipykernel/subshell_manager.py +++ b/ipykernel/subshell_manager.py @@ -10,6 +10,7 @@ import zmq from tornado.ioloop import IOLoop +from zmq.eventloop.zmqstream import ZMQStream from .socket_pair import SocketPair from .subshell import SubshellThread @@ -62,7 +63,8 @@ def __init__( # Inproc socket pair for communication from main thread to shell channel thread. # such as for execute_reply messages. - self._main_to_shell_channel = SocketPair(self._context, "main-reverse") + main_io_loop = IOLoop.current() + self._main_to_shell_channel = SocketPair(self._context, "main-reverse", main_io_loop) self._main_to_shell_channel.on_recv( self._shell_channel_io_loop, self._send_on_shell_channel ) @@ -90,14 +92,17 @@ def get_shell_channel_to_subshell_pair(self, subshell_id: str | None) -> SocketP with self._lock_cache: return self._cache[subshell_id].shell_channel_to_subshell - def get_subshell_to_shell_channel_socket(self, subshell_id: str | None) -> zmq.Socket[t.Any]: - """Return the socket used by a particular subshell or main shell to send + def get_subshell_to_shell_channel_stream(self, subshell_id: str | None) -> ZMQStream: + """Return the stream used by a particular subshell or main shell to send messages to the shell channel. """ if subshell_id is None: - return self._main_to_shell_channel.from_socket - with self._lock_cache: - return self._cache[subshell_id].subshell_to_shell_channel.from_socket + from_stream = self._main_to_shell_channel.from_stream + else: + with self._lock_cache: + from_stream = self._cache[subshell_id].subshell_to_shell_channel.from_stream + assert from_stream is not None + return from_stream def get_shell_channel_to_subshell_socket(self, subshell_id: str | None) -> zmq.Socket[t.Any]: """Return the socket used by the shell channel to send messages to a particular