@@ -289,18 +289,13 @@ async def _handle_activation(
289289 workflow_id = workflow_id ,
290290 )
291291 data_converter = self ._data_converter .with_context (workflow_context )
292- if self ._data_converter .payload_codec :
293- assert data_converter .payload_codec
294- if workflow :
295- data_converter = dataclasses .replace (
296- data_converter ,
297- payload_codec = _CommandAwarePayloadCodec (
298- workflow .instance ,
299- context_free_payload_codec = self ._data_converter .payload_codec ,
300- workflow_context_payload_codec = data_converter .payload_codec ,
301- workflow_context = workflow_context ,
302- ),
303- )
292+ if workflow :
293+ data_converter = _CommandAwareDataConverter .create (
294+ instance = workflow .instance ,
295+ context_free_dc = self ._data_converter ,
296+ workflow_context_dc = data_converter ,
297+ workflow_context = workflow_context ,
298+ )
304299 download_metrics = await temporalio .bridge .worker .decode_activation (
305300 act ,
306301 data_converter ,
@@ -395,19 +390,16 @@ async def _handle_activation(
395390 completion .run_id = act .run_id
396391
397392 # Encode completion
398- if self ._data_converter .payload_codec and workflow :
399- assert data_converter .payload_codec
400- data_converter = dataclasses .replace (
401- data_converter ,
402- payload_codec = _CommandAwarePayloadCodec (
403- workflow .instance ,
404- context_free_payload_codec = self ._data_converter .payload_codec ,
405- workflow_context_payload_codec = data_converter .payload_codec ,
406- workflow_context = temporalio .converter .WorkflowSerializationContext (
407- namespace = self ._namespace ,
408- workflow_id = workflow .workflow_id ,
409- ),
410- ),
393+ if workflow :
394+ workflow_context = temporalio .converter .WorkflowSerializationContext (
395+ namespace = self ._namespace ,
396+ workflow_id = workflow .workflow_id ,
397+ )
398+ data_converter = _CommandAwareDataConverter .create (
399+ instance = workflow .instance ,
400+ context_free_dc = self ._data_converter ,
401+ workflow_context_dc = self ._data_converter .with_context (workflow_context ),
402+ workflow_context = workflow_context ,
411403 )
412404
413405 upload_metrics = temporalio .converter ._extstore .StorageOperationMetrics ()
@@ -837,45 +829,74 @@ def attempt_deadlock_interruption(self) -> None:
837829
838830
839831@dataclass (frozen = True )
840- class _CommandAwarePayloadCodec (temporalio .converter .PayloadCodec ):
841- """A payload codec that sets serialization context for the command associated with each payload .
832+ class _CommandAwareDataConverter (temporalio .converter .DataConverter ):
833+ """Data converter that resolves serialization context per- command.
842834
843- This codec responds to the context variable set by
835+ Responds to the context variable set by
844836 :py:class:`_command_aware_visitor.CommandAwarePayloadVisitor`.
845837 """
846838
847- instance : WorkflowInstance
848- context_free_payload_codec : temporalio .converter .PayloadCodec
849- workflow_context_payload_codec : temporalio .converter .PayloadCodec
850- workflow_context : temporalio .converter .WorkflowSerializationContext
851-
852- async def encode (
853- self ,
854- payloads : Sequence [temporalio .api .common .v1 .Payload ],
855- ) -> list [temporalio .api .common .v1 .Payload ]:
856- return await self ._get_current_command_codec ().encode (payloads )
857-
858- async def decode (
859- self ,
860- payloads : Sequence [temporalio .api .common .v1 .Payload ],
861- ) -> list [temporalio .api .common .v1 .Payload ]:
862- return await self ._get_current_command_codec ().decode (payloads )
839+ _ca_instance : WorkflowInstance = dataclasses .field (
840+ default = None ,
841+ repr = False ,
842+ compare = False , # type: ignore[assignment]
843+ )
844+ _ca_context_free_dc : temporalio .converter .DataConverter = dataclasses .field (
845+ default = None ,
846+ repr = False ,
847+ compare = False , # type: ignore[assignment]
848+ )
849+ _ca_workflow_context_dc : temporalio .converter .DataConverter = dataclasses .field (
850+ default = None ,
851+ repr = False ,
852+ compare = False , # type: ignore[assignment]
853+ )
854+ _ca_workflow_context : temporalio .converter .WorkflowSerializationContext = (
855+ dataclasses .field (
856+ default = None ,
857+ repr = False ,
858+ compare = False , # type: ignore[assignment]
859+ )
860+ )
863861
864- def _get_current_command_codec (self ) -> temporalio .converter .PayloadCodec :
865- if not isinstance (
866- self .context_free_payload_codec ,
867- temporalio .converter .WithSerializationContext ,
868- ):
869- return self .context_free_payload_codec
862+ @staticmethod
863+ def create (
864+ instance : WorkflowInstance ,
865+ context_free_dc : temporalio .converter .DataConverter ,
866+ workflow_context_dc : temporalio .converter .DataConverter ,
867+ workflow_context : temporalio .converter .WorkflowSerializationContext ,
868+ ) -> _CommandAwareDataConverter :
869+ return _CommandAwareDataConverter (
870+ payload_converter_class = workflow_context_dc .payload_converter_class ,
871+ payload_codec = workflow_context_dc .payload_codec ,
872+ failure_converter_class = workflow_context_dc .failure_converter_class ,
873+ payload_limits = workflow_context_dc .payload_limits ,
874+ external_storage = workflow_context_dc .external_storage ,
875+ _ca_instance = instance ,
876+ _ca_context_free_dc = context_free_dc ,
877+ _ca_workflow_context_dc = workflow_context_dc ,
878+ _ca_workflow_context = workflow_context ,
879+ )
870880
871- if context := self .instance .get_serialization_context (
881+ def _get_current_dc (self ) -> temporalio .converter .DataConverter :
882+ context = self ._ca_instance .get_serialization_context (
872883 _command_aware_visitor .current_command_info .get (),
873- ):
874- if context == self .workflow_context :
875- return self .workflow_context_payload_codec
876- return self .context_free_payload_codec .with_context (context )
884+ )
885+ if context is None :
886+ return self ._ca_context_free_dc
887+ if context == self ._ca_workflow_context :
888+ return self ._ca_workflow_context_dc
889+ return self ._ca_context_free_dc .with_context (context )
890+
891+ async def _encode_payload_sequence (
892+ self , payloads : Sequence [temporalio .api .common .v1 .Payload ]
893+ ) -> list [temporalio .api .common .v1 .Payload ]:
894+ return await self ._get_current_dc ()._encode_payload_sequence (payloads )
877895
878- return self .context_free_payload_codec
896+ async def _decode_payload_sequence (
897+ self , payloads : Sequence [temporalio .api .common .v1 .Payload ]
898+ ) -> list [temporalio .api .common .v1 .Payload ]:
899+ return await self ._get_current_dc ()._decode_payload_sequence (payloads )
879900
880901
881902class _InterruptDeadlockError (BaseException ):
0 commit comments