Skip to content

Commit 2fa367b

Browse files
CasperGNsicoyle
andauthored
fix: ensure response format passed in aio. Ensure reconnecting ratherthan breaking. Correctly bubble up workflow name if set (#951)
* fix: ensure response format passed in aio. Ensure reconnecting rather than breaking. Correctly bubble up workflow name if set Signed-off-by: Casper Nielsen <casper@diagrid.io> * chore: fix f-string Signed-off-by: Casper Nielsen <casper@diagrid.io> * chore: add todo string Signed-off-by: Casper Nielsen <casper@diagrid.io> * chore(deps): bump durabletask-dapr Signed-off-by: Casper Nielsen <casper@diagrid.io> * chore(test): increase coverage Signed-off-by: Casper Nielsen <casper@diagrid.io> * chore(format): ruff Signed-off-by: Casper Nielsen <casper@diagrid.io> --------- Signed-off-by: Casper Nielsen <casper@diagrid.io> Co-authored-by: Sam <sam@diagrid.io>
1 parent 99a5c40 commit 2fa367b

7 files changed

Lines changed: 369 additions & 19 deletions

File tree

dapr/aio/clients/grpc/client.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@
2323

2424
import grpc.aio # type: ignore
2525
from google.protobuf.any_pb2 import Any as GrpcAny
26+
from google.protobuf.duration_pb2 import Duration as GrpcDuration
2627
from google.protobuf.empty_pb2 import Empty as GrpcEmpty
2728
from google.protobuf.message import Message as GrpcMessage
29+
from google.protobuf.struct_pb2 import Struct as GrpcStruct
2830
from grpc import StatusCode # type: ignore
2931
from grpc.aio import ( # type: ignore
3032
AioRpcError,
@@ -1563,6 +1565,8 @@ async def converse_alpha2(
15631565
temperature: Optional[float] = None,
15641566
tools: Optional[List[conversation.ConversationTools]] = None,
15651567
tool_choice: Optional[str] = None,
1568+
response_format: Optional[GrpcStruct] = None,
1569+
prompt_cache_retention: Optional[GrpcDuration] = None,
15661570
) -> conversation.ConversationResponseAlpha2:
15671571
"""Invoke an LLM using the conversation API (Alpha2) with tool calling support.
15681572
@@ -1576,6 +1580,8 @@ async def converse_alpha2(
15761580
temperature: Optional temperature setting for the LLM to optimize for creativity or predictability
15771581
tools: Optional list of tools available for the LLM to call
15781582
tool_choice: Optional control over which tools can be called ('none', 'auto', 'required', or specific tool name)
1583+
response_format: Optional response format (google.protobuf.struct_pb2.Struct, ex: json_schema for structured output)
1584+
prompt_cache_retention: Optional retention for prompt cache (google.protobuf.duration_pb2.Duration)
15791585
15801586
Returns:
15811587
ConversationResponseAlpha2 containing the conversation results with choices and tool calls
@@ -1631,6 +1637,10 @@ async def converse_alpha2(
16311637
request.temperature = temperature
16321638
if tool_choice is not None:
16331639
request.tool_choice = tool_choice
1640+
if response_format is not None and hasattr(request, 'response_format'):
1641+
request.response_format.CopyFrom(response_format)
1642+
if prompt_cache_retention is not None and hasattr(request, 'prompt_cache_retention'):
1643+
request.prompt_cache_retention.CopyFrom(prompt_cache_retention)
16341644

16351645
try:
16361646
response, call = await self.retry_policy.run_rpc_async(

dapr/clients/grpc/client.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,16 @@ def stream_messages(sub):
635635
break
636636
except StreamCancelledError:
637637
break
638+
except Exception:
639+
# Stream died — reconnect via the subscription's own
640+
# reconnect logic (which waits for the sidecar to be healthy).
641+
try:
642+
sub.reconnect_stream()
643+
except Exception:
644+
# Sidecar still unavailable — back off before retrying
645+
# TODO: Make this configurable
646+
time.sleep(5)
647+
continue
638648

639649
def close_subscription():
640650
subscription.close()

dapr/clients/grpc/subscription.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
import queue
23
import threading
34
from typing import Optional
@@ -13,6 +14,8 @@
1314
)
1415
from dapr.proto import api_v1, appcallback_v1
1516

17+
logger = logging.getLogger(__name__)
18+
1619

1720
class Subscription:
1821
def __init__(self, stub, pubsub_name, topic, metadata=None, dead_letter_topic=None):
@@ -67,7 +70,7 @@ def outgoing_request_iterator():
6770
def reconnect_stream(self):
6871
self.close()
6972
DaprHealth.wait_for_sidecar()
70-
print('Attempting to reconnect...')
73+
logger.info('Subscription stream reconnecting...')
7174
self.start()
7275

7376
def next_message(self):
@@ -84,10 +87,17 @@ def next_message(self):
8487
message = next(self._stream)
8588
return SubscriptionMessage(message.event_message)
8689
except RpcError as e:
87-
# If Dapr can't be reached, wait until it's ready and reconnect the stream
88-
if e.code() == StatusCode.UNAVAILABLE or e.code() == StatusCode.UNKNOWN:
89-
print(
90-
f'gRPC error while reading from stream: {e.details()}, Status Code: {e.code()}'
90+
# If Dapr can't be reached, wait until it's ready and reconnect the stream.
91+
# INTERNAL covers RST_STREAM from cloud proxies (e.g. Diagrid Cloud).
92+
if e.code() in (
93+
StatusCode.UNAVAILABLE,
94+
StatusCode.UNKNOWN,
95+
StatusCode.INTERNAL,
96+
):
97+
logger.warning(
98+
'Subscription stream error (%s): %s — reconnecting',
99+
e.code(),
100+
e.details(),
91101
)
92102
self.reconnect_stream()
93103
elif e.code() == StatusCode.CANCELLED:
@@ -111,7 +121,7 @@ def respond(self, message, status):
111121
raise StreamInactiveError('Stream is not active')
112122
self._send_queue.put(msg)
113123
except Exception as e:
114-
print(f"Can't send message on inactive stream: {e}")
124+
logger.warning(f"Can't send message on inactive stream: {e}")
115125

116126
def respond_success(self, message):
117127
self.respond(message, TopicEventResponse('success').status)
@@ -135,15 +145,12 @@ def _is_stream_active(self):
135145
return self._stream_active
136146

137147
def close(self):
148+
self._set_stream_inactive()
138149
if self._stream:
139150
try:
140151
self._stream.cancel()
141-
self._set_stream_inactive()
142-
except RpcError as e:
143-
if e.code() != StatusCode.CANCELLED:
144-
raise Exception(f'Error while closing stream: {e}')
145-
except Exception as e:
146-
raise Exception(f'Error while closing stream: {e}')
152+
except Exception:
153+
pass # Stream already dead — safe to ignore
147154

148155
def __iter__(self):
149156
return self

dev-requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ Flask>=1.1
1414
# needed for auto fix
1515
ruff===0.14.1
1616
# needed for dapr-ext-workflow
17-
durabletask-dapr >= 0.17.1
17+
durabletask-dapr >= 0.17.2
1818
# needed for .env file loading in examples
1919
python-dotenv>=1.0.0
2020
# needed for enhanced schema generation from function features

ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ def __init__(
8686
)
8787

8888
def register_workflow(self, fn: Workflow, *, name: Optional[str] = None):
89-
self._logger.info(f"Registering workflow '{fn.__name__}' with runtime")
89+
effective_name = name or fn.__name__
90+
self._logger.info(f"Registering workflow '{effective_name}' with runtime")
9091

9192
def orchestrationWrapper(ctx: task.OrchestrationContext, inp: Optional[TInput] = None):
9293
"""Responsible to call Workflow function in orchestrationWrapper"""
@@ -125,8 +126,9 @@ def orchestrationWrapper(ctx: task.OrchestrationContext, inp: Optional[TInput] =
125126
def register_versioned_workflow(
126127
self, fn: Workflow, *, name: str, version_name: Optional[str] = None, is_latest: bool
127128
):
129+
effective_name = name or fn.__name__
128130
self._logger.info(
129-
f"Registering version {version_name} of workflow '{fn.__name__}' with runtime"
131+
f"Registering version {version_name} of workflow '{effective_name}' with runtime"
130132
)
131133

132134
def orchestrationWrapper(ctx: task.OrchestrationContext, inp: Optional[TInput] = None):
@@ -162,7 +164,8 @@ def register_activity(self, fn: Activity, *, name: Optional[str] = None):
162164
"""Registers a workflow activity as a function that takes
163165
a specified input type and returns a specified output type.
164166
"""
165-
self._logger.info(f"Registering activity '{fn.__name__}' with runtime")
167+
effective_name = name or fn.__name__
168+
self._logger.info(f"Registering activity '{effective_name}' with runtime")
166169

167170
def activityWrapper(ctx: task.ActivityContext, inp: Optional[TInput] = None):
168171
"""Responsible to call Activity function in activityWrapper"""
@@ -176,7 +179,7 @@ def activityWrapper(ctx: task.ActivityContext, inp: Optional[TInput] = None):
176179
result = fn(wfActivityContext, inp)
177180
return result
178181
except Exception as e:
179-
self._logger.exception(
182+
self._logger.warning(
180183
f'Activity execution failed - task_id: {activity_id}, error: {e}'
181184
)
182185
raise

ext/dapr-ext-workflow/setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ packages = find_namespace:
2525
include_package_data = True
2626
install_requires =
2727
dapr >= 1.17.0.dev
28-
durabletask-dapr >= 0.17.1
28+
durabletask-dapr >= 0.17.2
2929

3030
[options.packages.find]
3131
include =

0 commit comments

Comments
 (0)