Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions gigl/common/services/vertex_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def get_pipeline() -> int: # NOTE: `get_pipeline` here is the Pipeline name
import datetime
import time
from dataclasses import dataclass
from typing import Final, Optional, Sequence, Union, cast
from typing import Any, Final, Optional, Sequence, Union, cast

from google.cloud import aiplatform
from google.cloud.aiplatform_v1.types import (
Expand Down Expand Up @@ -370,7 +370,7 @@ def run_pipeline(
self,
display_name: str,
template_path: Uri,
run_keyword_args: dict[str, str],
run_keyword_args: dict[str, Any],
job_id: Optional[str] = None,
labels: Optional[dict[str, str]] = None,
experiment: Optional[str] = None,
Expand All @@ -383,7 +383,7 @@ def run_pipeline(
Args:
display_name (str): The display of the pipeline.
template_path (Uri): The path to the compiled pipeline YAML.
run_keyword_args (dict[str, str]): Runtime arguements passed to your pipeline.
run_keyword_args (dict[str, Any]): Runtime arguements passed to your pipeline.
job_id (Optional[str]): The ID of the job. If not provided will be the *pipeline_name* + datetime.
Note: The pipeline_name and display_name are *not* the same.
Note: pipeline_name comes is defined in the `template_path` and ultimately comes from Python pipeline definition.
Expand Down
4 changes: 4 additions & 0 deletions gigl/orchestration/kubeflow/kfp_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from gigl.common.types.resource_config import CommonPipelineComponentConfigs
from gigl.env.pipelines_config import get_resource_config
from gigl.orchestration.kubeflow.kfp_pipeline import generate_pipeline
from gigl.orchestration.kubeflow.utils.glt_backend import resolve_should_use_glt_backend
from gigl.src.common.constants.components import GiGLComponents
from gigl.src.common.types import AppliedTaskIdentifier
from gigl.src.common.utils.file_loader import FileLoader
Expand Down Expand Up @@ -148,6 +149,9 @@ def run(
"start_at": start_at,
"template_or_frozen_config_uri": task_config_uri.uri,
"resource_config_uri": resource_config_uri.uri,
"should_use_glt_backend": resolve_should_use_glt_backend(
task_config_uri=task_config_uri
),
}
# We need to provide *some* notification emails, other wise the cleanup component will fail.
# Ideally, we'd be able to provide None and have it handle it, but for whatever reason
Expand Down
14 changes: 5 additions & 9 deletions gigl/orchestration/kubeflow/kfp_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
from gigl.common import LocalUri
from gigl.common.logger import Logger
from gigl.common.types.resource_config import CommonPipelineComponentConfigs
from gigl.orchestration.kubeflow.utils.glt_backend import (
check_glt_backend_eligibility_component,
)
from gigl.orchestration.kubeflow.utils.log_metrics import log_metrics_to_ui
from gigl.orchestration.kubeflow.utils.resource import add_task_resource_requirements
from gigl.src.common.constants.components import GiGLComponents
Expand Down Expand Up @@ -135,6 +132,7 @@ def _generate_component_tasks(
common_pipeline_component_configs: CommonPipelineComponentConfigs,
start_at: Optional[str] = None,
stop_after: Optional[str] = None,
should_use_glt_backend: bool = False,
):
validation_check_task = _generate_component_task(
component=GiGLComponents.ConfigValidator,
Expand All @@ -145,18 +143,14 @@ def _generate_component_tasks(
resource_config_uri=resource_config_uri,
common_pipeline_component_configs=common_pipeline_component_configs,
)
should_use_glt = check_glt_backend_eligibility_component(
task_config_uri=template_or_frozen_config_uri,
base_image=common_pipeline_component_configs.cpu_container_image,
)

with kfp.dsl.Condition(start_at == GiGLComponents.ConfigPopulator.value):
config_populator_task = _create_config_populator_task_op(
job_name=job_name,
task_config_uri=template_or_frozen_config_uri,
resource_config_uri=resource_config_uri,
common_pipeline_component_configs=common_pipeline_component_configs,
should_use_glt_runtime_param=should_use_glt,
should_use_glt_runtime_param=should_use_glt_backend,
stop_after=stop_after,
)
config_populator_task.after(validation_check_task)
Expand All @@ -168,7 +162,7 @@ def _generate_component_tasks(
resource_config_uri=resource_config_uri,
common_pipeline_component_configs=common_pipeline_component_configs,
stop_after=stop_after,
should_use_glt_runtime_param=should_use_glt,
should_use_glt_runtime_param=should_use_glt_backend,
)
data_preprocessor_task.after(validation_check_task)

Expand Down Expand Up @@ -248,6 +242,7 @@ def pipeline(
start_at: str = GiGLComponents.ConfigPopulator.value,
stop_after: Optional[str] = None,
notification_emails: Optional[List[str]] = None,
should_use_glt_backend: bool = False,
):
with kfp.dsl.ExitHandler(
VertexNotificationEmailOp(recipients=notification_emails),
Expand All @@ -260,6 +255,7 @@ def pipeline(
common_pipeline_component_configs=common_pipeline_component_configs,
start_at=start_at,
stop_after=stop_after,
should_use_glt_backend=should_use_glt_backend,
)

return pipeline
Expand Down
36 changes: 4 additions & 32 deletions gigl/orchestration/kubeflow/utils/glt_backend.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,9 @@
from kfp import dsl
from gigl.common import Uri
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper


def check_glt_backend_eligibility_component(
task_config_uri: str, base_image: str
) -> bool:
comp = dsl.component(
func=_check_glt_backend_eligibility_component, base_image=base_image
)
comp.description = "Check whether to use GLT Backend"
return comp(task_config_uri=task_config_uri).output


def _check_glt_backend_eligibility_component(
task_config_uri: str,
) -> bool:
"""
Used by KFP to check if GLT should be used as a backend for current run.
Args:
task_config_uri (str): Task config uri for current run
Returns:
bool: Whether to use GLT as a backend for current run ('True' or 'False')
"""

# This is required to resolve below packages when containerized by KFP.
import os
import sys

sys.path.append(os.getcwd())

from gigl.common import UriFactory
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper

def resolve_should_use_glt_backend(task_config_uri: Uri) -> bool:
config = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri(
gbml_config_uri=UriFactory.create_uri(task_config_uri)
gbml_config_uri=task_config_uri
)
return config.should_use_glt_backend
25 changes: 25 additions & 0 deletions tests/unit/orchestration/kubeflow/glt_backend_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from unittest.mock import patch

from absl.testing import absltest

from gigl.common import GcsUri
from gigl.orchestration.kubeflow.utils.glt_backend import resolve_should_use_glt_backend
from tests.test_assets.test_case import TestCase


class GltBackendTest(TestCase):
@patch("gigl.orchestration.kubeflow.utils.glt_backend.GbmlConfigPbWrapper")
def test_resolve_should_use_glt_backend_reads_task_config(self, MockGbmlConfig):
mock_config = MockGbmlConfig.get_gbml_config_pb_wrapper_from_uri.return_value
mock_config.should_use_glt_backend = True

task_config_uri = GcsUri("gs://test-bucket/task_config.yaml")

self.assertTrue(resolve_should_use_glt_backend(task_config_uri=task_config_uri))
MockGbmlConfig.get_gbml_config_pb_wrapper_from_uri.assert_called_once_with(
gbml_config_uri=task_config_uri
)


if __name__ == "__main__":
absltest.main()
77 changes: 76 additions & 1 deletion tests/unit/orchestration/kubeflow/kfp_orchestrator_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from pathlib import Path
from tempfile import TemporaryDirectory
from types import SimpleNamespace
from unittest.mock import ANY, patch

from absl.testing import absltest

from gigl.common import GcsUri
from gigl.common import GcsUri, LocalUri
from gigl.common.logger import Logger
from gigl.orchestration.kubeflow.kfp_orchestrator import KfpOrchestrator
from gigl.src.common.types import AppliedTaskIdentifier
from tests.test_assets.test_case import TestCase

logger = Logger()
Expand All @@ -29,6 +33,77 @@ def test_compile_uploads_compiled_yaml(self, MockFileLoader):
file_uri_src=ANY, file_uri_dst=dst_compiled_pipeline_path
)

def test_compile_uses_glt_backend_pipeline_parameter(self):
with TemporaryDirectory() as temp_dir:
dst_compiled_pipeline_path = LocalUri(Path(temp_dir) / "pipeline.yaml")
KfpOrchestrator.compile(
cuda_container_image="SOME NONEXISTENT IMAGE 1",
cpu_container_image="SOME NONEXISTENT IMAGE 2",
dataflow_container_image="SOME NONEXISTENT IMAGE 3",
dst_compiled_pipeline_path=dst_compiled_pipeline_path,
)

compiled_pipeline_yaml = Path(dst_compiled_pipeline_path.uri).read_text()

self.assertIn("should_use_glt_backend", compiled_pipeline_yaml)
self.assertNotIn(
"check-glt-backend-eligibility-component", compiled_pipeline_yaml
)

@patch(
"gigl.orchestration.kubeflow.kfp_orchestrator.resolve_should_use_glt_backend"
)
@patch("gigl.orchestration.kubeflow.kfp_orchestrator.VertexAIService")
@patch("gigl.orchestration.kubeflow.kfp_orchestrator.get_resource_config")
@patch("gigl.orchestration.kubeflow.kfp_orchestrator.FileLoader")
def test_run_passes_resolved_glt_backend_param(
self,
MockFileLoader,
mock_get_resource_config,
MockVertexAIService,
mock_resolve_should_use_glt_backend,
):
mock_file_loader = MockFileLoader.return_value
mock_file_loader.does_uri_exist.return_value = True
mock_resolve_should_use_glt_backend.return_value = True
mock_get_resource_config.return_value = SimpleNamespace(
project="test-project",
region="us-central1",
service_account_email="test@test-project.iam.gserviceaccount.com",
temp_assets_regional_bucket_path=GcsUri("gs://test-bucket"),
)
mock_vertex_ai_service = MockVertexAIService.return_value
mock_vertex_ai_service.run_pipeline.return_value = "test-run"

with TemporaryDirectory() as temp_dir:
compiled_pipeline_path = LocalUri(Path(temp_dir) / "pipeline.yaml")
Path(compiled_pipeline_path.uri).write_text(
"""
root:
inputDefinitions:
parameters:
should_use_glt_backend:
parameterType: BOOLEAN
"""
)
task_config_uri = GcsUri("gs://test-bucket/task_config.yaml")
run = KfpOrchestrator().run(
applied_task_identifier=AppliedTaskIdentifier("test_job"),
task_config_uri=task_config_uri,
resource_config_uri=GcsUri("gs://test-bucket/resource_config.yaml"),
compiled_pipeline_path=compiled_pipeline_path,
)

self.assertEqual(run, "test-run")
mock_resolve_should_use_glt_backend.assert_called_once_with(
task_config_uri=task_config_uri
)
mock_vertex_ai_service.run_pipeline.assert_called_once()
run_keyword_args = mock_vertex_ai_service.run_pipeline.call_args.kwargs[
"run_keyword_args"
]
self.assertTrue(run_keyword_args["should_use_glt_backend"])


if __name__ == "__main__":
absltest.main()