Skip to content

Commit bdaefd0

Browse files
mpolson64meta-codesync[bot]
authored andcommitted
Add _trial_type_to_metric_names to base Experiment
Summary: This is Phase 1 of moving MultiTypeExperiment features into the base Experiment class, enabling eventual deprecation of MultiTypeExperiment. Adds `_trial_type_to_metric_names: dict[str, set[str]]` to Experiment — a mapping from trial type to the set of metric names relevant to that type. This is the natural complement to the existing `_trial_type_to_runner` dict. Along with it, adds the following properties and methods to Experiment: - `trial_type_to_metric_names`: read-only property (shallow copy) - `metric_to_trial_type`: computed inverse mapping, with optimization config metrics pinned to `default_trial_type` - `metrics_for_trial_type(trial_type)`: returns Metric objects for a given trial type - `default_trials`: returns trial indices matching the default type MultiTypeExperiment is updated to populate `_trial_type_to_metric_names` alongside `_metric_to_trial_type` in all mutation paths (init, optimization_config setter, add/update/remove tracking metric). The redundant MTE overrides for `metric_to_trial_type`, `metrics_for_trial_type`, `default_trials`, and `default_trial_type` are removed — they are now inherited from the base class. The JSON decoder is updated to rebuild `_trial_type_to_metric_names` from `_metric_to_trial_type` during deserialization for backward compatibility. Differential Revision: D94970662
1 parent bf9f66a commit bdaefd0

3 files changed

Lines changed: 87 additions & 49 deletions

File tree

ax/core/experiment.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,11 @@ def __init__(
160160
self._trial_type_to_runner: dict[str | None, Runner | None] = {
161161
default_trial_type: runner
162162
}
163+
# Maps each trial type to the set of metric names relevant to that type.
164+
# This is the complement of _trial_type_to_runner and is used for
165+
# multi-type experiments where different metrics apply to different
166+
# trial types.
167+
self._trial_type_to_metric_names: dict[str, set[str]] = {}
163168
# Used to keep track of whether any trials on the experiment
164169
# specify a TTL. Since trials need to be checked for their TTL's
165170
# expiration often, having this attribute helps avoid unnecessary
@@ -1910,6 +1915,58 @@ def default_trial_type(self) -> str | None:
19101915
"""
19111916
return self._default_trial_type
19121917

1918+
@property
1919+
def trial_type_to_metric_names(self) -> dict[str, set[str]]:
1920+
"""Map from trial type to the set of metric names relevant to that
1921+
type.
1922+
1923+
Returns a shallow copy of the internal mapping.
1924+
"""
1925+
return dict(self._trial_type_to_metric_names)
1926+
1927+
@property
1928+
def metric_to_trial_type(self) -> dict[str, str]:
1929+
"""Map each metric name to its associated trial type.
1930+
1931+
Computed from ``_trial_type_to_metric_names``. When a
1932+
``default_trial_type`` is set and an ``optimization_config`` exists,
1933+
optimization config metrics are pinned to the default trial type.
1934+
"""
1935+
result: dict[str, str] = {}
1936+
for trial_type, metric_names in self._trial_type_to_metric_names.items():
1937+
for name in metric_names:
1938+
result[name] = trial_type
1939+
if (
1940+
self._default_trial_type is not None
1941+
and self._optimization_config is not None
1942+
):
1943+
for metric_name in self._optimization_config.metric_names:
1944+
result[metric_name] = self._default_trial_type
1945+
return result
1946+
1947+
def metrics_for_trial_type(self, trial_type: str) -> list[Metric]:
1948+
"""Return the metrics associated with a given trial type.
1949+
1950+
Args:
1951+
trial_type: The trial type to look up metrics for.
1952+
1953+
Raises:
1954+
ValueError: If the trial type is not supported.
1955+
"""
1956+
if not self.supports_trial_type(trial_type):
1957+
raise ValueError(f"Trial type `{trial_type}` is not supported.")
1958+
valid_names = self._trial_type_to_metric_names.get(trial_type, set())
1959+
return [self._metrics[name] for name in valid_names if name in self._metrics]
1960+
1961+
@property
1962+
def default_trials(self) -> set[int]:
1963+
"""Return the indices for trials of the default type."""
1964+
return {
1965+
idx
1966+
for idx, trial in self.trials.items()
1967+
if trial.trial_type == self.default_trial_type
1968+
}
1969+
19131970
def runner_for_trial_type(self, trial_type: str | None) -> Runner | None:
19141971
"""The default runner to use for a given trial type.
19151972

ax/core/multi_type_experiment.py

Lines changed: 24 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,16 @@ def __init__(
9696
default_data_type=default_data_type,
9797
)
9898

99-
# Ensure tracking metrics are registered in _metric_to_trial_type.
99+
# Ensure tracking metrics are registered in _metric_to_trial_type
100+
# and _trial_type_to_metric_names.
100101
# super().__init__ sets self._metrics directly, bypassing
101102
# add_tracking_metric, so tracking metrics won't be in
102103
# _metric_to_trial_type yet.
103104
for m in tracking_metrics or []:
104105
if m.name not in self._metric_to_trial_type:
105-
self._metric_to_trial_type[m.name] = none_throws(
106-
self._default_trial_type
107-
)
106+
tt = none_throws(self._default_trial_type)
107+
self._metric_to_trial_type[m.name] = tt
108+
self._trial_type_to_metric_names.setdefault(tt, set()).add(m.name)
108109

109110
def add_trial_type(self, trial_type: str, runner: Runner) -> Self:
110111
"""Add a new trial_type to be supported by this experiment.
@@ -129,9 +130,9 @@ def optimization_config(self, optimization_config: OptimizationConfig) -> None:
129130
for metric_name in optimization_config.metric_names:
130131
# Optimization config metrics are required to be the default trial type
131132
# currently. TODO: remove that restriction (T202797235)
132-
self._metric_to_trial_type[metric_name] = none_throws(
133-
self.default_trial_type
134-
)
133+
tt = none_throws(self.default_trial_type)
134+
self._metric_to_trial_type[metric_name] = tt
135+
self._trial_type_to_metric_names.setdefault(tt, set()).add(metric_name)
135136

136137
def update_runner(self, trial_type: str, runner: Runner) -> Self:
137138
"""Update the default runner for an existing trial_type.
@@ -166,7 +167,9 @@ def add_tracking_metric(
166167
raise ValueError(f"`{trial_type}` is not a supported trial type.")
167168

168169
super().add_tracking_metric(metric)
169-
self._metric_to_trial_type[metric.name] = none_throws(trial_type)
170+
tt = none_throws(trial_type)
171+
self._metric_to_trial_type[metric.name] = tt
172+
self._trial_type_to_metric_names.setdefault(tt, set()).add(metric.name)
170173
if canonical_name is not None:
171174
self._metric_to_canonical_name[metric.name] = canonical_name
172175
return self
@@ -242,7 +245,14 @@ def update_tracking_metric(
242245
raise ValueError(f"`{trial_type}` is not a supported trial type.")
243246

244247
super().update_tracking_metric(metric)
245-
self._metric_to_trial_type[metric.name] = none_throws(trial_type)
248+
# Remove from old trial type set
249+
old_tt = self._metric_to_trial_type.get(metric.name)
250+
if old_tt is not None and old_tt in self._trial_type_to_metric_names:
251+
self._trial_type_to_metric_names[old_tt].discard(metric.name)
252+
# Add to new trial type set
253+
tt = none_throws(trial_type)
254+
self._metric_to_trial_type[metric.name] = tt
255+
self._trial_type_to_metric_names.setdefault(tt, set()).add(metric.name)
246256
if canonical_name is not None:
247257
self._metric_to_canonical_name[metric.name] = canonical_name
248258
return self
@@ -252,6 +262,11 @@ def remove_tracking_metric(self, metric_name: str) -> Self:
252262
if metric_name not in self._metrics:
253263
raise ValueError(f"Metric `{metric_name}` doesn't exist on experiment.")
254264

265+
# Clean up _trial_type_to_metric_names
266+
old_tt = self._metric_to_trial_type.get(metric_name)
267+
if old_tt is not None and old_tt in self._trial_type_to_metric_names:
268+
self._trial_type_to_metric_names[old_tt].discard(metric_name)
269+
255270
# Required fields
256271
del self._metrics[metric_name]
257272
del self._metric_to_trial_type[metric_name]
@@ -295,46 +310,6 @@ def _fetch_trial_data(
295310
# Invoke parent's fetch method using only metrics for this trial_type
296311
return super()._fetch_trial_data(trial.index, metrics=metrics, **kwargs)
297312

298-
@property
299-
def default_trials(self) -> set[int]:
300-
"""Return the indicies for trials of the default type."""
301-
return {
302-
idx
303-
for idx, trial in self.trials.items()
304-
if trial.trial_type == self.default_trial_type
305-
}
306-
307-
@property
308-
def metric_to_trial_type(self) -> dict[str, str]:
309-
"""Map metrics to trial types.
310-
311-
Adds in default trial type for OC metrics to custom defined trial types..
312-
"""
313-
opt_config_types = {
314-
metric_name: self.default_trial_type
315-
for metric_name in self.optimization_config.metric_names
316-
}
317-
return {**opt_config_types, **self._metric_to_trial_type}
318-
319-
# -- Overridden functions from Base Experiment Class --
320-
@property
321-
def default_trial_type(self) -> str | None:
322-
"""Default trial type assigned to trials in this experiment."""
323-
return self._default_trial_type
324-
325-
def metrics_for_trial_type(self, trial_type: str) -> list[Metric]:
326-
"""The default runner to use for a given trial type.
327-
328-
Looks up the appropriate runner for this trial type in the trial_type_to_runner.
329-
"""
330-
if not self.supports_trial_type(trial_type):
331-
raise ValueError(f"Trial type `{trial_type}` is not supported.")
332-
return [
333-
self.metrics[metric_name]
334-
for metric_name, metric_trial_type in self._metric_to_trial_type.items()
335-
if metric_trial_type == trial_type
336-
]
337-
338313
def supports_trial_type(self, trial_type: str | None) -> bool:
339314
"""Whether this experiment allows trials of the given type.
340315

ax/storage/json_store/decoder.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -720,6 +720,12 @@ def multi_type_experiment_from_json(
720720
experiment._metric_to_trial_type = _metric_to_trial_type
721721
experiment._trial_type_to_runner = _trial_type_to_runner
722722

723+
# Rebuild _trial_type_to_metric_names from _metric_to_trial_type
724+
trial_type_to_metric_names: dict[str, set[str]] = {}
725+
for metric_name, trial_type in _metric_to_trial_type.items():
726+
trial_type_to_metric_names.setdefault(trial_type, set()).add(metric_name)
727+
experiment._trial_type_to_metric_names = trial_type_to_metric_names
728+
723729
_load_experiment_info(
724730
exp=experiment,
725731
exp_info=experiment_info,

0 commit comments

Comments
 (0)