From b40cb7fa3c86f73b0ab0b4ce74457334f34a78d7 Mon Sep 17 00:00:00 2001 From: Bella Date: Sat, 18 Apr 2026 17:42:18 +0800 Subject: [PATCH 01/15] feat: prepare data for model training --- pyhealth/datasets/__init__.py | 1 + pyhealth/datasets/configs/mimic3_cf.yaml | 35 +++ pyhealth/datasets/mimic3_cf.py | 279 ++++++++++++++++++ pyhealth/tasks/__init__.py | 1 + .../tasks/circulatory_failure_prediction.py | 62 ++++ 5 files changed, 378 insertions(+) create mode 100644 pyhealth/datasets/configs/mimic3_cf.yaml create mode 100644 pyhealth/datasets/mimic3_cf.py create mode 100644 pyhealth/tasks/circulatory_failure_prediction.py diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 54e77670c..ba2e77f64 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -90,3 +90,4 @@ def __init__(self, *args, **kwargs): save_processors, ) from .collate import collate_temporal +from .mimic3_cf import MIMIC3CirculatoryFailureDataset \ No newline at end of file diff --git a/pyhealth/datasets/configs/mimic3_cf.yaml b/pyhealth/datasets/configs/mimic3_cf.yaml new file mode 100644 index 000000000..6ef098c92 --- /dev/null +++ b/pyhealth/datasets/configs/mimic3_cf.yaml @@ -0,0 +1,35 @@ +version: "1.4" +tables: + patients: + file_path: "PATIENTS.csv.gz" + patient_id: "subject_id" + timestamp: null + attributes: + - "gender" + - "dob" + - "dod" + - "expire_flag" + + admissions: + file_path: "ADMISSIONS.csv.gz" + patient_id: "subject_id" + timestamp: "admittime" + attributes: + - "hadm_id" + - "admittime" + - "dischtime" + - "deathtime" + - "hospital_expire_flag" + - "ethnicity" + + icustays: + file_path: "ICUSTAYS.csv.gz" + patient_id: "subject_id" + timestamp: "intime" + attributes: + - "hadm_id" + - "icustay_id" + - "intime" + - "outtime" + - "first_careunit" + - "last_careunit" \ No newline at end of file diff --git a/pyhealth/datasets/mimic3_cf.py b/pyhealth/datasets/mimic3_cf.py new file mode 100644 index 000000000..eb5901b5b --- /dev/null +++ b/pyhealth/datasets/mimic3_cf.py @@ -0,0 +1,279 @@ +import logging +from pathlib import Path +from typing import List, Optional +import pandas as pd +from .base_dataset import BaseDataset + +logger = logging.getLogger(__name__) + + +class MIMIC3CirculatoryFailureDataset(BaseDataset): + """MIMIC-III dataset for circulatory failure early-warning prediction. + + This dataset is designed for a FAMEWS-inspired reproduction setting on + MIMIC-III. It will support cohort construction, event parsing, and + time-series feature extraction for circulatory failure prediction within + a future prediction window. + + Args: + root: Root directory of the MIMIC-III dataset. + tables: Additional tables to load beyond the default cohort tables. + dataset_name: Name of the dataset instance. + config_path: Path to the dataset config YAML file. + **kwargs: Additional keyword arguments passed to BaseDataset. + """ + + def __init__( + self, + root: str, + tables: Optional[List[str]] = None, + dataset_name: Optional[str] = None, + config_path: Optional[str] = None, + **kwargs, + ) -> None: + """Initializes the MIMIC-III circulatory failure dataset.""" + if config_path is None: + logger.info("No config path provided, using default config") + config_path = Path(__file__).parent / "configs" / "mimic3_cf.yaml" + + default_tables = ["patients", "admissions", "icustays"] + tables = default_tables + (tables or []) + + super().__init__( + root=root, + tables=tables, + dataset_name=dataset_name or "mimic3_cf", + config_path=str(config_path), + **kwargs, + ) + + def load_cohort(self): + """Load patients + admissions + icustays.""" + + import pandas as pd + from pathlib import Path + + root = Path(self.root) + + patients_df = pd.read_csv(root / "PATIENTS.csv.gz") + admissions_df = pd.read_csv(root / "ADMISSIONS.csv.gz") + icustays_df = pd.read_csv(root / "ICUSTAYS.csv.gz") + + df = patients_df.merge(admissions_df, on="SUBJECT_ID") + df = df.merge(icustays_df, on=["SUBJECT_ID", "HADM_ID"]) + + patients = [] + + for _, row in df.iterrows(): + patients.append( + { + "patient_id": row["SUBJECT_ID"], + "gender": row["GENDER"], + "hadm_id": row["HADM_ID"], + "icustay_id": row["ICUSTAY_ID"], + "admittime": row["ADMITTIME"], + "intime": row["INTIME"], + "outtime": row["OUTTIME"], + } + ) + + return patients + + def load_patients(self): + """Backward-compatible wrapper for current development.""" + return self.load_cohort() + + def build_failure_labels(self): + """Build first failure time per ICU stay (MAP < 65) using chunked reads.""" + + import pandas as pd + from pathlib import Path + + root = Path(self.root) + + # load cohort once + cohort = pd.DataFrame(self.load_cohort()) + cohort["intime"] = pd.to_datetime(cohort["intime"]) + cohort["outtime"] = pd.to_datetime(cohort["outtime"]) + + results = [] + + chunks = pd.read_csv( + root / "CHARTEVENTS.csv.gz", + usecols=[ + "SUBJECT_ID", + "HADM_ID", + "ICUSTAY_ID", + "ITEMID", + "CHARTTIME", + "VALUENUM", + ], + chunksize=50000, + ) + + for chunk in chunks: + # filter MAP only + chunk = chunk[chunk["ITEMID"] == 220052].copy() + if chunk.empty: + continue + + chunk["CHARTTIME"] = pd.to_datetime( + chunk["CHARTTIME"], + format="%Y-%m-%d %H:%M:%S", + errors="coerce", + ) + + merged = chunk.merge( + cohort, + left_on="ICUSTAY_ID", + right_on="icustay_id", + ) + + if merged.empty: + continue + + filtered = merged[ + (merged["CHARTTIME"] >= merged["intime"]) + & (merged["CHARTTIME"] <= merged["outtime"]) + ].copy() + + if filtered.empty: + continue + + filtered["failure_label"] = (filtered["VALUENUM"] < 65).astype(int) + + failure_events = filtered[filtered["failure_label"] == 1] + if failure_events.empty: + continue + + first_failure_chunk = ( + failure_events.groupby("ICUSTAY_ID")["CHARTTIME"] + .min() + .reset_index() + .rename(columns={"CHARTTIME": "first_failure_time"}) + ) + + results.append(first_failure_chunk) + + if not results: + return pd.DataFrame(columns=["ICUSTAY_ID", "first_failure_time"]) + + first_failure = pd.concat(results, ignore_index=True) + + # keep earliest failure time per ICU stay across all chunks + first_failure = ( + first_failure.groupby("ICUSTAY_ID")["first_failure_time"] + .min() + .reset_index() + ) + + return first_failure + + # filter MAP + map_df = chartevents[chartevents["ITEMID"] == 220052] + + # convert time + map_df["CHARTTIME"] = pd.to_datetime(map_df["CHARTTIME"]) + + # load cohort + cohort = pd.DataFrame(self.load_cohort()) + cohort["intime"] = pd.to_datetime(cohort["intime"]) + cohort["outtime"] = pd.to_datetime(cohort["outtime"]) + + # merge + merged = map_df.merge( + cohort, + left_on="ICUSTAY_ID", + right_on="icustay_id", + ) + + # filter ICU period + filtered = merged[ + (merged["CHARTTIME"] >= merged["intime"]) + & (merged["CHARTTIME"] <= merged["outtime"]) + ].copy() + + # label + filtered["failure_label"] = (filtered["VALUENUM"] < 65).astype(int) + + # first failure time + failure_events = filtered[filtered["failure_label"] == 1] + + first_failure = ( + failure_events.groupby("ICUSTAY_ID")["CHARTTIME"] + .min() + .reset_index() + .rename(columns={"CHARTTIME": "first_failure_time"}) + ) + + return first_failure + + def get_patient_by_icustay_id(self, icustay_id: int): + """Build one task-ready patient dict for a given ICU stay.""" + + import pandas as pd + from pathlib import Path + + root = Path(self.root) + + # 1) load cohort + cohort_df = pd.DataFrame(self.load_cohort()) + cohort_df["intime"] = pd.to_datetime(cohort_df["intime"]) + cohort_df["outtime"] = pd.to_datetime(cohort_df["outtime"]) + + row = cohort_df[cohort_df["icustay_id"] == icustay_id] + if row.empty: + return None + row = row.iloc[0] + + # 2) load failure labels + first_failure = self.build_failure_labels() + failure_row = first_failure[first_failure["ICUSTAY_ID"] == icustay_id] + + first_failure_time = None + if not failure_row.empty: + first_failure_time = failure_row.iloc[0]["first_failure_time"] + + # 3) load MAP time series for this ICU stay + chartevents = pd.read_csv( + root / "CHARTEVENTS.csv.gz", + usecols=["ICUSTAY_ID", "ITEMID", "CHARTTIME", "VALUENUM"], + ) + + ts = chartevents[chartevents["ITEMID"] == 220052].copy() + ts["CHARTTIME"] = pd.to_datetime( + ts["CHARTTIME"], + format="%Y-%m-%d %H:%M:%S", + errors="coerce", + ) + ts = ts[ts["ICUSTAY_ID"] == icustay_id].copy() + + ts = ts[ + (ts["CHARTTIME"] >= row["intime"]) & + (ts["CHARTTIME"] <= row["outtime"]) + ].copy() + + ts = ts.sort_values("CHARTTIME") + + time_series = [] + for _, ts_row in ts.iterrows(): + if pd.isna(ts_row["VALUENUM"]): + continue + time_series.append( + { + "charttime": ts_row["CHARTTIME"], + "map": float(ts_row["VALUENUM"]), + } + ) + + patient = { + "patient_id": int(row["patient_id"]), + "icustay_id": int(row["icustay_id"]), + "gender": row["gender"], + "intime": row["intime"], + "outtime": row["outtime"], + "time_series": time_series, + "first_failure_time": first_failure_time, + } + + return patient \ No newline at end of file diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 797988377..61cbf018c 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -66,3 +66,4 @@ VariantClassificationClinVar, ) from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task +from .circulatory_failure_prediction import CirculatoryFailurePredictionTask \ No newline at end of file diff --git a/pyhealth/tasks/circulatory_failure_prediction.py b/pyhealth/tasks/circulatory_failure_prediction.py new file mode 100644 index 000000000..389a9d012 --- /dev/null +++ b/pyhealth/tasks/circulatory_failure_prediction.py @@ -0,0 +1,62 @@ +from datetime import timedelta +from typing import Dict, List, Optional + + +class CirculatoryFailurePredictionTask: + """Early-warning task for circulatory failure prediction.""" + + def __init__( + self, + prediction_window_hours: int = 12, + ) -> None: + self.prediction_window_hours = prediction_window_hours + + def _to_timestamp(self, value): + """Converts a value to pandas.Timestamp lazily.""" + import pandas as pd + + if value is None: + return None + if pd.isna(value): + return None + return pd.to_datetime(value) + + def __call__(self, patient: Dict) -> Optional[List[Dict]]: + """Converts one patient record into training samples.""" + time_series = patient.get("time_series", None) + if not time_series: + return None + + first_failure_time = self._to_timestamp( + patient.get("first_failure_time", None) + ) + prediction_window = timedelta(hours=self.prediction_window_hours) + + samples = [] + + for point in time_series: + charttime = self._to_timestamp(point["charttime"]) + map_value = point.get("map", None) + + if charttime is None or map_value is None: + continue + + label = 0 + if first_failure_time is not None: + label = int( + charttime < first_failure_time <= charttime + prediction_window + ) + + sample = { + "patient_id": patient.get("patient_id"), + "icustay_id": patient.get("icustay_id"), + "gender": patient.get("gender"), + "timestamp": charttime, + "features": { + "map": float(map_value), + }, + "label": label, + } + samples.append(sample) + + return samples if samples else None \ No newline at end of file From 93640d254c8aa9532e1353cb667b205ccc1ad61b Mon Sep 17 00:00:00 2001 From: Bella Date: Sat, 18 Apr 2026 22:06:51 +0800 Subject: [PATCH 02/15] feat: load data with cache --- pyhealth/datasets/mimic3_cf.py | 61 ++++++++++++++++++++++++++-------- 1 file changed, 47 insertions(+), 14 deletions(-) diff --git a/pyhealth/datasets/mimic3_cf.py b/pyhealth/datasets/mimic3_cf.py index eb5901b5b..e9ce0864d 100644 --- a/pyhealth/datasets/mimic3_cf.py +++ b/pyhealth/datasets/mimic3_cf.py @@ -235,19 +235,8 @@ def get_patient_by_icustay_id(self, icustay_id: int): first_failure_time = failure_row.iloc[0]["first_failure_time"] # 3) load MAP time series for this ICU stay - chartevents = pd.read_csv( - root / "CHARTEVENTS.csv.gz", - usecols=["ICUSTAY_ID", "ITEMID", "CHARTTIME", "VALUENUM"], - ) - - ts = chartevents[chartevents["ITEMID"] == 220052].copy() - ts["CHARTTIME"] = pd.to_datetime( - ts["CHARTTIME"], - format="%Y-%m-%d %H:%M:%S", - errors="coerce", - ) - ts = ts[ts["ICUSTAY_ID"] == icustay_id].copy() - + map_df = self.load_map_cache() + ts = map_df[map_df["ICUSTAY_ID"] == icustay_id].copy() ts = ts[ (ts["CHARTTIME"] >= row["intime"]) & (ts["CHARTTIME"] <= row["outtime"]) @@ -276,4 +265,48 @@ def get_patient_by_icustay_id(self, icustay_id: int): "first_failure_time": first_failure_time, } - return patient \ No newline at end of file + return patient + + def load_map_cache(self): + """Load MAP (mean arterial pressure) data once and cache it.""" + + import pandas as pd + from pathlib import Path + + if hasattr(self, "_map_cache"): + return self._map_cache + + root = Path(self.root) + + print("Loading MAP cache (this will take a bit, only once)...") + + chunks = pd.read_csv( + root / "CHARTEVENTS.csv.gz", + usecols=["ICUSTAY_ID", "ITEMID", "CHARTTIME", "VALUENUM"], + chunksize=100000, + ) + + parts = [] + + for chunk in chunks: + chunk = chunk[chunk["ITEMID"] == 220052].copy() + if chunk.empty: + continue + + chunk["CHARTTIME"] = pd.to_datetime( + chunk["CHARTTIME"], + format="%Y-%m-%d %H:%M:%S", + errors="coerce", + ) + + parts.append(chunk) + + if parts: + df = pd.concat(parts, ignore_index=True) + else: + df = pd.DataFrame(columns=["ICUSTAY_ID", "CHARTTIME", "VALUENUM"]) + + self._map_cache = df + print("MAP cache loaded:", len(df)) + + return df \ No newline at end of file From c14fb13851b1c3bda4d37bc03d9bc0a0a06152d4 Mon Sep 17 00:00:00 2001 From: Bella Date: Sat, 18 Apr 2026 22:08:01 +0800 Subject: [PATCH 03/15] feat: show how prepared dataset can be used --- pyhealth/test_pipeline.py | 44 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 pyhealth/test_pipeline.py diff --git a/pyhealth/test_pipeline.py b/pyhealth/test_pipeline.py new file mode 100644 index 000000000..7187d83cb --- /dev/null +++ b/pyhealth/test_pipeline.py @@ -0,0 +1,44 @@ +from pyhealth.datasets import MIMIC3CirculatoryFailureDataset +from pyhealth.tasks import CirculatoryFailurePredictionTask + + +def main(): + # 1. 初始化資料集(請確保路徑正確) + dataset = MIMIC3CirculatoryFailureDataset( + root="/mimic_test" + ) + + # 2. 初始化任務(12小時預警) + task = CirculatoryFailurePredictionTask(prediction_window_hours=12) + + # 3. 讀取 cohort,找到第一個真的能產生 samples 的 ICU stay + cohort = dataset.load_cohort() + + samples = None + chosen_icustay_id = None + + for row in cohort: + icustay_id = row["icustay_id"] + patient = dataset.get_patient_by_icustay_id(icustay_id) + samples = task(patient) + + if samples: + chosen_icustay_id = icustay_id + break + + # 4. 檢查結果 + if samples: + print(f"--- 成功測試 ICU Stay ID: {chosen_icustay_id} ---") + print(f"成功產生樣本數: {len(samples)}") + print( + f"其中 Label=1 (未來12小時內會衰竭) 的數量: " + f"{sum(s['label'] for s in samples)}" + ) + print(f"第一筆樣本特徵: {samples[0]['features']}") + print(f"第一筆完整樣本: {samples[0]}") + else: + print("未找到任何可產生樣本的 ICU stay,請檢查資料與路徑。") + + +if __name__ == "__main__": + main() \ No newline at end of file From 664a1ecf57e72f24fcf0c222a834158bdd90cb6c Mon Sep 17 00:00:00 2001 From: Bella Date: Sat, 18 Apr 2026 22:33:54 +0800 Subject: [PATCH 04/15] feat: add baseline model for future models comparison --- pyhealth/train_baseline.py | 89 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 pyhealth/train_baseline.py diff --git a/pyhealth/train_baseline.py b/pyhealth/train_baseline.py new file mode 100644 index 000000000..4085d9e29 --- /dev/null +++ b/pyhealth/train_baseline.py @@ -0,0 +1,89 @@ +from pyhealth.datasets import MIMIC3CirculatoryFailureDataset +from pyhealth.tasks import CirculatoryFailurePredictionTask + +import pandas as pd +from sklearn.model_selection import train_test_split +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import roc_auc_score, accuracy_score + + +def build_dataset(dataset, task, max_icu=50): + """Build samples from multiple ICU stays""" + cohort = dataset.load_cohort()[:max_icu] + + all_samples = [] + + for row in cohort: + icustay_id = row["icustay_id"] + patient = dataset.get_patient_by_icustay_id(icustay_id) + + if patient is None: + continue + + samples = task(patient) + if samples: + all_samples.extend(samples) + + return all_samples + + +def samples_to_df(samples): + """Convert samples to DataFrame""" + df = pd.DataFrame( + [ + { + "patient_id": s["patient_id"], + "icustay_id": s["icustay_id"], + "gender": s["gender"], + "timestamp": s["timestamp"], + "map": s["features"]["map"], + "label": s["label"], + } + for s in samples + ] + ) + return df + + +def train_model(df): + """Train a simple baseline model""" + + X = df[["map"]] + y = df["label"] + + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42 + ) + + model = LogisticRegression() + model.fit(X_train, y_train) + + preds = model.predict(X_test) + probs = model.predict_proba(X_test)[:, 1] + + print("Accuracy:", accuracy_score(y_test, preds)) + print("ROC-AUC:", roc_auc_score(y_test, probs)) + + return model + + +def main(): + dataset = MIMIC3CirculatoryFailureDataset( + root="/Users/bella/Desktop/UIUC MCS/CS598/mimic_test" + ) + task = CirculatoryFailurePredictionTask() + + print("Building dataset...") + samples = build_dataset(dataset, task, max_icu=20) + + print(f"Total samples: {len(samples)}") + + df = samples_to_df(samples) + print(df.head()) + + print("\nTraining model...") + model = train_model(df) + + +if __name__ == "__main__": + main() \ No newline at end of file From cc406e7a34a60fb283adb52d8e2b6f7c39784eb1 Mon Sep 17 00:00:00 2001 From: Bella Date: Tue, 21 Apr 2026 00:20:06 +0800 Subject: [PATCH 05/15] feat:add MIMIC-III circulatory failure dataset and prediction task --- docs/api/datasets.rst | 1 + .../datasets/pyhealth.datasets.mimic3_cf.rst | 26 +++++ docs/api/tasks.rst | 1 + ...h.tasks.circulatory_failure_prediction.rst | 24 +++++ examples/mimic3_cf_example.py | 23 +++++ pyhealth/datasets/mimic3_cf.py | 24 ++++- .../tasks/circulatory_failure_prediction.py | 97 +++++++++---------- tests/test_circulatory_failure_prediction.py | 28 ++++++ tests/test_mimic3_cf.py | 57 +++++++++++ 9 files changed, 229 insertions(+), 52 deletions(-) create mode 100644 docs/api/datasets/pyhealth.datasets.mimic3_cf.rst create mode 100644 docs/api/tasks/pyhealth.tasks.circulatory_failure_prediction.rst create mode 100644 examples/mimic3_cf_example.py create mode 100644 tests/test_circulatory_failure_prediction.py create mode 100644 tests/test_mimic3_cf.py diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index b02439d26..6e3517c7a 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -245,3 +245,4 @@ Available Datasets datasets/pyhealth.datasets.TCGAPRADDataset datasets/pyhealth.datasets.splitter datasets/pyhealth.datasets.utils + datasets/pyhealth.datasets.mimic3_cf \ No newline at end of file diff --git a/docs/api/datasets/pyhealth.datasets.mimic3_cf.rst b/docs/api/datasets/pyhealth.datasets.mimic3_cf.rst new file mode 100644 index 000000000..b0d47472c --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.mimic3_cf.rst @@ -0,0 +1,26 @@ +pyhealth.datasets.mimic3_cf +=========================== + +Overview +-------- + +MIMIC3CirculatoryFailureDataset is a MIMIC-III based dataset for early warning +prediction of circulatory failure. + +It constructs an ICU-stay-level cohort from PATIENTS, ADMISSIONS, and ICUSTAYS, +and uses CHARTEVENTS to extract Mean Arterial Pressure (MAP) measurements. + +Circulatory failure is defined using a proxy event: + +- MAP < 65 mmHg + +For each ICU stay, the dataset identifies the first occurrence of this event and +supports building task-ready patient records for downstream prediction tasks. + +API Reference +------------- + +.. autoclass:: pyhealth.datasets.MIMIC3CirculatoryFailureDataset + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 399b8f1aa..f39838f6a 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -229,3 +229,4 @@ Available Tasks Mutation Pathogenicity (COSMIC) Cancer Survival Prediction (TCGA) Cancer Mutation Burden (TCGA) + Circulatory Failure Prediction \ No newline at end of file diff --git a/docs/api/tasks/pyhealth.tasks.circulatory_failure_prediction.rst b/docs/api/tasks/pyhealth.tasks.circulatory_failure_prediction.rst new file mode 100644 index 000000000..6a9cf0e2c --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.circulatory_failure_prediction.rst @@ -0,0 +1,24 @@ +pyhealth.tasks.circulatory_failure_prediction +============================================= + +Overview +-------- + +CirculatoryFailurePredictionTask defines a time-series prediction task for early +detection of circulatory failure. + +The task predicts whether a patient will experience circulatory failure within +the next 12 hours based on physiological measurements. + +Label definition: + +- label = 1 if circulatory failure occurs within the next 12 hours +- label = 0 otherwise + +API Reference +------------- + +.. autoclass:: pyhealth.tasks.CirculatoryFailurePredictionTask + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/examples/mimic3_cf_example.py b/examples/mimic3_cf_example.py new file mode 100644 index 000000000..baa4c8783 --- /dev/null +++ b/examples/mimic3_cf_example.py @@ -0,0 +1,23 @@ +from pyhealth.datasets import MIMIC3CirculatoryFailureDataset +from pyhealth.tasks import CirculatoryFailurePredictionTask + + +def main(): + dataset = MIMIC3CirculatoryFailureDataset( + root="/path/to/mimic3" + ) + + task = CirculatoryFailurePredictionTask(prediction_window_hours=12) + + # apply task + samples = dataset.set_task(task, max_patients=5) + + print(f"Total samples: {len(samples)}") + + if samples: + print("Sample example:") + print(samples[0]) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/pyhealth/datasets/mimic3_cf.py b/pyhealth/datasets/mimic3_cf.py index e9ce0864d..c852745ea 100644 --- a/pyhealth/datasets/mimic3_cf.py +++ b/pyhealth/datasets/mimic3_cf.py @@ -309,4 +309,26 @@ def load_map_cache(self): self._map_cache = df print("MAP cache loaded:", len(df)) - return df \ No newline at end of file + return df + + def set_task(self, task, max_patients: int | None = None): + """Apply a task function to the cohort and return task samples.""" + + samples = [] + cohort = self.load_cohort() + + if max_patients is not None: + cohort = cohort[:max_patients] + + for row in cohort: + icustay_id = row["icustay_id"] + patient = self.get_patient_by_icustay_id(icustay_id) + + if patient is None: + continue + + task_samples = task(patient) + if task_samples: + samples.extend(task_samples) + + return samples \ No newline at end of file diff --git a/pyhealth/tasks/circulatory_failure_prediction.py b/pyhealth/tasks/circulatory_failure_prediction.py index 389a9d012..d3e3083de 100644 --- a/pyhealth/tasks/circulatory_failure_prediction.py +++ b/pyhealth/tasks/circulatory_failure_prediction.py @@ -1,62 +1,57 @@ -from datetime import timedelta -from typing import Dict, List, Optional +from pyhealth.tasks.base_task import BaseTask +from typing import List, Dict -class CirculatoryFailurePredictionTask: - """Early-warning task for circulatory failure prediction.""" +class CirculatoryFailurePredictionTask(BaseTask): - def __init__( - self, - prediction_window_hours: int = 12, - ) -> None: + task_name = "circulatory_failure_prediction" + + input_schema = { + "map": float, + "timestamp": "datetime", + "gender": str, + } + + output_schema = { + "label": int, + } + + def __init__(self, prediction_window_hours: int = 12): + super().__init__() self.prediction_window_hours = prediction_window_hours - def _to_timestamp(self, value): - """Converts a value to pandas.Timestamp lazily.""" + def __call__(self, patient) -> List[Dict]: + if not patient["time_series"]: + return [] + import pandas as pd + from datetime import timedelta + + first_failure_time = patient["first_failure_time"] + if first_failure_time is None: + return [] + + first_failure_time = pd.to_datetime(first_failure_time) - if value is None: - return None - if pd.isna(value): - return None - return pd.to_datetime(value) - - def __call__(self, patient: Dict) -> Optional[List[Dict]]: - """Converts one patient record into training samples.""" - time_series = patient.get("time_series", None) - if not time_series: - return None - - first_failure_time = self._to_timestamp( - patient.get("first_failure_time", None) - ) prediction_window = timedelta(hours=self.prediction_window_hours) samples = [] - for point in time_series: - charttime = self._to_timestamp(point["charttime"]) - map_value = point.get("map", None) - - if charttime is None or map_value is None: - continue - - label = 0 - if first_failure_time is not None: - label = int( - charttime < first_failure_time <= charttime + prediction_window - ) - - sample = { - "patient_id": patient.get("patient_id"), - "icustay_id": patient.get("icustay_id"), - "gender": patient.get("gender"), - "timestamp": charttime, - "features": { - "map": float(map_value), - }, - "label": label, - } - samples.append(sample) - - return samples if samples else None \ No newline at end of file + for row in patient["time_series"]: + t = pd.to_datetime(row["charttime"]) + map_value = row["map"] + + label = int(t < first_failure_time <= t + prediction_window) + + samples.append( + { + "patient_id": patient["patient_id"], + "icustay_id": patient["icustay_id"], + "gender": patient["gender"], + "timestamp": t, + "features": {"map": map_value}, + "label": label, + } + ) + + return samples \ No newline at end of file diff --git a/tests/test_circulatory_failure_prediction.py b/tests/test_circulatory_failure_prediction.py new file mode 100644 index 000000000..e0daba760 --- /dev/null +++ b/tests/test_circulatory_failure_prediction.py @@ -0,0 +1,28 @@ +from pyhealth.tasks import CirculatoryFailurePredictionTask + + +def test_circulatory_failure_task_basic(): + task = CirculatoryFailurePredictionTask(prediction_window_hours=12) + + patient = { + "patient_id": 1, + "icustay_id": 1001, + "gender": "F", + "first_failure_time": "2150-01-01 11:00:00", + "time_series": [ + {"charttime": "2150-01-01 00:00:00", "map": 80.0}, + {"charttime": "2150-01-01 01:00:00", "map": 78.0}, + {"charttime": "2150-01-01 10:00:00", "map": 70.0}, + {"charttime": "2150-01-01 11:00:00", "map": 60.0}, + ], + } + + samples = task(patient) + + assert len(samples) == 4 + assert samples[0]["label"] == 1 + assert samples[1]["label"] == 1 + assert samples[2]["label"] == 1 + assert samples[3]["label"] == 0 + assert samples[0]["features"]["map"] == 80.0 + assert samples[0]["gender"] == "F" \ No newline at end of file diff --git a/tests/test_mimic3_cf.py b/tests/test_mimic3_cf.py new file mode 100644 index 000000000..c8fa7b4a1 --- /dev/null +++ b/tests/test_mimic3_cf.py @@ -0,0 +1,57 @@ +from pyhealth.datasets import MIMIC3CirculatoryFailureDataset +from pyhealth.tasks import CirculatoryFailurePredictionTask + + +class DummyMIMIC3CFDataset(MIMIC3CirculatoryFailureDataset): + """Small synthetic subclass for fast unit testing.""" + + def __init__(self): + # do not call super().__init__ because we don't want real files + pass + + def load_cohort(self): + return [ + { + "patient_id": 1, + "gender": "F", + "hadm_id": 100, + "icustay_id": 1001, + "admittime": "2150-01-01 00:00:00", + "intime": "2150-01-01 00:00:00", + "outtime": "2150-01-02 00:00:00", + } + ] + + def get_patient_by_icustay_id(self, icustay_id: int): + if icustay_id != 1001: + return None + + return { + "patient_id": 1, + "icustay_id": 1001, + "gender": "F", + "intime": "2150-01-01 00:00:00", + "outtime": "2150-01-02 00:00:00", + "first_failure_time": "2150-01-01 11:00:00", + "time_series": [ + {"charttime": "2150-01-01 00:00:00", "map": 80.0}, + {"charttime": "2150-01-01 01:00:00", "map": 78.0}, + {"charttime": "2150-01-01 10:00:00", "map": 70.0}, + {"charttime": "2150-01-01 11:00:00", "map": 60.0}, + ], + } + + +def test_set_task_returns_samples(): + dataset = DummyMIMIC3CFDataset() + task = CirculatoryFailurePredictionTask(prediction_window_hours=12) + + samples = dataset.set_task(task) + + assert isinstance(samples, list) + assert len(samples) == 4 + assert samples[0]["patient_id"] == 1 + assert samples[0]["icustay_id"] == 1001 + assert samples[0]["features"]["map"] == 80.0 + assert samples[0]["label"] == 1 + assert samples[-1]["label"] == 0 \ No newline at end of file From c8e5f3652926fe71c01361412ff263c7b49d6731 Mon Sep 17 00:00:00 2001 From: Bella Date: Tue, 21 Apr 2026 00:39:16 +0800 Subject: [PATCH 06/15] feat:modify baseline model to fit the function-call style of PyHealth --- pyhealth/train_baseline.py | 23 +---------------------- 1 file changed, 1 insertion(+), 22 deletions(-) diff --git a/pyhealth/train_baseline.py b/pyhealth/train_baseline.py index 4085d9e29..10ac3b939 100644 --- a/pyhealth/train_baseline.py +++ b/pyhealth/train_baseline.py @@ -1,32 +1,11 @@ from pyhealth.datasets import MIMIC3CirculatoryFailureDataset from pyhealth.tasks import CirculatoryFailurePredictionTask - import pandas as pd from sklearn.model_selection import train_test_split from sklearn.linear_model import LogisticRegression from sklearn.metrics import roc_auc_score, accuracy_score -def build_dataset(dataset, task, max_icu=50): - """Build samples from multiple ICU stays""" - cohort = dataset.load_cohort()[:max_icu] - - all_samples = [] - - for row in cohort: - icustay_id = row["icustay_id"] - patient = dataset.get_patient_by_icustay_id(icustay_id) - - if patient is None: - continue - - samples = task(patient) - if samples: - all_samples.extend(samples) - - return all_samples - - def samples_to_df(samples): """Convert samples to DataFrame""" df = pd.DataFrame( @@ -74,7 +53,7 @@ def main(): task = CirculatoryFailurePredictionTask() print("Building dataset...") - samples = build_dataset(dataset, task, max_icu=20) + samples = dataset.set_task(task, max_patients=20) print(f"Total samples: {len(samples)}") From 550251619da79d81d9bbf5278d6a7d9bcaea88ff Mon Sep 17 00:00:00 2001 From: Bella Date: Wed, 22 Apr 2026 01:16:15 +0800 Subject: [PATCH 07/15] feat: add CirculatoryFailure logistic regression example and register chartevents in MIMIC3 configuration --- .../mimic3_cf_circulatory_failure_logreg.py | 136 ++++++++++++++++++ 1 file changed, 136 insertions(+) create mode 100644 examples/mimic3_cf_circulatory_failure_logreg.py diff --git a/examples/mimic3_cf_circulatory_failure_logreg.py b/examples/mimic3_cf_circulatory_failure_logreg.py new file mode 100644 index 000000000..fd7c6b8c6 --- /dev/null +++ b/examples/mimic3_cf_circulatory_failure_logreg.py @@ -0,0 +1,136 @@ +from pyhealth.datasets import MIMIC3CirculatoryFailureDataset +from pyhealth.tasks import CirculatoryFailurePredictionTask + +import pandas as pd +from sklearn.model_selection import train_test_split +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import accuracy_score, roc_auc_score, recall_score + + +def samples_to_df(samples: list[dict]) -> pd.DataFrame: + rows = [] + for s in samples: + rows.append( + { + "patient_id": s["patient_id"], + "icustay_id": s["icustay_id"], + "gender": s["gender"], + "timestamp": s["timestamp"], + "map": s["features"]["map"], + "label": s["label"], + } + ) + df = pd.DataFrame(rows) + return df + + +def add_advanced_features(df: pd.DataFrame) -> pd.DataFrame: + """Add simple temporal features for the advanced setting.""" + df = df.sort_values(["icustay_id", "timestamp"]).copy() + df["map_prev"] = df.groupby("icustay_id")["map"].shift(1) + df["map_diff"] = df["map"] - df["map_prev"] + df["map_prev"] = df["map_prev"].fillna(df["map"]) + df["map_diff"] = df["map_diff"].fillna(0.0) + return df + + +def evaluate_model( + df: pd.DataFrame, + feature_cols: list[str], + balanced: bool = False, +) -> dict: + if df.empty or df["label"].nunique() < 2: + return { + "n_samples": len(df), + "accuracy": None, + "roc_auc": None, + "recall": None, + } + + X = df[feature_cols] + y = df["label"] + + X_train, X_test, y_train, y_test = train_test_split( + X, + y, + test_size=0.2, + random_state=42, + stratify=y, + ) + + model = LogisticRegression( + max_iter=1000, + class_weight="balanced" if balanced else None, + ) + model.fit(X_train, y_train) + + preds = model.predict(X_test) + probs = model.predict_proba(X_test)[:, 1] + + return { + "n_samples": len(df), + "accuracy": accuracy_score(y_test, preds), + "roc_auc": roc_auc_score(y_test, probs), + "recall": recall_score(y_test, preds), + } + + +def print_metrics(title: str, metrics: dict) -> None: + print(f"\n=== {title} ===") + print(f"n_samples: {metrics['n_samples']}") + print(f"accuracy: {metrics['accuracy']}") + print(f"roc_auc: {metrics['roc_auc']}") + print(f"recall: {metrics['recall']}") + + +def main() -> None: + dataset = MIMIC3CirculatoryFailureDataset( + root="mimic_test" + ) + + # task ablation: prediction windows + for window in [6, 12, 24]: + print(f"\n############################") + print(f"Prediction window = {window}h") + print(f"############################") + + task = CirculatoryFailurePredictionTask(prediction_window_hours=window) + samples = dataset.set_task(task, max_patients=100) + df = samples_to_df(samples) + + print("\nSample preview:") + print(df.head()) + + # baseline setting + baseline_metrics = evaluate_model( + df=df, + feature_cols=["map"], + balanced=False, + ) + print_metrics("Baseline: LogisticRegression(map)", baseline_metrics) + + # advanced setting + df_adv = add_advanced_features(df) + advanced_metrics = evaluate_model( + df=df_adv, + feature_cols=["map", "map_diff"], + balanced=True, + ) + print_metrics( + "Advanced: LogisticRegression(map + map_diff, balanced)", + advanced_metrics, + ) + + # subgroup fairness + for gender in ["M", "F"]: + subgroup_df = df_adv[df_adv["gender"] == gender].copy() + subgroup_metrics = evaluate_model( + df=subgroup_df, + feature_cols=["map", "map_diff"], + balanced=True, + ) + print_metrics(f"Advanced subgroup gender={gender}", subgroup_metrics) + + +if __name__ == "__main__": + main() \ No newline at end of file From 3fd6c59274e0cbca052b1b3a628cf64439e6dc95 Mon Sep 17 00:00:00 2001 From: Bella Date: Wed, 22 Apr 2026 01:16:24 +0800 Subject: [PATCH 08/15] feat: add chartevents configuration to MIMIC-III dataset schema --- pyhealth/datasets/configs/mimic3_cf.yaml | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/pyhealth/datasets/configs/mimic3_cf.yaml b/pyhealth/datasets/configs/mimic3_cf.yaml index 6ef098c92..0de518e58 100644 --- a/pyhealth/datasets/configs/mimic3_cf.yaml +++ b/pyhealth/datasets/configs/mimic3_cf.yaml @@ -32,4 +32,16 @@ tables: - "intime" - "outtime" - "first_careunit" - - "last_careunit" \ No newline at end of file + - "last_careunit" + + chartevents: + file_path: "CHARTEVENTS.csv.gz" + patient_id: "subject_id" + timestamp: "charttime" + attributes: + - "hadm_id" + - "icustay_id" + - "itemid" + - "charttime" + - "value" + - "valuenum" \ No newline at end of file From f86ca8a7e891c9304882cb618b8a2e6a950cfdcc Mon Sep 17 00:00:00 2001 From: Bella Date: Wed, 22 Apr 2026 01:16:54 +0800 Subject: [PATCH 09/15] refactor: remove baseline training script for circulatory failure prediction --- pyhealth/train_baseline.py | 68 -------------------------------------- 1 file changed, 68 deletions(-) delete mode 100644 pyhealth/train_baseline.py diff --git a/pyhealth/train_baseline.py b/pyhealth/train_baseline.py deleted file mode 100644 index 10ac3b939..000000000 --- a/pyhealth/train_baseline.py +++ /dev/null @@ -1,68 +0,0 @@ -from pyhealth.datasets import MIMIC3CirculatoryFailureDataset -from pyhealth.tasks import CirculatoryFailurePredictionTask -import pandas as pd -from sklearn.model_selection import train_test_split -from sklearn.linear_model import LogisticRegression -from sklearn.metrics import roc_auc_score, accuracy_score - - -def samples_to_df(samples): - """Convert samples to DataFrame""" - df = pd.DataFrame( - [ - { - "patient_id": s["patient_id"], - "icustay_id": s["icustay_id"], - "gender": s["gender"], - "timestamp": s["timestamp"], - "map": s["features"]["map"], - "label": s["label"], - } - for s in samples - ] - ) - return df - - -def train_model(df): - """Train a simple baseline model""" - - X = df[["map"]] - y = df["label"] - - X_train, X_test, y_train, y_test = train_test_split( - X, y, test_size=0.2, random_state=42 - ) - - model = LogisticRegression() - model.fit(X_train, y_train) - - preds = model.predict(X_test) - probs = model.predict_proba(X_test)[:, 1] - - print("Accuracy:", accuracy_score(y_test, preds)) - print("ROC-AUC:", roc_auc_score(y_test, probs)) - - return model - - -def main(): - dataset = MIMIC3CirculatoryFailureDataset( - root="/Users/bella/Desktop/UIUC MCS/CS598/mimic_test" - ) - task = CirculatoryFailurePredictionTask() - - print("Building dataset...") - samples = dataset.set_task(task, max_patients=20) - - print(f"Total samples: {len(samples)}") - - df = samples_to_df(samples) - print(df.head()) - - print("\nTraining model...") - model = train_model(df) - - -if __name__ == "__main__": - main() \ No newline at end of file From d4030b0746700598b4795c447f948dac0904e488 Mon Sep 17 00:00:00 2001 From: Bella Date: Wed, 22 Apr 2026 20:01:47 +0800 Subject: [PATCH 10/15] fix: remove redundant code --- .../mimic3_cf_circulatory_failure_logreg.py | 2 +- pyhealth/datasets/mimic3_cf.py | 39 ------------------- 2 files changed, 1 insertion(+), 40 deletions(-) diff --git a/examples/mimic3_cf_circulatory_failure_logreg.py b/examples/mimic3_cf_circulatory_failure_logreg.py index fd7c6b8c6..47c075fa1 100644 --- a/examples/mimic3_cf_circulatory_failure_logreg.py +++ b/examples/mimic3_cf_circulatory_failure_logreg.py @@ -85,7 +85,7 @@ def print_metrics(title: str, metrics: dict) -> None: def main() -> None: dataset = MIMIC3CirculatoryFailureDataset( - root="mimic_test" + root="mimic-iii-dataset" ) # task ablation: prediction windows diff --git a/pyhealth/datasets/mimic3_cf.py b/pyhealth/datasets/mimic3_cf.py index c852745ea..3afdc339d 100644 --- a/pyhealth/datasets/mimic3_cf.py +++ b/pyhealth/datasets/mimic3_cf.py @@ -169,45 +169,6 @@ def build_failure_labels(self): return first_failure - # filter MAP - map_df = chartevents[chartevents["ITEMID"] == 220052] - - # convert time - map_df["CHARTTIME"] = pd.to_datetime(map_df["CHARTTIME"]) - - # load cohort - cohort = pd.DataFrame(self.load_cohort()) - cohort["intime"] = pd.to_datetime(cohort["intime"]) - cohort["outtime"] = pd.to_datetime(cohort["outtime"]) - - # merge - merged = map_df.merge( - cohort, - left_on="ICUSTAY_ID", - right_on="icustay_id", - ) - - # filter ICU period - filtered = merged[ - (merged["CHARTTIME"] >= merged["intime"]) - & (merged["CHARTTIME"] <= merged["outtime"]) - ].copy() - - # label - filtered["failure_label"] = (filtered["VALUENUM"] < 65).astype(int) - - # first failure time - failure_events = filtered[filtered["failure_label"] == 1] - - first_failure = ( - failure_events.groupby("ICUSTAY_ID")["CHARTTIME"] - .min() - .reset_index() - .rename(columns={"CHARTTIME": "first_failure_time"}) - ) - - return first_failure - def get_patient_by_icustay_id(self, icustay_id: int): """Build one task-ready patient dict for a given ICU stay.""" From a9cc815eceb8416fa819c670eca8b3b32912e62b Mon Sep 17 00:00:00 2001 From: Bella Date: Wed, 22 Apr 2026 20:54:44 +0800 Subject: [PATCH 11/15] fix: move test files to the correct path --- tests/{ => core}/test_circulatory_failure_prediction.py | 0 tests/{ => core}/test_mimic3_cf.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename tests/{ => core}/test_circulatory_failure_prediction.py (100%) rename tests/{ => core}/test_mimic3_cf.py (100%) diff --git a/tests/test_circulatory_failure_prediction.py b/tests/core/test_circulatory_failure_prediction.py similarity index 100% rename from tests/test_circulatory_failure_prediction.py rename to tests/core/test_circulatory_failure_prediction.py diff --git a/tests/test_mimic3_cf.py b/tests/core/test_mimic3_cf.py similarity index 100% rename from tests/test_mimic3_cf.py rename to tests/core/test_mimic3_cf.py From 6591c5b428df78f7a4dc72774d93ef8bf74d62f9 Mon Sep 17 00:00:00 2001 From: Bella Date: Wed, 22 Apr 2026 20:58:48 +0800 Subject: [PATCH 12/15] docs: add module, class, and function docstrings --- .../mimic3_cf_circulatory_failure_logreg.py | 9 +++++ .../tasks/circulatory_failure_prediction.py | 38 ++++++++++++++++++- 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/examples/mimic3_cf_circulatory_failure_logreg.py b/examples/mimic3_cf_circulatory_failure_logreg.py index 47c075fa1..72e81cb5a 100644 --- a/examples/mimic3_cf_circulatory_failure_logreg.py +++ b/examples/mimic3_cf_circulatory_failure_logreg.py @@ -1,3 +1,11 @@ +""" +Example ablation script for MIMIC-III circulatory failure prediction. + +This script compares different prediction windows (6h, 12h, 24h) and +feature settings using logistic regression. It is intended as an example +usage script for the dataset-task pipeline and ablation study. +""" + from pyhealth.datasets import MIMIC3CirculatoryFailureDataset from pyhealth.tasks import CirculatoryFailurePredictionTask @@ -85,6 +93,7 @@ def print_metrics(title: str, metrics: dict) -> None: def main() -> None: dataset = MIMIC3CirculatoryFailureDataset( + # path to the unzipped MIMIC-III database on your machine root="mimic-iii-dataset" ) diff --git a/pyhealth/tasks/circulatory_failure_prediction.py b/pyhealth/tasks/circulatory_failure_prediction.py index d3e3083de..c34250059 100644 --- a/pyhealth/tasks/circulatory_failure_prediction.py +++ b/pyhealth/tasks/circulatory_failure_prediction.py @@ -3,7 +3,23 @@ class CirculatoryFailurePredictionTask(BaseTask): - + """Early-warning task for circulatory failure prediction. + + This task converts one ICU-stay patient record into multiple + time-point prediction samples. At each timestamp t, the label is 1 + if the first circulatory failure event occurs within the next + prediction window, and 0 otherwise. + + Circulatory failure is defined upstream using a proxy event based on + MAP < 65 mmHg. + + Attributes: + task_name: Unique task identifier used by PyHealth. + input_schema: Expected input feature schema. + output_schema: Expected output label schema. + prediction_window_hours: Number of hours used for early-warning label + generation. + """ task_name = "circulatory_failure_prediction" input_schema = { @@ -17,10 +33,30 @@ class CirculatoryFailurePredictionTask(BaseTask): } def __init__(self, prediction_window_hours: int = 12): + """Initializes the circulatory failure prediction task. + + Args: + prediction_window_hours: Future prediction window in hours. + A sample is labeled positive if the first failure event + happens within this horizon. + """ super().__init__() self.prediction_window_hours = prediction_window_hours def __call__(self, patient) -> List[Dict]: + """Converts one patient record into task samples. + + Args: + patient: A task-ready patient dictionary containing ICU-stay + metadata, time-series MAP measurements, and + first_failure_time. + + Returns: + A list of sample dictionaries. Each sample contains patient + metadata, a timestamp, feature values, and a binary label. + Returns an empty list if the patient has no usable + time-series data or no failure time. + """ if not patient["time_series"]: return [] From 3ce7421e03cac71dd4db40b32dd2ea2d5717fc54 Mon Sep 17 00:00:00 2001 From: Bella Date: Thu, 23 Apr 2026 23:23:57 +0800 Subject: [PATCH 13/15] refactor: extract MAP_ITEMID constant and clean up redundant imports in MIMIC3CirculatoryFailureDataset --- pyhealth/datasets/mimic3_cf.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/pyhealth/datasets/mimic3_cf.py b/pyhealth/datasets/mimic3_cf.py index 3afdc339d..8aeaed876 100644 --- a/pyhealth/datasets/mimic3_cf.py +++ b/pyhealth/datasets/mimic3_cf.py @@ -5,7 +5,7 @@ from .base_dataset import BaseDataset logger = logging.getLogger(__name__) - +MAP_ITEMID = 220052 class MIMIC3CirculatoryFailureDataset(BaseDataset): """MIMIC-III dataset for circulatory failure early-warning prediction. @@ -86,9 +86,6 @@ def load_patients(self): def build_failure_labels(self): """Build first failure time per ICU stay (MAP < 65) using chunked reads.""" - import pandas as pd - from pathlib import Path - root = Path(self.root) # load cohort once @@ -113,7 +110,7 @@ def build_failure_labels(self): for chunk in chunks: # filter MAP only - chunk = chunk[chunk["ITEMID"] == 220052].copy() + chunk = chunk[chunk["ITEMID"] == MAP_ITEMID].copy() if chunk.empty: continue @@ -172,11 +169,6 @@ def build_failure_labels(self): def get_patient_by_icustay_id(self, icustay_id: int): """Build one task-ready patient dict for a given ICU stay.""" - import pandas as pd - from pathlib import Path - - root = Path(self.root) - # 1) load cohort cohort_df = pd.DataFrame(self.load_cohort()) cohort_df["intime"] = pd.to_datetime(cohort_df["intime"]) @@ -231,9 +223,6 @@ def get_patient_by_icustay_id(self, icustay_id: int): def load_map_cache(self): """Load MAP (mean arterial pressure) data once and cache it.""" - import pandas as pd - from pathlib import Path - if hasattr(self, "_map_cache"): return self._map_cache From 2ff44b188a2b643b493919903d04f218d89e1ad1 Mon Sep 17 00:00:00 2001 From: Bella Date: Thu, 7 May 2026 23:31:28 +0800 Subject: [PATCH 14/15] Align circulatory failure pipeline with PyHealth API --- .../mimic3_cf_circulatory_failure_logreg.py | 79 +++-- examples/mimic3_cf_example.py | 23 -- pyhealth/datasets/mimic3_cf.py | 301 ++++-------------- .../tasks/circulatory_failure_prediction.py | 264 +++++++++++---- pyhealth/test_pipeline.py | 44 --- .../test_circulatory_failure_prediction.py | 112 ++++++- tests/core/test_mimic3_cf.py | 90 +++--- 7 files changed, 434 insertions(+), 479 deletions(-) delete mode 100644 examples/mimic3_cf_example.py delete mode 100644 pyhealth/test_pipeline.py diff --git a/examples/mimic3_cf_circulatory_failure_logreg.py b/examples/mimic3_cf_circulatory_failure_logreg.py index 72e81cb5a..62bcb6a24 100644 --- a/examples/mimic3_cf_circulatory_failure_logreg.py +++ b/examples/mimic3_cf_circulatory_failure_logreg.py @@ -3,43 +3,40 @@ This script compares different prediction windows (6h, 12h, 24h) and feature settings using logistic regression. It is intended as an example -usage script for the dataset-task pipeline and ablation study. +usage script for the standard PyHealth dataset → task → SampleDataset pipeline. + +Usage: + python mimic3_cf_circulatory_failure_logreg.py --root /path/to/mimic-iii """ -from pyhealth.datasets import MIMIC3CirculatoryFailureDataset -from pyhealth.tasks import CirculatoryFailurePredictionTask +import argparse import pandas as pd -from sklearn.model_selection import train_test_split from sklearn.linear_model import LogisticRegression -from sklearn.metrics import accuracy_score, roc_auc_score, recall_score +from sklearn.metrics import accuracy_score, recall_score, roc_auc_score +from sklearn.model_selection import train_test_split +from pyhealth.datasets import MIMIC3CirculatoryFailureDataset +from pyhealth.tasks import CirculatoryFailurePredictionTask -def samples_to_df(samples: list[dict]) -> pd.DataFrame: + +def samples_to_df(sample_dataset) -> pd.DataFrame: + """Converts a SampleDataset into a pandas DataFrame.""" rows = [] - for s in samples: + for i in range(len(sample_dataset)): + s = sample_dataset[i] rows.append( { "patient_id": s["patient_id"], "icustay_id": s["icustay_id"], - "gender": s["gender"], - "timestamp": s["timestamp"], - "map": s["features"]["map"], - "label": s["label"], + "gender": s.get("gender"), + "timestamp": s.get("timestamp"), + "map": to_scalar(s["map"]), + "map_diff": to_scalar(s["map_diff"]), + "label": int(to_scalar(s["label"])), } ) - df = pd.DataFrame(rows) - return df - - -def add_advanced_features(df: pd.DataFrame) -> pd.DataFrame: - """Add simple temporal features for the advanced setting.""" - df = df.sort_values(["icustay_id", "timestamp"]).copy() - df["map_prev"] = df.groupby("icustay_id")["map"].shift(1) - df["map_diff"] = df["map"] - df["map_prev"] - df["map_prev"] = df["map_prev"].fillna(df["map"]) - df["map_diff"] = df["map_diff"].fillna(0.0) - return df + return pd.DataFrame(rows) def evaluate_model( @@ -90,27 +87,40 @@ def print_metrics(title: str, metrics: dict) -> None: print(f"roc_auc: {metrics['roc_auc']}") print(f"recall: {metrics['recall']}") +def to_scalar(x): + """Converts scalar tensor-like values to Python scalars.""" + if hasattr(x, "item"): + return x.item() + return x def main() -> None: - dataset = MIMIC3CirculatoryFailureDataset( - # path to the unzipped MIMIC-III database on your machine - root="mimic-iii-dataset" + parser = argparse.ArgumentParser( + description="MIMIC-III circulatory failure prediction ablation study." ) + parser.add_argument( + "--root", + type=str, + required=True, + help="Path to the unzipped MIMIC-III database directory.", + ) + args = parser.parse_args() + + dataset = MIMIC3CirculatoryFailureDataset(root=args.root) - # task ablation: prediction windows + # Task ablation: prediction windows for window in [6, 12, 24]: print(f"\n############################") print(f"Prediction window = {window}h") print(f"############################") task = CirculatoryFailurePredictionTask(prediction_window_hours=window) - samples = dataset.set_task(task, max_patients=100) - df = samples_to_df(samples) + sample_dataset = dataset.set_task(task) + df = samples_to_df(sample_dataset) print("\nSample preview:") print(df.head()) - # baseline setting + # Baseline setting baseline_metrics = evaluate_model( df=df, feature_cols=["map"], @@ -118,10 +128,9 @@ def main() -> None: ) print_metrics("Baseline: LogisticRegression(map)", baseline_metrics) - # advanced setting - df_adv = add_advanced_features(df) + # Advanced setting advanced_metrics = evaluate_model( - df=df_adv, + df=df, feature_cols=["map", "map_diff"], balanced=True, ) @@ -130,9 +139,9 @@ def main() -> None: advanced_metrics, ) - # subgroup fairness + # Subgroup fairness for gender in ["M", "F"]: - subgroup_df = df_adv[df_adv["gender"] == gender].copy() + subgroup_df = df[df["gender"] == gender].copy() subgroup_metrics = evaluate_model( df=subgroup_df, feature_cols=["map", "map_diff"], diff --git a/examples/mimic3_cf_example.py b/examples/mimic3_cf_example.py deleted file mode 100644 index baa4c8783..000000000 --- a/examples/mimic3_cf_example.py +++ /dev/null @@ -1,23 +0,0 @@ -from pyhealth.datasets import MIMIC3CirculatoryFailureDataset -from pyhealth.tasks import CirculatoryFailurePredictionTask - - -def main(): - dataset = MIMIC3CirculatoryFailureDataset( - root="/path/to/mimic3" - ) - - task = CirculatoryFailurePredictionTask(prediction_window_hours=12) - - # apply task - samples = dataset.set_task(task, max_patients=5) - - print(f"Total samples: {len(samples)}") - - if samples: - print("Sample example:") - print(samples[0]) - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/pyhealth/datasets/mimic3_cf.py b/pyhealth/datasets/mimic3_cf.py index 8aeaed876..82462575c 100644 --- a/pyhealth/datasets/mimic3_cf.py +++ b/pyhealth/datasets/mimic3_cf.py @@ -1,26 +1,60 @@ +""" +MIMIC-III Circulatory Failure Dataset for PyHealth. + +Dataset: + MIMIC-III Clinical Database v1.4 + https://physionet.org/content/mimiciii/1.4/ + +Inspired by: + Hoche, M., Mineeva, O., Burger, M., Blasimme, A., & Ratsch, G. (2024). + FAMEWS: A fairness auditing tool for medical early-warning systems. + Proceedings of the Fifth Conference on Health, Inference, and Learning, 248, 297–311. PMLR. + https://proceedings.mlr.press/v248/hoche24a.html + +Description: + Configures the MIMIC-III tables required for a circulatory-failure + early-warning task. The dataset keeps data loading separate from + task logic; sample generation is handled by + ``CirculatoryFailurePredictionTask`` through the standard PyHealth + ``dataset.set_task(task)`` pipeline. + +Authors: + Kuang-Yu Wang (kuangyu4@illinois.edu) + Ya Hsuan Yang (yhyang3@illinois.edu) +""" + import logging from pathlib import Path from typing import List, Optional -import pandas as pd from .base_dataset import BaseDataset logger = logging.getLogger(__name__) -MAP_ITEMID = 220052 + class MIMIC3CirculatoryFailureDataset(BaseDataset): - """MIMIC-III dataset for circulatory failure early-warning prediction. + """MIMIC-III wrapper for circulatory failure early-warning prediction. - This dataset is designed for a FAMEWS-inspired reproduction setting on - MIMIC-III. It will support cohort construction, event parsing, and - time-series feature extraction for circulatory failure prediction within - a future prediction window. + This dataset configures the MIMIC-III tables required for a + FAMEWS-inspired circulatory failure early-warning task. The dataset keeps + data loading separate from task logic; sample generation is handled by + ``CirculatoryFailurePredictionTask`` through the standard PyHealth + ``dataset.set_task(task)`` pipeline. Args: root: Root directory of the MIMIC-III dataset. - tables: Additional tables to load beyond the default cohort tables. + tables: Additional tables to load beyond the default tables. dataset_name: Name of the dataset instance. config_path: Path to the dataset config YAML file. **kwargs: Additional keyword arguments passed to BaseDataset. + + Examples: + >>> from pyhealth.datasets import MIMIC3CirculatoryFailureDataset + >>> from pyhealth.tasks import CirculatoryFailurePredictionTask + >>> dataset = MIMIC3CirculatoryFailureDataset( + ... root="/path/to/mimic-iii", + ... ) + >>> task = CirculatoryFailurePredictionTask(prediction_window_hours=12) + >>> sample_dataset = dataset.set_task(task) """ def __init__( @@ -36,8 +70,17 @@ def __init__( logger.info("No config path provided, using default config") config_path = Path(__file__).parent / "configs" / "mimic3_cf.yaml" - default_tables = ["patients", "admissions", "icustays"] - tables = default_tables + (tables or []) + default_tables = [ + "patients", + "admissions", + "icustays", + "chartevents", + ] + + if tables is None: + tables = default_tables + else: + tables = list(dict.fromkeys(default_tables + tables)) super().__init__( root=root, @@ -45,240 +88,4 @@ def __init__( dataset_name=dataset_name or "mimic3_cf", config_path=str(config_path), **kwargs, - ) - - def load_cohort(self): - """Load patients + admissions + icustays.""" - - import pandas as pd - from pathlib import Path - - root = Path(self.root) - - patients_df = pd.read_csv(root / "PATIENTS.csv.gz") - admissions_df = pd.read_csv(root / "ADMISSIONS.csv.gz") - icustays_df = pd.read_csv(root / "ICUSTAYS.csv.gz") - - df = patients_df.merge(admissions_df, on="SUBJECT_ID") - df = df.merge(icustays_df, on=["SUBJECT_ID", "HADM_ID"]) - - patients = [] - - for _, row in df.iterrows(): - patients.append( - { - "patient_id": row["SUBJECT_ID"], - "gender": row["GENDER"], - "hadm_id": row["HADM_ID"], - "icustay_id": row["ICUSTAY_ID"], - "admittime": row["ADMITTIME"], - "intime": row["INTIME"], - "outtime": row["OUTTIME"], - } - ) - - return patients - - def load_patients(self): - """Backward-compatible wrapper for current development.""" - return self.load_cohort() - - def build_failure_labels(self): - """Build first failure time per ICU stay (MAP < 65) using chunked reads.""" - - root = Path(self.root) - - # load cohort once - cohort = pd.DataFrame(self.load_cohort()) - cohort["intime"] = pd.to_datetime(cohort["intime"]) - cohort["outtime"] = pd.to_datetime(cohort["outtime"]) - - results = [] - - chunks = pd.read_csv( - root / "CHARTEVENTS.csv.gz", - usecols=[ - "SUBJECT_ID", - "HADM_ID", - "ICUSTAY_ID", - "ITEMID", - "CHARTTIME", - "VALUENUM", - ], - chunksize=50000, - ) - - for chunk in chunks: - # filter MAP only - chunk = chunk[chunk["ITEMID"] == MAP_ITEMID].copy() - if chunk.empty: - continue - - chunk["CHARTTIME"] = pd.to_datetime( - chunk["CHARTTIME"], - format="%Y-%m-%d %H:%M:%S", - errors="coerce", - ) - - merged = chunk.merge( - cohort, - left_on="ICUSTAY_ID", - right_on="icustay_id", - ) - - if merged.empty: - continue - - filtered = merged[ - (merged["CHARTTIME"] >= merged["intime"]) - & (merged["CHARTTIME"] <= merged["outtime"]) - ].copy() - - if filtered.empty: - continue - - filtered["failure_label"] = (filtered["VALUENUM"] < 65).astype(int) - - failure_events = filtered[filtered["failure_label"] == 1] - if failure_events.empty: - continue - - first_failure_chunk = ( - failure_events.groupby("ICUSTAY_ID")["CHARTTIME"] - .min() - .reset_index() - .rename(columns={"CHARTTIME": "first_failure_time"}) - ) - - results.append(first_failure_chunk) - - if not results: - return pd.DataFrame(columns=["ICUSTAY_ID", "first_failure_time"]) - - first_failure = pd.concat(results, ignore_index=True) - - # keep earliest failure time per ICU stay across all chunks - first_failure = ( - first_failure.groupby("ICUSTAY_ID")["first_failure_time"] - .min() - .reset_index() - ) - - return first_failure - - def get_patient_by_icustay_id(self, icustay_id: int): - """Build one task-ready patient dict for a given ICU stay.""" - - # 1) load cohort - cohort_df = pd.DataFrame(self.load_cohort()) - cohort_df["intime"] = pd.to_datetime(cohort_df["intime"]) - cohort_df["outtime"] = pd.to_datetime(cohort_df["outtime"]) - - row = cohort_df[cohort_df["icustay_id"] == icustay_id] - if row.empty: - return None - row = row.iloc[0] - - # 2) load failure labels - first_failure = self.build_failure_labels() - failure_row = first_failure[first_failure["ICUSTAY_ID"] == icustay_id] - - first_failure_time = None - if not failure_row.empty: - first_failure_time = failure_row.iloc[0]["first_failure_time"] - - # 3) load MAP time series for this ICU stay - map_df = self.load_map_cache() - ts = map_df[map_df["ICUSTAY_ID"] == icustay_id].copy() - ts = ts[ - (ts["CHARTTIME"] >= row["intime"]) & - (ts["CHARTTIME"] <= row["outtime"]) - ].copy() - - ts = ts.sort_values("CHARTTIME") - - time_series = [] - for _, ts_row in ts.iterrows(): - if pd.isna(ts_row["VALUENUM"]): - continue - time_series.append( - { - "charttime": ts_row["CHARTTIME"], - "map": float(ts_row["VALUENUM"]), - } - ) - - patient = { - "patient_id": int(row["patient_id"]), - "icustay_id": int(row["icustay_id"]), - "gender": row["gender"], - "intime": row["intime"], - "outtime": row["outtime"], - "time_series": time_series, - "first_failure_time": first_failure_time, - } - - return patient - - def load_map_cache(self): - """Load MAP (mean arterial pressure) data once and cache it.""" - - if hasattr(self, "_map_cache"): - return self._map_cache - - root = Path(self.root) - - print("Loading MAP cache (this will take a bit, only once)...") - - chunks = pd.read_csv( - root / "CHARTEVENTS.csv.gz", - usecols=["ICUSTAY_ID", "ITEMID", "CHARTTIME", "VALUENUM"], - chunksize=100000, - ) - - parts = [] - - for chunk in chunks: - chunk = chunk[chunk["ITEMID"] == 220052].copy() - if chunk.empty: - continue - - chunk["CHARTTIME"] = pd.to_datetime( - chunk["CHARTTIME"], - format="%Y-%m-%d %H:%M:%S", - errors="coerce", - ) - - parts.append(chunk) - - if parts: - df = pd.concat(parts, ignore_index=True) - else: - df = pd.DataFrame(columns=["ICUSTAY_ID", "CHARTTIME", "VALUENUM"]) - - self._map_cache = df - print("MAP cache loaded:", len(df)) - - return df - - def set_task(self, task, max_patients: int | None = None): - """Apply a task function to the cohort and return task samples.""" - - samples = [] - cohort = self.load_cohort() - - if max_patients is not None: - cohort = cohort[:max_patients] - - for row in cohort: - icustay_id = row["icustay_id"] - patient = self.get_patient_by_icustay_id(icustay_id) - - if patient is None: - continue - - task_samples = task(patient) - if task_samples: - samples.extend(task_samples) - - return samples \ No newline at end of file + ) \ No newline at end of file diff --git a/pyhealth/tasks/circulatory_failure_prediction.py b/pyhealth/tasks/circulatory_failure_prediction.py index c34250059..f6824fc63 100644 --- a/pyhealth/tasks/circulatory_failure_prediction.py +++ b/pyhealth/tasks/circulatory_failure_prediction.py @@ -1,93 +1,235 @@ +""" +Circulatory Failure Prediction Task for PyHealth. + +Dataset: + MIMIC-III Clinical Database v1.4 + https://physionet.org/content/mimiciii/1.4/ + +Inspired by: + Hoche, M., Mineeva, O., Burger, M., Blasimme, A., & Ratsch, G. (2024). + FAMEWS: A fairness auditing tool for medical early-warning systems. + Proceedings of the Fifth Conference on Health, Inference, and Learning, 248, 297–311. PMLR. + https://proceedings.mlr.press/v248/hoche24a.html + +Description: + Time-point prediction task for circulatory failure early warning. + For each MAP measurement at time *t*, the sample is labelled positive + if the first circulatory failure event occurs within the future + prediction window. Circulatory failure is approximated by a mean + arterial pressure (MAP) value below 65 mmHg. + +Authors: + Kuang-Yu Wang (kuangyu4@illinois.edu) + Ya Hsuan Yang (yhyang3@illinois.edu) +""" + +from datetime import timedelta +from typing import Any, Dict, List, Optional +import pandas as pd +from pyhealth.data import Patient from pyhealth.tasks.base_task import BaseTask -from typing import List, Dict + + +MAP_ITEMID = 220052 +MAP_FAILURE_THRESHOLD = 65.0 class CirculatoryFailurePredictionTask(BaseTask): """Early-warning task for circulatory failure prediction. - This task converts one ICU-stay patient record into multiple - time-point prediction samples. At each timestamp t, the label is 1 - if the first circulatory failure event occurs within the next - prediction window, and 0 otherwise. - - Circulatory failure is defined upstream using a proxy event based on - MAP < 65 mmHg. + This task converts a PyHealth Patient object into time-point prediction + samples for circulatory failure early warning. For each MAP measurement at + time t, the sample is labeled positive if the first circulatory failure + event occurs within the future prediction window. + Circulatory failure is approximated by a mean arterial pressure (MAP) + value below 65 mmHg. Attributes: - task_name: Unique task identifier used by PyHealth. - input_schema: Expected input feature schema. - output_schema: Expected output label schema. - prediction_window_hours: Number of hours used for early-warning label - generation. + task_name: Unique task identifier. + input_schema: Input feature schema for PyHealth processors. + output_schema: Output label schema for PyHealth processors. + + Examples: + >>> from pyhealth.datasets import MIMIC3CirculatoryFailureDataset + >>> from pyhealth.tasks import CirculatoryFailurePredictionTask + >>> dataset = MIMIC3CirculatoryFailureDataset( + ... root="/path/to/mimic-iii", + ... ) + >>> task = CirculatoryFailurePredictionTask(prediction_window_hours=12) + >>> sample_dataset = dataset.set_task(task) + >>> sample_dataset[0] # doctest: +SKIP """ - task_name = "circulatory_failure_prediction" - input_schema = { - "map": float, - "timestamp": "datetime", - "gender": str, + input_schema: Dict[str, str] = { + "map": "tensor", + "map_diff": "tensor" } - output_schema = { - "label": int, + output_schema: Dict[str, str] = { + "label": "binary" } - def __init__(self, prediction_window_hours: int = 12): + task_name: str = "circulatory_failure_prediction" + + def __init__( + self, + prediction_window_hours: int = 12, + map_itemid: int = MAP_ITEMID, + failure_threshold: float = MAP_FAILURE_THRESHOLD, + **kwargs: Any, + ) -> None: """Initializes the circulatory failure prediction task. Args: prediction_window_hours: Future prediction window in hours. - A sample is labeled positive if the first failure event - happens within this horizon. + map_itemid: MIMIC-III ITEMID corresponding to MAP. + failure_threshold: MAP threshold used to define circulatory failure. + **kwargs: Additional keyword arguments passed to BaseTask. """ - super().__init__() + super().__init__(**kwargs) self.prediction_window_hours = prediction_window_hours - - def __call__(self, patient) -> List[Dict]: - """Converts one patient record into task samples. + self.map_itemid = map_itemid + self.failure_threshold = failure_threshold + self.task_name = ( + f"circulatory_failure_prediction_{prediction_window_hours}h" + ) + + @staticmethod + def _to_datetime(value: Any) -> Optional[pd.Timestamp]: + """Converts a timestamp-like value to pandas Timestamp.""" + if value is None or pd.isna(value): + return None + return pd.to_datetime(value, errors="coerce") + + @staticmethod + def _to_float(value: Any) -> Optional[float]: + """Converts a numeric-like value to float.""" + if value is None or pd.isna(value): + return None + try: + return float(value) + except (TypeError, ValueError): + return None + + @staticmethod + def _event_attr(event: Any, attr: str, default: Any = None) -> Any: + """Gets an event attribute from either object attr or attr_dict.""" + if hasattr(event, attr): + return getattr(event, attr) + + attr_dict = getattr(event, "attr_dict", None) + if isinstance(attr_dict, dict) and attr in attr_dict: + return attr_dict[attr] + + return default + + def __call__(self, patient: Patient) -> List[Dict[str, Any]]: + """Converts one Patient object into prediction samples. Args: - patient: A task-ready patient dictionary containing ICU-stay - metadata, time-series MAP measurements, and - first_failure_time. + patient: A PyHealth Patient object. The patient should contain + `icustays`, `patients`, and `chartevents` events loaded from + MIMIC-III. Returns: - A list of sample dictionaries. Each sample contains patient - metadata, a timestamp, feature values, and a binary label. - Returns an empty list if the patient has no usable - time-series data or no failure time. + A list of sample dictionaries. Each sample contains patient/visit + metadata, MAP-based features, and a binary early-warning label. + Returning an empty list is valid when no usable MAP data or no + failure event is found. """ - if not patient["time_series"]: - return [] - - import pandas as pd - from datetime import timedelta - - first_failure_time = patient["first_failure_time"] - if first_failure_time is None: - return [] - - first_failure_time = pd.to_datetime(first_failure_time) - + samples: List[Dict[str, Any]] = [] prediction_window = timedelta(hours=self.prediction_window_hours) - samples = [] + patient_events = patient.get_events(event_type="patients") + gender = None + if len(patient_events) > 0: + gender = self._event_attr(patient_events[0], "gender") - for row in patient["time_series"]: - t = pd.to_datetime(row["charttime"]) - map_value = row["map"] + icu_stays = patient.get_events(event_type="icustays") + chartevents = patient.get_events(event_type="chartevents") - label = int(t < first_failure_time <= t + prediction_window) + if len(icu_stays) == 0 or len(chartevents) == 0: + return [] - samples.append( - { - "patient_id": patient["patient_id"], - "icustay_id": patient["icustay_id"], - "gender": patient["gender"], - "timestamp": t, - "features": {"map": map_value}, - "label": label, - } - ) + for icu_stay in icu_stays: + icustay_id = self._event_attr(icu_stay, "icustay_id") + intime = self._to_datetime(self._event_attr(icu_stay, "intime")) + outtime = self._to_datetime(self._event_attr(icu_stay, "outtime")) + hadm_id = self._event_attr(icu_stay, "hadm_id") + + if icustay_id is None or intime is None or outtime is None: + continue + + map_events = [] + for event in chartevents: + event_icustay_id = self._event_attr(event, "icustay_id") + itemid = self._event_attr(event, "itemid") + charttime = self._to_datetime(event.timestamp) + valuenum = self._to_float(self._event_attr(event, "valuenum")) + + if event_icustay_id != icustay_id: + continue + if itemid != self.map_itemid: + continue + if charttime is None or pd.isna(charttime): + continue + if valuenum is None: + continue + if not (intime <= charttime <= outtime): + continue + + map_events.append( + { + "charttime": charttime, + "map": valuenum, + } + ) + + if not map_events: + continue + + map_events = sorted(map_events, key=lambda x: x["charttime"]) + + failure_times = [ + row["charttime"] + for row in map_events + if row["map"] < self.failure_threshold + ] + + if not failure_times: + continue + + first_failure_time = min(failure_times) + + previous_map = None + for row in map_events: + timestamp = row["charttime"] + map_value = row["map"] + + if previous_map is None: + map_diff = 0.0 + else: + map_diff = map_value - previous_map + + previous_map = map_value + + label = int( + timestamp < first_failure_time + <= timestamp + prediction_window + ) + + samples.append( + { + "patient_id": patient.patient_id, + "visit_id": str(icustay_id), + "hadm_id": hadm_id, + "icustay_id": icustay_id, + "gender": gender, + "timestamp": timestamp, + "map": map_value, + "map_diff": map_diff, + "label": label, + } + ) return samples \ No newline at end of file diff --git a/pyhealth/test_pipeline.py b/pyhealth/test_pipeline.py deleted file mode 100644 index 7187d83cb..000000000 --- a/pyhealth/test_pipeline.py +++ /dev/null @@ -1,44 +0,0 @@ -from pyhealth.datasets import MIMIC3CirculatoryFailureDataset -from pyhealth.tasks import CirculatoryFailurePredictionTask - - -def main(): - # 1. 初始化資料集(請確保路徑正確) - dataset = MIMIC3CirculatoryFailureDataset( - root="/mimic_test" - ) - - # 2. 初始化任務(12小時預警) - task = CirculatoryFailurePredictionTask(prediction_window_hours=12) - - # 3. 讀取 cohort,找到第一個真的能產生 samples 的 ICU stay - cohort = dataset.load_cohort() - - samples = None - chosen_icustay_id = None - - for row in cohort: - icustay_id = row["icustay_id"] - patient = dataset.get_patient_by_icustay_id(icustay_id) - samples = task(patient) - - if samples: - chosen_icustay_id = icustay_id - break - - # 4. 檢查結果 - if samples: - print(f"--- 成功測試 ICU Stay ID: {chosen_icustay_id} ---") - print(f"成功產生樣本數: {len(samples)}") - print( - f"其中 Label=1 (未來12小時內會衰竭) 的數量: " - f"{sum(s['label'] for s in samples)}" - ) - print(f"第一筆樣本特徵: {samples[0]['features']}") - print(f"第一筆完整樣本: {samples[0]}") - else: - print("未找到任何可產生樣本的 ICU stay,請檢查資料與路徑。") - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/tests/core/test_circulatory_failure_prediction.py b/tests/core/test_circulatory_failure_prediction.py index e0daba760..73c1ebfac 100644 --- a/tests/core/test_circulatory_failure_prediction.py +++ b/tests/core/test_circulatory_failure_prediction.py @@ -1,21 +1,100 @@ +from dataclasses import dataclass +from datetime import datetime + from pyhealth.tasks import CirculatoryFailurePredictionTask +@dataclass +class FakeEvent: + event_type: str + timestamp: datetime | None = None + attr_dict: dict | None = None + + def __post_init__(self): + if self.attr_dict is None: + self.attr_dict = {} + + def __getattr__(self, name): + if name in self.attr_dict: + return self.attr_dict[name] + raise AttributeError(name) + + +class FakePatient: + patient_id = "1" + + def __init__(self): + self.events = { + "patients": [ + FakeEvent( + event_type="patients", + attr_dict={"gender": "F"}, + ) + ], + "icustays": [ + FakeEvent( + event_type="icustays", + timestamp=datetime(2150, 1, 1, 0, 0, 0), + attr_dict={ + "hadm_id": 100, + "icustay_id": 1001, + "intime": "2150-01-01 00:00:00", + "outtime": "2150-01-02 00:00:00", + }, + ) + ], + "chartevents": [ + FakeEvent( + event_type="chartevents", + timestamp=datetime(2150, 1, 1, 0, 0, 0), + attr_dict={ + "icustay_id": 1001, + "itemid": 220052, + "valuenum": 80.0, + }, + ), + FakeEvent( + event_type="chartevents", + timestamp=datetime(2150, 1, 1, 1, 0, 0), + attr_dict={ + "icustay_id": 1001, + "itemid": 220052, + "valuenum": 78.0, + }, + ), + FakeEvent( + event_type="chartevents", + timestamp=datetime(2150, 1, 1, 10, 0, 0), + attr_dict={ + "icustay_id": 1001, + "itemid": 220052, + "valuenum": 70.0, + }, + ), + FakeEvent( + event_type="chartevents", + timestamp=datetime(2150, 1, 1, 11, 0, 0), + attr_dict={ + "icustay_id": 1001, + "itemid": 220052, + "valuenum": 60.0, + }, + ), + ], + } + + def get_events(self, event_type=None, *args, **kwargs): + if event_type is None: + all_events = [] + for events in self.events.values(): + all_events.extend(events) + return all_events + return self.events.get(event_type, []) + + def test_circulatory_failure_task_basic(): task = CirculatoryFailurePredictionTask(prediction_window_hours=12) - - patient = { - "patient_id": 1, - "icustay_id": 1001, - "gender": "F", - "first_failure_time": "2150-01-01 11:00:00", - "time_series": [ - {"charttime": "2150-01-01 00:00:00", "map": 80.0}, - {"charttime": "2150-01-01 01:00:00", "map": 78.0}, - {"charttime": "2150-01-01 10:00:00", "map": 70.0}, - {"charttime": "2150-01-01 11:00:00", "map": 60.0}, - ], - } + patient = FakePatient() samples = task(patient) @@ -24,5 +103,8 @@ def test_circulatory_failure_task_basic(): assert samples[1]["label"] == 1 assert samples[2]["label"] == 1 assert samples[3]["label"] == 0 - assert samples[0]["features"]["map"] == 80.0 - assert samples[0]["gender"] == "F" \ No newline at end of file + assert samples[0]["map"] == 80.0 + assert samples[1]["map_diff"] == -2.0 + assert samples[0]["gender"] == "F" + assert samples[0]["patient_id"] == "1" + assert samples[0]["visit_id"] == "1001" \ No newline at end of file diff --git a/tests/core/test_mimic3_cf.py b/tests/core/test_mimic3_cf.py index c8fa7b4a1..4979a8ea3 100644 --- a/tests/core/test_mimic3_cf.py +++ b/tests/core/test_mimic3_cf.py @@ -1,57 +1,39 @@ +""" +Unit tests for ``pyhealth.datasets.MIMIC3CirculatoryFailureDataset``. +""" +from pathlib import Path from pyhealth.datasets import MIMIC3CirculatoryFailureDataset -from pyhealth.tasks import CirculatoryFailurePredictionTask -class DummyMIMIC3CFDataset(MIMIC3CirculatoryFailureDataset): - """Small synthetic subclass for fast unit testing.""" - - def __init__(self): - # do not call super().__init__ because we don't want real files - pass - - def load_cohort(self): - return [ - { - "patient_id": 1, - "gender": "F", - "hadm_id": 100, - "icustay_id": 1001, - "admittime": "2150-01-01 00:00:00", - "intime": "2150-01-01 00:00:00", - "outtime": "2150-01-02 00:00:00", - } - ] - - def get_patient_by_icustay_id(self, icustay_id: int): - if icustay_id != 1001: - return None - - return { - "patient_id": 1, - "icustay_id": 1001, - "gender": "F", - "intime": "2150-01-01 00:00:00", - "outtime": "2150-01-02 00:00:00", - "first_failure_time": "2150-01-01 11:00:00", - "time_series": [ - {"charttime": "2150-01-01 00:00:00", "map": 80.0}, - {"charttime": "2150-01-01 01:00:00", "map": 78.0}, - {"charttime": "2150-01-01 10:00:00", "map": 70.0}, - {"charttime": "2150-01-01 11:00:00", "map": 60.0}, - ], - } - - -def test_set_task_returns_samples(): - dataset = DummyMIMIC3CFDataset() - task = CirculatoryFailurePredictionTask(prediction_window_hours=12) - - samples = dataset.set_task(task) - - assert isinstance(samples, list) - assert len(samples) == 4 - assert samples[0]["patient_id"] == 1 - assert samples[0]["icustay_id"] == 1001 - assert samples[0]["features"]["map"] == 80.0 - assert samples[0]["label"] == 1 - assert samples[-1]["label"] == 0 \ No newline at end of file +def test_mimic3_cf_dataset_initialization(monkeypatch): + captured = {} + + def fake_base_init( + self, + root, + tables, + dataset_name=None, + config_path=None, + **kwargs, + ): + self.root = root + self.tables = tables + self.dataset_name = dataset_name + self.config_path = config_path + captured["tables"] = tables + captured["dataset_name"] = dataset_name + captured["config_path"] = config_path + + monkeypatch.setattr( + "pyhealth.datasets.base_dataset.BaseDataset.__init__", + fake_base_init, + ) + + dataset = MIMIC3CirculatoryFailureDataset(root="dummy-root") + + assert dataset.dataset_name == "mimic3_cf" + assert "patients" in dataset.tables + assert "admissions" in dataset.tables + assert "icustays" in dataset.tables + assert "chartevents" in dataset.tables + assert Path(dataset.config_path).name == "mimic3_cf.yaml" \ No newline at end of file From a8cfc28c755aff923149231f280e400cbba21294 Mon Sep 17 00:00:00 2001 From: Bella Date: Sun, 10 May 2026 12:49:49 +0800 Subject: [PATCH 15/15] refactor: refactor circulatory failure contribution as MIMIC3 task --- docs/api/datasets.rst | 3 +- .../datasets/pyhealth.datasets.mimic3_cf.rst | 26 ------ .../mimic3_cf_circulatory_failure_logreg.py | 8 +- pyhealth/datasets/__init__.py | 3 +- pyhealth/datasets/configs/mimic3_cf.yaml | 47 ---------- pyhealth/datasets/mimic3_cf.py | 91 ------------------- .../tasks/circulatory_failure_prediction.py | 5 +- tests/core/test_mimic3_cf.py | 39 -------- 8 files changed, 10 insertions(+), 212 deletions(-) delete mode 100644 docs/api/datasets/pyhealth.datasets.mimic3_cf.rst delete mode 100644 pyhealth/datasets/configs/mimic3_cf.yaml delete mode 100644 pyhealth/datasets/mimic3_cf.py delete mode 100644 tests/core/test_mimic3_cf.py diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index ed67dcfcc..1a24e3e8c 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -245,5 +245,4 @@ Available Datasets datasets/pyhealth.datasets.COSMICDataset datasets/pyhealth.datasets.TCGAPRADDataset datasets/pyhealth.datasets.splitter - datasets/pyhealth.datasets.utils - datasets/pyhealth.datasets.mimic3_cf \ No newline at end of file + datasets/pyhealth.datasets.utils \ No newline at end of file diff --git a/docs/api/datasets/pyhealth.datasets.mimic3_cf.rst b/docs/api/datasets/pyhealth.datasets.mimic3_cf.rst deleted file mode 100644 index b0d47472c..000000000 --- a/docs/api/datasets/pyhealth.datasets.mimic3_cf.rst +++ /dev/null @@ -1,26 +0,0 @@ -pyhealth.datasets.mimic3_cf -=========================== - -Overview --------- - -MIMIC3CirculatoryFailureDataset is a MIMIC-III based dataset for early warning -prediction of circulatory failure. - -It constructs an ICU-stay-level cohort from PATIENTS, ADMISSIONS, and ICUSTAYS, -and uses CHARTEVENTS to extract Mean Arterial Pressure (MAP) measurements. - -Circulatory failure is defined using a proxy event: - -- MAP < 65 mmHg - -For each ICU stay, the dataset identifies the first occurrence of this event and -supports building task-ready patient records for downstream prediction tasks. - -API Reference -------------- - -.. autoclass:: pyhealth.datasets.MIMIC3CirculatoryFailureDataset - :members: - :undoc-members: - :show-inheritance: \ No newline at end of file diff --git a/examples/mimic3_cf_circulatory_failure_logreg.py b/examples/mimic3_cf_circulatory_failure_logreg.py index 62bcb6a24..a3cdf5369 100644 --- a/examples/mimic3_cf_circulatory_failure_logreg.py +++ b/examples/mimic3_cf_circulatory_failure_logreg.py @@ -16,10 +16,9 @@ from sklearn.metrics import accuracy_score, recall_score, roc_auc_score from sklearn.model_selection import train_test_split -from pyhealth.datasets import MIMIC3CirculatoryFailureDataset +from pyhealth.datasets import MIMIC3Dataset from pyhealth.tasks import CirculatoryFailurePredictionTask - def samples_to_df(sample_dataset) -> pd.DataFrame: """Converts a SampleDataset into a pandas DataFrame.""" rows = [] @@ -105,7 +104,10 @@ def main() -> None: ) args = parser.parse_args() - dataset = MIMIC3CirculatoryFailureDataset(root=args.root) + dataset = MIMIC3Dataset( + root=args.root, + tables=["patients", "admissions", "icustays", "chartevents"], + ) # Task ablation: prediction windows for window in [6, 12, 24]: diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 2cafac05d..01b1ee6d8 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -90,5 +90,4 @@ def __init__(self, *args, **kwargs): load_processors, save_processors, ) -from .collate import collate_temporal -from .mimic3_cf import MIMIC3CirculatoryFailureDataset \ No newline at end of file +from .collate import collate_temporal \ No newline at end of file diff --git a/pyhealth/datasets/configs/mimic3_cf.yaml b/pyhealth/datasets/configs/mimic3_cf.yaml deleted file mode 100644 index 0de518e58..000000000 --- a/pyhealth/datasets/configs/mimic3_cf.yaml +++ /dev/null @@ -1,47 +0,0 @@ -version: "1.4" -tables: - patients: - file_path: "PATIENTS.csv.gz" - patient_id: "subject_id" - timestamp: null - attributes: - - "gender" - - "dob" - - "dod" - - "expire_flag" - - admissions: - file_path: "ADMISSIONS.csv.gz" - patient_id: "subject_id" - timestamp: "admittime" - attributes: - - "hadm_id" - - "admittime" - - "dischtime" - - "deathtime" - - "hospital_expire_flag" - - "ethnicity" - - icustays: - file_path: "ICUSTAYS.csv.gz" - patient_id: "subject_id" - timestamp: "intime" - attributes: - - "hadm_id" - - "icustay_id" - - "intime" - - "outtime" - - "first_careunit" - - "last_careunit" - - chartevents: - file_path: "CHARTEVENTS.csv.gz" - patient_id: "subject_id" - timestamp: "charttime" - attributes: - - "hadm_id" - - "icustay_id" - - "itemid" - - "charttime" - - "value" - - "valuenum" \ No newline at end of file diff --git a/pyhealth/datasets/mimic3_cf.py b/pyhealth/datasets/mimic3_cf.py deleted file mode 100644 index 82462575c..000000000 --- a/pyhealth/datasets/mimic3_cf.py +++ /dev/null @@ -1,91 +0,0 @@ -""" -MIMIC-III Circulatory Failure Dataset for PyHealth. - -Dataset: - MIMIC-III Clinical Database v1.4 - https://physionet.org/content/mimiciii/1.4/ - -Inspired by: - Hoche, M., Mineeva, O., Burger, M., Blasimme, A., & Ratsch, G. (2024). - FAMEWS: A fairness auditing tool for medical early-warning systems. - Proceedings of the Fifth Conference on Health, Inference, and Learning, 248, 297–311. PMLR. - https://proceedings.mlr.press/v248/hoche24a.html - -Description: - Configures the MIMIC-III tables required for a circulatory-failure - early-warning task. The dataset keeps data loading separate from - task logic; sample generation is handled by - ``CirculatoryFailurePredictionTask`` through the standard PyHealth - ``dataset.set_task(task)`` pipeline. - -Authors: - Kuang-Yu Wang (kuangyu4@illinois.edu) - Ya Hsuan Yang (yhyang3@illinois.edu) -""" - -import logging -from pathlib import Path -from typing import List, Optional -from .base_dataset import BaseDataset - -logger = logging.getLogger(__name__) - - -class MIMIC3CirculatoryFailureDataset(BaseDataset): - """MIMIC-III wrapper for circulatory failure early-warning prediction. - - This dataset configures the MIMIC-III tables required for a - FAMEWS-inspired circulatory failure early-warning task. The dataset keeps - data loading separate from task logic; sample generation is handled by - ``CirculatoryFailurePredictionTask`` through the standard PyHealth - ``dataset.set_task(task)`` pipeline. - - Args: - root: Root directory of the MIMIC-III dataset. - tables: Additional tables to load beyond the default tables. - dataset_name: Name of the dataset instance. - config_path: Path to the dataset config YAML file. - **kwargs: Additional keyword arguments passed to BaseDataset. - - Examples: - >>> from pyhealth.datasets import MIMIC3CirculatoryFailureDataset - >>> from pyhealth.tasks import CirculatoryFailurePredictionTask - >>> dataset = MIMIC3CirculatoryFailureDataset( - ... root="/path/to/mimic-iii", - ... ) - >>> task = CirculatoryFailurePredictionTask(prediction_window_hours=12) - >>> sample_dataset = dataset.set_task(task) - """ - - def __init__( - self, - root: str, - tables: Optional[List[str]] = None, - dataset_name: Optional[str] = None, - config_path: Optional[str] = None, - **kwargs, - ) -> None: - """Initializes the MIMIC-III circulatory failure dataset.""" - if config_path is None: - logger.info("No config path provided, using default config") - config_path = Path(__file__).parent / "configs" / "mimic3_cf.yaml" - - default_tables = [ - "patients", - "admissions", - "icustays", - "chartevents", - ] - - if tables is None: - tables = default_tables - else: - tables = list(dict.fromkeys(default_tables + tables)) - - super().__init__( - root=root, - tables=tables, - dataset_name=dataset_name or "mimic3_cf", - config_path=str(config_path), - **kwargs, - ) \ No newline at end of file diff --git a/pyhealth/tasks/circulatory_failure_prediction.py b/pyhealth/tasks/circulatory_failure_prediction.py index f6824fc63..1a73db80f 100644 --- a/pyhealth/tasks/circulatory_failure_prediction.py +++ b/pyhealth/tasks/circulatory_failure_prediction.py @@ -50,10 +50,11 @@ class CirculatoryFailurePredictionTask(BaseTask): output_schema: Output label schema for PyHealth processors. Examples: - >>> from pyhealth.datasets import MIMIC3CirculatoryFailureDataset + >>> from pyhealth.datasets import MIMIC3Dataset >>> from pyhealth.tasks import CirculatoryFailurePredictionTask - >>> dataset = MIMIC3CirculatoryFailureDataset( + >>> dataset = MIMIC3Dataset( ... root="/path/to/mimic-iii", + ... tables=["chartevents"], ... ) >>> task = CirculatoryFailurePredictionTask(prediction_window_hours=12) >>> sample_dataset = dataset.set_task(task) diff --git a/tests/core/test_mimic3_cf.py b/tests/core/test_mimic3_cf.py deleted file mode 100644 index 4979a8ea3..000000000 --- a/tests/core/test_mimic3_cf.py +++ /dev/null @@ -1,39 +0,0 @@ -""" -Unit tests for ``pyhealth.datasets.MIMIC3CirculatoryFailureDataset``. -""" -from pathlib import Path -from pyhealth.datasets import MIMIC3CirculatoryFailureDataset - - -def test_mimic3_cf_dataset_initialization(monkeypatch): - captured = {} - - def fake_base_init( - self, - root, - tables, - dataset_name=None, - config_path=None, - **kwargs, - ): - self.root = root - self.tables = tables - self.dataset_name = dataset_name - self.config_path = config_path - captured["tables"] = tables - captured["dataset_name"] = dataset_name - captured["config_path"] = config_path - - monkeypatch.setattr( - "pyhealth.datasets.base_dataset.BaseDataset.__init__", - fake_base_init, - ) - - dataset = MIMIC3CirculatoryFailureDataset(root="dummy-root") - - assert dataset.dataset_name == "mimic3_cf" - assert "patients" in dataset.tables - assert "admissions" in dataset.tables - assert "icustays" in dataset.tables - assert "chartevents" in dataset.tables - assert Path(dataset.config_path).name == "mimic3_cf.yaml" \ No newline at end of file