diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 8d9a59d21..1a24e3e8c 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -245,4 +245,4 @@ Available Datasets datasets/pyhealth.datasets.COSMICDataset datasets/pyhealth.datasets.TCGAPRADDataset datasets/pyhealth.datasets.splitter - datasets/pyhealth.datasets.utils + datasets/pyhealth.datasets.utils \ No newline at end of file diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 23a4e06e5..e68d80185 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -230,3 +230,4 @@ Available Tasks Mutation Pathogenicity (COSMIC) Cancer Survival Prediction (TCGA) Cancer Mutation Burden (TCGA) + Circulatory Failure Prediction \ No newline at end of file diff --git a/docs/api/tasks/pyhealth.tasks.circulatory_failure_prediction.rst b/docs/api/tasks/pyhealth.tasks.circulatory_failure_prediction.rst new file mode 100644 index 000000000..6a9cf0e2c --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.circulatory_failure_prediction.rst @@ -0,0 +1,24 @@ +pyhealth.tasks.circulatory_failure_prediction +============================================= + +Overview +-------- + +CirculatoryFailurePredictionTask defines a time-series prediction task for early +detection of circulatory failure. + +The task predicts whether a patient will experience circulatory failure within +the next 12 hours based on physiological measurements. + +Label definition: + +- label = 1 if circulatory failure occurs within the next 12 hours +- label = 0 otherwise + +API Reference +------------- + +.. autoclass:: pyhealth.tasks.CirculatoryFailurePredictionTask + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/examples/mimic3_cf_circulatory_failure_logreg.py b/examples/mimic3_cf_circulatory_failure_logreg.py new file mode 100644 index 000000000..a3cdf5369 --- /dev/null +++ b/examples/mimic3_cf_circulatory_failure_logreg.py @@ -0,0 +1,156 @@ +""" +Example ablation script for MIMIC-III circulatory failure prediction. + +This script compares different prediction windows (6h, 12h, 24h) and +feature settings using logistic regression. It is intended as an example +usage script for the standard PyHealth dataset → task → SampleDataset pipeline. + +Usage: + python mimic3_cf_circulatory_failure_logreg.py --root /path/to/mimic-iii +""" + +import argparse + +import pandas as pd +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import accuracy_score, recall_score, roc_auc_score +from sklearn.model_selection import train_test_split + +from pyhealth.datasets import MIMIC3Dataset +from pyhealth.tasks import CirculatoryFailurePredictionTask + +def samples_to_df(sample_dataset) -> pd.DataFrame: + """Converts a SampleDataset into a pandas DataFrame.""" + rows = [] + for i in range(len(sample_dataset)): + s = sample_dataset[i] + rows.append( + { + "patient_id": s["patient_id"], + "icustay_id": s["icustay_id"], + "gender": s.get("gender"), + "timestamp": s.get("timestamp"), + "map": to_scalar(s["map"]), + "map_diff": to_scalar(s["map_diff"]), + "label": int(to_scalar(s["label"])), + } + ) + return pd.DataFrame(rows) + + +def evaluate_model( + df: pd.DataFrame, + feature_cols: list[str], + balanced: bool = False, +) -> dict: + if df.empty or df["label"].nunique() < 2: + return { + "n_samples": len(df), + "accuracy": None, + "roc_auc": None, + "recall": None, + } + + X = df[feature_cols] + y = df["label"] + + X_train, X_test, y_train, y_test = train_test_split( + X, + y, + test_size=0.2, + random_state=42, + stratify=y, + ) + + model = LogisticRegression( + max_iter=1000, + class_weight="balanced" if balanced else None, + ) + model.fit(X_train, y_train) + + preds = model.predict(X_test) + probs = model.predict_proba(X_test)[:, 1] + + return { + "n_samples": len(df), + "accuracy": accuracy_score(y_test, preds), + "roc_auc": roc_auc_score(y_test, probs), + "recall": recall_score(y_test, preds), + } + + +def print_metrics(title: str, metrics: dict) -> None: + print(f"\n=== {title} ===") + print(f"n_samples: {metrics['n_samples']}") + print(f"accuracy: {metrics['accuracy']}") + print(f"roc_auc: {metrics['roc_auc']}") + print(f"recall: {metrics['recall']}") + +def to_scalar(x): + """Converts scalar tensor-like values to Python scalars.""" + if hasattr(x, "item"): + return x.item() + return x + +def main() -> None: + parser = argparse.ArgumentParser( + description="MIMIC-III circulatory failure prediction ablation study." + ) + parser.add_argument( + "--root", + type=str, + required=True, + help="Path to the unzipped MIMIC-III database directory.", + ) + args = parser.parse_args() + + dataset = MIMIC3Dataset( + root=args.root, + tables=["patients", "admissions", "icustays", "chartevents"], + ) + + # Task ablation: prediction windows + for window in [6, 12, 24]: + print(f"\n############################") + print(f"Prediction window = {window}h") + print(f"############################") + + task = CirculatoryFailurePredictionTask(prediction_window_hours=window) + sample_dataset = dataset.set_task(task) + df = samples_to_df(sample_dataset) + + print("\nSample preview:") + print(df.head()) + + # Baseline setting + baseline_metrics = evaluate_model( + df=df, + feature_cols=["map"], + balanced=False, + ) + print_metrics("Baseline: LogisticRegression(map)", baseline_metrics) + + # Advanced setting + advanced_metrics = evaluate_model( + df=df, + feature_cols=["map", "map_diff"], + balanced=True, + ) + print_metrics( + "Advanced: LogisticRegression(map + map_diff, balanced)", + advanced_metrics, + ) + + # Subgroup fairness + for gender in ["M", "F"]: + subgroup_df = df[df["gender"] == gender].copy() + subgroup_metrics = evaluate_model( + df=subgroup_df, + feature_cols=["map", "map_diff"], + balanced=True, + ) + print_metrics(f"Advanced subgroup gender={gender}", subgroup_metrics) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 50b1b3887..01b1ee6d8 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -90,4 +90,4 @@ def __init__(self, *args, **kwargs): load_processors, save_processors, ) -from .collate import collate_temporal +from .collate import collate_temporal \ No newline at end of file diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index a32618f9c..a883adfac 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -67,3 +67,4 @@ VariantClassificationClinVar, ) from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task +from .circulatory_failure_prediction import CirculatoryFailurePredictionTask \ No newline at end of file diff --git a/pyhealth/tasks/circulatory_failure_prediction.py b/pyhealth/tasks/circulatory_failure_prediction.py new file mode 100644 index 000000000..1a73db80f --- /dev/null +++ b/pyhealth/tasks/circulatory_failure_prediction.py @@ -0,0 +1,236 @@ +""" +Circulatory Failure Prediction Task for PyHealth. + +Dataset: + MIMIC-III Clinical Database v1.4 + https://physionet.org/content/mimiciii/1.4/ + +Inspired by: + Hoche, M., Mineeva, O., Burger, M., Blasimme, A., & Ratsch, G. (2024). + FAMEWS: A fairness auditing tool for medical early-warning systems. + Proceedings of the Fifth Conference on Health, Inference, and Learning, 248, 297–311. PMLR. + https://proceedings.mlr.press/v248/hoche24a.html + +Description: + Time-point prediction task for circulatory failure early warning. + For each MAP measurement at time *t*, the sample is labelled positive + if the first circulatory failure event occurs within the future + prediction window. Circulatory failure is approximated by a mean + arterial pressure (MAP) value below 65 mmHg. + +Authors: + Kuang-Yu Wang (kuangyu4@illinois.edu) + Ya Hsuan Yang (yhyang3@illinois.edu) +""" + +from datetime import timedelta +from typing import Any, Dict, List, Optional +import pandas as pd +from pyhealth.data import Patient +from pyhealth.tasks.base_task import BaseTask + + +MAP_ITEMID = 220052 +MAP_FAILURE_THRESHOLD = 65.0 + + +class CirculatoryFailurePredictionTask(BaseTask): + """Early-warning task for circulatory failure prediction. + + This task converts a PyHealth Patient object into time-point prediction + samples for circulatory failure early warning. For each MAP measurement at + time t, the sample is labeled positive if the first circulatory failure + event occurs within the future prediction window. + Circulatory failure is approximated by a mean arterial pressure (MAP) + value below 65 mmHg. + + Attributes: + task_name: Unique task identifier. + input_schema: Input feature schema for PyHealth processors. + output_schema: Output label schema for PyHealth processors. + + Examples: + >>> from pyhealth.datasets import MIMIC3Dataset + >>> from pyhealth.tasks import CirculatoryFailurePredictionTask + >>> dataset = MIMIC3Dataset( + ... root="/path/to/mimic-iii", + ... tables=["chartevents"], + ... ) + >>> task = CirculatoryFailurePredictionTask(prediction_window_hours=12) + >>> sample_dataset = dataset.set_task(task) + >>> sample_dataset[0] # doctest: +SKIP + """ + + input_schema: Dict[str, str] = { + "map": "tensor", + "map_diff": "tensor" + } + + output_schema: Dict[str, str] = { + "label": "binary" + } + + task_name: str = "circulatory_failure_prediction" + + def __init__( + self, + prediction_window_hours: int = 12, + map_itemid: int = MAP_ITEMID, + failure_threshold: float = MAP_FAILURE_THRESHOLD, + **kwargs: Any, + ) -> None: + """Initializes the circulatory failure prediction task. + + Args: + prediction_window_hours: Future prediction window in hours. + map_itemid: MIMIC-III ITEMID corresponding to MAP. + failure_threshold: MAP threshold used to define circulatory failure. + **kwargs: Additional keyword arguments passed to BaseTask. + """ + super().__init__(**kwargs) + self.prediction_window_hours = prediction_window_hours + self.map_itemid = map_itemid + self.failure_threshold = failure_threshold + self.task_name = ( + f"circulatory_failure_prediction_{prediction_window_hours}h" + ) + + @staticmethod + def _to_datetime(value: Any) -> Optional[pd.Timestamp]: + """Converts a timestamp-like value to pandas Timestamp.""" + if value is None or pd.isna(value): + return None + return pd.to_datetime(value, errors="coerce") + + @staticmethod + def _to_float(value: Any) -> Optional[float]: + """Converts a numeric-like value to float.""" + if value is None or pd.isna(value): + return None + try: + return float(value) + except (TypeError, ValueError): + return None + + @staticmethod + def _event_attr(event: Any, attr: str, default: Any = None) -> Any: + """Gets an event attribute from either object attr or attr_dict.""" + if hasattr(event, attr): + return getattr(event, attr) + + attr_dict = getattr(event, "attr_dict", None) + if isinstance(attr_dict, dict) and attr in attr_dict: + return attr_dict[attr] + + return default + + def __call__(self, patient: Patient) -> List[Dict[str, Any]]: + """Converts one Patient object into prediction samples. + + Args: + patient: A PyHealth Patient object. The patient should contain + `icustays`, `patients`, and `chartevents` events loaded from + MIMIC-III. + + Returns: + A list of sample dictionaries. Each sample contains patient/visit + metadata, MAP-based features, and a binary early-warning label. + Returning an empty list is valid when no usable MAP data or no + failure event is found. + """ + samples: List[Dict[str, Any]] = [] + prediction_window = timedelta(hours=self.prediction_window_hours) + + patient_events = patient.get_events(event_type="patients") + gender = None + if len(patient_events) > 0: + gender = self._event_attr(patient_events[0], "gender") + + icu_stays = patient.get_events(event_type="icustays") + chartevents = patient.get_events(event_type="chartevents") + + if len(icu_stays) == 0 or len(chartevents) == 0: + return [] + + for icu_stay in icu_stays: + icustay_id = self._event_attr(icu_stay, "icustay_id") + intime = self._to_datetime(self._event_attr(icu_stay, "intime")) + outtime = self._to_datetime(self._event_attr(icu_stay, "outtime")) + hadm_id = self._event_attr(icu_stay, "hadm_id") + + if icustay_id is None or intime is None or outtime is None: + continue + + map_events = [] + for event in chartevents: + event_icustay_id = self._event_attr(event, "icustay_id") + itemid = self._event_attr(event, "itemid") + charttime = self._to_datetime(event.timestamp) + valuenum = self._to_float(self._event_attr(event, "valuenum")) + + if event_icustay_id != icustay_id: + continue + if itemid != self.map_itemid: + continue + if charttime is None or pd.isna(charttime): + continue + if valuenum is None: + continue + if not (intime <= charttime <= outtime): + continue + + map_events.append( + { + "charttime": charttime, + "map": valuenum, + } + ) + + if not map_events: + continue + + map_events = sorted(map_events, key=lambda x: x["charttime"]) + + failure_times = [ + row["charttime"] + for row in map_events + if row["map"] < self.failure_threshold + ] + + if not failure_times: + continue + + first_failure_time = min(failure_times) + + previous_map = None + for row in map_events: + timestamp = row["charttime"] + map_value = row["map"] + + if previous_map is None: + map_diff = 0.0 + else: + map_diff = map_value - previous_map + + previous_map = map_value + + label = int( + timestamp < first_failure_time + <= timestamp + prediction_window + ) + + samples.append( + { + "patient_id": patient.patient_id, + "visit_id": str(icustay_id), + "hadm_id": hadm_id, + "icustay_id": icustay_id, + "gender": gender, + "timestamp": timestamp, + "map": map_value, + "map_diff": map_diff, + "label": label, + } + ) + + return samples \ No newline at end of file diff --git a/tests/core/test_circulatory_failure_prediction.py b/tests/core/test_circulatory_failure_prediction.py new file mode 100644 index 000000000..73c1ebfac --- /dev/null +++ b/tests/core/test_circulatory_failure_prediction.py @@ -0,0 +1,110 @@ +from dataclasses import dataclass +from datetime import datetime + +from pyhealth.tasks import CirculatoryFailurePredictionTask + + +@dataclass +class FakeEvent: + event_type: str + timestamp: datetime | None = None + attr_dict: dict | None = None + + def __post_init__(self): + if self.attr_dict is None: + self.attr_dict = {} + + def __getattr__(self, name): + if name in self.attr_dict: + return self.attr_dict[name] + raise AttributeError(name) + + +class FakePatient: + patient_id = "1" + + def __init__(self): + self.events = { + "patients": [ + FakeEvent( + event_type="patients", + attr_dict={"gender": "F"}, + ) + ], + "icustays": [ + FakeEvent( + event_type="icustays", + timestamp=datetime(2150, 1, 1, 0, 0, 0), + attr_dict={ + "hadm_id": 100, + "icustay_id": 1001, + "intime": "2150-01-01 00:00:00", + "outtime": "2150-01-02 00:00:00", + }, + ) + ], + "chartevents": [ + FakeEvent( + event_type="chartevents", + timestamp=datetime(2150, 1, 1, 0, 0, 0), + attr_dict={ + "icustay_id": 1001, + "itemid": 220052, + "valuenum": 80.0, + }, + ), + FakeEvent( + event_type="chartevents", + timestamp=datetime(2150, 1, 1, 1, 0, 0), + attr_dict={ + "icustay_id": 1001, + "itemid": 220052, + "valuenum": 78.0, + }, + ), + FakeEvent( + event_type="chartevents", + timestamp=datetime(2150, 1, 1, 10, 0, 0), + attr_dict={ + "icustay_id": 1001, + "itemid": 220052, + "valuenum": 70.0, + }, + ), + FakeEvent( + event_type="chartevents", + timestamp=datetime(2150, 1, 1, 11, 0, 0), + attr_dict={ + "icustay_id": 1001, + "itemid": 220052, + "valuenum": 60.0, + }, + ), + ], + } + + def get_events(self, event_type=None, *args, **kwargs): + if event_type is None: + all_events = [] + for events in self.events.values(): + all_events.extend(events) + return all_events + return self.events.get(event_type, []) + + +def test_circulatory_failure_task_basic(): + task = CirculatoryFailurePredictionTask(prediction_window_hours=12) + patient = FakePatient() + + samples = task(patient) + + assert len(samples) == 4 + assert samples[0]["label"] == 1 + assert samples[1]["label"] == 1 + assert samples[2]["label"] == 1 + assert samples[3]["label"] == 0 + assert samples[0]["map"] == 80.0 + assert samples[1]["map_diff"] == -2.0 + assert samples[0]["gender"] == "F" + assert samples[0]["patient_id"] == "1" + assert samples[0]["visit_id"] == "1001" \ No newline at end of file