Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ def __init__(
self._trial_type_to_runner: dict[str | None, Runner | None] = {
default_trial_type: runner
}
# Maps each trial type to the set of metric names relevant to that type.
# This is the complement of _trial_type_to_runner and is used for
# multi-type experiments where different metrics apply to different
# trial types.
self._trial_type_to_metric_names: dict[str, set[str]] = {}
# Used to keep track of whether any trials on the experiment
# specify a TTL. Since trials need to be checked for their TTL's
# expiration often, having this attribute helps avoid unnecessary
Expand Down Expand Up @@ -1928,6 +1933,57 @@ def default_trial_type(self) -> str | None:
"""
return self._default_trial_type

@property
def trial_type_to_metric_names(self) -> dict[str, set[str]]:
"""Map from trial type to the set of metric names relevant to that
type.

Returns a shallow copy of the internal mapping.
"""
return dict(self._trial_type_to_metric_names)

@property
def metric_to_trial_type(self) -> dict[str, str]:
"""Map each metric name to its associated trial type.

Computed from ``_trial_type_to_metric_names``. When a
``default_trial_type`` is set and an ``optimization_config`` exists,
optimization config metrics are pinned to the default trial type.
"""
result: dict[str, str] = {}
for trial_type, metric_names in self._trial_type_to_metric_names.items():
for name in metric_names:
result[name] = trial_type
opt_config = self._optimization_config
default_trial_type = self._default_trial_type
if default_trial_type is not None and opt_config is not None:
for metric_name in opt_config.metric_names:
result[metric_name] = default_trial_type
return result

def metrics_for_trial_type(self, trial_type: str) -> list[Metric]:
"""Return the metrics associated with a given trial type.

Args:
trial_type: The trial type to look up metrics for.

Raises:
ValueError: If the trial type is not supported.
"""
if not self.supports_trial_type(trial_type):
raise ValueError(f"Trial type `{trial_type}` is not supported.")
valid_names = self._trial_type_to_metric_names.get(trial_type, set())
return [self._metrics[name] for name in valid_names if name in self._metrics]

@property
def default_trials(self) -> set[int]:
"""Return the indices for trials of the default type."""
return {
idx
for idx, trial in self.trials.items()
if trial.trial_type == self.default_trial_type
}

def runner_for_trial_type(self, trial_type: str | None) -> Runner | None:
"""The default runner to use for a given trial type.

Expand Down
73 changes: 24 additions & 49 deletions ax/core/multi_type_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,16 @@ def __init__(
default_data_type=default_data_type,
)

# Ensure tracking metrics are registered in _metric_to_trial_type.
# Ensure tracking metrics are registered in _metric_to_trial_type
# and _trial_type_to_metric_names.
# super().__init__ sets self._metrics directly, bypassing
# add_tracking_metric, so tracking metrics won't be in
# _metric_to_trial_type yet.
for m in tracking_metrics or []:
if m.name not in self._metric_to_trial_type:
self._metric_to_trial_type[m.name] = none_throws(
self._default_trial_type
)
tt = none_throws(self._default_trial_type)
self._metric_to_trial_type[m.name] = tt
self._trial_type_to_metric_names.setdefault(tt, set()).add(m.name)

def add_trial_type(self, trial_type: str, runner: Runner) -> Self:
"""Add a new trial_type to be supported by this experiment.
Expand All @@ -129,9 +130,9 @@ def optimization_config(self, optimization_config: OptimizationConfig) -> None:
for metric_name in optimization_config.metric_names:
# Optimization config metrics are required to be the default trial type
# currently. TODO: remove that restriction (T202797235)
self._metric_to_trial_type[metric_name] = none_throws(
self.default_trial_type
)
tt = none_throws(self.default_trial_type)
self._metric_to_trial_type[metric_name] = tt
self._trial_type_to_metric_names.setdefault(tt, set()).add(metric_name)

def update_runner(self, trial_type: str, runner: Runner) -> Self:
"""Update the default runner for an existing trial_type.
Expand Down Expand Up @@ -166,7 +167,9 @@ def add_tracking_metric(
raise ValueError(f"`{trial_type}` is not a supported trial type.")

super().add_tracking_metric(metric)
self._metric_to_trial_type[metric.name] = none_throws(trial_type)
tt = none_throws(trial_type)
self._metric_to_trial_type[metric.name] = tt
self._trial_type_to_metric_names.setdefault(tt, set()).add(metric.name)
if canonical_name is not None:
self._metric_to_canonical_name[metric.name] = canonical_name
return self
Expand Down Expand Up @@ -242,7 +245,14 @@ def update_tracking_metric(
raise ValueError(f"`{trial_type}` is not a supported trial type.")

super().update_tracking_metric(metric)
self._metric_to_trial_type[metric.name] = none_throws(trial_type)
# Remove from old trial type set
old_tt = self._metric_to_trial_type.get(metric.name)
if old_tt is not None and old_tt in self._trial_type_to_metric_names:
self._trial_type_to_metric_names[old_tt].discard(metric.name)
# Add to new trial type set
tt = none_throws(trial_type)
self._metric_to_trial_type[metric.name] = tt
self._trial_type_to_metric_names.setdefault(tt, set()).add(metric.name)
if canonical_name is not None:
self._metric_to_canonical_name[metric.name] = canonical_name
return self
Expand All @@ -252,6 +262,11 @@ def remove_tracking_metric(self, metric_name: str) -> Self:
if metric_name not in self._metrics:
raise ValueError(f"Metric `{metric_name}` doesn't exist on experiment.")

# Clean up _trial_type_to_metric_names
old_tt = self._metric_to_trial_type.get(metric_name)
if old_tt is not None and old_tt in self._trial_type_to_metric_names:
self._trial_type_to_metric_names[old_tt].discard(metric_name)

# Required fields
del self._metrics[metric_name]
del self._metric_to_trial_type[metric_name]
Expand Down Expand Up @@ -295,46 +310,6 @@ def _fetch_trial_data(
# Invoke parent's fetch method using only metrics for this trial_type
return super()._fetch_trial_data(trial.index, metrics=metrics, **kwargs)

@property
def default_trials(self) -> set[int]:
"""Return the indicies for trials of the default type."""
return {
idx
for idx, trial in self.trials.items()
if trial.trial_type == self.default_trial_type
}

@property
def metric_to_trial_type(self) -> dict[str, str]:
"""Map metrics to trial types.

Adds in default trial type for OC metrics to custom defined trial types..
"""
opt_config_types = {
metric_name: self.default_trial_type
for metric_name in self.optimization_config.metric_names
}
return {**opt_config_types, **self._metric_to_trial_type}

# -- Overridden functions from Base Experiment Class --
@property
def default_trial_type(self) -> str | None:
"""Default trial type assigned to trials in this experiment."""
return self._default_trial_type

def metrics_for_trial_type(self, trial_type: str) -> list[Metric]:
"""The default runner to use for a given trial type.

Looks up the appropriate runner for this trial type in the trial_type_to_runner.
"""
if not self.supports_trial_type(trial_type):
raise ValueError(f"Trial type `{trial_type}` is not supported.")
return [
self.metrics[metric_name]
for metric_name, metric_trial_type in self._metric_to_trial_type.items()
if metric_trial_type == trial_type
]

def supports_trial_type(self, trial_type: str | None) -> bool:
"""Whether this experiment allows trials of the given type.

Expand Down
6 changes: 6 additions & 0 deletions ax/storage/json_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,12 @@ def multi_type_experiment_from_json(
experiment._metric_to_trial_type = _metric_to_trial_type
experiment._trial_type_to_runner = _trial_type_to_runner

# Rebuild _trial_type_to_metric_names from _metric_to_trial_type
trial_type_to_metric_names: dict[str, set[str]] = {}
for metric_name, trial_type in _metric_to_trial_type.items():
trial_type_to_metric_names.setdefault(trial_type, set()).add(metric_name)
experiment._trial_type_to_metric_names = trial_type_to_metric_names

_load_experiment_info(
exp=experiment,
exp_info=experiment_info,
Expand Down