Skip to content

Commit be89102

Browse files
committed
Better annotate RW.wrap_subscribe
The signature of the callback changes based on whether the subscription can accept non-workflows messages.
1 parent 688e76b commit be89102

1 file changed

Lines changed: 35 additions & 6 deletions

File tree

src/workflows/recipe/__init__.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
import functools
44
import logging
55
from collections.abc import Callable
6-
from contextlib import ExitStack
7-
from typing import Any
6+
from contextlib import AbstractContextManager, ExitStack
7+
from typing import Any, Literal, overload
88

99
from opentelemetry import trace
1010

1111
from workflows.recipe.recipe import Recipe
1212
from workflows.recipe.validate import validate_recipe
1313
from workflows.recipe.wrapper import RecipeWrapper
14+
from workflows.transport.common_transport import CommonTransport
1415

1516
__all__ = [
1617
"Recipe",
@@ -30,6 +31,8 @@ def _wrap_subscription(
3031
callback,
3132
*args,
3233
mangle_for_receiving: Callable[[Any], Any] | None = None,
34+
allow_non_recipe_messages: bool = False,
35+
log_extender=None,
3336
**kwargs,
3437
):
3538
"""Internal method to create an intercepting function for incoming messages
@@ -55,9 +58,6 @@ def _wrap_subscription(
5558
:return: Return value of call to subscription_call.
5659
"""
5760

58-
allow_non_recipe_messages = kwargs.pop("allow_non_recipe_messages", False)
59-
log_extender = kwargs.pop("log_extender", None)
60-
6161
@functools.wraps(callback)
6262
def unwrap_recipe(header, message):
6363
"""This is a helper function unpacking incoming messages when they are
@@ -113,12 +113,40 @@ def unwrap_recipe(header, message):
113113
return subscription_call(channel, unwrap_recipe, *args, **kwargs)
114114

115115

116+
@overload
117+
def wrap_subscribe(
118+
transport_layer: CommonTransport,
119+
channel: str,
120+
callback: Callable[[RecipeWrapper, dict, dict], None],
121+
*args,
122+
allow_non_recipe_messages: Literal[False] = False,
123+
mangle_for_receiving: Callable[[Any], Any] | None = None,
124+
log_extender: Callable[[str, Any], AbstractContextManager[Any]] | None = None,
125+
**kwargs,
126+
) -> int: ...
127+
128+
129+
@overload
130+
def wrap_subscribe(
131+
transport_layer: CommonTransport,
132+
channel: str,
133+
callback: Callable[[RecipeWrapper | None, dict, dict | bytes], None],
134+
*args,
135+
allow_non_recipe_messages: Literal[True],
136+
mangle_for_receiving: Callable[[Any], Any] | None = None,
137+
log_extender: Callable[[str, Any], AbstractContextManager[Any]] | None = None,
138+
**kwargs,
139+
) -> int: ...
140+
141+
116142
def wrap_subscribe(
117143
transport_layer,
118144
channel,
119145
callback,
120146
*args,
121-
mangle_for_receiving: Callable[[Any], Any] | None = None,
147+
allow_non_recipe_messages=False,
148+
mangle_for_receiving=None,
149+
log_extender=None,
122150
**kwargs,
123151
):
124152
"""Listen to a queue on the transport layer, similar to the subscribe call in
@@ -141,6 +169,7 @@ def wrap_subscribe(
141169
callback,
142170
*args,
143171
mangle_for_receiving=mangle_for_receiving,
172+
allow_non_recipe_messages=allow_non_recipe_messages,
144173
**kwargs,
145174
)
146175

0 commit comments

Comments
 (0)