diff --git a/.gitignore b/.gitignore index be471dd0..f1286a8c 100644 --- a/.gitignore +++ b/.gitignore @@ -175,7 +175,7 @@ chebai.egg-info lightning_logs logs .isort.cfg -/.vscode +/.vscode/launch.json *.out *.err diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 00000000..d1e06324 --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,11 @@ +{ + "recommendations": [ + "ms-python.python", + "ms-python.vscode-pylance", + "charliermarsh.ruff", + "usernamehw.errorlens" + ], + "unwantedRecommendations": [ + "ms-python.vscode-python2" + ] +} diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..dbebc3c5 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,16 @@ +{ + "python.testing.unittestArgs": [ + "-v", + "-s", + "./tests", + "-p", + "test*.py" + ], + "python.testing.pytestEnabled": false, + "python.testing.unittestEnabled": true, + "python.analysis.typeCheckingMode": "basic", + "editor.formatOnSave": true, + "[python]": { + "editor.defaultFormatter": "charliermarsh.ruff" + } +} diff --git a/README.md b/README.md index 7672bc28..952bc858 100644 --- a/README.md +++ b/README.md @@ -63,7 +63,7 @@ python -m chebai fit --trainer=configs/training/default_trainer.yml --model=conf ``` A command with additional options may look like this: ``` -python3 -m chebai fit --trainer=configs/training/default_trainer.yml --model=configs/model/electra.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --model.pretrained_checkpoint=electra_pretrained.ckpt --model.load_prefix=generator. --data=configs/data/chebi/chebi50.yml --model.criterion=configs/loss/bce.yml --data.init_args.batch_size=10 --trainer.logger.init_args.name=chebi50_bce_unweighted --data.init_args.num_workers=9 --model.pass_loss_kwargs=false --data.init_args.chebi_version=231 --data.init_args.data_limit=1000 +python3 -m chebai fit --trainer=configs/training/default_trainer.yml --model=configs/model/electra.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --model.pretrained_checkpoint=electra_pretrained.ckpt --model.load_prefix=generator. --data=configs/data/chebi/chebi50.yml --model.criterion=configs/loss/bce_weighted.yml --data.init_args.batch_size=10 --trainer.logger.init_args.name=chebi50_bce_unweighted --data.init_args.num_workers=9 --model.pass_loss_kwargs=false --data.init_args.chebi_version=231 --data.init_args.data_limit=1000 ``` ### Fine-tuning for classification tasks, e.g. Toxicity prediction @@ -78,11 +78,16 @@ python -m chebai fit --config=[path-to-your-esol-config] --trainer.callbacks=con ### Predicting classes given SMILES strings ``` -python3 -m chebai predict_from_file --model=[path-to-model-config] --checkpoint_path=[path-to-model] --input_path={path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]] +python3 chebai/result/prediction.py predict_from_file --checkpoint_path=[path-to-model] ----smiles_file_path=[path-to-file-containing-smiles] [--save_to=[path-to-output]] ``` -The input files should contain a list of line-separated SMILES strings. This generates a CSV file that contains the -one row for each SMILES string and one column for each class. -The `classes_path` is the path to the dataset's `raw/classes.txt` file that contains the relationship between model output and ChEBI-IDs. + +* **`--checkpoint_path`**: Path to the Lightning checkpoint file (must end with `.ckpt`). + +* **`--smiles_file_path`**: Path to a text file containing one SMILES string per line. + +* **`--save_to`** *(optional)*: Predictions will be saved to the path as CSV file. The CSV will contain one row per SMILES string and one column per predicted class. Default path will be the current working directory with file name as `predictions.csv`. + +> **Note**: Newly created checkpoints after PR #148 must be used for this prediction pipeline. The list of ChEBI classes (classification labels) used during training is stored in new checkpoints, which are required. ## Evaluation @@ -96,7 +101,7 @@ An example notebook is provided at `tutorials/eval_model_basic.ipynb`. Alternatively, you can evaluate the model via the CLI: ```bash -python -m chebai test --trainer=configs/training/default_trainer.yml --trainer.devices=1 --trainer.num_nodes=1 --ckpt_path=[path-to-finetuned-model] --model=configs/model/electra.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --data=configs/data/chebi/chebi50.yml --data.init_args.batch_size=32 --data.init_args.num_workers=10 --data.init_args.chebi_version=[chebi-version] --model.pass_loss_kwargs=false --model.criterion=configs/loss/bce.yml --model.criterion.init_args.beta=0.99 --data.init_args.splits_file_path=[path-to-splits-file] +python -m chebai test --trainer=configs/training/default_trainer.yml --trainer.devices=1 --trainer.num_nodes=1 --ckpt_path=[path-to-finetuned-model] --model=configs/model/electra.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --data=configs/data/chebi/chebi50.yml --data.init_args.batch_size=32 --data.init_args.num_workers=10 --data.init_args.chebi_version=[chebi-version] --model.pass_loss_kwargs=false --model.criterion=configs/loss/bce_weighted.yml --model.criterion.init_args.beta=0.99 --data.init_args.splits_file_path=[path-to-splits-file] ``` > **Note**: It is recommended to use `devices=1` and `num_nodes=1` during testing; multi-device settings use a `DistributedSampler`, which may replicate some samples to maintain equal batch sizes, so using a single device ensures that each sample or batch is evaluated exactly once. diff --git a/chebai/cli.py b/chebai/cli.py index 1aaba53c..8b51e45f 100644 --- a/chebai/cli.py +++ b/chebai/cli.py @@ -59,6 +59,12 @@ def call_data_methods(data: Type[XYBaseDataModule]): apply_on="instantiate", ) + parser.link_arguments( + "data.classes_txt_file_path", + "model.init_args.classes_txt_file_path", + apply_on="instantiate", + ) + for kind in ("train", "val", "test"): for average in ( "micro-f1", @@ -112,7 +118,6 @@ def subcommands() -> Dict[str, Set[str]]: "validate": {"model", "dataloaders", "datamodule"}, "test": {"model", "dataloaders", "datamodule"}, "predict": {"model", "dataloaders", "datamodule"}, - "predict_from_file": {"model"}, } diff --git a/chebai/models/base.py b/chebai/models/base.py index 82d84033..df060e9a 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -40,6 +40,7 @@ def __init__( pass_loss_kwargs: bool = True, optimizer_kwargs: Optional[Dict[str, Any]] = None, exclude_hyperparameter_logging: Optional[Iterable[str]] = None, + classes_txt_file_path: Optional[str] = None, **kwargs, ): super().__init__(**kwargs) @@ -47,8 +48,8 @@ def __init__( if exclude_hyperparameter_logging is None: exclude_hyperparameter_logging = tuple() self.criterion = criterion - assert out_dim is not None, "out_dim must be specified" - assert input_dim is not None, "input_dim must be specified" + assert out_dim is not None and out_dim > 0, "out_dim must be specified" + assert input_dim is not None and input_dim > 0, "input_dim must be specified" self.out_dim = out_dim self.input_dim = input_dim print( @@ -62,6 +63,7 @@ def __init__( "train_metrics", "val_metrics", "test_metrics", + "classes_txt_file_path", *exclude_hyperparameter_logging, ] ) @@ -78,6 +80,23 @@ def __init__( self.test_metrics = test_metrics self.pass_loss_kwargs = pass_loss_kwargs + self.classes_txt_file_path = classes_txt_file_path + + # During prediction `classes_txt_file_path` is set to None + if classes_txt_file_path is not None: + with open(classes_txt_file_path, "r") as f: + self.labels_list = [cls.strip() for cls in f.readlines()] + assert len(self.labels_list) > 0, "Class labels list is empty." + assert len(self.labels_list) == out_dim, ( + f"Number of class labels ({len(self.labels_list)}) does not match " + f"the model output dimension ({out_dim})." + ) + + def on_save_checkpoint(self, checkpoint): + if self.classes_txt_file_path is not None: + # https://lightning.ai/docs/pytorch/stable/common/checkpointing_intermediate.html#modify-a-checkpoint-anywhere + checkpoint["classification_labels"] = self.labels_list + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # avoid errors due to unexpected keys (e.g., if loading checkpoint from a bce model and using it with a # different loss) @@ -100,7 +119,7 @@ def __init_subclass__(cls, **kwargs): def _get_prediction_and_labels( self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor - ) -> (torch.Tensor, torch.Tensor): + ) -> tuple[torch.Tensor, torch.Tensor]: """ Gets the predictions and labels from the model output. @@ -151,7 +170,7 @@ def _process_for_loss( model_output: torch.Tensor, labels: torch.Tensor, loss_kwargs: Dict[str, Any], - ) -> (torch.Tensor, torch.Tensor, Dict[str, Any]): + ) -> tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: """ Processes the data for loss computation. @@ -237,7 +256,7 @@ def predict_step( Returns: Dict[str, Union[torch.Tensor, Any]]: The result of the prediction step. """ - return self._execute(batch, batch_idx, self.test_metrics, prefix="", log=False) + return self._execute(batch, batch_idx, log=False) def _execute( self, diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 9ff40748..ad92f892 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -340,18 +340,19 @@ def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]: for d in tqdm.tqdm(self._load_dict(path), total=lines) if d["features"] is not None ] - # filter for missing features in resulting data, keep features length below token limit - data = [ - val - for val in data - if val["features"] is not None - and ( - self.n_token_limit is None or len(val["features"]) <= self.n_token_limit - ) - ] + data = [val for val in data if self._filter_to_token_limit(val)] return data + def _filter_to_token_limit(self, data_instance: dict) -> bool: + # filter for missing features in resulting data, keep features length below token limit + if data_instance["features"] is not None and ( + self.n_token_limit is None + or len(data_instance["features"]) <= self.n_token_limit + ): + return True + return False + def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: """ Returns the train DataLoader. @@ -401,22 +402,84 @@ def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader] Returns: Union[DataLoader, List[DataLoader]]: A DataLoader object for test data. """ + return self.dataloader("test", shuffle=False, **kwargs) def predict_dataloader( - self, *args, **kwargs - ) -> Union[DataLoader, List[DataLoader]]: + self, + smiles_list: List[str], + model_hparams: Optional[dict] = None, + **kwargs, + ) -> tuple[DataLoader, list[int]]: """ Returns the predict DataLoader. Args: - *args: Additional positional arguments (unused). + smiles_list (List[str]): List of SMILES strings to predict. + model_hparams (Optional[dict]): Model hyperparameters. + Some prediction pre-processing pipelines may require these. **kwargs: Additional keyword arguments, passed to dataloader(). Returns: - Union[DataLoader, List[DataLoader]]: A DataLoader object for prediction data. + tuple[DataLoader, list[int]]: A DataLoader object for prediction data and a list of valid indices. """ - return self.dataloader(self.prediction_kind, shuffle=False, **kwargs) + + data, valid_indices = self._process_input_for_prediction( + smiles_list, model_hparams + ) + return ( + DataLoader( + data, + collate_fn=self.reader.collator, + batch_size=self.batch_size, + **kwargs, + ), + valid_indices, + ) + + def _process_input_for_prediction( + self, smiles_list: list[str], model_hparams: Optional[dict] = None + ) -> tuple[list, list]: + """ + Process input data for prediction. + + Args: + smiles_list (List[str]): List of SMILES strings. + model_hparams (Optional[dict]): Model hyperparameters. + Some prediction pre-processing pipelines may require these. + + Returns: + tuple[list, list]: Processed input data and valid indices. + """ + data, valid_indices = [], [] + num_of_labels = int(model_hparams["out_dim"]) + self._dummy_labels: list = list(range(1, num_of_labels + 1)) + + for idx, smiles in enumerate(smiles_list): + result = self._preprocess_smiles_for_pred(idx, smiles, model_hparams) + if result is None or result["features"] is None: + continue + if not self._filter_to_token_limit(result): + continue + data.append(result) + valid_indices.append(idx) + + return data, valid_indices + + def _preprocess_smiles_for_pred( + self, idx, smiles: str, model_hparams: Optional[dict] = None + ) -> dict: + """Preprocess prediction data.""" + # Add dummy labels because the collate function requires them. + # Note: If labels are set to `None`, the collator will insert a `non_null_labels` entry into `loss_kwargs`, + # which later causes `_get_prediction_and_labels` method in the prediction pipeline to treat the data as empty. + return self.reader.to_data( + { + "id": f"smiles_{idx}", + "features": smiles, + "labels": self._dummy_labels, + } + ) def prepare_data(self, *args, **kwargs) -> None: if self._prepare_data_flag != 1: @@ -563,6 +626,19 @@ def raw_file_names_dict(self) -> dict: """ raise NotImplementedError + @property + def classes_txt_file_path(self) -> str: + """ + Returns the filename for the classes text file. + + Returns: + str: The filename for the classes text file. + """ + # This property also used in following places: + # - results/prediction.py: to load class names for csv columns names + # - chebai/cli.py: to link this property to `model.init_args.classes_txt_file_path` + return os.path.join(self.processed_dir_main, "classes.txt") + class MergedDataset(XYBaseDataModule): MERGED = [] @@ -1189,7 +1265,8 @@ def _retrieve_splits_from_csv(self) -> None: print(f"Applying label filter from {self.apply_label_filter}...") with open(self.apply_label_filter, "r") as f: label_filter = [line.strip() for line in f] - with open(os.path.join(self.processed_dir_main, "classes.txt"), "r") as cf: + + with open(self.classes_txt_file_path, "r") as cf: classes = [line.strip() for line in cf] # reorder labels old_labels = np.stack(df_data["labels"]) diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py new file mode 100644 index 00000000..8748ef73 --- /dev/null +++ b/chebai/result/prediction.py @@ -0,0 +1,179 @@ +from typing import List, Optional + +import pandas as pd +import torch +from jsonargparse import CLI +from lightning.fabric.utilities.types import _PATH +from lightning.pytorch.cli import instantiate_module + +from chebai.models.base import ChebaiBaseNet +from chebai.preprocessing.datasets.base import XYBaseDataModule + + +class Predictor: + def __init__( + self, + checkpoint_path: _PATH, + batch_size: Optional[int] = None, + compile_model: bool = True, + ): + """Initializes the Predictor with a model loaded from the checkpoint. + + Args: + checkpoint_path: Path to the model checkpoint. + batch_size: Optional batch size for the DataLoader. If not provided, + the default from the datamodule will be used. + compile_model: Whether to compile the model using torch.compile. Default is True. + """ + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + ckpt_file = torch.load( + checkpoint_path, map_location=self.device, weights_only=False + ) + assert ( + "_class_path" in ckpt_file["datamodule_hyper_parameters"] + and "_class_path" in ckpt_file["hyper_parameters"] + ), ( + "Datamodule and Model hyperparameters must include a '_class_path' key.\n" + "Hence, either the checkpoint is corrupted or " + "it was not saved properly with latest lightning version" + ) + + print("-" * 50) + print(f"Using device: {self.device}") + print(f"For Loaded checkpoint from: {checkpoint_path}") + print("Below are the modules loaded from the checkpoint:") + + self._dm_hparams = ckpt_file["datamodule_hyper_parameters"] + self._dm_hparams.pop("splits_file_path") + self._dm_hparams.pop("augment_smiles", None) + self._dm_hparams.pop("aug_smiles_variations", None) + self._dm_hparams.pop("_instantiator", None) + self._dm: XYBaseDataModule = instantiate_module( + XYBaseDataModule, self._dm_hparams + ) + if batch_size is not None and int(batch_size) > 0: + self._dm.batch_size = int(batch_size) + print("*" * 10, f"Loaded datamodule class: {self._dm.__class__.__name__}") + + self._model_hparams = ckpt_file["hyper_parameters"] + self._model_hparams.pop("_instantiator", None) + self._model_hparams.pop("classes_txt_file_path", None) + self._model: ChebaiBaseNet = instantiate_module( + ChebaiBaseNet, self._model_hparams + ) + self._model.to(self.device) + print("*" * 10, f"Loaded model class: {self._model.__class__.__name__}") + + self._classification_labels: list = ckpt_file.get("classification_labels") + print(f"Loaded {len(self._classification_labels)} classification labels.") + assert len(self._classification_labels) > 0, ( + "Classification labels list is empty." + ) + assert len(self._classification_labels) == self._model.out_dim, ( + f"Number of class labels ({len(self._classification_labels)}) does not match " + f"the model output dimension ({self._model.out_dim})." + ) + + if compile_model: + self._model = torch.compile(self._model) # type: ignore + self._model.eval() + print("-" * 50) + + def predict_from_file( + self, + smiles_file_path: _PATH, + save_to: _PATH = "predictions.csv", + ) -> None: + """ + Loads a model from a checkpoint and makes predictions on input data from a file. + + Args: + smiles_file_path: Path to the input file containing SMILES strings. + save_to: Path to save the predictions CSV file. + """ + with open(smiles_file_path, "r") as input: + smiles_strings = [inp.strip() for inp in input.readlines()] + + preds: list[torch.Tensor | None] = self.predict_smiles(smiles=smiles_strings) + if all(pred is None for pred in preds): + print("No valid predictions were made. (All predictions are None.)") + return + + num_of_cols = len(self._classification_labels) + rows = [ + pred.tolist() if pred is not None else [None] * num_of_cols + for pred in preds + ] + predictions_df = pd.DataFrame( + rows, columns=self._classification_labels, index=smiles_strings + ) + + predictions_df.to_csv(save_to) + print(f"Predictions saved to: {save_to}") + + @torch.inference_mode() + def predict_smiles( + self, + smiles: List[str], + ) -> list[torch.Tensor | None]: + """ + Predicts the output for a list of SMILES strings using the model. + + Args: + smiles: A list of SMILES strings. + + Returns: + A tensor containing the predictions. + """ + # For certain data prediction piplines, we may need model hyperparameters + pred_dl, valid_indices = self._dm.predict_dataloader( + smiles_list=smiles, model_hparams=self._model_hparams + ) + + preds = [] + for batch_idx, batch in enumerate(pred_dl): + # For certain model prediction pipelines, we may need data module hyperparameters + result = self._model.predict_step( + batch, batch_idx, dm_hparams=self._dm_hparams + ) + preds.append(result["preds"]) + preds = torch.cat(preds) + + # Initialize output with None + output: list[torch.Tensor | None] = [None] * len(smiles) + + # Scatter predictions back + for pred, idx in zip(preds, valid_indices): + output[idx] = pred + + return output + + +class MainPredictor: + @staticmethod + def predict_from_file( + checkpoint_path: _PATH, + smiles_file_path: _PATH, + save_to: _PATH = "predictions.csv", + batch_size: Optional[int] = None, + ) -> None: + predictor = Predictor(checkpoint_path, batch_size) + predictor.predict_from_file( + smiles_file_path, + save_to, + ) + + @staticmethod + def predict_smiles( + checkpoint_path: _PATH, + smiles: List[str], + batch_size: Optional[int] = None, + ) -> list[torch.Tensor | None]: + predictor = Predictor(checkpoint_path, batch_size) + return predictor.predict_smiles(smiles) + + +if __name__ == "__main__": + # python chebai/result/prediction.py predict_from_file --help + # python chebai/result/prediction.py predict_smiles --help + CLI(MainPredictor, as_positional=False) diff --git a/chebai/trainer/CustomTrainer.py b/chebai/trainer/CustomTrainer.py index 5c960007..b9e4b0f3 100644 --- a/chebai/trainer/CustomTrainer.py +++ b/chebai/trainer/CustomTrainer.py @@ -1,18 +1,13 @@ import logging -from typing import Any, List, Optional, Tuple +from typing import Any, Optional, Tuple -import pandas as pd -import torch -from lightning import LightningModule, Trainer +from lightning import Trainer from lightning.fabric.utilities.data import _set_sampler_epoch -from lightning.fabric.utilities.types import _PATH from lightning.pytorch.loggers import WandbLogger from lightning.pytorch.loops.fit_loop import _FitLoop from lightning.pytorch.trainer import call -from torch.nn.utils.rnn import pad_sequence from chebai.loggers.custom import CustomLogger -from chebai.preprocessing.reader import CLS_TOKEN, ChemDataReader log = logging.getLogger(__name__) @@ -74,68 +69,18 @@ def _resolve_logging_argument(self, key: str, value: Any) -> Tuple[str, Any]: else: return key, value - def predict_from_file( + def predict( self, - model: LightningModule, - checkpoint_path: _PATH, - input_path: _PATH, - save_to: _PATH = "predictions.csv", - classes_path: Optional[_PATH] = None, - **kwargs, - ) -> None: - """ - Loads a model from a checkpoint and makes predictions on input data from a file. - - Args: - model: The model to use for predictions. - checkpoint_path: Path to the model checkpoint. - input_path: Path to the input file containing SMILES strings. - save_to: Path to save the predictions CSV file. - classes_path: Optional path to a file containing class names (if no class names are provided, columns will be numbered). - """ - loaded_model = model.__class__.load_from_checkpoint(checkpoint_path) - with open(input_path, "r") as input: - smiles_strings = [inp.strip() for inp in input.readlines()] - loaded_model.eval() - predictions = self._predict_smiles(loaded_model, smiles_strings) - predictions_df = pd.DataFrame(predictions.detach().cpu().numpy()) - if classes_path is not None: - with open(classes_path, "r") as f: - predictions_df.columns = [cls.strip() for cls in f.readlines()] - predictions_df.index = smiles_strings - predictions_df.to_csv(save_to) - - def _predict_smiles( - self, model: LightningModule, smiles: List[str] - ) -> torch.Tensor: - """ - Predicts the output for a list of SMILES strings using the model. - - Args: - model: The model to use for predictions. - smiles: A list of SMILES strings. - - Returns: - A tensor containing the predictions. - """ - reader = ChemDataReader() - parsed_smiles = [reader._read_data(s) for s in smiles] - x = pad_sequence( - [torch.tensor(a, device=model.device) for a in parsed_smiles], - batch_first=True, + model=None, + dataloaders=None, + datamodule=None, + return_predictions=None, + ckpt_path=None, + ): + raise NotImplementedError( + "CustomTrainer.predict is not implemented." + "Use `Prediction.predict_from_file` or `Prediction.predict_smiles` from `chebai/result/prediction.py` instead." ) - cls_tokens = ( - torch.ones(x.shape[0], dtype=torch.int, device=model.device).unsqueeze(-1) - * CLS_TOKEN - ) - features = torch.cat((cls_tokens, x), dim=1) - model_output = model({"features": features}) - if model.model_type == "regression": - preds = model_output["logits"] - else: - preds = torch.sigmoid(model_output["logits"]) - - return preds @property def log_dir(self) -> Optional[str]: diff --git a/tests/unit/cli/classification_labels.txt b/tests/unit/cli/classification_labels.txt new file mode 100644 index 00000000..06d2d6d1 --- /dev/null +++ b/tests/unit/cli/classification_labels.txt @@ -0,0 +1,10 @@ +label_1 +label_2 +label_3 +label_4 +label_5 +label_6 +label_7 +label_8 +label_9 +label_10 diff --git a/tests/unit/cli/mock_dm.py b/tests/unit/cli/mock_dm.py index 25116e21..e3fd60a7 100644 --- a/tests/unit/cli/mock_dm.py +++ b/tests/unit/cli/mock_dm.py @@ -1,3 +1,5 @@ +import os + import torch from lightning.pytorch.core.datamodule import LightningDataModule from torch.utils.data import DataLoader @@ -29,6 +31,10 @@ def num_of_labels(self): def feature_vector_size(self): return self._feature_vector_size + @property + def classes_txt_file_path(self) -> str: + return os.path.join("tests", "unit", "cli", "classification_labels.txt") + def train_dataloader(self): assert self.feature_vector_size is not None, "feature_vector_size must be set" # Dummy dataset for example purposes diff --git a/tests/unit/cli/testCLI.py b/tests/unit/cli/testCLI.py index 863a6df3..584da5e7 100644 --- a/tests/unit/cli/testCLI.py +++ b/tests/unit/cli/testCLI.py @@ -9,7 +9,7 @@ def setUp(self): "fit", "--trainer=configs/training/default_trainer.yml", "--model=configs/model/ffn.yml", - "--model.init_args.hidden_layers=[10]", + "--model.init_args.hidden_layers=[1]", "--model.train_metrics=configs/metrics/micro-macro-f1.yml", "--data=tests/unit/cli/mock_dm_config.yml", "--model.pass_loss_kwargs=false",