Skip to content

Commit 70772ea

Browse files
mpolson64meta-codesync[bot]
authored andcommitted
Add _trial_type_to_metric_names to base Experiment (#5004)
Summary: Pull Request resolved: #5004 This is Phase 1 of moving MultiTypeExperiment features into the base Experiment class, enabling eventual deprecation of MultiTypeExperiment. Adds `_trial_type_to_metric_names: dict[str, set[str]]` to Experiment — a mapping from trial type to the set of metric names relevant to that type. This is the natural complement to the existing `_trial_type_to_runner` dict. Along with it, adds the following properties and methods to Experiment: - `trial_type_to_metric_names`: read-only property (shallow copy) - `metric_to_trial_type`: computed inverse mapping, with optimization config metrics pinned to `default_trial_type` - `metrics_for_trial_type(trial_type)`: returns Metric objects for a given trial type - `default_trials`: returns trial indices matching the default type MultiTypeExperiment is updated to populate `_trial_type_to_metric_names` alongside `_metric_to_trial_type` in all mutation paths (init, optimization_config setter, add/update/remove tracking metric). The redundant MTE overrides for `metric_to_trial_type`, `metrics_for_trial_type`, `default_trials`, and `default_trial_type` are removed — they are now inherited from the base class. The JSON decoder is updated to rebuild `_trial_type_to_metric_names` from `_metric_to_trial_type` during deserialization for backward compatibility. Reviewed By: lena-kashtelyan Differential Revision: D94970662
1 parent 6e527d5 commit 70772ea

3 files changed

Lines changed: 86 additions & 49 deletions

File tree

ax/core/experiment.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,11 @@ def __init__(
161161
self._trial_type_to_runner: dict[str | None, Runner | None] = {
162162
default_trial_type: runner
163163
}
164+
# Maps each trial type to the set of metric names relevant to that type.
165+
# This is the complement of _trial_type_to_runner and is used for
166+
# multi-type experiments where different metrics apply to different
167+
# trial types.
168+
self._trial_type_to_metric_names: dict[str, set[str]] = {}
164169
# Used to keep track of whether any trials on the experiment
165170
# specify a TTL. Since trials need to be checked for their TTL's
166171
# expiration often, having this attribute helps avoid unnecessary
@@ -1928,6 +1933,57 @@ def default_trial_type(self) -> str | None:
19281933
"""
19291934
return self._default_trial_type
19301935

1936+
@property
1937+
def trial_type_to_metric_names(self) -> dict[str, set[str]]:
1938+
"""Map from trial type to the set of metric names relevant to that
1939+
type.
1940+
1941+
Returns a shallow copy of the internal mapping.
1942+
"""
1943+
return dict(self._trial_type_to_metric_names)
1944+
1945+
@property
1946+
def metric_to_trial_type(self) -> dict[str, str]:
1947+
"""Map each metric name to its associated trial type.
1948+
1949+
Computed from ``_trial_type_to_metric_names``. When a
1950+
``default_trial_type`` is set and an ``optimization_config`` exists,
1951+
optimization config metrics are pinned to the default trial type.
1952+
"""
1953+
result: dict[str, str] = {}
1954+
for trial_type, metric_names in self._trial_type_to_metric_names.items():
1955+
for name in metric_names:
1956+
result[name] = trial_type
1957+
opt_config = self._optimization_config
1958+
default_trial_type = self._default_trial_type
1959+
if default_trial_type is not None and opt_config is not None:
1960+
for metric_name in opt_config.metric_names:
1961+
result[metric_name] = default_trial_type
1962+
return result
1963+
1964+
def metrics_for_trial_type(self, trial_type: str) -> list[Metric]:
1965+
"""Return the metrics associated with a given trial type.
1966+
1967+
Args:
1968+
trial_type: The trial type to look up metrics for.
1969+
1970+
Raises:
1971+
ValueError: If the trial type is not supported.
1972+
"""
1973+
if not self.supports_trial_type(trial_type):
1974+
raise ValueError(f"Trial type `{trial_type}` is not supported.")
1975+
valid_names = self._trial_type_to_metric_names.get(trial_type, set())
1976+
return [self._metrics[name] for name in valid_names if name in self._metrics]
1977+
1978+
@property
1979+
def default_trials(self) -> set[int]:
1980+
"""Return the indices for trials of the default type."""
1981+
return {
1982+
idx
1983+
for idx, trial in self.trials.items()
1984+
if trial.trial_type == self.default_trial_type
1985+
}
1986+
19311987
def runner_for_trial_type(self, trial_type: str | None) -> Runner | None:
19321988
"""The default runner to use for a given trial type.
19331989

ax/core/multi_type_experiment.py

Lines changed: 24 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,16 @@ def __init__(
9696
default_data_type=default_data_type,
9797
)
9898

99-
# Ensure tracking metrics are registered in _metric_to_trial_type.
99+
# Ensure tracking metrics are registered in _metric_to_trial_type
100+
# and _trial_type_to_metric_names.
100101
# super().__init__ sets self._metrics directly, bypassing
101102
# add_tracking_metric, so tracking metrics won't be in
102103
# _metric_to_trial_type yet.
103104
for m in tracking_metrics or []:
104105
if m.name not in self._metric_to_trial_type:
105-
self._metric_to_trial_type[m.name] = none_throws(
106-
self._default_trial_type
107-
)
106+
tt = none_throws(self._default_trial_type)
107+
self._metric_to_trial_type[m.name] = tt
108+
self._trial_type_to_metric_names.setdefault(tt, set()).add(m.name)
108109

109110
def add_trial_type(self, trial_type: str, runner: Runner) -> Self:
110111
"""Add a new trial_type to be supported by this experiment.
@@ -129,9 +130,9 @@ def optimization_config(self, optimization_config: OptimizationConfig) -> None:
129130
for metric_name in optimization_config.metric_names:
130131
# Optimization config metrics are required to be the default trial type
131132
# currently. TODO: remove that restriction (T202797235)
132-
self._metric_to_trial_type[metric_name] = none_throws(
133-
self.default_trial_type
134-
)
133+
tt = none_throws(self.default_trial_type)
134+
self._metric_to_trial_type[metric_name] = tt
135+
self._trial_type_to_metric_names.setdefault(tt, set()).add(metric_name)
135136

136137
def update_runner(self, trial_type: str, runner: Runner) -> Self:
137138
"""Update the default runner for an existing trial_type.
@@ -166,7 +167,9 @@ def add_tracking_metric(
166167
raise ValueError(f"`{trial_type}` is not a supported trial type.")
167168

168169
super().add_tracking_metric(metric)
169-
self._metric_to_trial_type[metric.name] = none_throws(trial_type)
170+
tt = none_throws(trial_type)
171+
self._metric_to_trial_type[metric.name] = tt
172+
self._trial_type_to_metric_names.setdefault(tt, set()).add(metric.name)
170173
if canonical_name is not None:
171174
self._metric_to_canonical_name[metric.name] = canonical_name
172175
return self
@@ -242,7 +245,14 @@ def update_tracking_metric(
242245
raise ValueError(f"`{trial_type}` is not a supported trial type.")
243246

244247
super().update_tracking_metric(metric)
245-
self._metric_to_trial_type[metric.name] = none_throws(trial_type)
248+
# Remove from old trial type set
249+
old_tt = self._metric_to_trial_type.get(metric.name)
250+
if old_tt is not None and old_tt in self._trial_type_to_metric_names:
251+
self._trial_type_to_metric_names[old_tt].discard(metric.name)
252+
# Add to new trial type set
253+
tt = none_throws(trial_type)
254+
self._metric_to_trial_type[metric.name] = tt
255+
self._trial_type_to_metric_names.setdefault(tt, set()).add(metric.name)
246256
if canonical_name is not None:
247257
self._metric_to_canonical_name[metric.name] = canonical_name
248258
return self
@@ -252,6 +262,11 @@ def remove_tracking_metric(self, metric_name: str) -> Self:
252262
if metric_name not in self._metrics:
253263
raise ValueError(f"Metric `{metric_name}` doesn't exist on experiment.")
254264

265+
# Clean up _trial_type_to_metric_names
266+
old_tt = self._metric_to_trial_type.get(metric_name)
267+
if old_tt is not None and old_tt in self._trial_type_to_metric_names:
268+
self._trial_type_to_metric_names[old_tt].discard(metric_name)
269+
255270
# Required fields
256271
del self._metrics[metric_name]
257272
del self._metric_to_trial_type[metric_name]
@@ -295,46 +310,6 @@ def _fetch_trial_data(
295310
# Invoke parent's fetch method using only metrics for this trial_type
296311
return super()._fetch_trial_data(trial.index, metrics=metrics, **kwargs)
297312

298-
@property
299-
def default_trials(self) -> set[int]:
300-
"""Return the indicies for trials of the default type."""
301-
return {
302-
idx
303-
for idx, trial in self.trials.items()
304-
if trial.trial_type == self.default_trial_type
305-
}
306-
307-
@property
308-
def metric_to_trial_type(self) -> dict[str, str]:
309-
"""Map metrics to trial types.
310-
311-
Adds in default trial type for OC metrics to custom defined trial types..
312-
"""
313-
opt_config_types = {
314-
metric_name: self.default_trial_type
315-
for metric_name in self.optimization_config.metric_names
316-
}
317-
return {**opt_config_types, **self._metric_to_trial_type}
318-
319-
# -- Overridden functions from Base Experiment Class --
320-
@property
321-
def default_trial_type(self) -> str | None:
322-
"""Default trial type assigned to trials in this experiment."""
323-
return self._default_trial_type
324-
325-
def metrics_for_trial_type(self, trial_type: str) -> list[Metric]:
326-
"""The default runner to use for a given trial type.
327-
328-
Looks up the appropriate runner for this trial type in the trial_type_to_runner.
329-
"""
330-
if not self.supports_trial_type(trial_type):
331-
raise ValueError(f"Trial type `{trial_type}` is not supported.")
332-
return [
333-
self.metrics[metric_name]
334-
for metric_name, metric_trial_type in self._metric_to_trial_type.items()
335-
if metric_trial_type == trial_type
336-
]
337-
338313
def supports_trial_type(self, trial_type: str | None) -> bool:
339314
"""Whether this experiment allows trials of the given type.
340315

ax/storage/json_store/decoder.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -720,6 +720,12 @@ def multi_type_experiment_from_json(
720720
experiment._metric_to_trial_type = _metric_to_trial_type
721721
experiment._trial_type_to_runner = _trial_type_to_runner
722722

723+
# Rebuild _trial_type_to_metric_names from _metric_to_trial_type
724+
trial_type_to_metric_names: dict[str, set[str]] = {}
725+
for metric_name, trial_type in _metric_to_trial_type.items():
726+
trial_type_to_metric_names.setdefault(trial_type, set()).add(metric_name)
727+
experiment._trial_type_to_metric_names = trial_type_to_metric_names
728+
723729
_load_experiment_info(
724730
exp=experiment,
725731
exp_info=experiment_info,

0 commit comments

Comments
 (0)