diff --git a/temporalio/client/_client.py b/temporalio/client/_client.py index 1d8b8e4f2..438043f3d 100644 --- a/temporalio/client/_client.py +++ b/temporalio/client/_client.py @@ -1216,6 +1216,8 @@ def on_start_error( input = StartWorkflowUpdateWithStartInput( start_workflow_input=start_workflow_operation._start_workflow_input, update_workflow_input=update_input, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, _on_start=on_start, _on_start_error=on_start_error, ) diff --git a/temporalio/client/_impl.py b/temporalio/client/_impl.py index e62f0b4a2..1481c8327 100644 --- a/temporalio/client/_impl.py +++ b/temporalio/client/_impl.py @@ -852,7 +852,11 @@ def on_start( try: return await self._start_workflow_update_with_start( - input.start_workflow_input, input.update_workflow_input, on_start + input.start_workflow_input, + input.update_workflow_input, + input.rpc_metadata, + input.rpc_timeout, + on_start, ) except asyncio.CancelledError as _err: err = _err @@ -914,6 +918,8 @@ async def _start_workflow_update_with_start( self, start_input: UpdateWithStartStartWorkflowInput, update_input: UpdateWithStartUpdateWorkflowInput, + rpc_metadata: Mapping[str, str | bytes], + rpc_timeout: timedelta | None, on_start: Callable[ [temporalio.api.workflowservice.v1.StartWorkflowExecutionResponse], None ], @@ -941,7 +947,12 @@ async def _start_workflow_update_with_start( # Repeatedly try to invoke ExecuteMultiOperation until the update is durable while True: multiop_response = ( - await self._client.workflow_service.execute_multi_operation(multiop_req) + await self._client.workflow_service.execute_multi_operation( + multiop_req, + retry=True, + metadata=rpc_metadata, + timeout=rpc_timeout, + ) ) start_response = multiop_response.responses[0].start_workflow update_response = multiop_response.responses[1].update_workflow diff --git a/temporalio/client/_interceptor.py b/temporalio/client/_interceptor.py index 0e780146d..1f6dfa8b9 100644 --- a/temporalio/client/_interceptor.py +++ b/temporalio/client/_interceptor.py @@ -374,10 +374,20 @@ class UpdateWithStartStartWorkflowInput: @dataclass class StartWorkflowUpdateWithStartInput: - """Input for :py:meth:`OutboundInterceptor.start_update_with_start_workflow`.""" + """Input for :py:meth:`OutboundInterceptor.start_update_with_start_workflow`. + + The top-level ``rpc_metadata`` and ``rpc_timeout`` fields are authoritative + for the ``execute_multi_operation`` gRPC call. The sub-inputs + (``start_workflow_input`` and ``update_workflow_input``) also carry their own + ``rpc_metadata`` / ``rpc_timeout`` for interceptor introspection, but those + values are **not** forwarded to the gRPC call. Interceptors that wish to set + RPC metadata should modify :py:attr:`rpc_metadata` on this object. + """ start_workflow_input: UpdateWithStartStartWorkflowInput update_workflow_input: UpdateWithStartUpdateWorkflowInput + rpc_metadata: Mapping[str, str | bytes] + rpc_timeout: timedelta | None _on_start: Callable[ [temporalio.api.workflowservice.v1.StartWorkflowExecutionResponse], None ] diff --git a/tests/worker/test_update_with_start.py b/tests/worker/test_update_with_start.py index 4ed625960..2ceb5e91b 100644 --- a/tests/worker/test_update_with_start.py +++ b/tests/worker/test_update_with_start.py @@ -1104,3 +1104,84 @@ async def _do_update() -> Any: elif id_reuse_policy == WorkflowIDReusePolicy.REJECT_DUPLICATE: with pytest.raises(WorkflowAlreadyStartedError): await _do_update() + + +class MetadataCapturingInterceptor(Interceptor): + """Interceptor that sets rpc_metadata on update-with-start calls.""" + + def intercept_client(self, next: OutboundInterceptor) -> OutboundInterceptor: + return MetadataCapturingOutboundInterceptor(super().intercept_client(next)) + + +class MetadataCapturingOutboundInterceptor(OutboundInterceptor): + def __init__(self, next: OutboundInterceptor) -> None: + super().__init__(next) + + async def start_update_with_start_workflow( + self, input: StartWorkflowUpdateWithStartInput + ) -> WorkflowUpdateHandle[Any]: + input.rpc_metadata = { + **input.rpc_metadata, + "test-header-key": "test-header-value", + } + return await super().start_update_with_start_workflow(input) + + +# Verify fix for https://github.com/temporalio/sdk-python/issues/1582 +async def test_update_with_start_rpc_metadata_and_timeout_forwarded(client: Client): + """Test that rpc_metadata and rpc_timeout on StartWorkflowUpdateWithStartInput + are forwarded to the execute_multi_operation gRPC call.""" + captured_metadata: dict[str, str | bytes] = {} + captured_timeout: list[timedelta | None] = [] + + class execute_multi_operation: + err = RPCError("intentional", RPCStatusCode.INTERNAL, b"") + err._grpc_status = temporalio.api.common.v1.GrpcStatus(details=[]) + + def __init__(self) -> None: # type: ignore[reportMissingSuperCall] + pass + + async def __call__( + self, + req: temporalio.api.workflowservice.v1.ExecuteMultiOperationRequest, + *, + retry: bool = False, + metadata: Mapping[str, str | bytes] = {}, + timeout: timedelta | None = None, + ) -> temporalio.api.workflowservice.v1.ExecuteMultiOperationResponse: + captured_metadata.update(metadata) + captured_timeout.append(timeout) + raise self.err + + interceptor = MetadataCapturingInterceptor() + intercepted_client = Client( + **{**client.config(), "interceptors": [interceptor]} # type: ignore + ) + + with patch.object( + intercepted_client.workflow_service, + "execute_multi_operation", + execute_multi_operation(), + ): + start_workflow_operation = WithStartWorkflowOperation( + UpdateWithStartInterceptorWorkflow.run, + "wf-arg", + id=f"wf-{uuid.uuid4()}", + task_queue="tq", + id_conflict_policy=WorkflowIDConflictPolicy.FAIL, + ) + with pytest.raises(RPCError): + await intercepted_client.start_update_with_start_workflow( + UpdateWithStartInterceptorWorkflow.my_update, + "update-arg", + start_workflow_operation=start_workflow_operation, + wait_for_stage=WorkflowUpdateStage.ACCEPTED, + rpc_metadata={"original-key": "original-value"}, + rpc_timeout=timedelta(seconds=42), + ) + + # The interceptor should have added its metadata on top of the caller's + assert captured_metadata.get("test-header-key") == "test-header-value" + assert captured_metadata.get("original-key") == "original-value" + # The caller's timeout should have been forwarded + assert captured_timeout == [timedelta(seconds=42)]