diff --git a/pyproject.toml b/pyproject.toml index 317b378cb..99b86cfcb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,14 +99,17 @@ format = [ { cmd = "cargo fmt", cwd = "temporalio/bridge" }, ] gen-docs = "uv run scripts/gen_docs.py" +gen-nexus-system-api = "uv run scripts/gen_nexus_system_api.py" gen-protos = [ { cmd = "uv run scripts/gen_protos.py" }, + { ref = "gen-nexus-system-api" }, { cmd = "uv run scripts/gen_payload_visitor.py" }, { cmd = "uv run scripts/gen_bridge_client.py" }, { ref = "format" }, ] gen-protos-docker = [ { cmd = "uv run scripts/gen_protos_docker.py" }, + { ref = "gen-nexus-system-api" }, { cmd = "uv run scripts/gen_payload_visitor.py" }, { cmd = "uv run scripts/gen_bridge_client.py" }, { ref = "format" }, @@ -170,7 +173,7 @@ exclude = [ [tool.pydocstyle] convention = "google" # https://github.com/PyCQA/pydocstyle/issues/363#issuecomment-625563088 -match_dir = "^(?!(docs|scripts|tests|api|proto|\\.)).*" +match_dir = "^(?!(docs|scripts|tests|api|proto|system|\\.)).*" add_ignore = [ # We like to wrap at a certain number of chars, even long summary sentences. # https://github.com/PyCQA/pydocstyle/issues/184 diff --git a/scripts/gen_nexus_system_api.py b/scripts/gen_nexus_system_api.py new file mode 100644 index 000000000..9baea590f --- /dev/null +++ b/scripts/gen_nexus_system_api.py @@ -0,0 +1,205 @@ +import os +import shutil +import subprocess +import sys +import tempfile +from importlib.util import module_from_spec, spec_from_file_location +from pathlib import Path +from typing import cast + +import gen_protos + +base_dir = Path(__file__).parent.parent +sys.path.insert(0, str(base_dir)) +wit_input_dir = ( + base_dir + / "temporalio" + / "bridge" + / "sdk-core" + / "crates" + / "protos" + / "protos" + / "api_upstream" + / "nexus" +) +wit_path = wit_input_dir / "workflow-service.wit" +wit_deps_dir = wit_input_dir / "deps" +python_support_path = ( + base_dir + / "temporalio" + / "nexus" + / "system" + / "_generation_support" + / "temporal_model_converters.py" +) +output_dir = base_dir / "temporalio" / "nexus" / "system" / "workflow_service" +workflow_init_path = base_dir / "temporalio" / "workflow" / "__init__.py" +workflowservice_request_response_proto = ( + gen_protos.api_proto_dir + / "temporal" + / "api" + / "workflowservice" + / "v1" + / "request_response.proto" +) + + +def nex_gen_command() -> list[str]: + if bin_path := os.environ.get("NEX_GEN_BIN"): + return [bin_path] + + if shutil.which("nex-gen") is None: + subprocess.check_call(["cargo", "install", "--locked", "nex-gen"]) + return ["nex-gen"] + + +def build_descriptor_set(descriptor_path: Path) -> None: + subprocess.check_call( + [ + sys.executable, + "-mgrpc_tools.protoc", + f"--proto_path={gen_protos.api_proto_dir}", + f"--proto_path={gen_protos.proto_dir}", + "--include_imports", + f"--descriptor_set_out={descriptor_path}", + str(workflowservice_request_response_proto), + ] + ) + + +def strip_unsupported_pyright_comments() -> None: + for path in output_dir.rglob("*.py"): + content = path.read_text() + content = content.replace("# pyright: reportAny=false\n", "") + content = content.replace( + "# pyright: reportAny=false, reportExplicitAny=false\n", "" + ) + path.write_text(content) + + +def generate_workflow_exports() -> None: + spec = spec_from_file_location( + "temporalio_nexus_system_workflow_service_exports", + output_dir / "__init__.py", + submodule_search_locations=[str(output_dir)], + ) + if spec is None or spec.loader is None: + raise RuntimeError(f"Cannot load generated workflow service from {output_dir}") + module = module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + exports = cast(list[str], module.__all__) + + import_block = [ + "# BEGIN GENERATED NEXUS SYSTEM EXPORTS\n", + "from temporalio.nexus.system.workflow_service import (\n", + *[f" {export},\n" for export in exports], + ")\n", + "# END GENERATED NEXUS SYSTEM EXPORTS\n", + ] + all_block = [ + " # BEGIN GENERATED NEXUS SYSTEM __ALL__\n", + *[f' "{export}",\n' for export in exports], + " # END GENERATED NEXUS SYSTEM __ALL__\n", + ] + content = workflow_init_path.read_text() + start = content.index("# BEGIN GENERATED NEXUS SYSTEM EXPORTS") + end = content.index("# END GENERATED NEXUS SYSTEM EXPORTS", start) + end = content.index("\n", end) + 1 + content = content[:start] + "".join(import_block) + content[end:] + start = content.index(" # BEGIN GENERATED NEXUS SYSTEM __ALL__") + end = content.index(" # END GENERATED NEXUS SYSTEM __ALL__", start) + end = content.index("\n", end) + 1 + workflow_init_path.write_text(content[:start] + "".join(all_block) + content[end:]) + + +def prepare_wit_workspace(temp_dir: Path) -> tuple[Path, Path]: + workspace_input_dir = temp_dir / "nexus" + shutil.copytree(wit_input_dir, workspace_input_dir) + + model_path = workspace_input_dir / "deps" / "nexus-temporal-types" / "model.wit" + model_content = model_path.read_text() + if "@nexus.support" not in model_content: + support_path = ( + workspace_input_dir + / "deps" + / "nexus-temporal-types" + / "python" + / python_support_path.name + ) + support_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(python_support_path, support_path) + model_path.write_text( + '/// @nexus.support python="python/temporal_model_converters.py"\n' + + model_content + ) + + return workspace_input_dir / "workflow-service.wit", workspace_input_dir / "deps" + + +def generate_nexus_system_api() -> None: + if not wit_path.exists(): + raise RuntimeError(f"missing WIT source: {wit_path}") + if not wit_deps_dir.exists(): + raise RuntimeError(f"missing WIT dependency directory: {wit_deps_dir}") + if not python_support_path.exists(): + raise RuntimeError(f"missing Python support source: {python_support_path}") + + with tempfile.TemporaryDirectory(dir=base_dir) as temp_dir: + temp_path = Path(temp_dir) + descriptor_path = temp_path / "temporal_api.bin" + workspace_wit_path, workspace_wit_deps_dir = prepare_wit_workspace(temp_path) + build_descriptor_set(descriptor_path) + command = nex_gen_command() + + shutil.rmtree(output_dir, ignore_errors=True) + output_dir.parent.mkdir(parents=True, exist_ok=True) + subprocess.check_call( + [ + *command, + "generate", + "--lang", + "python", + "--input", + str(workspace_wit_path), + "--input", + str(workspace_wit_deps_dir), + "--descriptors", + str(descriptor_path), + "--output", + str(output_dir), + ] + ) + + (output_dir.parent / "__init__.py").touch() + strip_unsupported_pyright_comments() + generate_workflow_exports() + subprocess.check_call( + [ + sys.executable, + "-m", + "ruff", + "check", + "--select", + "I", + "--fix", + str(output_dir), + str(workflow_init_path), + ] + ) + subprocess.check_call( + [ + sys.executable, + "-m", + "ruff", + "format", + str(output_dir), + str(workflow_init_path), + ] + ) + + +if __name__ == "__main__": + print("Generating Nexus system API...", file=sys.stderr) + generate_nexus_system_api() + print("Done", file=sys.stderr) diff --git a/scripts/gen_payload_visitor.py b/scripts/gen_payload_visitor.py index 928be03e5..4e0d780aa 100644 --- a/scripts/gen_payload_visitor.py +++ b/scripts/gen_payload_visitor.py @@ -1,9 +1,16 @@ import subprocess import sys +from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path +from typing import cast +import google.protobuf.message +import nexusrpc from google.protobuf.descriptor import Descriptor, FieldDescriptor +base_dir = Path(__file__).parent.parent +sys.path.insert(0, str(base_dir)) + from temporalio.api.common.v1.message_pb2 import Payload, Payloads, SearchAttributes from temporalio.bridge.proto.workflow_activation.workflow_activation_pb2 import ( WorkflowActivation, @@ -12,7 +19,36 @@ WorkflowActivationCompletion, ) -base_dir = Path(__file__).parent.parent + +def discover_system_nexus_roots() -> list[Descriptor]: + module_path = ( + base_dir / "temporalio" / "nexus" / "system" / "workflow_service" / "service.py" + ) + spec = spec_from_file_location( + "temporalio_nexus_system_workflow_service", module_path + ) + if spec is None or spec.loader is None: + raise RuntimeError(f"Cannot load generated system service from {module_path}") + module = module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + + roots: list[Descriptor] = [] + for operation in vars(module.WorkflowService).values(): + if not isinstance(operation, nexusrpc.Operation): + continue + for proto_type in (operation.input_type, operation.output_type): + if isinstance(proto_type, type) and issubclass( + proto_type, google.protobuf.message.Message + ): + roots.append(cast(Descriptor, proto_type.DESCRIPTOR)) + deduped: list[Descriptor] = [] + seen: set[str] = set() + for root in roots: + if root.full_name not in seen: + seen.add(root.full_name) + deduped.append(root) + return deduped def name_for(desc: Descriptor) -> str: @@ -80,28 +116,15 @@ def generate(self, roots: list[Descriptor]) -> str: self.walk(r) header = """ +from __future__ import annotations + # This file is generated by gen_payload_visitor.py. Changes should be made there. -import abc import asyncio -from typing import Any, MutableSequence +from typing import Any +import temporalio.nexus.system from temporalio.api.common.v1.message_pb2 import Payload - - -class VisitorFunctions(abc.ABC): - \"\"\"Set of functions which can be called by the visitor. - Allows handling payloads as a sequence. - \"\"\" - - @abc.abstractmethod - async def visit_payload(self, payload: Payload) -> None: - \"\"\"Called when encountering a single payload.\"\"\" - raise NotImplementedError() - - @abc.abstractmethod - async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: - \"\"\"Called when encountering multiple payloads together.\"\"\" - raise NotImplementedError() +from temporalio.bridge._visitor_functions import PayloadSequence, VisitorFunctions class _BoundedVisitorFunctions(VisitorFunctions): @@ -126,7 +149,7 @@ async def _run() -> None: self._tasks.append(asyncio.create_task(_run())) - async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: + async def visit_payloads(self, payloads: PayloadSequence) -> None: await self._sem.acquire() async def _run() -> None: @@ -137,6 +160,9 @@ async def _run() -> None: self._tasks.append(asyncio.create_task(_run())) + async def visit_system_nexus_envelope(self, payload: Payload) -> None: + await self._inner.visit_system_nexus_envelope(payload) + async def drain(self) -> None: \"\"\"Wait for all in-flight background tasks to complete. @@ -199,6 +225,28 @@ async def visit( finally: await bounded.drain() + async def _visit_system_nexus_payload( + self, + fs: VisitorFunctions, + service: str, + operation: str, + payload: Payload, + ) -> None: + new_payload = await temporalio.nexus.system.visit_payload( + service, + operation, + payload, + fs, + self.skip_search_attributes, + ) + if new_payload is None: + await self._visit_temporal_api_common_v1_Payload(fs, payload) + return + + if new_payload is not payload: + payload.CopyFrom(new_payload) + await fs.visit_system_nexus_envelope(payload) + """ return header + "\n".join(self.methods) @@ -212,15 +260,15 @@ def __init__(self): self.in_progress: set[str] = set() self.methods: list[str] = [ """\ - async def _visit_temporal_api_common_v1_Payload(self, fs, o): + async def _visit_temporal_api_common_v1_Payload(self, fs: VisitorFunctions, o: Payload): await fs.visit_payload(o) """, """\ - async def _visit_temporal_api_common_v1_Payloads(self, fs, o): + async def _visit_temporal_api_common_v1_Payloads(self, fs: VisitorFunctions, o: Any): await fs.visit_payloads(o.payloads) """, """\ - async def _visit_payload_container(self, fs, o): + async def _visit_payload_container(self, fs: VisitorFunctions, o: PayloadSequence): await fs.visit_payloads(o) """, ] @@ -275,6 +323,22 @@ def walk(self, desc: Descriptor) -> bool: # Process regular fields first for field in regular_fields: + if ( + desc.full_name == "coresdk.workflow_commands.ScheduleNexusOperation" + and field.name == "input" + ): + has_payload = True + emit_items.append( + ( + "system_nexus", + field.name, + "o.service", + "o.operation", + "o.input", + ) + ) + continue + # Repeated fields (including maps which are represented as repeated messages) if field.label == FieldDescriptor.LABEL_REPEATED: if ( @@ -359,7 +423,10 @@ def walk(self, desc: Descriptor) -> bool: self.in_progress.discard(key) if has_payload: - lines: list[str] = [f" async def _visit_{name_for(desc)}(self, fs, o):"] + lines: list[str] = [ + f" async def _visit_{name_for(desc)}" + "(self, fs: VisitorFunctions, o: Any):" + ] if is_search_attrs: lines.append(" if self.skip_search_attributes:") lines.append(" return") @@ -375,6 +442,14 @@ def walk(self, desc: Descriptor) -> bool: field_name, access_expr, child_method, presence_word ) ) + elif item[0] == "system_nexus": + _, field_name, service_expr, operation_expr, payload_expr = item + lines.append( + f' if o.HasField("{field_name}"):\n' + " await self._visit_system_nexus_payload(\n" + f" fs, {service_expr}, {operation_expr}, {payload_expr}\n" + " )" + ) else: # oneof_group for field_name, access_expr, child_method, presence_word in item[1]: lines.append( @@ -387,8 +462,7 @@ def walk(self, desc: Descriptor) -> bool: return has_payload -def write_generated_visitors_into_visitor_generated_py() -> None: - """Write the generated visitor code into _visitor.py.""" +def write_bridge_visitors() -> None: out_path = base_dir / "temporalio" / "bridge" / "_visitor.py" # Build root descriptors: WorkflowActivation, WorkflowActivationCompletion, @@ -402,7 +476,41 @@ def write_generated_visitors_into_visitor_generated_py() -> None: out_path.write_text(code) +def write_system_nexus_payload_visitors() -> None: + out_path = base_dir / "temporalio" / "nexus" / "system" / "_payload_visitor.py" + code = VisitorGenerator().generate(discover_system_nexus_roots()) + out_path.write_text(code) + + if __name__ == "__main__": print("Generating temporalio/bridge/_visitor.py...", file=sys.stderr) - write_generated_visitors_into_visitor_generated_py() - subprocess.run(["uv", "run", "ruff", "format", "temporalio/bridge/_visitor.py"]) + write_bridge_visitors() + print("Generating temporalio/nexus/system/_payload_visitor.py...", file=sys.stderr) + write_system_nexus_payload_visitors() + subprocess.run( + [ + "uv", + "run", + "ruff", + "check", + "--select", + "I", + "--fix", + "temporalio/bridge/_visitor.py", + "temporalio/nexus/system/_payload_visitor.py", + ], + cwd=base_dir, + check=True, + ) + subprocess.run( + [ + "uv", + "run", + "ruff", + "format", + "temporalio/bridge/_visitor.py", + "temporalio/nexus/system/_payload_visitor.py", + ], + cwd=base_dir, + check=True, + ) diff --git a/temporalio/bridge/_visitor.py b/temporalio/bridge/_visitor.py index 0f030ac01..40bbceb8f 100644 --- a/temporalio/bridge/_visitor.py +++ b/temporalio/bridge/_visitor.py @@ -1,25 +1,12 @@ +from __future__ import annotations + # This file is generated by gen_payload_visitor.py. Changes should be made there. -import abc import asyncio -from typing import Any, MutableSequence +from typing import Any +import temporalio.nexus.system from temporalio.api.common.v1.message_pb2 import Payload - - -class VisitorFunctions(abc.ABC): - """Set of functions which can be called by the visitor. - Allows handling payloads as a sequence. - """ - - @abc.abstractmethod - async def visit_payload(self, payload: Payload) -> None: - """Called when encountering a single payload.""" - raise NotImplementedError() - - @abc.abstractmethod - async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: - """Called when encountering multiple payloads together.""" - raise NotImplementedError() +from temporalio.bridge._visitor_functions import PayloadSequence, VisitorFunctions class _BoundedVisitorFunctions(VisitorFunctions): @@ -44,7 +31,7 @@ async def _run() -> None: self._tasks.append(asyncio.create_task(_run())) - async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: + async def visit_payloads(self, payloads: PayloadSequence) -> None: await self._sem.acquire() async def _run() -> None: @@ -55,6 +42,9 @@ async def _run() -> None: self._tasks.append(asyncio.create_task(_run())) + async def visit_system_nexus_envelope(self, payload: Payload) -> None: + await self._inner.visit_system_nexus_envelope(payload) + async def drain(self) -> None: """Wait for all in-flight background tasks to complete. @@ -117,36 +107,72 @@ async def visit(self, fs: VisitorFunctions, root: Any) -> None: finally: await bounded.drain() - async def _visit_temporal_api_common_v1_Payload(self, fs, o): + async def _visit_system_nexus_payload( + self, + fs: VisitorFunctions, + service: str, + operation: str, + payload: Payload, + ) -> None: + new_payload = await temporalio.nexus.system.visit_payload( + service, + operation, + payload, + fs, + self.skip_search_attributes, + ) + if new_payload is None: + await self._visit_temporal_api_common_v1_Payload(fs, payload) + return + + if new_payload is not payload: + payload.CopyFrom(new_payload) + await fs.visit_system_nexus_envelope(payload) + + async def _visit_temporal_api_common_v1_Payload( + self, fs: VisitorFunctions, o: Payload + ): await fs.visit_payload(o) - async def _visit_temporal_api_common_v1_Payloads(self, fs, o): + async def _visit_temporal_api_common_v1_Payloads( + self, fs: VisitorFunctions, o: Any + ): await fs.visit_payloads(o.payloads) - async def _visit_payload_container(self, fs, o): + async def _visit_payload_container(self, fs: VisitorFunctions, o: PayloadSequence): await fs.visit_payloads(o) - async def _visit_temporal_api_failure_v1_ApplicationFailureInfo(self, fs, o): + async def _visit_temporal_api_failure_v1_ApplicationFailureInfo( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("details"): await self._visit_temporal_api_common_v1_Payloads(fs, o.details) - async def _visit_temporal_api_failure_v1_TimeoutFailureInfo(self, fs, o): + async def _visit_temporal_api_failure_v1_TimeoutFailureInfo( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("last_heartbeat_details"): await self._visit_temporal_api_common_v1_Payloads( fs, o.last_heartbeat_details ) - async def _visit_temporal_api_failure_v1_CanceledFailureInfo(self, fs, o): + async def _visit_temporal_api_failure_v1_CanceledFailureInfo( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("details"): await self._visit_temporal_api_common_v1_Payloads(fs, o.details) - async def _visit_temporal_api_failure_v1_ResetWorkflowFailureInfo(self, fs, o): + async def _visit_temporal_api_failure_v1_ResetWorkflowFailureInfo( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("last_heartbeat_details"): await self._visit_temporal_api_common_v1_Payloads( fs, o.last_heartbeat_details ) - async def _visit_temporal_api_failure_v1_Failure(self, fs, o): + async def _visit_temporal_api_failure_v1_Failure( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("encoded_attributes"): await self._visit_temporal_api_common_v1_Payload(fs, o.encoded_attributes) if o.HasField("cause"): @@ -168,17 +194,21 @@ async def _visit_temporal_api_failure_v1_Failure(self, fs, o): fs, o.reset_workflow_failure_info ) - async def _visit_temporal_api_common_v1_Memo(self, fs, o): + async def _visit_temporal_api_common_v1_Memo(self, fs: VisitorFunctions, o: Any): for v in o.fields.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) - async def _visit_temporal_api_common_v1_SearchAttributes(self, fs, o): + async def _visit_temporal_api_common_v1_SearchAttributes( + self, fs: VisitorFunctions, o: Any + ): if self.skip_search_attributes: return for v in o.indexed_fields.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) - async def _visit_coresdk_workflow_activation_InitializeWorkflow(self, fs, o): + async def _visit_coresdk_workflow_activation_InitializeWorkflow( + self, fs: VisitorFunctions, o: Any + ): await self._visit_payload_container(fs, o.arguments) if not self.skip_headers: for v in o.headers.values(): @@ -196,31 +226,43 @@ async def _visit_coresdk_workflow_activation_InitializeWorkflow(self, fs, o): fs, o.search_attributes ) - async def _visit_coresdk_workflow_activation_QueryWorkflow(self, fs, o): + async def _visit_coresdk_workflow_activation_QueryWorkflow( + self, fs: VisitorFunctions, o: Any + ): await self._visit_payload_container(fs, o.arguments) if not self.skip_headers: for v in o.headers.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) - async def _visit_coresdk_workflow_activation_SignalWorkflow(self, fs, o): + async def _visit_coresdk_workflow_activation_SignalWorkflow( + self, fs: VisitorFunctions, o: Any + ): await self._visit_payload_container(fs, o.input) if not self.skip_headers: for v in o.headers.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) - async def _visit_coresdk_activity_result_Success(self, fs, o): + async def _visit_coresdk_activity_result_Success( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("result"): await self._visit_temporal_api_common_v1_Payload(fs, o.result) - async def _visit_coresdk_activity_result_Failure(self, fs, o): + async def _visit_coresdk_activity_result_Failure( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) - async def _visit_coresdk_activity_result_Cancellation(self, fs, o): + async def _visit_coresdk_activity_result_Cancellation( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) - async def _visit_coresdk_activity_result_ActivityResolution(self, fs, o): + async def _visit_coresdk_activity_result_ActivityResolution( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("completed"): await self._visit_coresdk_activity_result_Success(fs, o.completed) elif o.HasField("failed"): @@ -228,37 +270,43 @@ async def _visit_coresdk_activity_result_ActivityResolution(self, fs, o): elif o.HasField("cancelled"): await self._visit_coresdk_activity_result_Cancellation(fs, o.cancelled) - async def _visit_coresdk_workflow_activation_ResolveActivity(self, fs, o): + async def _visit_coresdk_workflow_activation_ResolveActivity( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("result"): await self._visit_coresdk_activity_result_ActivityResolution(fs, o.result) async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStartCancelled( - self, fs, o + self, fs: VisitorFunctions, o: Any ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStart( - self, fs, o + self, fs: VisitorFunctions, o: Any ): if o.HasField("cancelled"): await self._visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStartCancelled( fs, o.cancelled ) - async def _visit_coresdk_child_workflow_Success(self, fs, o): + async def _visit_coresdk_child_workflow_Success(self, fs: VisitorFunctions, o: Any): if o.HasField("result"): await self._visit_temporal_api_common_v1_Payload(fs, o.result) - async def _visit_coresdk_child_workflow_Failure(self, fs, o): + async def _visit_coresdk_child_workflow_Failure(self, fs: VisitorFunctions, o: Any): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) - async def _visit_coresdk_child_workflow_Cancellation(self, fs, o): + async def _visit_coresdk_child_workflow_Cancellation( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) - async def _visit_coresdk_child_workflow_ChildWorkflowResult(self, fs, o): + async def _visit_coresdk_child_workflow_ChildWorkflowResult( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("completed"): await self._visit_coresdk_child_workflow_Success(fs, o.completed) elif o.HasField("failed"): @@ -267,36 +315,40 @@ async def _visit_coresdk_child_workflow_ChildWorkflowResult(self, fs, o): await self._visit_coresdk_child_workflow_Cancellation(fs, o.cancelled) async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecution( - self, fs, o + self, fs: VisitorFunctions, o: Any ): if o.HasField("result"): await self._visit_coresdk_child_workflow_ChildWorkflowResult(fs, o.result) async def _visit_coresdk_workflow_activation_ResolveSignalExternalWorkflow( - self, fs, o + self, fs: VisitorFunctions, o: Any ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) async def _visit_coresdk_workflow_activation_ResolveRequestCancelExternalWorkflow( - self, fs, o + self, fs: VisitorFunctions, o: Any ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) - async def _visit_coresdk_workflow_activation_DoUpdate(self, fs, o): + async def _visit_coresdk_workflow_activation_DoUpdate( + self, fs: VisitorFunctions, o: Any + ): await self._visit_payload_container(fs, o.input) if not self.skip_headers: for v in o.headers.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) async def _visit_coresdk_workflow_activation_ResolveNexusOperationStart( - self, fs, o + self, fs: VisitorFunctions, o: Any ): if o.HasField("failed"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failed) - async def _visit_coresdk_nexus_NexusOperationResult(self, fs, o): + async def _visit_coresdk_nexus_NexusOperationResult( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("completed"): await self._visit_temporal_api_common_v1_Payload(fs, o.completed) elif o.HasField("failed"): @@ -306,11 +358,15 @@ async def _visit_coresdk_nexus_NexusOperationResult(self, fs, o): elif o.HasField("timed_out"): await self._visit_temporal_api_failure_v1_Failure(fs, o.timed_out) - async def _visit_coresdk_workflow_activation_ResolveNexusOperation(self, fs, o): + async def _visit_coresdk_workflow_activation_ResolveNexusOperation( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("result"): await self._visit_coresdk_nexus_NexusOperationResult(fs, o.result) - async def _visit_coresdk_workflow_activation_WorkflowActivationJob(self, fs, o): + async def _visit_coresdk_workflow_activation_WorkflowActivationJob( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("initialize_workflow"): await self._visit_coresdk_workflow_activation_InitializeWorkflow( fs, o.initialize_workflow @@ -354,42 +410,56 @@ async def _visit_coresdk_workflow_activation_WorkflowActivationJob(self, fs, o): fs, o.resolve_nexus_operation ) - async def _visit_coresdk_workflow_activation_WorkflowActivation(self, fs, o): + async def _visit_coresdk_workflow_activation_WorkflowActivation( + self, fs: VisitorFunctions, o: Any + ): for v in o.jobs: await self._visit_coresdk_workflow_activation_WorkflowActivationJob(fs, v) - async def _visit_temporal_api_sdk_v1_UserMetadata(self, fs, o): + async def _visit_temporal_api_sdk_v1_UserMetadata( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("summary"): await self._visit_temporal_api_common_v1_Payload(fs, o.summary) if o.HasField("details"): await self._visit_temporal_api_common_v1_Payload(fs, o.details) - async def _visit_coresdk_workflow_commands_ScheduleActivity(self, fs, o): + async def _visit_coresdk_workflow_commands_ScheduleActivity( + self, fs: VisitorFunctions, o: Any + ): if not self.skip_headers: for v in o.headers.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) await self._visit_payload_container(fs, o.arguments) - async def _visit_coresdk_workflow_commands_QuerySuccess(self, fs, o): + async def _visit_coresdk_workflow_commands_QuerySuccess( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("response"): await self._visit_temporal_api_common_v1_Payload(fs, o.response) - async def _visit_coresdk_workflow_commands_QueryResult(self, fs, o): + async def _visit_coresdk_workflow_commands_QueryResult( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("succeeded"): await self._visit_coresdk_workflow_commands_QuerySuccess(fs, o.succeeded) elif o.HasField("failed"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failed) - async def _visit_coresdk_workflow_commands_CompleteWorkflowExecution(self, fs, o): + async def _visit_coresdk_workflow_commands_CompleteWorkflowExecution( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("result"): await self._visit_temporal_api_common_v1_Payload(fs, o.result) - async def _visit_coresdk_workflow_commands_FailWorkflowExecution(self, fs, o): + async def _visit_coresdk_workflow_commands_FailWorkflowExecution( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) async def _visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution( - self, fs, o + self, fs: VisitorFunctions, o: Any ): await self._visit_payload_container(fs, o.arguments) for v in o.memo.values(): @@ -402,7 +472,9 @@ async def _visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution( fs, o.search_attributes ) - async def _visit_coresdk_workflow_commands_StartChildWorkflowExecution(self, fs, o): + async def _visit_coresdk_workflow_commands_StartChildWorkflowExecution( + self, fs: VisitorFunctions, o: Any + ): await self._visit_payload_container(fs, o.input) if not self.skip_headers: for v in o.headers.values(): @@ -415,42 +487,52 @@ async def _visit_coresdk_workflow_commands_StartChildWorkflowExecution(self, fs, ) async def _visit_coresdk_workflow_commands_SignalExternalWorkflowExecution( - self, fs, o + self, fs: VisitorFunctions, o: Any ): await self._visit_payload_container(fs, o.args) if not self.skip_headers: for v in o.headers.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) - async def _visit_coresdk_workflow_commands_ScheduleLocalActivity(self, fs, o): + async def _visit_coresdk_workflow_commands_ScheduleLocalActivity( + self, fs: VisitorFunctions, o: Any + ): if not self.skip_headers: for v in o.headers.values(): await self._visit_temporal_api_common_v1_Payload(fs, v) await self._visit_payload_container(fs, o.arguments) async def _visit_coresdk_workflow_commands_UpsertWorkflowSearchAttributes( - self, fs, o + self, fs: VisitorFunctions, o: Any ): if o.HasField("search_attributes"): await self._visit_temporal_api_common_v1_SearchAttributes( fs, o.search_attributes ) - async def _visit_coresdk_workflow_commands_ModifyWorkflowProperties(self, fs, o): + async def _visit_coresdk_workflow_commands_ModifyWorkflowProperties( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("upserted_memo"): await self._visit_temporal_api_common_v1_Memo(fs, o.upserted_memo) - async def _visit_coresdk_workflow_commands_UpdateResponse(self, fs, o): + async def _visit_coresdk_workflow_commands_UpdateResponse( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("rejected"): await self._visit_temporal_api_failure_v1_Failure(fs, o.rejected) elif o.HasField("completed"): await self._visit_temporal_api_common_v1_Payload(fs, o.completed) - async def _visit_coresdk_workflow_commands_ScheduleNexusOperation(self, fs, o): + async def _visit_coresdk_workflow_commands_ScheduleNexusOperation( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("input"): - await self._visit_temporal_api_common_v1_Payload(fs, o.input) + await self._visit_system_nexus_payload(fs, o.service, o.operation, o.input) - async def _visit_coresdk_workflow_commands_WorkflowCommand(self, fs, o): + async def _visit_coresdk_workflow_commands_WorkflowCommand( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("user_metadata"): await self._visit_temporal_api_sdk_v1_UserMetadata(fs, o.user_metadata) if o.HasField("schedule_activity"): @@ -502,16 +584,20 @@ async def _visit_coresdk_workflow_commands_WorkflowCommand(self, fs, o): fs, o.schedule_nexus_operation ) - async def _visit_coresdk_workflow_completion_Success(self, fs, o): + async def _visit_coresdk_workflow_completion_Success( + self, fs: VisitorFunctions, o: Any + ): for v in o.commands: await self._visit_coresdk_workflow_commands_WorkflowCommand(fs, v) - async def _visit_coresdk_workflow_completion_Failure(self, fs, o): + async def _visit_coresdk_workflow_completion_Failure( + self, fs: VisitorFunctions, o: Any + ): if o.HasField("failure"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) async def _visit_coresdk_workflow_completion_WorkflowActivationCompletion( - self, fs, o + self, fs: VisitorFunctions, o: Any ): if o.HasField("successful"): await self._visit_coresdk_workflow_completion_Success(fs, o.successful) diff --git a/temporalio/bridge/_visitor_functions.py b/temporalio/bridge/_visitor_functions.py new file mode 100644 index 000000000..6014f8e75 --- /dev/null +++ b/temporalio/bridge/_visitor_functions.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from typing import Protocol + +from google.protobuf.internal.containers import RepeatedCompositeFieldContainer + +from temporalio.api.common.v1.message_pb2 import Payload + +PayloadSequence = list[Payload] | RepeatedCompositeFieldContainer[Payload] + + +class VisitorFunctions(Protocol): + """Functions invoked by generated payload visitors.""" + + async def visit_payload(self, payload: Payload) -> None: + """Visit a single payload.""" + ... + + async def visit_payloads(self, payloads: PayloadSequence) -> None: + """Visit a sequence of payloads together.""" + ... + + async def visit_system_nexus_envelope(self, payload: Payload) -> None: + """Visit a recognized system Nexus envelope payload.""" + return None diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index a9c857373..e1e23dd89 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -5,7 +5,7 @@ from __future__ import annotations -from collections.abc import Awaitable, Callable, MutableSequence, Sequence +from collections.abc import Awaitable, Callable, Sequence from dataclasses import dataclass from typing import ( TypeAlias, @@ -22,7 +22,7 @@ import temporalio.converter import temporalio.converter._extstore from temporalio.api.common.v1.message_pb2 import Payload -from temporalio.bridge._visitor import VisitorFunctions +from temporalio.bridge._visitor_functions import PayloadSequence, VisitorFunctions from temporalio.bridge.temporal_sdk_bridge import ( CustomSlotSupplier as BridgeCustomSlotSupplier, ) @@ -281,15 +281,20 @@ async def finalize_shutdown(self) -> None: class _Visitor(VisitorFunctions): - def __init__(self, f: Callable[[Sequence[Payload]], Awaitable[list[Payload]]]): + def __init__( + self, + f: Callable[[Sequence[Payload]], Awaitable[list[Payload]]], + visit_system_nexus_envelope: Callable[[Payload], Awaitable[None]] | None = None, + ): self._f = f + self._visit_system_nexus_envelope = visit_system_nexus_envelope async def visit_payload(self, payload: Payload) -> None: new_payload = (await self._f([payload]))[0] if new_payload is not payload: payload.CopyFrom(new_payload) - async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: + async def visit_payloads(self, payloads: PayloadSequence) -> None: if len(payloads) == 0: return new_payloads = await self._f(payloads) @@ -298,6 +303,10 @@ async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: del payloads[:] payloads.extend(new_payloads) + async def visit_system_nexus_envelope(self, payload: Payload) -> None: + if self._visit_system_nexus_envelope is not None: + await self._visit_system_nexus_envelope(payload) + async def decode_activation( activation: temporalio.bridge.proto.workflow_activation.WorkflowActivation, @@ -339,10 +348,20 @@ async def encode_completion( Returns: Metrics from any external storage store operations that occurred. """ + + async def _validate_system_nexus_envelope(payload: Payload) -> None: + data_converter._validate_payload_limits([payload]) + await CommandAwarePayloadVisitor( skip_search_attributes=True, skip_headers=not encode_headers, - ).visit(_Visitor(data_converter._encode_payload_sequence), completion) + ).visit( + _Visitor( + data_converter._encode_payload_sequence, + visit_system_nexus_envelope=_validate_system_nexus_envelope, + ), + completion, + ) async def _store_and_validate( payloads: Sequence[Payload], @@ -357,6 +376,12 @@ async def _store_and_validate( skip_search_attributes=True, skip_headers=not encode_headers, concurrency_limit=storage_concurrency_limit, - ).visit(_Visitor(_store_and_validate), completion) + ).visit( + _Visitor( + _store_and_validate, + visit_system_nexus_envelope=_validate_system_nexus_envelope, + ), + completion, + ) return metrics diff --git a/temporalio/nexus/system/__init__.py b/temporalio/nexus/system/__init__.py new file mode 100644 index 000000000..0187e0291 --- /dev/null +++ b/temporalio/nexus/system/__init__.py @@ -0,0 +1,74 @@ +"""System Nexus operation helpers.""" + +from __future__ import annotations + +import typing + +import google.protobuf.message +import nexusrpc + +import temporalio.api.common.v1 +import temporalio.converter +from temporalio.bridge._visitor_functions import VisitorFunctions +from temporalio.converter import BinaryProtoPayloadConverter, CompositePayloadConverter +from temporalio.nexus.system import workflow_service + + +class SystemNexusPayloadConverter(CompositePayloadConverter): + """Payload converter for system Nexus outer envelopes.""" + + def __init__(self) -> None: + """Create a payload converter for system Nexus outer envelopes.""" + super().__init__(BinaryProtoPayloadConverter()) + + +def _operation( + service: str, operation: str +) -> nexusrpc.Operation[typing.Any, typing.Any] | None: + return workflow_service.__nexus_operation_registry__.get((service, operation)) + + +async def visit_payload( + service: str, + operation: str, + payload: temporalio.api.common.v1.Payload, + visitor_functions: VisitorFunctions, + skip_search_attributes: bool, +) -> temporalio.api.common.v1.Payload | None: + """Visit nested payloads inside a recognized system Nexus envelope.""" + operation_def = _operation(service, operation) + if operation_def is None: + return None + input_type = operation_def.input_type + if not ( + isinstance(input_type, type) + and issubclass(input_type, google.protobuf.message.Message) + ): + return None + + payload_converter = get_payload_converter() + value = payload_converter.from_payload(payload, input_type) + from ._payload_visitor import PayloadVisitor + + await PayloadVisitor(skip_search_attributes=skip_search_attributes).visit( + visitor_functions, value + ) + return payload_converter.to_payload(value) + + +def is_system_operation(service: str, operation: str) -> bool: + """Return whether a Nexus operation uses a generated system envelope.""" + return _operation(service, operation) is not None + + +def get_payload_converter() -> temporalio.converter.PayloadConverter: + """Return the fixed payload converter for system Nexus outer envelopes.""" + return SystemNexusPayloadConverter() + + +__all__ = [ + "get_payload_converter", + "is_system_operation", + "SystemNexusPayloadConverter", + "visit_payload", +] diff --git a/temporalio/nexus/system/_generation_support/temporal_model_converters.py b/temporalio/nexus/system/_generation_support/temporal_model_converters.py new file mode 100644 index 000000000..58b51e263 --- /dev/null +++ b/temporalio/nexus/system/_generation_support/temporal_model_converters.py @@ -0,0 +1,195 @@ +import collections.abc +import typing +from datetime import timedelta + +import google.protobuf.duration_pb2 + +import temporalio.api.common.v1.message_pb2 as common_pb2 +import temporalio.api.enums.v1.workflow_pb2 as workflow_enums_pb2 +import temporalio.api.taskqueue.v1.message_pb2 as taskqueue_pb2 +import temporalio.api.workflow.v1 +import temporalio.common +import temporalio.converter + + +def retry_policy_from_proto( + proto: common_pb2.RetryPolicy, +) -> temporalio.common.RetryPolicy: + return temporalio.common.RetryPolicy.from_proto(proto) + + +def retry_policy_to_proto( + retry_policy: temporalio.common.RetryPolicy, +) -> common_pb2.RetryPolicy: + proto = common_pb2.RetryPolicy() + retry_policy.apply_to_proto(proto) + return proto + + +def workflow_function_name( + value: str | collections.abc.Callable[..., collections.abc.Awaitable[object]], +) -> str: + from temporalio.workflow import _Definition # pyright: ignore[reportPrivateUsage] + + name, _result_type = _Definition.get_name_and_result_type(value) + return name + + +def signal_function_to_proto( + value: str | collections.abc.Callable[..., typing.Any], +) -> str: + from temporalio.workflow import ( + _SignalDefinition, # pyright: ignore[reportPrivateUsage] + ) + + return _SignalDefinition.must_name_from_fn_or_str(value) # pyright: ignore[reportUnknownMemberType] + + +def workflow_type_to_proto( + workflow_type: str + | collections.abc.Callable[..., collections.abc.Awaitable[object]], +) -> common_pb2.WorkflowType: + return common_pb2.WorkflowType(name=workflow_function_name(workflow_type)) + + +def task_queue_from_proto( + proto: taskqueue_pb2.TaskQueue, +) -> str: + return proto.name + + +def task_queue_to_proto( + task_queue: str, +) -> taskqueue_pb2.TaskQueue: + return taskqueue_pb2.TaskQueue(name=task_queue) + + +def workflow_namespace() -> str: + from temporalio.workflow import info + + return info().namespace + + +def payloads_to_proto( + values: collections.abc.Sequence[typing.Any], +) -> common_pb2.Payloads: + from temporalio.workflow import payload_converter + + return payload_converter().to_payloads_wrapper(values) + + +def _clone_payload(payload: common_pb2.Payload) -> common_pb2.Payload: + clone = common_pb2.Payload() + clone.CopyFrom(payload) + return clone + + +def _value_to_payload(value: object | common_pb2.Payload) -> common_pb2.Payload: + if isinstance(value, common_pb2.Payload): + return _clone_payload(value) + from temporalio.workflow import payload_converter + + payloads = payload_converter().to_payloads_wrapper([value]) + return _clone_payload(payloads.payloads[0]) + + +def _payload_to_value(payload: common_pb2.Payload) -> object: + wrapper = common_pb2.Payloads() + wrapper.payloads.add().CopyFrom(payload) + from temporalio.workflow import payload_converter + + return typing.cast( + object, + payload_converter().from_payloads_wrapper(wrapper)[0], + ) + + +def payload_from_proto( + proto: common_pb2.Payload, +) -> object: + return _payload_to_value(proto) + + +def payload_to_proto( + payload: object, +) -> common_pb2.Payload: + return _value_to_payload(payload) + + +def memo_from_proto( + proto: common_pb2.Memo, +) -> collections.abc.Mapping[str, object]: + return {key: _payload_to_value(value) for key, value in proto.fields.items()} + + +def memo_to_proto( + memo: collections.abc.Mapping[str, object], +) -> common_pb2.Memo: + message = common_pb2.Memo() + for key, value in memo.items(): + message.fields[key].CopyFrom(_value_to_payload(value)) + return message + + +def duration_from_proto(proto: google.protobuf.duration_pb2.Duration) -> timedelta: + return proto.ToTimedelta() + + +def duration_to_proto( + duration: timedelta, +) -> google.protobuf.duration_pb2.Duration: + proto = google.protobuf.duration_pb2.Duration() + proto.FromTimedelta(duration) + return proto + + +def workflow_id_reuse_policy_from_proto( + policy: workflow_enums_pb2.WorkflowIdReusePolicy.ValueType, +) -> temporalio.common.WorkflowIDReusePolicy: + return temporalio.common.WorkflowIDReusePolicy(int(policy)) + + +def workflow_id_reuse_policy_to_proto( + policy: temporalio.common.WorkflowIDReusePolicy, +) -> workflow_enums_pb2.WorkflowIdReusePolicy.ValueType: + return typing.cast(workflow_enums_pb2.WorkflowIdReusePolicy.ValueType, int(policy)) + + +def workflow_id_conflict_policy_from_proto( + policy: workflow_enums_pb2.WorkflowIdConflictPolicy.ValueType, +) -> temporalio.common.WorkflowIDConflictPolicy: + return temporalio.common.WorkflowIDConflictPolicy(int(policy)) + + +def workflow_id_conflict_policy_to_proto( + policy: temporalio.common.WorkflowIDConflictPolicy, +) -> workflow_enums_pb2.WorkflowIdConflictPolicy.ValueType: + return typing.cast( + workflow_enums_pb2.WorkflowIdConflictPolicy.ValueType, int(policy) + ) + + +def search_attributes_to_proto( + search_attributes: temporalio.common.TypedSearchAttributes, +) -> common_pb2.SearchAttributes: + proto = common_pb2.SearchAttributes() + temporalio.converter.encode_search_attributes(search_attributes, proto) + return proto + + +def priority_from_proto( + proto: common_pb2.Priority, +) -> temporalio.common.Priority: + return temporalio.common.Priority._from_proto(proto) # pyright: ignore[reportPrivateUsage] + + +def priority_to_proto( + priority: temporalio.common.Priority, +) -> common_pb2.Priority: + return priority._to_proto() # pyright: ignore[reportPrivateUsage] + + +def versioning_override_to_proto( + versioning_override: temporalio.common.VersioningOverride, +) -> temporalio.api.workflow.v1.VersioningOverride: + return versioning_override._to_proto() # pyright: ignore[reportPrivateUsage] diff --git a/temporalio/nexus/system/_payload_visitor.py b/temporalio/nexus/system/_payload_visitor.py new file mode 100644 index 000000000..ecc51e2c4 --- /dev/null +++ b/temporalio/nexus/system/_payload_visitor.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +# This file is generated by gen_payload_visitor.py. Changes should be made there. +import asyncio +from typing import Any + +import temporalio.nexus.system +from temporalio.api.common.v1.message_pb2 import Payload +from temporalio.bridge._visitor_functions import PayloadSequence, VisitorFunctions + + +class _BoundedVisitorFunctions(VisitorFunctions): + """Wraps VisitorFunctions to cap concurrent payload visits via a semaphore. + + After the full traversal, call drain() to await all in-flight tasks. + """ + + def __init__(self, inner: VisitorFunctions, sem: asyncio.Semaphore) -> None: + self._inner = inner + self._sem = sem + self._tasks: list[asyncio.Task[None]] = [] + + async def visit_payload(self, payload: Payload) -> None: + await self._sem.acquire() + + async def _run() -> None: + try: + await self._inner.visit_payload(payload) + finally: + self._sem.release() + + self._tasks.append(asyncio.create_task(_run())) + + async def visit_payloads(self, payloads: PayloadSequence) -> None: + await self._sem.acquire() + + async def _run() -> None: + try: + await self._inner.visit_payloads(payloads) + finally: + self._sem.release() + + self._tasks.append(asyncio.create_task(_run())) + + async def visit_system_nexus_envelope(self, payload: Payload) -> None: + await self._inner.visit_system_nexus_envelope(payload) + + async def drain(self) -> None: + """Wait for all in-flight background tasks to complete. + + On cancellation or error, cancels all remaining tasks and awaits + them so their finally blocks run before this coroutine returns. + """ + if not self._tasks: + return + try: + await asyncio.gather(*self._tasks) + except BaseException: + for task in self._tasks: + task.cancel() + await asyncio.gather(*self._tasks, return_exceptions=True) + raise + + +class PayloadVisitor: + """A visitor for payloads. + Applies a function to every payload in a tree of messages. + """ + + def __init__( + self, + *, + skip_search_attributes: bool = False, + skip_headers: bool = False, + concurrency_limit: int = 1, + ): + """Creates a new payload visitor. + + Args: + skip_search_attributes: If True, search attributes are not visited. + skip_headers: If True, headers are not visited. + concurrency_limit: Maximum number of payload visits that may run + concurrently during a single call to visit(). Defaults to 1 + (sequential). + """ + if concurrency_limit < 1: + raise ValueError("concurrency_limit must be positive") + self.skip_search_attributes = skip_search_attributes + self.skip_headers = skip_headers + self._concurrency_limit = concurrency_limit + + async def visit(self, fs: VisitorFunctions, root: Any) -> None: + """Visits the given root message with the given function.""" + method_name = "_visit_" + root.DESCRIPTOR.full_name.replace(".", "_") + method = getattr(self, method_name, None) + if method is None: + raise ValueError(f"Unknown root message type: {root.DESCRIPTOR.full_name}") + if self._concurrency_limit == 1: + await method(fs, root) + return + + bounded = _BoundedVisitorFunctions( + fs, asyncio.Semaphore(self._concurrency_limit) + ) + try: + await method(bounded, root) + finally: + await bounded.drain() + + async def _visit_system_nexus_payload( + self, + fs: VisitorFunctions, + service: str, + operation: str, + payload: Payload, + ) -> None: + new_payload = await temporalio.nexus.system.visit_payload( + service, + operation, + payload, + fs, + self.skip_search_attributes, + ) + if new_payload is None: + await self._visit_temporal_api_common_v1_Payload(fs, payload) + return + + if new_payload is not payload: + payload.CopyFrom(new_payload) + await fs.visit_system_nexus_envelope(payload) + + async def _visit_temporal_api_common_v1_Payload( + self, fs: VisitorFunctions, o: Payload + ): + await fs.visit_payload(o) + + async def _visit_temporal_api_common_v1_Payloads( + self, fs: VisitorFunctions, o: Any + ): + await fs.visit_payloads(o.payloads) + + async def _visit_payload_container(self, fs: VisitorFunctions, o: PayloadSequence): + await fs.visit_payloads(o) + + async def _visit_temporal_api_common_v1_Memo(self, fs: VisitorFunctions, o: Any): + for v in o.fields.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) + + async def _visit_temporal_api_common_v1_SearchAttributes( + self, fs: VisitorFunctions, o: Any + ): + if self.skip_search_attributes: + return + for v in o.indexed_fields.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) + + async def _visit_temporal_api_common_v1_Header(self, fs: VisitorFunctions, o: Any): + for v in o.fields.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) + + async def _visit_temporal_api_sdk_v1_UserMetadata( + self, fs: VisitorFunctions, o: Any + ): + if o.HasField("summary"): + await self._visit_temporal_api_common_v1_Payload(fs, o.summary) + if o.HasField("details"): + await self._visit_temporal_api_common_v1_Payload(fs, o.details) + + async def _visit_temporal_api_workflowservice_v1_SignalWithStartWorkflowExecutionRequest( + self, fs: VisitorFunctions, o: Any + ): + if o.HasField("input"): + await self._visit_temporal_api_common_v1_Payloads(fs, o.input) + if o.HasField("signal_input"): + await self._visit_temporal_api_common_v1_Payloads(fs, o.signal_input) + if o.HasField("memo"): + await self._visit_temporal_api_common_v1_Memo(fs, o.memo) + if o.HasField("search_attributes"): + await self._visit_temporal_api_common_v1_SearchAttributes( + fs, o.search_attributes + ) + if o.HasField("header"): + await self._visit_temporal_api_common_v1_Header(fs, o.header) + if o.HasField("user_metadata"): + await self._visit_temporal_api_sdk_v1_UserMetadata(fs, o.user_metadata) diff --git a/temporalio/nexus/system/workflow_service/__init__.py b/temporalio/nexus/system/workflow_service/__init__.py new file mode 100644 index 000000000..7c24fa125 --- /dev/null +++ b/temporalio/nexus/system/workflow_service/__init__.py @@ -0,0 +1,18 @@ +# Generated by nex-gen. DO NOT EDIT! + +from __future__ import annotations + +from . import service as _service +from .operations.signal_with_start_workflow import signal_with_start_workflow + +__all__ = [ + "signal_with_start_workflow", +] + + +__nexus_operation_registry__ = { + ( + "temporal.api.workflowservice.v1.WorkflowService", + "SignalWithStartWorkflowExecution", + ): _service.WorkflowService.signal_with_start_workflow, +} diff --git a/temporalio/nexus/system/workflow_service/_resources/__init__.py b/temporalio/nexus/system/workflow_service/_resources/__init__.py new file mode 100644 index 000000000..373efbd33 --- /dev/null +++ b/temporalio/nexus/system/workflow_service/_resources/__init__.py @@ -0,0 +1,5 @@ +# Generated by nex-gen. DO NOT EDIT! + +from __future__ import annotations + +__all__ = [] diff --git a/temporalio/nexus/system/workflow_service/_support/__init__.py b/temporalio/nexus/system/workflow_service/_support/__init__.py new file mode 100644 index 000000000..166261d16 --- /dev/null +++ b/temporalio/nexus/system/workflow_service/_support/__init__.py @@ -0,0 +1,5 @@ +# Generated by nex-gen. DO NOT EDIT! + +from __future__ import annotations + +from .temporal_model_converters import * # noqa: F401,F403 diff --git a/temporalio/nexus/system/workflow_service/_support/temporal_model_converters.py b/temporalio/nexus/system/workflow_service/_support/temporal_model_converters.py new file mode 100644 index 000000000..58b51e263 --- /dev/null +++ b/temporalio/nexus/system/workflow_service/_support/temporal_model_converters.py @@ -0,0 +1,195 @@ +import collections.abc +import typing +from datetime import timedelta + +import google.protobuf.duration_pb2 + +import temporalio.api.common.v1.message_pb2 as common_pb2 +import temporalio.api.enums.v1.workflow_pb2 as workflow_enums_pb2 +import temporalio.api.taskqueue.v1.message_pb2 as taskqueue_pb2 +import temporalio.api.workflow.v1 +import temporalio.common +import temporalio.converter + + +def retry_policy_from_proto( + proto: common_pb2.RetryPolicy, +) -> temporalio.common.RetryPolicy: + return temporalio.common.RetryPolicy.from_proto(proto) + + +def retry_policy_to_proto( + retry_policy: temporalio.common.RetryPolicy, +) -> common_pb2.RetryPolicy: + proto = common_pb2.RetryPolicy() + retry_policy.apply_to_proto(proto) + return proto + + +def workflow_function_name( + value: str | collections.abc.Callable[..., collections.abc.Awaitable[object]], +) -> str: + from temporalio.workflow import _Definition # pyright: ignore[reportPrivateUsage] + + name, _result_type = _Definition.get_name_and_result_type(value) + return name + + +def signal_function_to_proto( + value: str | collections.abc.Callable[..., typing.Any], +) -> str: + from temporalio.workflow import ( + _SignalDefinition, # pyright: ignore[reportPrivateUsage] + ) + + return _SignalDefinition.must_name_from_fn_or_str(value) # pyright: ignore[reportUnknownMemberType] + + +def workflow_type_to_proto( + workflow_type: str + | collections.abc.Callable[..., collections.abc.Awaitable[object]], +) -> common_pb2.WorkflowType: + return common_pb2.WorkflowType(name=workflow_function_name(workflow_type)) + + +def task_queue_from_proto( + proto: taskqueue_pb2.TaskQueue, +) -> str: + return proto.name + + +def task_queue_to_proto( + task_queue: str, +) -> taskqueue_pb2.TaskQueue: + return taskqueue_pb2.TaskQueue(name=task_queue) + + +def workflow_namespace() -> str: + from temporalio.workflow import info + + return info().namespace + + +def payloads_to_proto( + values: collections.abc.Sequence[typing.Any], +) -> common_pb2.Payloads: + from temporalio.workflow import payload_converter + + return payload_converter().to_payloads_wrapper(values) + + +def _clone_payload(payload: common_pb2.Payload) -> common_pb2.Payload: + clone = common_pb2.Payload() + clone.CopyFrom(payload) + return clone + + +def _value_to_payload(value: object | common_pb2.Payload) -> common_pb2.Payload: + if isinstance(value, common_pb2.Payload): + return _clone_payload(value) + from temporalio.workflow import payload_converter + + payloads = payload_converter().to_payloads_wrapper([value]) + return _clone_payload(payloads.payloads[0]) + + +def _payload_to_value(payload: common_pb2.Payload) -> object: + wrapper = common_pb2.Payloads() + wrapper.payloads.add().CopyFrom(payload) + from temporalio.workflow import payload_converter + + return typing.cast( + object, + payload_converter().from_payloads_wrapper(wrapper)[0], + ) + + +def payload_from_proto( + proto: common_pb2.Payload, +) -> object: + return _payload_to_value(proto) + + +def payload_to_proto( + payload: object, +) -> common_pb2.Payload: + return _value_to_payload(payload) + + +def memo_from_proto( + proto: common_pb2.Memo, +) -> collections.abc.Mapping[str, object]: + return {key: _payload_to_value(value) for key, value in proto.fields.items()} + + +def memo_to_proto( + memo: collections.abc.Mapping[str, object], +) -> common_pb2.Memo: + message = common_pb2.Memo() + for key, value in memo.items(): + message.fields[key].CopyFrom(_value_to_payload(value)) + return message + + +def duration_from_proto(proto: google.protobuf.duration_pb2.Duration) -> timedelta: + return proto.ToTimedelta() + + +def duration_to_proto( + duration: timedelta, +) -> google.protobuf.duration_pb2.Duration: + proto = google.protobuf.duration_pb2.Duration() + proto.FromTimedelta(duration) + return proto + + +def workflow_id_reuse_policy_from_proto( + policy: workflow_enums_pb2.WorkflowIdReusePolicy.ValueType, +) -> temporalio.common.WorkflowIDReusePolicy: + return temporalio.common.WorkflowIDReusePolicy(int(policy)) + + +def workflow_id_reuse_policy_to_proto( + policy: temporalio.common.WorkflowIDReusePolicy, +) -> workflow_enums_pb2.WorkflowIdReusePolicy.ValueType: + return typing.cast(workflow_enums_pb2.WorkflowIdReusePolicy.ValueType, int(policy)) + + +def workflow_id_conflict_policy_from_proto( + policy: workflow_enums_pb2.WorkflowIdConflictPolicy.ValueType, +) -> temporalio.common.WorkflowIDConflictPolicy: + return temporalio.common.WorkflowIDConflictPolicy(int(policy)) + + +def workflow_id_conflict_policy_to_proto( + policy: temporalio.common.WorkflowIDConflictPolicy, +) -> workflow_enums_pb2.WorkflowIdConflictPolicy.ValueType: + return typing.cast( + workflow_enums_pb2.WorkflowIdConflictPolicy.ValueType, int(policy) + ) + + +def search_attributes_to_proto( + search_attributes: temporalio.common.TypedSearchAttributes, +) -> common_pb2.SearchAttributes: + proto = common_pb2.SearchAttributes() + temporalio.converter.encode_search_attributes(search_attributes, proto) + return proto + + +def priority_from_proto( + proto: common_pb2.Priority, +) -> temporalio.common.Priority: + return temporalio.common.Priority._from_proto(proto) # pyright: ignore[reportPrivateUsage] + + +def priority_to_proto( + priority: temporalio.common.Priority, +) -> common_pb2.Priority: + return priority._to_proto() # pyright: ignore[reportPrivateUsage] + + +def versioning_override_to_proto( + versioning_override: temporalio.common.VersioningOverride, +) -> temporalio.api.workflow.v1.VersioningOverride: + return versioning_override._to_proto() # pyright: ignore[reportPrivateUsage] diff --git a/temporalio/nexus/system/workflow_service/models.py b/temporalio/nexus/system/workflow_service/models.py new file mode 100644 index 000000000..88d77c35b --- /dev/null +++ b/temporalio/nexus/system/workflow_service/models.py @@ -0,0 +1,141 @@ +# Generated by nex-gen. DO NOT EDIT! + +from __future__ import annotations + +import collections.abc +import dataclasses +import datetime +import typing + +import temporalio.api.sdk.v1.user_metadata_pb2 +import temporalio.api.workflowservice.v1.request_response_pb2 +import temporalio.common + +from ._support import ( + duration_to_proto, + memo_to_proto, + payload_from_proto, + payload_to_proto, + payloads_to_proto, + priority_to_proto, + retry_policy_to_proto, + search_attributes_to_proto, + signal_function_to_proto, + task_queue_to_proto, + versioning_override_to_proto, + workflow_id_conflict_policy_to_proto, + workflow_id_reuse_policy_to_proto, + workflow_namespace, + workflow_type_to_proto, +) + + +@dataclasses.dataclass(slots=True, kw_only=True) +class SignalWithStartWorkflowRequest: + """ + .. warning:: + This API is experimental and subject to change. + """ + + workflow: str | collections.abc.Callable[..., collections.abc.Awaitable[typing.Any]] + args: list[typing.Any] | None = None + id: str + task_queue: str + signal: str | collections.abc.Callable[..., None | collections.abc.Awaitable[None]] + signal_args: list[typing.Any] | None = None + execution_timeout: datetime.timedelta | None = None + run_timeout: datetime.timedelta | None = None + task_timeout: datetime.timedelta | None = None + request_id: str | None = None + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = ( + temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE + ) + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy | None = None + retry_policy: temporalio.common.RetryPolicy | None = None + cron_schedule: str | None = None + memo: collections.abc.Mapping[str, typing.Any] | None = None + search_attributes: temporalio.common.TypedSearchAttributes | None = None + priority: temporalio.common.Priority | None = None + versioning_override: temporalio.common.VersioningOverride | None = None + start_delay: datetime.timedelta | None = None + user_metadata: UserMetadata | None = None + + def to_proto( + self, + ) -> temporalio.api.workflowservice.v1.request_response_pb2.SignalWithStartWorkflowExecutionRequest: + message = temporalio.api.workflowservice.v1.request_response_pb2.SignalWithStartWorkflowExecutionRequest() + message.workflow_type.CopyFrom(workflow_type_to_proto(self.workflow)) + if self.args is not None: + message.input.CopyFrom(payloads_to_proto(self.args)) + message.workflow_id = self.id + message.task_queue.CopyFrom(task_queue_to_proto(self.task_queue)) + message.signal_name = signal_function_to_proto(self.signal) + if self.signal_args is not None: + message.signal_input.CopyFrom(payloads_to_proto(self.signal_args)) + if self.execution_timeout is not None: + message.workflow_execution_timeout.CopyFrom( + duration_to_proto(self.execution_timeout) + ) + if self.run_timeout is not None: + message.workflow_run_timeout.CopyFrom(duration_to_proto(self.run_timeout)) + if self.task_timeout is not None: + message.workflow_task_timeout.CopyFrom(duration_to_proto(self.task_timeout)) + if self.request_id is not None: + message.request_id = self.request_id + message.workflow_id_reuse_policy = workflow_id_reuse_policy_to_proto( + self.id_reuse_policy + ) + if self.id_conflict_policy is not None: + message.workflow_id_conflict_policy = workflow_id_conflict_policy_to_proto( + self.id_conflict_policy + ) + if self.retry_policy is not None: + message.retry_policy.CopyFrom(retry_policy_to_proto(self.retry_policy)) + if self.cron_schedule is not None: + message.cron_schedule = self.cron_schedule + if self.memo is not None: + message.memo.CopyFrom(memo_to_proto(self.memo)) + if self.search_attributes is not None: + message.search_attributes.CopyFrom( + search_attributes_to_proto(self.search_attributes) + ) + if self.priority is not None: + message.priority.CopyFrom(priority_to_proto(self.priority)) + if self.versioning_override is not None: + message.versioning_override.CopyFrom( + versioning_override_to_proto(self.versioning_override) + ) + if self.start_delay is not None: + message.workflow_start_delay.CopyFrom(duration_to_proto(self.start_delay)) + if self.user_metadata is not None: + message.user_metadata.CopyFrom(self.user_metadata.to_proto()) + message.namespace = workflow_namespace() + return message + + +@dataclasses.dataclass(slots=True) +class UserMetadata: + static_summary: typing.Any | None = None + static_details: typing.Any | None = None + + @classmethod + def from_proto( + cls, + proto: temporalio.api.sdk.v1.user_metadata_pb2.UserMetadata, + ) -> UserMetadata: + return cls( + static_summary=payload_from_proto(proto.summary) + if proto.HasField("summary") + else None, + static_details=payload_from_proto(proto.details) + if proto.HasField("details") + else None, + ) + + def to_proto(self) -> temporalio.api.sdk.v1.user_metadata_pb2.UserMetadata: + message = temporalio.api.sdk.v1.user_metadata_pb2.UserMetadata() + if self.static_summary is not None: + message.summary.CopyFrom(payload_to_proto(self.static_summary)) + if self.static_details is not None: + message.details.CopyFrom(payload_to_proto(self.static_details)) + return message diff --git a/temporalio/nexus/system/workflow_service/operations/__init__.py b/temporalio/nexus/system/workflow_service/operations/__init__.py new file mode 100644 index 000000000..67c9cc56b --- /dev/null +++ b/temporalio/nexus/system/workflow_service/operations/__init__.py @@ -0,0 +1,3 @@ +# Generated by nex-gen. DO NOT EDIT! + +from __future__ import annotations diff --git a/temporalio/nexus/system/workflow_service/operations/signal_with_start_workflow.py b/temporalio/nexus/system/workflow_service/operations/signal_with_start_workflow.py new file mode 100644 index 000000000..ded496898 --- /dev/null +++ b/temporalio/nexus/system/workflow_service/operations/signal_with_start_workflow.py @@ -0,0 +1,507 @@ +# Generated by nex-gen. DO NOT EDIT! + +from __future__ import annotations + +import collections.abc +import datetime +import typing + +import typing_extensions + +import temporalio.api.workflowservice.v1.request_response_pb2 +import temporalio.common + +if typing.TYPE_CHECKING: + from temporalio.workflow import ExternalWorkflowHandle + +from ..models import ( + SignalWithStartWorkflowRequest, + UserMetadata, +) + +SignalArg = typing.TypeVar("SignalArg") +WorkflowResult = typing.TypeVar("WorkflowResult") +WorkflowArgs = typing_extensions.TypeVarTuple("WorkflowArgs") + + +async def _signal_with_start_workflow( + request: SignalWithStartWorkflowRequest, +) -> ExternalWorkflowHandle[typing.Any]: + from temporalio.workflow import ( + create_nexus_client, + get_external_workflow_handle, + ) + + request_proto = request.to_proto() + nexus_client = create_nexus_client( + service="temporal.api.workflowservice.v1.WorkflowService", + endpoint="__temporal_system", + ) + handle = await nexus_client.start_operation( + operation="SignalWithStartWorkflowExecution", + input=request_proto, + output_type=temporalio.api.workflowservice.v1.request_response_pb2.SignalWithStartWorkflowExecutionResponse, + ) + result = await handle + return get_external_workflow_handle(request.id, run_id=result.run_id) + + +@typing.overload +async def signal_with_start_workflow( + workflow: str, + *positional_args: object, + args: list[typing.Any] | None = ..., + id: str, + task_queue: str, + signal: str, + signal_args: list[typing.Any] | None = ..., + execution_timeout: datetime.timedelta | None = ..., + run_timeout: datetime.timedelta | None = ..., + task_timeout: datetime.timedelta | None = ..., + request_id: str | None = ..., + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = ..., + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy | None = ..., + retry_policy: temporalio.common.RetryPolicy | None = ..., + cron_schedule: str | None = ..., + memo: collections.abc.Mapping[str, typing.Any] | None = ..., + search_attributes: temporalio.common.TypedSearchAttributes | None = ..., + priority: temporalio.common.Priority | None = ..., + versioning_override: temporalio.common.VersioningOverride | None = ..., + start_delay: datetime.timedelta | None = ..., + static_summary: str | None = ..., + static_details: str | None = ..., +) -> ExternalWorkflowHandle[typing.Any]: ... + + +@typing.overload +async def signal_with_start_workflow( + workflow: collections.abc.Callable[ + [typing.Any, typing_extensions.Unpack[WorkflowArgs]], + collections.abc.Awaitable[WorkflowResult], + ], + *positional_args: typing_extensions.Unpack[WorkflowArgs], + id: str, + task_queue: str, + signal: str, + signal_args: list[typing.Any] | None = ..., + execution_timeout: datetime.timedelta | None = ..., + run_timeout: datetime.timedelta | None = ..., + task_timeout: datetime.timedelta | None = ..., + request_id: str | None = ..., + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = ..., + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy | None = ..., + retry_policy: temporalio.common.RetryPolicy | None = ..., + cron_schedule: str | None = ..., + memo: collections.abc.Mapping[str, typing.Any] | None = ..., + search_attributes: temporalio.common.TypedSearchAttributes | None = ..., + priority: temporalio.common.Priority | None = ..., + versioning_override: temporalio.common.VersioningOverride | None = ..., + start_delay: datetime.timedelta | None = ..., + static_summary: str | None = ..., + static_details: str | None = ..., +) -> ExternalWorkflowHandle[WorkflowResult]: ... + + +@typing.overload +async def signal_with_start_workflow( + workflow: collections.abc.Callable[..., collections.abc.Awaitable[WorkflowResult]], + *, + args: list[typing.Any], + id: str, + task_queue: str, + signal: str, + signal_args: list[typing.Any] | None = ..., + execution_timeout: datetime.timedelta | None = ..., + run_timeout: datetime.timedelta | None = ..., + task_timeout: datetime.timedelta | None = ..., + request_id: str | None = ..., + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = ..., + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy | None = ..., + retry_policy: temporalio.common.RetryPolicy | None = ..., + cron_schedule: str | None = ..., + memo: collections.abc.Mapping[str, typing.Any] | None = ..., + search_attributes: temporalio.common.TypedSearchAttributes | None = ..., + priority: temporalio.common.Priority | None = ..., + versioning_override: temporalio.common.VersioningOverride | None = ..., + start_delay: datetime.timedelta | None = ..., + static_summary: str | None = ..., + static_details: str | None = ..., +) -> ExternalWorkflowHandle[WorkflowResult]: ... + + +@typing.overload +async def signal_with_start_workflow( + workflow: str, + *positional_args: object, + args: list[typing.Any] | None = ..., + id: str, + task_queue: str, + signal: collections.abc.Callable[ + [typing.Any], None | collections.abc.Awaitable[None] + ], + execution_timeout: datetime.timedelta | None = ..., + run_timeout: datetime.timedelta | None = ..., + task_timeout: datetime.timedelta | None = ..., + request_id: str | None = ..., + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = ..., + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy | None = ..., + retry_policy: temporalio.common.RetryPolicy | None = ..., + cron_schedule: str | None = ..., + memo: collections.abc.Mapping[str, typing.Any] | None = ..., + search_attributes: temporalio.common.TypedSearchAttributes | None = ..., + priority: temporalio.common.Priority | None = ..., + versioning_override: temporalio.common.VersioningOverride | None = ..., + start_delay: datetime.timedelta | None = ..., + static_summary: str | None = ..., + static_details: str | None = ..., +) -> ExternalWorkflowHandle[typing.Any]: ... + + +@typing.overload +async def signal_with_start_workflow( + workflow: collections.abc.Callable[ + [typing.Any, typing_extensions.Unpack[WorkflowArgs]], + collections.abc.Awaitable[WorkflowResult], + ], + *positional_args: typing_extensions.Unpack[WorkflowArgs], + id: str, + task_queue: str, + signal: collections.abc.Callable[ + [typing.Any], None | collections.abc.Awaitable[None] + ], + execution_timeout: datetime.timedelta | None = ..., + run_timeout: datetime.timedelta | None = ..., + task_timeout: datetime.timedelta | None = ..., + request_id: str | None = ..., + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = ..., + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy | None = ..., + retry_policy: temporalio.common.RetryPolicy | None = ..., + cron_schedule: str | None = ..., + memo: collections.abc.Mapping[str, typing.Any] | None = ..., + search_attributes: temporalio.common.TypedSearchAttributes | None = ..., + priority: temporalio.common.Priority | None = ..., + versioning_override: temporalio.common.VersioningOverride | None = ..., + start_delay: datetime.timedelta | None = ..., + static_summary: str | None = ..., + static_details: str | None = ..., +) -> ExternalWorkflowHandle[WorkflowResult]: ... + + +@typing.overload +async def signal_with_start_workflow( + workflow: collections.abc.Callable[..., collections.abc.Awaitable[WorkflowResult]], + *, + args: list[typing.Any], + id: str, + task_queue: str, + signal: collections.abc.Callable[ + [typing.Any], None | collections.abc.Awaitable[None] + ], + execution_timeout: datetime.timedelta | None = ..., + run_timeout: datetime.timedelta | None = ..., + task_timeout: datetime.timedelta | None = ..., + request_id: str | None = ..., + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = ..., + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy | None = ..., + retry_policy: temporalio.common.RetryPolicy | None = ..., + cron_schedule: str | None = ..., + memo: collections.abc.Mapping[str, typing.Any] | None = ..., + search_attributes: temporalio.common.TypedSearchAttributes | None = ..., + priority: temporalio.common.Priority | None = ..., + versioning_override: temporalio.common.VersioningOverride | None = ..., + start_delay: datetime.timedelta | None = ..., + static_summary: str | None = ..., + static_details: str | None = ..., +) -> ExternalWorkflowHandle[WorkflowResult]: ... + + +@typing.overload +async def signal_with_start_workflow( + workflow: str, + *positional_args: object, + args: list[typing.Any] | None = ..., + id: str, + task_queue: str, + signal: collections.abc.Callable[ + [typing.Any, SignalArg], None | collections.abc.Awaitable[None] + ], + signal_args: SignalArg, + execution_timeout: datetime.timedelta | None = ..., + run_timeout: datetime.timedelta | None = ..., + task_timeout: datetime.timedelta | None = ..., + request_id: str | None = ..., + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = ..., + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy | None = ..., + retry_policy: temporalio.common.RetryPolicy | None = ..., + cron_schedule: str | None = ..., + memo: collections.abc.Mapping[str, typing.Any] | None = ..., + search_attributes: temporalio.common.TypedSearchAttributes | None = ..., + priority: temporalio.common.Priority | None = ..., + versioning_override: temporalio.common.VersioningOverride | None = ..., + start_delay: datetime.timedelta | None = ..., + static_summary: str | None = ..., + static_details: str | None = ..., +) -> ExternalWorkflowHandle[typing.Any]: ... + + +@typing.overload +async def signal_with_start_workflow( + workflow: collections.abc.Callable[ + [typing.Any, typing_extensions.Unpack[WorkflowArgs]], + collections.abc.Awaitable[WorkflowResult], + ], + *positional_args: typing_extensions.Unpack[WorkflowArgs], + id: str, + task_queue: str, + signal: collections.abc.Callable[ + [typing.Any, SignalArg], None | collections.abc.Awaitable[None] + ], + signal_args: SignalArg, + execution_timeout: datetime.timedelta | None = ..., + run_timeout: datetime.timedelta | None = ..., + task_timeout: datetime.timedelta | None = ..., + request_id: str | None = ..., + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = ..., + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy | None = ..., + retry_policy: temporalio.common.RetryPolicy | None = ..., + cron_schedule: str | None = ..., + memo: collections.abc.Mapping[str, typing.Any] | None = ..., + search_attributes: temporalio.common.TypedSearchAttributes | None = ..., + priority: temporalio.common.Priority | None = ..., + versioning_override: temporalio.common.VersioningOverride | None = ..., + start_delay: datetime.timedelta | None = ..., + static_summary: str | None = ..., + static_details: str | None = ..., +) -> ExternalWorkflowHandle[WorkflowResult]: ... + + +@typing.overload +async def signal_with_start_workflow( + workflow: collections.abc.Callable[..., collections.abc.Awaitable[WorkflowResult]], + *, + args: list[typing.Any], + id: str, + task_queue: str, + signal: collections.abc.Callable[ + [typing.Any, SignalArg], None | collections.abc.Awaitable[None] + ], + signal_args: SignalArg, + execution_timeout: datetime.timedelta | None = ..., + run_timeout: datetime.timedelta | None = ..., + task_timeout: datetime.timedelta | None = ..., + request_id: str | None = ..., + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = ..., + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy | None = ..., + retry_policy: temporalio.common.RetryPolicy | None = ..., + cron_schedule: str | None = ..., + memo: collections.abc.Mapping[str, typing.Any] | None = ..., + search_attributes: temporalio.common.TypedSearchAttributes | None = ..., + priority: temporalio.common.Priority | None = ..., + versioning_override: temporalio.common.VersioningOverride | None = ..., + start_delay: datetime.timedelta | None = ..., + static_summary: str | None = ..., + static_details: str | None = ..., +) -> ExternalWorkflowHandle[WorkflowResult]: ... + + +@typing.overload +async def signal_with_start_workflow( + workflow: str, + *positional_args: object, + args: list[typing.Any] | None = ..., + id: str, + task_queue: str, + signal: collections.abc.Callable[..., None | collections.abc.Awaitable[None]], + signal_args: list[typing.Any], + execution_timeout: datetime.timedelta | None = ..., + run_timeout: datetime.timedelta | None = ..., + task_timeout: datetime.timedelta | None = ..., + request_id: str | None = ..., + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = ..., + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy | None = ..., + retry_policy: temporalio.common.RetryPolicy | None = ..., + cron_schedule: str | None = ..., + memo: collections.abc.Mapping[str, typing.Any] | None = ..., + search_attributes: temporalio.common.TypedSearchAttributes | None = ..., + priority: temporalio.common.Priority | None = ..., + versioning_override: temporalio.common.VersioningOverride | None = ..., + start_delay: datetime.timedelta | None = ..., + static_summary: str | None = ..., + static_details: str | None = ..., +) -> ExternalWorkflowHandle[typing.Any]: ... + + +@typing.overload +async def signal_with_start_workflow( + workflow: collections.abc.Callable[ + [typing.Any, typing_extensions.Unpack[WorkflowArgs]], + collections.abc.Awaitable[WorkflowResult], + ], + *positional_args: typing_extensions.Unpack[WorkflowArgs], + id: str, + task_queue: str, + signal: collections.abc.Callable[..., None | collections.abc.Awaitable[None]], + signal_args: list[typing.Any], + execution_timeout: datetime.timedelta | None = ..., + run_timeout: datetime.timedelta | None = ..., + task_timeout: datetime.timedelta | None = ..., + request_id: str | None = ..., + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = ..., + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy | None = ..., + retry_policy: temporalio.common.RetryPolicy | None = ..., + cron_schedule: str | None = ..., + memo: collections.abc.Mapping[str, typing.Any] | None = ..., + search_attributes: temporalio.common.TypedSearchAttributes | None = ..., + priority: temporalio.common.Priority | None = ..., + versioning_override: temporalio.common.VersioningOverride | None = ..., + start_delay: datetime.timedelta | None = ..., + static_summary: str | None = ..., + static_details: str | None = ..., +) -> ExternalWorkflowHandle[WorkflowResult]: ... + + +@typing.overload +async def signal_with_start_workflow( + workflow: collections.abc.Callable[..., collections.abc.Awaitable[WorkflowResult]], + *, + args: list[typing.Any], + id: str, + task_queue: str, + signal: collections.abc.Callable[..., None | collections.abc.Awaitable[None]], + signal_args: list[typing.Any], + execution_timeout: datetime.timedelta | None = ..., + run_timeout: datetime.timedelta | None = ..., + task_timeout: datetime.timedelta | None = ..., + request_id: str | None = ..., + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = ..., + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy | None = ..., + retry_policy: temporalio.common.RetryPolicy | None = ..., + cron_schedule: str | None = ..., + memo: collections.abc.Mapping[str, typing.Any] | None = ..., + search_attributes: temporalio.common.TypedSearchAttributes | None = ..., + priority: temporalio.common.Priority | None = ..., + versioning_override: temporalio.common.VersioningOverride | None = ..., + start_delay: datetime.timedelta | None = ..., + static_summary: str | None = ..., + static_details: str | None = ..., +) -> ExternalWorkflowHandle[WorkflowResult]: ... + + +async def signal_with_start_workflow( + workflow: str + | collections.abc.Callable[..., collections.abc.Awaitable[typing.Any]], + *positional_args: object, + args: list[typing.Any] | None = None, + id: str, + task_queue: str, + signal: str | collections.abc.Callable[..., None | collections.abc.Awaitable[None]], + signal_args: object | list[typing.Any] | None = None, + execution_timeout: datetime.timedelta | None = None, + run_timeout: datetime.timedelta | None = None, + task_timeout: datetime.timedelta | None = None, + request_id: str | None = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = ( + temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE + ), + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy | None = None, + retry_policy: temporalio.common.RetryPolicy | None = None, + cron_schedule: str | None = None, + memo: collections.abc.Mapping[str, typing.Any] | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + priority: temporalio.common.Priority | None = None, + versioning_override: temporalio.common.VersioningOverride | None = None, + start_delay: datetime.timedelta | None = None, + static_summary: str | None = None, + static_details: str | None = None, +) -> ExternalWorkflowHandle[typing.Any]: + """Signal a workflow, starting it first if needed. + + .. warning:: + This API is experimental and subject to change. + + Args: + workflow: Workflow type name or callable identifying the workflow to start. + positional_args: Positional arguments for workflow. Cannot be set if args is + set. + args: List-form arguments for workflow. Cannot be set if positional_args are + set. For typed workflow callables, list contents are not statically + typechecked; pass workflow arguments positionally for precise typechecking. + id: Unique identifier for the workflow execution. + task_queue: Task queue to run the workflow on. + signal: Signal name or callable to send with the start request. + signal_args: Argument value, or list of argument values, for signal. For typed + single-argument signals, scalar signal_args values are statically + typechecked. List-form signal_args values are not precisely typechecked. To + pass a single signal argument that is itself a list, wrap it in another + list; otherwise the list is interpreted as multiple signal arguments. + execution_timeout: Total workflow execution timeout, including retries and + continue-as-new. + run_timeout: Timeout of a single workflow run. + task_timeout: Timeout of a single workflow task. + request_id: Request ID used to deduplicate workflow start requests. + id_reuse_policy: Behavior when a closed workflow with the same ID exists. + Default is allow-duplicate. + id_conflict_policy: Behavior when a workflow is currently running with the same + ID. Set to use-existing for idempotent deduplication on workflow ID. Cannot + be set if id-reuse-policy is terminate-if-running. + retry_policy: Retry policy for the workflow. + cron_schedule: Cron schedule for recurring workflow executions. See + https://docs.temporal.io/cron-job. + memo: Memo for the workflow. + search_attributes: Typed search attributes for the workflow. + priority: Priority of the workflow execution. + versioning_override: Override for workflow versioning behavior. + start_delay: Amount of time to wait before starting the workflow. This does not + work with cron-schedule. + static_summary: Single-line fixed summary for the workflow execution that may + appear in UI and CLI. This can be in single-line Temporal Markdown format. + static_details: General fixed details for the workflow execution that may appear + in UI and CLI. This can be in Temporal Markdown format and can span multiple + lines. This value is fixed on the workflow execution and cannot be updated. + + Returns: + A workflow handle to the started workflow. + """ + normalized_signal_args: list[typing.Any] | None + if signal_args is None: + normalized_signal_args = None + elif isinstance(signal_args, list): + normalized_signal_args = typing.cast(list[typing.Any], signal_args) + else: + normalized_signal_args = [signal_args] + if positional_args and args is not None: + raise TypeError("cannot specify both positional arguments and args") + normalized_args: list[typing.Any] | None = ( + list(positional_args) if positional_args else args + ) + user_metadata = ( + None + if static_summary is None and static_details is None + else UserMetadata( + static_summary=static_summary, + static_details=static_details, + ) + ) + request = SignalWithStartWorkflowRequest( + workflow=workflow, + args=normalized_args, + id=id, + task_queue=task_queue, + signal=signal, + signal_args=normalized_signal_args, + execution_timeout=execution_timeout, + run_timeout=run_timeout, + task_timeout=task_timeout, + request_id=request_id, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + cron_schedule=cron_schedule, + memo=memo, + search_attributes=search_attributes, + priority=priority, + versioning_override=versioning_override, + start_delay=start_delay, + user_metadata=user_metadata, + ) + return await _signal_with_start_workflow(request) diff --git a/temporalio/nexus/system/workflow_service/service.py b/temporalio/nexus/system/workflow_service/service.py new file mode 100644 index 000000000..7ce5849ca --- /dev/null +++ b/temporalio/nexus/system/workflow_service/service.py @@ -0,0 +1,21 @@ +# Generated by nex-gen. DO NOT EDIT! + +from __future__ import annotations + +from nexusrpc import Operation, service + +import temporalio.api.workflowservice.v1.request_response_pb2 + + +@service(name="temporal.api.workflowservice.v1.WorkflowService") +class WorkflowService: + """ + .. warning:: + This API is experimental and subject to change. + """ + + # .. warning:: This API is experimental and subject to change. + signal_with_start_workflow: Operation[ + temporalio.api.workflowservice.v1.request_response_pb2.SignalWithStartWorkflowExecutionRequest, + temporalio.api.workflowservice.v1.request_response_pb2.SignalWithStartWorkflowExecutionResponse, + ] = Operation(name="SignalWithStartWorkflowExecution") diff --git a/temporalio/worker/_command_aware_visitor.py b/temporalio/worker/_command_aware_visitor.py index f77bea042..500fc4db5 100644 --- a/temporalio/worker/_command_aware_visitor.py +++ b/temporalio/worker/_command_aware_visitor.py @@ -6,7 +6,8 @@ from dataclasses import dataclass from temporalio.api.enums.v1.command_type_pb2 import CommandType -from temporalio.bridge._visitor import PayloadVisitor, VisitorFunctions +from temporalio.bridge._visitor import PayloadVisitor +from temporalio.bridge._visitor_functions import VisitorFunctions from temporalio.bridge.proto.workflow_activation.workflow_activation_pb2 import ( ResolveActivity, ResolveChildWorkflowExecution, diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 76ccdb2e3..deefb5ad3 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -57,6 +57,7 @@ import temporalio.common import temporalio.converter import temporalio.exceptions +import temporalio.nexus.system import temporalio.workflow from temporalio.converter import StorageDriverStoreContext, StorageDriverWorkflowInfo from temporalio.service import __version__ @@ -2085,8 +2086,19 @@ async def operation_handle_fn() -> OutputT: ): t.uncancel() # type: ignore[union-attr] + payload_converter = ( + temporalio.nexus.system.get_payload_converter() + if temporalio.nexus.system.is_system_operation( + input.service, input.operation_name + ) + else self._context_free_payload_converter + ) handle = _NexusOperationHandle( - self, self._next_seq("nexus_operation"), input, operation_handle_fn() + self, + self._next_seq("nexus_operation"), + input, + operation_handle_fn(), + payload_converter, ) handle._apply_schedule_command() self._pending_nexus_operations[handle._seq] = handle @@ -3453,6 +3465,7 @@ def __init__( seq: int, input: StartNexusOperationInput[Any, OutputT], fn: Coroutine[Any, Any, OutputT], + payload_converter: temporalio.converter.PayloadConverter, ): self._instance = instance self._seq = seq @@ -3460,7 +3473,7 @@ def __init__( self._task = asyncio.Task(fn) self._start_fut: asyncio.Future[str | None] = instance.create_future() self._result_fut: asyncio.Future[OutputT | None] = instance.create_future() - self._payload_converter = self._instance._context_free_payload_converter + self._payload_converter = payload_converter self._failure_converter = self._instance._context_free_failure_converter @property diff --git a/temporalio/workflow/__init__.py b/temporalio/workflow/__init__.py index 8b8b0fb6f..fedc31b10 100644 --- a/temporalio/workflow/__init__.py +++ b/temporalio/workflow/__init__.py @@ -1,5 +1,7 @@ """Utilities that can decorate or be called inside workflows.""" +# ruff: noqa: I001 + from __future__ import annotations from ..types import ( @@ -167,6 +169,12 @@ start_child_workflow, ) +# BEGIN GENERATED NEXUS SYSTEM EXPORTS +from temporalio.nexus.system.workflow_service import ( + signal_with_start_workflow, +) +# END GENERATED NEXUS SYSTEM EXPORTS + __all__ = [ "ActivityCancellationType", "ActivityConfig", @@ -314,4 +322,7 @@ "ProtocolReturnType", "ReturnType", "SelfType", + # BEGIN GENERATED NEXUS SYSTEM __ALL__ + "signal_with_start_workflow", + # END GENERATED NEXUS SYSTEM __ALL__ ] diff --git a/tests/__init__.py b/tests/__init__.py index d62129b39..4725d3a7e 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +1 @@ -DEV_SERVER_DOWNLOAD_VERSION = "v1.7.1-standalone-nexus-operations" +DEV_SERVER_DOWNLOAD_VERSION = "v1.7.1-system-nexus-operations" diff --git a/tests/conftest.py b/tests/conftest.py index 1e1db3730..9eaa1ff47 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -136,6 +136,8 @@ async def env(env_type: str) -> AsyncGenerator[WorkflowEnvironment, None]: "nexusoperation.enableStandalone=true", "--dynamic-config-value", 'system.system.refreshNexusEndpointsMinWait="0s"', + "--dynamic-config-value", + "history.enableSignalWithStartFromWorkflow=true", ], dev_server_download_version=DEV_SERVER_DOWNLOAD_VERSION, ) diff --git a/tests/nexus/test_temporal_system_nexus.py b/tests/nexus/test_temporal_system_nexus.py new file mode 100644 index 000000000..c7d9319ca --- /dev/null +++ b/tests/nexus/test_temporal_system_nexus.py @@ -0,0 +1,302 @@ +from __future__ import annotations + +import dataclasses +import uuid +from collections.abc import Sequence +from datetime import timedelta +from typing import Any, cast + +import pytest +from google.protobuf.descriptor import FieldDescriptor +from google.protobuf.message import Message + +import temporalio.api.common.v1 +import temporalio.api.workflowservice.v1.request_response_pb2 as workflowservice_pb2 +import temporalio.converter +import temporalio.nexus.system as nexus_system +from temporalio import workflow +from temporalio.client import Client +from temporalio.converter import ExternalStorage, PayloadCodec +from temporalio.testing import WorkflowEnvironment +from temporalio.worker import ( + Interceptor, + StartNexusOperationInput, + Worker, + WorkflowInboundInterceptor, + WorkflowInterceptorClassInput, + WorkflowOutboundInterceptor, +) +from temporalio.worker._workflow_instance import UnsandboxedWorkflowRunner +from tests.test_extstore import InMemoryTestDriver + +interceptor_traces: list[tuple[str, object]] = [] + + +@workflow.defn +class ExternalHandleSignalWithStartWorkflowCaller: + @workflow.run + async def run(self, task_queue: str) -> str: + started_handle = await workflow.signal_with_start_workflow( + "test-workflow", + "workflow-input", + id="system-nexus-workflow-id", + task_queue=task_queue, + signal="test-signal", + signal_args=["signal-input"], + memo={"memo-key": "memo-value"}, + static_summary="summary-value", + static_details="details-value", + ) + return started_handle.id + + +class RejectOuterSystemNexusCodec(PayloadCodec): + def __init__(self) -> None: + self.encode_count = 0 + + async def encode( + self, payloads: Sequence[temporalio.api.common.v1.Payload] + ) -> list[temporalio.api.common.v1.Payload]: + encoded: list[temporalio.api.common.v1.Payload] = [] + for payload in payloads: + if ( + payload.metadata.get("encoding") == b"binary/protobuf" + and payload.metadata.get("messageType") + == b"temporal.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest" + ): + raise RuntimeError( + "outer system nexus envelope should not be codec encoded" + ) + self.encode_count += 1 + encoded.append( + temporalio.api.common.v1.Payload( + metadata={**payload.metadata, "test-codec": b"true"}, + data=payload.data, + ) + ) + return encoded + + async def decode( + self, payloads: Sequence[temporalio.api.common.v1.Payload] + ) -> list[temporalio.api.common.v1.Payload]: + decoded: list[temporalio.api.common.v1.Payload] = [] + for payload in payloads: + if ( + payload.metadata.get("encoding") == b"binary/protobuf" + and payload.metadata.get("messageType") + == b"temporal.api.workflowservice.v1.SignalWithStartWorkflowExecutionRequest" + ): + raise RuntimeError( + "outer system nexus envelope should not be codec decoded" + ) + decoded.append(payload) + return decoded + + +class TracingWorkflowInterceptor(Interceptor): + def workflow_interceptor_class( + self, input: WorkflowInterceptorClassInput + ) -> type[WorkflowInboundInterceptor] | None: + return _TracingWorkflowInboundInterceptor + + +class _TracingWorkflowInboundInterceptor(WorkflowInboundInterceptor): + def init(self, outbound: WorkflowOutboundInterceptor) -> None: + super().init(_TracingWorkflowOutboundInterceptor(outbound)) + + +class _TracingWorkflowOutboundInterceptor(WorkflowOutboundInterceptor): + async def start_nexus_operation( + self, input: StartNexusOperationInput[Any, Any] + ) -> workflow.NexusOperationHandle[Any]: + interceptor_traces.append(("workflow.start_nexus_operation", input)) + return await super().start_nexus_operation(input) + + +def _assert_stored_payloads_include( + driver: InMemoryTestDriver, expected_payload_data: set[bytes] +) -> None: + stored_payload_data: set[bytes] = set() + for stored_payload_bytes in driver._storage.values(): + stored_payload = temporalio.api.common.v1.Payload() + stored_payload.ParseFromString(stored_payload_bytes) + assert stored_payload.metadata["test-codec"] == b"true" + stored_payload_data.add(stored_payload.data) + assert expected_payload_data.issubset(stored_payload_data) + + +def _assert_start_nexus_operation_interceptor_trace() -> None: + assert len(interceptor_traces) == 1 + trace_name, trace_value = interceptor_traces.pop() + assert trace_name == "workflow.start_nexus_operation" + trace_input = cast(StartNexusOperationInput[Any, Any], trace_value) + request = cast( + workflowservice_pb2.SignalWithStartWorkflowExecutionRequest, + trace_input.input, + ) + assert request.workflow_id == "system-nexus-workflow-id" + assert request.signal_name == "test-signal" + assert request.workflow_type.name == "test-workflow" + + +def _build_proto_sample(message_type: type[Message]) -> Message: + message = message_type() + _populate_proto_sample(message) + return message + + +def _populate_proto_sample(message: Message, *, path: str = "value") -> None: + seen_oneofs: set[str] = set() + for field in message.DESCRIPTOR.fields: + if field.containing_oneof is not None: + if field.containing_oneof.name in seen_oneofs: + continue + seen_oneofs.add(field.containing_oneof.name) + if field.label == FieldDescriptor.LABEL_REPEATED: + if ( + field.message_type is not None + and field.message_type.GetOptions().map_entry + ): + _populate_proto_map_entry(message, field, path=path) + elif field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE: + _populate_proto_sample( + getattr(message, field.name).add(), + path=f"{path}.{field.name}[0]", + ) + else: + getattr(message, field.name).append( + _proto_scalar_sample(field, path=f"{path}.{field.name}[0]") + ) + elif field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE: + _populate_proto_sample( + getattr(message, field.name), + path=f"{path}.{field.name}", + ) + else: + setattr( + message, + field.name, + _proto_scalar_sample(field, path=f"{path}.{field.name}"), + ) + + +def _populate_proto_map_entry( + message: Message, + field: FieldDescriptor, + *, + path: str, +) -> None: + key_field = field.message_type.fields_by_name["key"] + value_field = field.message_type.fields_by_name["value"] + key = _proto_scalar_sample(key_field, path=f"{path}.{field.name}.key") + container = getattr(message, field.name) + if value_field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE: + _populate_proto_sample( + container[key], + path=f"{path}.{field.name}[{key!r}]", + ) + else: + container[key] = _proto_scalar_sample( + value_field, + path=f"{path}.{field.name}[{key!r}]", + ) + + +def _proto_scalar_sample(field: FieldDescriptor, *, path: str) -> Any: + if field.type == FieldDescriptor.TYPE_BYTES: + return b"test" + if field.cpp_type == FieldDescriptor.CPPTYPE_STRING: + return f"{path}-value" + if field.cpp_type == FieldDescriptor.CPPTYPE_BOOL: + return True + if field.cpp_type in ( + FieldDescriptor.CPPTYPE_INT32, + FieldDescriptor.CPPTYPE_INT64, + FieldDescriptor.CPPTYPE_UINT32, + FieldDescriptor.CPPTYPE_UINT64, + ): + return 1 + if field.cpp_type in ( + FieldDescriptor.CPPTYPE_FLOAT, + FieldDescriptor.CPPTYPE_DOUBLE, + ): + return 1.5 + if field.cpp_type == FieldDescriptor.CPPTYPE_ENUM: + for enum_value in field.enum_type.values: + if enum_value.number != 0: + return enum_value.number + return field.enum_type.values[0].number + raise TypeError(f"Unhandled proto scalar sample at {path}: {field!r}") + + +@pytest.mark.parametrize( + "message_type", + [ + workflowservice_pb2.SignalWithStartWorkflowExecutionRequest, + workflowservice_pb2.SignalWithStartWorkflowExecutionResponse, + ], +) +def test_system_nexus_proto_roundtrip(message_type: type[Message]) -> None: + payload_converter = nexus_system.get_payload_converter() + proto_value = _build_proto_sample(message_type) + payload = payload_converter.to_payload(proto_value) + assert payload is not None + assert payload.metadata["encoding"] == b"binary/protobuf" + assert payload.metadata["messageType"] == message_type.DESCRIPTOR.full_name.encode() + roundtripped = payload_converter.from_payload(payload, message_type) + assert isinstance(roundtripped, message_type) + assert roundtripped == proto_value + + +async def test_external_workflow_handle_signal_with_start_workflow_uses_system_nexus( + env: WorkflowEnvironment, +): + if env.supports_time_skipping: + pytest.skip("Nexus tests don't work with the Java test server") + + codec = RejectOuterSystemNexusCodec() + interceptor_traces.clear() + driver = InMemoryTestDriver() + caller_config = env.client.config() + caller_config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_codec=codec, + external_storage=ExternalStorage( + drivers=[driver], + payload_size_threshold=1, + ), + ) + caller_client = Client(**caller_config) + caller_task_queue = str(uuid.uuid4()) + handler_task_queue = str(uuid.uuid4()) + + caller_worker = Worker( + caller_client, + task_queue=caller_task_queue, + workflows=[ExternalHandleSignalWithStartWorkflowCaller], + workflow_runner=UnsandboxedWorkflowRunner(), + interceptors=[TracingWorkflowInterceptor()], + ) + + async with caller_worker: + result = await caller_client.execute_workflow( + ExternalHandleSignalWithStartWorkflowCaller.run, + args=[handler_task_queue], + id=str(uuid.uuid4()), + task_queue=caller_task_queue, + execution_timeout=timedelta(seconds=5), + ) + + assert result == "system-nexus-workflow-id" + assert codec.encode_count >= 5 + _assert_stored_payloads_include( + driver, + { + b'"workflow-input"', + b'"signal-input"', + b'"memo-value"', + b'"summary-value"', + b'"details-value"', + }, + ) + _assert_start_nexus_operation_interceptor_trace() diff --git a/tests/worker/test_visitor.py b/tests/worker/test_visitor.py index 876387393..a815f3135 100644 --- a/tests/worker/test_visitor.py +++ b/tests/worker/test_visitor.py @@ -15,7 +15,8 @@ SearchAttributes, ) from temporalio.api.sdk.v1.user_metadata_pb2 import UserMetadata -from temporalio.bridge._visitor import PayloadVisitor, VisitorFunctions +from temporalio.bridge._visitor import PayloadVisitor +from temporalio.bridge._visitor_functions import VisitorFunctions from temporalio.bridge.proto.workflow_activation.workflow_activation_pb2 import ( InitializeWorkflow, WorkflowActivation,