From 6571cfbe6277d6016794e8782fd9655de004f3e9 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 30 Jan 2026 16:52:13 +0100 Subject: [PATCH 1/9] changes moved from https://github.com/ChEB-AI/python-chebai/pull/135 --- README.md | 16 +- chebai/cli.py | 7 +- chebai/models/base.py | 30 +++- chebai/preprocessing/datasets/base.py | 100 +++++++++-- chebai/result/prediction.py | 214 +++++++++++++++++++++++ chebai/trainer/CustomTrainer.py | 79 ++------- tests/unit/cli/classification_labels.txt | 10 ++ tests/unit/cli/mock_dm.py | 6 + tests/unit/cli/testCLI.py | 2 +- 9 files changed, 371 insertions(+), 93 deletions(-) create mode 100644 chebai/result/prediction.py create mode 100644 tests/unit/cli/classification_labels.txt diff --git a/README.md b/README.md index 7672bc28..d9145cf3 100644 --- a/README.md +++ b/README.md @@ -78,11 +78,19 @@ python -m chebai fit --config=[path-to-your-esol-config] --trainer.callbacks=con ### Predicting classes given SMILES strings ``` -python3 -m chebai predict_from_file --model=[path-to-model-config] --checkpoint_path=[path-to-model] --input_path={path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]] +python3 chebai/result/prediction.py predict_from_file --checkpoint_path=[path-to-model] ----smiles_file_path=[path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]] ``` -The input files should contain a list of line-separated SMILES strings. This generates a CSV file that contains the -one row for each SMILES string and one column for each class. -The `classes_path` is the path to the dataset's `raw/classes.txt` file that contains the relationship between model output and ChEBI-IDs. + +* **`--checkpoint_path`**: Path to the Lightning checkpoint file (must end with `.ckpt`). + +* **`--smiles_file_path`**: Path to a text file containing one SMILES string per line. + +* **`--save_to`** *(optional)*: Predictions will be saved to the path as CSV file. The CSV will contain one row per SMILES string and one column per predicted class. Default path will be the current working directory with file name as `predictions.csv`. + +* **`--classes_path`** *(optional)*: Path to the dataset’s `classes.txt` file, which maps model output indices to ChEBI IDs. + * Checkpoints created after PR #135 will have the classification labels stored in them and hence this parameter is not required. + * If provided, the CSV columns will be named using the ChEBI IDs. + * If omitted, then script will located the file automatically. If unable to locate then the columns will be numbered sequentially. ## Evaluation diff --git a/chebai/cli.py b/chebai/cli.py index 1aaba53c..8b51e45f 100644 --- a/chebai/cli.py +++ b/chebai/cli.py @@ -59,6 +59,12 @@ def call_data_methods(data: Type[XYBaseDataModule]): apply_on="instantiate", ) + parser.link_arguments( + "data.classes_txt_file_path", + "model.init_args.classes_txt_file_path", + apply_on="instantiate", + ) + for kind in ("train", "val", "test"): for average in ( "micro-f1", @@ -112,7 +118,6 @@ def subcommands() -> Dict[str, Set[str]]: "validate": {"model", "dataloaders", "datamodule"}, "test": {"model", "dataloaders", "datamodule"}, "predict": {"model", "dataloaders", "datamodule"}, - "predict_from_file": {"model"}, } diff --git a/chebai/models/base.py b/chebai/models/base.py index 82d84033..6c2daa41 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -40,6 +40,7 @@ def __init__( pass_loss_kwargs: bool = True, optimizer_kwargs: Optional[Dict[str, Any]] = None, exclude_hyperparameter_logging: Optional[Iterable[str]] = None, + classes_txt_file_path: Optional[str] = None, **kwargs, ): super().__init__(**kwargs) @@ -47,8 +48,8 @@ def __init__( if exclude_hyperparameter_logging is None: exclude_hyperparameter_logging = tuple() self.criterion = criterion - assert out_dim is not None, "out_dim must be specified" - assert input_dim is not None, "input_dim must be specified" + assert out_dim is not None and out_dim > 0, "out_dim must be specified" + assert input_dim is not None and input_dim > 0, "input_dim must be specified" self.out_dim = out_dim self.input_dim = input_dim print( @@ -77,6 +78,17 @@ def __init__( self.validation_metrics = val_metrics self.test_metrics = test_metrics self.pass_loss_kwargs = pass_loss_kwargs + with open(classes_txt_file_path, "r") as f: + self.labels_list = [cls.strip() for cls in f.readlines()] + assert len(self.labels_list) > 0, "Class labels list is empty." + assert len(self.labels_list) == out_dim, ( + f"Number of class labels ({len(self.labels_list)}) does not match " + f"the model output dimension ({out_dim})." + ) + + def on_save_checkpoint(self, checkpoint): + # https://lightning.ai/docs/pytorch/stable/common/checkpointing_intermediate.html#modify-a-checkpoint-anywhere + checkpoint["classification_labels"] = self.labels_list def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # avoid errors due to unexpected keys (e.g., if loading checkpoint from a bce model and using it with a @@ -100,7 +112,7 @@ def __init_subclass__(cls, **kwargs): def _get_prediction_and_labels( self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor - ) -> (torch.Tensor, torch.Tensor): + ) -> tuple[torch.Tensor, torch.Tensor]: """ Gets the predictions and labels from the model output. @@ -151,7 +163,7 @@ def _process_for_loss( model_output: torch.Tensor, labels: torch.Tensor, loss_kwargs: Dict[str, Any], - ) -> (torch.Tensor, torch.Tensor, Dict[str, Any]): + ) -> tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: """ Processes the data for loss computation. @@ -237,7 +249,15 @@ def predict_step( Returns: Dict[str, Union[torch.Tensor, Any]]: The result of the prediction step. """ - return self._execute(batch, batch_idx, self.test_metrics, prefix="", log=False) + assert isinstance(batch, XYData) + batch = batch.to(self.device) + data = self._process_batch(batch, batch_idx) + model_output = self(data, **data.get("model_kwargs", dict())) + + # Dummy labels to avoid errors in _get_prediction_and_labels + labels = torch.zeros((len(batch), self.out_dim)).to(self.device) + pr, _ = self._get_prediction_and_labels(data, labels, model_output) + return pr def _execute( self, diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 9ff40748..df3dab6a 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -340,18 +340,19 @@ def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]: for d in tqdm.tqdm(self._load_dict(path), total=lines) if d["features"] is not None ] - # filter for missing features in resulting data, keep features length below token limit - data = [ - val - for val in data - if val["features"] is not None - and ( - self.n_token_limit is None or len(val["features"]) <= self.n_token_limit - ) - ] + data = [val for val in data if self._filter_to_token_limit(val)] return data + def _filter_to_token_limit(self, data_instance: dict) -> bool: + # filter for missing features in resulting data, keep features length below token limit + if data_instance["features"] is not None and ( + self.n_token_limit is None + or len(data_instance["features"]) <= self.n_token_limit + ): + return True + return False + def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: """ Returns the train DataLoader. @@ -401,22 +402,77 @@ def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader] Returns: Union[DataLoader, List[DataLoader]]: A DataLoader object for test data. """ + return self.dataloader("test", shuffle=False, **kwargs) def predict_dataloader( - self, *args, **kwargs - ) -> Union[DataLoader, List[DataLoader]]: + self, + smiles_list: List[str], + model_hparams: Optional[dict] = None, + **kwargs, + ) -> tuple[DataLoader, list[int]]: """ Returns the predict DataLoader. Args: - *args: Additional positional arguments (unused). + smiles_list (List[str]): List of SMILES strings to predict. + model_hparams (Optional[dict]): Model hyperparameters. + Some prediction pre-processing pipelines may require these. **kwargs: Additional keyword arguments, passed to dataloader(). Returns: - Union[DataLoader, List[DataLoader]]: A DataLoader object for prediction data. + tuple[DataLoader, list[int]]: A DataLoader object for prediction data and a list of valid indices. """ - return self.dataloader(self.prediction_kind, shuffle=False, **kwargs) + + data, valid_indices = self._process_input_for_prediction( + smiles_list, model_hparams + ) + return ( + DataLoader( + data, + collate_fn=self.reader.collator, + batch_size=self.batch_size, + **kwargs, + ), + valid_indices, + ) + + def _process_input_for_prediction( + self, smiles_list: list[str], model_hparams: Optional[dict] = None + ) -> tuple[list, list]: + """ + Process input data for prediction. + + Args: + smiles_list (List[str]): List of SMILES strings. + model_hparams (Optional[dict]): Model hyperparameters. + Some prediction pre-processing pipelines may require these. + + Returns: + tuple[list, list]: Processed input data and valid indices. + """ + data, valid_indices = [], [] + for idx, smiles in enumerate(smiles_list): + result = self._preprocess_smiles_for_pred(idx, smiles, model_hparams) + if result is None or result["features"] is None: + continue + if not self._filter_to_token_limit(result): + continue + data.append(result) + valid_indices.append(idx) + + return data, valid_indices + + def _preprocess_smiles_for_pred( + self, idx, smiles: str, model_hparams: Optional[dict] = None + ) -> dict: + """Preprocess prediction data.""" + # Add dummy labels because the collate function requires them. + # Note: If labels are set to `None`, the collator will insert a `non_null_labels` entry into `loss_kwargs`, + # which later causes `_get_prediction_and_labels` method in the prediction pipeline to treat the data as empty. + return self.reader.to_data( + {"id": f"smiles_{idx}", "features": smiles, "labels": [1, 2]} + ) def prepare_data(self, *args, **kwargs) -> None: if self._prepare_data_flag != 1: @@ -563,6 +619,19 @@ def raw_file_names_dict(self) -> dict: """ raise NotImplementedError + @property + def classes_txt_file_path(self) -> str: + """ + Returns the filename for the classes text file. + + Returns: + str: The filename for the classes text file. + """ + # This property also used in following places: + # - results/prediction.py: to load class names for csv columns names + # - chebai/cli.py: to link this property to `model.init_args.classes_txt_file_path` + return os.path.join(self.processed_dir_main, "classes.txt") + class MergedDataset(XYBaseDataModule): MERGED = [] @@ -1189,7 +1258,8 @@ def _retrieve_splits_from_csv(self) -> None: print(f"Applying label filter from {self.apply_label_filter}...") with open(self.apply_label_filter, "r") as f: label_filter = [line.strip() for line in f] - with open(os.path.join(self.processed_dir_main, "classes.txt"), "r") as cf: + + with open(self.classes_txt_file_path, "r") as cf: classes = [line.strip() for line in cf] # reorder labels old_labels = np.stack(df_data["labels"]) diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py new file mode 100644 index 00000000..fda8a308 --- /dev/null +++ b/chebai/result/prediction.py @@ -0,0 +1,214 @@ +import os +from typing import List, Optional + +import pandas as pd +import torch +from jsonargparse import CLI +from lightning.fabric.utilities.types import _PATH +from lightning.pytorch.cli import instantiate_module + +from chebai.models.base import ChebaiBaseNet +from chebai.preprocessing.datasets.base import XYBaseDataModule + + +class Predictor: + def __init__( + self, + checkpoint_path: _PATH, + batch_size: Optional[int] = None, + compile_model: bool = True, + ): + """Initializes the Predictor with a model loaded from the checkpoint. + + Args: + checkpoint_path: Path to the model checkpoint. + batch_size: Optional batch size for the DataLoader. If not provided, + the default from the datamodule will be used. + compile_model: Whether to compile the model using torch.compile. Default is True. + """ + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + ckpt_file = torch.load( + checkpoint_path, map_location=self.device, weights_only=False + ) + assert ( + "_class_path" in ckpt_file["datamodule_hyper_parameters"] + and "_class_path" in ckpt_file["hyper_parameters"] + ), ( + "Datamodule and Model hyperparameters must include a '_class_path' key.\n" + "Hence, either the checkpoint is corrupted or " + "it was not saved properly with latest lightning version" + ) + + print("-" * 50) + print(f"For Loaded checkpoint from: {checkpoint_path}") + print("Below are the modules loaded from the checkpoint:") + + self._dm_hparams = ckpt_file["datamodule_hyper_parameters"] + self._dm_hparams.pop("splits_file_path") + self._dm_hparams.pop("augment_smiles", None) + self._dm_hparams.pop("aug_smiles_variations", None) + self._dm_hparams.pop("_instantiator", None) + self._dm: XYBaseDataModule = instantiate_module( + XYBaseDataModule, self._dm_hparams + ) + if batch_size is not None and int(batch_size) > 0: + self._dm.batch_size = int(batch_size) + print("*" * 10, f"Loaded datamodule class: {self._dm.__class__.__name__}") + + self._model_hparams = ckpt_file["hyper_parameters"] + self._model_hparams.pop("_instantiator", None) + self._model: ChebaiBaseNet = instantiate_module( + ChebaiBaseNet, self._model_hparams + ) + print("*" * 10, f"Loaded model class: {self._model.__class__.__name__}") + + self._classification_labels: list | None = ckpt_file.get( + "classification_labels", None + ) + if self._classification_labels is not None: + print(f"Loaded {len(self._classification_labels)} classification labels.") + assert len(self._classification_labels) > 0, ( + "Classification labels list is empty." + ) + assert len(self._classification_labels) == self._model.out_dim, ( + f"Number of class labels ({len(self._classification_labels)}) does not match " + f"the model output dimension ({self._model.out_dim})." + ) + + if compile_model: + self._model = torch.compile(self._model) + self._model.eval() + print("-" * 50) + + def predict_from_file( + self, + smiles_file_path: _PATH, + save_to: _PATH = "predictions.csv", + classes_path: Optional[_PATH] = None, + ) -> None: + """ + Loads a model from a checkpoint and makes predictions on input data from a file. + + Args: + smiles_file_path: Path to the input file containing SMILES strings. + save_to: Path to save the predictions CSV file. + classes_path: Optional path to a file containing class names: + if no class names are provided, code will try to get the class path + from the datamodule, else the columns will be numbered. + """ + with open(smiles_file_path, "r") as input: + smiles_strings = [inp.strip() for inp in input.readlines()] + + CLASS_LABELS: list | None = None + + def _add_class_columns(class_file_path: _PATH) -> list[str]: + with open(class_file_path, "r") as f: + return [cls.strip() for cls in f.readlines()] + + if self._classification_labels is not None: + CLASS_LABELS = self._classification_labels + # --- For old checkpoints that do not have classification_labels saved --- + elif classes_path is not None: + CLASS_LABELS = _add_class_columns(classes_path) + elif os.path.exists(self._dm.classes_txt_file_path): + CLASS_LABELS = _add_class_columns(self._dm.classes_txt_file_path) + + preds: list[torch.Tensor | None] = self.predict_smiles(smiles=smiles_strings) + if all(pred is None for pred in preds): + print("No valid predictions were made. (All predictions are None.)") + return + + # --- Logic for old checkpoints that do not have classification_labels saved --- + if CLASS_LABELS is not None and self._model.out_dim is not None: + assert len(CLASS_LABELS) > 0, "Class labels list is empty." + assert len(CLASS_LABELS) == self._model.out_dim, ( + f"Number of class labels ({len(CLASS_LABELS)}) does not match " + f"the model output dimension ({self._model.out_dim})." + ) + num_of_cols = len(CLASS_LABELS) + elif CLASS_LABELS is not None: + assert len(CLASS_LABELS) > 0, "Class labels list is empty." + num_of_cols = len(CLASS_LABELS) + elif self._model.out_dim is not None: + num_of_cols = self._model.out_dim + else: + # find first non-None tensor to determine width + num_of_cols = next(x.numel() for x in preds if x is not None) + CLASS_LABELS = [f"class_{i}" for i in range(num_of_cols)] + + rows = [ + pred.tolist() if pred is not None else [None] * num_of_cols + for pred in preds + ] + predictions_df = pd.DataFrame(rows, columns=CLASS_LABELS, index=smiles_strings) + + predictions_df.to_csv(save_to) + print(f"Predictions saved to: {save_to}") + + @torch.inference_mode() + def predict_smiles( + self, + smiles: List[str], + ) -> list[torch.Tensor | None]: + """ + Predicts the output for a list of SMILES strings using the model. + + Args: + smiles: A list of SMILES strings. + + Returns: + A tensor containing the predictions. + """ + # For certain data prediction piplines, we may need model hyperparameters + pred_dl, valid_indices = self._dm.predict_dataloader( + smiles_list=smiles, model_hparams=self._model_hparams + ) + + preds = [] + for batch_idx, batch in enumerate(pred_dl): + # For certain model prediction pipelines, we may need data module hyperparameters + preds.append( + self._model.predict_step(batch, batch_idx, dm_hparams=self._dm_hparams) + ) + preds = torch.cat(preds) + + # Initialize output with None + output: list[torch.Tensor | None] = [None] * len(smiles) + + # Scatter predictions back + for pred, idx in zip(preds, valid_indices): + output[idx] = pred + + return output + + +class MainPredictor: + @staticmethod + def predict_from_file( + checkpoint_path: _PATH, + smiles_file_path: _PATH, + save_to: _PATH = "predictions.csv", + classes_path: Optional[_PATH] = None, + batch_size: Optional[int] = None, + ) -> None: + predictor = Predictor(checkpoint_path, batch_size) + predictor.predict_from_file( + smiles_file_path, + save_to, + classes_path, + ) + + @staticmethod + def predict_smiles( + checkpoint_path: _PATH, + smiles: List[str], + batch_size: Optional[int] = None, + ) -> torch.Tensor: + predictor = Predictor(checkpoint_path, batch_size) + return predictor.predict_smiles(smiles) + + +if __name__ == "__main__": + # python chebai/result/prediction.py predict_from_file --help + # python chebai/result/prediction.py predict_smiles --help + CLI(MainPredictor, as_positional=False) diff --git a/chebai/trainer/CustomTrainer.py b/chebai/trainer/CustomTrainer.py index 5c960007..b9e4b0f3 100644 --- a/chebai/trainer/CustomTrainer.py +++ b/chebai/trainer/CustomTrainer.py @@ -1,18 +1,13 @@ import logging -from typing import Any, List, Optional, Tuple +from typing import Any, Optional, Tuple -import pandas as pd -import torch -from lightning import LightningModule, Trainer +from lightning import Trainer from lightning.fabric.utilities.data import _set_sampler_epoch -from lightning.fabric.utilities.types import _PATH from lightning.pytorch.loggers import WandbLogger from lightning.pytorch.loops.fit_loop import _FitLoop from lightning.pytorch.trainer import call -from torch.nn.utils.rnn import pad_sequence from chebai.loggers.custom import CustomLogger -from chebai.preprocessing.reader import CLS_TOKEN, ChemDataReader log = logging.getLogger(__name__) @@ -74,68 +69,18 @@ def _resolve_logging_argument(self, key: str, value: Any) -> Tuple[str, Any]: else: return key, value - def predict_from_file( + def predict( self, - model: LightningModule, - checkpoint_path: _PATH, - input_path: _PATH, - save_to: _PATH = "predictions.csv", - classes_path: Optional[_PATH] = None, - **kwargs, - ) -> None: - """ - Loads a model from a checkpoint and makes predictions on input data from a file. - - Args: - model: The model to use for predictions. - checkpoint_path: Path to the model checkpoint. - input_path: Path to the input file containing SMILES strings. - save_to: Path to save the predictions CSV file. - classes_path: Optional path to a file containing class names (if no class names are provided, columns will be numbered). - """ - loaded_model = model.__class__.load_from_checkpoint(checkpoint_path) - with open(input_path, "r") as input: - smiles_strings = [inp.strip() for inp in input.readlines()] - loaded_model.eval() - predictions = self._predict_smiles(loaded_model, smiles_strings) - predictions_df = pd.DataFrame(predictions.detach().cpu().numpy()) - if classes_path is not None: - with open(classes_path, "r") as f: - predictions_df.columns = [cls.strip() for cls in f.readlines()] - predictions_df.index = smiles_strings - predictions_df.to_csv(save_to) - - def _predict_smiles( - self, model: LightningModule, smiles: List[str] - ) -> torch.Tensor: - """ - Predicts the output for a list of SMILES strings using the model. - - Args: - model: The model to use for predictions. - smiles: A list of SMILES strings. - - Returns: - A tensor containing the predictions. - """ - reader = ChemDataReader() - parsed_smiles = [reader._read_data(s) for s in smiles] - x = pad_sequence( - [torch.tensor(a, device=model.device) for a in parsed_smiles], - batch_first=True, + model=None, + dataloaders=None, + datamodule=None, + return_predictions=None, + ckpt_path=None, + ): + raise NotImplementedError( + "CustomTrainer.predict is not implemented." + "Use `Prediction.predict_from_file` or `Prediction.predict_smiles` from `chebai/result/prediction.py` instead." ) - cls_tokens = ( - torch.ones(x.shape[0], dtype=torch.int, device=model.device).unsqueeze(-1) - * CLS_TOKEN - ) - features = torch.cat((cls_tokens, x), dim=1) - model_output = model({"features": features}) - if model.model_type == "regression": - preds = model_output["logits"] - else: - preds = torch.sigmoid(model_output["logits"]) - - return preds @property def log_dir(self) -> Optional[str]: diff --git a/tests/unit/cli/classification_labels.txt b/tests/unit/cli/classification_labels.txt new file mode 100644 index 00000000..06d2d6d1 --- /dev/null +++ b/tests/unit/cli/classification_labels.txt @@ -0,0 +1,10 @@ +label_1 +label_2 +label_3 +label_4 +label_5 +label_6 +label_7 +label_8 +label_9 +label_10 diff --git a/tests/unit/cli/mock_dm.py b/tests/unit/cli/mock_dm.py index 25116e21..e3fd60a7 100644 --- a/tests/unit/cli/mock_dm.py +++ b/tests/unit/cli/mock_dm.py @@ -1,3 +1,5 @@ +import os + import torch from lightning.pytorch.core.datamodule import LightningDataModule from torch.utils.data import DataLoader @@ -29,6 +31,10 @@ def num_of_labels(self): def feature_vector_size(self): return self._feature_vector_size + @property + def classes_txt_file_path(self) -> str: + return os.path.join("tests", "unit", "cli", "classification_labels.txt") + def train_dataloader(self): assert self.feature_vector_size is not None, "feature_vector_size must be set" # Dummy dataset for example purposes diff --git a/tests/unit/cli/testCLI.py b/tests/unit/cli/testCLI.py index 863a6df3..584da5e7 100644 --- a/tests/unit/cli/testCLI.py +++ b/tests/unit/cli/testCLI.py @@ -9,7 +9,7 @@ def setUp(self): "fit", "--trainer=configs/training/default_trainer.yml", "--model=configs/model/ffn.yml", - "--model.init_args.hidden_layers=[10]", + "--model.init_args.hidden_layers=[1]", "--model.train_metrics=configs/metrics/micro-macro-f1.yml", "--data=tests/unit/cli/mock_dm_config.yml", "--model.pass_loss_kwargs=false", From 1997d9f712f5cb5f33b73624b081096fa24aa5fc Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 30 Jan 2026 19:06:42 +0100 Subject: [PATCH 2/9] fix pipeline --- chebai/models/base.py | 25 +++++++++++++++---------- chebai/preprocessing/datasets/base.py | 7 ++++++- chebai/result/prediction.py | 24 ++++++++++-------------- 3 files changed, 31 insertions(+), 25 deletions(-) diff --git a/chebai/models/base.py b/chebai/models/base.py index 6c2daa41..ec02e343 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -63,6 +63,7 @@ def __init__( "train_metrics", "val_metrics", "test_metrics", + "classes_txt_file_path", *exclude_hyperparameter_logging, ] ) @@ -78,17 +79,21 @@ def __init__( self.validation_metrics = val_metrics self.test_metrics = test_metrics self.pass_loss_kwargs = pass_loss_kwargs - with open(classes_txt_file_path, "r") as f: - self.labels_list = [cls.strip() for cls in f.readlines()] - assert len(self.labels_list) > 0, "Class labels list is empty." - assert len(self.labels_list) == out_dim, ( - f"Number of class labels ({len(self.labels_list)}) does not match " - f"the model output dimension ({out_dim})." - ) + + self.classes_txt_file_path = classes_txt_file_path + if classes_txt_file_path is not None: + with open(classes_txt_file_path, "r") as f: + self.labels_list = [cls.strip() for cls in f.readlines()] + assert len(self.labels_list) > 0, "Class labels list is empty." + assert len(self.labels_list) == out_dim, ( + f"Number of class labels ({len(self.labels_list)}) does not match " + f"the model output dimension ({out_dim})." + ) def on_save_checkpoint(self, checkpoint): - # https://lightning.ai/docs/pytorch/stable/common/checkpointing_intermediate.html#modify-a-checkpoint-anywhere - checkpoint["classification_labels"] = self.labels_list + if self.classes_txt_file_path is not None: + # https://lightning.ai/docs/pytorch/stable/common/checkpointing_intermediate.html#modify-a-checkpoint-anywhere + checkpoint["classification_labels"] = self.labels_list def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # avoid errors due to unexpected keys (e.g., if loading checkpoint from a bce model and using it with a @@ -257,7 +262,7 @@ def predict_step( # Dummy labels to avoid errors in _get_prediction_and_labels labels = torch.zeros((len(batch), self.out_dim)).to(self.device) pr, _ = self._get_prediction_and_labels(data, labels, model_output) - return pr + return {"prediction": pr, "model_output": model_output} def _execute( self, diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index df3dab6a..297bbe0b 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -470,8 +470,13 @@ def _preprocess_smiles_for_pred( # Add dummy labels because the collate function requires them. # Note: If labels are set to `None`, the collator will insert a `non_null_labels` entry into `loss_kwargs`, # which later causes `_get_prediction_and_labels` method in the prediction pipeline to treat the data as empty. + num_of_labels = int(model_hparams["out_dim"]) return self.reader.to_data( - {"id": f"smiles_{idx}", "features": smiles, "labels": [1, 2]} + { + "id": f"smiles_{idx}", + "features": smiles, + "labels": list(range(1, num_of_labels + 1)), + } ) def prepare_data(self, *args, **kwargs) -> None: diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index fda8a308..51bec4cb 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -57,6 +57,7 @@ def __init__( self._model_hparams = ckpt_file["hyper_parameters"] self._model_hparams.pop("_instantiator", None) + self._model_hparams.pop("classes_txt_file_path", None) self._model: ChebaiBaseNet = instantiate_module( ChebaiBaseNet, self._model_hparams ) @@ -106,11 +107,14 @@ def _add_class_columns(class_file_path: _PATH) -> list[str]: return [cls.strip() for cls in f.readlines()] if self._classification_labels is not None: + # Prioritize classification labels saved in the checkpoint CLASS_LABELS = self._classification_labels # --- For old checkpoints that do not have classification_labels saved --- elif classes_path is not None: + # If user provides a classes_path, use it CLASS_LABELS = _add_class_columns(classes_path) elif os.path.exists(self._dm.classes_txt_file_path): + # Check existence of classes_txt_file_path which the datamodule points to CLASS_LABELS = _add_class_columns(self._dm.classes_txt_file_path) preds: list[torch.Tensor | None] = self.predict_smiles(smiles=smiles_strings) @@ -119,21 +123,12 @@ def _add_class_columns(class_file_path: _PATH) -> list[str]: return # --- Logic for old checkpoints that do not have classification_labels saved --- - if CLASS_LABELS is not None and self._model.out_dim is not None: - assert len(CLASS_LABELS) > 0, "Class labels list is empty." - assert len(CLASS_LABELS) == self._model.out_dim, ( - f"Number of class labels ({len(CLASS_LABELS)}) does not match " - f"the model output dimension ({self._model.out_dim})." - ) - num_of_cols = len(CLASS_LABELS) - elif CLASS_LABELS is not None: + if CLASS_LABELS is not None: assert len(CLASS_LABELS) > 0, "Class labels list is empty." num_of_cols = len(CLASS_LABELS) - elif self._model.out_dim is not None: - num_of_cols = self._model.out_dim else: - # find first non-None tensor to determine width - num_of_cols = next(x.numel() for x in preds if x is not None) + # self._model.out_dim is already asserted during model initialization + num_of_cols = self._model.out_dim CLASS_LABELS = [f"class_{i}" for i in range(num_of_cols)] rows = [ @@ -167,9 +162,10 @@ def predict_smiles( preds = [] for batch_idx, batch in enumerate(pred_dl): # For certain model prediction pipelines, we may need data module hyperparameters - preds.append( - self._model.predict_step(batch, batch_idx, dm_hparams=self._dm_hparams) + result = self._model.predict_step( + batch, batch_idx, dm_hparams=self._dm_hparams ) + preds.append(result["prediction"]) preds = torch.cat(preds) # Initialize output with None From 7e1143e665fafc3c6f23c1a13eb0a8ade49afb83 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 30 Jan 2026 20:14:10 +0100 Subject: [PATCH 3/9] make classification labels to be in ckpt file mandatory --- README.md | 7 ++--- chebai/models/base.py | 2 ++ chebai/result/prediction.py | 57 ++++++++----------------------------- 3 files changed, 16 insertions(+), 50 deletions(-) diff --git a/README.md b/README.md index d9145cf3..e6cc0099 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,7 @@ python -m chebai fit --config=[path-to-your-esol-config] --trainer.callbacks=con ### Predicting classes given SMILES strings ``` -python3 chebai/result/prediction.py predict_from_file --checkpoint_path=[path-to-model] ----smiles_file_path=[path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]] +python3 chebai/result/prediction.py predict_from_file --checkpoint_path=[path-to-model] ----smiles_file_path=[path-to-file-containing-smiles] [--save_to=[path-to-output]] ``` * **`--checkpoint_path`**: Path to the Lightning checkpoint file (must end with `.ckpt`). @@ -87,10 +87,7 @@ python3 chebai/result/prediction.py predict_from_file --checkpoint_path=[path-t * **`--save_to`** *(optional)*: Predictions will be saved to the path as CSV file. The CSV will contain one row per SMILES string and one column per predicted class. Default path will be the current working directory with file name as `predictions.csv`. -* **`--classes_path`** *(optional)*: Path to the dataset’s `classes.txt` file, which maps model output indices to ChEBI IDs. - * Checkpoints created after PR #135 will have the classification labels stored in them and hence this parameter is not required. - * If provided, the CSV columns will be named using the ChEBI IDs. - * If omitted, then script will located the file automatically. If unable to locate then the columns will be numbered sequentially. +> **Note**: Newly created checkpoints after PR #148 must be used for this prediction pipeline. The list of ChEBI classes (classification labels) used during training is stored in new checkpoints, which are required. ## Evaluation diff --git a/chebai/models/base.py b/chebai/models/base.py index ec02e343..b77757cd 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -81,6 +81,8 @@ def __init__( self.pass_loss_kwargs = pass_loss_kwargs self.classes_txt_file_path = classes_txt_file_path + + # During prediction `classes_txt_file_path` is set to None if classes_txt_file_path is not None: with open(classes_txt_file_path, "r") as f: self.labels_list = [cls.strip() for cls in f.readlines()] diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index 51bec4cb..c0ac909b 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -1,4 +1,3 @@ -import os from typing import List, Optional import pandas as pd @@ -63,18 +62,15 @@ def __init__( ) print("*" * 10, f"Loaded model class: {self._model.__class__.__name__}") - self._classification_labels: list | None = ckpt_file.get( - "classification_labels", None + self._classification_labels: list = ckpt_file.get("classification_labels") + print(f"Loaded {len(self._classification_labels)} classification labels.") + assert len(self._classification_labels) > 0, ( + "Classification labels list is empty." + ) + assert len(self._classification_labels) == self._model.out_dim, ( + f"Number of class labels ({len(self._classification_labels)}) does not match " + f"the model output dimension ({self._model.out_dim})." ) - if self._classification_labels is not None: - print(f"Loaded {len(self._classification_labels)} classification labels.") - assert len(self._classification_labels) > 0, ( - "Classification labels list is empty." - ) - assert len(self._classification_labels) == self._model.out_dim, ( - f"Number of class labels ({len(self._classification_labels)}) does not match " - f"the model output dimension ({self._model.out_dim})." - ) if compile_model: self._model = torch.compile(self._model) @@ -85,7 +81,6 @@ def predict_from_file( self, smiles_file_path: _PATH, save_to: _PATH = "predictions.csv", - classes_path: Optional[_PATH] = None, ) -> None: """ Loads a model from a checkpoint and makes predictions on input data from a file. @@ -93,49 +88,23 @@ def predict_from_file( Args: smiles_file_path: Path to the input file containing SMILES strings. save_to: Path to save the predictions CSV file. - classes_path: Optional path to a file containing class names: - if no class names are provided, code will try to get the class path - from the datamodule, else the columns will be numbered. """ with open(smiles_file_path, "r") as input: smiles_strings = [inp.strip() for inp in input.readlines()] - CLASS_LABELS: list | None = None - - def _add_class_columns(class_file_path: _PATH) -> list[str]: - with open(class_file_path, "r") as f: - return [cls.strip() for cls in f.readlines()] - - if self._classification_labels is not None: - # Prioritize classification labels saved in the checkpoint - CLASS_LABELS = self._classification_labels - # --- For old checkpoints that do not have classification_labels saved --- - elif classes_path is not None: - # If user provides a classes_path, use it - CLASS_LABELS = _add_class_columns(classes_path) - elif os.path.exists(self._dm.classes_txt_file_path): - # Check existence of classes_txt_file_path which the datamodule points to - CLASS_LABELS = _add_class_columns(self._dm.classes_txt_file_path) - preds: list[torch.Tensor | None] = self.predict_smiles(smiles=smiles_strings) if all(pred is None for pred in preds): print("No valid predictions were made. (All predictions are None.)") return - # --- Logic for old checkpoints that do not have classification_labels saved --- - if CLASS_LABELS is not None: - assert len(CLASS_LABELS) > 0, "Class labels list is empty." - num_of_cols = len(CLASS_LABELS) - else: - # self._model.out_dim is already asserted during model initialization - num_of_cols = self._model.out_dim - CLASS_LABELS = [f"class_{i}" for i in range(num_of_cols)] - + num_of_cols = len(self._classification_labels) rows = [ pred.tolist() if pred is not None else [None] * num_of_cols for pred in preds ] - predictions_df = pd.DataFrame(rows, columns=CLASS_LABELS, index=smiles_strings) + predictions_df = pd.DataFrame( + rows, columns=self._classification_labels, index=smiles_strings + ) predictions_df.to_csv(save_to) print(f"Predictions saved to: {save_to}") @@ -184,14 +153,12 @@ def predict_from_file( checkpoint_path: _PATH, smiles_file_path: _PATH, save_to: _PATH = "predictions.csv", - classes_path: Optional[_PATH] = None, batch_size: Optional[int] = None, ) -> None: predictor = Predictor(checkpoint_path, batch_size) predictor.predict_from_file( smiles_file_path, save_to, - classes_path, ) @staticmethod From ac863cd7d583ab3195f8bf122cb71e69c4223983 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 30 Jan 2026 23:27:57 +0100 Subject: [PATCH 4/9] avoid generated dummy labels for each smiles --- chebai/preprocessing/datasets/base.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 297bbe0b..ad92f892 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -452,6 +452,9 @@ def _process_input_for_prediction( tuple[list, list]: Processed input data and valid indices. """ data, valid_indices = [], [] + num_of_labels = int(model_hparams["out_dim"]) + self._dummy_labels: list = list(range(1, num_of_labels + 1)) + for idx, smiles in enumerate(smiles_list): result = self._preprocess_smiles_for_pred(idx, smiles, model_hparams) if result is None or result["features"] is None: @@ -470,12 +473,11 @@ def _preprocess_smiles_for_pred( # Add dummy labels because the collate function requires them. # Note: If labels are set to `None`, the collator will insert a `non_null_labels` entry into `loss_kwargs`, # which later causes `_get_prediction_and_labels` method in the prediction pipeline to treat the data as empty. - num_of_labels = int(model_hparams["out_dim"]) return self.reader.to_data( { "id": f"smiles_{idx}", "features": smiles, - "labels": list(range(1, num_of_labels + 1)), + "labels": self._dummy_labels, } ) From 9daa71248fd8957d7058a64e2df11f1c6756cdc4 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 2 Feb 2026 13:16:20 +0100 Subject: [PATCH 5/9] add vs extensions setting for the projects --- .vscode/extensions.json | 11 +++++++++++ .vscode/settings.json | 16 ++++++++++++++++ 2 files changed, 27 insertions(+) create mode 100644 .vscode/extensions.json create mode 100644 .vscode/settings.json diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 00000000..d1e06324 --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,11 @@ +{ + "recommendations": [ + "ms-python.python", + "ms-python.vscode-pylance", + "charliermarsh.ruff", + "usernamehw.errorlens" + ], + "unwantedRecommendations": [ + "ms-python.vscode-python2" + ] +} diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..d3ff676e --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,16 @@ +{ + "python.testing.unittestArgs": [ + "-v", + "-s", + "./tests", + "-p", + "test*.py" + ], + "python.testing.pytestEnabled": false, + "python.testing.unittestEnabled": true, + "python.analysis.typeCheckingMode": "standard", + "editor.formatOnSave": true, + "[python]": { + "editor.defaultFormatter": "charliermarsh.ruff" + }, +} From c58428484c9997a0f8d6b0920a7bb8310160fc04 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 2 Feb 2026 16:56:21 +0100 Subject: [PATCH 6/9] fix model device error + config fp --- .gitignore | 2 +- README.md | 4 ++-- chebai/result/prediction.py | 6 ++++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index be471dd0..f1286a8c 100644 --- a/.gitignore +++ b/.gitignore @@ -175,7 +175,7 @@ chebai.egg-info lightning_logs logs .isort.cfg -/.vscode +/.vscode/launch.json *.out *.err diff --git a/README.md b/README.md index e6cc0099..952bc858 100644 --- a/README.md +++ b/README.md @@ -63,7 +63,7 @@ python -m chebai fit --trainer=configs/training/default_trainer.yml --model=conf ``` A command with additional options may look like this: ``` -python3 -m chebai fit --trainer=configs/training/default_trainer.yml --model=configs/model/electra.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --model.pretrained_checkpoint=electra_pretrained.ckpt --model.load_prefix=generator. --data=configs/data/chebi/chebi50.yml --model.criterion=configs/loss/bce.yml --data.init_args.batch_size=10 --trainer.logger.init_args.name=chebi50_bce_unweighted --data.init_args.num_workers=9 --model.pass_loss_kwargs=false --data.init_args.chebi_version=231 --data.init_args.data_limit=1000 +python3 -m chebai fit --trainer=configs/training/default_trainer.yml --model=configs/model/electra.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --model.pretrained_checkpoint=electra_pretrained.ckpt --model.load_prefix=generator. --data=configs/data/chebi/chebi50.yml --model.criterion=configs/loss/bce_weighted.yml --data.init_args.batch_size=10 --trainer.logger.init_args.name=chebi50_bce_unweighted --data.init_args.num_workers=9 --model.pass_loss_kwargs=false --data.init_args.chebi_version=231 --data.init_args.data_limit=1000 ``` ### Fine-tuning for classification tasks, e.g. Toxicity prediction @@ -101,7 +101,7 @@ An example notebook is provided at `tutorials/eval_model_basic.ipynb`. Alternatively, you can evaluate the model via the CLI: ```bash -python -m chebai test --trainer=configs/training/default_trainer.yml --trainer.devices=1 --trainer.num_nodes=1 --ckpt_path=[path-to-finetuned-model] --model=configs/model/electra.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --data=configs/data/chebi/chebi50.yml --data.init_args.batch_size=32 --data.init_args.num_workers=10 --data.init_args.chebi_version=[chebi-version] --model.pass_loss_kwargs=false --model.criterion=configs/loss/bce.yml --model.criterion.init_args.beta=0.99 --data.init_args.splits_file_path=[path-to-splits-file] +python -m chebai test --trainer=configs/training/default_trainer.yml --trainer.devices=1 --trainer.num_nodes=1 --ckpt_path=[path-to-finetuned-model] --model=configs/model/electra.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --data=configs/data/chebi/chebi50.yml --data.init_args.batch_size=32 --data.init_args.num_workers=10 --data.init_args.chebi_version=[chebi-version] --model.pass_loss_kwargs=false --model.criterion=configs/loss/bce_weighted.yml --model.criterion.init_args.beta=0.99 --data.init_args.splits_file_path=[path-to-splits-file] ``` > **Note**: It is recommended to use `devices=1` and `num_nodes=1` during testing; multi-device settings use a `DistributedSampler`, which may replicate some samples to maintain equal batch sizes, so using a single device ensures that each sample or batch is evaluated exactly once. diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index c0ac909b..68a16d97 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -39,6 +39,7 @@ def __init__( ) print("-" * 50) + print(f"Using device: {self.device}") print(f"For Loaded checkpoint from: {checkpoint_path}") print("Below are the modules loaded from the checkpoint:") @@ -60,6 +61,7 @@ def __init__( self._model: ChebaiBaseNet = instantiate_module( ChebaiBaseNet, self._model_hparams ) + self._model.to(self.device) print("*" * 10, f"Loaded model class: {self._model.__class__.__name__}") self._classification_labels: list = ckpt_file.get("classification_labels") @@ -73,7 +75,7 @@ def __init__( ) if compile_model: - self._model = torch.compile(self._model) + self._model = torch.compile(self._model) # type: ignore self._model.eval() print("-" * 50) @@ -166,7 +168,7 @@ def predict_smiles( checkpoint_path: _PATH, smiles: List[str], batch_size: Optional[int] = None, - ) -> torch.Tensor: + ) -> list[torch.Tensor | None]: predictor = Predictor(checkpoint_path, batch_size) return predictor.predict_smiles(smiles) From 2ef9cc81ad0670e35999943d6b7f0fe78112dea2 Mon Sep 17 00:00:00 2001 From: Aditya Khedekar <65857172+aditya0by0@users.noreply.github.com> Date: Tue, 3 Feb 2026 11:46:19 +0100 Subject: [PATCH 7/9] Change Python type checking mode to 'basic' --- .vscode/settings.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index d3ff676e..1bada353 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -8,7 +8,7 @@ ], "python.testing.pytestEnabled": false, "python.testing.unittestEnabled": true, - "python.analysis.typeCheckingMode": "standard", + "python.analysis.typeCheckingMode": "basic", "editor.formatOnSave": true, "[python]": { "editor.defaultFormatter": "charliermarsh.ruff" From 5e9d5e9b9325aa132c46d18774c93ce4e6370a87 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 3 Feb 2026 20:36:49 +0100 Subject: [PATCH 8/9] use execute for predict step --- chebai/models/base.py | 10 +--------- chebai/result/prediction.py | 2 +- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/chebai/models/base.py b/chebai/models/base.py index b77757cd..df060e9a 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -256,15 +256,7 @@ def predict_step( Returns: Dict[str, Union[torch.Tensor, Any]]: The result of the prediction step. """ - assert isinstance(batch, XYData) - batch = batch.to(self.device) - data = self._process_batch(batch, batch_idx) - model_output = self(data, **data.get("model_kwargs", dict())) - - # Dummy labels to avoid errors in _get_prediction_and_labels - labels = torch.zeros((len(batch), self.out_dim)).to(self.device) - pr, _ = self._get_prediction_and_labels(data, labels, model_output) - return {"prediction": pr, "model_output": model_output} + return self._execute(batch, batch_idx, log=False) def _execute( self, diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index 68a16d97..8748ef73 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -136,7 +136,7 @@ def predict_smiles( result = self._model.predict_step( batch, batch_idx, dm_hparams=self._dm_hparams ) - preds.append(result["prediction"]) + preds.append(result["preds"]) preds = torch.cat(preds) # Initialize output with None From c7b3a86c2e4f0ae33ec77829313d01b3ab23e0bf Mon Sep 17 00:00:00 2001 From: Aditya Khedekar <65857172+aditya0by0@users.noreply.github.com> Date: Tue, 10 Feb 2026 00:32:28 +0100 Subject: [PATCH 9/9] Update .vscode/settings.json Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .vscode/settings.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 1bada353..dbebc3c5 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -12,5 +12,5 @@ "editor.formatOnSave": true, "[python]": { "editor.defaultFormatter": "charliermarsh.ruff" - }, + } }