Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,64 @@ def start(self) -> None:
scheduler_job_id=self.scheduler_job_id,
)

def _coordinator_extra(self, queue: str | None) -> dict[str, Any] | None:
"""
Return the ``extra`` mapping a coordinator declares for *queue*, if any.

Read from the coordinator's declarative ``[sdk]`` config without importing
or instantiating the coordinator. The coordinator manager only exists on
Airflow 3.3+; on older Task SDKs the import fails and we fall back to no
extra. A malformed ``[sdk] coordinators`` / ``queue_to_coordinator`` config
must not crash the scheduler on this first lookup either, so an invalid
config also falls back to no extra. The exception types are imported from
``airflow.sdk`` so they match whatever Task SDK actually raised them.
"""
if not queue:
return None
try:
from airflow.sdk.exceptions import AirflowConfigException
from airflow.sdk.execution_time.coordinator import get_coordinator_manager
except ImportError:
return None
try:
return get_coordinator_manager().extra_for_queue(queue)
except (AirflowConfigException, ValueError):
self.log.warning(
"Ignoring coordinator config for queue %s: invalid [sdk] coordinator config",
queue,
exc_info=True,
)
return None

def _coordinator_pod_template_file(self, queue: str | None) -> str | None:
"""
Return the pod template a coordinator declares for *queue*, if any.

Lets a queue routed to a non-Python coordinator (via ``[sdk]
queue_to_coordinator``) launch its worker pod from a coordinator-specific
template — for example an image carrying the JVM for a Java coordinator.
"""
if (extra := self._coordinator_extra(queue)) is not None:
return extra.get("pod_template_file", None)
return None

def _coordinator_kube_image(self, queue: str | None) -> str | None:
"""
Return the worker base image a coordinator declares for *queue*, if any.

The base container image is never taken from a pod template; it comes
from ``kube_image`` (``worker_container_repository:worker_container_tag``)
or a per-task ``pod_override``. A coordinator may declare its own
``worker_container_repository`` and ``worker_container_tag`` in ``extra``
(e.g. a JRE-bearing image for a Java coordinator); both are required to
compose an override, otherwise the executor default applies.
"""
if (extra := self._coordinator_extra(queue)) is None:
return None
if (repo := extra.get("worker_container_repository")) and (tag := extra.get("worker_container_tag")):
return f"{repo}:{tag}"
return None

def execute_async(
self,
key: TaskInstanceKey,
Expand Down Expand Up @@ -225,8 +283,31 @@ def execute_async(
pod_template_file = executor_config.get("pod_template_file", None)
else:
pod_template_file = None

# A coordinator-level pod_template wins (e.g. a JVM image for JavaCoordinator)
if (coordinator_pod_template_file := self._coordinator_pod_template_file(queue)) is not None:
self.log.debug(
"Using coordinator-declared pod template %s for task %s in queue %s",
coordinator_pod_template_file,
key,
queue,
)
pod_template_file = coordinator_pod_template_file

# The base image is not carried by a pod template, so a coordinator routes
# its worker base image separately (e.g. a JRE image for a Java queue).
if (coordinator_kube_image := self._coordinator_kube_image(queue)) is not None:
self.log.debug(
"Using coordinator-declared base image %s for task %s in queue %s",
coordinator_kube_image,
key,
queue,
)

self.event_buffer[key] = (TaskInstanceState.QUEUED, self.scheduler_job_id)
self.task_queue.put(KubernetesJob(key, command, kube_executor_config, pod_template_file))
self.task_queue.put(
KubernetesJob(key, command, kube_executor_config, pod_template_file, coordinator_kube_image)
)

def queue_workload(self, workload: workloads.All, session: Session | None) -> None:
from airflow.executors import workloads
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class KubernetesJob(NamedTuple):
command: Sequence[str]
kube_executor_config: Any
pod_template_file: str | None
kube_image: str | None = None


ALL_NAMESPACES = "ALL_NAMESPACES"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,7 @@ def run_next(self, next_job: KubernetesJob) -> None:
command = next_job.command
kube_executor_config = next_job.kube_executor_config
pod_template_file = next_job.pod_template_file
kube_image = next_job.kube_image or self.kube_config.kube_image

dag_id, task_id, run_id, try_number, map_index = key
if len(command) == 1:
Expand Down Expand Up @@ -586,7 +587,7 @@ def run_next(self, next_job: KubernetesJob) -> None:
pod_id=create_unique_id(dag_id, task_id),
dag_id=dag_id,
task_id=task_id,
kube_image=self.kube_config.kube_image,
kube_image=kube_image,
try_number=try_number,
map_index=map_index,
date=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from airflow.providers.cncf.kubernetes.executors.kubernetes_executor_types import (
ADOPTED,
KubernetesJob,
KubernetesResults,
KubernetesWatch,
)
Expand Down Expand Up @@ -66,7 +67,11 @@
from airflow.utils.state import State, TaskInstanceState

from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS
from tests_common.test_utils.version_compat import (
AIRFLOW_V_3_0_PLUS,
AIRFLOW_V_3_2_PLUS,
AIRFLOW_V_3_3_PLUS,
)

try:
# Check whether a module-level function from stats is importable.
Expand Down Expand Up @@ -863,11 +868,10 @@ def test_pod_template_file_override_in_executor_config(

assert not executor.task_queue.empty()
task = executor.task_queue.get_nowait()
_, _, expected_executor_config, expected_pod_template_file = task
executor.task_queue.task_done()
# Test that the correct values have been put to queue
assert expected_executor_config.metadata.labels == {"release": "stable"}
assert expected_pod_template_file == executor_template_file
assert task.kube_executor_config.metadata.labels == {"release": "stable"}
assert task.pod_template_file == executor_template_file

self.kubernetes_executor.kube_scheduler.run_next(task)
mock_run_pod_async.assert_called_once_with(
Expand Down Expand Up @@ -915,6 +919,205 @@ def test_pod_template_file_override_in_executor_config(
finally:
executor.end()

@pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="The coordinator interface only support since 3.3+")
@pytest.mark.parametrize(
("coordinator_template", "executor_config", "expected_template"),
[
pytest.param("/coord/java.yaml", None, "/coord/java.yaml", id="coordinator-template-used"),
pytest.param(
"/coord/java.yaml",
{"pod_template_file": "/from/executor_config.yaml"},
"/coord/java.yaml",
id="coordinator-template-wins",
),
pytest.param(
None,
{"pod_template_file": "/from/executor_config.yaml"},
"/from/executor_config.yaml",
id="executor-config-used-without-coordinator",
),
],
)
@mock.patch("airflow.providers.cncf.kubernetes.executors.kubernetes_executor_utils.KubernetesJobWatcher")
@mock.patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client")
@mock.patch.object(KubernetesExecutor, "_coordinator_pod_template_file")
def test_coordinator_pod_template_file_used_for_queue(
self,
mock_coordinator_template,
mock_get_kube_client,
mock_kubernetes_job_watcher,
coordinator_template,
executor_config,
expected_template,
):
"""A queue's coordinator template overrides executor_config; without one, executor_config is used."""
mock_coordinator_template.return_value = coordinator_template
executor = self.kubernetes_executor
executor.start()
try:
executor.execute_async(
key=TaskInstanceKey("dag", "task", "run_id", 1, -1),
queue="java",
command=["airflow", "tasks", "run", "true", "some_parameter"],
executor_config=executor_config,
)
assert not executor.task_queue.empty()
queued_job = executor.task_queue.get_nowait()
executor.task_queue.task_done()
assert queued_job.pod_template_file == expected_template
finally:
executor.end()

@pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="The coordinator interface only support since 3.3+")
def test_coordinator_pod_template_file_skips_lookup_without_queue(self):
"""No queue means no coordinator lookup (and no Task SDK import)."""
with mock.patch("airflow.sdk.execution_time.coordinator.get_coordinator_manager") as mock_get_manager:
assert self.kubernetes_executor._coordinator_pod_template_file(None) is None
mock_get_manager.assert_not_called()

@pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="The coordinator interface only support since 3.3+")
@pytest.mark.parametrize(
("extra", "expected"),
[
pytest.param({"pod_template_file": "/coord/go.yaml"}, "/coord/go.yaml", id="template-in-extra"),
pytest.param({"other": "value"}, None, id="extra-without-template"),
pytest.param(None, None, id="no-extra"),
],
)
def test_coordinator_pod_template_file_reads_extra(self, extra, expected):
"""The template is read from the queue coordinator's ``extra`` mapping."""
with mock.patch("airflow.sdk.execution_time.coordinator.get_coordinator_manager") as mock_get_manager:
mock_get_manager.return_value.extra_for_queue.return_value = extra
assert self.kubernetes_executor._coordinator_pod_template_file("go") == expected
mock_get_manager.return_value.extra_for_queue.assert_called_once_with("go")

@pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="The coordinator interface only support since 3.3+")
def test_coordinator_pod_template_file_returns_none_on_old_task_sdk(self):
"""Pre-3.3 Task SDKs lack get_coordinator_manager; the import error falls back to None."""
with mock.patch.dict("sys.modules", {"airflow.sdk.execution_time.coordinator": None}):
assert self.kubernetes_executor._coordinator_pod_template_file("go") is None

@pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="The coordinator interface only support since 3.3+")
@mock.patch("airflow.providers.cncf.kubernetes.executors.kubernetes_executor_utils.KubernetesJobWatcher")
@mock.patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client")
@mock.patch.object(KubernetesExecutor, "_coordinator_kube_image", return_value="repo/java:1")
def test_coordinator_kube_image_carried_on_job(
self,
mock_coordinator_kube_image,
mock_get_kube_client,
mock_kubernetes_job_watcher,
):
"""A coordinator base image resolved by queue rides on the queued job."""
executor = self.kubernetes_executor
executor.start()
try:
executor.execute_async(
key=TaskInstanceKey("dag", "task", "run_id", 1, -1),
queue="java",
command=["airflow", "tasks", "run", "true", "some_parameter"],
)
queued_job = executor.task_queue.get_nowait()
executor.task_queue.task_done()
assert queued_job.kube_image == "repo/java:1"
finally:
executor.end()

@pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="The coordinator interface only support since 3.3+")
def test_coordinator_kube_image_skips_lookup_without_queue(self):
"""No queue means no coordinator lookup (and no Task SDK import)."""
with mock.patch("airflow.sdk.execution_time.coordinator.get_coordinator_manager") as mock_get_manager:
assert self.kubernetes_executor._coordinator_kube_image(None) is None
mock_get_manager.assert_not_called()

@pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="The coordinator interface only support since 3.3+")
@pytest.mark.parametrize(
("extra", "expected"),
[
pytest.param(
{"worker_container_repository": "repo/java", "worker_container_tag": "17"},
"repo/java:17",
id="repository-and-tag",
),
pytest.param({"worker_container_repository": "repo/java"}, None, id="repository-only"),
pytest.param({"worker_container_tag": "17"}, None, id="tag-only"),
pytest.param({"other": "value"}, None, id="extra-without-image"),
pytest.param(None, None, id="no-extra"),
],
)
def test_coordinator_kube_image_reads_extra(self, extra, expected):
"""The base image is composed from the queue coordinator's ``extra`` mapping."""
with mock.patch("airflow.sdk.execution_time.coordinator.get_coordinator_manager") as mock_get_manager:
mock_get_manager.return_value.extra_for_queue.return_value = extra
assert self.kubernetes_executor._coordinator_kube_image("java") == expected
mock_get_manager.return_value.extra_for_queue.assert_called_once_with("java")

@pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="The coordinator interface only support since 3.3+")
def test_coordinator_kube_image_returns_none_on_old_task_sdk(self):
"""Pre-3.3 Task SDKs lack get_coordinator_manager; the import error falls back to None."""
with mock.patch.dict("sys.modules", {"airflow.sdk.execution_time.coordinator": None}):
assert self.kubernetes_executor._coordinator_kube_image("java") is None

@pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="The coordinator interface only support since 3.3+")
@pytest.mark.parametrize("exc_type", ["airflow_config_exception", "value_error"])
def test_coordinator_extra_falls_back_on_invalid_config(self, exc_type):
"""A malformed ``[sdk]`` coordinator config must degrade gracefully, not crash the scheduler."""
if exc_type == "airflow_config_exception":
from airflow.sdk.exceptions import AirflowConfigException

exc = AirflowConfigException("invalid json")
else:
exc = ValueError("invalid coordinator key")
with mock.patch("airflow.sdk.execution_time.coordinator.get_coordinator_manager") as mock_get_manager:
mock_get_manager.return_value.extra_for_queue.side_effect = exc
assert self.kubernetes_executor._coordinator_pod_template_file("java") is None
assert self.kubernetes_executor._coordinator_kube_image("java") is None

@pytest.mark.skipif(
AirflowKubernetesScheduler is None, reason="kubernetes python package is not installed"
)
@pytest.mark.parametrize(
("job_image", "use_default"),
[
pytest.param("repo/java:17", False, id="job-image-wins"),
pytest.param(None, True, id="falls-back-to-kube_config"),
],
)
@mock.patch(
"airflow.providers.cncf.kubernetes.executors.kubernetes_executor_utils.AirflowKubernetesScheduler.run_pod_async"
)
@mock.patch("airflow.providers.cncf.kubernetes.executors.kubernetes_executor_utils.PodGenerator")
@mock.patch(
"airflow.providers.cncf.kubernetes.executors.kubernetes_executor_utils.get_base_pod_from_template"
)
@mock.patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client")
def test_run_next_applies_job_kube_image(
self,
mock_get_kube_client,
mock_get_base_pod,
mock_pod_generator,
mock_run_pod_async,
job_image,
use_default,
):
"""``run_next`` uses the job's coordinator image, falling back to the kube_config default."""
executor = self.kubernetes_executor
executor.start()
try:
scheduler = executor.kube_scheduler
scheduler.run_next(
KubernetesJob(
key=TaskInstanceKey("dag", "task", "run_id", 1, -1),
command=["airflow", "tasks", "run", "true", "some_parameter"],
kube_executor_config=None,
pod_template_file=None,
kube_image=job_image,
)
)
expected = scheduler.kube_config.kube_image if use_default else job_image
assert mock_pod_generator.construct_pod.call_args.kwargs["kube_image"] == expected
finally:
executor.end()

@pytest.mark.db_test
@mock.patch("airflow.providers.cncf.kubernetes.executors.kubernetes_executor_utils.KubernetesJobWatcher")
@mock.patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client")
Expand Down
Loading