Skip to content

Commit 83738c3

Browse files
authored
Fix command context awareness to apply to all of DataConverter (#1387)
1 parent 3788785 commit 83738c3

2 files changed

Lines changed: 163 additions & 56 deletions

File tree

temporalio/worker/_workflow.py

Lines changed: 77 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -289,18 +289,13 @@ async def _handle_activation(
289289
workflow_id=workflow_id,
290290
)
291291
data_converter = self._data_converter.with_context(workflow_context)
292-
if self._data_converter.payload_codec:
293-
assert data_converter.payload_codec
294-
if workflow:
295-
data_converter = dataclasses.replace(
296-
data_converter,
297-
payload_codec=_CommandAwarePayloadCodec(
298-
workflow.instance,
299-
context_free_payload_codec=self._data_converter.payload_codec,
300-
workflow_context_payload_codec=data_converter.payload_codec,
301-
workflow_context=workflow_context,
302-
),
303-
)
292+
if workflow:
293+
data_converter = _CommandAwareDataConverter.create(
294+
instance=workflow.instance,
295+
context_free_dc=self._data_converter,
296+
workflow_context_dc=data_converter,
297+
workflow_context=workflow_context,
298+
)
304299
download_metrics = await temporalio.bridge.worker.decode_activation(
305300
act,
306301
data_converter,
@@ -395,19 +390,16 @@ async def _handle_activation(
395390
completion.run_id = act.run_id
396391

397392
# Encode completion
398-
if self._data_converter.payload_codec and workflow:
399-
assert data_converter.payload_codec
400-
data_converter = dataclasses.replace(
401-
data_converter,
402-
payload_codec=_CommandAwarePayloadCodec(
403-
workflow.instance,
404-
context_free_payload_codec=self._data_converter.payload_codec,
405-
workflow_context_payload_codec=data_converter.payload_codec,
406-
workflow_context=temporalio.converter.WorkflowSerializationContext(
407-
namespace=self._namespace,
408-
workflow_id=workflow.workflow_id,
409-
),
410-
),
393+
if workflow:
394+
workflow_context = temporalio.converter.WorkflowSerializationContext(
395+
namespace=self._namespace,
396+
workflow_id=workflow.workflow_id,
397+
)
398+
data_converter = _CommandAwareDataConverter.create(
399+
instance=workflow.instance,
400+
context_free_dc=self._data_converter,
401+
workflow_context_dc=self._data_converter.with_context(workflow_context),
402+
workflow_context=workflow_context,
411403
)
412404

413405
upload_metrics = temporalio.converter._extstore.StorageOperationMetrics()
@@ -837,45 +829,74 @@ def attempt_deadlock_interruption(self) -> None:
837829

838830

839831
@dataclass(frozen=True)
840-
class _CommandAwarePayloadCodec(temporalio.converter.PayloadCodec):
841-
"""A payload codec that sets serialization context for the command associated with each payload.
832+
class _CommandAwareDataConverter(temporalio.converter.DataConverter):
833+
"""Data converter that resolves serialization context per-command.
842834
843-
This codec responds to the context variable set by
835+
Responds to the context variable set by
844836
:py:class:`_command_aware_visitor.CommandAwarePayloadVisitor`.
845837
"""
846838

847-
instance: WorkflowInstance
848-
context_free_payload_codec: temporalio.converter.PayloadCodec
849-
workflow_context_payload_codec: temporalio.converter.PayloadCodec
850-
workflow_context: temporalio.converter.WorkflowSerializationContext
851-
852-
async def encode(
853-
self,
854-
payloads: Sequence[temporalio.api.common.v1.Payload],
855-
) -> list[temporalio.api.common.v1.Payload]:
856-
return await self._get_current_command_codec().encode(payloads)
857-
858-
async def decode(
859-
self,
860-
payloads: Sequence[temporalio.api.common.v1.Payload],
861-
) -> list[temporalio.api.common.v1.Payload]:
862-
return await self._get_current_command_codec().decode(payloads)
839+
_ca_instance: WorkflowInstance = dataclasses.field(
840+
default=None,
841+
repr=False,
842+
compare=False, # type: ignore[assignment]
843+
)
844+
_ca_context_free_dc: temporalio.converter.DataConverter = dataclasses.field(
845+
default=None,
846+
repr=False,
847+
compare=False, # type: ignore[assignment]
848+
)
849+
_ca_workflow_context_dc: temporalio.converter.DataConverter = dataclasses.field(
850+
default=None,
851+
repr=False,
852+
compare=False, # type: ignore[assignment]
853+
)
854+
_ca_workflow_context: temporalio.converter.WorkflowSerializationContext = (
855+
dataclasses.field(
856+
default=None,
857+
repr=False,
858+
compare=False, # type: ignore[assignment]
859+
)
860+
)
863861

864-
def _get_current_command_codec(self) -> temporalio.converter.PayloadCodec:
865-
if not isinstance(
866-
self.context_free_payload_codec,
867-
temporalio.converter.WithSerializationContext,
868-
):
869-
return self.context_free_payload_codec
862+
@staticmethod
863+
def create(
864+
instance: WorkflowInstance,
865+
context_free_dc: temporalio.converter.DataConverter,
866+
workflow_context_dc: temporalio.converter.DataConverter,
867+
workflow_context: temporalio.converter.WorkflowSerializationContext,
868+
) -> _CommandAwareDataConverter:
869+
return _CommandAwareDataConverter(
870+
payload_converter_class=workflow_context_dc.payload_converter_class,
871+
payload_codec=workflow_context_dc.payload_codec,
872+
failure_converter_class=workflow_context_dc.failure_converter_class,
873+
payload_limits=workflow_context_dc.payload_limits,
874+
external_storage=workflow_context_dc.external_storage,
875+
_ca_instance=instance,
876+
_ca_context_free_dc=context_free_dc,
877+
_ca_workflow_context_dc=workflow_context_dc,
878+
_ca_workflow_context=workflow_context,
879+
)
870880

871-
if context := self.instance.get_serialization_context(
881+
def _get_current_dc(self) -> temporalio.converter.DataConverter:
882+
context = self._ca_instance.get_serialization_context(
872883
_command_aware_visitor.current_command_info.get(),
873-
):
874-
if context == self.workflow_context:
875-
return self.workflow_context_payload_codec
876-
return self.context_free_payload_codec.with_context(context)
884+
)
885+
if context is None:
886+
return self._ca_context_free_dc
887+
if context == self._ca_workflow_context:
888+
return self._ca_workflow_context_dc
889+
return self._ca_context_free_dc.with_context(context)
890+
891+
async def _encode_payload_sequence(
892+
self, payloads: Sequence[temporalio.api.common.v1.Payload]
893+
) -> list[temporalio.api.common.v1.Payload]:
894+
return await self._get_current_dc()._encode_payload_sequence(payloads)
877895

878-
return self.context_free_payload_codec
896+
async def _decode_payload_sequence(
897+
self, payloads: Sequence[temporalio.api.common.v1.Payload]
898+
) -> list[temporalio.api.common.v1.Payload]:
899+
return await self._get_current_dc()._decode_payload_sequence(payloads)
879900

880901

881902
class _InterruptDeadlockError(BaseException):

tests/test_serialization_context.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,15 @@
4040
DefaultFailureConverter,
4141
DefaultPayloadConverter,
4242
EncodingPayloadConverter,
43+
ExternalStorage,
4344
JSONPlainPayloadConverter,
4445
PayloadCodec,
4546
PayloadConverter,
4647
SerializationContext,
48+
StorageDriver,
49+
StorageDriverClaim,
50+
StorageDriverRetrieveContext,
51+
StorageDriverStoreContext,
4752
WithSerializationContext,
4853
WorkflowSerializationContext,
4954
)
@@ -1914,3 +1919,84 @@ async def test_user_customization_of_default_payload_converter(
19141919
id=wf_id,
19151920
task_queue=task_queue,
19161921
)
1922+
1923+
1924+
# Child workflow external storage context test
1925+
1926+
1927+
class ContextTrackingStorageDriver(StorageDriver):
1928+
"""In-memory driver that records the serialization context on each store/retrieve."""
1929+
1930+
def __init__(self) -> None:
1931+
self._storage: dict[str, bytes] = {}
1932+
self.store_contexts: list[SerializationContext | None] = []
1933+
1934+
def name(self) -> str:
1935+
return "context-tracking"
1936+
1937+
async def store(
1938+
self,
1939+
context: StorageDriverStoreContext,
1940+
payloads: Sequence[temporalio.api.common.v1.Payload],
1941+
) -> list[StorageDriverClaim]:
1942+
self.store_contexts.append(context.serialization_context)
1943+
claims: list[StorageDriverClaim] = []
1944+
for payload in payloads:
1945+
key = f"payload-{len(self._storage)}"
1946+
self._storage[key] = payload.SerializeToString()
1947+
claims.append(StorageDriverClaim(claim_data={"key": key}))
1948+
return claims
1949+
1950+
async def retrieve(
1951+
self,
1952+
context: StorageDriverRetrieveContext,
1953+
claims: Sequence[StorageDriverClaim],
1954+
) -> list[temporalio.api.common.v1.Payload]:
1955+
results: list[temporalio.api.common.v1.Payload] = []
1956+
for claim in claims:
1957+
payload = temporalio.api.common.v1.Payload()
1958+
payload.ParseFromString(self._storage[claim.claim_data["key"]])
1959+
results.append(payload)
1960+
return results
1961+
1962+
1963+
async def test_child_workflow_external_storage_with_context(client: Client):
1964+
"""External storage should receive the child workflow's context, not the parent's."""
1965+
workflow_id = str(uuid.uuid4())
1966+
child_workflow_id = f"{workflow_id}-child"
1967+
task_queue = str(uuid.uuid4())
1968+
1969+
driver = ContextTrackingStorageDriver()
1970+
config = client.config()
1971+
config["data_converter"] = dataclasses.replace(
1972+
DataConverter.default,
1973+
external_storage=ExternalStorage(
1974+
drivers=[driver],
1975+
payload_size_threshold=None,
1976+
),
1977+
)
1978+
client = Client(**config)
1979+
1980+
async with Worker(
1981+
client,
1982+
task_queue=task_queue,
1983+
workflows=[ChildWorkflowCodecTestWorkflow, EchoWorkflow],
1984+
workflow_runner=UnsandboxedWorkflowRunner(),
1985+
):
1986+
await client.execute_workflow(
1987+
ChildWorkflowCodecTestWorkflow.run,
1988+
TraceData(),
1989+
id=workflow_id,
1990+
task_queue=task_queue,
1991+
)
1992+
1993+
child_context = WorkflowSerializationContext(
1994+
namespace=client.namespace,
1995+
workflow_id=child_workflow_id,
1996+
)
1997+
# store_contexts[0]: parent input encode → parent context
1998+
# store_contexts[1]: child workflow input encode → child context
1999+
# store_contexts[2]: child workflow result encode → child context
2000+
# store_contexts[3]: parent result encode → parent context
2001+
child_context_count = sum(1 for c in driver.store_contexts if c == child_context)
2002+
assert child_context_count == 2

0 commit comments

Comments
 (0)