-
Notifications
You must be signed in to change notification settings - Fork 773
Add ADNI dataset and Alzheimer's Disease classification pipeline #921
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| pyhealth.datasets.ADNIDataset | ||
| =================================== | ||
|
|
||
| The Alzheimer's Disease Neuroimaging Initiative (ADNI) dataset. A Pyhealth dataset consisting of labelled MRI brain scan images. For more information, and to apply for access, visit `https://adni.loni.usc.edu/data-samples/adni-data/`. | ||
|
|
||
| .. autoclass:: pyhealth.datasets.ADNIDataset | ||
| :members: | ||
| :undoc-members: | ||
| :show-inheritance: | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May as well clean up this trailing whitespace at the end of this file (even though it's just copy/paste from some of the existing docs). |
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| pyhealth.models.AlzheimersDiseaseCNN | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a very generic name, which makes me think it might conflict with future PyHealth contributions. Unfortunately, the paper this model comes from doesn't really provide a unique name. Any ideas for better names? My first thought was |
||
| =================================== | ||
|
|
||
| Pyhealth model to detect Alzheimer's Disease using ADNI datasets, based on the model described in "On the Design of Convolutional Neural Networks for Automatic Detection of Alzheimer's Disease" by Liu et al. (`https://arxiv.org/abs/1911.03740`). | ||
|
|
||
| .. autoclass:: pyhealth.models.AlzheimersDiseaseCNN | ||
| :members: | ||
| :undoc-members: | ||
| :show-inheritance: | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -46,6 +46,7 @@ Available Processors | |
| - ``TimeImageProcessor``: For time-stamped image sequences (e.g., serial X-rays) | ||
| - ``TensorProcessor``: For pre-processed tensor data | ||
| - ``RawProcessor``: Pass-through processor for raw data | ||
| - ``NIftIImageProcessor``: For NIftI MRI images | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This probably belongs in the "Specialized Processors" section. |
||
|
|
||
| **Specialized Processors:** | ||
|
|
||
|
|
@@ -496,4 +497,5 @@ API Reference | |
| processors/pyhealth.processors.MultiHotProcessor | ||
| processors/pyhealth.processors.StageNetProcessor | ||
| processors/pyhealth.processors.StageNetTensorProcessor | ||
| processors/pyhealth.processors.GraphProcessor | ||
| processors/pyhealth.processors.GraphProcessor | ||
| processors/pyhealth.processors.NIftIImageProcessor | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| pyhealth.processors.NIftIImageProcessor | ||
| =================================== | ||
|
|
||
| Processor for Neuroimaging Informatics Technology Initiative (NIftI) images. | ||
|
|
||
| .. autoclass:: pyhealth.processors.NIftIImageProcessor | ||
| :members: | ||
| :undoc-members: | ||
| :show-inheritance: |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| pyhealth.tasks.AlzheimersDiseaseClassification | ||
| ======================================= | ||
|
|
||
| .. autoclass:: pyhealth.tasks.base_task.AlzheimersDiseaseClassification | ||
| :members: | ||
| :undoc-members: | ||
| :show-inheritance: |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,154 @@ | ||
| """Classification of Alzheimer's Disease using the ADNI Dataset. | ||
|
|
||
| Author: Bryan Lau (bryan16@illinois.edu) | ||
|
|
||
| This script executes the Pyhealth pipeline implementation of the method | ||
| described in the paper "On the Design of Convolutional Neural Networks | ||
| for Automatic Detection of Alzheimer's Disease" by Liu et al. | ||
| (https://arxiv.org/abs/1911.03740). | ||
|
|
||
| The pipeline consists of: | ||
|
|
||
| - ADNIDataset: | ||
| Pyhealth-compatible dataset for ADNI data | ||
|
|
||
| - AlzheimersDiseaseClassification: | ||
| Task to present the necessary features | ||
|
|
||
| - NIftIImageProcessor: | ||
| Image processor to load and process NIftI files | ||
|
|
||
| - AlzheimersDiseaseCNN: | ||
| Model that replicates the structure described by Liu et al. | ||
|
|
||
| This script has been tested against 3182 real MRI brain scan image files | ||
| downloaded from the ADNI dataset, hosted at the Image & Data Archive (IDA) | ||
| at the Laboratory of Neuro Imaging (LONI). Users may apply for access at: | ||
|
|
||
| https://adni.loni.usc.edu/data-samples/adni-data/ | ||
|
|
||
| The required pre-processing parameters are: | ||
|
|
||
| - Multiplanar reconstruction (MPR) | ||
| - Gradient warping correction (GradWarp) | ||
| - B1 non-uniformity correction | ||
| - N3 intensity normalization | ||
|
|
||
| This implementation aligns with the method described by Liu et al. except | ||
| that the following are omitted for the sake of simplicity: | ||
|
|
||
| - Image pre-processing using Clinica (e.g. Dartel template registration) | ||
| - Training augmentation, i.e. gaussian blurring and random cropping | ||
|
|
||
| """ | ||
| import os | ||
| import shutil | ||
|
|
||
| from pathlib import Path | ||
| from torch.optim import SGD | ||
|
|
||
| from pyhealth.datasets import ADNIDataset, get_dataloader | ||
| from pyhealth.datasets.splitter import split_by_patient # New splitter | ||
| from pyhealth.models import AlzheimersDiseaseCNN | ||
| from pyhealth.tasks import AlzheimersDiseaseClassification | ||
| from pyhealth.trainer import Trainer | ||
|
|
||
| from adni_ad_synthetic_data import create_adni_image | ||
|
|
||
| # Initialization | ||
| BATCH_SIZE = 4 | ||
| EPOCHS = 50 | ||
| PATIENCE = 10 | ||
| LEARNING_RATE = 0.01 | ||
| MOMENTUM = 0.9 | ||
| METRICS = ["balanced_accuracy", "accuracy", "f1_macro", "roc_auc_macro_ovr"] | ||
| MONITOR = "balanced_accuracy" | ||
| NUM_SYNTHETIC_SAMPLES = 30 | ||
| NUM_WORKERS = 4 | ||
| SEED = 99 | ||
|
|
||
| # Path where the ADNI files are located | ||
| ADNI_ROOT = "./adni_root" | ||
| CACHE_DIR = "./cache" | ||
|
|
||
| # Set this flag to: | ||
| # True to generate and use synthetic data | ||
| # False to use real ADNI images that you have downloaded | ||
| USE_SYNTHETIC_DATA = True | ||
|
|
||
| if __name__ == '__main__': | ||
|
|
||
| # Convert paths | ||
| adni_path = Path(ADNI_ROOT) | ||
| cache_path = Path(CACHE_DIR) | ||
|
|
||
| # Generate synthetic ADNI data | ||
| if USE_SYNTHETIC_DATA: | ||
| adni_path = Path("./adni_synthetic") | ||
| cache_path = Path("./cache_synthetic") | ||
|
|
||
| for path in [adni_path, cache_path]: | ||
| if os.path.exists(path): | ||
| shutil.rmtree(path) | ||
| os.makedirs(path, exist_ok=True) | ||
|
|
||
| for i in range(NUM_SYNTHETIC_SAMPLES): | ||
| subject_id = f"002_S_{i:04d}" | ||
| if i < int(NUM_SYNTHETIC_SAMPLES * 0.33): | ||
| group = "CN" | ||
| elif i < int(NUM_SYNTHETIC_SAMPLES * 0.66): | ||
| group = "MCI" | ||
| else: | ||
| group = "AD" | ||
| print(f"Generating synthetic image for {subject_id} ({group})") | ||
| create_adni_image(adni_path, subject_id, group) | ||
|
|
||
| # Instantiate base ADNI dataset | ||
| adni_dataset = ADNIDataset(root=str(adni_path), cache_dir=str(cache_path), dev=False, num_workers=NUM_WORKERS) | ||
| adni_dataset.stats() | ||
|
|
||
| # Set task and obtain samples | ||
| adni_task = AlzheimersDiseaseClassification() | ||
| sample_dataset = adni_dataset.set_task(adni_task) | ||
|
|
||
| # Split data by patient into train/val/test (70/15/15) | ||
| split_ratios = [0.7, 0.15, 0.15] | ||
| train_data, val_data, test_data = split_by_patient( | ||
| sample_dataset, ratios=split_ratios, seed=SEED) | ||
|
|
||
| # Create dataloaders | ||
| train_loader = get_dataloader(train_data, batch_size=BATCH_SIZE, shuffle=True) | ||
| val_loader = get_dataloader(val_data, batch_size=BATCH_SIZE, shuffle=False) | ||
| test_loader = get_dataloader(test_data, batch_size=BATCH_SIZE, shuffle=False) | ||
|
|
||
| # Instantiate model using samples | ||
| model = AlzheimersDiseaseCNN( | ||
| dataset=sample_dataset, | ||
| width_factor=4, use_age=True, use_gender=True, norm_method="instance" | ||
| ) | ||
|
|
||
| # Instantiate trainer | ||
| trainer = Trainer( | ||
| model=model, | ||
| metrics=METRICS, | ||
| output_path="./output" | ||
| ) | ||
|
|
||
| # Train the model | ||
| trainer.train( | ||
| train_dataloader=train_loader, | ||
| val_dataloader=val_loader, | ||
| test_dataloader=test_loader, | ||
| epochs=EPOCHS, | ||
| optimizer_class=SGD, | ||
| optimizer_params={"lr": LEARNING_RATE, "momentum": MOMENTUM, "weight_decay": 1e-3}, | ||
| max_grad_norm=1.0, | ||
| monitor=MONITOR, | ||
| monitor_criterion="max", | ||
| patience=PATIENCE, | ||
| load_best_model_at_last=True, | ||
| ) | ||
|
|
||
| # Evaluate | ||
| scores = trainer.evaluate(test_loader) | ||
| print(f"\nTest scores: {scores}") |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,125 @@ | ||
| """Synthetic ADNI data helper. | ||
|
|
||
| Author: Bryan Lau (bryan16@illinois.edu) | ||
| Description: | ||
| This helper function creates an ADNI image record along with the required | ||
| directory structure and metadata files. | ||
| """ | ||
| import nibabel as nib | ||
| import numpy as np | ||
| import random | ||
|
|
||
| def create_adni_image(adni_root_path, subject_id=None, group=None): | ||
| """Create test ADNI directory structure populated with synthetic data. | ||
|
|
||
| Creates a directory structure for one subject, with the same layout as | ||
| the one obtained by downloading actual ADNI data files. | ||
|
|
||
| The directory structure has the following layout: | ||
|
|
||
| - root | ||
| - subject id | ||
| - pre-processing transform | ||
| - date acquired | ||
| - image uid | ||
| MRI image file | ||
| metadata xml file | ||
|
|
||
| Args: | ||
| adni_root_path: Path in which to create the ADNI directory structure | ||
| subject_id: Subject ID for this directory structure, if None then a | ||
| random Subject ID will be generated instead. | ||
| group: Label to assign to this subject, if None then a label will be | ||
| randomly selected from the three valid choices (i.e. MCI, CN, AD). | ||
|
|
||
| Returns: | ||
| Dictionary containing the following values for later comparison: | ||
| - subject_id: Subject ID of the patient. | ||
| - group: Label assigned to the patient. | ||
| - gender: Patient's randomly selected gender. | ||
| - age: Patient's randomly selected age. | ||
| - weight: Patient's randomly selected weight. | ||
| - image_uid: Unique ID of the MRI image. | ||
| - image_path: Path to the MRI image file. | ||
| """ | ||
|
|
||
| if not subject_id: | ||
| subject_id = f"002_S_{random.randint(0, 9999):04d}" | ||
| if not group: | ||
| group = random.choice(["CN", "MCI", "AD"]) | ||
| gender = random.choice(["M", "F"]) | ||
| age = round(random.uniform(40.0000, 85.0000), 4) | ||
| weight = round(random.uniform(55.0, 120.0), 1) | ||
| date_acquired = f"{random.randint(1950, 2000)}-03-15" | ||
|
|
||
| xform_str = "MPR__GradWarp__B1_Correction__N3" | ||
| date_dir = f"{date_acquired}_09_45_30.0" | ||
| series_id = f"{random.randint(0, 99999):05d}" | ||
| image_uid = f"{random.randint(0, 99999):05d}" | ||
|
|
||
| # Create MRI image directory structure | ||
| adni_image_dir = adni_root_path / subject_id / \ | ||
| xform_str / date_dir / f"I{image_uid}" | ||
| adni_image_dir.mkdir(parents=True) | ||
|
|
||
| # Generate test image filename and data | ||
| file_date_str = f"{date_acquired.replace("-", "")}{random.randint(100000000, 300000000):9d}" | ||
| image_filepath = adni_image_dir / \ | ||
| f"ADNI_{subject_id}_MR_{xform_str}_Br_{file_date_str}_S{series_id}_I{image_uid}.nii" | ||
| image_data = np.random.rand(121, 145, 121).astype(np.float32) | ||
|
|
||
| # Generate group marking (to simulate image features) | ||
| mark_value = 10.0 | ||
| if group == "CN": | ||
| image_data[10:15, 10:15, 10:15] = mark_value | ||
| elif group == "MCI": | ||
| image_data[60:65, 60:65, 60:65] = mark_value | ||
| elif group == "AD": | ||
| image_data[100:105, 100:105, 100:105] = mark_value | ||
|
|
||
| # Save the image | ||
| mri_image = nib.Nifti1Image(image_data, affine=np.eye(4)) | ||
| nib.save(mri_image, image_filepath) | ||
|
|
||
| # Generate metadata xml | ||
| metadata_xml = f"""<?xml version="1.0" encoding="UTF-8"?> | ||
| <idaxs> | ||
| <project> | ||
| <projectIdentifier>ADNI</projectIdentifier> | ||
| <siteKey>002</siteKey> | ||
| <subject> | ||
| <subjectIdentifier>{subject_id}</subjectIdentifier> | ||
| <researchGroup>{group}</researchGroup> | ||
| <subjectSex>{gender}</subjectSex> | ||
| <study> | ||
| <subjectAge>{age}</subjectAge> | ||
| <weightKg>{weight}</weightKg> | ||
| <series> | ||
| <seriesIdentifier>{series_id}</seriesIdentifier> | ||
| <dateAcquired>{date_acquired}</dateAcquired> | ||
| <seriesLevelMeta> | ||
| <derivedProduct> | ||
| <imageUID>{image_uid}</imageUID> | ||
| </derivedProduct> | ||
| </seriesLevelMeta> | ||
| </series> | ||
| </study> | ||
| </subject> | ||
| </project> | ||
| </idaxs> | ||
| """ | ||
| metadata_xml_filename = f"ADNI_{subject_id}_{xform_str}_S{series_id}_I{image_uid}.xml" | ||
| metadata_xml_path = adni_root_path / metadata_xml_filename | ||
| with open(metadata_xml_path, "w", encoding="utf-8") as f: | ||
| f.write(metadata_xml) | ||
|
|
||
| # Return test values for later comparison | ||
| return { | ||
| "subject_id": subject_id, | ||
| "group": group, | ||
| "gender": gender, | ||
| "age": age, | ||
| "weight": weight, | ||
| "image_uid": image_uid, | ||
| "image_path": image_filepath, | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pretty sure this won't create clickable links (here and in
pyhealth.models.AlzheimersDiseaseCNN.rst). See the existing docs for how to create clickable links in RST files (and don't forget the trailing underscore).