diff --git a/doc/changes/dev/13680.other.rst b/doc/changes/dev/13680.other.rst new file mode 100644 index 00000000000..affeff62c03 --- /dev/null +++ b/doc/changes/dev/13680.other.rst @@ -0,0 +1 @@ +Document :attr:`~mne.Annotations.onset`, :attr:`~mne.Annotations.duration`, :attr:`~mne.Annotations.description`, and :attr:`~mne.Annotations.ch_names` attributes of :class:`mne.Annotations`, by `Famous077`_. \ No newline at end of file diff --git a/mne/annotations.py b/mne/annotations.py index c03e0610f28..1ef323e5234 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -158,48 +158,82 @@ def _validate_extras(extras, length: int): return _AnnotationsExtrasList(extras or [None] * length) -def _check_o_d_s_c_e(onset, duration, description, ch_names, extras): +def _check_onset(onset, n=None): + """Convert and validate onset to a 1D float array.""" onset = np.atleast_1d(np.array(onset, dtype=float)) if onset.ndim != 1: raise ValueError( f"Onset must be a one dimensional array, got {onset.ndim} (shape " f"{onset.shape})." ) + if n is not None and len(onset) != n: + raise ValueError( + f"Length of onset ({len(onset)}) must match the length of " + f"existing annotations ({n})." + ) + return onset + + +def _check_duration(duration, n): + """Convert and validate duration to a 1D float array of length n.""" duration = np.array(duration, dtype=float) if duration.ndim == 0 or duration.shape == (1,): - duration = np.repeat(duration, len(onset)) + duration = np.repeat(duration, n) if duration.ndim != 1: raise ValueError( f"Duration must be a one dimensional array, got {duration.ndim}." ) + if len(duration) != n: + raise ValueError( + f"Length of duration ({len(duration)}) must match the length of " + f"existing onset ({n})." + ) + return duration - description = np.array(description, dtype=str) + +def _check_description(description, n): + """Convert and validate description to a 1D str array of length n.""" + description = np.atleast_1d(np.array(description, dtype=str)) if description.ndim == 0 or description.shape == (1,): - description = np.repeat(description, len(onset)) + description = np.repeat(description, n) if description.ndim != 1: raise ValueError( f"Description must be a one dimensional array, got {description.ndim}." ) + if len(description) != n: + raise ValueError( + f"Length of description ({len(description)}) must match the " + f"length of existing onset ({n})." + ) _safe_name_list(description, "write", "description") + return description - # ch_names: convert to ndarray of tuples + +def _check_ch_names_annot(ch_names, n): + """Convert and validate ch_names to an ndarray of tuples of length n.""" _validate_type(ch_names, (None, tuple, list, np.ndarray), "ch_names") if ch_names is None: - ch_names = [()] * len(onset) + ch_names = [()] * n ch_names = list(ch_names) + if len(ch_names) != n: + raise ValueError( + f"Length of ch_names ({len(ch_names)}) must match the length of " + f"existing annotations ({n})." + ) for ai, ch in enumerate(ch_names): _validate_type(ch, (list, tuple, np.ndarray), f"ch_names[{ai}]") ch_names[ai] = tuple(ch) for ci, name in enumerate(ch_names[ai]): _validate_type(name, str, f"ch_names[{ai}][{ci}]") - ch_names = _ndarray_ch_names(ch_names) + return _ndarray_ch_names(ch_names) - if not (len(onset) == len(duration) == len(description) == len(ch_names)): - raise ValueError( - "Onset, duration, description, and ch_names must be " - f"equal in sizes, got {len(onset)}, {len(duration)}, " - f"{len(description)}, and {len(ch_names)}." - ) + +def _check_o_d_s_c_e(onset, duration, description, ch_names, extras): + onset = _check_onset(onset) + n = len(onset) + duration = _check_duration(duration, n) + description = _check_description(description, n) + ch_names = _check_ch_names_annot(ch_names, n) extras = _validate_extras(extras, len(onset)) return onset, duration, description, ch_names, extras @@ -408,7 +442,7 @@ def __init__( f"' '. Got: {orig_time}. Defaulting `orig_time` to None.", RuntimeWarning, ) - self.onset, self.duration, self.description, self.ch_names, self._extras = ( + self._onset, self._duration, self._description, self._ch_names, self._extras = ( _check_o_d_s_c_e(onset, duration, description, ch_names, extras) ) self._sort() # ensure we're sorted @@ -418,6 +452,96 @@ def orig_time(self): """The time base of the Annotations.""" return self._orig_time + @property + def onset(self): + """Onset of each annotation (in seconds). + + Returns + ------- + onset : array of shape (n_annotations,) + The onset of each annotation in seconds from the start of + the recording. + + See Also + -------- + :attr:`~mne.Annotations.duration` + :attr:`~mne.Annotations.description` + """ + return self._onset + + @onset.setter + def onset(self, onset): + onset = _check_onset(onset, n=len(self._onset)) + self._onset = onset + + @property + def duration(self): + """Duration of each annotation (in seconds). + + Returns + ------- + duration : array of shape (n_annotations,) + The duration of each annotation in seconds. + + See Also + -------- + :attr:`~mne.Annotations.onset` + :attr:`~mne.Annotations.description` + """ + return self._duration + + @duration.setter + def duration(self, duration): + n = len(self._duration) + duration = _check_duration(duration, n) + self._duration = duration + + @property + def description(self): + """Description of each annotation. + + Returns + ------- + description : array of shape (n_annotations,) + A string description for each annotation (e.g., event + label or condition name). + + See Also + -------- + :attr:`~mne.Annotations.onset` + :attr:`~mne.Annotations.duration` + """ + return self._description + + @description.setter + def description(self, description): + n = len(self._description) + description = _check_description(description, n) + self._description = description + + @property + def ch_names(self): + """Channel names associated with each annotation. + + Returns + ------- + ch_names : list of tuple + Channel names associated with each annotation. + + See Also + -------- + :attr:`~mne.Annotations.onset` + :attr:`~mne.Annotations.duration` + :attr:`~mne.Annotations.description` + """ + return self._ch_names + + @ch_names.setter + def ch_names(self, ch_names): + n = len(self._ch_names) + ch_names = _check_ch_names_annot(ch_names, n) + self._ch_names = ch_names + @property def extras(self): """The extras of the Annotations. @@ -573,11 +697,15 @@ def append(self, onset, duration, description, ch_names=None, *, extras=None): onset, duration, description, ch_names, extras = _check_o_d_s_c_e( onset, duration, description, ch_names, extras ) - self.onset = np.append(self.onset, onset) - self.duration = np.append(self.duration, duration) - self.description = np.append(self.description, description) - self.ch_names = np.append(self.ch_names, ch_names) - self.extras.extend(extras) + # Write directly to private attributes to avoid triggering the public + # setter validation, which would raise an error due to temporary length + # mismatches while fields are being extended one at a time. + # The data is already validated by _check_o_d_s_c_e above. + self._onset = np.append(self._onset, onset) + self._duration = np.append(self._duration, duration) + self._description = np.append(self._description, description) + self._ch_names = np.append(self._ch_names, ch_names) + self._extras.extend(extras) self._sort() return self @@ -600,10 +728,10 @@ def delete(self, idx): Index of the annotation to remove. Can be array-like to remove multiple indices. """ - self.onset = np.delete(self.onset, idx) - self.duration = np.delete(self.duration, idx) - self.description = np.delete(self.description, idx) - self.ch_names = np.delete(self.ch_names, idx) + self._onset = np.delete(self._onset, idx) + self._duration = np.delete(self._duration, idx) + self._description = np.delete(self._description, idx) + self._ch_names = np.delete(self._ch_names, idx) if isinstance(idx, int_like): del self.extras[idx] elif len(idx) > 0: @@ -740,11 +868,11 @@ def _sort(self): # the onset-then-duration hierarchy vals = sorted(zip(self.onset, self.duration, range(len(self)))) order = list(list(zip(*vals))[-1]) if len(vals) else [] - self.onset = self.onset[order] - self.duration = self.duration[order] - self.description = self.description[order] - self.ch_names = self.ch_names[order] - self.extras = [self.extras[i] for i in order] + self._onset = self._onset[order] + self._duration = self._duration[order] + self._description = self._description[order] + self._ch_names = self._ch_names[order] + self._extras = [self._extras[i] for i in order] return order def _get_crop_lims(self, tmin, tmax, use_orig_time): @@ -848,12 +976,12 @@ def crop( ch_names.append(ch) extras.append(extra) logger.debug(f"Cropping complete (kept {len(onsets)})") - self.onset = np.array(onsets, float) - self.duration = np.array(durations, float) - assert (self.duration >= 0).all() - self.description = np.array(descriptions, dtype=str) - self.ch_names = _ndarray_ch_names(ch_names) - self.extras = extras + self._onset = np.array(onsets, float) + self._duration = np.array(durations, float) + assert (self._duration >= 0).all() + self._description = np.array(descriptions, dtype=str) + self._ch_names = _ndarray_ch_names(ch_names) + self._extras = extras if emit_warning: omitted = np.array(out_of_bounds).sum() @@ -1168,11 +1296,15 @@ def __getstate__(self): def __setstate__(self, state): """Unpack from serialized format.""" self._orig_time = state["_orig_time"] - self.onset = state["onset"] - self.duration = state["duration"] - self.description = state["description"] - self.ch_names = state["ch_names"] - self.extras = state.get("_extras", [None] * len(self.onset)) + self._onset, self._duration, self._description, self._ch_names, self._extras = ( + _check_o_d_s_c_e( + state["onset"], + state["duration"], + state["description"], + state["ch_names"], + state.get("_extras", None), + ) + ) self._hed_version = state["_hed_version"] self.hed_string = _HEDStrings( state["hed_string"], hed_version=self._hed_version @@ -1211,6 +1343,7 @@ def append( onset, duration, description, ch_names, extras ) hed_string = self._check_hed_strings(hed_string, len(onset)) + hed_objs = [ self.hed_string._validate_hed_string(v, self.hed_string._schema) for v in hed_string @@ -1229,7 +1362,6 @@ def append( def __iadd__(self, other): """Add (concatenate) two HEDAnnotations objects in-place.""" if not isinstance(other, type(self)): - # Convert self to plain Annotations, preserving HED in extras extras = _hed_extras_from_hed_annotations(self) result = Annotations( onset=self.onset, diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index 1c9a1416a29..b91627ef7da 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -1752,6 +1752,41 @@ def test_annotation_duration_setting(): a.set_durations({"aaa", 2.2}) +def test_setter_validation(): + """Test that onset/duration/description/ch_names setters validate length.""" + annots = Annotations(onset=[1, 3, 2, 4], duration=0, description="foo") + + # onset mismatch should raise + with pytest.raises(ValueError, match="Length of onset"): + annots.onset = annots.onset[:2] + + # duration mismatch should raise + with pytest.raises(ValueError, match="Length of duration"): + annots.duration = annots.duration[:2] + + # description mismatch should raise (the bug drammock reported) + with pytest.raises(ValueError, match="Length of description"): + annots.description = annots.description[:2] + + # scalar duration should broadcast without error + annots.duration = 1.0 + assert len(annots.duration) == 4 + assert all(annots.duration == 1.0) + + # scalar description should broadcast without error + annots.description = "bad" + assert len(annots.description) == 4 + assert all(annots.description == "bad") + + # ch_names mismatch should raise + with pytest.raises(ValueError, match="Length of ch_names"): + annots.ch_names = [(), ()] + + # valid ch_names assignment (correct length) should succeed + annots.ch_names = [("MEG 0111",), (), (), ()] + assert annots.ch_names[0] == ("MEG 0111",) + + @pytest.mark.parametrize("meas_date", (None, 1)) @pytest.mark.parametrize("set_meas_date", ("before", "after")) @pytest.mark.parametrize("first_samp", (0, 100, 3000))