diff --git a/google/cloud/dataproc_spark_connect/proto/__init__.py b/google/cloud/dataproc_spark_connect/proto/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/google/cloud/dataproc_spark_connect/proto/sparkmonitor.proto b/google/cloud/dataproc_spark_connect/proto/sparkmonitor.proto new file mode 100644 index 0000000..8e3d920 --- /dev/null +++ b/google/cloud/dataproc_spark_connect/proto/sparkmonitor.proto @@ -0,0 +1,115 @@ +syntax = "proto3"; + +package spark.connect; + +option java_multiple_files = true; +option java_package = "org.apache.spark.connect.proto"; + +// SparkMonitor progress data delivered via the upstream extension slot on ExecutePlanResponse +// (google.protobuf.Any extension = 999). +// type_url: "type.googleapis.com/spark.connect.SparkMonitorProgress" +message SparkMonitorProgress { + optional ApplicationInfo application_info = 1; + repeated JobEvent job_events = 2; + repeated DetailedStageEvent stage_events = 3; + repeated TaskEvent task_events = 4; + repeated ExecutorEvent executor_events = 5; + optional bool stream_complete = 6; + + // Application lifecycle info (start_time present = start event, end_time present = end event) + message ApplicationInfo { + optional int64 start_time = 1; + optional int64 end_time = 2; + optional string app_id = 3; + optional string app_attempt_id = 4; + optional string app_name = 5; + optional string spark_user = 6; + } + + // Job events (JOB_START=0, JOB_END=1) + message JobEvent { + enum JobEventType { + JOB_START = 0; + JOB_END = 1; + } + JobEventType event_type = 1; + int64 job_id = 2; + string status = 3; + optional int64 submission_time = 4; + optional int64 completion_time = 5; + optional string job_group = 6; + optional string name = 7; + repeated int32 stage_ids = 8; + map stage_infos = 9; + optional int32 num_tasks = 10; + optional int32 total_cores = 11; + optional string app_id = 12; + optional int32 num_executors = 13; + } + + message JobStageInfo { + int32 attempt_id = 1; + string name = 2; + int32 num_tasks = 3; + int64 completion_time = 4; + int64 submission_time = 5; + } + + // Detailed stage events (STAGE_SUBMITTED=0, STAGE_ACTIVE=1, STAGE_COMPLETED=2) + message DetailedStageEvent { + enum StageEventType { + STAGE_SUBMITTED = 0; + STAGE_ACTIVE = 1; + STAGE_COMPLETED = 2; + } + StageEventType event_type = 1; + int64 stage_id = 2; + int32 stage_attempt_id = 3; + string name = 4; + int32 num_tasks = 5; + repeated int32 parent_ids = 6; + optional int64 submission_time = 7; + optional int64 completion_time = 8; + repeated int64 job_ids = 9; + optional int32 num_active_tasks = 10; + optional int32 num_failed_tasks = 11; + optional int32 num_completed_tasks = 12; + optional string status = 13; + } + + // Task events (TASK_START=0, TASK_END=1) + message TaskEvent { + enum TaskEventType { + TASK_START = 0; + TASK_END = 1; + } + TaskEventType event_type = 1; + int64 task_id = 2; + int64 stage_id = 3; + int32 stage_attempt_id = 4; + int32 index = 5; + int32 attempt_number = 6; + string executor_id = 7; + string host = 8; + string status = 9; + bool speculative = 10; + optional int64 launch_time = 11; + optional int64 finish_time = 12; + optional string task_type = 13; + optional string error_message = 14; + } + + // Executor events (EXECUTOR_ADDED=0, EXECUTOR_REMOVED=1) + message ExecutorEvent { + enum ExecutorEventType { + EXECUTOR_ADDED = 0; + EXECUTOR_REMOVED = 1; + } + ExecutorEventType event_type = 1; + string executor_id = 2; + int64 time = 3; + optional string host = 4; + optional int32 num_cores = 5; + optional int32 total_cores = 6; + } +} \ No newline at end of file diff --git a/google/cloud/dataproc_spark_connect/proto/sparkmonitor_pb2.py b/google/cloud/dataproc_spark_connect/proto/sparkmonitor_pb2.py new file mode 100644 index 0000000..c884088 --- /dev/null +++ b/google/cloud/dataproc_spark_connect/proto/sparkmonitor_pb2.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: google/cloud/dataproc_spark_connect/proto/sparkmonitor.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n 0 + or len(sm.stage_events) > 0 + or len(sm.task_events) > 0 + or len(sm.executor_events) > 0 + or sm.HasField("stream_complete") + ) + if not is_sparkmonitor: + return + + responses_with_sparkmonitor[0] += 1 + + msg_type = self._derive_sparkmonitor_msgtype(sm) + msg_type_counts[msg_type] = msg_type_counts.get(msg_type, 0) + 1 + + # Skip stream completion signal (don't forward to VS Code) + if sm.HasField("stream_complete") and sm.stream_complete: + return + + # Convert to Scala-compatible JSON and send to VS Code + json_msg = self._proto_to_scala_json_format(sm) + self._send_to_vscode(json_msg) + + except Exception as e: + logger.debug(f"Error extracting SparkMonitor: {e}") + + def _derive_sparkmonitor_msgtype( + self, sm: sparkmonitor_pb2.SparkMonitorProgress + ) -> str: + """Derive a msgtype string from the new enum-based SparkMonitor proto structure.""" + if sm.HasField("stream_complete"): + return "sparkMonitorStreamComplete" + if sm.HasField("application_info"): + return ( + "sparkApplicationStart" + if sm.application_info.HasField("start_time") + else "sparkApplicationEnd" + ) + if sm.job_events: + return ( + "sparkJobStart" + if sm.job_events[0].event_type == 0 + else "sparkJobEnd" + ) + if sm.stage_events: + return [ + "sparkStageSubmitted", + "sparkStageActive", + "sparkStageCompleted", + ][sm.stage_events[0].event_type] + if sm.task_events: + return ( + "sparkTaskStart" + if sm.task_events[0].event_type == 0 + else "sparkTaskEnd" + ) + if sm.executor_events: + return ( + "sparkExecutorAdded" + if sm.executor_events[0].event_type == 0 + else "sparkExecutorRemoved" + ) + return "unknown" + + def _convert_string_numbers_to_int(self, obj): + """ + Recursively convert string numbers to integers in a dictionary. + + MessageToJson converts int64 fields to strings by default to avoid JavaScript + precision issues, but the VS Code SparkMonitor extension expects numeric values. + """ + if isinstance(obj, dict): + return { + k: self._convert_string_numbers_to_int(v) + for k, v in obj.items() + } + elif isinstance(obj, list): + return [self._convert_string_numbers_to_int(item) for item in obj] + elif isinstance(obj, str): + # Try to convert string to int if it looks like a number + # Negative numbers (like -1 for completionTime) should also be converted + if obj.lstrip("-").isdigit(): + return int(obj) + return obj + else: + return obj + + def _proto_to_scala_json_format( + self, sm: sparkmonitor_pb2.SparkMonitorProgress + ) -> dict: + """ + Convert protobuf message to JSON format matching the Scala listener's output. + + Handles the new ExecutionProgress-based protocol where events are delivered as + typed sub-messages with enums (JobEvent, DetailedStageEvent, TaskEvent, ExecutorEvent) + rather than the old string msg_type + separate data messages approach. + + The output format is unchanged from before: + - 'msgtype' (lowercase) for the event type string + - camelCase for all other fields + - Numeric fields as JSON numbers (not strings) + """ + try: + # Convert proto to JSON with camelCase field names + try: + # Protobuf 5.x+ uses always_print_fields_with_no_presence + json_str = json_format.MessageToJson( + sm, + preserving_proto_field_name=False, + always_print_fields_with_no_presence=True, + ) + except TypeError: + # Protobuf <5.x uses including_default_value_fields + json_str = json_format.MessageToJson( + sm, + preserving_proto_field_name=False, + including_default_value_fields=True, + ) + except Exception as e: + logger.error(f"Failed to convert proto to JSON: {e}") + return {"msgtype": "unknown", "error": "conversion_failed"} + + msg = json.loads(json_str) + + # Convert string numbers to actual numbers for compatibility with VS Code extension + # MessageToJson converts int64 to strings by default to avoid JS precision issues, + # but the SparkMonitor extension expects numeric values + msg = self._convert_string_numbers_to_int(msg) + + # Use proto HasField / list length for type detection. + # Then pull event data from the corresponding JSON key and strip the enum 'eventType' field. + if sm.HasField("application_info"): + msgtype = ( + "sparkApplicationStart" + if sm.application_info.HasField("start_time") + else "sparkApplicationEnd" + ) + event_data = msg.get("applicationInfo", {}) + elif sm.job_events: + msgtype = ( + "sparkJobStart" + if sm.job_events[0].event_type == 0 + else "sparkJobEnd" + ) + raw = msg.get("jobEvents", [{}])[0] + event_data = {k: v for k, v in raw.items() if k != "eventType"} + elif sm.stage_events: + msgtype = [ + "sparkStageSubmitted", + "sparkStageActive", + "sparkStageCompleted", + ][sm.stage_events[0].event_type] + raw = msg.get("stageEvents", [{}])[0] + event_data = {k: v for k, v in raw.items() if k != "eventType"} + elif sm.task_events: + msgtype = ( + "sparkTaskStart" + if sm.task_events[0].event_type == 0 + else "sparkTaskEnd" + ) + raw = msg.get("taskEvents", [{}])[0] + event_data = {k: v for k, v in raw.items() if k != "eventType"} + elif sm.executor_events: + msgtype = ( + "sparkExecutorAdded" + if sm.executor_events[0].event_type == 0 + else "sparkExecutorRemoved" + ) + raw = msg.get("executorEvents", [{}])[0] + event_data = {k: v for k, v in raw.items() if k != "eventType"} + else: + return {"msgtype": "unknown"} + + return {"msgtype": msgtype, **event_data} + + def _send_to_vscode(self, msg: dict): + """Send SparkMonitor data to VS Code using IPython display mechanism. + + Matches the remote kernel format exactly: + - Wraps the event in a 'fromscala' envelope + - Converts the msg dict to a JSON string (like the Scala listener does) + """ + if not self._ipython_available: + return + + try: + from IPython.display import display + + display_id = self._current_cell_run_id or str(uuid.uuid4()) + + wrapper = {"msgtype": "fromscala", "msg": json.dumps(msg)} + + display_data = { + "application/vnd.sparkmonitor+json": wrapper, + } + + display(display_data, raw=True, display_id=display_id) + + except Exception as e: + logger.debug(f"Error sending to VS Code: {e}") + @staticmethod @functools.lru_cache(maxsize=1) def get_tqdm_bar(): diff --git a/setup.py b/setup.py index 539e50e..2163f51 100644 --- a/setup.py +++ b/setup.py @@ -28,6 +28,13 @@ url="https://github.com/GoogleCloudDataproc/dataproc-spark-connect-python", license="Apache 2.0", packages=find_namespace_packages(include=["google.*"]), + package_data={ + "google.cloud.dataproc_spark_connect.proto": [ + "*.proto", + "*_pb2.py", + ], + }, + include_package_data=True, install_requires=[ "google-api-core>=2.19", "google-cloud-dataproc>=5.18", @@ -35,5 +42,6 @@ "pyspark[connect]~=4.0.0", "tqdm>=4.67", "websockets>=14.0", + "protobuf>=3.20.0", ], ) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 2b1a624..3e0ddb3 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -1617,6 +1617,7 @@ def test_execute_plan_request_default_behaviour( try: session = DataprocSparkSession.builder.getOrCreate() + mock_uuid4.reset_mock() # clear calls from session init (e.g. _setup_cell_execution_tracking) client = session.client result_request = client._execute_plan_request_with_metadata() @@ -1710,6 +1711,7 @@ def test_execute_plan_request_with_operation_id_provided( try: session = DataprocSparkSession.builder.getOrCreate() + mock_uuid4.reset_mock() # clear calls from session init (e.g. _setup_cell_execution_tracking) client = session.client result_request = client._execute_plan_request_with_metadata() @@ -2644,5 +2646,344 @@ def test_session_skip_terminated(self, mock_session_controller_client): mock_client.get_session.assert_called_once() +class SparkMonitorTests(unittest.TestCase): + """Tests for the SparkMonitor integration added to DataprocSparkSession.""" + + def setUp(self): + self.original_environment = dict(os.environ) + os.environ.clear() + os.environ["GOOGLE_CLOUD_PROJECT"] = "test-project" + os.environ["GOOGLE_CLOUD_REGION"] = "test-region" + + def tearDown(self): + os.environ.clear() + os.environ.update(self.original_environment) + + @staticmethod + def _make_session_instance(**attrs): + """Create a minimal mock DataprocSparkSession with given attributes.""" + session = mock.MagicMock(spec=DataprocSparkSession) + for key, value in attrs.items(): + setattr(session, key, value) + return session + + @staticmethod + def _encode_varint(value): + """Encode an integer as a protobuf base-128 varint.""" + result = b"" + while value > 127: + result += bytes([(value & 0x7F) | 0x80]) + value >>= 7 + result += bytes([value]) + return result + + def _build_fake_grpc_response(self, sm): + """Build a fake gRPC response with SparkMonitorProgress packed in extension (Any, field 999).""" + from google.cloud.dataproc_spark_connect.session import _SPARK_MONITOR_TYPE_URL + + sm_bytes = sm.SerializeToString() + mock_response = mock.MagicMock() + mock_response.HasField.side_effect = lambda field: field == "extension" + mock_response.extension.type_url = _SPARK_MONITOR_TYPE_URL + mock_response.extension.value = sm_bytes + return mock_response + + def test_convert_string_numbers_to_int_positive(self): + session = self._make_session_instance() + result = DataprocSparkSession._convert_string_numbers_to_int( + session, "42" + ) + self.assertEqual(result, 42) + self.assertIsInstance(result, int) + + def test_convert_string_numbers_to_int_negative(self): + """Negative string numbers such as completionTime=-1 should be converted.""" + session = self._make_session_instance() + result = DataprocSparkSession._convert_string_numbers_to_int( + session, "-1" + ) + self.assertEqual(result, -1) + self.assertIsInstance(result, int) + + def test_convert_string_numbers_to_int_preserves_non_numeric(self): + session = self._make_session_instance() + result = DataprocSparkSession._convert_string_numbers_to_int( + session, "sparkJobStart" + ) + self.assertEqual(result, "sparkJobStart") + + def test_convert_string_numbers_to_int_nested_dict_and_list(self): + session = self._make_session_instance() + # Wire up the recursive self-call so nested values are also converted + session._convert_string_numbers_to_int = ( + lambda x: DataprocSparkSession._convert_string_numbers_to_int( + session, x + ) + ) + obj = {"jobId": "5", "status": "SUCCEEDED", "stageIds": ["1", "2"]} + result = DataprocSparkSession._convert_string_numbers_to_int( + session, obj + ) + self.assertEqual( + result, {"jobId": 5, "status": "SUCCEEDED", "stageIds": [1, 2]} + ) + + def test_convert_string_numbers_to_int_passthrough_non_string(self): + session = self._make_session_instance() + self.assertEqual( + DataprocSparkSession._convert_string_numbers_to_int(session, 99), 99 + ) + self.assertIsNone( + DataprocSparkSession._convert_string_numbers_to_int(session, None) + ) + + def test_proto_to_scala_json_format_job_start(self): + from google.cloud.dataproc_spark_connect.proto import sparkmonitor_pb2 + + session = self._make_session_instance() + session._convert_string_numbers_to_int = ( + lambda x: DataprocSparkSession._convert_string_numbers_to_int( + session, x + ) + ) + + sm = sparkmonitor_pb2.SparkMonitorProgress() + je = sm.job_events.add() + je.event_type = sparkmonitor_pb2.SparkMonitorProgress.JobEvent.JOB_START + je.job_id = 3 + je.num_tasks = 10 + je.num_executors = 2 + + result = DataprocSparkSession._proto_to_scala_json_format(session, sm) + + self.assertEqual(result["msgtype"], "sparkJobStart") + self.assertEqual(result["jobId"], 3) + self.assertEqual(result["numTasks"], 10) + self.assertNotIn("eventType", result) + + def test_proto_to_scala_json_format_job_end(self): + from google.cloud.dataproc_spark_connect.proto import sparkmonitor_pb2 + + session = self._make_session_instance() + session._convert_string_numbers_to_int = ( + lambda x: DataprocSparkSession._convert_string_numbers_to_int( + session, x + ) + ) + + sm = sparkmonitor_pb2.SparkMonitorProgress() + je = sm.job_events.add() + je.event_type = sparkmonitor_pb2.SparkMonitorProgress.JobEvent.JOB_END + je.job_id = 3 + je.status = "SUCCEEDED" + + result = DataprocSparkSession._proto_to_scala_json_format(session, sm) + + self.assertEqual(result["msgtype"], "sparkJobEnd") + self.assertEqual(result["jobId"], 3) + self.assertEqual(result["status"], "SUCCEEDED") + + def test_proto_to_scala_json_format_stage_active(self): + from google.cloud.dataproc_spark_connect.proto import sparkmonitor_pb2 + + session = self._make_session_instance() + session._convert_string_numbers_to_int = ( + lambda x: DataprocSparkSession._convert_string_numbers_to_int( + session, x + ) + ) + + sm = sparkmonitor_pb2.SparkMonitorProgress() + se = sm.stage_events.add() + se.event_type = ( + sparkmonitor_pb2.SparkMonitorProgress.DetailedStageEvent.STAGE_ACTIVE + ) + se.stage_id = 7 + se.num_tasks = 20 + se.num_completed_tasks = 20 # optional field + + result = DataprocSparkSession._proto_to_scala_json_format(session, sm) + + self.assertEqual(result["msgtype"], "sparkStageActive") + self.assertEqual(result["stageId"], 7) + self.assertEqual(result["numTasks"], 20) + self.assertNotIn("eventType", result) + + def test_send_to_vscode_skips_when_ipython_unavailable(self): + session = self._make_session_instance(_ipython_available=False) + + with mock.patch("IPython.display.display") as mock_display: + DataprocSparkSession._send_to_vscode( + session, {"msgtype": "sparkJobStart"} + ) + mock_display.assert_not_called() + + def test_send_to_vscode_calls_display_when_ipython_available(self): + import json + + run_id = "test-run-id-1234" + session = self._make_session_instance( + _ipython_available=True, + _current_cell_run_id=run_id, + ) + msg = {"msgtype": "sparkJobEnd", "jobId": 1} + + with mock.patch("IPython.display.display") as mock_display: + with mock.patch.dict( + "sys.modules", + {"IPython.display": mock.MagicMock(display=mock_display)}, + ): + DataprocSparkSession._send_to_vscode(session, msg) + + mock_display.assert_called_once() + call_args = mock_display.call_args + display_data = call_args[0][0] + self.assertIn("application/vnd.sparkmonitor+json", display_data) + wrapper = display_data["application/vnd.sparkmonitor+json"] + self.assertEqual(wrapper["msgtype"], "fromscala") + self.assertEqual(json.loads(wrapper["msg"]), msg) + + def test_extract_and_send_skips_response_without_sparkmonitor_data(self): + session = self._make_session_instance() + + # Response that has no extension field at all + mock_response = mock.MagicMock() + mock_response.HasField.side_effect = lambda field: False + + msg_type_counts = {} + responses_with_sparkmonitor = [0] + + DataprocSparkSession._extract_and_send_sparkmonitor( + session, + mock_response, + 1, + msg_type_counts, + responses_with_sparkmonitor, + ) + + self.assertEqual(responses_with_sparkmonitor[0], 0) + session._send_to_vscode.assert_not_called() + + def test_extract_and_send_skips_stream_complete_signal(self): + from google.cloud.dataproc_spark_connect.proto import sparkmonitor_pb2 + + session = self._make_session_instance() + + sm = sparkmonitor_pb2.SparkMonitorProgress() + sm.stream_complete = True + mock_response = self._build_fake_grpc_response(sm) + + # Wire up _derive_sparkmonitor_msgtype + session._derive_sparkmonitor_msgtype = ( + lambda s: DataprocSparkSession._derive_sparkmonitor_msgtype( + session, s + ) + ) + + msg_type_counts = {} + responses_with_sparkmonitor = [0] + + DataprocSparkSession._extract_and_send_sparkmonitor( + session, + mock_response, + 1, + msg_type_counts, + responses_with_sparkmonitor, + ) + + # Counter incremented but _send_to_vscode NOT called + self.assertEqual(responses_with_sparkmonitor[0], 1) + self.assertEqual(msg_type_counts["sparkMonitorStreamComplete"], 1) + session._send_to_vscode.assert_not_called() + + def test_extract_and_send_processes_valid_job_start_payload(self): + from google.cloud.dataproc_spark_connect.proto import sparkmonitor_pb2 + + session = self._make_session_instance() + + sm = sparkmonitor_pb2.SparkMonitorProgress() + je = sm.job_events.add() + je.event_type = sparkmonitor_pb2.SparkMonitorProgress.JobEvent.JOB_START + je.job_id = 1 + je.num_tasks = 8 + + mock_response = self._build_fake_grpc_response(sm) + + # Wire up real implementations so the full extraction pipeline runs + session._convert_string_numbers_to_int = ( + lambda x: DataprocSparkSession._convert_string_numbers_to_int( + session, x + ) + ) + session._proto_to_scala_json_format = ( + lambda s: DataprocSparkSession._proto_to_scala_json_format( + session, s + ) + ) + session._derive_sparkmonitor_msgtype = ( + lambda s: DataprocSparkSession._derive_sparkmonitor_msgtype( + session, s + ) + ) + + msg_type_counts = {} + responses_with_sparkmonitor = [0] + + DataprocSparkSession._extract_and_send_sparkmonitor( + session, + mock_response, + 1, + msg_type_counts, + responses_with_sparkmonitor, + ) + + self.assertEqual(responses_with_sparkmonitor[0], 1) + self.assertEqual(msg_type_counts["sparkJobStart"], 1) + session._send_to_vscode.assert_called_once() + sent_msg = session._send_to_vscode.call_args[0][0] + self.assertEqual(sent_msg["msgtype"], "sparkJobStart") + + def test_setup_cell_tracking_sets_flag_when_ipython_present(self): + """When IPython is available and has a live shell, _ipython_available should be True.""" + session = self._make_session_instance( + _ipython_available=False, _current_cell_run_id=None + ) + + mock_ip = mock.MagicMock() + with mock.patch("IPython.get_ipython", return_value=mock_ip): + with mock.patch("IPython.display.display"): + DataprocSparkSession._setup_cell_execution_tracking(session) + + self.assertTrue(session._ipython_available) + self.assertIsNotNone(session._current_cell_run_id) + mock_ip.events.register.assert_called_once_with( + "pre_run_cell", mock.ANY + ) + + def test_setup_cell_tracking_leaves_flag_false_when_no_ipython_shell(self): + """When get_ipython() returns None, _ipython_available should remain False.""" + session = self._make_session_instance( + _ipython_available=False, _current_cell_run_id=None + ) + + with mock.patch("IPython.get_ipython", return_value=None): + DataprocSparkSession._setup_cell_execution_tracking(session) + + self.assertFalse(session._ipython_available) + self.assertIsNone(session._current_cell_run_id) + + def test_setup_cell_tracking_is_resilient_to_import_error(self): + """If IPython is not installed, the method should not raise.""" + session = self._make_session_instance( + _ipython_available=False, _current_cell_run_id=None + ) + + with mock.patch.dict("sys.modules", {"IPython": None}): + # Should not raise + DataprocSparkSession._setup_cell_execution_tracking(session) + + self.assertFalse(session._ipython_available) + + if __name__ == "__main__": unittest.main()