From 3c37a018395a6f839b907c5d3eafcd56c6ea984f Mon Sep 17 00:00:00 2001 From: heyan3 Date: Sun, 5 Apr 2026 16:29:44 +0800 Subject: [PATCH 01/21] added mvcl_training_sleepedf_task --- pyhealth/tasks/__init__.py | 1 + pyhealth/tasks/mvcl_training_sleepedf_task.py | 179 ++++++++++++++++++ 2 files changed, 180 insertions(+) create mode 100644 pyhealth/tasks/mvcl_training_sleepedf_task.py diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 2f4294a19..c64bc6471 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -68,3 +68,4 @@ VariantClassificationClinVar, ) from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task +from .mvcl_training_sleepedf_task import MVCLTrainingSleepEEG diff --git a/pyhealth/tasks/mvcl_training_sleepedf_task.py b/pyhealth/tasks/mvcl_training_sleepedf_task.py new file mode 100644 index 000000000..81cebfd9e --- /dev/null +++ b/pyhealth/tasks/mvcl_training_sleepedf_task.py @@ -0,0 +1,179 @@ +"""PyHealth task: SleepEDF → TF-C / SleepEEG-style fixed-length windows. + +Produces samples aligned with the tensor layout used in +`mims-harvard/TFC-pretraining`: a dict with ``samples`` of shape +``[N, n_channels, L]`` and ``labels`` of shape ``[N]``. See the upstream +`dataloader.py` (expects channel in dimension 1, then crops to +``TSlength_aligned``, default 178). + +References: +- https://github.com/mims-harvard/TFC-pretraining +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +import mne +import numpy as np +import torch + +from pyhealth.tasks import BaseTask + + +def _map_to_MVCL_five_class(pyhealth_stage: int) -> int: + """Map PyHealth 6-class staging to 5-class AASM-style (N3+N4 → deep).""" + # PyHealth: W=0, N1=1, N2=2, N3=3, N4=4, R=5 + return (0, 1, 2, 3, 3, 4)[int(pyhealth_stage)] + + +class MVCLTrainingSleepEEG(BaseTask): + """Short EEG windows from SleepEDFDataset for time-series pretraining (TF-C style). + + Reads each recording like ``SleepStagingSleepEDF`` (30 s scored epochs), picks a + single EEG lead, then splits each epoch into non-overlapping windows of + ``window_size`` samples (default 200 @ 100 Hz, consistent with the SleepEEG + description in the TF-C paper). Each window inherits the epoch's sleep-stage + label, remapped to 5 classes (N3 and N4 both map to deep sleep). + + Output samples are dicts with ``signal`` shaped ``(1, L)`` where ``L`` is + ``window_size`` or ``crop_length`` + """ + + task_name: str = "MVCLTrainingSleepEEG" + input_schema: Dict[str, str] = {"signal": "tensor"} + output_schema: Dict[str, str] = {"label": "multiclass"} + + def __init__( + self, + chunk_duration: float = 30.0, + window_size: int = 200, + crop_length: Optional[int] = 178, + eeg_channel: Optional[str] = "EEG Fpz-Cz", + ) -> None: + """ + Args: + chunk_duration: Hypnogram epoch length in seconds (PyHealth default 30). + window_size: Non-overlapping window length in samples (TF-C SleepEEG: 200). + crop_length: If set, each window is truncated to this many samples from + the start (TF-C code often uses 178 for cross-dataset alignment). + eeg_channel: MNE channel name to keep (single lead). If None or name + missing, falls back to the first channel whose name contains ``EEG``, + then index 0. + """ + self.chunk_duration = float(chunk_duration) + self.window_size = int(window_size) + self.crop_length = int(crop_length) if crop_length is not None else None + self.eeg_channel = eeg_channel + super().__init__() + + def _pick_eeg_index(self, ch_names: List[str]) -> int: + if self.eeg_channel and self.eeg_channel in ch_names: + return ch_names.index(self.eeg_channel) + for i, n in enumerate(ch_names): + if "eeg" in n.lower(): + return i + return 0 + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + pid = patient.patient_id + events = patient.get_events() + samples: List[Dict[str, Any]] = [] + + event_id = { + "Sleep stage W": 0, + "Sleep stage 1": 1, + "Sleep stage 2": 2, + "Sleep stage 3": 3, + "Sleep stage 4": 4, + "Sleep stage R": 5, + } + + win = self.window_size + crop = self.crop_length + + global_epoch = 0 + for event in events: + if not event.signal_file or not event.label_file: + continue + data = mne.io.read_raw_edf( + event.signal_file, + stim_channel="Event marker", + infer_types=True, + preload=True, + verbose="error", + ) + ann = mne.read_annotations(event.label_file) + data.set_annotations(ann, emit_warning=False) + + ann_events, event_id_used = mne.events_from_annotations( + data, event_id=event_id, chunk_duration=self.chunk_duration + ) + if ann_events.size == 0: + continue + + # Use MNE's filtered event_id (only stages present in this night). Passing + # the full 6-stage dict causes ValueError for missing N3/N4/etc. on some nights. + epochs_train = mne.Epochs( + data, + ann_events, + event_id_used, + tmin=0.0, + tmax=self.chunk_duration - 1.0 / data.info["sfreq"], + baseline=None, + preload=True, + on_missing="ignore", + verbose="error", + ) + + ch_i = self._pick_eeg_index(list(epochs_train.ch_names)) + signals = epochs_train.get_data()[:, ch_i, :] + labels = epochs_train.events[:, 2] + + n_epochs, n_times = signals.shape + n_full = (n_times // win) * win + + for epi in range(n_epochs): + lab = _map_to_MVCL_five_class(int(labels[epi])) + row = signals[epi, :n_full] + for w in range(n_full // win): + seg = row[w * win : (w + 1) * win].astype(np.float32, copy=False) + if crop is not None: + seg = seg[:crop] + vec = seg[np.newaxis, :] + samples.append( + { + "patient_id": pid, + "night": event.night, + "patient_age": event.age, + "patient_sex": event.sex, + "epoch_index": global_epoch, + "window_in_epoch": w, + "signal": vec, + "label": lab, + } + ) + global_epoch += 1 + + return samples + + +def stack_samples_to_mvcl_dict(samples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: + """Stack PyHealth task outputs into a TF-C-style tensor dict.""" + if not samples: + raise ValueError("empty sample list") + xs = np.stack([np.asarray(s["signal"]) for s in samples], axis=0) + ys = np.array([int(s["label"]) for s in samples], dtype=np.int64) + x = torch.from_numpy(np.ascontiguousarray(xs)).float() + y = torch.from_numpy(ys).long() + if x.ndim == 2: + x = x.unsqueeze(1) + return {"samples": x, "labels": y} + + +def save_mvcl_pt( + tensor_dict: Dict[str, torch.Tensor], + path: str, +) -> None: + """Save ``{"samples", "labels"}`` in PyTorch ``torch.save`` format (``.pt``).""" + torch.save(tensor_dict, path) From 10f4c6db09732003209d594788df050b0c366366 Mon Sep 17 00:00:00 2001 From: gaohey Date: Mon, 6 Apr 2026 22:32:14 +0800 Subject: [PATCH 02/21] task outputs xt,dx,xf --- pyhealth/tasks/mvcl_training_sleepedf_task.py | 170 +++++++++++++----- 1 file changed, 123 insertions(+), 47 deletions(-) diff --git a/pyhealth/tasks/mvcl_training_sleepedf_task.py b/pyhealth/tasks/mvcl_training_sleepedf_task.py index 81cebfd9e..14610933e 100644 --- a/pyhealth/tasks/mvcl_training_sleepedf_task.py +++ b/pyhealth/tasks/mvcl_training_sleepedf_task.py @@ -12,32 +12,31 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional - +from typing import Any, Dict, List, Optional, Tuple import mne import numpy as np import torch +import torch.fft as fft from pyhealth.tasks import BaseTask + def _map_to_MVCL_five_class(pyhealth_stage: int) -> int: """Map PyHealth 6-class staging to 5-class AASM-style (N3+N4 → deep).""" - # PyHealth: W=0, N1=1, N2=2, N3=3, N4=4, R=5 return (0, 1, 2, 3, 3, 4)[int(pyhealth_stage)] class MVCLTrainingSleepEEG(BaseTask): - """Short EEG windows from SleepEDFDataset for time-series pretraining (TF-C style). + """SleepEDF windows with Multi-View contrastive tensor views. - Reads each recording like ``SleepStagingSleepEDF`` (30 s scored epochs), picks a - single EEG lead, then splits each epoch into non-overlapping windows of - ``window_size`` samples (default 200 @ 100 Hz, consistent with the SleepEEG - description in the TF-C paper). Each window inherits the epoch's sleep-stage - label, remapped to 5 classes (N3 and N4 both map to deep sleep). + Applies MV preprocessing per event file (one PSG/Hypnogram pair at a time), + then appends samples immediately, so each returned sample includes ``xt``, + ``dx``, and ``xf`` without a patient-level global buffer. - Output samples are dicts with ``signal`` shaped ``(1, L)`` where ``L`` is - ``window_size`` or ``crop_length`` + Tensors are stored as ``numpy.float32`` arrays with shape ``(L, C_view)`` where + ``C_view`` is 1 by default; with ``time_as_feature=True``, a leading time channel + in ``[0,1]`` is concatenated so ``C_view`` is 2. """ task_name: str = "MVCLTrainingSleepEEG" @@ -50,21 +49,16 @@ def __init__( window_size: int = 200, crop_length: Optional[int] = 178, eeg_channel: Optional[str] = "EEG Fpz-Cz", + time_as_feature: bool = False, + dx_backend: str = "cde", ) -> None: - """ - Args: - chunk_duration: Hypnogram epoch length in seconds (PyHealth default 30). - window_size: Non-overlapping window length in samples (TF-C SleepEEG: 200). - crop_length: If set, each window is truncated to this many samples from - the start (TF-C code often uses 178 for cross-dataset alignment). - eeg_channel: MNE channel name to keep (single lead). If None or name - missing, falls back to the first channel whose name contains ``EEG``, - then index 0. - """ self.chunk_duration = float(chunk_duration) self.window_size = int(window_size) self.crop_length = int(crop_length) if crop_length is not None else None self.eeg_channel = eeg_channel + # ``False`` matches ``preprocess_data`` defaults in MV run_pretrain / run_finetune. + self.time_as_feature = bool(time_as_feature) + super().__init__() def _pick_eeg_index(self, ch_names: List[str]) -> int: @@ -112,8 +106,6 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: if ann_events.size == 0: continue - # Use MNE's filtered event_id (only stages present in this night). Passing - # the full 6-stage dict causes ValueError for missing N3/N4/etc. on some nights. epochs_train = mne.Epochs( data, ann_events, @@ -133,6 +125,7 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: n_epochs, n_times = signals.shape n_full = (n_times // win) * win + event_buffers: List[Dict[str, Any]] = [] for epi in range(n_epochs): lab = _map_to_MVCL_five_class(int(labels[epi])) row = signals[epi, :n_full] @@ -140,40 +133,123 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: seg = row[w * win : (w + 1) * win].astype(np.float32, copy=False) if crop is not None: seg = seg[:crop] - vec = seg[np.newaxis, :] - samples.append( + event_buffers.append( { - "patient_id": pid, + "seg_1d": seg.copy(), + "label": lab, "night": event.night, "patient_age": event.age, "patient_sex": event.sex, "epoch_index": global_epoch, "window_in_epoch": w, - "signal": vec, - "label": lab, } ) global_epoch += 1 + if not event_buffers: + continue + + X = torch.stack( + [torch.from_numpy(b["seg_1d"]).float() for b in event_buffers], dim=0 + ).unsqueeze(-1) + xt, dx, xf = preprocess_mvcl_views( + X, + time_as_feature=self.time_as_feature + ) + + for i, b in enumerate(event_buffers): + seg = b["seg_1d"] + vec = seg[np.newaxis, :] + samples.append( + { + "patient_id": pid, + "night": b["night"], + "patient_age": b["patient_age"], + "patient_sex": b["patient_sex"], + "epoch_index": b["epoch_index"], + "window_in_epoch": b["window_in_epoch"], + "signal": vec, + "xt": xt[i].detach().cpu().numpy().astype(np.float16), + "dx": dx[i].detach().cpu().numpy().astype(np.float16), + "xf": xf[i].detach().cpu().numpy().astype(np.float16), + "label": b["label"], + } + ) + return samples -def stack_samples_to_mvcl_dict(samples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: - """Stack PyHealth task outputs into a TF-C-style tensor dict.""" - if not samples: - raise ValueError("empty sample list") - xs = np.stack([np.asarray(s["signal"]) for s in samples], axis=0) - ys = np.array([int(s["label"]) for s in samples], dtype=np.int64) - x = torch.from_numpy(np.ascontiguousarray(xs)).float() - y = torch.from_numpy(ys).long() - if x.ndim == 2: - x = x.unsqueeze(1) - return {"samples": x, "labels": y} - - -def save_mvcl_pt( - tensor_dict: Dict[str, torch.Tensor], - path: str, -) -> None: - """Save ``{"samples", "labels"}`` in PyTorch ``torch.save`` format (``.pt``).""" - torch.save(tensor_dict, path) +def normalize_mvcl( + X_train: torch.Tensor, + X_test: torch.Tensor, + epsilon: float = 1e-8, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + mean = X_train.mean(dim=(0, 1), keepdim=True) + std = X_train.std(dim=(0, 1), keepdim=True).clamp(min=epsilon) + return ( + (X_train - mean) / std, + (X_test - mean) / std, + mean, + std, + ) + + +def add_time_feature(X: torch.Tensor) -> torch.Tensor: + """X: [num_samples, sequence_length, num_features] -> concat time in last dim.""" + num_samples, seq_length, _ = X.shape + time_index = torch.linspace(0, 1, steps=seq_length, dtype=X.dtype, device=X.device) + time_feature = time_index.view(1, seq_length, 1).expand(num_samples, seq_length, 1) + return torch.cat([time_feature, X], dim=-1) + + + +def get_dx_gradient(X: torch.Tensor) -> torch.Tensor: + """Time derivative via ``torch.gradient`` along **dim=1** for **X [N, L, D]**. + + This is **not** equivalent to :func:`get_dx` (torchcde spline); see module docstring. + """ + if X.ndim != 3: + raise ValueError(f"Expected [N, L, D], got {tuple(X.shape)}") + return torch.gradient(X, dim=1)[0] + + + +def get_xf(X: torch.Tensor) -> torch.Tensor: + return torch.abs(fft.fft(X, dim=1)) + + +def preprocess_mvcl_views( + X: torch.Tensor, + time_as_feature: bool = False +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Apply the same logical pipeline as ``preprocess_data`` when train and test + are the same tensor (per-domain batch, e.g. all windows of one patient). + + X: float tensor **[N, L, D]** (e.g. ``D=1`` for single-channel EEG; **L** is time length). + + dx_backend: ``"cde"`` (default, torchcde, matches MV) or ``"gradient"`` (no torchcde). + + Returns xt, dx, xf each [N, L, D] or [N, L, D+1] if ``time_as_feature``. + """ + if X.ndim != 3: + raise ValueError(f"Expected X shape [N, L, D], got {tuple(X.shape)}") + xt_tr, xt_te, _, _ = normalize_mvcl(X, X) + xt = xt_tr + + + dx_raw = get_dx_gradient(xt) + + dx_tr, dx_te, _, _ = normalize_mvcl(dx_raw, dx_raw) + dx = dx_tr + + xf_raw = get_xf(xt) + xf_tr, xf_te, _, _ = normalize_mvcl(xf_raw, xf_raw) + xf = xf_tr + + if time_as_feature: + xt = add_time_feature(xt) + dx = add_time_feature(dx) + xf = add_time_feature(xf) + + return xt, dx, xf From 679b4ba029946fe899b4d8a2fa126dfa3e64b4dc Mon Sep 17 00:00:00 2001 From: coderookie1994 Date: Thu, 9 Apr 2026 17:50:35 +0530 Subject: [PATCH 03/21] Set input_schema and output_schema and do not redefine it --- pyhealth/tasks/mvcl_training_sleepedf_task.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyhealth/tasks/mvcl_training_sleepedf_task.py b/pyhealth/tasks/mvcl_training_sleepedf_task.py index 14610933e..de0283971 100644 --- a/pyhealth/tasks/mvcl_training_sleepedf_task.py +++ b/pyhealth/tasks/mvcl_training_sleepedf_task.py @@ -40,8 +40,8 @@ class MVCLTrainingSleepEEG(BaseTask): """ task_name: str = "MVCLTrainingSleepEEG" - input_schema: Dict[str, str] = {"signal": "tensor"} - output_schema: Dict[str, str] = {"label": "multiclass"} + input_schema = {"signal": "tensor"} + output_schema = {"label": "multiclass"} def __init__( self, From 202311ffe5a57aa2f4c3b9c6e8fc884ec9094d98 Mon Sep 17 00:00:00 2001 From: coderookie1994 Date: Sun, 12 Apr 2026 03:42:09 +0530 Subject: [PATCH 04/21] Init commit for the model implementation --- pyhealth/datasets/sleepedf.py | 2 + .../models/multi_view_contrastive_model.py | 139 ++++++++++++++++++ pyhealth/tasks/mvcl_training_sleepedf_task.py | 6 +- 3 files changed, 144 insertions(+), 3 deletions(-) create mode 100644 pyhealth/models/multi_view_contrastive_model.py diff --git a/pyhealth/datasets/sleepedf.py b/pyhealth/datasets/sleepedf.py index 71ef3a450..979337539 100644 --- a/pyhealth/datasets/sleepedf.py +++ b/pyhealth/datasets/sleepedf.py @@ -58,6 +58,7 @@ def __init__( dataset_name: Optional[str] = None, config_path: Optional[str] = None, subset: Optional[str] = "cassette", + dev: bool = False, ) -> None: subset = (subset or "cassette").lower() if subset not in {"cassette", "telemetry"}: @@ -87,6 +88,7 @@ def __init__( tables=default_tables, dataset_name=dataset_name or "sleepedf", config_path=config_path, + dev=dev, ) def prepare_metadata_cassette(self, root: str) -> None: diff --git a/pyhealth/models/multi_view_contrastive_model.py b/pyhealth/models/multi_view_contrastive_model.py new file mode 100644 index 000000000..1e8ba8b8f --- /dev/null +++ b/pyhealth/models/multi_view_contrastive_model.py @@ -0,0 +1,139 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from pyhealth.models import BaseModel +from typing import cast + +class MultiViewContrastiveModel(BaseModel): + """A simple multi-view contrastive model for demonstration purposes.""" + + def __init__(self, dataset, training_stage="pretrain", num_classes=3, **kwargs): + super().__init__(dataset=dataset) + self.hidden_dim = 128 + seq_length = 256 + self.training_stage = training_stage + self.lambda_cl = 0.1 + self.tau = 0.07 + + self.proj_t = nn.Linear(1, self.hidden_dim) + self.proj_d = nn.Linear(1, self.hidden_dim) + self.proj_f = nn.Linear(1, self.hidden_dim) + + encoder_layer = nn.TransformerEncoderLayer( + d_model=self.hidden_dim, nhead=4, batch_first=True + ) + self.encoder_t = nn.TransformerEncoder(encoder_layer, num_layers=3) + self.encoder_d = nn.TransformerEncoder(encoder_layer, num_layers=3) + self.encoder_f = nn.TransformerEncoder(encoder_layer, num_layers=3) + + # Now we need MHA + self.fusion_mha = nn.MultiheadAttention(embed_dim=self.hidden_dim, num_heads=4, batch_first=True) + self.fusion_layer_norm = nn.LayerNorm(self.hidden_dim) + + # Feature-specific projectors + def projector(): + return nn.Sequential( + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.ReLU(), + nn.Linear(self.hidden_dim, self.hidden_dim), + ) + self.F_t = projector() + self.F_d = projector() + self.F_f = projector() + + self.classifier = nn.Linear(self.hidden_dim, num_classes) + + def augment(self, z): + # Placeholder for time-series augmentation (e.g., adding Gaussian noise) + noise = torch.randn_like(z) * 0.01 + return z + noise + + def info_nce_loss(self, z_i, z_j, tau): + # Compute cosine similarity + z_i = F.normalize(z_i, dim=1) + z_j = F.normalize(z_j, dim=1) + sim = torch.mm(z_i, z_j.T) / tau + + # Positive pairs are on the diagonal + labels = torch.arange(sim.size(0), device=sim.device) + return F.cross_entropy(sim, labels) + + def forward(self, **kwargs) -> dict[str, torch.Tensor]: + temporal_tensor = self._prepare_tensor(kwargs.get("xt")) # [N, L, 1] + derivative_tensor = self._prepare_tensor(kwargs.get("xd")) # [N, L, 1] + frequency_tensor = self._prepare_tensor(kwargs.get("xf")) # [N, L, 1] + + temporal_tensor = self.proj_t(temporal_tensor) # [N, L, hidden_dim] + derivative_tensor = self.proj_d(derivative_tensor) # [N, L, hidden_dim] + frequency_tensor = self.proj_f(frequency_tensor) # [N, L, hidden_dim] + + h_t = self.encoder_t(temporal_tensor) # [N, L, hidden_dim] + h_d = self.encoder_d(derivative_tensor) # [N, L, hidden_dim] + h_f = self.encoder_f(frequency_tensor) # [N, L, hidden_dim] + + batch_size, seq_length, _ = h_t.shape + H = torch.stack([h_t, h_d, h_f], dim=2) # [N, 2, L, hidden_dim] + H_flat = H.view(-1, 3, self.hidden_dim) + + MHA_out, _ = self.fusion_mha(H_flat, H_flat, H_flat) # [N*L, 3, hidden_dim] + MHA_out = self.fusion_layer_norm(MHA_out + H_flat) # Residual connection + + # Split the MHA output back into individual views + MHA_out = MHA_out.view(batch_size, seq_length, 3, self.hidden_dim) + h_t_star, h_d_star, h_f_star = MHA_out[:, :, 0, :], MHA_out[:, :, 1, :], MHA_out[:, :, 2, :] + z_t = self.F_t(h_t_star).mean(dim=1) # [N, L, hidden_dim] -> [N, hidden_dim] + z_d = self.F_d(h_d_star).mean(dim=1) # [N, L, hidden_dim] -> [N, hidden_dim] + z_f = self.F_f(h_f_star).mean(dim=1) # [N, L, hidden_dim] -> [N, hidden_dim] + + # --- Stage Routing --- + if self.training_stage == "pretrain": + z_t_aug, z_d_aug, z_f_aug = self.augment(z_t), self.augment(z_d), self.augment(z_f) + loss = self.info_nce_loss(z_t, z_t_aug, self.tau) + \ + self.info_nce_loss(z_d, z_d_aug, self.tau) + \ + self.info_nce_loss(z_f, z_f_aug, self.tau) + + return {"loss": loss} + + elif self.training_stage == "finetune": + z_combined = torch.cat([z_t, z_d, z_f], dim=1) + logits = self.classifier(z_combined) + + # Use PyHealth's automatic label parsing + label_key = self.label_keys[0] + y_true = cast(torch.Tensor, kwargs[label_key]) + + # Use PyHealth's automatic loss function mapping + criterion = self.get_loss_function() + loss_ce = criterion(logits, y_true) + + # Contrastive penalty during finetuning + z_t_aug, z_d_aug, z_f_aug = self.augment(z_t), self.augment(z_d), self.augment(z_f) + loss_cl = self.info_nce_loss(z_t, z_t_aug, self.tau) + \ + self.info_nce_loss(z_d, z_d_aug, self.tau) + \ + self.info_nce_loss(z_f, z_f_aug, self.tau) + + total_loss = (self.lambda_cl * loss_cl) + loss_ce + + # Return PyHealth's expected dictionary schema + return { + "loss": total_loss, + "logit": logits, + "y_prob": self.prepare_y_prob(logits), # Autocast to prob + "y_true": y_true + } + + return {} + + def _prepare_tensor(self, x): + """Converts lists to batched tensors, enforces float32, and moves to device.""" + if isinstance(x, list): + if isinstance(x[0], torch.Tensor): + x = torch.stack(x) + else: + import numpy as np + x = torch.from_numpy(np.stack(x)) + + # Enforce standard float precision and push to GPU + return x.float().to(self.device) \ No newline at end of file diff --git a/pyhealth/tasks/mvcl_training_sleepedf_task.py b/pyhealth/tasks/mvcl_training_sleepedf_task.py index de0283971..de17c6ea0 100644 --- a/pyhealth/tasks/mvcl_training_sleepedf_task.py +++ b/pyhealth/tasks/mvcl_training_sleepedf_task.py @@ -169,9 +169,9 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: "epoch_index": b["epoch_index"], "window_in_epoch": b["window_in_epoch"], "signal": vec, - "xt": xt[i].detach().cpu().numpy().astype(np.float16), - "dx": dx[i].detach().cpu().numpy().astype(np.float16), - "xf": xf[i].detach().cpu().numpy().astype(np.float16), + "xt": xt[i].detach().cpu().numpy().astype(np.float32), + "xd": dx[i].detach().cpu().numpy().astype(np.float32), + "xf": xf[i].detach().cpu().numpy().astype(np.float32), "label": b["label"], } ) From ce5ff2e6dd0a5c5bbd3c958b57b7e4b10462a52e Mon Sep 17 00:00:00 2001 From: gaohey Date: Mon, 13 Apr 2026 00:00:09 +0800 Subject: [PATCH 05/21] added function to load external data in task and update finetune model --- pyhealth/models/__init__.py | 1 + .../models/multi_view_contrastive_model.py | 22 ++- pyhealth/tasks/__init__.py | 6 +- pyhealth/tasks/mvcl_training_sleepedf_task.py | 136 +++++++++++++++++- 4 files changed, 159 insertions(+), 6 deletions(-) diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 5233b1726..bc3d37bb2 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -44,3 +44,4 @@ from .sdoh import SdohClassifier from .medlink import MedLink from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding +from .multi_view_contrastive_model import MultiViewContrastiveModel diff --git a/pyhealth/models/multi_view_contrastive_model.py b/pyhealth/models/multi_view_contrastive_model.py index 1e8ba8b8f..f810fe79e 100644 --- a/pyhealth/models/multi_view_contrastive_model.py +++ b/pyhealth/models/multi_view_contrastive_model.py @@ -16,6 +16,7 @@ def __init__(self, dataset, training_stage="pretrain", num_classes=3, **kwargs): self.training_stage = training_stage self.lambda_cl = 0.1 self.tau = 0.07 + self.num_classes = num_classes self.proj_t = nn.Linear(1, self.hidden_dim) self.proj_d = nn.Linear(1, self.hidden_dim) @@ -43,10 +44,17 @@ def projector(): self.F_d = projector() self.F_f = projector() - self.classifier = nn.Linear(self.hidden_dim, num_classes) + # self.classifier = nn.Linear(self.hidden_dim, num_classes) + self.classifier = nn.Sequential( + nn.Linear(self.hidden_dim * 3 , 1024), + nn.ReLU(), + nn.Linear(1024 , 512), + nn.ReLU(), + nn.Dropout(0.5), + nn.Linear(512, self.num_classes) + ) - def augment(self, z): - # Placeholder for time-series augmentation (e.g., adding Gaussian noise) + def augment(self, z): noise = torch.randn_like(z) * 0.01 return z + noise @@ -94,18 +102,24 @@ def forward(self, **kwargs) -> dict[str, torch.Tensor]: self.info_nce_loss(z_d, z_d_aug, self.tau) + \ self.info_nce_loss(z_f, z_f_aug, self.tau) - return {"loss": loss} + return {"loss": loss, + "zs":[z_t, z_d, z_f] + } # Return the embeddings for each view + elif self.training_stage == "finetune": z_combined = torch.cat([z_t, z_d, z_f], dim=1) + z_combined = z_combined.view(z_combined.size(0), -1) logits = self.classifier(z_combined) # Use PyHealth's automatic label parsing label_key = self.label_keys[0] y_true = cast(torch.Tensor, kwargs[label_key]) + y_true = y_true.to(logits.device) # Use PyHealth's automatic loss function mapping criterion = self.get_loss_function() + # Cross entropy expects raw logits, not argmax class indices. loss_ce = criterion(logits, y_true) # Contrastive penalty during finetuning diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 4fbfb8cb9..96340e360 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -66,4 +66,8 @@ VariantClassificationClinVar, ) from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task -from .mvcl_training_sleepedf_task import MVCLTrainingSleepEEG +from .mvcl_training_sleepedf_task import ( + MVCLTrainingSleepEEG, + pt_dict_to_pyhealth_samples, + pt_file_to_sample_dataset, +) diff --git a/pyhealth/tasks/mvcl_training_sleepedf_task.py b/pyhealth/tasks/mvcl_training_sleepedf_task.py index de17c6ea0..613e8b9b5 100644 --- a/pyhealth/tasks/mvcl_training_sleepedf_task.py +++ b/pyhealth/tasks/mvcl_training_sleepedf_task.py @@ -12,12 +12,14 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional, Tuple +from pathlib import Path +from typing import Any, Dict, List, Mapping, Optional, Tuple, Union import mne import numpy as np import torch import torch.fft as fft +from pyhealth.datasets.sample_dataset import create_sample_dataset from pyhealth.tasks import BaseTask @@ -253,3 +255,135 @@ def preprocess_mvcl_views( xf = add_time_feature(xf) return xt, dx, xf + + +def _to_tensor(data: Any, key_name: str) -> torch.Tensor: + """Convert input payloads to a detached CPU tensor for validation.""" + if isinstance(data, torch.Tensor): + return data.detach().cpu() + try: + return torch.as_tensor(data).detach().cpu() + except Exception as err: # pragma: no cover - defensive branch + raise TypeError(f"Could not convert `{key_name}` to tensor: {err}") from err + + +def _normalize_signal_array(samples_obj: Any) -> np.ndarray: + """Normalize `.pt` sample tensor to [N, L] float32.""" + signal_tensor = _to_tensor(samples_obj, "samples").float().contiguous() + if signal_tensor.ndim == 2: + return signal_tensor.numpy().astype(np.float32, copy=False) + if signal_tensor.ndim != 3: + raise ValueError( + "Expected `samples` shape [N, L], [N, 1, L], or [N, L, 1], " + f"got {tuple(signal_tensor.shape)}" + ) + + if signal_tensor.shape[1] == 1: + signal_tensor = signal_tensor[:, 0, :] + elif signal_tensor.shape[2] == 1: + signal_tensor = signal_tensor[:, :, 0] + else: + raise ValueError( + "Expected a single-channel tensor for `samples` when rank is 3; " + f"got {tuple(signal_tensor.shape)}" + ) + return signal_tensor.numpy().astype(np.float32, copy=False) + + +def _normalize_label_array(labels_obj: Any) -> np.ndarray: + """Normalize labels to [N] int64.""" + label_tensor = _to_tensor(labels_obj, "labels").long().contiguous() + if label_tensor.ndim == 1: + return label_tensor.numpy().astype(np.int64, copy=False) + if label_tensor.ndim == 2 and 1 in label_tensor.shape: + return label_tensor.reshape(-1).numpy().astype(np.int64, copy=False) + raise ValueError( + "Expected `labels` shape [N] or [N, 1], " + f"got {tuple(label_tensor.shape)}" + ) + + +def pt_dict_to_pyhealth_samples( + tensor_dict: Mapping[str, Any], + *, + patient_id_prefix: str = "epilepsy_patient", + record_id_prefix: str = "epilepsy_record", + time_as_feature: bool = False, +) -> List[Dict[str, Any]]: + """Convert `{samples, labels}` tensors into PyHealth raw sample dicts.""" + if "samples" not in tensor_dict or "labels" not in tensor_dict: + keys = sorted(tensor_dict.keys()) + raise KeyError( + "Expected keys `samples` and `labels` in tensor dict, " + f"but found keys: {keys}" + ) + + signal_array = _normalize_signal_array(tensor_dict["samples"]) + label_array = _normalize_label_array(tensor_dict["labels"]) + + if signal_array.shape[0] != label_array.shape[0]: + raise ValueError( + "`samples` and `labels` length mismatch: " + f"{signal_array.shape[0]} vs {label_array.shape[0]}" + ) + + # preprocess_mvcl_views expects [N, L, D], use D=1 for single-channel EEG. + signal_tensor = torch.from_numpy(np.ascontiguousarray(signal_array)).float().unsqueeze(-1) + xt, dx, xf = preprocess_mvcl_views(signal_tensor, time_as_feature=time_as_feature) + + samples: List[Dict[str, Any]] = [] + for i in range(signal_array.shape[0]): + samples.append( + { + "patient_id": f"{patient_id_prefix}_{i}", + "record_id": f"{record_id_prefix}_{i}", + "signal": signal_array[i][np.newaxis, :], + "xt": xt[i].detach().cpu().numpy().astype(np.float32), + "xd": dx[i].detach().cpu().numpy().astype(np.float32), + "xf": xf[i].detach().cpu().numpy().astype(np.float32), + "label": int(label_array[i]), + } + ) + return samples + + +def pt_file_to_sample_dataset( + pt_path: Union[str, Path], + *, + dataset_name: str = "epilepsy_pt", + task_name: str = "MVCLTrainingEpilepsyPT", + in_memory: bool = True, + patient_id_prefix: str = "epilepsy_patient", + record_id_prefix: str = "epilepsy_record", + time_as_feature: bool = False, +): + """Load one `.pt` file and return a PyHealth SampleDataset.""" + try: + tensor_dict = torch.load(pt_path, map_location="cpu", weights_only=False) + except TypeError: + tensor_dict = torch.load(pt_path, map_location="cpu") + + if not isinstance(tensor_dict, Mapping): + raise TypeError( + f"Expected `{pt_path}` to load as a mapping, got {type(tensor_dict)}" + ) + + samples = pt_dict_to_pyhealth_samples( + tensor_dict, + patient_id_prefix=patient_id_prefix, + record_id_prefix=record_id_prefix, + time_as_feature=time_as_feature, + ) + return create_sample_dataset( + samples=samples, + input_schema={ + "signal": "tensor", + "xt": "tensor", + "xd": "tensor", + "xf": "tensor", + }, + output_schema={"label": "multiclass"}, + dataset_name=dataset_name, + task_name=task_name, + in_memory=in_memory, + ) From 12b2c2deb39b8de9c90e67c4954c1a20d65cd4e1 Mon Sep 17 00:00:00 2001 From: coderookie1994 Date: Thu, 16 Apr 2026 16:03:54 +0530 Subject: [PATCH 06/21] Encoder encodes xt, xd, xf and doesn't augment z_k --- .../models/multi_view_contrastive_model.py | 84 ++++++++++--------- 1 file changed, 46 insertions(+), 38 deletions(-) diff --git a/pyhealth/models/multi_view_contrastive_model.py b/pyhealth/models/multi_view_contrastive_model.py index f810fe79e..08ef1e59d 100644 --- a/pyhealth/models/multi_view_contrastive_model.py +++ b/pyhealth/models/multi_view_contrastive_model.py @@ -22,12 +22,14 @@ def __init__(self, dataset, training_stage="pretrain", num_classes=3, **kwargs): self.proj_d = nn.Linear(1, self.hidden_dim) self.proj_f = nn.Linear(1, self.hidden_dim) - encoder_layer = nn.TransformerEncoderLayer( - d_model=self.hidden_dim, nhead=4, batch_first=True - ) - self.encoder_t = nn.TransformerEncoder(encoder_layer, num_layers=3) - self.encoder_d = nn.TransformerEncoder(encoder_layer, num_layers=3) - self.encoder_f = nn.TransformerEncoder(encoder_layer, num_layers=3) + def make_encoder(): + encoder_layer = nn.TransformerEncoderLayer( + d_model=self.hidden_dim, nhead=4, batch_first=True + ) + return nn.TransformerEncoder(encoder_layer, num_layers=3) + self.encoder_t = make_encoder() + self.encoder_d = make_encoder() + self.encoder_f = make_encoder() # Now we need MHA self.fusion_mha = nn.MultiheadAttention(embed_dim=self.hidden_dim, num_heads=4, batch_first=True) @@ -54,9 +56,10 @@ def projector(): nn.Linear(512, self.num_classes) ) - def augment(self, z): - noise = torch.randn_like(z) * 0.01 - return z + noise + def augment(self, x): + # Placeholder for time-series augmentation (e.g., adding Gaussian noise) + noise = torch.randn_like(x) * 0.01 + return x + noise def info_nce_loss(self, z_i, z_j, tau): # Compute cosine similarity @@ -67,47 +70,50 @@ def info_nce_loss(self, z_i, z_j, tau): # Positive pairs are on the diagonal labels = torch.arange(sim.size(0), device=sim.device) return F.cross_entropy(sim, labels) + + def _forward_features(self, x_t, x_d, x_f): + x_t = self.proj_t(x_t) + x_d = self.proj_d(x_d) + x_f = self.proj_f(x_f) - def forward(self, **kwargs) -> dict[str, torch.Tensor]: - temporal_tensor = self._prepare_tensor(kwargs.get("xt")) # [N, L, 1] - derivative_tensor = self._prepare_tensor(kwargs.get("xd")) # [N, L, 1] - frequency_tensor = self._prepare_tensor(kwargs.get("xf")) # [N, L, 1] - - temporal_tensor = self.proj_t(temporal_tensor) # [N, L, hidden_dim] - derivative_tensor = self.proj_d(derivative_tensor) # [N, L, hidden_dim] - frequency_tensor = self.proj_f(frequency_tensor) # [N, L, hidden_dim] - - h_t = self.encoder_t(temporal_tensor) # [N, L, hidden_dim] - h_d = self.encoder_d(derivative_tensor) # [N, L, hidden_dim] - h_f = self.encoder_f(frequency_tensor) # [N, L, hidden_dim] + h_t = self.encoder_t(x_t) + h_d = self.encoder_d(x_d) + h_f = self.encoder_f(x_f) batch_size, seq_length, _ = h_t.shape - H = torch.stack([h_t, h_d, h_f], dim=2) # [N, 2, L, hidden_dim] + H = torch.stack([h_t, h_d, h_f], dim=2) H_flat = H.view(-1, 3, self.hidden_dim) - MHA_out, _ = self.fusion_mha(H_flat, H_flat, H_flat) # [N*L, 3, hidden_dim] - MHA_out = self.fusion_layer_norm(MHA_out + H_flat) # Residual connection + MHA_out, _ = self.fusion_mha(H_flat, H_flat, H_flat) + H_out = self.fusion_layer_norm(MHA_out + H_flat) - # Split the MHA output back into individual views - MHA_out = MHA_out.view(batch_size, seq_length, 3, self.hidden_dim) - h_t_star, h_d_star, h_f_star = MHA_out[:, :, 0, :], MHA_out[:, :, 1, :], MHA_out[:, :, 2, :] - z_t = self.F_t(h_t_star).mean(dim=1) # [N, L, hidden_dim] -> [N, hidden_dim] - z_d = self.F_d(h_d_star).mean(dim=1) # [N, L, hidden_dim] -> [N, hidden_dim] - z_f = self.F_f(h_f_star).mean(dim=1) # [N, L, hidden_dim] -> [N, hidden_dim] + H_out = H_out.view(batch_size, seq_length, 3, self.hidden_dim) + h_t_star, h_d_star, h_f_star = H_out[:, :, 0, :], H_out[:, :, 1, :], H_out[:, :, 2, :] + z_t = self.F_t(h_t_star).mean(dim=1) + z_d = self.F_d(h_d_star).mean(dim=1) + z_f = self.F_f(h_f_star).mean(dim=1) + return z_t, z_d, z_f + + def forward(self, **kwargs) -> dict[str, torch.Tensor]: + temporal_tensor = self._prepare_tensor(kwargs.get("xt")) # [N, L, 1] + derivative_tensor = self._prepare_tensor(kwargs.get("xd")) # [N, L, 1] + frequency_tensor = self._prepare_tensor(kwargs.get("xf")) # [N, L, 1] # --- Stage Routing --- if self.training_stage == "pretrain": - z_t_aug, z_d_aug, z_f_aug = self.augment(z_t), self.augment(z_d), self.augment(z_f) + x_t_aug = self.augment(temporal_tensor) + x_d_aug = self.augment(derivative_tensor) + x_f_aug = self.augment(frequency_tensor) + z_t, z_d, z_f = self._forward_features(temporal_tensor, derivative_tensor, frequency_tensor) + z_t_aug, z_d_aug, z_f_aug = self._forward_features(x_t_aug, x_d_aug, x_f_aug) loss = self.info_nce_loss(z_t, z_t_aug, self.tau) + \ self.info_nce_loss(z_d, z_d_aug, self.tau) + \ self.info_nce_loss(z_f, z_f_aug, self.tau) - - return {"loss": loss, - "zs":[z_t, z_d, z_f] - } # Return the embeddings for each view - + print (f"Pretrain Loss: {loss.item():.4f}") + return {"loss": loss} elif self.training_stage == "finetune": + z_t, z_d, z_f = self._forward_features(temporal_tensor, derivative_tensor, frequency_tensor) z_combined = torch.cat([z_t, z_d, z_f], dim=1) z_combined = z_combined.view(z_combined.size(0), -1) logits = self.classifier(z_combined) @@ -123,7 +129,10 @@ def forward(self, **kwargs) -> dict[str, torch.Tensor]: loss_ce = criterion(logits, y_true) # Contrastive penalty during finetuning - z_t_aug, z_d_aug, z_f_aug = self.augment(z_t), self.augment(z_d), self.augment(z_f) + x_t_aug = self.augment(temporal_tensor) + x_d_aug = self.augment(derivative_tensor) + x_f_aug = self.augment(frequency_tensor) + z_t_aug, z_d_aug, z_f_aug = self._forward_features(x_t_aug, x_d_aug, x_f_aug) loss_cl = self.info_nce_loss(z_t, z_t_aug, self.tau) + \ self.info_nce_loss(z_d, z_d_aug, self.tau) + \ self.info_nce_loss(z_f, z_f_aug, self.tau) @@ -137,7 +146,6 @@ def forward(self, **kwargs) -> dict[str, torch.Tensor]: "y_prob": self.prepare_y_prob(logits), # Autocast to prob "y_true": y_true } - return {} def _prepare_tensor(self, x): From f411e233cbc2cd833031c46cc59f327ebae7f5be Mon Sep 17 00:00:00 2001 From: gaohey Date: Thu, 16 Apr 2026 23:41:40 +0800 Subject: [PATCH 07/21] added symmetry selfloss;updated the finetuning classifier --- .../models/multi_view_contrastive_model.py | 23 +++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/pyhealth/models/multi_view_contrastive_model.py b/pyhealth/models/multi_view_contrastive_model.py index 08ef1e59d..d5dd49045 100644 --- a/pyhealth/models/multi_view_contrastive_model.py +++ b/pyhealth/models/multi_view_contrastive_model.py @@ -24,7 +24,7 @@ def __init__(self, dataset, training_stage="pretrain", num_classes=3, **kwargs): def make_encoder(): encoder_layer = nn.TransformerEncoderLayer( - d_model=self.hidden_dim, nhead=4, batch_first=True + d_model=self.hidden_dim, nhead=4, batch_first=True, dropout=0.2 ) return nn.TransformerEncoder(encoder_layer, num_layers=3) self.encoder_t = make_encoder() @@ -61,15 +61,22 @@ def augment(self, x): noise = torch.randn_like(x) * 0.01 return x + noise - def info_nce_loss(self, z_i, z_j, tau): + def info_nce_loss(self, z_i, z_j, tau, symmetric=True): # Compute cosine similarity z_i = F.normalize(z_i, dim=1) z_j = F.normalize(z_j, dim=1) - sim = torch.mm(z_i, z_j.T) / tau + sim_ij = torch.mm(z_i, z_j.T) / tau # Positive pairs are on the diagonal - labels = torch.arange(sim.size(0), device=sim.device) - return F.cross_entropy(sim, labels) + labels = torch.arange(sim_ij.size(0), device=sim_ij.device) + loss_ij = F.cross_entropy(sim_ij, labels) + + if symmetric: + sim_ji = torch.mm(z_j, z_i.T) / tau + loss_ji = F.cross_entropy(sim_ji, labels) + return (loss_ij + loss_ji) / 2 + + return loss_ij def _forward_features(self, x_t, x_d, x_f): x_t = self.proj_t(x_t) @@ -109,8 +116,10 @@ def forward(self, **kwargs) -> dict[str, torch.Tensor]: loss = self.info_nce_loss(z_t, z_t_aug, self.tau) + \ self.info_nce_loss(z_d, z_d_aug, self.tau) + \ self.info_nce_loss(z_f, z_f_aug, self.tau) - print (f"Pretrain Loss: {loss.item():.4f}") - return {"loss": loss} + # print (f"Pretrain Loss: {loss.item():.4f}") + return {"loss": loss, + "zs":[z_t, z_d, z_f] + } # Return the embeddings for each view elif self.training_stage == "finetune": z_t, z_d, z_f = self._forward_features(temporal_tensor, derivative_tensor, frequency_tensor) From 24ed61f3bb0832e3cbae2f5ce231f2b10b456c47 Mon Sep 17 00:00:00 2001 From: gaohey Date: Sat, 18 Apr 2026 00:55:44 +0800 Subject: [PATCH 08/21] fixed and improved task to match papers --- pyhealth/tasks/mvcl_training_sleepedf_task.py | 80 ++++++++++++++----- 1 file changed, 59 insertions(+), 21 deletions(-) diff --git a/pyhealth/tasks/mvcl_training_sleepedf_task.py b/pyhealth/tasks/mvcl_training_sleepedf_task.py index 613e8b9b5..b11342663 100644 --- a/pyhealth/tasks/mvcl_training_sleepedf_task.py +++ b/pyhealth/tasks/mvcl_training_sleepedf_task.py @@ -96,7 +96,7 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: event.signal_file, stim_channel="Event marker", infer_types=True, - preload=True, + preload=False, verbose="error", ) ann = mne.read_annotations(event.label_file) @@ -108,6 +108,10 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: if ann_events.size == 0: continue + # Pick only the required EEG channel to save memory (7x reduction) + ch_i = self._pick_eeg_index(list(data.ch_names)) + data.pick([data.ch_names[ch_i]]) + epochs_train = mne.Epochs( data, ann_events, @@ -120,8 +124,8 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: verbose="error", ) - ch_i = self._pick_eeg_index(list(epochs_train.ch_names)) - signals = epochs_train.get_data()[:, ch_i, :] + # Since we picked exactly 1 channel, it is now at index 0 + signals = epochs_train.get_data()[:, 0, :] labels = epochs_train.events[:, 2] n_epochs, n_times = signals.shape @@ -130,22 +134,23 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: event_buffers: List[Dict[str, Any]] = [] for epi in range(n_epochs): lab = _map_to_MVCL_five_class(int(labels[epi])) - row = signals[epi, :n_full] - for w in range(n_full // win): - seg = row[w * win : (w + 1) * win].astype(np.float32, copy=False) - if crop is not None: - seg = seg[:crop] - event_buffers.append( - { - "seg_1d": seg.copy(), - "label": lab, - "night": event.night, - "patient_age": event.age, - "patient_sex": event.sex, - "epoch_index": global_epoch, - "window_in_epoch": w, - } - ) + + # Take only the first window of the epoch to match TFC-pretraining's sample count + seg = signals[epi, :win].astype(np.float32, copy=False) + if crop is not None: + seg = seg[:crop] + + event_buffers.append( + { + "seg_1d": seg.copy(), + "label": lab, + "night": event.night, + "patient_age": event.age, + "patient_sex": event.sex, + "epoch_index": global_epoch, + "window_in_epoch": 0, + } + ) global_epoch += 1 if not event_buffers: @@ -208,13 +213,45 @@ def add_time_feature(X: torch.Tensor) -> torch.Tensor: def get_dx_gradient(X: torch.Tensor) -> torch.Tensor: """Time derivative via ``torch.gradient`` along **dim=1** for **X [N, L, D]**. - This is **not** equivalent to :func:`get_dx` (torchcde spline); see module docstring. + This is **not** equivalent to :func:`get_dx` (torchcde spline);. """ if X.ndim != 3: raise ValueError(f"Expected [N, L, D], got {tuple(X.shape)}") return torch.gradient(X, dim=1)[0] +def get_dx_torchcde_equivalent(X: torch.Tensor) -> torch.Tensor: + """ + Pure PyTorch equivalent of torchcde Hermite cubic spline derivative + evaluated at the knot points with backward differences. + """ + N, L, D = X.shape + dx = torch.zeros_like(X) + + if L < 2: + return dx + + dt = 1.0 / (L - 1) + + # derivs[i] = X[i+1] - X[i] + derivs = X[:, 1:, :] - X[:, :-1, :] + + # derivs_prev[i] = derivs[i-1] for i > 0, and derivs[0] for i = 0 + derivs_prev = torch.cat([derivs[:, :1, :], derivs[:, :-1, :]], dim=1) + + # For i = 0 + dx[:, 0, :] = derivs[:, 0, :] + + # For i > 0 + factor = (4 - 3 * dt) * dt + b = derivs_prev + D = derivs - b + + dx[:, 1:, :] = b + D * factor + + return dx + + def get_xf(X: torch.Tensor) -> torch.Tensor: return torch.abs(fft.fft(X, dim=1)) @@ -240,7 +277,8 @@ def preprocess_mvcl_views( xt = xt_tr - dx_raw = get_dx_gradient(xt) + # dx_raw = get_dx_gradient(xt) # this is approxi to paper's torchcde spline + dx_raw = get_dx_torchcde_equivalent(xt) dx_tr, dx_te, _, _ = normalize_mvcl(dx_raw, dx_raw) dx = dx_tr From df04ed9e20c04c75812e7106cc7c98cae00a596c Mon Sep 17 00:00:00 2001 From: gaohey Date: Sat, 18 Apr 2026 10:13:01 +0800 Subject: [PATCH 09/21] improved augmentation method, noise + frequency --- .../models/multi_view_contrastive_model.py | 29 +++++++++++++++++-- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/pyhealth/models/multi_view_contrastive_model.py b/pyhealth/models/multi_view_contrastive_model.py index d5dd49045..c75332fd9 100644 --- a/pyhealth/models/multi_view_contrastive_model.py +++ b/pyhealth/models/multi_view_contrastive_model.py @@ -14,7 +14,7 @@ def __init__(self, dataset, training_stage="pretrain", num_classes=3, **kwargs): self.hidden_dim = 128 seq_length = 256 self.training_stage = training_stage - self.lambda_cl = 0.1 + self.lambda_cl = 0.001 self.tau = 0.07 self.num_classes = num_classes @@ -56,11 +56,27 @@ def projector(): nn.Linear(512, self.num_classes) ) - def augment(self, x): + def augment(self, x, std=0.1): # Placeholder for time-series augmentation (e.g., adding Gaussian noise) - noise = torch.randn_like(x) * 0.01 + noise = torch.randn_like(x) * std return x + noise + def data_transform_fd(self, sample: torch.Tensor, pertub_ratio: float = 0.05) -> torch.Tensor: + aug_1 = self.remove_frequency(sample, pertub_ratio) + aug_2 = self.add_frequency(sample, pertub_ratio) + return aug_1 + aug_2 + + def remove_frequency(self, x: torch.Tensor, pertub_ratio: float = 0.0) -> torch.Tensor: + mask = torch.rand(x.shape, device=x.device) > pertub_ratio + return x * mask + + def add_frequency(self, x: torch.Tensor, pertub_ratio: float = 0.0) -> torch.Tensor: + mask = torch.rand(x.shape, device=x.device) > (1 - pertub_ratio) + max_amplitude = x.max() + random_am = torch.rand(mask.shape, device=x.device) * (max_amplitude * 0.1) + pertub_matrix = mask * random_am + return x + pertub_matrix + def info_nce_loss(self, z_i, z_j, tau, symmetric=True): # Compute cosine similarity z_i = F.normalize(z_i, dim=1) @@ -111,6 +127,10 @@ def forward(self, **kwargs) -> dict[str, torch.Tensor]: x_t_aug = self.augment(temporal_tensor) x_d_aug = self.augment(derivative_tensor) x_f_aug = self.augment(frequency_tensor) + + # x_f_aug = self.data_transform_fd(frequency_tensor, 0.05) + + z_t, z_d, z_f = self._forward_features(temporal_tensor, derivative_tensor, frequency_tensor) z_t_aug, z_d_aug, z_f_aug = self._forward_features(x_t_aug, x_d_aug, x_f_aug) loss = self.info_nce_loss(z_t, z_t_aug, self.tau) + \ @@ -141,6 +161,9 @@ def forward(self, **kwargs) -> dict[str, torch.Tensor]: x_t_aug = self.augment(temporal_tensor) x_d_aug = self.augment(derivative_tensor) x_f_aug = self.augment(frequency_tensor) + + # x_f_aug = self.data_transform_fd(frequency_tensor, 0.05) + z_t_aug, z_d_aug, z_f_aug = self._forward_features(x_t_aug, x_d_aug, x_f_aug) loss_cl = self.info_nce_loss(z_t, z_t_aug, self.tau) + \ self.info_nce_loss(z_d, z_d_aug, self.tau) + \ From 9b9a8fcb34a507efd87bf086590f00de3af34efa Mon Sep 17 00:00:00 2001 From: coderookie1994 Date: Sat, 18 Apr 2026 12:25:36 +0530 Subject: [PATCH 10/21] Added type hints and fixed the return type for the 'forward' method --- .../models/multi_view_contrastive_model.py | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/pyhealth/models/multi_view_contrastive_model.py b/pyhealth/models/multi_view_contrastive_model.py index c75332fd9..d97d82142 100644 --- a/pyhealth/models/multi_view_contrastive_model.py +++ b/pyhealth/models/multi_view_contrastive_model.py @@ -4,7 +4,7 @@ import numpy as np from pyhealth.models import BaseModel -from typing import cast +from typing import Tuple, cast class MultiViewContrastiveModel(BaseModel): """A simple multi-view contrastive model for demonstration purposes.""" @@ -22,32 +22,32 @@ def __init__(self, dataset, training_stage="pretrain", num_classes=3, **kwargs): self.proj_d = nn.Linear(1, self.hidden_dim) self.proj_f = nn.Linear(1, self.hidden_dim) - def make_encoder(): + def make_encoder() -> nn.TransformerEncoder: encoder_layer = nn.TransformerEncoderLayer( d_model=self.hidden_dim, nhead=4, batch_first=True, dropout=0.2 ) return nn.TransformerEncoder(encoder_layer, num_layers=3) - self.encoder_t = make_encoder() - self.encoder_d = make_encoder() - self.encoder_f = make_encoder() + self.encoder_t: nn.TransformerEncoder = make_encoder() + self.encoder_d: nn.TransformerEncoder = make_encoder() + self.encoder_f: nn.TransformerEncoder = make_encoder() # Now we need MHA - self.fusion_mha = nn.MultiheadAttention(embed_dim=self.hidden_dim, num_heads=4, batch_first=True) - self.fusion_layer_norm = nn.LayerNorm(self.hidden_dim) + self.fusion_mha: nn.MultiheadAttention = nn.MultiheadAttention(embed_dim=self.hidden_dim, num_heads=4, batch_first=True) + self.fusion_layer_norm: nn.LayerNorm = nn.LayerNorm(self.hidden_dim) # Feature-specific projectors - def projector(): + def projector() -> nn.Sequential: return nn.Sequential( nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU(), nn.Linear(self.hidden_dim, self.hidden_dim), ) - self.F_t = projector() - self.F_d = projector() - self.F_f = projector() + self.F_t: nn.Sequential = projector() + self.F_d: nn.Sequential = projector() + self.F_f: nn.Sequential = projector() # self.classifier = nn.Linear(self.hidden_dim, num_classes) - self.classifier = nn.Sequential( + self.classifier: nn.Sequential = nn.Sequential( nn.Linear(self.hidden_dim * 3 , 1024), nn.ReLU(), nn.Linear(1024 , 512), @@ -77,7 +77,7 @@ def add_frequency(self, x: torch.Tensor, pertub_ratio: float = 0.0) -> torch.Ten pertub_matrix = mask * random_am return x + pertub_matrix - def info_nce_loss(self, z_i, z_j, tau, symmetric=True): + def info_nce_loss(self, z_i: torch.Tensor, z_j: torch.Tensor, tau: float, symmetric: bool = True) -> torch.Tensor: # Compute cosine similarity z_i = F.normalize(z_i, dim=1) z_j = F.normalize(z_j, dim=1) @@ -94,7 +94,7 @@ def info_nce_loss(self, z_i, z_j, tau, symmetric=True): return loss_ij - def _forward_features(self, x_t, x_d, x_f): + def _forward_features(self, x_t: torch.Tensor, x_d: torch.Tensor, x_f: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: x_t = self.proj_t(x_t) x_d = self.proj_d(x_d) x_f = self.proj_f(x_f) @@ -138,7 +138,7 @@ def forward(self, **kwargs) -> dict[str, torch.Tensor]: self.info_nce_loss(z_f, z_f_aug, self.tau) # print (f"Pretrain Loss: {loss.item():.4f}") return {"loss": loss, - "zs":[z_t, z_d, z_f] + "z_t": z_t, "z_d": z_d, "z_f": z_f } # Return the embeddings for each view elif self.training_stage == "finetune": @@ -180,7 +180,7 @@ def forward(self, **kwargs) -> dict[str, torch.Tensor]: } return {} - def _prepare_tensor(self, x): + def _prepare_tensor(self, x) -> torch.Tensor: """Converts lists to batched tensors, enforces float32, and moves to device.""" if isinstance(x, list): if isinstance(x[0], torch.Tensor): From 64112974816dc56d94b17c70dc623784c09f28d1 Mon Sep 17 00:00:00 2001 From: coderookie1994 Date: Sat, 18 Apr 2026 19:43:10 +0530 Subject: [PATCH 11/21] Fixed data loss bug and updated the loss function. --- .../models/multi_view_contrastive_model.py | 127 +++++++++++------- pyhealth/tasks/mvcl_training_sleepedf_task.py | 42 +++--- 2 files changed, 101 insertions(+), 68 deletions(-) diff --git a/pyhealth/models/multi_view_contrastive_model.py b/pyhealth/models/multi_view_contrastive_model.py index d97d82142..926b224bf 100644 --- a/pyhealth/models/multi_view_contrastive_model.py +++ b/pyhealth/models/multi_view_contrastive_model.py @@ -7,14 +7,14 @@ from typing import Tuple, cast class MultiViewContrastiveModel(BaseModel): - """A simple multi-view contrastive model for demonstration purposes.""" + """A multi-view contrastive model aligned with Oh and Bui (2025) and TFC.""" def __init__(self, dataset, training_stage="pretrain", num_classes=3, **kwargs): super().__init__(dataset=dataset) self.hidden_dim = 128 seq_length = 256 self.training_stage = training_stage - self.lambda_cl = 0.001 + self.lambda_cl = 0.1 self.tau = 0.07 self.num_classes = num_classes @@ -27,26 +27,28 @@ def make_encoder() -> nn.TransformerEncoder: d_model=self.hidden_dim, nhead=4, batch_first=True, dropout=0.2 ) return nn.TransformerEncoder(encoder_layer, num_layers=3) + self.encoder_t: nn.TransformerEncoder = make_encoder() self.encoder_d: nn.TransformerEncoder = make_encoder() self.encoder_f: nn.TransformerEncoder = make_encoder() - # Now we need MHA + # MHA for Hierarchical Fusion self.fusion_mha: nn.MultiheadAttention = nn.MultiheadAttention(embed_dim=self.hidden_dim, num_heads=4, batch_first=True) self.fusion_layer_norm: nn.LayerNorm = nn.LayerNorm(self.hidden_dim) - # Feature-specific projectors + # Feature-specific projectors (with BatchNorm1d aligned to TFC) def projector() -> nn.Sequential: return nn.Sequential( nn.Linear(self.hidden_dim, self.hidden_dim), + nn.BatchNorm1d(self.hidden_dim), nn.ReLU(), nn.Linear(self.hidden_dim, self.hidden_dim), ) + self.F_t: nn.Sequential = projector() self.F_d: nn.Sequential = projector() self.F_f: nn.Sequential = projector() - # self.classifier = nn.Linear(self.hidden_dim, num_classes) self.classifier: nn.Sequential = nn.Sequential( nn.Linear(self.hidden_dim * 3 , 1024), nn.ReLU(), @@ -56,12 +58,13 @@ def projector() -> nn.Sequential: nn.Linear(512, self.num_classes) ) - def augment(self, x, std=0.1): - # Placeholder for time-series augmentation (e.g., adding Gaussian noise) + def augment_time(self, x: torch.Tensor, std: float = 0.1) -> torch.Tensor: + """Time-domain jitter augmentation""" noise = torch.randn_like(x) * std return x + noise - def data_transform_fd(self, sample: torch.Tensor, pertub_ratio: float = 0.05) -> torch.Tensor: + def augment_freq(self, sample: torch.Tensor, pertub_ratio: float = 0.05) -> torch.Tensor: + """Frequency-domain augmentation (remove and add frequencies)""" aug_1 = self.remove_frequency(sample, pertub_ratio) aug_2 = self.add_frequency(sample, pertub_ratio) return aug_1 + aug_2 @@ -77,22 +80,42 @@ def add_frequency(self, x: torch.Tensor, pertub_ratio: float = 0.0) -> torch.Ten pertub_matrix = mask * random_am return x + pertub_matrix - def info_nce_loss(self, z_i: torch.Tensor, z_j: torch.Tensor, tau: float, symmetric: bool = True) -> torch.Tensor: - # Compute cosine similarity - z_i = F.normalize(z_i, dim=1) - z_j = F.normalize(z_j, dim=1) + def ntxent_loss(self, zis: torch.Tensor, zjs: torch.Tensor, tau: float) -> torch.Tensor: + """2N x 2N NTXentLoss aligned with the TFC implementation.""" + batch_size = zis.size(0) - sim_ij = torch.mm(z_i, z_j.T) / tau - # Positive pairs are on the diagonal - labels = torch.arange(sim_ij.size(0), device=sim_ij.device) - loss_ij = F.cross_entropy(sim_ij, labels) + # Normalize the representations + zis = F.normalize(zis, dim=1) + zjs = F.normalize(zjs, dim=1) - if symmetric: - sim_ji = torch.mm(z_j, z_i.T) / tau - loss_ji = F.cross_entropy(sim_ji, labels) - return (loss_ij + loss_ji) / 2 - - return loss_ij + # Concatenate into 2N + representations = torch.cat([zjs, zis], dim=0) # [2N, hidden_dim] + + # Compute 2Nx2N cosine similarity matrix + similarity_matrix = torch.mm(representations, representations.T) + + # Extract the positive pairs (offset by batch_size) + l_pos = torch.diag(similarity_matrix, batch_size) + r_pos = torch.diag(similarity_matrix, -batch_size) + positives = torch.cat([l_pos, r_pos]).view(2 * batch_size, 1) + + # Create a mask to remove self-similarity (the diagonal) + mask = (~torch.eye(2 * batch_size, 2 * batch_size, dtype=torch.bool, device=zis.device)) + + # Extract negatives (everything except the diagonal) + negatives = similarity_matrix[mask].view(2 * batch_size, -1) + + # Concatenate logits: [positives, negatives] + logits = torch.cat((positives, negatives), dim=1) + logits /= tau + + # The positive sample is always at index 0 for each row + labels = torch.zeros(2 * batch_size, dtype=torch.long, device=zis.device) + + # PyTorch CrossEntropy applies the log-softmax calculation + loss = F.cross_entropy(logits, labels, reduction="sum") + + return loss / (2 * batch_size) def _forward_features(self, x_t: torch.Tensor, x_d: torch.Tensor, x_f: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: x_t = self.proj_t(x_t) @@ -112,9 +135,16 @@ def _forward_features(self, x_t: torch.Tensor, x_d: torch.Tensor, x_f: torch.Ten H_out = H_out.view(batch_size, seq_length, 3, self.hidden_dim) h_t_star, h_d_star, h_f_star = H_out[:, :, 0, :], H_out[:, :, 1, :], H_out[:, :, 2, :] - z_t = self.F_t(h_t_star).mean(dim=1) - z_d = self.F_d(h_d_star).mean(dim=1) - z_f = self.F_f(h_f_star).mean(dim=1) + + # Pool across sequence length FIRST so BatchNorm1d receives 2D inputs [N, C] + h_t_pool = h_t_star.mean(dim=1) + h_d_pool = h_d_star.mean(dim=1) + h_f_pool = h_f_star.mean(dim=1) + + z_t = self.F_t(h_t_pool) + z_d = self.F_d(h_d_pool) + z_f = self.F_f(h_f_pool) + return z_t, z_d, z_f def forward(self, **kwargs) -> dict[str, torch.Tensor]: @@ -124,27 +154,31 @@ def forward(self, **kwargs) -> dict[str, torch.Tensor]: # --- Stage Routing --- if self.training_stage == "pretrain": - x_t_aug = self.augment(temporal_tensor) - x_d_aug = self.augment(derivative_tensor) - x_f_aug = self.augment(frequency_tensor) - - # x_f_aug = self.data_transform_fd(frequency_tensor, 0.05) - + # 1. Apply domain-specific augmentations + x_t_aug = self.augment_time(temporal_tensor) + x_d_aug = self.augment_time(derivative_tensor) + x_f_aug = self.augment_freq(frequency_tensor) + # 2. Encode z_t, z_d, z_f = self._forward_features(temporal_tensor, derivative_tensor, frequency_tensor) z_t_aug, z_d_aug, z_f_aug = self._forward_features(x_t_aug, x_d_aug, x_f_aug) - loss = self.info_nce_loss(z_t, z_t_aug, self.tau) + \ - self.info_nce_loss(z_d, z_d_aug, self.tau) + \ - self.info_nce_loss(z_f, z_f_aug, self.tau) - # print (f"Pretrain Loss: {loss.item():.4f}") - return {"loss": loss, - "z_t": z_t, "z_d": z_d, "z_f": z_f - } # Return the embeddings for each view + + # 3. Apply 2N x 2N NTXentLoss + loss = self.ntxent_loss(z_t, z_t_aug, self.tau) + \ + self.ntxent_loss(z_d, z_d_aug, self.tau) + \ + self.ntxent_loss(z_f, z_f_aug, self.tau) + + # Dict strictly containing torch.Tensor + return { + "loss": loss, + "z_t": z_t, + "z_d": z_d, + "z_f": z_f + } elif self.training_stage == "finetune": z_t, z_d, z_f = self._forward_features(temporal_tensor, derivative_tensor, frequency_tensor) z_combined = torch.cat([z_t, z_d, z_f], dim=1) - z_combined = z_combined.view(z_combined.size(0), -1) logits = self.classifier(z_combined) # Use PyHealth's automatic label parsing @@ -158,16 +192,14 @@ def forward(self, **kwargs) -> dict[str, torch.Tensor]: loss_ce = criterion(logits, y_true) # Contrastive penalty during finetuning - x_t_aug = self.augment(temporal_tensor) - x_d_aug = self.augment(derivative_tensor) - x_f_aug = self.augment(frequency_tensor) - - # x_f_aug = self.data_transform_fd(frequency_tensor, 0.05) + x_t_aug = self.augment_time(temporal_tensor) + x_d_aug = self.augment_time(derivative_tensor) + x_f_aug = self.augment_freq(frequency_tensor) z_t_aug, z_d_aug, z_f_aug = self._forward_features(x_t_aug, x_d_aug, x_f_aug) - loss_cl = self.info_nce_loss(z_t, z_t_aug, self.tau) + \ - self.info_nce_loss(z_d, z_d_aug, self.tau) + \ - self.info_nce_loss(z_f, z_f_aug, self.tau) + loss_cl = self.ntxent_loss(z_t, z_t_aug, self.tau) + \ + self.ntxent_loss(z_d, z_d_aug, self.tau) + \ + self.ntxent_loss(z_f, z_f_aug, self.tau) total_loss = (self.lambda_cl * loss_cl) + loss_ce @@ -186,7 +218,6 @@ def _prepare_tensor(self, x) -> torch.Tensor: if isinstance(x[0], torch.Tensor): x = torch.stack(x) else: - import numpy as np x = torch.from_numpy(np.stack(x)) # Enforce standard float precision and push to GPU diff --git a/pyhealth/tasks/mvcl_training_sleepedf_task.py b/pyhealth/tasks/mvcl_training_sleepedf_task.py index b11342663..946a6d08c 100644 --- a/pyhealth/tasks/mvcl_training_sleepedf_task.py +++ b/pyhealth/tasks/mvcl_training_sleepedf_task.py @@ -133,26 +133,28 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: event_buffers: List[Dict[str, Any]] = [] for epi in range(n_epochs): - lab = _map_to_MVCL_five_class(int(labels[epi])) - - # Take only the first window of the epoch to match TFC-pretraining's sample count - seg = signals[epi, :win].astype(np.float32, copy=False) - if crop is not None: - seg = seg[:crop] - - event_buffers.append( - { - "seg_1d": seg.copy(), - "label": lab, - "night": event.night, - "patient_age": event.age, - "patient_sex": event.sex, - "epoch_index": global_epoch, - "window_in_epoch": 0, - } - ) - global_epoch += 1 - + lab = _map_to_MVCL_five_class(int(labels[epi])) + row = signals[epi, :n_full] # Get the full 3000 points + + # Extract all 15 non-overlapping windows per epoch + for w in range(n_full // win): + seg = row[w * win : (w + 1) * win].astype(np.float32, copy=False) + if crop is not None: + seg = seg[:crop] + + event_buffers.append( + { + "seg_1d": seg.copy(), + "label": lab, + "night": event.night, + "patient_age": event.age, + "patient_sex": event.sex, + "epoch_index": global_epoch, + "window_in_epoch": w, # properly track the window + } + ) + global_epoch += 1 + if not event_buffers: continue From 37027ab058fcc681062927882b23528a1883b822 Mon Sep 17 00:00:00 2001 From: gaohey Date: Sun, 19 Apr 2026 14:06:05 +0800 Subject: [PATCH 12/21] 1. H dimentsion/permute fix(critial); 2. add positional encoding; 3. amend pooling to include residuals; 4. add MHA to finetune --- .../models/multi_view_contrastive_model.py | 62 +++++++++++++------ 1 file changed, 44 insertions(+), 18 deletions(-) diff --git a/pyhealth/models/multi_view_contrastive_model.py b/pyhealth/models/multi_view_contrastive_model.py index 926b224bf..e457e7767 100644 --- a/pyhealth/models/multi_view_contrastive_model.py +++ b/pyhealth/models/multi_view_contrastive_model.py @@ -2,10 +2,29 @@ import torch.nn as nn import torch.nn.functional as F import numpy as np +import math from pyhealth.models import BaseModel from typing import Tuple, cast +class PositionalEncoding(nn.Module): + def __init__(self, hidden_dim, dropout=0.1, max_len=1024): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, hidden_dim) + position = torch.arange(0, max_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, hidden_dim, 2) * + (-math.log(10000.0) / hidden_dim)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) # Shape: [1, max_len, hidden_dim] + self.register_buffer('pe', pe) + + def forward(self, x): + x = x + self.pe[:, :x.size(1)] + return self.dropout(x) + class MultiViewContrastiveModel(BaseModel): """A multi-view contrastive model aligned with Oh and Bui (2025) and TFC.""" @@ -22,6 +41,8 @@ def __init__(self, dataset, training_stage="pretrain", num_classes=3, **kwargs): self.proj_d = nn.Linear(1, self.hidden_dim) self.proj_f = nn.Linear(1, self.hidden_dim) + self.pos_encoder = PositionalEncoding(self.hidden_dim, dropout=0.1) + def make_encoder() -> nn.TransformerEncoder: encoder_layer = nn.TransformerEncoderLayer( d_model=self.hidden_dim, nhead=4, batch_first=True, dropout=0.2 @@ -36,12 +57,13 @@ def make_encoder() -> nn.TransformerEncoder: self.fusion_mha: nn.MultiheadAttention = nn.MultiheadAttention(embed_dim=self.hidden_dim, num_heads=4, batch_first=True) self.fusion_layer_norm: nn.LayerNorm = nn.LayerNorm(self.hidden_dim) - # Feature-specific projectors (with BatchNorm1d aligned to TFC) + # Feature-specific projectors def projector() -> nn.Sequential: return nn.Sequential( - nn.Linear(self.hidden_dim, self.hidden_dim), - nn.BatchNorm1d(self.hidden_dim), + nn.Linear(self.hidden_dim * 2, self.hidden_dim), + nn.LayerNorm(self.hidden_dim), nn.ReLU(), + nn.Dropout(0.2), nn.Linear(self.hidden_dim, self.hidden_dim), ) @@ -49,14 +71,8 @@ def projector() -> nn.Sequential: self.F_d: nn.Sequential = projector() self.F_f: nn.Sequential = projector() - self.classifier: nn.Sequential = nn.Sequential( - nn.Linear(self.hidden_dim * 3 , 1024), - nn.ReLU(), - nn.Linear(1024 , 512), - nn.ReLU(), - nn.Dropout(0.5), - nn.Linear(512, self.num_classes) - ) + self.classifier_mha = nn.MultiheadAttention(embed_dim=self.hidden_dim, num_heads=1, batch_first=True) + self.classifier = nn.Linear(self.hidden_dim * 3, self.num_classes) def augment_time(self, x: torch.Tensor, std: float = 0.1) -> torch.Tensor: """Time-domain jitter augmentation""" @@ -122,24 +138,28 @@ def _forward_features(self, x_t: torch.Tensor, x_d: torch.Tensor, x_f: torch.Ten x_d = self.proj_d(x_d) x_f = self.proj_f(x_f) + x_t = self.pos_encoder(x_t) + x_d = self.pos_encoder(x_d) + x_f = self.pos_encoder(x_f) + h_t = self.encoder_t(x_t) h_d = self.encoder_d(x_d) h_f = self.encoder_f(x_f) batch_size, seq_length, _ = h_t.shape H = torch.stack([h_t, h_d, h_f], dim=2) - H_flat = H.view(-1, 3, self.hidden_dim) + H_flat = H.permute(0, 2, 1, 3).contiguous().view(batch_size * 3, seq_length, self.hidden_dim) MHA_out, _ = self.fusion_mha(H_flat, H_flat, H_flat) H_out = self.fusion_layer_norm(MHA_out + H_flat) - H_out = H_out.view(batch_size, seq_length, 3, self.hidden_dim) + H_out = H_out.view(batch_size, 3, seq_length, self.hidden_dim).permute(0, 2, 1, 3) h_t_star, h_d_star, h_f_star = H_out[:, :, 0, :], H_out[:, :, 1, :], H_out[:, :, 2, :] - # Pool across sequence length FIRST so BatchNorm1d receives 2D inputs [N, C] - h_t_pool = h_t_star.mean(dim=1) - h_d_pool = h_d_star.mean(dim=1) - h_f_pool = h_f_star.mean(dim=1) + # Pool across sequence length and concatenate with pre-interaction features + h_t_pool = torch.cat([h_t.mean(dim=1), h_t_star.mean(dim=1)], dim=-1) + h_d_pool = torch.cat([h_d.mean(dim=1), h_d_star.mean(dim=1)], dim=-1) + h_f_pool = torch.cat([h_f.mean(dim=1), h_f_star.mean(dim=1)], dim=-1) z_t = self.F_t(h_t_pool) z_d = self.F_d(h_d_pool) @@ -178,7 +198,13 @@ def forward(self, **kwargs) -> dict[str, torch.Tensor]: elif self.training_stage == "finetune": z_t, z_d, z_f = self._forward_features(temporal_tensor, derivative_tensor, frequency_tensor) - z_combined = torch.cat([z_t, z_d, z_f], dim=1) + + # Cross-view attention for classification + stacked_emb = torch.stack([z_t, z_d, z_f], dim=1) # [batch_size, 3, hidden_dim] + attn_out, _ = self.classifier_mha(stacked_emb, stacked_emb, stacked_emb) + emb = attn_out + stacked_emb # Residual connection + + z_combined = emb.reshape(emb.size(0), -1) # Flatten to [batch_size, 3 * hidden_dim] logits = self.classifier(z_combined) # Use PyHealth's automatic label parsing From ebb078778e2e143a03dfb7678ff1c48b0bc649b3 Mon Sep 17 00:00:00 2001 From: gaohey Date: Mon, 20 Apr 2026 14:24:55 +0800 Subject: [PATCH 13/21] small updated in task; plus unit test/exmaples/api docs --- docs/api/tasks.rst | 1 + .../pyhealth.tasks.MVCLTrainingSleepEEG.rst | 7 + examples/mvcl_training_sleepedf.ipynb | 111 ++++++++++++ pyhealth/tasks/mvcl_training_sleepedf_task.py | 17 +- .../core/test_mvcl_training_sleepedf_task.py | 166 ++++++++++++++++++ 5 files changed, 293 insertions(+), 9 deletions(-) create mode 100644 docs/api/tasks/pyhealth.tasks.MVCLTrainingSleepEEG.rst create mode 100644 examples/mvcl_training_sleepedf.ipynb create mode 100644 tests/core/test_mvcl_training_sleepedf_task.py diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 399b8f1aa..6eaa0e051 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -220,6 +220,7 @@ Available Tasks Readmission Prediction Sleep Staging Sleep Staging (SleepEDF) + MVCL Training (SleepEDF EEG) Temple University EEG Tasks Sleep Staging v2 Benchmark EHRShot diff --git a/docs/api/tasks/pyhealth.tasks.MVCLTrainingSleepEEG.rst b/docs/api/tasks/pyhealth.tasks.MVCLTrainingSleepEEG.rst new file mode 100644 index 000000000..3017180ba --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.MVCLTrainingSleepEEG.rst @@ -0,0 +1,7 @@ +pyhealth.tasks.MVCLTrainingSleepEEG +=================================== + +.. autoclass:: pyhealth.tasks.MVCLTrainingSleepEEG + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/mvcl_training_sleepedf.ipynb b/examples/mvcl_training_sleepedf.ipynb new file mode 100644 index 000000000..9c556f6ef --- /dev/null +++ b/examples/mvcl_training_sleepedf.ipynb @@ -0,0 +1,111 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# MVCLTrainingSleepEEG on Sleep-EDF\n", + "\n", + "This example shows how to use `MVCLTrainingSleepEEG` from `pyhealth.tasks`.\n", + "\n", + "The workflow mirrors the task pattern used in PyHealth examples:\n", + "1. Load `SleepEDFDataset`\n", + "2. Run the task on one patient for a quick sanity check\n", + "3. Optionally run `set_task()` for the full dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pyhealth.datasets import SleepEDFDataset\n", + "from pyhealth.tasks import MVCLTrainingSleepEEG\n", + "\n", + "# Update this path to your local Sleep-EDF root.\n", + "DATA_ROOT = \"../sleepedf\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = SleepEDFDataset(root=DATA_ROOT, subset=\"cassette\")\n", + "dataset.stats()\n", + "print(f\"Number of patients: {len(dataset.unique_patient_ids)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d838e3b9", + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"\n", + "This dataset contains 153 whole-night sleep electroencephalography\n", + "(EEG) recordings collected from 82 healthy subjects. Each recording is sampled at 100 Hz using a 1-lead \n", + "EEG signal. The EEG signals are segmented into non-overlapping windows of size 200, each forming\n", + "one sample. Each sample is labeled with one of five sleep stages: Wake (W), Non-rapid Eye Movement\n", + "(N1, N2, N3), and Rapid Eye Movement (REM). This segmentation results in 371,055 samples\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Quick sanity check on one patient.\n", + "patient_id = dataset.unique_patient_ids[0]\n", + "patient = dataset.get_patient(patient_id)\n", + "\n", + "task = MVCLTrainingSleepEEG(\n", + " window_size=200, ## Create None overlapping window of 200 lenth \n", + " crop_length=178, ## take first 178 data points of the window to match that of Epilepsy data \n", + " eeg_channel=\"EEG Fpz-Cz\",\n", + ")\n", + "samples = task(patient)\n", + "\n", + "print(f\"patient_id: {patient_id}\")\n", + "print(f\"sample count: {len(samples)}\")\n", + "print(f\"sample keys: {list(samples[0].keys())}\")\n", + "print(f\"signal shape: {samples[0]['signal'].shape}\")\n", + "print(f\"xt shape: {samples[0]['xt'].shape}\")\n", + "print(f\"xd shape: {samples[0]['xd'].shape}\")\n", + "print(f\"xf shape: {samples[0]['xf'].shape}\")\n", + "print(f\"label: {samples[0]['label']}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Full pipeline (can take a while and uses disk cache).\n", + "sample_dataset = dataset.set_task(task, num_workers=1)\n", + "print(f\"Total task samples: {len(sample_dataset)}\")\n", + "print(f\"Input schema: {sample_dataset.input_schema}\")\n", + "print(f\"Output schema: {sample_dataset.output_schema}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyhealth/tasks/mvcl_training_sleepedf_task.py b/pyhealth/tasks/mvcl_training_sleepedf_task.py index b11342663..31084a47b 100644 --- a/pyhealth/tasks/mvcl_training_sleepedf_task.py +++ b/pyhealth/tasks/mvcl_training_sleepedf_task.py @@ -34,7 +34,7 @@ class MVCLTrainingSleepEEG(BaseTask): Applies MV preprocessing per event file (one PSG/Hypnogram pair at a time), then appends samples immediately, so each returned sample includes ``xt``, - ``dx``, and ``xf`` without a patient-level global buffer. + ``xd``, and ``xf`` without a patient-level global buffer. Tensors are stored as ``numpy.float32`` arrays with shape ``(L, C_view)`` where ``C_view`` is 1 by default; with ``time_as_feature=True``, a leading time channel @@ -42,7 +42,7 @@ class MVCLTrainingSleepEEG(BaseTask): """ task_name: str = "MVCLTrainingSleepEEG" - input_schema = {"signal": "tensor"} + input_schema = {"xt": "tensor", "xd": "tensor", "xf": "tensor"} output_schema = {"label": "multiclass"} def __init__( @@ -176,9 +176,9 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: "epoch_index": b["epoch_index"], "window_in_epoch": b["window_in_epoch"], "signal": vec, - "xt": xt[i].detach().cpu().numpy().astype(np.float32), - "xd": dx[i].detach().cpu().numpy().astype(np.float32), - "xf": xf[i].detach().cpu().numpy().astype(np.float32), + "xt": xt[i], + "xd": dx[i], + "xf": xf[i], "label": b["label"], } ) @@ -376,9 +376,9 @@ def pt_dict_to_pyhealth_samples( "patient_id": f"{patient_id_prefix}_{i}", "record_id": f"{record_id_prefix}_{i}", "signal": signal_array[i][np.newaxis, :], - "xt": xt[i].detach().cpu().numpy().astype(np.float32), - "xd": dx[i].detach().cpu().numpy().astype(np.float32), - "xf": xf[i].detach().cpu().numpy().astype(np.float32), + "xt": xt[i], + "xd": dx[i], + "xf": xf[i], "label": int(label_array[i]), } ) @@ -415,7 +415,6 @@ def pt_file_to_sample_dataset( return create_sample_dataset( samples=samples, input_schema={ - "signal": "tensor", "xt": "tensor", "xd": "tensor", "xf": "tensor", diff --git a/tests/core/test_mvcl_training_sleepedf_task.py b/tests/core/test_mvcl_training_sleepedf_task.py new file mode 100644 index 000000000..08f66d82d --- /dev/null +++ b/tests/core/test_mvcl_training_sleepedf_task.py @@ -0,0 +1,166 @@ +import shutil +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch +from collections import Counter + +import mne +import numpy as np +import pandas as pd + +from pyhealth.datasets import SleepEDFDataset +from pyhealth.tasks import MVCLTrainingSleepEEG + + +class TestMVCLTrainingSleepEEGTask(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.temp_dir = tempfile.mkdtemp() + cls.dummy_dataset_dir = Path(cls.temp_dir) / "dummy_dataset" + cls.cassette_dir = cls.dummy_dataset_dir / "sleep-cassette" + cls.cassette_dir.mkdir(parents=True, exist_ok=True) + + cls._create_dummy_subject_spreadsheets() + cls._create_dummy_patient_files() + cls._create_dummy_metadata_csv() + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.temp_dir, ignore_errors=True) + + @classmethod + def _create_dummy_subject_spreadsheets(cls): + """Create required SC/ST files in both requested locations.""" + df = pd.DataFrame( + { + "subject": [1, 2], + "night": [1, 2], + "age": [25, 30], + "sex (F=1)": [1, 2], + "LightsOff": ["22:00", "22:30"], + } + ) + spreadsheet_targets = [ + cls.dummy_dataset_dir / "SC-subjects.xls", + cls.dummy_dataset_dir / "ST-subjects.xls", + cls.cassette_dir / "SC-subjects.xls", + cls.cassette_dir / "ST-subjects.xls", + ] + for path in spreadsheet_targets: + # These files are only test placeholders; metadata is loaded from + # sleepedf-cassette-pyhealth.csv created below. + df.to_csv(path, index=False) + + @classmethod + def _create_dummy_patient_files(cls): + """Create two patients with 2 x 3000 dummy points each.""" + # Expected cassette metadata rows in `sleepedf-cassette-pyhealth.csv` look like: + # subject,night,age,sex,lights_off,signal_file,label_file + # 1,1,25,F,22:00,<...>/SC4011E0-PSG.edf,<...>/SC4011E0-Hypnogram.edf + # 2,2,30,M,22:30,<...>/SC4022E0-PSG.edf,<...>/SC4022E0-Hypnogram.edf + # + # This helper creates those referenced PSG/Hypnogram files so SleepEDFDataset + # can load events and MVCLTrainingSleepEEG can read per-patient signal/labels. + patient_records = [ + ("SC4011E0", 1), # subject 01, night 1 + ("SC4022E0", 2), # subject 02, night 2 + ] + for stem, seed in patient_records: + signal = np.full(6000, fill_value=seed, dtype=np.float32) + for suffix in ("-PSG.edf", "-Hypnogram.edf"): + file_path = cls.cassette_dir / f"{stem}{suffix}" + signal.tofile(file_path) + + @classmethod + def _create_dummy_metadata_csv(cls): + rows = [ + { + "subject": 1, + "night": 1, + "age": 25, + "sex": "F", + "lights_off": "22:00", + "signal_file": str(cls.cassette_dir / "SC4011E0-PSG.edf"), + "label_file": str(cls.cassette_dir / "SC4011E0-Hypnogram.edf"), + }, + { + "subject": 2, + "night": 2, + "age": 30, + "sex": "M", + "lights_off": "22:30", + "signal_file": str(cls.cassette_dir / "SC4022E0-PSG.edf"), + "label_file": str(cls.cassette_dir / "SC4022E0-Hypnogram.edf"), + }, + ] + pd.DataFrame(rows).to_csv( + cls.dummy_dataset_dir / "sleepedf-cassette-pyhealth.csv", index=False + ) + + @staticmethod + def _mock_read_raw_edf(signal_file, *args, **kwargs): + """Load the binary dummy payload from .edf placeholder file.""" + signal = np.fromfile(signal_file, dtype=np.float32) + if signal.size != 6000: + raise ValueError(f"Expected 6000 points in {signal_file}, got {signal.size}") + data = signal.reshape(1, -1) + info = mne.create_info(["EEG Fpz-Cz"], sfreq=100, ch_types=["eeg"]) + return mne.io.RawArray(data, info, verbose="error") + + @staticmethod + def _mock_read_annotations(label_file, *args, **kwargs): + """Return two 30-second sleep-stage annotations per patient.""" + name = Path(label_file).name + if "SC4011E0" in name: + descriptions = ["Sleep stage W", "Sleep stage R"] + elif "SC4022E0" in name: + descriptions = ["Sleep stage 2", "Sleep stage 4"] + else: + raise ValueError(f"Unexpected label file: {label_file}") + return mne.Annotations( + onset=[0.0, 30.0], + duration=[30.0, 30.0], + description=descriptions, + ) + + def test_import_from_pyhealth_tasks(self): + """Matches notebook usage: from pyhealth.tasks import MVCLTrainingSleepEEG.""" + self.assertTrue(callable(MVCLTrainingSleepEEG)) + + def test_sleepedf_dummy_dataset_label_mapping(self): + dataset = SleepEDFDataset(root=str(self.dummy_dataset_dir), subset="cassette") + task = MVCLTrainingSleepEEG(window_size=200, crop_length=178, eeg_channel="EEG Fpz-Cz") + + with patch( + "pyhealth.tasks.mvcl_training_sleepedf_task.mne.io.read_raw_edf", + side_effect=self._mock_read_raw_edf, + ), patch( + "pyhealth.tasks.mvcl_training_sleepedf_task.mne.read_annotations", + side_effect=self._mock_read_annotations, + ): + sample_dataset = dataset.set_task(task, num_workers=1) + + # 2 patients x 2 epochs each x (3000 / 200) windows = 60 windows + self.assertEqual(len(sample_dataset), 60) + self.assertEqual(sample_dataset.input_schema, {"xt": "tensor", "xd": "tensor", "xf": "tensor"}) + self.assertEqual(sample_dataset.output_schema, {"label": "multiclass"}) + + sample = sample_dataset[0] + for key in ("xt", "xd", "xf"): + self.assertIn(key, sample) + # MV views are stored as [L, C] in the task; enforce equivalent 1x178 content. + for key in ("xt", "xd", "xf"): + self.assertEqual(sample[key].ndim, 2) + self.assertIn(1, sample[key].shape) + self.assertIn(178, sample[key].shape) + + # set_task() encodes multiclass labels to contiguous ids, but class balance + # should still match the four injected stages (W, R, 2, 4) => 15 windows each. + label_counts = Counter(int(s["label"]) for s in sample_dataset) + self.assertEqual(len(label_counts), 4) + self.assertTrue(all(count == 15 for count in label_counts.values())) + + +if __name__ == "__main__": + unittest.main() From 8ae76ac419bf6b5201bf21f7be444298ab41f7e3 Mon Sep 17 00:00:00 2001 From: gaohey Date: Mon, 20 Apr 2026 19:11:10 +0800 Subject: [PATCH 14/21] remove loop; added docstrings --- pyhealth/tasks/mvcl_training_sleepedf_task.py | 229 ++++++++++-------- 1 file changed, 130 insertions(+), 99 deletions(-) diff --git a/pyhealth/tasks/mvcl_training_sleepedf_task.py b/pyhealth/tasks/mvcl_training_sleepedf_task.py index 014643a94..03e32ade1 100644 --- a/pyhealth/tasks/mvcl_training_sleepedf_task.py +++ b/pyhealth/tasks/mvcl_training_sleepedf_task.py @@ -30,7 +30,14 @@ def _map_to_MVCL_five_class(pyhealth_stage: int) -> int: class MVCLTrainingSleepEEG(BaseTask): - """SleepEDF windows with Multi-View contrastive tensor views. + """ + SleepEDF windows with Multi-View contrastive tensor views. + + This dataset contains 153 whole-night sleep electroencephalography + (EEG) recordings collected from 82 healthy subjects. Each recording is sampled at 100 Hz using a 1-lead + EEG signal. The EEG signals are segmented into non-overlapping windows of size 200, each forming + one sample. Each sample is labeled with one of five sleep stages: Wake (W), Non-rapid Eye Movement + (N1, N2, N3), and Rapid Eye Movement (REM). This segmentation results in 371,055 samples. Applies MV preprocessing per event file (one PSG/Hypnogram pair at a time), then appends samples immediately, so each returned sample includes ``xt``, @@ -39,6 +46,23 @@ class MVCLTrainingSleepEEG(BaseTask): Tensors are stored as ``numpy.float32`` arrays with shape ``(L, C_view)`` where ``C_view`` is 1 by default; with ``time_as_feature=True``, a leading time channel in ``[0,1]`` is concatenated so ``C_view`` is 2. + + Attributes: + task_name (str): The name of the task. + input_schema (Dict[str, str]): The schema for the task input. + output_schema (Dict[str, str]): The schema for the task output. + + Examples: + >>> from pyhealth.datasets import SleepEDFDataset + >>> from pyhealth.tasks import MVCLTrainingSleepEEG + >>> import os + >>> os.chdir("/path/to/sleep-edf") + >>> dataset = SleepEDFDataset( + ... root="/path/to/sleep-edf", + ... ) + >>> task = MVCLTrainingSleepEEG() + >>> samples = dataset.set_task(task) + >>> print(samples[0]) """ task_name: str = "MVCLTrainingSleepEEG" @@ -54,6 +78,16 @@ def __init__( time_as_feature: bool = False, dx_backend: str = "cde", ) -> None: + """Initializes the task object. + + Args: + chunk_duration: How long each chunk of EEG signal is (in seconds). Defaults to 30.0. + window_size: Number of samples per window. Defaults to 200. + crop_length: Optional length to crop the windows to. Defaults to 178. + eeg_channel: Which EEG channel to pick. Defaults to "EEG Fpz-Cz". + time_as_feature: Whether to add a time feature channel. Defaults to False. + dx_backend: Backend to use for computing the derivative view. Defaults to "cde". + """ self.chunk_duration = float(chunk_duration) self.window_size = int(window_size) self.crop_length = int(crop_length) if crop_length is not None else None @@ -72,6 +106,26 @@ def _pick_eeg_index(self, ch_names: List[str]) -> int: return 0 def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """ + Generates classification data samples for a single patient. + + Args: + patient (Any): A PyHealth patient object containing sleep events. + + Returns: + List[Dict[str, Any]]: A list containing a dictionary for each sleep window sample with: + - 'patient_id': Patient identifier. + - 'night': The night number of the recording. + - 'patient_age': Age of the patient. + - 'patient_sex': Sex of the patient. + - 'epoch_index': Global index of the 30s epoch. + - 'window_in_epoch': Index of the window within the epoch. + - 'signal': Original raw signal slice. + - 'xt': Time-domain view tensor. + - 'xd': Derivative view tensor. + - 'xf': Frequency-domain view tensor. + - 'label': Mapped 5-class sleep stage label. + """ pid = patient.patient_id events = patient.get_events() samples: List[Dict[str, Any]] = [] @@ -129,72 +183,62 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: labels = epochs_train.events[:, 2] n_epochs, n_times = signals.shape - n_full = (n_times // win) * win - - event_buffers: List[Dict[str, Any]] = [] - for epi in range(n_epochs): - lab = _map_to_MVCL_five_class(int(labels[epi])) - row = signals[epi, :n_full] # Get the full 3000 points - - # Extract all 15 non-overlapping windows per epoch - for w in range(n_full // win): - seg = row[w * win : (w + 1) * win].astype(np.float32, copy=False) - if crop is not None: - seg = seg[:crop] - - event_buffers.append( - { - "seg_1d": seg.copy(), - "label": lab, - "night": event.night, - "patient_age": event.age, - "patient_sex": event.sex, - "epoch_index": global_epoch, - "window_in_epoch": w, # properly track the window - } - ) - global_epoch += 1 - - if not event_buffers: + n_windows = n_times // win + n_full = n_windows * win + + if n_epochs == 0 or n_windows == 0: continue - X = torch.stack( - [torch.from_numpy(b["seg_1d"]).float() for b in event_buffers], dim=0 - ).unsqueeze(-1) - xt, dx, xf = preprocess_mvcl_views( - X, + # Vectorized window extraction + segs = signals[:, :n_full].reshape(n_epochs, n_windows, win) + if crop is not None: + segs = segs[:, :, :crop] + segs = segs.reshape(-1, segs.shape[-1]).astype(np.float32) + + # Vectorized metadata mapping + mapping = np.array([0, 1, 2, 3, 3, 4], dtype=np.int64) + mapped_labels = mapping[labels.astype(np.int64)] + labels_rep = np.repeat(mapped_labels, n_windows) + epoch_indices = np.repeat(np.arange(global_epoch, global_epoch + n_epochs), n_windows) + window_indices = np.tile(np.arange(n_windows), n_epochs) + + global_epoch += n_epochs + + # Create numpy array directly + X_np = segs[..., np.newaxis] + xt_np, dx_np, xf_np = preprocess_mvcl_views_numpy( + X_np, time_as_feature=self.time_as_feature ) - for i, b in enumerate(event_buffers): - seg = b["seg_1d"] - vec = seg[np.newaxis, :] + # Construct samples in a single loop + for i in range(len(segs)): samples.append( { "patient_id": pid, - "night": b["night"], - "patient_age": b["patient_age"], - "patient_sex": b["patient_sex"], - "epoch_index": b["epoch_index"], - "window_in_epoch": b["window_in_epoch"], - "signal": vec, - "xt": xt[i], - "xd": dx[i], - "xf": xf[i], - "label": b["label"], + "night": event.night, + "patient_age": event.age, + "patient_sex": event.sex, + "epoch_index": int(epoch_indices[i]), + "window_in_epoch": int(window_indices[i]), + "signal": segs[i][np.newaxis, :].copy(), + "xt": torch.from_numpy(xt_np[i]), + "xd": torch.from_numpy(dx_np[i]), + "xf": torch.from_numpy(xf_np[i]), + "label": int(labels_rep[i]), } ) return samples -def normalize_mvcl( - X_train: torch.Tensor, - X_test: torch.Tensor, +def normalize_mvcl_numpy( + X_train: np.ndarray, + X_test: np.ndarray, epsilon: float = 1e-8, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - mean = X_train.mean(dim=(0, 1), keepdim=True) - std = X_train.std(dim=(0, 1), keepdim=True).clamp(min=epsilon) +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + mean = X_train.mean(axis=(0, 1), keepdims=True) + std = np.maximum(X_train.std(axis=(0, 1), keepdims=True), epsilon) return ( (X_train - mean) / std, (X_test - mean) / std, @@ -203,32 +247,24 @@ def normalize_mvcl( ) -def add_time_feature(X: torch.Tensor) -> torch.Tensor: +def add_time_feature_numpy(X: np.ndarray) -> np.ndarray: """X: [num_samples, sequence_length, num_features] -> concat time in last dim.""" num_samples, seq_length, _ = X.shape - time_index = torch.linspace(0, 1, steps=seq_length, dtype=X.dtype, device=X.device) - time_feature = time_index.view(1, seq_length, 1).expand(num_samples, seq_length, 1) - return torch.cat([time_feature, X], dim=-1) - - - -def get_dx_gradient(X: torch.Tensor) -> torch.Tensor: - """Time derivative via ``torch.gradient`` along **dim=1** for **X [N, L, D]**. - - This is **not** equivalent to :func:`get_dx` (torchcde spline);. - """ - if X.ndim != 3: - raise ValueError(f"Expected [N, L, D], got {tuple(X.shape)}") - return torch.gradient(X, dim=1)[0] + time_index = np.linspace(0, 1, num=seq_length, dtype=X.dtype) + time_feature = np.broadcast_to( + time_index.reshape(1, seq_length, 1), + (num_samples, seq_length, 1) + ) + return np.concatenate([time_feature, X], axis=-1) -def get_dx_torchcde_equivalent(X: torch.Tensor) -> torch.Tensor: +def get_dx_torchcde_equivalent_numpy(X: np.ndarray) -> np.ndarray: """ - Pure PyTorch equivalent of torchcde Hermite cubic spline derivative + Pure NumPy equivalent of torchcde Hermite cubic spline derivative evaluated at the knot points with backward differences. """ N, L, D = X.shape - dx = torch.zeros_like(X) + dx = np.zeros_like(X) if L < 2: return dx @@ -239,7 +275,7 @@ def get_dx_torchcde_equivalent(X: torch.Tensor) -> torch.Tensor: derivs = X[:, 1:, :] - X[:, :-1, :] # derivs_prev[i] = derivs[i-1] for i > 0, and derivs[0] for i = 0 - derivs_prev = torch.cat([derivs[:, :1, :], derivs[:, :-1, :]], dim=1) + derivs_prev = np.concatenate([derivs[:, :1, :], derivs[:, :-1, :]], axis=1) # For i = 0 dx[:, 0, :] = derivs[:, 0, :] @@ -247,52 +283,47 @@ def get_dx_torchcde_equivalent(X: torch.Tensor) -> torch.Tensor: # For i > 0 factor = (4 - 3 * dt) * dt b = derivs_prev - D = derivs - b + D_diff = derivs - b - dx[:, 1:, :] = b + D * factor + dx[:, 1:, :] = b + D_diff * factor return dx +def get_xf_numpy(X: np.ndarray) -> np.ndarray: + return np.abs(np.fft.fft(X, axis=1)).astype(np.float32) -def get_xf(X: torch.Tensor) -> torch.Tensor: - return torch.abs(fft.fft(X, dim=1)) - -def preprocess_mvcl_views( - X: torch.Tensor, +def preprocess_mvcl_views_numpy( + X: np.ndarray, time_as_feature: bool = False -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Apply the same logical pipeline as ``preprocess_data`` when train and test are the same tensor (per-domain batch, e.g. all windows of one patient). - X: float tensor **[N, L, D]** (e.g. ``D=1`` for single-channel EEG; **L** is time length). - - dx_backend: ``"cde"`` (default, torchcde, matches MV) or ``"gradient"`` (no torchcde). + X: float numpy array **[N, L, D]** (e.g. ``D=1`` for single-channel EEG; **L** is time length). Returns xt, dx, xf each [N, L, D] or [N, L, D+1] if ``time_as_feature``. """ if X.ndim != 3: raise ValueError(f"Expected X shape [N, L, D], got {tuple(X.shape)}") - xt_tr, xt_te, _, _ = normalize_mvcl(X, X) + xt_tr, xt_te, _, _ = normalize_mvcl_numpy(X, X) xt = xt_tr + dx_raw = get_dx_torchcde_equivalent_numpy(xt) - # dx_raw = get_dx_gradient(xt) # this is approxi to paper's torchcde spline - dx_raw = get_dx_torchcde_equivalent(xt) - - dx_tr, dx_te, _, _ = normalize_mvcl(dx_raw, dx_raw) + dx_tr, dx_te, _, _ = normalize_mvcl_numpy(dx_raw, dx_raw) dx = dx_tr - xf_raw = get_xf(xt) - xf_tr, xf_te, _, _ = normalize_mvcl(xf_raw, xf_raw) + xf_raw = get_xf_numpy(xt) + xf_tr, xf_te, _, _ = normalize_mvcl_numpy(xf_raw, xf_raw) xf = xf_tr if time_as_feature: - xt = add_time_feature(xt) - dx = add_time_feature(dx) - xf = add_time_feature(xf) + xt = add_time_feature_numpy(xt) + dx = add_time_feature_numpy(dx) + xf = add_time_feature_numpy(xf) return xt, dx, xf @@ -367,9 +398,9 @@ def pt_dict_to_pyhealth_samples( f"{signal_array.shape[0]} vs {label_array.shape[0]}" ) - # preprocess_mvcl_views expects [N, L, D], use D=1 for single-channel EEG. - signal_tensor = torch.from_numpy(np.ascontiguousarray(signal_array)).float().unsqueeze(-1) - xt, dx, xf = preprocess_mvcl_views(signal_tensor, time_as_feature=time_as_feature) + # preprocess_mvcl_views_numpy expects [N, L, D], use D=1 for single-channel EEG. + signal_np = np.ascontiguousarray(signal_array)[..., np.newaxis] + xt_np, dx_np, xf_np = preprocess_mvcl_views_numpy(signal_np, time_as_feature=time_as_feature) samples: List[Dict[str, Any]] = [] for i in range(signal_array.shape[0]): @@ -377,10 +408,10 @@ def pt_dict_to_pyhealth_samples( { "patient_id": f"{patient_id_prefix}_{i}", "record_id": f"{record_id_prefix}_{i}", - "signal": signal_array[i][np.newaxis, :], - "xt": xt[i], - "xd": dx[i], - "xf": xf[i], + "signal": signal_array[i][np.newaxis, :].copy(), + "xt": torch.from_numpy(xt_np[i]), + "xd": torch.from_numpy(dx_np[i]), + "xf": torch.from_numpy(xf_np[i]), "label": int(label_array[i]), } ) From f0b71e149994f70342884a95a4995bcf8d9f8aee Mon Sep 17 00:00:00 2001 From: coderookie1994 Date: Wed, 22 Apr 2026 01:02:08 +0530 Subject: [PATCH 15/21] Generic MVCL model --- pyhealth/models/mvcl_model.py | 197 ++++++++++++++++++++++++++++++++++ 1 file changed, 197 insertions(+) create mode 100644 pyhealth/models/mvcl_model.py diff --git a/pyhealth/models/mvcl_model.py b/pyhealth/models/mvcl_model.py new file mode 100644 index 000000000..23ea39e33 --- /dev/null +++ b/pyhealth/models/mvcl_model.py @@ -0,0 +1,197 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from pyhealth.models import BaseModel +from typing import Callable, Dict, cast + +class GenericMultiViewModel(BaseModel): + """ + A generic Multi-View Contrastive Learning model that supports an arbitrary + number of views and modalities. + """ + def __init__( + self, + dataset, + encoders: nn.ModuleDict, + projectors: nn.ModuleDict, + augmentations: Dict[str, Callable], + pos_encoders: nn.ModuleDict = nn.ModuleDict({}), + hidden_dim: int = 128, + training_stage: str = "pretrain", + num_classes: int = 3, + lambda_cl: float = 0.1, + tau: float = 0.07, + **kwargs + ): + super().__init__(dataset=dataset) + self.training_stage = training_stage + self.mode = "" + + # Disable inference metrics during pre-training + if self.training_stage == "pretrain": + self.mode = None + + self.hidden_dim = hidden_dim + self.num_classes = num_classes + self.lambda_cl = lambda_cl + self.tau = tau + self.view_names = list(encoders.keys()) + + # Dynamic Modules: The model will automatically build itself based on the keys! + self.encoders = encoders + self.projectors = projectors + self.augmentations = augmentations + self.pos_encoders = pos_encoders if len(pos_encoders) > 0 else nn.ModuleDict({ + view: nn.Identity() for view in self.view_names + }) + + # Generic Fusion: Applies attention across the V different views + self.fusion_mha = nn.MultiheadAttention(embed_dim=self.hidden_dim, num_heads=4, batch_first=True) + self.fusion_layer_norm = nn.LayerNorm(self.hidden_dim) + + # Dynamic feature-specific projectors (F_k) + self.F_projectors = nn.ModuleDict({ + view: nn.Sequential( + nn.Linear(self.hidden_dim * 2, self.hidden_dim), + nn.BatchNorm1d(self.hidden_dim), + nn.ReLU(), + nn.Dropout(0.2), + nn.Linear(self.hidden_dim, self.hidden_dim), + ) for view in self.view_names + }) + + # Classifier dynamically sizes itself based on the number of views + self.classifier_mha = nn.MultiheadAttention(embed_dim=self.hidden_dim, num_heads=1, batch_first=True) + self.classifier = nn.Sequential( + nn.Linear(self.hidden_dim * len(self.view_names), 1024), + nn.ReLU(), + nn.Linear(1024, 512), + nn.ReLU(), + nn.Dropout(0.5), + nn.Linear(512, self.num_classes) + ) + + def ntxent_loss(self, zis: torch.Tensor, zjs: torch.Tensor, tau: float) -> torch.Tensor: + """2N x 2N NTXentLoss.""" + batch_size = zis.size(0) + zis = F.normalize(zis, dim=1) + zjs = F.normalize(zjs, dim=1) + + representations = torch.cat([zjs, zis], dim=0) + similarity_matrix = torch.mm(representations, representations.T) + + l_pos = torch.diag(similarity_matrix, batch_size) + r_pos = torch.diag(similarity_matrix, -batch_size) + positives = torch.cat([l_pos, r_pos]).view(2 * batch_size, 1) + + mask = (~torch.eye(2 * batch_size, 2 * batch_size, dtype=torch.bool, device=zis.device)) + negatives = similarity_matrix[mask].view(2 * batch_size, -1) + + logits = torch.cat((positives, negatives), dim=1) / tau + labels = torch.zeros(2 * batch_size, dtype=torch.long, device=zis.device) + + return F.cross_entropy(logits, labels, reduction="sum") / (2 * batch_size) + + def _forward_features(self, views_data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + encoded_views = {} + + # 1. Project and Encode (keep sequence length for MHA) + for view in self.view_names: + x = views_data[view] + x = self.projectors[view](x) + if view in self.pos_encoders: + x = self.pos_encoders[view](x) + h = self.encoders[view](x) + encoded_views[view] = h # Shape: [N, L, hidden_dim] + + # 2. Cross-View Fusion + batch_size = encoded_views[self.view_names[0]].shape[0] + seq_length = encoded_views[self.view_names[0]].shape[1] + num_views = len(self.view_names) + + # Stack into [N, num_views, L, hidden_dim] + H = torch.stack([encoded_views[v] for v in self.view_names], dim=1) + + # Flatten for MHA: [N * num_views, L, hidden_dim] + H_flat = H.permute(0, 2, 1, 3).contiguous().view(batch_size * num_views, seq_length, self.hidden_dim) + + MHA_out, _ = self.fusion_mha(H_flat, H_flat, H_flat) + H_out = self.fusion_layer_norm(MHA_out + H_flat) + + # Reshape back to [N, num_views, L, hidden_dim] + H_out = H_out.view(batch_size, seq_length, num_views, self.hidden_dim).permute(0, 2, 1, 3) + + # 3. Concatenate and Project (This restores your original logic!) + final_zs = {} + for i, view in enumerate(self.view_names): + h_pre = encoded_views[view].mean(dim=1) # Pre-interaction + h_post = H_out[:, i, :].mean(dim=1) # Post-interaction + + # Reintroducing your concatenation! + h_pool = torch.cat([h_pre, h_post], dim=-1) # Shape: [N, hidden_dim * 2] + + final_zs[view] = self.F_projectors[view](h_pool) + + return final_zs + + def forward(self, **kwargs) -> dict[str, torch.Tensor]: + # Dynamically extract and prepare the views + views_data = {view: self._prepare_tensor(kwargs.get(view)) for view in self.view_names} + + if self.training_stage == "pretrain": + # Augment + augmented_views = { + view: self.augmentations[view](views_data[view]) for view in self.view_names + } + + # Encode + zs = self._forward_features(views_data) + zs_aug = self._forward_features(augmented_views) + + # Dynamic Loss Calculation + loss = torch.tensor(0.0, device=self.device) + for view in self.view_names: + loss += self.ntxent_loss(zs[view], zs_aug[view], self.tau) + + result = {"loss": loss} + result.update({f"z_{v}": zs[v] for v in self.view_names}) # Add embeddings + return result + + elif self.training_stage == "finetune": + zs = self._forward_features(views_data) + + # Stack and fuse for classification + stacked_emb = torch.stack([zs[v] for v in self.view_names], dim=1) + attn_out, _ = self.classifier_mha(stacked_emb, stacked_emb, stacked_emb) + emb = attn_out + stacked_emb + + z_combined = emb.reshape(emb.size(0), -1) + logits = self.classifier(z_combined) + + label_key = self.label_keys[0] + y_true = cast(torch.Tensor, kwargs[label_key]).to(logits.device) + loss_ce = self.get_loss_function()(logits, y_true) + + # Contrastive Penalty + augmented_views = {v: self.augmentations[v](views_data[v]) for v in self.view_names} + zs_aug = self._forward_features(augmented_views) + + loss_cl = torch.tensor(0.0, device=self.device) + for view in self.view_names: + loss_cl += self.ntxent_loss(zs[view], zs_aug[view], self.tau) + + total_loss = (self.lambda_cl * loss_cl) + loss_ce + + return { + "loss": total_loss, + "logit": logits, + "y_prob": self.prepare_y_prob(logits), + "y_true": y_true + } + return {} + + def _prepare_tensor(self, x) -> torch.Tensor: + if isinstance(x, list): + import numpy as np + x = torch.stack(x) if isinstance(x[0], torch.Tensor) else torch.from_numpy(np.stack(x)) + return x.float().to(self.device) \ No newline at end of file From aae6aa55f08149ed06f05a77b0167241c125a0fd Mon Sep 17 00:00:00 2001 From: coderookie1994 Date: Wed, 22 Apr 2026 01:56:11 +0530 Subject: [PATCH 16/21] attempt at making the tests run faster --- tests/core/test_mvcl_training_sleepedf_task.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/tests/core/test_mvcl_training_sleepedf_task.py b/tests/core/test_mvcl_training_sleepedf_task.py index 08f66d82d..2d5b2824b 100644 --- a/tests/core/test_mvcl_training_sleepedf_task.py +++ b/tests/core/test_mvcl_training_sleepedf_task.py @@ -5,7 +5,8 @@ from unittest.mock import patch from collections import Counter -import mne +import concurrent.futures +from mne import create_info, io, Annotations import numpy as np import pandas as pd @@ -66,11 +67,15 @@ def _create_dummy_patient_files(cls): ("SC4011E0", 1), # subject 01, night 1 ("SC4022E0", 2), # subject 02, night 2 ] - for stem, seed in patient_records: + def write_patient_file(record): + stem, seed = record signal = np.full(6000, fill_value=seed, dtype=np.float32) for suffix in ("-PSG.edf", "-Hypnogram.edf"): file_path = cls.cassette_dir / f"{stem}{suffix}" signal.tofile(file_path) + + with concurrent.futures.ThreadPoolExecutor() as executor: + executor.map(write_patient_file, patient_records) @classmethod def _create_dummy_metadata_csv(cls): @@ -105,8 +110,8 @@ def _mock_read_raw_edf(signal_file, *args, **kwargs): if signal.size != 6000: raise ValueError(f"Expected 6000 points in {signal_file}, got {signal.size}") data = signal.reshape(1, -1) - info = mne.create_info(["EEG Fpz-Cz"], sfreq=100, ch_types=["eeg"]) - return mne.io.RawArray(data, info, verbose="error") + info = create_info(["EEG Fpz-Cz"], sfreq=100, ch_types=["eeg"]) + return io.RawArray(data, info, verbose="error") @staticmethod def _mock_read_annotations(label_file, *args, **kwargs): @@ -118,7 +123,7 @@ def _mock_read_annotations(label_file, *args, **kwargs): descriptions = ["Sleep stage 2", "Sleep stage 4"] else: raise ValueError(f"Unexpected label file: {label_file}") - return mne.Annotations( + return Annotations( onset=[0.0, 30.0], duration=[30.0, 30.0], description=descriptions, From d24238765904a0c89817e2404b45a1970300966a Mon Sep 17 00:00:00 2001 From: coderookie1994 Date: Wed, 22 Apr 2026 14:14:17 +0530 Subject: [PATCH 17/21] Added tests and updated docstrings --- pyhealth/models/__init__.py | 3 +- ...lti_view_contrastive_time_series_model.py} | 2 +- pyhealth/models/mvcl_model.py | 76 ++++++++++++---- pyhealth/tasks/mvcl_training_sleepedf_task.py | 10 ++- tests/core/test_mvcl.py | 88 +++++++++++++++++++ 5 files changed, 159 insertions(+), 20 deletions(-) rename pyhealth/models/{multi_view_contrastive_model.py => multi_view_contrastive_time_series_model.py} (99%) create mode 100644 tests/core/test_mvcl.py diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index bc3d37bb2..47679cae5 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -44,4 +44,5 @@ from .sdoh import SdohClassifier from .medlink import MedLink from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding -from .multi_view_contrastive_model import MultiViewContrastiveModel +from .multi_view_contrastive_time_series_model import MultiViewContrastiveTimeSeriesModel +from .mvcl_model import MultiViewContrastiveModel diff --git a/pyhealth/models/multi_view_contrastive_model.py b/pyhealth/models/multi_view_contrastive_time_series_model.py similarity index 99% rename from pyhealth/models/multi_view_contrastive_model.py rename to pyhealth/models/multi_view_contrastive_time_series_model.py index e457e7767..133addd50 100644 --- a/pyhealth/models/multi_view_contrastive_model.py +++ b/pyhealth/models/multi_view_contrastive_time_series_model.py @@ -25,7 +25,7 @@ def forward(self, x): x = x + self.pe[:, :x.size(1)] return self.dropout(x) -class MultiViewContrastiveModel(BaseModel): +class MultiViewContrastiveTimeSeriesModel(BaseModel): """A multi-view contrastive model aligned with Oh and Bui (2025) and TFC.""" def __init__(self, dataset, training_stage="pretrain", num_classes=3, **kwargs): diff --git a/pyhealth/models/mvcl_model.py b/pyhealth/models/mvcl_model.py index 23ea39e33..cf7a281a0 100644 --- a/pyhealth/models/mvcl_model.py +++ b/pyhealth/models/mvcl_model.py @@ -2,12 +2,54 @@ import torch.nn as nn import torch.nn.functional as F from pyhealth.models import BaseModel -from typing import Callable, Dict, cast +from typing import Callable, Dict, Optional, cast -class GenericMultiViewModel(BaseModel): - """ - A generic Multi-View Contrastive Learning model that supports an arbitrary - number of views and modalities. +class MultiViewContrastiveModel(BaseModel): + """A generic, plug-and-play Multi-View Contrastive Learning (MVCL) model. + This model supports an arbitrary number of views and modalities by dynamically + constructing its architecture based on the provided dictionaries of encoders, + projectors, and augmentations. It implements hierarchical cross-view fusion + via Multi-Head Attention and uses a 2N x 2N NT-Xent (InfoNCE) contrastive loss, + aligning with the conceptual framework of Oh and Bui (2025) and TFC-pretraining. + + Args: + dataset (SampleDataset): The PyHealth dataset object. + encoders (nn.ModuleDict): A dictionary mapping view names (e.g., "xt", "xf") + to their respective PyTorch encoder modules (e.g., Transformer, CNN). + projectors (nn.ModuleDict): A dictionary mapping view names to their initial + feature projection layers (e.g., mapping raw inputs to `hidden_dim`). + augmentations (Dict[str, Callable]): A dictionary mapping view names to + their specific data augmentation functions (e.g., jittering, frequency masking). + pos_encoders (nn.ModuleDict, optional): A dictionary mapping view names to + their positional encoding modules. If a view is not included, it defaults + to `nn.Identity()`. Useful for sequence models. Defaults to an empty dict. + hidden_dim (int, optional): The hidden dimension size for the embeddings, + MHA fusion, and projections. Defaults to 128. + training_stage (str, optional): The current stage of the model. Accepts + "pretrain" (contrastive representation learning) or "finetune" + (downstream classification). Defaults to "pretrain". + num_classes (int, optional): The number of target classes for the downstream + classification task. Defaults to 3. + lambda_cl (float, optional): The weight/penalty hyperparameter for the contrastive + loss during the fine-tuning stage. Defaults to 0.1. + tau (float, optional): The temperature parameter for the NT-Xent loss function. + Defaults to 0.07. + + Outputs (dict): + During "pretrain": + - loss (torch.Tensor): The aggregated 2N x 2N NT-Xent contrastive loss across all views. + - z_{view} (torch.Tensor): The final fused embeddings for each provided view. + During "finetune": + - loss (torch.Tensor): The combined cross-entropy loss and contrastive penalty. + - logit (torch.Tensor): The raw classification logits. + - y_prob (torch.Tensor): The predicted class probabilities. + - y_true (torch.Tensor): The ground truth labels. + + Example: + >>> encoders = nn.ModuleDict({"v1": CNNEncoder(), "v2": TextEncoder()}) + >>> projectors = nn.ModuleDict({"v1": nn.Linear(3, 128), "v2": nn.Linear(768, 128)}) + >>> augs = {"v1": image_jitter, "v2": text_mask} + >>> model = MultiViewContrastiveModel(dataset, encoders, projectors, augs) """ def __init__( self, @@ -15,7 +57,7 @@ def __init__( encoders: nn.ModuleDict, projectors: nn.ModuleDict, augmentations: Dict[str, Callable], - pos_encoders: nn.ModuleDict = nn.ModuleDict({}), + pos_encoders: Optional[nn.ModuleDict] = None, hidden_dim: int = 128, training_stage: str = "pretrain", num_classes: int = 3, @@ -37,11 +79,10 @@ def __init__( self.tau = tau self.view_names = list(encoders.keys()) - # Dynamic Modules: The model will automatically build itself based on the keys! self.encoders = encoders self.projectors = projectors self.augmentations = augmentations - self.pos_encoders = pos_encoders if len(pos_encoders) > 0 else nn.ModuleDict({ + self.pos_encoders = pos_encoders if pos_encoders is not None else nn.ModuleDict({ view: nn.Identity() for view in self.view_names }) @@ -112,22 +153,26 @@ def _forward_features(self, views_data: Dict[str, torch.Tensor]) -> Dict[str, to # Stack into [N, num_views, L, hidden_dim] H = torch.stack([encoded_views[v] for v in self.view_names], dim=1) - # Flatten for MHA: [N * num_views, L, hidden_dim] - H_flat = H.permute(0, 2, 1, 3).contiguous().view(batch_size * num_views, seq_length, self.hidden_dim) + H_permuted = H.permute(0, 2, 1, 3).contiguous() + + # Flatten Batch and Sequence length together: [N * L, num_views, hidden_dim] + H_flat = H_permuted.view(batch_size * seq_length, num_views, self.hidden_dim) MHA_out, _ = self.fusion_mha(H_flat, H_flat, H_flat) H_out = self.fusion_layer_norm(MHA_out + H_flat) - # Reshape back to [N, num_views, L, hidden_dim] - H_out = H_out.view(batch_size, seq_length, num_views, self.hidden_dim).permute(0, 2, 1, 3) + # Reshape back to [N, L, num_views, hidden_dim] + H_out = H_out.view(batch_size, seq_length, num_views, self.hidden_dim) - # 3. Concatenate and Project (This restores your original logic!) + # 3. Concatenate and Project final_zs = {} for i, view in enumerate(self.view_names): h_pre = encoded_views[view].mean(dim=1) # Pre-interaction - h_post = H_out[:, i, :].mean(dim=1) # Post-interaction - # Reintroducing your concatenation! + # Extract from dimension 2 (num_views dimension) + h_post = H_out[:, :, i, :].mean(dim=1) # Post-interaction + + # Concatenation of pre and post features! h_pool = torch.cat([h_pre, h_post], dim=-1) # Shape: [N, hidden_dim * 2] final_zs[view] = self.F_projectors[view](h_pool) @@ -135,7 +180,6 @@ def _forward_features(self, views_data: Dict[str, torch.Tensor]) -> Dict[str, to return final_zs def forward(self, **kwargs) -> dict[str, torch.Tensor]: - # Dynamically extract and prepare the views views_data = {view: self._prepare_tensor(kwargs.get(view)) for view in self.view_names} if self.training_stage == "pretrain": diff --git a/pyhealth/tasks/mvcl_training_sleepedf_task.py b/pyhealth/tasks/mvcl_training_sleepedf_task.py index 03e32ade1..b54b0a0aa 100644 --- a/pyhealth/tasks/mvcl_training_sleepedf_task.py +++ b/pyhealth/tasks/mvcl_training_sleepedf_task.py @@ -12,6 +12,7 @@ from __future__ import annotations +import os from pathlib import Path from typing import Any, Dict, List, Mapping, Optional, Tuple, Union import mne @@ -77,6 +78,7 @@ def __init__( eeg_channel: Optional[str] = "EEG Fpz-Cz", time_as_feature: bool = False, dx_backend: str = "cde", + root_path: Optional[Union[str, Path]] = None, ) -> None: """Initializes the task object. @@ -87,6 +89,7 @@ def __init__( eeg_channel: Which EEG channel to pick. Defaults to "EEG Fpz-Cz". time_as_feature: Whether to add a time feature channel. Defaults to False. dx_backend: Backend to use for computing the derivative view. Defaults to "cde". + root_path: Optional path to the root directory of the dataset. Defaults to None. """ self.chunk_duration = float(chunk_duration) self.window_size = int(window_size) @@ -94,6 +97,7 @@ def __init__( self.eeg_channel = eeg_channel # ``False`` matches ``preprocess_data`` defaults in MV run_pretrain / run_finetune. self.time_as_feature = bool(time_as_feature) + self.root_path = root_path super().__init__() @@ -146,14 +150,16 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: for event in events: if not event.signal_file or not event.label_file: continue + signal_file = os.path.join(self.root_path, event.signal_file) if self.root_path else event.signal_file + label_file = os.path.join(self.root_path, event.label_file) if self.root_path else event.label_file data = mne.io.read_raw_edf( - event.signal_file, + signal_file, stim_channel="Event marker", infer_types=True, preload=False, verbose="error", ) - ann = mne.read_annotations(event.label_file) + ann = mne.read_annotations(label_file) data.set_annotations(ann, emit_warning=False) ann_events, event_id_used = mne.events_from_annotations( diff --git a/tests/core/test_mvcl.py b/tests/core/test_mvcl.py new file mode 100644 index 000000000..a1f27cffd --- /dev/null +++ b/tests/core/test_mvcl.py @@ -0,0 +1,88 @@ +import unittest +from unittest.mock import MagicMock +import torch +import torch.nn as nn + +from pyhealth.models.mvcl_model import MultiViewContrastiveModel + +class TestMultiViewContrastiveModel(unittest.TestCase): + def setUp(self): + """Set up dummy hyperparameters and dynamic modules for testing.""" + self.hidden_dim = 32 + self.batch_size = 4 + self.seq_len = 10 + self.num_classes = 3 + + self.view_names = ["view_A", "view_B"] + + self.projectors = nn.ModuleDict({ + view: nn.Linear(1, self.hidden_dim) for view in self.view_names + }) + + self.encoders = nn.ModuleDict({ + view: nn.Linear(self.hidden_dim, self.hidden_dim) for view in self.view_names + }) + + self.augmentations = { + view: lambda x: x + torch.randn_like(x) * 0.01 for view in self.view_names + } + + self.mock_dataset = MagicMock() + self.mock_dataset.input_schema = {view: "tensor" for view in self.view_names} + self.mock_dataset.output_schema = {"label": "multiclass"} + + def test_pretrain_forward(self): + """Tests the representation learning stage and NT-Xent loss.""" + model = MultiViewContrastiveModel( + dataset=self.mock_dataset, + encoders=self.encoders, + projectors=self.projectors, + augmentations=self.augmentations, + hidden_dim=self.hidden_dim, + training_stage="pretrain" + ) + + kwargs = { + "view_A": torch.randn(self.batch_size, self.seq_len, 1), + "view_B": torch.randn(self.batch_size, self.seq_len, 1) + } + + outputs = model(**kwargs) + + self.assertIn("loss", outputs) + self.assertIn("z_view_A", outputs) + self.assertIn("z_view_B", outputs) + self.assertEqual(outputs["z_view_A"].shape, (self.batch_size, self.hidden_dim)) + self.assertTrue(torch.is_tensor(outputs["loss"])) + self.assertFalse(torch.isnan(outputs["loss"])) + + def test_finetune_forward(self): + """Tests the classification finetuning stage and cross-entropy loss.""" + model = MultiViewContrastiveModel( + dataset=self.mock_dataset, + encoders=self.encoders, + projectors=self.projectors, + augmentations=self.augmentations, + hidden_dim=self.hidden_dim, + training_stage="finetune", + num_classes=self.num_classes + ) + + kwargs = { + "view_A": torch.randn(self.batch_size, self.seq_len, 1), + "view_B": torch.randn(self.batch_size, self.seq_len, 1), + "label": torch.randint(0, self.num_classes, (self.batch_size,)) + } + + outputs = model(**kwargs) + + self.assertIn("loss", outputs) + self.assertIn("logit", outputs) + self.assertIn("y_prob", outputs) + self.assertIn("y_true", outputs) + self.assertEqual(outputs["logit"].shape, (self.batch_size, self.num_classes)) + self.assertEqual(outputs["y_prob"].shape, (self.batch_size, self.num_classes)) + self.assertTrue(torch.allclose(outputs["y_prob"].sum(dim=-1), torch.ones(self.batch_size))) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 3510816dbf1fdace7fbf9cf98278b41a9aadf07a Mon Sep 17 00:00:00 2001 From: coderookie1994 Date: Wed, 22 Apr 2026 14:24:38 +0530 Subject: [PATCH 18/21] Updated the notebook --- examples/mvcl_training_sleepedf.ipynb | 167 ++++++++++++++++++++++++-- 1 file changed, 158 insertions(+), 9 deletions(-) diff --git a/examples/mvcl_training_sleepedf.ipynb b/examples/mvcl_training_sleepedf.ipynb index 9c556f6ef..fb4fc4c35 100644 --- a/examples/mvcl_training_sleepedf.ipynb +++ b/examples/mvcl_training_sleepedf.ipynb @@ -9,9 +9,11 @@ "This example shows how to use `MVCLTrainingSleepEEG` from `pyhealth.tasks`.\n", "\n", "The workflow mirrors the task pattern used in PyHealth examples:\n", - "1. Load `SleepEDFDataset`\n", - "2. Run the task on one patient for a quick sanity check\n", - "3. Optionally run `set_task()` for the full dataset" + "1. Load `SleepEDFDataset`.\n", + "2. Run the task on one patient for a quick sanity check.\n", + "3. Optionally run `set_task()` for the full dataset.\n", + "4. Run the pretraining step on the subset of SleepEDF data.\n", + "5. Save the model state." ] }, { @@ -20,11 +22,17 @@ "metadata": {}, "outputs": [], "source": [ + "!uv pip install ipywidgets\n", + "\n", + "import os\n", + "\n", "from pyhealth.datasets import SleepEDFDataset\n", "from pyhealth.tasks import MVCLTrainingSleepEEG\n", "\n", - "# Update this path to your local Sleep-EDF root.\n", - "DATA_ROOT = \"../sleepedf\"" + "# Update this absolute path to your local Sleep-EDF root.\n", + "DATA_ROOT = \"C:\\\\Users\\\\shart\\\\workspace\\\\CS-598\\\\PyHealth\\\\sleepedf\"\n", + "assert os.path.exists(DATA_ROOT), f\"Sleep-EDF root path {DATA_ROOT} does not exist. Please update the path to your local Sleep-EDF root.\"\n", + "assert os.path.isabs(DATA_ROOT), f\"Sleep-EDF root path {DATA_ROOT} is not an absolute path. Please update to an absolute path.\"" ] }, { @@ -33,7 +41,7 @@ "metadata": {}, "outputs": [], "source": [ - "dataset = SleepEDFDataset(root=DATA_ROOT, subset=\"cassette\")\n", + "dataset = SleepEDFDataset(root=DATA_ROOT, subset=\"cassette\", dev=True)\n", "dataset.stats()\n", "print(f\"Number of patients: {len(dataset.unique_patient_ids)}\")" ] @@ -68,6 +76,7 @@ " window_size=200, ## Create None overlapping window of 200 lenth \n", " crop_length=178, ## take first 178 data points of the window to match that of Epilepsy data \n", " eeg_channel=\"EEG Fpz-Cz\",\n", + " root_path=DATA_ROOT, ## Pass the root path to the task so it can load the data correctly.\n", ")\n", "samples = task(patient)\n", "\n", @@ -88,22 +97,162 @@ "outputs": [], "source": [ "# Full pipeline (can take a while and uses disk cache).\n", - "sample_dataset = dataset.set_task(task, num_workers=1)\n", + "sample_dataset = dataset.set_task(task, num_workers=16)\n", "print(f\"Total task samples: {len(sample_dataset)}\")\n", "print(f\"Input schema: {sample_dataset.input_schema}\")\n", "print(f\"Output schema: {sample_dataset.output_schema}\")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "44748101", + "metadata": {}, + "outputs": [], + "source": [ + "# Factory functions and helpers required to setup the model.\n", + "import math\n", + "from pathlib import Path\n", + "from typing import Any, Dict, Dict, List, List, Mapping, Union\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "\n", + "def augment_time(x: torch.Tensor, std: float = 0.1) -> torch.Tensor:\n", + " \"\"\"Time-domain jitter augmentation\"\"\"\n", + " noise = torch.randn_like(x) * std\n", + " return x + noise\n", + " \n", + "def augment_freq(sample: torch.Tensor, pertub_ratio: float = 0.05) -> torch.Tensor:\n", + " \"\"\"Frequency-domain augmentation (remove and add frequencies)\"\"\"\n", + " aug_1 = remove_frequency(sample, pertub_ratio)\n", + " aug_2 = add_frequency(sample, pertub_ratio)\n", + " return aug_1 + aug_2\n", + "\n", + "def remove_frequency(x: torch.Tensor, pertub_ratio: float = 0.0) -> torch.Tensor:\n", + " mask = torch.rand(x.shape, device=x.device) > pertub_ratio\n", + " return x * mask\n", + "\n", + "def add_frequency(x: torch.Tensor, pertub_ratio: float = 0.0) -> torch.Tensor:\n", + " mask = torch.rand(x.shape, device=x.device) > (1 - pertub_ratio)\n", + " max_amplitude = x.max()\n", + " random_am = torch.rand(mask.shape, device=x.device) * (max_amplitude * 0.1)\n", + " pertub_matrix = mask * random_am\n", + " return x + pertub_matrix\n", + "\n", + "class PositionalEncoding(nn.Module):\n", + " def __init__(self, hidden_dim, dropout=0.1, max_len=1024):\n", + " super(PositionalEncoding, self).__init__()\n", + " self.dropout = nn.Dropout(p=dropout)\n", + " \n", + " pe = torch.zeros(max_len, hidden_dim)\n", + " position = torch.arange(0, max_len).unsqueeze(1)\n", + " div_term = torch.exp(torch.arange(0, hidden_dim, 2) * \n", + " (-math.log(10000.0) / hidden_dim))\n", + " pe[:, 0::2] = torch.sin(position * div_term)\n", + " pe[:, 1::2] = torch.cos(position * div_term)\n", + " pe = pe.unsqueeze(0) # Shape: [1, max_len, hidden_dim]\n", + " self.register_buffer('pe', pe)\n", + " \n", + " def forward(self, x):\n", + " x = x + self.pe[:, :x.size(1)]\n", + " return self.dropout(x)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "894abbca", + "metadata": {}, + "outputs": [], + "source": [ + "# Now we set up the model and training loop.\n", + "from pyhealth.datasets.sample_dataset import SampleDataset\n", + "from pyhealth.datasets.splitter import split_by_sample\n", + "from pyhealth.datasets.utils import get_dataloader\n", + "from pyhealth.models.mvcl_model import MultiViewContrastiveModel\n", + "from torch.optim import Adam\n", + "from torch.utils.data import DataLoader\n", + "from pyhealth.trainer import Trainer\n", + "\n", + "pretrain_ds, _, _ = split_by_sample(sample_dataset, [0.05, 0.05, 0.9])\n", + "assert type(pretrain_ds) == SampleDataset, \"Expected a SampleDataset after splitting\"\n", + "pretrain_loader = get_dataloader(pretrain_ds, batch_size=128, shuffle=True)\n", + "\n", + "\n", + "view_keys=[\"xt\", \"xf\", \"xd\"]\n", + "model = MultiViewContrastiveModel(\n", + " dataset=pretrain_ds,\n", + " encoders=torch.nn.ModuleDict({\n", + " k: nn.TransformerEncoder(\n", + " nn.TransformerEncoderLayer(d_model=128, nhead=4, batch_first=True),\n", + " num_layers=2\n", + " ) for k in view_keys\n", + " }),\n", + " projectors = nn.ModuleDict({\n", + " k: nn.Linear(1, 128) \n", + " for k in view_keys\n", + " }),\n", + " augmentations={\"xt\": augment_time, \"xf\": augment_freq, \"xd\": augment_time},\n", + " pos_encoders=nn.ModuleDict({\n", + " k: PositionalEncoding(hidden_dim=128, dropout=0.1) for k in view_keys\n", + " }),\n", + " hidden_dim=128,\n", + " training_stage=\"pretrain\",\n", + " num_classes=3\n", + "\n", + ")\n", + "device=torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "model.to(device)\n", + "\n", + "pretrainer = Trainer(\n", + " model=model,\n", + " device=str(device),\n", + ")\n", + "pretrainer.train(\n", + " train_dataloader=pretrain_loader,\n", + " epochs=2,\n", + " optimizer_class=torch.optim.Adam,\n", + " optimizer_params={\n", + " \"lr\": 0.002,\n", + " \"weight_decay\": 1e-5,\n", + " },\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "711ee718", + "metadata": {}, + "outputs": [], + "source": [ + "# Save the pretraining model state for later use.\n", + "import os\n", + "import tempfile\n", + "\n", + "tmp_path = tempfile.mkdtemp()\n", + "torch.save(model.state_dict(), os.path.join(tmp_path, \"mvcl_pretrain.pth\"))" + ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "pyhealth (3.13.7)", "language": "python", "name": "python3" }, "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", "name": "python", - "version": "3.10" + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.7" } }, "nbformat": 4, From 578d1ad07f16da5b294b57d034aa642f63575116 Mon Sep 17 00:00:00 2001 From: coderookie1994 Date: Wed, 22 Apr 2026 15:01:29 +0530 Subject: [PATCH 19/21] Updated RST files --- docs/api/models.rst | 1 + docs/api/models/pyhealth.models.MVCL.rst | 7 +++++++ 2 files changed, 8 insertions(+) create mode 100644 docs/api/models/pyhealth.models.MVCL.rst diff --git a/docs/api/models.rst b/docs/api/models.rst index 7368dec94..607dcbc05 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -204,3 +204,4 @@ API Reference models/pyhealth.models.TextEmbedding models/pyhealth.models.BIOT models/pyhealth.models.unified_multimodal_embedding_docs + models/pyhealth.models.MultiViewContrastiveModel diff --git a/docs/api/models/pyhealth.models.MVCL.rst b/docs/api/models/pyhealth.models.MVCL.rst new file mode 100644 index 000000000..3fd27cd94 --- /dev/null +++ b/docs/api/models/pyhealth.models.MVCL.rst @@ -0,0 +1,7 @@ +pyhealth.models.MultiViewContrastiveModel +=================================== + +autoclass:: pyhealth.models.MultiViewContrastiveModel + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file From 486e89dfe8334d2a5b73badb073e58445c9e1b0e Mon Sep 17 00:00:00 2001 From: coderookie1994 Date: Wed, 22 Apr 2026 18:28:09 +0530 Subject: [PATCH 20/21] Captured output from the example notebook --- examples/mvcl_training_sleepedf.ipynb | 324 ++++++++++++++++++++++++-- 1 file changed, 310 insertions(+), 14 deletions(-) diff --git a/examples/mvcl_training_sleepedf.ipynb b/examples/mvcl_training_sleepedf.ipynb index fb4fc4c35..874419bb7 100644 --- a/examples/mvcl_training_sleepedf.ipynb +++ b/examples/mvcl_training_sleepedf.ipynb @@ -2,6 +2,7 @@ "cells": [ { "cell_type": "markdown", + "id": "7db5f192", "metadata": {}, "source": [ "# MVCLTrainingSleepEEG on Sleep-EDF\n", @@ -18,9 +19,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, + "id": "943a666a", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[2mUsing Python 3.13.7 environment at: C:\\Users\\shart\\workspace\\CS-598\\PyHealth\\.venv\u001b[0m\n", + "\u001b[2mChecked \u001b[1m1 package\u001b[0m \u001b[2min 182ms\u001b[0m\u001b[0m\n" + ] + } + ], "source": [ "!uv pip install ipywidgets\n", "\n", @@ -37,9 +48,27 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, + "id": "866a8eb3", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No config path provided, using default config\n", + "Initializing sleepedf dataset from C:\\Users\\shart\\workspace\\CS-598\\PyHealth\\sleepedf (dev mode: True)\n", + "No cache_dir provided. Using default cache dir: C:\\Users\\shart\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\c8f0e13c-fb2e-5216-8969-8e6afcc7338c\n", + "Found cached event dataframe: C:\\Users\\shart\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\c8f0e13c-fb2e-5216-8969-8e6afcc7338c\\global_event_df.parquet\n", + "Dataset: sleepedf\n", + "Dev mode: True\n", + "Number of patients: 78\n", + "Number of events: 153\n", + "Found 78 unique patient IDs\n", + "Number of patients: 78\n" + ] + } + ], "source": [ "dataset = SleepEDFDataset(root=DATA_ROOT, subset=\"cassette\", dev=True)\n", "dataset.stats()\n", @@ -48,10 +77,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "d838e3b9", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "'\\nThis dataset contains 153 whole-night sleep electroencephalography\\n(EEG) recordings collected from 82 healthy subjects. Each recording is sampled at 100 Hz using a 1-lead \\nEEG signal. The EEG signals are segmented into non-overlapping windows of size 200, each forming\\none sample. Each sample is labeled with one of five sleep stages: Wake (W), Non-rapid Eye Movement\\n(N1, N2, N3), and Rapid Eye Movement (REM). This segmentation results in 371,055 samples\\n'" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "\"\"\"\n", "This dataset contains 153 whole-night sleep electroencephalography\n", @@ -64,9 +104,27 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, + "id": "fdd22b73", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Used Annotations descriptions: [np.str_('Sleep stage 1'), np.str_('Sleep stage 2'), np.str_('Sleep stage 3'), np.str_('Sleep stage 4'), np.str_('Sleep stage R'), np.str_('Sleep stage W')]\n", + "Used Annotations descriptions: [np.str_('Sleep stage 1'), np.str_('Sleep stage 2'), np.str_('Sleep stage 3'), np.str_('Sleep stage R'), np.str_('Sleep stage W')]\n", + "patient_id: 25\n", + "sample count: 81375\n", + "sample keys: ['patient_id', 'night', 'patient_age', 'patient_sex', 'epoch_index', 'window_in_epoch', 'signal', 'xt', 'xd', 'xf', 'label']\n", + "signal shape: (1, 178)\n", + "xt shape: torch.Size([178, 1])\n", + "xd shape: torch.Size([178, 1])\n", + "xf shape: torch.Size([178, 1])\n", + "label: 0\n" + ] + } + ], "source": [ "# Quick sanity check on one patient.\n", "patient_id = dataset.unique_patient_ids[0]\n", @@ -92,9 +150,74 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, + "id": "535a0419", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Setting task MVCLTrainingSleepEEG for sleepedf base dataset...\n", + "Task cache paths: task_df=C:\\Users\\shart\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\c8f0e13c-fb2e-5216-8969-8e6afcc7338c\\tasks\\MVCLTrainingSleepEEG_2060b134-eac9-5c04-b238-25b00edd7ba5\\task_df.ld, samples=C:\\Users\\shart\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\c8f0e13c-fb2e-5216-8969-8e6afcc7338c\\tasks\\MVCLTrainingSleepEEG_2060b134-eac9-5c04-b238-25b00edd7ba5\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld\n", + "Found cached task dataframe at C:\\Users\\shart\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\c8f0e13c-fb2e-5216-8969-8e6afcc7338c\\tasks\\MVCLTrainingSleepEEG_2060b134-eac9-5c04-b238-25b00edd7ba5\\task_df.ld, skipping task transformation.\n", + "Fitting processors on the dataset...\n", + "Label label vocab: {0: 0, 1: 1, 2: 2, 3: 3, 4: 4}\n", + "Processing samples and saving to C:\\Users\\shart\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\c8f0e13c-fb2e-5216-8969-8e6afcc7338c\\tasks\\MVCLTrainingSleepEEG_2060b134-eac9-5c04-b238-25b00edd7ba5\\samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld...\n", + "Applying processors on data with 16 workers...\n", + "Detected Jupyter notebook environment, setting num_workers to 1\n", + "Single worker mode, processing sequentially\n", + "Worker 0 started processing 6224415 samples. (0 to 6224415)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/6224415 [00:00\n", + "Optimizer params: {'lr': 0.002, 'weight_decay': 1e-05}\n", + "Weight decay: 0.0\n", + "Max grad norm: None\n", + "Val dataloader: None\n", + "Monitor: None\n", + "Monitor criterion: max\n", + "Epochs: 2\n", + "Patience: None\n", + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9bfad519d14947ceb5700c21c2c3d04f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Epoch 0 / 2: 0%| | 0/2432 [00:00 Date: Wed, 6 May 2026 21:00:10 +0800 Subject: [PATCH 21/21] fixed feedbacks on gradescope --- docs/api/models.rst | 2 +- docs/api/models/pyhealth.models.MVCL.rst | 4 +- ...ulti_view_contrastive_time_series_model.py | 142 ++++++++++++------ pyhealth/models/mvcl_model.py | 102 +++++++------ pyhealth/tasks/mvcl_training_sleepedf_task.py | 7 +- 5 files changed, 163 insertions(+), 94 deletions(-) diff --git a/docs/api/models.rst b/docs/api/models.rst index 2fe03cb43..c0f0b7b3b 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -205,5 +205,5 @@ API Reference models/pyhealth.models.TextEmbedding models/pyhealth.models.BIOT models/pyhealth.models.unified_multimodal_embedding_docs - models/pyhealth.models.MultiViewContrastiveModel + models/pyhealth.models.MVCL models/pyhealth.models.califorest diff --git a/docs/api/models/pyhealth.models.MVCL.rst b/docs/api/models/pyhealth.models.MVCL.rst index 3fd27cd94..2c2200c66 100644 --- a/docs/api/models/pyhealth.models.MVCL.rst +++ b/docs/api/models/pyhealth.models.MVCL.rst @@ -1,7 +1,7 @@ pyhealth.models.MultiViewContrastiveModel =================================== -autoclass:: pyhealth.models.MultiViewContrastiveModel +.. autoclass:: pyhealth.models.MultiViewContrastiveModel :members: :undoc-members: - :show-inheritance: \ No newline at end of file + :show-inheritance: diff --git a/pyhealth/models/multi_view_contrastive_time_series_model.py b/pyhealth/models/multi_view_contrastive_time_series_model.py index 133addd50..0c426a812 100644 --- a/pyhealth/models/multi_view_contrastive_time_series_model.py +++ b/pyhealth/models/multi_view_contrastive_time_series_model.py @@ -1,11 +1,16 @@ +"""Multi-view contrastive time-series model for PyHealth datasets.""" + +import math +from typing import Any, Tuple, cast + +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -import numpy as np -import math +from pyhealth.datasets import SampleDataset from pyhealth.models import BaseModel -from typing import Tuple, cast + class PositionalEncoding(nn.Module): def __init__(self, hidden_dim, dropout=0.1, max_len=1024): @@ -26,20 +31,58 @@ def forward(self, x): return self.dropout(x) class MultiViewContrastiveTimeSeriesModel(BaseModel): - """A multi-view contrastive model aligned with Oh and Bui (2025) and TFC.""" + """Multi-view contrastive model for time-series tensors. + + This model follows the multi-view contrastive learning setup used for + time-domain, derivative, and frequency-domain views. Each view is projected, + encoded with a Transformer encoder, fused with cross-view attention, and + used for either contrastive pretraining or downstream classification. + + Args: + dataset (SampleDataset): Dataset with ``xt``, ``xd``, and ``xf`` tensor + inputs and one output label. + training_stage (str): Training stage, either ``"pretrain"`` for + contrastive representation learning or ``"finetune"`` for + classification. Default is ``"pretrain"``. + num_classes (int): Number of classes used by the classification head. + Default is 3. + **kwargs: Additional keyword arguments kept for PyHealth model API + compatibility. + + Attributes: + hidden_dim: Hidden dimension used by projections, encoders, and fusion. + lambda_cl: Weight for the contrastive penalty during finetuning. + tau: Temperature used by the NT-Xent contrastive loss. + + Examples: + >>> from pyhealth.models import MultiViewContrastiveTimeSeriesModel + >>> model = MultiViewContrastiveTimeSeriesModel( + ... dataset=sample_dataset, + ... training_stage="pretrain", + ... num_classes=5, + ... ) + >>> output = model(xt=xt, xd=xd, xf=xf) + >>> sorted(output.keys()) + ['loss', 'z_d', 'z_f', 'z_t'] + """ - def __init__(self, dataset, training_stage="pretrain", num_classes=3, **kwargs): + def __init__( + self, + dataset: SampleDataset, + training_stage: str = "pretrain", + num_classes: int = 3, + **kwargs: Any + ): super().__init__(dataset=dataset) self.hidden_dim = 128 - seq_length = 256 self.training_stage = training_stage self.lambda_cl = 0.1 self.tau = 0.07 self.num_classes = num_classes - self.proj_t = nn.Linear(1, self.hidden_dim) - self.proj_d = nn.Linear(1, self.hidden_dim) - self.proj_f = nn.Linear(1, self.hidden_dim) + self.temporal_projection = nn.Linear(1, self.hidden_dim) + self.derivative_projection = nn.Linear(1, self.hidden_dim) + self.frequency_projection = nn.Linear(1, self.hidden_dim) self.pos_encoder = PositionalEncoding(self.hidden_dim, dropout=0.1) @@ -53,11 +96,11 @@ def make_encoder() -> nn.TransformerEncoder: self.encoder_d: nn.TransformerEncoder = make_encoder() self.encoder_f: nn.TransformerEncoder = make_encoder() - # MHA for Hierarchical Fusion - self.fusion_mha: nn.MultiheadAttention = nn.MultiheadAttention(embed_dim=self.hidden_dim, num_heads=4, batch_first=True) + self.fusion_mha: nn.MultiheadAttention = nn.MultiheadAttention( + embed_dim=self.hidden_dim, num_heads=4, batch_first=True + ) self.fusion_layer_norm: nn.LayerNorm = nn.LayerNorm(self.hidden_dim) - # Feature-specific projectors def projector() -> nn.Sequential: return nn.Sequential( nn.Linear(self.hidden_dim * 2, self.hidden_dim), @@ -67,11 +110,13 @@ def projector() -> nn.Sequential: nn.Linear(self.hidden_dim, self.hidden_dim), ) - self.F_t: nn.Sequential = projector() - self.F_d: nn.Sequential = projector() - self.F_f: nn.Sequential = projector() + self.temporal_feature_projector: nn.Sequential = projector() + self.derivative_feature_projector: nn.Sequential = projector() + self.frequency_feature_projector: nn.Sequential = projector() - self.classifier_mha = nn.MultiheadAttention(embed_dim=self.hidden_dim, num_heads=1, batch_first=True) + self.classifier_mha = nn.MultiheadAttention( + embed_dim=self.hidden_dim, num_heads=1, batch_first=True + ) self.classifier = nn.Linear(self.hidden_dim * 3, self.num_classes) def augment_time(self, x: torch.Tensor, std: float = 0.1) -> torch.Tensor: @@ -79,22 +124,22 @@ def augment_time(self, x: torch.Tensor, std: float = 0.1) -> torch.Tensor: noise = torch.randn_like(x) * std return x + noise - def augment_freq(self, sample: torch.Tensor, pertub_ratio: float = 0.05) -> torch.Tensor: + def augment_freq(self, sample: torch.Tensor, perturb_ratio: float = 0.05) -> torch.Tensor: """Frequency-domain augmentation (remove and add frequencies)""" - aug_1 = self.remove_frequency(sample, pertub_ratio) - aug_2 = self.add_frequency(sample, pertub_ratio) - return aug_1 + aug_2 + removed_frequency = self.remove_frequency(sample, perturb_ratio) + added_frequency = self.add_frequency(sample, perturb_ratio) + return removed_frequency + added_frequency - def remove_frequency(self, x: torch.Tensor, pertub_ratio: float = 0.0) -> torch.Tensor: - mask = torch.rand(x.shape, device=x.device) > pertub_ratio + def remove_frequency(self, x: torch.Tensor, perturb_ratio: float = 0.0) -> torch.Tensor: + mask = torch.rand(x.shape, device=x.device) > perturb_ratio return x * mask - def add_frequency(self, x: torch.Tensor, pertub_ratio: float = 0.0) -> torch.Tensor: - mask = torch.rand(x.shape, device=x.device) > (1 - pertub_ratio) + def add_frequency(self, x: torch.Tensor, perturb_ratio: float = 0.0) -> torch.Tensor: + mask = torch.rand(x.shape, device=x.device) > (1 - perturb_ratio) max_amplitude = x.max() - random_am = torch.rand(mask.shape, device=x.device) * (max_amplitude * 0.1) - pertub_matrix = mask * random_am - return x + pertub_matrix + random_amplitude = torch.rand(mask.shape, device=x.device) * (max_amplitude * 0.1) + perturbation = mask * random_amplitude + return x + perturbation def ntxent_loss(self, zis: torch.Tensor, zjs: torch.Tensor, tau: float) -> torch.Tensor: """2N x 2N NTXentLoss aligned with the TFC implementation.""" @@ -134,9 +179,9 @@ def ntxent_loss(self, zis: torch.Tensor, zjs: torch.Tensor, tau: float) -> torch return loss / (2 * batch_size) def _forward_features(self, x_t: torch.Tensor, x_d: torch.Tensor, x_f: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - x_t = self.proj_t(x_t) - x_d = self.proj_d(x_d) - x_f = self.proj_f(x_f) + x_t = self.temporal_projection(x_t) + x_d = self.derivative_projection(x_d) + x_f = self.frequency_projection(x_f) x_t = self.pos_encoder(x_t) x_d = self.pos_encoder(x_d) @@ -147,23 +192,34 @@ def _forward_features(self, x_t: torch.Tensor, x_d: torch.Tensor, x_f: torch.Ten h_f = self.encoder_f(x_f) batch_size, seq_length, _ = h_t.shape - H = torch.stack([h_t, h_d, h_f], dim=2) - H_flat = H.permute(0, 2, 1, 3).contiguous().view(batch_size * 3, seq_length, self.hidden_dim) + view_sequence = torch.stack([h_t, h_d, h_f], dim=2) + flattened_views = ( + view_sequence + .permute(0, 2, 1, 3) + .contiguous() + .view(batch_size * 3, seq_length, self.hidden_dim) + ) - MHA_out, _ = self.fusion_mha(H_flat, H_flat, H_flat) - H_out = self.fusion_layer_norm(MHA_out + H_flat) + attention_output, _ = self.fusion_mha( + flattened_views, flattened_views, flattened_views + ) + fused_views = self.fusion_layer_norm(attention_output + flattened_views) - H_out = H_out.view(batch_size, 3, seq_length, self.hidden_dim).permute(0, 2, 1, 3) - h_t_star, h_d_star, h_f_star = H_out[:, :, 0, :], H_out[:, :, 1, :], H_out[:, :, 2, :] + fused_views = fused_views.view(batch_size, 3, seq_length, self.hidden_dim).permute(0, 2, 1, 3) + h_t_star, h_d_star, h_f_star = ( + fused_views[:, :, 0, :], + fused_views[:, :, 1, :], + fused_views[:, :, 2, :], + ) # Pool across sequence length and concatenate with pre-interaction features h_t_pool = torch.cat([h_t.mean(dim=1), h_t_star.mean(dim=1)], dim=-1) h_d_pool = torch.cat([h_d.mean(dim=1), h_d_star.mean(dim=1)], dim=-1) h_f_pool = torch.cat([h_f.mean(dim=1), h_f_star.mean(dim=1)], dim=-1) - z_t = self.F_t(h_t_pool) - z_d = self.F_d(h_d_pool) - z_f = self.F_f(h_f_pool) + z_t = self.temporal_feature_projector(h_t_pool) + z_d = self.derivative_feature_projector(h_d_pool) + z_f = self.frequency_feature_projector(h_f_pool) return z_t, z_d, z_f @@ -172,18 +228,14 @@ def forward(self, **kwargs) -> dict[str, torch.Tensor]: derivative_tensor = self._prepare_tensor(kwargs.get("xd")) # [N, L, 1] frequency_tensor = self._prepare_tensor(kwargs.get("xf")) # [N, L, 1] - # --- Stage Routing --- if self.training_stage == "pretrain": - # 1. Apply domain-specific augmentations x_t_aug = self.augment_time(temporal_tensor) x_d_aug = self.augment_time(derivative_tensor) x_f_aug = self.augment_freq(frequency_tensor) - # 2. Encode z_t, z_d, z_f = self._forward_features(temporal_tensor, derivative_tensor, frequency_tensor) z_t_aug, z_d_aug, z_f_aug = self._forward_features(x_t_aug, x_d_aug, x_f_aug) - # 3. Apply 2N x 2N NTXentLoss loss = self.ntxent_loss(z_t, z_t_aug, self.tau) + \ self.ntxent_loss(z_d, z_d_aug, self.tau) + \ self.ntxent_loss(z_f, z_f_aug, self.tau) @@ -201,8 +253,8 @@ def forward(self, **kwargs) -> dict[str, torch.Tensor]: # Cross-view attention for classification stacked_emb = torch.stack([z_t, z_d, z_f], dim=1) # [batch_size, 3, hidden_dim] - attn_out, _ = self.classifier_mha(stacked_emb, stacked_emb, stacked_emb) - emb = attn_out + stacked_emb # Residual connection + attention_output, _ = self.classifier_mha(stacked_emb, stacked_emb, stacked_emb) + emb = attention_output + stacked_emb # Residual connection z_combined = emb.reshape(emb.size(0), -1) # Flatten to [batch_size, 3 * hidden_dim] logits = self.classifier(z_combined) diff --git a/pyhealth/models/mvcl_model.py b/pyhealth/models/mvcl_model.py index cf7a281a0..1167155cf 100644 --- a/pyhealth/models/mvcl_model.py +++ b/pyhealth/models/mvcl_model.py @@ -1,8 +1,14 @@ +"""Generic multi-view contrastive learning model for PyHealth datasets.""" + +from typing import Callable, Dict, Optional, cast + +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F + from pyhealth.models import BaseModel -from typing import Callable, Dict, Optional, cast + class MultiViewContrastiveModel(BaseModel): """A generic, plug-and-play Multi-View Contrastive Learning (MVCL) model. @@ -86,12 +92,13 @@ def __init__( view: nn.Identity() for view in self.view_names }) - # Generic Fusion: Applies attention across the V different views - self.fusion_mha = nn.MultiheadAttention(embed_dim=self.hidden_dim, num_heads=4, batch_first=True) + # Cross-view fusion applies attention across the provided views. + self.fusion_mha = nn.MultiheadAttention( + embed_dim=self.hidden_dim, num_heads=4, batch_first=True + ) self.fusion_layer_norm = nn.LayerNorm(self.hidden_dim) - # Dynamic feature-specific projectors (F_k) - self.F_projectors = nn.ModuleDict({ + self.feature_projectors = nn.ModuleDict({ view: nn.Sequential( nn.Linear(self.hidden_dim * 2, self.hidden_dim), nn.BatchNorm1d(self.hidden_dim), @@ -102,7 +109,9 @@ def __init__( }) # Classifier dynamically sizes itself based on the number of views - self.classifier_mha = nn.MultiheadAttention(embed_dim=self.hidden_dim, num_heads=1, batch_first=True) + self.classifier_mha = nn.MultiheadAttention( + embed_dim=self.hidden_dim, num_heads=1, batch_first=True + ) self.classifier = nn.Sequential( nn.Linear(self.hidden_dim * len(self.view_names), 1024), nn.ReLU(), @@ -136,7 +145,6 @@ def ntxent_loss(self, zis: torch.Tensor, zjs: torch.Tensor, tau: float) -> torch def _forward_features(self, views_data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: encoded_views = {} - # 1. Project and Encode (keep sequence length for MHA) for view in self.view_names: x = views_data[view] x = self.projectors[view](x) @@ -145,84 +153,89 @@ def _forward_features(self, views_data: Dict[str, torch.Tensor]) -> Dict[str, to h = self.encoders[view](x) encoded_views[view] = h # Shape: [N, L, hidden_dim] - # 2. Cross-View Fusion batch_size = encoded_views[self.view_names[0]].shape[0] seq_length = encoded_views[self.view_names[0]].shape[1] num_views = len(self.view_names) - # Stack into [N, num_views, L, hidden_dim] - H = torch.stack([encoded_views[v] for v in self.view_names], dim=1) + # Stack into [N, num_views, L, hidden_dim]. + stacked_views = torch.stack([encoded_views[v] for v in self.view_names], dim=1) - H_permuted = H.permute(0, 2, 1, 3).contiguous() + sequence_first_views = stacked_views.permute(0, 2, 1, 3).contiguous() - # Flatten Batch and Sequence length together: [N * L, num_views, hidden_dim] - H_flat = H_permuted.view(batch_size * seq_length, num_views, self.hidden_dim) + # Flatten batch and sequence length: [N * L, num_views, hidden_dim]. + flattened_views = sequence_first_views.view( + batch_size * seq_length, num_views, self.hidden_dim + ) - MHA_out, _ = self.fusion_mha(H_flat, H_flat, H_flat) - H_out = self.fusion_layer_norm(MHA_out + H_flat) + attention_output, _ = self.fusion_mha( + flattened_views, flattened_views, flattened_views + ) + fused_views = self.fusion_layer_norm(attention_output + flattened_views) - # Reshape back to [N, L, num_views, hidden_dim] - H_out = H_out.view(batch_size, seq_length, num_views, self.hidden_dim) + # Reshape back to [N, L, num_views, hidden_dim]. + fused_views = fused_views.view(batch_size, seq_length, num_views, self.hidden_dim) - # 3. Concatenate and Project - final_zs = {} + view_embeddings = {} for i, view in enumerate(self.view_names): - h_pre = encoded_views[view].mean(dim=1) # Pre-interaction + pre_fusion_embedding = encoded_views[view].mean(dim=1) - # Extract from dimension 2 (num_views dimension) - h_post = H_out[:, :, i, :].mean(dim=1) # Post-interaction + post_fusion_embedding = fused_views[:, :, i, :].mean(dim=1) - # Concatenation of pre and post features! - h_pool = torch.cat([h_pre, h_post], dim=-1) # Shape: [N, hidden_dim * 2] + pooled_embedding = torch.cat( + [pre_fusion_embedding, post_fusion_embedding], dim=-1 + ) - final_zs[view] = self.F_projectors[view](h_pool) + view_embeddings[view] = self.feature_projectors[view](pooled_embedding) - return final_zs + return view_embeddings def forward(self, **kwargs) -> dict[str, torch.Tensor]: views_data = {view: self._prepare_tensor(kwargs.get(view)) for view in self.view_names} if self.training_stage == "pretrain": - # Augment augmented_views = { view: self.augmentations[view](views_data[view]) for view in self.view_names } - # Encode - zs = self._forward_features(views_data) - zs_aug = self._forward_features(augmented_views) + view_embeddings = self._forward_features(views_data) + augmented_embeddings = self._forward_features(augmented_views) - # Dynamic Loss Calculation loss = torch.tensor(0.0, device=self.device) for view in self.view_names: - loss += self.ntxent_loss(zs[view], zs_aug[view], self.tau) + loss += self.ntxent_loss( + view_embeddings[view], augmented_embeddings[view], self.tau + ) result = {"loss": loss} - result.update({f"z_{v}": zs[v] for v in self.view_names}) # Add embeddings + result.update({f"z_{v}": view_embeddings[v] for v in self.view_names}) return result elif self.training_stage == "finetune": - zs = self._forward_features(views_data) + view_embeddings = self._forward_features(views_data) - # Stack and fuse for classification - stacked_emb = torch.stack([zs[v] for v in self.view_names], dim=1) - attn_out, _ = self.classifier_mha(stacked_emb, stacked_emb, stacked_emb) - emb = attn_out + stacked_emb + stacked_embeddings = torch.stack( + [view_embeddings[v] for v in self.view_names], dim=1 + ) + attention_output, _ = self.classifier_mha( + stacked_embeddings, stacked_embeddings, stacked_embeddings + ) + fused_embeddings = attention_output + stacked_embeddings - z_combined = emb.reshape(emb.size(0), -1) - logits = self.classifier(z_combined) + combined_embedding = fused_embeddings.reshape(fused_embeddings.size(0), -1) + logits = self.classifier(combined_embedding) label_key = self.label_keys[0] y_true = cast(torch.Tensor, kwargs[label_key]).to(logits.device) loss_ce = self.get_loss_function()(logits, y_true) - # Contrastive Penalty augmented_views = {v: self.augmentations[v](views_data[v]) for v in self.view_names} - zs_aug = self._forward_features(augmented_views) + augmented_embeddings = self._forward_features(augmented_views) loss_cl = torch.tensor(0.0, device=self.device) for view in self.view_names: - loss_cl += self.ntxent_loss(zs[view], zs_aug[view], self.tau) + loss_cl += self.ntxent_loss( + view_embeddings[view], augmented_embeddings[view], self.tau + ) total_loss = (self.lambda_cl * loss_cl) + loss_ce @@ -236,6 +249,5 @@ def forward(self, **kwargs) -> dict[str, torch.Tensor]: def _prepare_tensor(self, x) -> torch.Tensor: if isinstance(x, list): - import numpy as np x = torch.stack(x) if isinstance(x[0], torch.Tensor) else torch.from_numpy(np.stack(x)) - return x.float().to(self.device) \ No newline at end of file + return x.float().to(self.device) diff --git a/pyhealth/tasks/mvcl_training_sleepedf_task.py b/pyhealth/tasks/mvcl_training_sleepedf_task.py index b54b0a0aa..cb3443662 100644 --- a/pyhealth/tasks/mvcl_training_sleepedf_task.py +++ b/pyhealth/tasks/mvcl_training_sleepedf_task.py @@ -131,7 +131,12 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: - 'label': Mapped 5-class sleep stage label. """ pid = patient.patient_id - events = patient.get_events() + events = patient.get_events(event_type="sleepedf") + if not events: + for event_type in ("cassette", "telemetry"): + events = patient.get_events(event_type=event_type) + if events: + break samples: List[Dict[str, Any]] = [] event_id = {