33import functools
44import logging
55from collections .abc import Callable
6- from contextlib import ExitStack
7- from typing import Any
6+ from contextlib import AbstractContextManager , ExitStack
7+ from typing import Any , Literal , overload
88
99from opentelemetry import trace
1010
1111from workflows .recipe .recipe import Recipe
1212from workflows .recipe .validate import validate_recipe
1313from workflows .recipe .wrapper import RecipeWrapper
14+ from workflows .transport .common_transport import CommonTransport
1415
1516__all__ = [
1617 "Recipe" ,
@@ -30,6 +31,8 @@ def _wrap_subscription(
3031 callback ,
3132 * args ,
3233 mangle_for_receiving : Callable [[Any ], Any ] | None = None ,
34+ allow_non_recipe_messages : bool = False ,
35+ log_extender = None ,
3336 ** kwargs ,
3437):
3538 """Internal method to create an intercepting function for incoming messages
@@ -55,9 +58,6 @@ def _wrap_subscription(
5558 :return: Return value of call to subscription_call.
5659 """
5760
58- allow_non_recipe_messages = kwargs .pop ("allow_non_recipe_messages" , False )
59- log_extender = kwargs .pop ("log_extender" , None )
60-
6161 @functools .wraps (callback )
6262 def unwrap_recipe (header , message ):
6363 """This is a helper function unpacking incoming messages when they are
@@ -113,12 +113,40 @@ def unwrap_recipe(header, message):
113113 return subscription_call (channel , unwrap_recipe , * args , ** kwargs )
114114
115115
116+ @overload
117+ def wrap_subscribe (
118+ transport_layer : CommonTransport ,
119+ channel : str ,
120+ callback : Callable [[RecipeWrapper , dict , dict ], None ],
121+ * args ,
122+ allow_non_recipe_messages : Literal [False ] = False ,
123+ mangle_for_receiving : Callable [[Any ], Any ] | None = None ,
124+ log_extender : Callable [[str , Any ], AbstractContextManager [Any ]] | None = None ,
125+ ** kwargs ,
126+ ) -> int : ...
127+
128+
129+ @overload
130+ def wrap_subscribe (
131+ transport_layer : CommonTransport ,
132+ channel : str ,
133+ callback : Callable [[RecipeWrapper | None , dict , dict | bytes ], None ],
134+ * args ,
135+ allow_non_recipe_messages : Literal [True ],
136+ mangle_for_receiving : Callable [[Any ], Any ] | None = None ,
137+ log_extender : Callable [[str , Any ], AbstractContextManager [Any ]] | None = None ,
138+ ** kwargs ,
139+ ) -> int : ...
140+
141+
116142def wrap_subscribe (
117143 transport_layer ,
118144 channel ,
119145 callback ,
120146 * args ,
121- mangle_for_receiving : Callable [[Any ], Any ] | None = None ,
147+ allow_non_recipe_messages = False ,
148+ mangle_for_receiving = None ,
149+ log_extender = None ,
122150 ** kwargs ,
123151):
124152 """Listen to a queue on the transport layer, similar to the subscribe call in
@@ -141,6 +169,7 @@ def wrap_subscribe(
141169 callback ,
142170 * args ,
143171 mangle_for_receiving = mangle_for_receiving ,
172+ allow_non_recipe_messages = allow_non_recipe_messages ,
144173 ** kwargs ,
145174 )
146175
0 commit comments