From 02c54090d5f5372db9ea0be8bdb906e4859ba18e Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 24 Jun 2025 20:29:07 +0200 Subject: [PATCH 01/38] api code to download model from hugging face --- .gitignore | 1 + api/hugging_face.py | 56 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+) create mode 100644 .gitignore create mode 100644 api/hugging_face.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..905568c --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/api/api_models diff --git a/api/hugging_face.py b/api/hugging_face.py new file mode 100644 index 0000000..cfcf97c --- /dev/null +++ b/api/hugging_face.py @@ -0,0 +1,56 @@ +import shutil +from pathlib import Path + +from huggingface_hub import hf_hub_download + +# Updated registry: use a list of filenames if you're downloading a folder +MODEL_REGISTRY = { + "electra": { + "repo_id": "aditya0by0/python-chebifier", + "subfolder": "electra", + "filenames": ["electra.ckpt", "classes.txt"], + } +} + +DOWNLOAD_PATH = Path(__file__).resolve().parent / "api_models" + + +def download_model(model_name): + if model_name not in MODEL_REGISTRY: + raise ValueError( + f"Unknown model name. Available models: {list(MODEL_REGISTRY.keys())}" + ) + + model_info = MODEL_REGISTRY[model_name] + repo_id = model_info["repo_id"] + subfolder = model_info["subfolder"] + filenames = model_info["filenames"] + + local_paths = [] + for filename in filenames: + local_model_path = DOWNLOAD_PATH / model_name / filename + if local_model_path.exists(): + print(f"File already exists: {local_model_path}") + local_paths.append(local_model_path) + continue + + print(f"Downloading: {repo_id}/{filename} (subfolder: {subfolder})") + downloaded_file = hf_hub_download( + repo_id=repo_id, + filename=filename, + subfolder=subfolder, + ) + + local_model_path.parent.mkdir(parents=True, exist_ok=True) + shutil.move(downloaded_file, local_model_path) + print(f"Saved to: {local_model_path}") + local_paths.append(local_model_path) + + return local_paths + + +if __name__ == "__main__": + paths = download_model("electra") + print("Downloaded files:") + for p in paths: + print(p) From b539f0af136692aa83ea4370896f97151b359dc7 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 28 Jun 2025 00:33:39 +0200 Subject: [PATCH 02/38] Create .pre-commit-config.yaml --- .pre-commit-config.yaml | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..e32d80c --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,31 @@ +repos: +- repo: https://github.com/psf/black + rev: "24.2.0" + hooks: + - id: black + - id: black-jupyter # for formatting jupyter-notebook + +- repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + name: isort (python) + args: ["--profile=black"] + +- repo: https://github.com/asottile/seed-isort-config + rev: v2.2.0 + hooks: + - id: seed-isort-config + +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.12.1 + hooks: + - id: ruff + args: [] # No --fix, disables formatting From 2c2aba2315ae0fc10c501dd4db4c44e8619df8c0 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 28 Jun 2025 00:34:40 +0200 Subject: [PATCH 03/38] utility to setup env and model package dependencies --- api/__init__.py | 0 api/setup_env.py | 165 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 165 insertions(+) create mode 100644 api/__init__.py create mode 100644 api/setup_env.py diff --git a/api/__init__.py b/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/setup_env.py b/api/setup_env.py new file mode 100644 index 0000000..a246c26 --- /dev/null +++ b/api/setup_env.py @@ -0,0 +1,165 @@ +import os +import re +import subprocess +import sys +from pathlib import Path + +# Conditional import of tomllib based on Python version +if sys.version_info >= (3, 11): + import tomllib # built-in in Python 3.11+ +else: + import toml as tomllib # use third-party toml library for older versions + + +class SetupEnvAndPackage: + """Utility class for cloning a repository, setting up a virtual environment, and installing a package.""" + + def setup( + self, + repo_url: str, + clone_dir: Path, + venv_dir: Path, + venv_name: str = ".venv-chebifier", + ) -> None: + """ + Orchestrates the full setup process: cloning the repository, + creating a virtual environment, and installing the package. + + Args: + repo_url (str): URL of the Git repository. + clone_dir (Path): Directory to clone the repo into. + venv_dir (Path): Directory where the virtual environment will be created. + venv_name (str): Name of the virtual environment folder. + """ + cloned_repo_path = self._clone_repo(repo_url, clone_dir) + venv_path = self._create_virtualenv(venv_dir, venv_name) + self._install_from_pyproject(venv_path, cloned_repo_path) + + def _clone_repo(self, repo_url: str, clone_dir: Path) -> Path: + """ + Clone a Git repository into a specified directory. + + Args: + repo_url (str): Git URL to clone. + clone_dir (Path): Directory to clone into. + + Returns: + Path: Path to the cloned repository. + """ + repo_name = repo_url.rstrip("/").split("/")[-1].replace(".git", "") + clone_path = Path(clone_dir or repo_name) + + if not clone_path.exists(): + print(f"Cloning {repo_url} into {clone_path}...") + subprocess.check_call( + ["git", "clone", "--depth=1", repo_url, str(clone_path)] + ) + else: + print(f"Repo already exists at {clone_path}") + + return clone_path + + @staticmethod + def _create_virtualenv(venv_dir: Path, venv_name: str = ".venv-chebifier") -> Path: + """ + Create a virtual environment at the specified path. + + Args: + venv_dir (Path): Base directory where the venv will be created. + venv_name (str): Name of the virtual environment directory. + + Returns: + Path: Path to the virtual environment. + """ + venv_path = venv_dir / venv_name + + if venv_path.exists(): + print(f"Virtual environment already exists at: {venv_path}") + return venv_path + + print(f"Creating virtual environment at: {venv_path}") + + try: + subprocess.check_call(["virtualenv", str(venv_path)]) + except FileNotFoundError: + print("virtualenv not found, installing it now...") + subprocess.check_call( + [sys.executable, "-m", "pip", "install", "virtualenv"] + ) + subprocess.check_call(["virtualenv", str(venv_path)]) + + return venv_path + + def _install_from_pyproject(self, venv_dir: Path, cloned_repo_path: Path) -> None: + """ + Install the cloned package in editable mode. + + Args: + venv_dir (Path): Path to the virtual environment. + cloned_repo_path (Path): Path to the cloned repository. + """ + pip_executable = ( + venv_dir / "Scripts" / "pip.exe" + if os.name == "nt" + else venv_dir / "bin" / "pip" + ) + + if not pip_executable.exists(): + raise FileNotFoundError(f"pip not found at {pip_executable}") + + try: + package_name = self._get_package_name(cloned_repo_path) + except Exception as e: + raise RuntimeError(f"Error extracting package name: {e}") + + try: + subprocess.check_output( + [str(pip_executable), "show", package_name], stderr=subprocess.DEVNULL + ) + print(f"Package '{package_name}' is already installed.") + except subprocess.CalledProcessError: + print(f"Installing '{package_name}' from {cloned_repo_path}...") + subprocess.check_call( + [str(pip_executable), "install", "-e", "."], + cwd=cloned_repo_path, + ) + + @staticmethod + def _get_package_name(cloned_repo_path: Path) -> str: + """ + Extracts the package name from `pyproject.toml` or `setup.py`. + + Args: + cloned_repo_path (Path): Path to the cloned repository. + + Returns: + str: Name of the Python package. + + Raises: + ValueError: If parsing fails. + FileNotFoundError: If neither config file is found. + """ + pyproject_path = cloned_repo_path / "pyproject.toml" + setup_path = cloned_repo_path / "setup.py" + + if pyproject_path.exists(): + try: + with pyproject_path.open("rb") as f: + pyproject = tomllib.load(f) + return pyproject["project"]["name"] + except Exception as e: + raise ValueError(f"Failed to parse pyproject.toml: {e}") + + elif setup_path.exists(): + try: + setup_contents = setup_path.read_text() + match = re.search(r'name\s*=\s*[\'"]([^\'"]+)[\'"]', setup_contents) + if match: + return match.group(1) + else: + raise ValueError("Could not find package name in setup.py") + except Exception as e: + raise ValueError(f"Failed to parse setup.py: {e}") + + else: + raise FileNotFoundError("Neither pyproject.toml nor setup.py found.") From 2b9f335c060040c6f86242435fa8bd05d3678ea6 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 28 Jun 2025 00:37:03 +0200 Subject: [PATCH 04/38] `gather_predictions` will return predicted_classes_dict --- chebifier/ensemble/base_ensemble.py | 70 ++++++++++++++++++++--------- 1 file changed, 48 insertions(+), 22 deletions(-) diff --git a/chebifier/ensemble/base_ensemble.py b/chebifier/ensemble/base_ensemble.py index edaaf5e..1869923 100644 --- a/chebifier/ensemble/base_ensemble.py +++ b/chebifier/ensemble/base_ensemble.py @@ -1,8 +1,8 @@ import os from abc import ABC + import torch import tqdm -from rdkit import Chem from chebifier.prediction_models.base_predictor import BasePredictor from chebifier.prediction_models.chemlog_predictor import ChemLogPredictor @@ -12,11 +12,11 @@ MODEL_TYPES = { "electra": ElectraPredictor, "resgated": ResGatedPredictor, - "chemlog": ChemLogPredictor + "chemlog": ChemLogPredictor, } -class BaseEnsemble(ABC): +class BaseEnsemble(ABC): def __init__(self, model_configs: dict): self.models = [] self.positive_prediction_threshold = 0.5 @@ -37,22 +37,30 @@ def gather_predictions(self, smiles_list): if logits_for_smiles is not None: for cls in logits_for_smiles: predicted_classes.add(cls) - print(f"Sorting predictions...") + print("Sorting predictions...") predicted_classes = sorted(list(predicted_classes)) predicted_classes_dict = {cls: i for i, cls in enumerate(predicted_classes)} - ordered_logits = torch.zeros(len(smiles_list), len(predicted_classes), len(self.models)) * torch.nan + ordered_logits = ( + torch.zeros(len(smiles_list), len(predicted_classes), len(self.models)) + * torch.nan + ) for i, model_prediction in enumerate(model_predictions): - for j, logits_for_smiles in tqdm.tqdm(enumerate(model_prediction), - total=len(model_prediction), - desc=f"Sorting predictions for {self.models[i].model_name}"): + for j, logits_for_smiles in tqdm.tqdm( + enumerate(model_prediction), + total=len(model_prediction), + desc=f"Sorting predictions for {self.models[i].model_name}", + ): if logits_for_smiles is not None: for cls in logits_for_smiles: - ordered_logits[j, predicted_classes_dict[cls], i] = logits_for_smiles[cls] + ordered_logits[j, predicted_classes_dict[cls], i] = ( + logits_for_smiles[cls] + ) - return ordered_logits, predicted_classes + return ordered_logits, predicted_classes_dict - - def consolidate_predictions(self, predictions, predicted_classes, classwise_weights, **kwargs): + def consolidate_predictions( + self, predictions, predicted_classes, classwise_weights, **kwargs + ): """ Aggregates predictions from multiple models using weighted majority voting. Optimized version using tensor operations instead of for loops. @@ -74,7 +82,9 @@ def consolidate_predictions(self, predictions, predicted_classes, classwise_weig positive_mask = (predictions > 0.5) & valid_predictions negative_mask = (predictions < 0.5) & valid_predictions - confidence = 2 * torch.abs(predictions.nan_to_num() - self.positive_prediction_threshold) + confidence = 2 * torch.abs( + predictions.nan_to_num() - self.positive_prediction_threshold + ) # Extract positive and negative weights pos_weights = classwise_weights[0] # Shape: (num_classes, num_models) @@ -83,8 +93,12 @@ def consolidate_predictions(self, predictions, predicted_classes, classwise_weig # Calculate weighted predictions using broadcasting # predictions shape: (num_smiles, num_classes, num_models) # weights shape: (num_classes, num_models) - positive_weighted = positive_mask.float() * confidence * pos_weights.unsqueeze(0) - negative_weighted = negative_mask.float() * confidence * neg_weights.unsqueeze(0) + positive_weighted = ( + positive_mask.float() * confidence * pos_weights.unsqueeze(0) + ) + negative_weighted = ( + negative_mask.float() * confidence * neg_weights.unsqueeze(0) + ) # Sum over models dimension positive_sum = positive_weighted.sum(dim=2) # Shape: (num_smiles, num_classes) @@ -92,17 +106,21 @@ def consolidate_predictions(self, predictions, predicted_classes, classwise_weig # Determine which classes to include for each SMILES net_score = positive_sum - negative_sum # Shape: (num_smiles, num_classes) - class_decisions = (net_score > 0) & has_valid_predictions # Shape: (num_smiles, num_classes) + class_decisions = ( + net_score > 0 + ) & has_valid_predictions # Shape: (num_smiles, num_classes) # Convert tensor decisions to result list using list comprehension for efficiency result = [ - [class_indices[idx.item()] for idx in torch.nonzero(class_decisions[i], as_tuple=True)[0]] + [ + class_indices[idx.item()] + for idx in torch.nonzero(class_decisions[i], as_tuple=True)[0] + ] for i in range(num_smiles) ] return result - def calculate_classwise_weights(self, predicted_classes): """No weights, simple majority voting""" positive_weights = torch.ones(len(predicted_classes), len(self.models)) @@ -114,18 +132,26 @@ def predict_smiles_list(self, smiles_list, load_preds_if_possible=True) -> list: preds_file = f"predictions_by_model_{'_'.join(model.model_name for model in self.models)}.pt" predicted_classes_file = f"predicted_classes_{'_'.join(model.model_name for model in self.models)}.txt" if not load_preds_if_possible or not os.path.isfile(preds_file): - ordered_predictions, predicted_classes = self.gather_predictions(smiles_list) + ordered_predictions, predicted_classes = self.gather_predictions( + smiles_list + ) # save predictions torch.save(ordered_predictions, preds_file) with open(predicted_classes_file, "w") as f: for cls in predicted_classes: f.write(f"{cls}\n") else: - print(f"Loading predictions from {preds_file} and label indexes from {predicted_classes_file}") + print( + f"Loading predictions from {preds_file} and label indexes from {predicted_classes_file}" + ) ordered_predictions = torch.load(preds_file) with open(predicted_classes_file, "r") as f: - predicted_classes = {line.strip(): i for i, line in enumerate(f.readlines())} + predicted_classes = { + line.strip(): i for i, line in enumerate(f.readlines()) + } classwise_weights = self.calculate_classwise_weights(predicted_classes) - aggregated_predictions = self.consolidate_predictions(ordered_predictions, predicted_classes, classwise_weights) + aggregated_predictions = self.consolidate_predictions( + ordered_predictions, predicted_classes, classwise_weights + ) return aggregated_predictions From 6faf3bd67298b6e33efc226f1ff612b1e4846498 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 28 Jun 2025 10:25:52 +0200 Subject: [PATCH 05/38] use package namespace imports for prediction models --- chebifier/ensemble/base_ensemble.py | 10 ++++++---- chebifier/prediction_models/__init__.py | 6 ++++++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/chebifier/ensemble/base_ensemble.py b/chebifier/ensemble/base_ensemble.py index 1869923..d4a4fe3 100644 --- a/chebifier/ensemble/base_ensemble.py +++ b/chebifier/ensemble/base_ensemble.py @@ -4,10 +4,12 @@ import torch import tqdm -from chebifier.prediction_models.base_predictor import BasePredictor -from chebifier.prediction_models.chemlog_predictor import ChemLogPredictor -from chebifier.prediction_models.electra_predictor import ElectraPredictor -from chebifier.prediction_models.gnn_predictor import ResGatedPredictor +from chebifier.prediction_models import ( + BasePredictor, + ChemLogPredictor, + ElectraPredictor, + ResGatedPredictor, +) MODEL_TYPES = { "electra": ElectraPredictor, diff --git a/chebifier/prediction_models/__init__.py b/chebifier/prediction_models/__init__.py index e69de29..ed08890 100644 --- a/chebifier/prediction_models/__init__.py +++ b/chebifier/prediction_models/__init__.py @@ -0,0 +1,6 @@ +from .base_predictor import BasePredictor +from .chemlog_predictor import ChemLogPredictor +from .electra_predictor import ElectraPredictor +from .gnn_predictor import ResGatedPredictor + +__all__ = ["BasePredictor", "ChemLogPredictor", "ElectraPredictor", "ResGatedPredictor"] From a4f5f85cc4eaaeffe4a6a5973826e49ba43f520c Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 28 Jun 2025 13:20:22 +0200 Subject: [PATCH 06/38] add hugging face api --- api/hugging_face.py | 58 ++++++++++++++------------------------------- 1 file changed, 18 insertions(+), 40 deletions(-) diff --git a/api/hugging_face.py b/api/hugging_face.py index cfcf97c..19debb4 100644 --- a/api/hugging_face.py +++ b/api/hugging_face.py @@ -3,54 +3,32 @@ from huggingface_hub import hf_hub_download -# Updated registry: use a list of filenames if you're downloading a folder -MODEL_REGISTRY = { - "electra": { - "repo_id": "aditya0by0/python-chebifier", - "subfolder": "electra", - "filenames": ["electra.ckpt", "classes.txt"], - } -} -DOWNLOAD_PATH = Path(__file__).resolve().parent / "api_models" - - -def download_model(model_name): - if model_name not in MODEL_REGISTRY: - raise ValueError( - f"Unknown model name. Available models: {list(MODEL_REGISTRY.keys())}" - ) - - model_info = MODEL_REGISTRY[model_name] - repo_id = model_info["repo_id"] - subfolder = model_info["subfolder"] - filenames = model_info["filenames"] - - local_paths = [] - for filename in filenames: - local_model_path = DOWNLOAD_PATH / model_name / filename - if local_model_path.exists(): - print(f"File already exists: {local_model_path}") - local_paths.append(local_model_path) +def download_model_files(model_config: dict, download_path: Path): + repo_id = model_config["repo_id"] + subfolder = model_config["subfolder"] + filenames = model_config["files"] + + local_paths = {} + for file_type, filename in filenames.items(): + local_file_path = download_path / filename + if local_file_path.exists(): + print(f"File already exists: {local_file_path}") + local_paths[file_type] = local_file_path continue - print(f"Downloading: {repo_id}/{filename} (subfolder: {subfolder})") + print( + f"Downloading file from: https://huggingface.co/{repo_id}/{subfolder}/{filename}" + ) downloaded_file = hf_hub_download( repo_id=repo_id, filename=filename, subfolder=subfolder, ) - local_model_path.parent.mkdir(parents=True, exist_ok=True) - shutil.move(downloaded_file, local_model_path) - print(f"Saved to: {local_model_path}") - local_paths.append(local_model_path) + local_file_path.parent.mkdir(parents=True, exist_ok=True) + shutil.move(downloaded_file, local_file_path) + print(f"Saved to: {local_file_path}") + local_paths[file_type] = local_file_path return local_paths - - -if __name__ == "__main__": - paths = download_model("electra") - print("Downloaded files:") - for p in paths: - print(p) From 481a2eb6da4fd12de4f98ecd3efdb6268e2affcc Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 28 Jun 2025 13:20:57 +0200 Subject: [PATCH 07/38] api registry --- api/registry.yml | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 api/registry.yml diff --git a/api/registry.yml b/api/registry.yml new file mode 100644 index 0000000..c9069a8 --- /dev/null +++ b/api/registry.yml @@ -0,0 +1,23 @@ +electra: + hugging_face: + repo_id: aditya0by0/python-chebifier + subfolder: electra + files: + ckpt: electra.ckpt + labels: classes.txt + repo_url: https://github.com/ChEB-AI/python-chebai + wrapper: chebifier.prediction_models.ElectraPredictor + +resgated: + hugging_face: + repo_id: aditya0by0/python-chebifier + subfolder: resgated + files: + ckpt: resgated.ckpt + labels: classes.txt + repo_url: https://github.com/ChEB-AI/python-chebai-graph + wrapper: chebifier.prediction_models.ResGatedPredictor + +chemlog: + repo_url: https://github.com/sfluegel05/chemlog-peptides + wrapper: chebifier.prediction_models.ChemLogPredictor From 584b6a6ac21b230732f7200078100e13896657c9 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 28 Jun 2025 13:21:19 +0200 Subject: [PATCH 08/38] api cli --- api/__main__.py | 10 +++++ api/cli.py | 114 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 124 insertions(+) create mode 100644 api/__main__.py create mode 100644 api/cli.py diff --git a/api/__main__.py b/api/__main__.py new file mode 100644 index 0000000..ec70a17 --- /dev/null +++ b/api/__main__.py @@ -0,0 +1,10 @@ +from .cli import cli + +if __name__ == "__main__": + """ + Entry point for the CLI application. + + This script calls the `cli` function from the `api.cli` module + when executed as the main program. + """ + cli() diff --git a/api/cli.py b/api/cli.py new file mode 100644 index 0000000..6f94431 --- /dev/null +++ b/api/cli.py @@ -0,0 +1,114 @@ +import importlib +from pathlib import Path + +import click +import yaml + +from chebifier.prediction_models.base_predictor import BasePredictor + +from .hugging_face import download_model_files +from .setup_env import SetupEnvAndPackage + +yaml_path = Path("api/registry.yml") +if yaml_path.exists(): + with yaml_path.open("r") as f: + model_registry = yaml.safe_load(f) +else: + raise FileNotFoundError(f"{yaml_path} not found.") + + +@click.group() +def cli(): + """Command line interface for Api-Chebifier.""" + pass + + +@cli.command() +@click.option("--smiles", "-s", multiple=True, help="SMILES strings to predict") +@click.option( + "--smiles-file", + "-f", + type=click.Path(exists=True), + help="File containing SMILES strings (one per line)", +) +@click.option( + "--output", + "-o", + type=click.Path(), + help="Output file to save predictions (optional)", +) +@click.option( + "--model-type", + "-m", + type=click.Choice(model_registry.keys()), + default="mv", + help="Type of model to use", +) +def predict(smiles, smiles_file, output, model_type): + """Predict ChEBI classes for SMILES strings using an ensemble model. + + CONFIG_FILE is the path to a YAML configuration file for the ensemble model. + """ + + # Collect SMILES strings from arguments and/or file + smiles_list = list(smiles) + if smiles_file: + with open(smiles_file, "r") as f: + smiles_list.extend([line.strip() for line in f if line.strip()]) + + if not smiles_list: + click.echo("No SMILES strings provided. Use --smiles or --smiles-file options.") + return + + model_config = model_registry[model_type] + predictor_kwargs = {"model_name": model_type} + + current_dir = Path(__file__).resolve().parent + + if "hugging_face" in model_config: + local_file_path = download_model_files( + model_config["hugging_face"], + current_dir / ".api_models" / model_type, + ) + predictor_kwargs["ckpt_path"] = local_file_path["ckpt"] + predictor_kwargs["target_labels_path"] = local_file_path["labels"] + + SetupEnvAndPackage().setup( + repo_url=model_config["repo_url"], + clone_dir=current_dir / ".cloned_repos", + venv_dir=current_dir, + ) + + model_cls_path = model_config["wrapper"] + module_path, class_name = model_cls_path.rsplit(".", 1) + module = importlib.import_module(module_path) + model_cls: type = getattr(module, class_name) + model_instance = model_cls(**predictor_kwargs) + assert isinstance(model_instance, BasePredictor) + + # Make predictions + predictions = model_instance.predict_smiles_list(smiles_list) + + if output: + # save as json + import json + + with open(output, "w") as f: + json.dump( + {smiles: pred for smiles, pred in zip(smiles_list, predictions)}, + f, + indent=2, + ) + + else: + # Print results + for i, (smiles, prediction) in enumerate(zip(smiles_list, predictions)): + click.echo(f"Result for: {smiles}") + if prediction: + click.echo(f" Predicted classes: {', '.join(map(str, prediction))}") + else: + click.echo(" No predictions") + + +if __name__ == "__main__": + cli() From 05d8580358038047aa387c2080f2e4f78aa84196 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 28 Jun 2025 13:21:33 +0200 Subject: [PATCH 09/38] Update .gitignore --- .gitignore | 180 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 180 insertions(+) diff --git a/.gitignore b/.gitignore index 905568c..90044ae 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,181 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ +docs/build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# configs/ # commented as new configs can be added as a part of a feature + +/.idea +/data +/logs +/results_buffer +electra_pretrained.ckpt + +build +.virtual_documents +.jupyter +chebai.egg-info +lightning_logs +logs +.isort.cfg +/.vscode /api/api_models +/api/.api_models +/api/.cloned_repos From 997120e9574c260b3c8199aa20cc05ac67032193 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 28 Jun 2025 19:20:17 +0200 Subject: [PATCH 10/38] use hugging face's cache system instead of custom file management --- .gitignore | 2 -- api/cli.py | 6 ++---- api/hugging_face.py | 50 +++++++++++++++++++++++++++++---------------- 3 files changed, 34 insertions(+), 24 deletions(-) diff --git a/.gitignore b/.gitignore index 90044ae..613c70b 100644 --- a/.gitignore +++ b/.gitignore @@ -176,6 +176,4 @@ lightning_logs logs .isort.cfg /.vscode -/api/api_models -/api/.api_models /api/.cloned_repos diff --git a/api/cli.py b/api/cli.py index 6f94431..99f7572 100644 --- a/api/cli.py +++ b/api/cli.py @@ -66,10 +66,8 @@ def predict(smiles, smiles_file, output, model_type): current_dir = Path(__file__).resolve().parent if "hugging_face" in model_config: - local_file_path = download_model_files( - model_config["hugging_face"], - current_dir / ".api_models" / model_type, - ) + print(f"For model type `{model_type}` following files are used:") + local_file_path = download_model_files(model_config["hugging_face"]) predictor_kwargs["ckpt_path"] = local_file_path["ckpt"] predictor_kwargs["target_labels_path"] = local_file_path["labels"] diff --git a/api/hugging_face.py b/api/hugging_face.py index 19debb4..62d16e8 100644 --- a/api/hugging_face.py +++ b/api/hugging_face.py @@ -1,34 +1,48 @@ -import shutil +""" +Hugging Face Api: + - For Windows Users check: https://huggingface.co/docs/huggingface_hub/en/guides/manage-cache#limitations + + Refer for Hugging Face Hub caching and versioning documentation: + https://huggingface.co/docs/huggingface_hub/en/guides/download + https://huggingface.co/docs/huggingface_hub/en/guides/manage-cache +""" + from pathlib import Path from huggingface_hub import hf_hub_download -def download_model_files(model_config: dict, download_path: Path): +def download_model_files( + model_config: dict[str, str | dict[str, str]], +) -> dict[str, Path]: + """ + Downloads specified model files from a Hugging Face Hub repository using hf_hub_download. + + Hugging Face Hub provides internal caching and versioning, so file management or duplication + checks are not required. + + Args: + model_config (Dict[str, str | Dict[str, str]]): A dictionary containing: + - 'repo_id' (str): The Hugging Face repository ID (e.g., 'username/modelname'). + - 'subfolder' (str): The subfolder within the repo where the files are located. + - 'files' (Dict[str, str]): A mapping from file type (e.g., 'ckpt', 'labels') to + actual file names (e.g., 'electra.ckpt', 'classes.txt'). + + Returns: + Dict[str, Path]: A dictionary mapping each file type to the local Path of the downloaded file. + """ repo_id = model_config["repo_id"] subfolder = model_config["subfolder"] filenames = model_config["files"] - local_paths = {} + local_paths: dict[str, Path] = {} for file_type, filename in filenames.items(): - local_file_path = download_path / filename - if local_file_path.exists(): - print(f"File already exists: {local_file_path}") - local_paths[file_type] = local_file_path - continue - - print( - f"Downloading file from: https://huggingface.co/{repo_id}/{subfolder}/{filename}" - ) - downloaded_file = hf_hub_download( + downloaded_file_path = hf_hub_download( repo_id=repo_id, filename=filename, subfolder=subfolder, ) - - local_file_path.parent.mkdir(parents=True, exist_ok=True) - shutil.move(downloaded_file, local_file_path) - print(f"Saved to: {local_file_path}") - local_paths[file_type] = local_file_path + local_paths[file_type] = Path(downloaded_file_path) + print(f"\t Using file `{filename}` from: {downloaded_file_path}") return local_paths From 9c3beea542985ca28e38e8a69843bec83a6e77e7 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 28 Jun 2025 19:33:08 +0200 Subject: [PATCH 11/38] pre-commit -run -a --- chebifier/cli.py | 63 +++++++++++++------ .../ensemble/weighted_majority_ensemble.py | 21 ++++--- chebifier/prediction_models/base_predictor.py | 18 ++++-- .../prediction_models/chemlog_predictor.py | 33 ++++++---- .../prediction_models/electra_predictor.py | 11 ++-- chebifier/prediction_models/gnn_predictor.py | 54 +++++++++++----- chebifier/prediction_models/nn_predictor.py | 46 ++++++++++---- 7 files changed, 169 insertions(+), 77 deletions(-) diff --git a/chebifier/cli.py b/chebifier/cli.py index 704f8a0..a6b8743 100644 --- a/chebifier/cli.py +++ b/chebifier/cli.py @@ -1,11 +1,11 @@ - - - import click import yaml -import sys + from chebifier.ensemble.base_ensemble import BaseEnsemble -from chebifier.ensemble.weighted_majority_ensemble import WMVwithPPVNPVEnsemble, WMVwithF1Ensemble +from chebifier.ensemble.weighted_majority_ensemble import ( + WMVwithF1Ensemble, + WMVwithPPVNPVEnsemble, +) @click.group() @@ -13,36 +13,54 @@ def cli(): """Command line interface for Chebifier.""" pass + ENSEMBLES = { "mv": BaseEnsemble, "wmv-ppvnpv": WMVwithPPVNPVEnsemble, - "wmv-f1": WMVwithF1Ensemble + "wmv-f1": WMVwithF1Ensemble, } + @cli.command() -@click.argument('config_file', type=click.Path(exists=True)) -@click.option('--smiles', '-s', multiple=True, help='SMILES strings to predict') -@click.option('--smiles-file', '-f', type=click.Path(exists=True), help='File containing SMILES strings (one per line)') -@click.option('--output', '-o', type=click.Path(), help='Output file to save predictions (optional)') -@click.option('--ensemble-type', '-e', type=click.Choice(ENSEMBLES.keys()), default='mv', help='Type of ensemble to use (default: Majority Voting)') +@click.argument("config_file", type=click.Path(exists=True)) +@click.option("--smiles", "-s", multiple=True, help="SMILES strings to predict") +@click.option( + "--smiles-file", + "-f", + type=click.Path(exists=True), + help="File containing SMILES strings (one per line)", +) +@click.option( + "--output", + "-o", + type=click.Path(), + help="Output file to save predictions (optional)", +) +@click.option( + "--ensemble-type", + "-e", + type=click.Choice(ENSEMBLES.keys()), + default="mv", + help="Type of ensemble to use (default: Majority Voting)", +) def predict(config_file, smiles, smiles_file, output, ensemble_type): """Predict ChEBI classes for SMILES strings using an ensemble model. - + CONFIG_FILE is the path to a YAML configuration file for the ensemble model. """ # Load configuration from YAML file - with open(config_file, 'r') as f: + with open(config_file, "r") as f: config = yaml.safe_load(f) - + # Instantiate ensemble model ensemble = ENSEMBLES[ensemble_type](config) - + # Collect SMILES strings from arguments and/or file smiles_list = list(smiles) if smiles_file: - with open(smiles_file, 'r') as f: + with open(smiles_file, "r") as f: smiles_list.extend([line.strip() for line in f if line.strip()]) - + if not smiles_list: click.echo("No SMILES strings provided. Use --smiles or --smiles-file options.") return @@ -53,8 +71,13 @@ def predict(config_file, smiles, smiles_file, output, ensemble_type): if output: # save as json import json - with open(output, 'w') as f: - json.dump({smiles: pred for smiles, pred in zip(smiles_list, predictions)}, f, indent=2) + + with open(output, "w") as f: + json.dump( + {smiles: pred for smiles, pred in zip(smiles_list, predictions)}, + f, + indent=2, + ) else: # Print results @@ -66,5 +89,5 @@ def predict(config_file, smiles, smiles_file, output, ensemble_type): click.echo(" No predictions") -if __name__ == '__main__': +if __name__ == "__main__": cli() diff --git a/chebifier/ensemble/weighted_majority_ensemble.py b/chebifier/ensemble/weighted_majority_ensemble.py index 811770d..95e9956 100644 --- a/chebifier/ensemble/weighted_majority_ensemble.py +++ b/chebifier/ensemble/weighted_majority_ensemble.py @@ -3,9 +3,7 @@ from chebifier.ensemble.base_ensemble import BaseEnsemble - class WMVwithPPVNPVEnsemble(BaseEnsemble): - def calculate_classwise_weights(self, predicted_classes): """ Given the positions of predicted classes in the predictions tensor, assign weights to each class. The @@ -23,15 +21,18 @@ def calculate_classwise_weights(self, predicted_classes): positive_weights[predicted_classes[cls], j] *= weights["PPV"] negative_weights[predicted_classes[cls], j] *= weights["NPV"] - print(f"Calculated model weightings. The averages for positive / negative weights are:") + print( + "Calculated model weightings. The averages for positive / negative weights are:" + ) for i, model in enumerate(self.models): - print(f"{model.model_name}: {positive_weights[:, i].mean().item():.3f} / {negative_weights[:, i].mean().item():.3f}") + print( + f"{model.model_name}: {positive_weights[:, i].mean().item():.3f} / {negative_weights[:, i].mean().item():.3f}" + ) return positive_weights, negative_weights class WMVwithF1Ensemble(BaseEnsemble): - def calculate_classwise_weights(self, predicted_classes): """ Given the positions of predicted classes in the predictions tensor, assign weights to each class. The @@ -45,11 +46,15 @@ def calculate_classwise_weights(self, predicted_classes): continue for cls, weights in model.classwise_weights.items(): if (2 * weights["TP"] + weights["FP"] + weights["FN"]) > 0: - f1 = 2 * weights["TP"] / (2 * weights["TP"] + weights["FP"] + weights["FN"]) + f1 = ( + 2 + * weights["TP"] + / (2 * weights["TP"] + weights["FP"] + weights["FN"]) + ) weights_by_cls[predicted_classes[cls], j] *= f1 - print(f"Calculated model weightings. The average weights are:") + print("Calculated model weightings. The average weights are:") for i, model in enumerate(self.models): print(f"{model.model_name}: {weights_by_cls[:, i].mean().item():.3f}") - return weights_by_cls, weights_by_cls \ No newline at end of file + return weights_by_cls, weights_by_cls diff --git a/chebifier/prediction_models/base_predictor.py b/chebifier/prediction_models/base_predictor.py index 5633458..e6b7952 100644 --- a/chebifier/prediction_models/base_predictor.py +++ b/chebifier/prediction_models/base_predictor.py @@ -1,16 +1,24 @@ -from abc import ABC import json +from abc import ABC + class BasePredictor(ABC): - def __init__(self, model_name: str, model_weight: int = 1, classwise_weights_path: str = None, **kwargs): + def __init__( + self, + model_name: str, + model_weight: int = 1, + classwise_weights_path: str = None, + **kwargs + ): self.model_name = model_name self.model_weight = model_weight if classwise_weights_path is not None: - self.classwise_weights = json.load(open(classwise_weights_path, encoding="utf-8")) + self.classwise_weights = json.load( + open(classwise_weights_path, encoding="utf-8") + ) else: self.classwise_weights = None - def predict_smiles_list(self, smiles_list: list[str]) -> dict: - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/chebifier/prediction_models/chemlog_predictor.py b/chebifier/prediction_models/chemlog_predictor.py index 54b020a..692c79c 100644 --- a/chebifier/prediction_models/chemlog_predictor.py +++ b/chebifier/prediction_models/chemlog_predictor.py @@ -1,23 +1,22 @@ import tqdm +from chemlog.cli import CLASSIFIERS, _smiles_to_mol, strategy_call from chebifier.prediction_models.base_predictor import BasePredictor -from chemlog.alg_classification.charge_classifier import AlgChargeClassifier -from chemlog.alg_classification.peptide_size_classifier import AlgPeptideSizeClassifier -from chemlog.alg_classification.proteinogenics_classifier import AlgProteinogenicsClassifier -from chemlog.alg_classification.substructure_classifier import AlgSubstructureClassifier -from chemlog.cli import strategy_call, _smiles_to_mol, CLASSIFIERS -class ChemLogPredictor(BasePredictor): +class ChemLogPredictor(BasePredictor): def __init__(self, model_name: str, **kwargs): super().__init__(model_name, **kwargs) self.strategy = "algo" self.classifier_instances = { k: v() for k, v in CLASSIFIERS[self.strategy].items() } - self.peptide_labels = ["15841", "16670", "24866", "25676", "25696", "25697", "27369", "46761", "47923", - "48030", "48545", "60194", "60334", "60466", "64372", "65061", "90799", "155837"] - + # fmt: off + self.peptide_labels = [ + "15841", "16670", "24866", "25676", "25696", "25697", "27369", "46761", "47923", + "48030", "48545", "60194", "60334", "60466", "64372", "65061", "90799", "155837" + ] + # fmt: on print(f"Initialised ChemLog model {self.model_name}") def predict_smiles_list(self, smiles_list: list[str]) -> list: @@ -27,9 +26,21 @@ def predict_smiles_list(self, smiles_list: list[str]) -> list: if mol is None: results.append(None) else: - results.append({label: 1 if label in strategy_call(self.strategy, self.classifier_instances, mol)["chebi_classes"] else 0 for label in self.peptide_labels}) + results.append( + { + label: ( + 1 + if label + in strategy_call( + self.strategy, self.classifier_instances, mol + )["chebi_classes"] + else 0 + ) + for label in self.peptide_labels + } + ) for classifier in self.classifier_instances.values(): classifier.on_finish() - return results \ No newline at end of file + return results diff --git a/chebifier/prediction_models/electra_predictor.py b/chebifier/prediction_models/electra_predictor.py index 075eafa..7a3bcaa 100644 --- a/chebifier/prediction_models/electra_predictor.py +++ b/chebifier/prediction_models/electra_predictor.py @@ -1,7 +1,8 @@ -from chebifier.prediction_models.nn_predictor import NNPredictor from chebai.models.electra import Electra from chebai.preprocessing.reader import ChemDataReader +from chebifier.prediction_models.nn_predictor import NNPredictor + class ElectraPredictor(NNPredictor): @@ -13,10 +14,10 @@ def init_model(self, ckpt_path: str, **kwargs) -> Electra: model = Electra.load_from_checkpoint( ckpt_path, map_location=self.device, - criterion=None, strict=False, - metrics=dict(train=dict(), test=dict(), validation=dict()), pretrained_checkpoint=None + criterion=None, + strict=False, + metrics=dict(train=dict(), test=dict(), validation=dict()), + pretrained_checkpoint=None, ) model.eval() return model - - diff --git a/chebifier/prediction_models/gnn_predictor.py b/chebifier/prediction_models/gnn_predictor.py index b139c6c..ef354c1 100644 --- a/chebifier/prediction_models/gnn_predictor.py +++ b/chebifier/prediction_models/gnn_predictor.py @@ -1,16 +1,19 @@ -from chebifier.prediction_models.nn_predictor import NNPredictor import chebai_graph.preprocessing.properties as p import torch from chebai_graph.models.graph import ResGatedGraphConvNetGraphPred -from chebai_graph.preprocessing.reader import GraphPropertyReader from chebai_graph.preprocessing.property_encoder import IndexEncoder, OneHotEncoder +from chebai_graph.preprocessing.reader import GraphPropertyReader from torch_geometric.data.data import Data as GeomData +from chebifier.prediction_models.nn_predictor import NNPredictor + class ResGatedPredictor(NNPredictor): def __init__(self, model_name: str, ckpt_path: str, molecular_properties, **kwargs): - super().__init__(model_name, ckpt_path, reader_cls=GraphPropertyReader, **kwargs) + super().__init__( + model_name, ckpt_path, reader_cls=GraphPropertyReader, **kwargs + ) # molecular_properties is a list of class paths if molecular_properties is not None: properties = [self.load_class(prop)() for prop in molecular_properties] @@ -32,11 +35,23 @@ def load_class(self, class_path: str): def init_model(self, ckpt_path: str, **kwargs) -> ResGatedGraphConvNetGraphPred: model = ResGatedGraphConvNetGraphPred.load_from_checkpoint( - ckpt_path, map_location=torch.device(self.device), criterion=None, strict=False, - metrics=dict(train=dict(), test=dict(), validation=dict()), pretrained_checkpoint=None, - config={"in_length": 256, "hidden_length": 512, "dropout_rate": 0.1, "n_conv_layers": 3, - "n_linear_layers": 3, "n_atom_properties": 158, "n_bond_properties": 7, - "n_molecule_properties": 200}) + ckpt_path, + map_location=torch.device(self.device), + criterion=None, + strict=False, + metrics=dict(train=dict(), test=dict(), validation=dict()), + pretrained_checkpoint=None, + config={ + "in_length": 256, + "hidden_length": 512, + "dropout_rate": 0.1, + "n_conv_layers": 3, + "n_linear_layers": 3, + "n_atom_properties": 158, + "n_bond_properties": 7, + "n_molecule_properties": 200, + }, + ) model.eval() return model @@ -55,14 +70,21 @@ def read_smiles(self, smiles): # use default value if we meet an unseen value if isinstance(prop.encoder, IndexEncoder): if str(value) in prop.encoder.cache: - index = prop.encoder.cache.index(str(value)) + prop.encoder.offset + index = ( + prop.encoder.cache.index(str(value)) + prop.encoder.offset + ) else: index = 0 - print(f"Unknown property value {value} for property {prop} at smiles {smiles}") + print( + f"Unknown property value {value} for property {prop} at smiles {smiles}" + ) if isinstance(prop.encoder, OneHotEncoder): - encoded_values.append(torch.nn.functional.one_hot( - torch.tensor(index), num_classes=prop.encoder.get_encoding_length() - )) + encoded_values.append( + torch.nn.functional.one_hot( + torch.tensor(index), + num_classes=prop.encoder.get_encoding_length(), + ) + ) else: encoded_values.append(torch.tensor([index])) @@ -77,9 +99,7 @@ def read_smiles(self, smiles): if len(encoded_values.size()) == 1: encoded_values = encoded_values.unsqueeze(1) else: - encoded_values = torch.zeros( - (0, prop.encoder.get_encoding_length()) - ) + encoded_values = torch.zeros((0, prop.encoder.get_encoding_length())) if isinstance(prop, p.AtomProperty): x = torch.cat([x, encoded_values], dim=1) elif isinstance(prop, p.BondProperty): @@ -93,4 +113,4 @@ def read_smiles(self, smiles): edge_attr=edge_attr, molecule_attr=molecule_attr, ) - return d \ No newline at end of file + return d diff --git a/chebifier/prediction_models/nn_predictor.py b/chebifier/prediction_models/nn_predictor.py index 1ee5e46..9f2e00a 100644 --- a/chebifier/prediction_models/nn_predictor.py +++ b/chebifier/prediction_models/nn_predictor.py @@ -1,24 +1,35 @@ +import numpy as np +import torch import tqdm +from rdkit import Chem from chebifier.prediction_models.base_predictor import BasePredictor -from rdkit import Chem -import numpy as np -import torch + class NNPredictor(BasePredictor): - def __init__(self, model_name: str, ckpt_path: str, reader_cls, target_labels_path: str, **kwargs): + def __init__( + self, + model_name: str, + ckpt_path: str, + reader_cls, + target_labels_path: str, + **kwargs, + ): super().__init__(model_name, **kwargs) self.reader_cls = reader_cls self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = self.init_model(ckpt_path=ckpt_path) - self.target_labels = [line.strip() for line in open(target_labels_path, encoding="utf-8")] + self.target_labels = [ + line.strip() for line in open(target_labels_path, encoding="utf-8") + ] self.batch_size = kwargs.get("batch_size", 1) - def init_model(self, ckpt_path: str, **kwargs): - raise NotImplementedError("Model initialization must be implemented in subclasses.") + raise NotImplementedError( + "Model initialization must be implemented in subclasses." + ) def calculate_results(self, batch): collator = self.reader_cls.COLLATOR() @@ -66,14 +77,27 @@ def predict_smiles_list(self, smiles_list) -> list: token_dicts.append(d) results = [] if token_dicts: - for batch in tqdm.tqdm(self.batchify(token_dicts), desc=f"{self.model_name}", total=len(token_dicts)//self.batch_size): + for batch in tqdm.tqdm( + self.batchify(token_dicts), + desc=f"{self.model_name}", + total=len(token_dicts) // self.batch_size, + ): result = self.calculate_results(batch) if isinstance(result, dict) and "logits" in result: result = result["logits"] results += torch.sigmoid(result).cpu().detach().tolist() results = np.stack(results, axis=0) - preds = [{self.target_labels[j]: p for j, p in enumerate(results[index_map[i]])} - if i not in could_not_parse else None for i in range(len(smiles_list))] + preds = [ + ( + { + self.target_labels[j]: p + for j, p in enumerate(results[index_map[i]]) + } + if i not in could_not_parse + else None + ) + for i in range(len(smiles_list)) + ] return preds else: - return [None for _ in smiles_list] \ No newline at end of file + return [None for _ in smiles_list] From e6602ef24249d20634fbdf539e8117696973946b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 1 Jul 2025 23:13:34 +0200 Subject: [PATCH 12/38] remove explicit config kwargs for resgated --- chebifier/prediction_models/gnn_predictor.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/chebifier/prediction_models/gnn_predictor.py b/chebifier/prediction_models/gnn_predictor.py index ef354c1..57afcfc 100644 --- a/chebifier/prediction_models/gnn_predictor.py +++ b/chebifier/prediction_models/gnn_predictor.py @@ -9,7 +9,6 @@ class ResGatedPredictor(NNPredictor): - def __init__(self, model_name: str, ckpt_path: str, molecular_properties, **kwargs): super().__init__( model_name, ckpt_path, reader_cls=GraphPropertyReader, **kwargs @@ -41,16 +40,6 @@ def init_model(self, ckpt_path: str, **kwargs) -> ResGatedGraphConvNetGraphPred: strict=False, metrics=dict(train=dict(), test=dict(), validation=dict()), pretrained_checkpoint=None, - config={ - "in_length": 256, - "hidden_length": 512, - "dropout_rate": 0.1, - "n_conv_layers": 3, - "n_linear_layers": 3, - "n_atom_properties": 158, - "n_bond_properties": 7, - "n_molecule_properties": 200, - }, ) model.eval() return model From fd814e928a62d17899859942c58f684613df2a1d Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 6 Jul 2025 18:30:46 +0200 Subject: [PATCH 13/38] api support for ensemble --- api/api_registry.yml | 24 ++++ api/check_env.py | 30 +++++ api/cli.py | 67 ++++++----- api/hugging_face.py | 5 +- api/registry.yml | 23 ---- api/setup_env.py | 165 ---------------------------- chebifier/cli.py | 13 +-- chebifier/ensemble/base_ensemble.py | 16 +-- chebifier/model_registry.py | 29 +++++ 9 files changed, 130 insertions(+), 242 deletions(-) create mode 100644 api/api_registry.yml create mode 100644 api/check_env.py delete mode 100644 api/registry.yml delete mode 100644 api/setup_env.py create mode 100644 chebifier/model_registry.py diff --git a/api/api_registry.yml b/api/api_registry.yml new file mode 100644 index 0000000..b6e30bd --- /dev/null +++ b/api/api_registry.yml @@ -0,0 +1,24 @@ +electra: + hugging_face: + repo_id: aditya0by0/python-chebifier + subfolder: electra + files: + ckpt: electra.ckpt + labels: classes.txt + package_name: chebai + +resgated: + hugging_face: + repo_id: aditya0by0/python-chebifier + subfolder: resgated + files: + ckpt: resgated.ckpt + labels: classes.txt + package_name: chebai-graph + +chemlog: + package_name: chemlog + + +en_mv: + ensemble_of: {electra, chemlog} diff --git a/api/check_env.py b/api/check_env.py new file mode 100644 index 0000000..b215fda --- /dev/null +++ b/api/check_env.py @@ -0,0 +1,30 @@ +import subprocess +import sys + + +def get_current_environment() -> str: + """ + Return the path of the Python executable for the current environment. + """ + return sys.executable + + +def check_package_installed(package_name: str) -> None: + """ + Check if the given package is installed in the current Python environment. + """ + python_exec = get_current_environment() + try: + subprocess.check_output( + [python_exec, "-m", "pip", "show", package_name], stderr=subprocess.DEVNULL + ) + print(f"✅ Package '{package_name}' is already installed.") + except subprocess.CalledProcessError: + raise ( + f"❌ Please install '{package_name}' into your environment: {python_exec}" + ) + + +if __name__ == "__main__": + print(f"🔍 Using Python executable: {get_current_environment()}") + check_package_installed("numpy") # Replace with your desired package diff --git a/api/cli.py b/api/cli.py index 99f7572..e20ed14 100644 --- a/api/cli.py +++ b/api/cli.py @@ -1,18 +1,17 @@ -import importlib from pathlib import Path import click import yaml -from chebifier.prediction_models.base_predictor import BasePredictor +from chebifier.model_registry import ENSEMBLES, MODEL_TYPES +from .check_env import check_package_installed, get_current_environment from .hugging_face import download_model_files -from .setup_env import SetupEnvAndPackage -yaml_path = Path("api/registry.yml") +yaml_path = Path("api/api_registry.yml") if yaml_path.exists(): with yaml_path.open("r") as f: - model_registry = yaml.safe_load(f) + api_registry = yaml.safe_load(f) else: raise FileNotFoundError(f"{yaml_path} not found.") @@ -40,7 +39,7 @@ def cli(): @click.option( "--model-type", "-m", - type=click.Choice(model_registry.keys()), + type=click.Choice(api_registry.keys()), default="mv", help="Type of model to use", ) @@ -60,29 +59,39 @@ def predict(smiles, smiles_file, output, model_type): click.echo("No SMILES strings provided. Use --smiles or --smiles-file options.") return - model_config = model_registry[model_type] - predictor_kwargs = {"model_name": model_type} - - current_dir = Path(__file__).resolve().parent - - if "hugging_face" in model_config: - print(f"For model type `{model_type}` following files are used:") - local_file_path = download_model_files(model_config["hugging_face"]) - predictor_kwargs["ckpt_path"] = local_file_path["ckpt"] - predictor_kwargs["target_labels_path"] = local_file_path["labels"] - - SetupEnvAndPackage().setup( - repo_url=model_config["repo_url"], - clone_dir=current_dir / ".cloned_repos", - venv_dir=current_dir, - ) - - model_cls_path = model_config["wrapper"] - module_path, class_name = model_cls_path.rsplit(".", 1) - module = importlib.import_module(module_path) - model_cls: type = getattr(module, class_name) - model_instance = model_cls(**predictor_kwargs) - assert isinstance(model_instance, BasePredictor) + print("Current working environment is:", get_current_environment()) + + def get_individual_model(model_config): + predictor_kwargs = {} + if "hugging_face" in model_config: + predictor_kwargs = download_model_files(model_config["hugging_face"]) + check_package_installed(model_config["package_name"]) + return predictor_kwargs + + if model_type in MODEL_TYPES: + print(f"Predictor for Single/Individual Model: {model_type}") + model_config = api_registry[model_type] + predictor_kwargs = get_individual_model(model_config) + predictor_kwargs["model_name"] = model_type + model_instance = MODEL_TYPES[model_type](**predictor_kwargs) + + elif model_type in ENSEMBLES: + print(f"Predictor for Ensemble Model: {model_type}") + ensemble_config = {} + for i, en_comp in enumerate(api_registry[model_type]["ensemble_of"]): + assert en_comp in MODEL_TYPES + print(f"For ensemble component {en_comp}") + predictor_kwargs = get_individual_model(api_registry[en_comp]) + model_key = f"model_{i + 1}" + ensemble_config[model_key] = { + "type": en_comp, + "model_name": f"{en_comp}_{model_key}", + **predictor_kwargs, + } + model_instance = ENSEMBLES[model_type](ensemble_config) + + else: + raise ValueError("") # Make predictions predictions = model_instance.predict_smiles_list(smiles_list) diff --git a/api/hugging_face.py b/api/hugging_face.py index 62d16e8..5569d86 100644 --- a/api/hugging_face.py +++ b/api/hugging_face.py @@ -45,4 +45,7 @@ def download_model_files( local_paths[file_type] = Path(downloaded_file_path) print(f"\t Using file `{filename}` from: {downloaded_file_path}") - return local_paths + return { + "ckpt_path": local_paths["ckpt"], + "target_labels_path": local_paths["labels"], + } diff --git a/api/registry.yml b/api/registry.yml deleted file mode 100644 index c9069a8..0000000 --- a/api/registry.yml +++ /dev/null @@ -1,23 +0,0 @@ -electra: - hugging_face: - repo_id: aditya0by0/python-chebifier - subfolder: electra - files: - ckpt: electra.ckpt - labels: classes.txt - repo_url: https://github.com/ChEB-AI/python-chebai - wrapper: chebifier.prediction_models.ElectraPredictor - -resgated: - hugging_face: - repo_id: aditya0by0/python-chebifier - subfolder: resgated - files: - ckpt: resgated.ckpt - labels: classes.txt - repo_url: https://github.com/ChEB-AI/python-chebai-graph - wrapper: chebifier.prediction_models.ResGatedPredictor - -chemlog: - repo_url: https://github.com/sfluegel05/chemlog-peptides - wrapper: chebifier.prediction_models.ChemLogPredictor diff --git a/api/setup_env.py b/api/setup_env.py deleted file mode 100644 index a246c26..0000000 --- a/api/setup_env.py +++ /dev/null @@ -1,165 +0,0 @@ -import os -import re -import subprocess -import sys -from pathlib import Path - -# Conditional import of tomllib based on Python version -if sys.version_info >= (3, 11): - import tomllib # built-in in Python 3.11+ -else: - import toml as tomllib # use third-party toml library for older versions - - -class SetupEnvAndPackage: - """Utility class for cloning a repository, setting up a virtual environment, and installing a package.""" - - def setup( - self, - repo_url: str, - clone_dir: Path, - venv_dir: Path, - venv_name: str = ".venv-chebifier", - ) -> None: - """ - Orchestrates the full setup process: cloning the repository, - creating a virtual environment, and installing the package. - - Args: - repo_url (str): URL of the Git repository. - clone_dir (Path): Directory to clone the repo into. - venv_dir (Path): Directory where the virtual environment will be created. - venv_name (str): Name of the virtual environment folder. - """ - cloned_repo_path = self._clone_repo(repo_url, clone_dir) - venv_path = self._create_virtualenv(venv_dir, venv_name) - self._install_from_pyproject(venv_path, cloned_repo_path) - - def _clone_repo(self, repo_url: str, clone_dir: Path) -> Path: - """ - Clone a Git repository into a specified directory. - - Args: - repo_url (str): Git URL to clone. - clone_dir (Path): Directory to clone into. - - Returns: - Path: Path to the cloned repository. - """ - repo_name = repo_url.rstrip("/").split("/")[-1].replace(".git", "") - clone_path = Path(clone_dir or repo_name) - - if not clone_path.exists(): - print(f"Cloning {repo_url} into {clone_path}...") - subprocess.check_call( - ["git", "clone", "--depth=1", repo_url, str(clone_path)] - ) - else: - print(f"Repo already exists at {clone_path}") - - return clone_path - - @staticmethod - def _create_virtualenv(venv_dir: Path, venv_name: str = ".venv-chebifier") -> Path: - """ - Create a virtual environment at the specified path. - - Args: - venv_dir (Path): Base directory where the venv will be created. - venv_name (str): Name of the virtual environment directory. - - Returns: - Path: Path to the virtual environment. - """ - venv_path = venv_dir / venv_name - - if venv_path.exists(): - print(f"Virtual environment already exists at: {venv_path}") - return venv_path - - print(f"Creating virtual environment at: {venv_path}") - - try: - subprocess.check_call(["virtualenv", str(venv_path)]) - except FileNotFoundError: - print("virtualenv not found, installing it now...") - subprocess.check_call( - [sys.executable, "-m", "pip", "install", "virtualenv"] - ) - subprocess.check_call(["virtualenv", str(venv_path)]) - - return venv_path - - def _install_from_pyproject(self, venv_dir: Path, cloned_repo_path: Path) -> None: - """ - Install the cloned package in editable mode. - - Args: - venv_dir (Path): Path to the virtual environment. - cloned_repo_path (Path): Path to the cloned repository. - """ - pip_executable = ( - venv_dir / "Scripts" / "pip.exe" - if os.name == "nt" - else venv_dir / "bin" / "pip" - ) - - if not pip_executable.exists(): - raise FileNotFoundError(f"pip not found at {pip_executable}") - - try: - package_name = self._get_package_name(cloned_repo_path) - except Exception as e: - raise RuntimeError(f"Error extracting package name: {e}") - - try: - subprocess.check_output( - [str(pip_executable), "show", package_name], stderr=subprocess.DEVNULL - ) - print(f"Package '{package_name}' is already installed.") - except subprocess.CalledProcessError: - print(f"Installing '{package_name}' from {cloned_repo_path}...") - subprocess.check_call( - [str(pip_executable), "install", "-e", "."], - cwd=cloned_repo_path, - ) - - @staticmethod - def _get_package_name(cloned_repo_path: Path) -> str: - """ - Extracts the package name from `pyproject.toml` or `setup.py`. - - Args: - cloned_repo_path (Path): Path to the cloned repository. - - Returns: - str: Name of the Python package. - - Raises: - ValueError: If parsing fails. - FileNotFoundError: If neither config file is found. - """ - pyproject_path = cloned_repo_path / "pyproject.toml" - setup_path = cloned_repo_path / "setup.py" - - if pyproject_path.exists(): - try: - with pyproject_path.open("rb") as f: - pyproject = tomllib.load(f) - return pyproject["project"]["name"] - except Exception as e: - raise ValueError(f"Failed to parse pyproject.toml: {e}") - - elif setup_path.exists(): - try: - setup_contents = setup_path.read_text() - match = re.search(r'name\s*=\s*[\'"]([^\'"]+)[\'"]', setup_contents) - if match: - return match.group(1) - else: - raise ValueError("Could not find package name in setup.py") - except Exception as e: - raise ValueError(f"Failed to parse setup.py: {e}") - - else: - raise FileNotFoundError("Neither pyproject.toml nor setup.py found.") diff --git a/chebifier/cli.py b/chebifier/cli.py index a6b8743..b51dc04 100644 --- a/chebifier/cli.py +++ b/chebifier/cli.py @@ -1,11 +1,7 @@ import click import yaml -from chebifier.ensemble.base_ensemble import BaseEnsemble -from chebifier.ensemble.weighted_majority_ensemble import ( - WMVwithF1Ensemble, - WMVwithPPVNPVEnsemble, -) +from .model_registry import ENSEMBLES @click.group() @@ -14,13 +10,6 @@ def cli(): pass -ENSEMBLES = { - "mv": BaseEnsemble, - "wmv-ppvnpv": WMVwithPPVNPVEnsemble, - "wmv-f1": WMVwithF1Ensemble, -} - - @cli.command() @click.argument("config_file", type=click.Path(exists=True)) @click.option("--smiles", "-s", multiple=True, help="SMILES strings to predict") diff --git a/chebifier/ensemble/base_ensemble.py b/chebifier/ensemble/base_ensemble.py index d4a4fe3..19f49d2 100644 --- a/chebifier/ensemble/base_ensemble.py +++ b/chebifier/ensemble/base_ensemble.py @@ -4,22 +4,14 @@ import torch import tqdm -from chebifier.prediction_models import ( - BasePredictor, - ChemLogPredictor, - ElectraPredictor, - ResGatedPredictor, -) - -MODEL_TYPES = { - "electra": ElectraPredictor, - "resgated": ResGatedPredictor, - "chemlog": ChemLogPredictor, -} +from chebifier.prediction_models import BasePredictor class BaseEnsemble(ABC): def __init__(self, model_configs: dict): + # Deferred Import: To avoid circular import error + from chebifier.model_registry import MODEL_TYPES + self.models = [] self.positive_prediction_threshold = 0.5 for model_name, model_config in model_configs.items(): diff --git a/chebifier/model_registry.py b/chebifier/model_registry.py new file mode 100644 index 0000000..4961f3e --- /dev/null +++ b/chebifier/model_registry.py @@ -0,0 +1,29 @@ +from chebifier.ensemble.base_ensemble import BaseEnsemble +from chebifier.ensemble.weighted_majority_ensemble import ( + WMVwithF1Ensemble, + WMVwithPPVNPVEnsemble, +) +from chebifier.prediction_models import ( + ChemLogPredictor, + ElectraPredictor, + ResGatedPredictor, +) + +ENSEMBLES = { + "en_mv": BaseEnsemble, + "en_wmv-ppvnpv": WMVwithPPVNPVEnsemble, + "en_wmv-f1": WMVwithF1Ensemble, +} + + +MODEL_TYPES = { + "electra": ElectraPredictor, + "resgated": ResGatedPredictor, + "chemlog": ChemLogPredictor, +} + + +common_keys = MODEL_TYPES.keys() & ENSEMBLES.keys() +assert ( + not common_keys +), f"Overlapping keys between MODEL_TYPES and ENSEMBLES: {common_keys}" From a044f23d6ec797645d05987eee590b9f3c46adf6 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 6 Jul 2025 18:34:31 +0200 Subject: [PATCH 14/38] add ruff action workflow --- .github/workflows/lint.yml | 26 ++++++++++++++++++++++++++ .pre-commit-config.yaml | 2 +- 2 files changed, 27 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/lint.yml diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..1b63c41 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,26 @@ +name: Lint + +on: [push, pull_request] + +jobs: + lint: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' # or any version your project uses + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install black ruff + + - name: Run Black + run: black --check . + + - name: Run Ruff (no formatting) + run: ruff check . --no-fix diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e32d80c..b8a785a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,4 +28,4 @@ repos: rev: v0.12.1 hooks: - id: ruff - args: [] # No --fix, disables formatting + args: [--fix] From 51a2d348e5d2dcaf4a1b6a2b79debb4c63a47964 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 6 Jul 2025 18:49:42 +0200 Subject: [PATCH 15/38] same version for workflow and pre-commit yaml --- .github/workflows/lint.yml | 2 +- .pre-commit-config.yaml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 1b63c41..bb9154f 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -17,7 +17,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install black ruff + pip install black==25.1.0 ruff==0.12.2 - name: Run Black run: black --check . diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b8a785a..cbb7284 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/psf/black - rev: "24.2.0" + rev: "25.1.0" hooks: - id: black - id: black-jupyter # for formatting jupyter-notebook @@ -25,7 +25,7 @@ repos: - id: trailing-whitespace - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.12.1 + rev: v0.12.2 hooks: - id: ruff args: [--fix] From d2c586aa9e8e25a3734fa3de48351264351d615e Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 6 Jul 2025 19:02:48 +0200 Subject: [PATCH 16/38] Update base_predictor.py --- chebifier/prediction_models/base_predictor.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/chebifier/prediction_models/base_predictor.py b/chebifier/prediction_models/base_predictor.py index e6b7952..3eeee52 100644 --- a/chebifier/prediction_models/base_predictor.py +++ b/chebifier/prediction_models/base_predictor.py @@ -3,13 +3,12 @@ class BasePredictor(ABC): - def __init__( self, model_name: str, model_weight: int = 1, classwise_weights_path: str = None, - **kwargs + **kwargs, ): self.model_name = model_name self.model_weight = model_weight From f3b39052857f4357b5baa528a058ff2b8c836dc2 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 11 Jul 2025 11:56:19 +0200 Subject: [PATCH 17/38] fix readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c42e4df..0559817 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ cd python-chebifier pip install -e . ``` -Some dependencies of `chebai-graph` cannot be installed automatically. If you want to use Graph Neural Networks, follow +u`chebai-graph` and its dependencies cannot be installed automatically. If you want to use Graph Neural Networks, follow the instructions in the [chebai-graph repository](https://github.com/ChEB-AI/python-chebai-graph). ## Usage From 001538daf33abc584f22c4a0afccb2c62a510f33 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 11 Jul 2025 11:59:41 +0200 Subject: [PATCH 18/38] fix cli and ensemble imports --- chebifier/cli.py | 9 --------- chebifier/ensemble/base_ensemble.py | 13 ++----------- 2 files changed, 2 insertions(+), 20 deletions(-) diff --git a/chebifier/cli.py b/chebifier/cli.py index 2c3ad0d..5fa9679 100644 --- a/chebifier/cli.py +++ b/chebifier/cli.py @@ -2,21 +2,12 @@ import yaml from .model_registry import ENSEMBLES -from chebifier.ensemble.base_ensemble import BaseEnsemble -from chebifier.ensemble.weighted_majority_ensemble import WMVwithPPVNPVEnsemble, WMVwithF1Ensemble - @click.group() def cli(): """Command line interface for Chebifier.""" pass -ENSEMBLES = { - "mv": BaseEnsemble, - "wmv-ppvnpv": WMVwithPPVNPVEnsemble, - "wmv-f1": WMVwithF1Ensemble -} - @cli.command() @click.argument('config_file', type=click.Path(exists=True)) @click.option('--smiles', '-s', multiple=True, help='SMILES strings to predict') diff --git a/chebifier/ensemble/base_ensemble.py b/chebifier/ensemble/base_ensemble.py index 5f94d02..a071a33 100644 --- a/chebifier/ensemble/base_ensemble.py +++ b/chebifier/ensemble/base_ensemble.py @@ -1,22 +1,13 @@ import os -from abc import ABC import torch import tqdm from chebai.preprocessing.datasets.chebi import ChEBIOver50 from chebai.result.analyse_sem import PredictionSmoother from chebifier.prediction_models.base_predictor import BasePredictor -from chebifier.prediction_models.chemlog_predictor import ChemLogPredictor -from chebifier.prediction_models.electra_predictor import ElectraPredictor -from chebifier.prediction_models.gnn_predictor import ResGatedPredictor - -MODEL_TYPES = { - "electra": ElectraPredictor, - "resgated": ResGatedPredictor, - "chemlog": ChemLogPredictor -} -class BaseEnsemble(ABC): + +class BaseEnsemble: def __init__(self, model_configs: dict, chebi_version: int = 241): # Deferred Import: To avoid circular import error From f8583cbdfa0059378e3039cb2c763c6866e737bd Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 11 Jul 2025 13:08:03 +0200 Subject: [PATCH 19/38] add huggingface download to cli --- chebifier/__main__.py | 4 ++++ chebifier/cli.py | 8 +++++--- chebifier/ensemble/base_ensemble.py | 22 +++++++++++++++------- chebifier/model_registry.py | 6 +++--- configs/huggingface_config.yml | 22 ++++++++++++++++++++++ pyproject.toml | 3 --- 6 files changed, 49 insertions(+), 16 deletions(-) create mode 100644 chebifier/__main__.py create mode 100644 configs/huggingface_config.yml diff --git a/chebifier/__main__.py b/chebifier/__main__.py new file mode 100644 index 0000000..9aebe0f --- /dev/null +++ b/chebifier/__main__.py @@ -0,0 +1,4 @@ +from chebifier.cli import cli + +if __name__ == '__main__': + cli() \ No newline at end of file diff --git a/chebifier/cli.py b/chebifier/cli.py index 5fa9679..a21ebf3 100644 --- a/chebifier/cli.py +++ b/chebifier/cli.py @@ -1,3 +1,5 @@ +import os + import click import yaml @@ -9,14 +11,14 @@ def cli(): pass @cli.command() -@click.argument('config_file', type=click.Path(exists=True)) +@click.option('--config_file', type=click.Path(exists=True), default=os.path.join('configs', 'huggingface_config.yml'), help="Configuration file for ensemble models") @click.option('--smiles', '-s', multiple=True, help='SMILES strings to predict') @click.option('--smiles-file', '-f', type=click.Path(exists=True), help='File containing SMILES strings (one per line)') @click.option('--output', '-o', type=click.Path(), help='Output file to save predictions (optional)') @click.option('--ensemble-type', '-e', type=click.Choice(ENSEMBLES.keys()), default='mv', help='Type of ensemble to use (default: Majority Voting)') @click.option("--chebi-version", "-v", type=int, default=241, help="ChEBI version to use for checking consistency (default: 241)") @click.option("--use-confidence", "-c", is_flag=True, default=True, help="Weight predictions based on how 'confident' a model is in its prediction (default: True)") -def predict(config_file, smiles, smiles_file, output, ensemble_type, chebi_version): +def predict(config_file, smiles, smiles_file, output, ensemble_type, chebi_version, use_confidence): """Predict ChEBI classes for SMILES strings using an ensemble model. CONFIG_FILE is the path to a YAML configuration file for the ensemble model. @@ -39,7 +41,7 @@ def predict(config_file, smiles, smiles_file, output, ensemble_type, chebi_versi return # Make predictions - predictions = ensemble.predict_smiles_list(smiles_list) + predictions = ensemble.predict_smiles_list(smiles_list, use_confidence=use_confidence) if output: # save as json diff --git a/chebifier/ensemble/base_ensemble.py b/chebifier/ensemble/base_ensemble.py index a071a33..a946c2a 100644 --- a/chebifier/ensemble/base_ensemble.py +++ b/chebifier/ensemble/base_ensemble.py @@ -4,6 +4,7 @@ from chebai.preprocessing.datasets.chebi import ChEBIOver50 from chebai.result.analyse_sem import PredictionSmoother +from api.hugging_face import download_model_files from chebifier.prediction_models.base_predictor import BasePredictor @@ -17,14 +18,20 @@ def __init__(self, model_configs: dict, chebi_version: int = 241): self.positive_prediction_threshold = 0.5 for model_name, model_config in model_configs.items(): model_cls = MODEL_TYPES[model_config["type"]] - model_instance = model_cls(model_name, **model_config) + if "hugging_face" in model_config: + hugging_face_kwargs = download_model_files(model_config["hugging_face"]) + else: + hugging_face_kwargs = {} + model_instance = model_cls(model_name, **model_config, **hugging_face_kwargs) assert isinstance(model_instance, BasePredictor) self.models.append(model_instance) - self.smoother = PredictionSmoother(ChEBIOver50(chebi_version=chebi_version), disjoint_files=[ + self.chebi_dataset = ChEBIOver50(chebi_version=chebi_version) + self.chebi_dataset._download_required_data() # download chebi if not already downloaded + self.disjoint_files=[ os.path.join("data", "disjoint_chebi.csv"), os.path.join("data", "disjoint_additional.csv") - ]) + ] def gather_predictions(self, smiles_list): @@ -110,7 +117,7 @@ def calculate_classwise_weights(self, predicted_classes): return positive_weights, negative_weights - def predict_smiles_list(self, smiles_list, load_preds_if_possible=True) -> list: + def predict_smiles_list(self, smiles_list, load_preds_if_possible=True, **kwargs) -> list: preds_file = f"predictions_by_model_{'_'.join(model.model_name for model in self.models)}.pt" predicted_classes_file = f"predicted_classes_{'_'.join(model.model_name for model in self.models)}.txt" if not load_preds_if_possible or not os.path.isfile(preds_file): @@ -128,11 +135,12 @@ def predict_smiles_list(self, smiles_list, load_preds_if_possible=True) -> list: predicted_classes = {line.strip(): i for i, line in enumerate(f.readlines())} classwise_weights = self.calculate_classwise_weights(predicted_classes) - class_decisions = self.consolidate_predictions(ordered_predictions, classwise_weights) + class_decisions = self.consolidate_predictions(ordered_predictions, classwise_weights, **kwargs) # Smooth predictions class_names = list(predicted_classes.keys()) - self.smoother.label_names = class_names - class_decisions = self.smoother(class_decisions) + # initialise new smoother class since we don't know the labels beforehand (this could be more efficient) + new_smoother = PredictionSmoother(self.chebi_dataset, label_names=class_names, disjoint_files=self.disjoint_files) + class_decisions = new_smoother(class_decisions) class_names = list(predicted_classes.keys()) class_indices = {predicted_classes[cls]: cls for cls in class_names} diff --git a/chebifier/model_registry.py b/chebifier/model_registry.py index 4961f3e..cf7d6d0 100644 --- a/chebifier/model_registry.py +++ b/chebifier/model_registry.py @@ -10,9 +10,9 @@ ) ENSEMBLES = { - "en_mv": BaseEnsemble, - "en_wmv-ppvnpv": WMVwithPPVNPVEnsemble, - "en_wmv-f1": WMVwithF1Ensemble, + "mv": BaseEnsemble, + "wmv-ppvnpv": WMVwithPPVNPVEnsemble, + "wmv-f1": WMVwithF1Ensemble, } diff --git a/configs/huggingface_config.yml b/configs/huggingface_config.yml new file mode 100644 index 0000000..c26950d --- /dev/null +++ b/configs/huggingface_config.yml @@ -0,0 +1,22 @@ + +chemlog_peptides: + type: chemlog + model_weight: 100 + +#resgated_huggingface: +# type: resgated +# hugging_face: +# repo_id: aditya0by0/python-chebifier +# subfolder: resgated +# files: +# ckpt: resgated.ckpt +# labels: classes.txt + +electra_huggingface: + type: electra + hugging_face: + repo_id: aditya0by0/python-chebifier + subfolder: electra + files: + ckpt: electra.ckpt + labels: classes.txt diff --git a/pyproject.toml b/pyproject.toml index 8a0223f..ff7837d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,9 +27,6 @@ dependencies = [ "chemlog>=1.0.4" ] -[project.scripts] -chebifier = "chebifier.cli:cli" - [tool.setuptools] packages = ["chebifier", "chebifier.ensemble", "chebifier.prediction_models"] From 90aedd43a105f83c65f2351bd59d124fc0bc0c51 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 11 Jul 2025 13:11:10 +0200 Subject: [PATCH 20/38] reformat with black --- chebifier/__main__.py | 4 +- chebifier/cli.py | 86 +++++++++++++++---- chebifier/ensemble/base_ensemble.py | 125 +++++++++++++++++----------- 3 files changed, 148 insertions(+), 67 deletions(-) diff --git a/chebifier/__main__.py b/chebifier/__main__.py index 9aebe0f..22bf70c 100644 --- a/chebifier/__main__.py +++ b/chebifier/__main__.py @@ -1,4 +1,4 @@ from chebifier.cli import cli -if __name__ == '__main__': - cli() \ No newline at end of file +if __name__ == "__main__": + cli() diff --git a/chebifier/cli.py b/chebifier/cli.py index a21ebf3..11c138b 100644 --- a/chebifier/cli.py +++ b/chebifier/cli.py @@ -5,49 +5,99 @@ from .model_registry import ENSEMBLES + @click.group() def cli(): """Command line interface for Chebifier.""" pass + @cli.command() -@click.option('--config_file', type=click.Path(exists=True), default=os.path.join('configs', 'huggingface_config.yml'), help="Configuration file for ensemble models") -@click.option('--smiles', '-s', multiple=True, help='SMILES strings to predict') -@click.option('--smiles-file', '-f', type=click.Path(exists=True), help='File containing SMILES strings (one per line)') -@click.option('--output', '-o', type=click.Path(), help='Output file to save predictions (optional)') -@click.option('--ensemble-type', '-e', type=click.Choice(ENSEMBLES.keys()), default='mv', help='Type of ensemble to use (default: Majority Voting)') -@click.option("--chebi-version", "-v", type=int, default=241, help="ChEBI version to use for checking consistency (default: 241)") -@click.option("--use-confidence", "-c", is_flag=True, default=True, help="Weight predictions based on how 'confident' a model is in its prediction (default: True)") -def predict(config_file, smiles, smiles_file, output, ensemble_type, chebi_version, use_confidence): +@click.option( + "--config_file", + type=click.Path(exists=True), + default=os.path.join("configs", "huggingface_config.yml"), + help="Configuration file for ensemble models", +) +@click.option("--smiles", "-s", multiple=True, help="SMILES strings to predict") +@click.option( + "--smiles-file", + "-f", + type=click.Path(exists=True), + help="File containing SMILES strings (one per line)", +) +@click.option( + "--output", + "-o", + type=click.Path(), + help="Output file to save predictions (optional)", +) +@click.option( + "--ensemble-type", + "-e", + type=click.Choice(ENSEMBLES.keys()), + default="mv", + help="Type of ensemble to use (default: Majority Voting)", +) +@click.option( + "--chebi-version", + "-v", + type=int, + default=241, + help="ChEBI version to use for checking consistency (default: 241)", +) +@click.option( + "--use-confidence", + "-c", + is_flag=True, + default=True, + help="Weight predictions based on how 'confident' a model is in its prediction (default: True)", +) +def predict( + config_file, + smiles, + smiles_file, + output, + ensemble_type, + chebi_version, + use_confidence, +): """Predict ChEBI classes for SMILES strings using an ensemble model. - + CONFIG_FILE is the path to a YAML configuration file for the ensemble model. """ # Load configuration from YAML file - with open(config_file, 'r') as f: + with open(config_file, "r") as f: config = yaml.safe_load(f) - + # Instantiate ensemble model ensemble = ENSEMBLES[ensemble_type](config, chebi_version=chebi_version) - + # Collect SMILES strings from arguments and/or file smiles_list = list(smiles) if smiles_file: - with open(smiles_file, 'r') as f: + with open(smiles_file, "r") as f: smiles_list.extend([line.strip() for line in f if line.strip()]) - + if not smiles_list: click.echo("No SMILES strings provided. Use --smiles or --smiles-file options.") return # Make predictions - predictions = ensemble.predict_smiles_list(smiles_list, use_confidence=use_confidence) + predictions = ensemble.predict_smiles_list( + smiles_list, use_confidence=use_confidence + ) if output: # save as json import json - with open(output, 'w') as f: - json.dump({smiles: pred for smiles, pred in zip(smiles_list, predictions)}, f, indent=2) + + with open(output, "w") as f: + json.dump( + {smiles: pred for smiles, pred in zip(smiles_list, predictions)}, + f, + indent=2, + ) else: # Print results @@ -59,5 +109,5 @@ def predict(config_file, smiles, smiles_file, output, ensemble_type, chebi_versi click.echo(" No predictions") -if __name__ == '__main__': +if __name__ == "__main__": cli() diff --git a/chebifier/ensemble/base_ensemble.py b/chebifier/ensemble/base_ensemble.py index a946c2a..0795fc0 100644 --- a/chebifier/ensemble/base_ensemble.py +++ b/chebifier/ensemble/base_ensemble.py @@ -22,18 +22,19 @@ def __init__(self, model_configs: dict, chebi_version: int = 241): hugging_face_kwargs = download_model_files(model_config["hugging_face"]) else: hugging_face_kwargs = {} - model_instance = model_cls(model_name, **model_config, **hugging_face_kwargs) + model_instance = model_cls( + model_name, **model_config, **hugging_face_kwargs + ) assert isinstance(model_instance, BasePredictor) self.models.append(model_instance) self.chebi_dataset = ChEBIOver50(chebi_version=chebi_version) self.chebi_dataset._download_required_data() # download chebi if not already downloaded - self.disjoint_files=[ + self.disjoint_files = [ os.path.join("data", "disjoint_chebi.csv"), - os.path.join("data", "disjoint_additional.csv") + os.path.join("data", "disjoint_additional.csv"), ] - def gather_predictions(self, smiles_list): # get predictions from all models for the SMILES list # order them by alphabetically by label class @@ -60,11 +61,12 @@ def gather_predictions(self, smiles_list): ): if logits_for_smiles is not None: for cls in logits_for_smiles: - ordered_logits[j, predicted_classes_dict[cls], i] = logits_for_smiles[cls] + ordered_logits[j, predicted_classes_dict[cls], i] = ( + logits_for_smiles[cls] + ) return ordered_logits, predicted_classes - def consolidate_predictions(self, predictions, classwise_weights, **kwargs): """ Aggregates predictions from multiple models using weighted majority voting. @@ -80,11 +82,17 @@ def consolidate_predictions(self, predictions, classwise_weights, **kwargs): has_valid_predictions = valid_counts > 0 # Calculate positive and negative predictions for all classes at once - positive_mask = (predictions > self.positive_prediction_threshold) & valid_predictions - negative_mask = (predictions < self.positive_prediction_threshold) & valid_predictions + positive_mask = ( + predictions > self.positive_prediction_threshold + ) & valid_predictions + negative_mask = ( + predictions < self.positive_prediction_threshold + ) & valid_predictions if "use_confidence" in kwargs and kwargs["use_confidence"]: - confidence = 2 * torch.abs(predictions.nan_to_num() - self.positive_prediction_threshold) + confidence = 2 * torch.abs( + predictions.nan_to_num() - self.positive_prediction_threshold + ) else: confidence = torch.ones_like(predictions) @@ -95,8 +103,12 @@ def consolidate_predictions(self, predictions, classwise_weights, **kwargs): # Calculate weighted predictions using broadcasting # predictions shape: (num_smiles, num_classes, num_models) # weights shape: (num_classes, num_models) - positive_weighted = positive_mask.float() * confidence * pos_weights.unsqueeze(0) - negative_weighted = negative_mask.float() * confidence * neg_weights.unsqueeze(0) + positive_weighted = ( + positive_mask.float() * confidence * pos_weights.unsqueeze(0) + ) + negative_weighted = ( + negative_mask.float() * confidence * neg_weights.unsqueeze(0) + ) # Sum over models dimension positive_sum = positive_weighted.sum(dim=2) # Shape: (num_smiles, num_classes) @@ -104,9 +116,9 @@ def consolidate_predictions(self, predictions, classwise_weights, **kwargs): # Determine which classes to include for each SMILES net_score = positive_sum - negative_sum # Shape: (num_smiles, num_classes) - class_decisions = (net_score > 0) & has_valid_predictions # Shape: (num_smiles, num_classes) - - + class_decisions = ( + net_score > 0 + ) & has_valid_predictions # Shape: (num_smiles, num_classes) return class_decisions @@ -117,11 +129,15 @@ def calculate_classwise_weights(self, predicted_classes): return positive_weights, negative_weights - def predict_smiles_list(self, smiles_list, load_preds_if_possible=True, **kwargs) -> list: + def predict_smiles_list( + self, smiles_list, load_preds_if_possible=True, **kwargs + ) -> list: preds_file = f"predictions_by_model_{'_'.join(model.model_name for model in self.models)}.pt" predicted_classes_file = f"predicted_classes_{'_'.join(model.model_name for model in self.models)}.txt" if not load_preds_if_possible or not os.path.isfile(preds_file): - ordered_predictions, predicted_classes = self.gather_predictions(smiles_list) + ordered_predictions, predicted_classes = self.gather_predictions( + smiles_list + ) # save predictions torch.save(ordered_predictions, preds_file) with open(predicted_classes_file, "w") as f: @@ -129,17 +145,27 @@ def predict_smiles_list(self, smiles_list, load_preds_if_possible=True, **kwargs f.write(f"{cls}\n") predicted_classes = {cls: i for i, cls in enumerate(predicted_classes)} else: - print(f"Loading predictions from {preds_file} and label indexes from {predicted_classes_file}") + print( + f"Loading predictions from {preds_file} and label indexes from {predicted_classes_file}" + ) ordered_predictions = torch.load(preds_file) with open(predicted_classes_file, "r") as f: - predicted_classes = {line.strip(): i for i, line in enumerate(f.readlines())} + predicted_classes = { + line.strip(): i for i, line in enumerate(f.readlines()) + } classwise_weights = self.calculate_classwise_weights(predicted_classes) - class_decisions = self.consolidate_predictions(ordered_predictions, classwise_weights, **kwargs) + class_decisions = self.consolidate_predictions( + ordered_predictions, classwise_weights, **kwargs + ) # Smooth predictions class_names = list(predicted_classes.keys()) # initialise new smoother class since we don't know the labels beforehand (this could be more efficient) - new_smoother = PredictionSmoother(self.chebi_dataset, label_names=class_names, disjoint_files=self.disjoint_files) + new_smoother = PredictionSmoother( + self.chebi_dataset, + label_names=class_names, + disjoint_files=self.disjoint_files, + ) class_decisions = new_smoother(class_decisions) class_names = list(predicted_classes.keys()) @@ -153,31 +179,36 @@ def predict_smiles_list(self, smiles_list, load_preds_if_possible=True, **kwargs if __name__ == "__main__": - ensemble = BaseEnsemble({"resgated_0ps1g189":{ - "type": "resgated", - "ckpt_path": "data/0ps1g189/epoch=122.ckpt", - "target_labels_path": "data/chebi_v241/ChEBI50/processed/classes.txt", - "molecular_properties": [ - "chebai_graph.preprocessing.properties.AtomType", - "chebai_graph.preprocessing.properties.NumAtomBonds", - "chebai_graph.preprocessing.properties.AtomCharge", - "chebai_graph.preprocessing.properties.AtomAromaticity", - "chebai_graph.preprocessing.properties.AtomHybridization", - "chebai_graph.preprocessing.properties.AtomNumHs", - "chebai_graph.preprocessing.properties.BondType", - "chebai_graph.preprocessing.properties.BondInRing", - "chebai_graph.preprocessing.properties.BondAromaticity", - "chebai_graph.preprocessing.properties.RDKit2DNormalized", - ], - #"classwise_weights_path" : "../python-chebai/metrics_0ps1g189_80-10-10.json" - }, - -"electra_14ko0zcf": { - "type": "electra", - "ckpt_path": "data/14ko0zcf/epoch=193.ckpt", - "target_labels_path": "data/chebi_v241/ChEBI50/processed/classes.txt", - #"classwise_weights_path": "../python-chebai/metrics_electra_14ko0zcf_80-10-10.json", -} - }) - r = ensemble.predict_smiles_list(["[NH3+]CCCC[C@H](NC(=O)[C@@H]([NH3+])CC([O-])=O)C([O-])=O"], load_preds_if_possible=False) + ensemble = BaseEnsemble( + { + "resgated_0ps1g189": { + "type": "resgated", + "ckpt_path": "data/0ps1g189/epoch=122.ckpt", + "target_labels_path": "data/chebi_v241/ChEBI50/processed/classes.txt", + "molecular_properties": [ + "chebai_graph.preprocessing.properties.AtomType", + "chebai_graph.preprocessing.properties.NumAtomBonds", + "chebai_graph.preprocessing.properties.AtomCharge", + "chebai_graph.preprocessing.properties.AtomAromaticity", + "chebai_graph.preprocessing.properties.AtomHybridization", + "chebai_graph.preprocessing.properties.AtomNumHs", + "chebai_graph.preprocessing.properties.BondType", + "chebai_graph.preprocessing.properties.BondInRing", + "chebai_graph.preprocessing.properties.BondAromaticity", + "chebai_graph.preprocessing.properties.RDKit2DNormalized", + ], + # "classwise_weights_path" : "../python-chebai/metrics_0ps1g189_80-10-10.json" + }, + "electra_14ko0zcf": { + "type": "electra", + "ckpt_path": "data/14ko0zcf/epoch=193.ckpt", + "target_labels_path": "data/chebi_v241/ChEBI50/processed/classes.txt", + # "classwise_weights_path": "../python-chebai/metrics_electra_14ko0zcf_80-10-10.json", + }, + } + ) + r = ensemble.predict_smiles_list( + ["[NH3+]CCCC[C@H](NC(=O)[C@@H]([NH3+])CC([O-])=O)C([O-])=O"], + load_preds_if_possible=False, + ) print(len(r), r[0]) From 2bead4a561101942ab51391250be7b7cdd3e5942 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 11 Jul 2025 16:05:27 +0200 Subject: [PATCH 21/38] use None values to mark samples where all methods failed (usually due to a faulty SMILES string) --- chebifier/ensemble/base_ensemble.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/chebifier/ensemble/base_ensemble.py b/chebifier/ensemble/base_ensemble.py index 0795fc0..b3a0fd1 100644 --- a/chebifier/ensemble/base_ensemble.py +++ b/chebifier/ensemble/base_ensemble.py @@ -1,10 +1,11 @@ import os +import time + import torch import tqdm from chebai.preprocessing.datasets.chebi import ChEBIOver50 from chebai.result.analyse_sem import PredictionSmoother -from api.hugging_face import download_model_files from chebifier.prediction_models.base_predictor import BasePredictor @@ -19,6 +20,7 @@ def __init__(self, model_configs: dict, chebi_version: int = 241): for model_name, model_config in model_configs.items(): model_cls = MODEL_TYPES[model_config["type"]] if "hugging_face" in model_config: + from api.hugging_face import download_model_files hugging_face_kwargs = download_model_files(model_config["hugging_face"]) else: hugging_face_kwargs = {} @@ -118,9 +120,10 @@ def consolidate_predictions(self, predictions, classwise_weights, **kwargs): net_score = positive_sum - negative_sum # Shape: (num_smiles, num_classes) class_decisions = ( net_score > 0 - ) & has_valid_predictions # Shape: (num_smiles, num_classes) + ) & has_valid_predictions # Shape: (num_smiles, num_classes) - return class_decisions + complete_failure = torch.all(~has_valid_predictions, dim=1) + return class_decisions, complete_failure def calculate_classwise_weights(self, predicted_classes): """No weights, simple majority voting""" @@ -155,24 +158,27 @@ def predict_smiles_list( } classwise_weights = self.calculate_classwise_weights(predicted_classes) - class_decisions = self.consolidate_predictions( + class_decisions, is_failure = self.consolidate_predictions( ordered_predictions, classwise_weights, **kwargs ) # Smooth predictions + start_time = time.perf_counter() class_names = list(predicted_classes.keys()) - # initialise new smoother class since we don't know the labels beforehand (this could be more efficient) + # initialise new smoother class since we don't know the labels beforehand (#todo this could be more efficient) new_smoother = PredictionSmoother( self.chebi_dataset, label_names=class_names, disjoint_files=self.disjoint_files, ) class_decisions = new_smoother(class_decisions) + end_time = time.perf_counter() + print(f"Prediction smoothing took {end_time - start_time:.2f} seconds") class_names = list(predicted_classes.keys()) class_indices = {predicted_classes[cls]: cls for cls in class_names} result = [ - [class_indices[idx.item()] for idx in torch.nonzero(i, as_tuple=True)[0]] - for i in class_decisions + [class_indices[idx.item()] for idx in torch.nonzero(i, as_tuple=True)[0]] if not failure else None + for i, failure in zip(class_decisions, is_failure) ] return result @@ -208,7 +214,7 @@ def predict_smiles_list( } ) r = ensemble.predict_smiles_list( - ["[NH3+]CCCC[C@H](NC(=O)[C@@H]([NH3+])CC([O-])=O)C([O-])=O"], + ["[NH3+]CCCC[C@H](NC(=O)[C@@H]([NH3+])CC([O-])=O)C([O-])=O", "C[C@H](N)C(=O)NCC(O)=O#", ""], load_preds_if_possible=False, ) print(len(r), r[0]) From df68ecb9f9e75faea94940850ab17a2030a3edb7 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 11 Jul 2025 18:01:02 +0200 Subject: [PATCH 22/38] init smoother at init to avoid re-initialising it for every prediction-call --- chebifier/ensemble/base_ensemble.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/chebifier/ensemble/base_ensemble.py b/chebifier/ensemble/base_ensemble.py index b3a0fd1..207a96a 100644 --- a/chebifier/ensemble/base_ensemble.py +++ b/chebifier/ensemble/base_ensemble.py @@ -37,6 +37,12 @@ def __init__(self, model_configs: dict, chebi_version: int = 241): os.path.join("data", "disjoint_additional.csv"), ] + self.smoother = PredictionSmoother( + self.chebi_dataset, + label_names=None, + disjoint_files=self.disjoint_files, + ) + def gather_predictions(self, smiles_list): # get predictions from all models for the SMILES list # order them by alphabetically by label class @@ -164,13 +170,8 @@ def predict_smiles_list( # Smooth predictions start_time = time.perf_counter() class_names = list(predicted_classes.keys()) - # initialise new smoother class since we don't know the labels beforehand (#todo this could be more efficient) - new_smoother = PredictionSmoother( - self.chebi_dataset, - label_names=class_names, - disjoint_files=self.disjoint_files, - ) - class_decisions = new_smoother(class_decisions) + self.smoother.set_label_names(class_names) + class_decisions = self.smoother(class_decisions) end_time = time.perf_counter() print(f"Prediction smoothing took {end_time - start_time:.2f} seconds") From c575637c1ac0bf8f48cdf2326e69bf4a6ecefcae Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 14 Jul 2025 17:53:47 +0200 Subject: [PATCH 23/38] add lookup classifier and chemlog-by-element classifier --- chebifier/model_registry.py | 7 +- chebifier/prediction_models/__init__.py | 7 +- chebifier/prediction_models/chebi_lookup.py | 115 ++++++++++++++++++ .../prediction_models/chemlog_predictor.py | 25 +++- 4 files changed, 147 insertions(+), 7 deletions(-) create mode 100644 chebifier/prediction_models/chebi_lookup.py diff --git a/chebifier/model_registry.py b/chebifier/model_registry.py index cf7d6d0..019e3f9 100644 --- a/chebifier/model_registry.py +++ b/chebifier/model_registry.py @@ -4,9 +4,10 @@ WMVwithPPVNPVEnsemble, ) from chebifier.prediction_models import ( - ChemLogPredictor, + ChemlogPeptidesPredictor, ElectraPredictor, ResGatedPredictor, + ChEBILookupPredictor, ChemlogByElementPredictor ) ENSEMBLES = { @@ -19,7 +20,9 @@ MODEL_TYPES = { "electra": ElectraPredictor, "resgated": ResGatedPredictor, - "chemlog": ChemLogPredictor, + "chemlog_peptides": ChemlogPeptidesPredictor, + "chebi_lookup": ChEBILookupPredictor, + "chemlog_element": ChemlogByElementPredictor } diff --git a/chebifier/prediction_models/__init__.py b/chebifier/prediction_models/__init__.py index ed08890..ec0726a 100644 --- a/chebifier/prediction_models/__init__.py +++ b/chebifier/prediction_models/__init__.py @@ -1,6 +1,7 @@ from .base_predictor import BasePredictor -from .chemlog_predictor import ChemLogPredictor +from .chemlog_predictor import ChemlogPeptidesPredictor, ChemlogByElementPredictor from .electra_predictor import ElectraPredictor from .gnn_predictor import ResGatedPredictor - -__all__ = ["BasePredictor", "ChemLogPredictor", "ElectraPredictor", "ResGatedPredictor"] +from .chebi_lookup import ChEBILookupPredictor +__all__ = ["BasePredictor", "ChemlogPeptidesPredictor", "ElectraPredictor", "ResGatedPredictor", "ChEBILookupPredictor", + "ChemlogByElementPredictor"] diff --git a/chebifier/prediction_models/chebi_lookup.py b/chebifier/prediction_models/chebi_lookup.py new file mode 100644 index 0000000..070e3b1 --- /dev/null +++ b/chebifier/prediction_models/chebi_lookup.py @@ -0,0 +1,115 @@ +from chebifier.prediction_models import BasePredictor +import os +import networkx as nx +from rdkit import Chem +import json + +class ChEBILookupPredictor(BasePredictor): + + def __init__(self, model_name: str, description: str = None, chebi_version: int = 241, **kwargs): + super().__init__(model_name, **kwargs) + self._description = description or "ChEBI Lookup: If the SMILES is equivalent to a ChEBI entry, retrieve the classification of that entry." + self.chebi_version = chebi_version + self.lookup_table = self.get_smiles_lookup() + + def get_smiles_lookup(self): + path = os.path.join("data", f"chebi_v{self.chebi_version}", "smiles_lookup.json") + if not os.path.exists(path): + smiles_lookup = self.build_smiles_lookup() + with open(path, "w", encoding="utf-8") as f: + json.dump(smiles_lookup, f, indent=4) + else: + print("Loading existing SMILES lookup...") + with open(path, "r", encoding="utf-8") as f: + smiles_lookup = json.load(f) + return smiles_lookup + + + def build_smiles_lookup(self): + # todo test + from chebai.preprocessing.datasets.chebi import ChEBIOver50 + self.chebi_dataset = ChEBIOver50(chebi_version=self.chebi_version) + self.chebi_dataset._download_required_data() + chebi_graph = self.chebi_dataset._extract_class_hierarchy( + os.path.join(self.chebi_dataset.raw_dir, "chebi.obo") + ) + smiles_lookup = dict() + for chebi_id, smiles in nx.get_node_attributes(chebi_graph, "smiles").items(): + if smiles is not None: + try: + mol = Chem.MolFromSmiles(smiles) + if mol is None: + print(f"Failed to parse SMILES {smiles} for ChEBI ID {chebi_id}") + continue + canonical_smiles = Chem.MolToSmiles(mol) + if canonical_smiles not in smiles_lookup: + smiles_lookup[canonical_smiles] = [] + # if the canonical SMILES is already in the lookup, append "different interpretation of the SMILES" + smiles_lookup[canonical_smiles].append((chebi_id, list(chebi_graph.predecessors(chebi_id)))) + except Exception as e: + print(f"Failed to parse SMILES {smiles} for ChEBI ID {chebi_id}: {e}") + return smiles_lookup + + + def predict_smiles_list(self, smiles_list: list[str]) -> list: + predictions = [] + for smiles in smiles_list: + if not smiles: + predictions.append(None) + continue + mol = Chem.MolFromSmiles(smiles) + if mol is None: + predictions.append(None) + continue + canonical_smiles = Chem.MolToSmiles(mol) + if canonical_smiles in self.lookup_table: + parent_candidates = self.lookup_table[canonical_smiles] + preds_i = dict() + if len(parent_candidates) > 1: + print(f"Multiple matches found in ChEBI for SMILES {smiles}: {', '.join(str(chebi_id) for chebi_id, _ in parent_candidates)}") + for k in list(set(pp for _, p in parent_candidates for pp in p)): + preds_i[str(k)] = 1 + elif len(parent_candidates) == 1: + chebi_id, parents = parent_candidates[0] + for k in parents: + preds_i[str(k)] = 1 + else: + preds_i = None + predictions.append(preds_i) + + return predictions + + @property + def info_text(self): + if self._description is None: + return "No description is available for this model." + return self._description + + def explain_smiles(self, smiles: str) -> dict: + mol = Chem.MolFromSmiles(smiles) + if mol is None: + return {"highlights": [ + ("text", "The input SMILES could not be parsed into a valid molecule.") + ]} + canonical_smiles = Chem.MolToSmiles(mol) + if canonical_smiles not in self.lookup_table: + return {"highlights": [ + ("text", "The input SMILES does not match any ChEBI entry.") + ]} + parent_candidates = self.lookup_table[canonical_smiles] + return {"highlights": [ + ("text", + f"The ChEBI Lookup matches the canonical version of the input SMILES against ChEBI (v{self.chebi_version})." + f" It found {'1 match' if len(parent_candidates) == 1 else f'{len(parent_candidates)} matches'}:" + f" {', '.join(f'CHEBI:{chebi_id}' for chebi_id, _ in parent_candidates)}. The predicted classes are the" + f" parent classes of the matched ChEBI entries.") + ]} + + +if __name__ == "__main__": + predictor = ChEBILookupPredictor("ChEBI Lookup") + print(predictor.info_text) + # Example usage + smiles_list = ["CCO", "C1=CC=CC=C1" '*C(=O)OC[C@H](COP(=O)([O-])OCC[N+](C)(C)C)OC(*)=O'] # SMILES with 251 matches in ChEBI + predictions = predictor.predict_smiles_list(smiles_list) + print(predictions) \ No newline at end of file diff --git a/chebifier/prediction_models/chemlog_predictor.py b/chebifier/prediction_models/chemlog_predictor.py index 4bcb9b8..63f0467 100644 --- a/chebifier/prediction_models/chemlog_predictor.py +++ b/chebifier/prediction_models/chemlog_predictor.py @@ -9,6 +9,7 @@ is_emericellamide, ) from chemlog.cli import CLASSIFIERS, _smiles_to_mol, strategy_call +from chemlog_extra.alg_classification.by_element_classification import XMolecularEntityClassifier, OrganoXCompoundClassifier from .base_predictor import BasePredictor @@ -38,8 +39,23 @@ "Y": "L-tyrosine", } +class ChemlogByElementPredictor(BasePredictor): -class ChemLogPredictor(BasePredictor): + def __init__(self, model_name: str, **kwargs): + super().__init__(model_name, **kwargs) + self.x_molecular = XMolecularEntityClassifier() + self.organo_x = OrganoXCompoundClassifier() + + def predict_smiles_list(self, smiles_list: list[str]) -> list: + mol_list = [_smiles_to_mol(smiles) for smiles in smiles_list] + return [ + {str(cls): 1 for cls in self.x_molecular.classify(mol)[0] + self.organo_x.classify(mol)[0]} + if mol + else None + for mol in mol_list + ] + +class ChemlogPeptidesPredictor(BasePredictor): def __init__(self, model_name: str, **kwargs): super().__init__(model_name, **kwargs) self.strategy = "algo" @@ -333,7 +349,12 @@ def build_explain_blocks_proteinogenics(self, proteinogenics, atoms): def explain_smiles(self, smiles) -> dict: info = self.get_chemlog_result_info(smiles) - highlight_blocks = self.build_explain_blocks_peptides(info) + zero_blocks = [ + ("text", "Results for peptides and peptide-related classes (e.g. peptide anion, depsipeptide) have been calculated" + "with a rule-based system. The following shows which parts of the molecule were identified as relevant" + "structures and have influenced the classification.") + ] + highlight_blocks = zero_blocks + self.build_explain_blocks_peptides(info) for chebi_id, internal_name in [ (64372, "emericellamide"), From 32da68ba90d343c7ff419663668e0ec05200acd2 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Tue, 15 Jul 2025 13:52:06 +0200 Subject: [PATCH 24/38] split chemlog extra predictors into two for generalisability --- chebifier/cli.py | 2 +- chebifier/model_registry.py | 6 +++-- chebifier/prediction_models/__init__.py | 4 ++-- .../prediction_models/chemlog_predictor.py | 22 +++++++++++-------- 4 files changed, 20 insertions(+), 14 deletions(-) diff --git a/chebifier/cli.py b/chebifier/cli.py index 11c138b..b3e7230 100644 --- a/chebifier/cli.py +++ b/chebifier/cli.py @@ -3,7 +3,7 @@ import click import yaml -from .model_registry import ENSEMBLES +from chebifier.model_registry import ENSEMBLES @click.group() diff --git a/chebifier/model_registry.py b/chebifier/model_registry.py index 019e3f9..46dcf33 100644 --- a/chebifier/model_registry.py +++ b/chebifier/model_registry.py @@ -7,8 +7,9 @@ ChemlogPeptidesPredictor, ElectraPredictor, ResGatedPredictor, - ChEBILookupPredictor, ChemlogByElementPredictor + ChEBILookupPredictor ) +from chebifier.prediction_models.chemlog_predictor import ChemlogXMolecularEntityPredictor, ChemlogOrganoXCompoundPredictor ENSEMBLES = { "mv": BaseEnsemble, @@ -22,7 +23,8 @@ "resgated": ResGatedPredictor, "chemlog_peptides": ChemlogPeptidesPredictor, "chebi_lookup": ChEBILookupPredictor, - "chemlog_element": ChemlogByElementPredictor + "chemlog_element": ChemlogXMolecularEntityPredictor, + "chemlog_organox": ChemlogOrganoXCompoundPredictor, } diff --git a/chebifier/prediction_models/__init__.py b/chebifier/prediction_models/__init__.py index ec0726a..ce33cca 100644 --- a/chebifier/prediction_models/__init__.py +++ b/chebifier/prediction_models/__init__.py @@ -1,7 +1,7 @@ from .base_predictor import BasePredictor -from .chemlog_predictor import ChemlogPeptidesPredictor, ChemlogByElementPredictor +from .chemlog_predictor import ChemlogPeptidesPredictor, ChemlogExtraPredictor from .electra_predictor import ElectraPredictor from .gnn_predictor import ResGatedPredictor from .chebi_lookup import ChEBILookupPredictor __all__ = ["BasePredictor", "ChemlogPeptidesPredictor", "ElectraPredictor", "ResGatedPredictor", "ChEBILookupPredictor", - "ChemlogByElementPredictor"] + "ChemlogExtraPredictor"] diff --git a/chebifier/prediction_models/chemlog_predictor.py b/chebifier/prediction_models/chemlog_predictor.py index 63f0467..6af0cf5 100644 --- a/chebifier/prediction_models/chemlog_predictor.py +++ b/chebifier/prediction_models/chemlog_predictor.py @@ -39,21 +39,25 @@ "Y": "L-tyrosine", } -class ChemlogByElementPredictor(BasePredictor): +class ChemlogExtraPredictor(BasePredictor): + + CHEMLOG_CLASSIFIER = None def __init__(self, model_name: str, **kwargs): super().__init__(model_name, **kwargs) - self.x_molecular = XMolecularEntityClassifier() - self.organo_x = OrganoXCompoundClassifier() + self.classifier = self.CHEMLOG_CLASSIFIER() def predict_smiles_list(self, smiles_list: list[str]) -> list: mol_list = [_smiles_to_mol(smiles) for smiles in smiles_list] - return [ - {str(cls): 1 for cls in self.x_molecular.classify(mol)[0] + self.organo_x.classify(mol)[0]} - if mol - else None - for mol in mol_list - ] + return self.classifier.classify(mol_list) + +class ChemlogXMolecularEntityPredictor(ChemlogExtraPredictor): + + CHEMLOG_CLASSIFIER = XMolecularEntityClassifier + +class ChemlogOrganoXCompoundPredictor(ChemlogExtraPredictor): + + CHEMLOG_CLASSIFIER = OrganoXCompoundClassifier class ChemlogPeptidesPredictor(BasePredictor): def __init__(self, model_name: str, **kwargs): From ecb48ffbf12b4101a78636dd86079cacea73a3ad Mon Sep 17 00:00:00 2001 From: sfluegel Date: Tue, 15 Jul 2025 18:10:18 +0200 Subject: [PATCH 25/38] fix typos --- chebifier/prediction_models/chemlog_predictor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chebifier/prediction_models/chemlog_predictor.py b/chebifier/prediction_models/chemlog_predictor.py index 6af0cf5..b40410f 100644 --- a/chebifier/prediction_models/chemlog_predictor.py +++ b/chebifier/prediction_models/chemlog_predictor.py @@ -355,8 +355,8 @@ def explain_smiles(self, smiles) -> dict: info = self.get_chemlog_result_info(smiles) zero_blocks = [ ("text", "Results for peptides and peptide-related classes (e.g. peptide anion, depsipeptide) have been calculated" - "with a rule-based system. The following shows which parts of the molecule were identified as relevant" - "structures and have influenced the classification.") + " with a rule-based system. The following shows which parts of the molecule were identified as relevant" + " structures and have influenced the classification.") ] highlight_blocks = zero_blocks + self.build_explain_blocks_peptides(info) From 89b4812c109ff1a08565bfcf2e2376e23692665f Mon Sep 17 00:00:00 2001 From: sfluegel Date: Thu, 17 Jul 2025 14:30:41 +0200 Subject: [PATCH 26/38] add c3p integration --- chebifier/model_registry.py | 2 ++ chebifier/prediction_models/c3p_predictor.py | 27 ++++++++++++++++++++ 2 files changed, 29 insertions(+) create mode 100644 chebifier/prediction_models/c3p_predictor.py diff --git a/chebifier/model_registry.py b/chebifier/model_registry.py index 46dcf33..5020f17 100644 --- a/chebifier/model_registry.py +++ b/chebifier/model_registry.py @@ -9,6 +9,7 @@ ResGatedPredictor, ChEBILookupPredictor ) +from chebifier.prediction_models.c3p_predictor import C3PPredictor from chebifier.prediction_models.chemlog_predictor import ChemlogXMolecularEntityPredictor, ChemlogOrganoXCompoundPredictor ENSEMBLES = { @@ -25,6 +26,7 @@ "chebi_lookup": ChEBILookupPredictor, "chemlog_element": ChemlogXMolecularEntityPredictor, "chemlog_organox": ChemlogOrganoXCompoundPredictor, + "c3p": C3PPredictor } diff --git a/chebifier/prediction_models/c3p_predictor.py b/chebifier/prediction_models/c3p_predictor.py new file mode 100644 index 0000000..da3c7f0 --- /dev/null +++ b/chebifier/prediction_models/c3p_predictor.py @@ -0,0 +1,27 @@ +from typing import Optional, List +from pathlib import Path + +from c3p import classifier as c3p_classifier + +from chebifier.prediction_models import BasePredictor + + +class C3PPredictor(BasePredictor): + """ + Wrapper for C3P (url). + """ + + def __init__(self, model_name: str, program_directory: Optional[Path]=None, chemical_classes: Optional[List[str]]=None, **kwargs): + super().__init__(model_name, **kwargs) + self.program_directory = program_directory + self.chemical_classes = chemical_classes + + def predict_smiles_list(self, smiles_list: list[str]) -> list: + result_list = c3p_classifier.classify(smiles_list, self.program_directory, self.chemical_classes, strict=False) + result_reformatted = [dict() for _ in range(len(smiles_list))] + for result in result_list: + result_reformatted[smiles_list.index(result.input_smiles)][result.class_id.split(":")[1]] = result.is_match + print(f"C3P predictions for {len(smiles_list)} SMILES strings:") + for i, smiles in enumerate(smiles_list): + print(f"{smiles}: {result_reformatted[i]}") + return result_reformatted \ No newline at end of file From e8e4ec35a1c9b3bd56be6fa4c12fc29c1f4cb060 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 18 Jul 2025 19:25:01 +0200 Subject: [PATCH 27/38] use class scores for smoothing, explicitly predict transitive closure for all models --- chebifier/ensemble/base_ensemble.py | 47 +++++++++++-------- chebifier/prediction_models/c3p_predictor.py | 12 +++-- chebifier/prediction_models/chebi_lookup.py | 19 ++++---- .../prediction_models/chemlog_predictor.py | 28 +++++++++-- 4 files changed, 67 insertions(+), 39 deletions(-) diff --git a/chebifier/ensemble/base_ensemble.py b/chebifier/ensemble/base_ensemble.py index 207a96a..e320e06 100644 --- a/chebifier/ensemble/base_ensemble.py +++ b/chebifier/ensemble/base_ensemble.py @@ -4,7 +4,7 @@ import torch import tqdm from chebai.preprocessing.datasets.chebi import ChEBIOver50 -from chebai.result.analyse_sem import PredictionSmoother +from chebai.result.analyse_sem import PredictionSmoother, get_chebi_graph from chebifier.prediction_models.base_predictor import BasePredictor @@ -15,6 +15,14 @@ def __init__(self, model_configs: dict, chebi_version: int = 241): # Deferred Import: To avoid circular import error from chebifier.model_registry import MODEL_TYPES + self.chebi_dataset = ChEBIOver50(chebi_version=chebi_version) + self.chebi_dataset._download_required_data() # download chebi if not already downloaded + self.chebi_graph = get_chebi_graph(self.chebi_dataset, None) + self.disjoint_files = [ + os.path.join("data", "disjoint_chebi.csv"), + os.path.join("data", "disjoint_additional.csv"), + ] + self.models = [] self.positive_prediction_threshold = 0.5 for model_name, model_config in model_configs.items(): @@ -25,17 +33,12 @@ def __init__(self, model_configs: dict, chebi_version: int = 241): else: hugging_face_kwargs = {} model_instance = model_cls( - model_name, **model_config, **hugging_face_kwargs + model_name, **model_config, **hugging_face_kwargs, chebi_graph=self.chebi_graph ) assert isinstance(model_instance, BasePredictor) self.models.append(model_instance) - self.chebi_dataset = ChEBIOver50(chebi_version=chebi_version) - self.chebi_dataset._download_required_data() # download chebi if not already downloaded - self.disjoint_files = [ - os.path.join("data", "disjoint_chebi.csv"), - os.path.join("data", "disjoint_additional.csv"), - ] + self.smoother = PredictionSmoother( self.chebi_dataset, @@ -54,7 +57,7 @@ def gather_predictions(self, smiles_list): if logits_for_smiles is not None: for cls in logits_for_smiles: predicted_classes.add(cls) - print("Sorting predictions...") + print(f"Sorting predictions from {len(model_predictions)} models...") predicted_classes = sorted(list(predicted_classes)) predicted_classes_dict = {cls: i for i, cls in enumerate(predicted_classes)} ordered_logits = ( @@ -75,7 +78,7 @@ def gather_predictions(self, smiles_list): return ordered_logits, predicted_classes - def consolidate_predictions(self, predictions, classwise_weights, **kwargs): + def consolidate_predictions(self, predictions, classwise_weights, predicted_classes, **kwargs): """ Aggregates predictions from multiple models using weighted majority voting. Optimized version using tensor operations instead of for loops. @@ -124,8 +127,17 @@ def consolidate_predictions(self, predictions, classwise_weights, **kwargs): # Determine which classes to include for each SMILES net_score = positive_sum - negative_sum # Shape: (num_smiles, num_classes) + + # Smooth predictions + start_time = time.perf_counter() + class_names = list(predicted_classes.keys()) + self.smoother.set_label_names(class_names) + smooth_net_score = self.smoother(net_score) + end_time = time.perf_counter() + print(f"Prediction smoothing took {end_time - start_time:.2f} seconds") + class_decisions = ( - net_score > 0 + smooth_net_score > 0.5 ) & has_valid_predictions # Shape: (num_smiles, num_classes) complete_failure = torch.all(~has_valid_predictions, dim=1) @@ -139,7 +151,7 @@ def calculate_classwise_weights(self, predicted_classes): return positive_weights, negative_weights def predict_smiles_list( - self, smiles_list, load_preds_if_possible=True, **kwargs + self, smiles_list, load_preds_if_possible=False, **kwargs ) -> list: preds_file = f"predictions_by_model_{'_'.join(model.model_name for model in self.models)}.pt" predicted_classes_file = f"predicted_classes_{'_'.join(model.model_name for model in self.models)}.txt" @@ -147,6 +159,8 @@ def predict_smiles_list( ordered_predictions, predicted_classes = self.gather_predictions( smiles_list ) + if len(predicted_classes) == 0: + print(f"Warning: No classes have been predicted for the given SMILES list.") # save predictions torch.save(ordered_predictions, preds_file) with open(predicted_classes_file, "w") as f: @@ -165,15 +179,8 @@ def predict_smiles_list( classwise_weights = self.calculate_classwise_weights(predicted_classes) class_decisions, is_failure = self.consolidate_predictions( - ordered_predictions, classwise_weights, **kwargs + ordered_predictions, classwise_weights, predicted_classes, **kwargs ) - # Smooth predictions - start_time = time.perf_counter() - class_names = list(predicted_classes.keys()) - self.smoother.set_label_names(class_names) - class_decisions = self.smoother(class_decisions) - end_time = time.perf_counter() - print(f"Prediction smoothing took {end_time - start_time:.2f} seconds") class_names = list(predicted_classes.keys()) class_indices = {predicted_classes[cls]: cls for cls in class_names} diff --git a/chebifier/prediction_models/c3p_predictor.py b/chebifier/prediction_models/c3p_predictor.py index da3c7f0..bc2f23e 100644 --- a/chebifier/prediction_models/c3p_predictor.py +++ b/chebifier/prediction_models/c3p_predictor.py @@ -15,13 +15,15 @@ def __init__(self, model_name: str, program_directory: Optional[Path]=None, chem super().__init__(model_name, **kwargs) self.program_directory = program_directory self.chemical_classes = chemical_classes + self.chebi_graph = kwargs.get("chebi_graph", None) def predict_smiles_list(self, smiles_list: list[str]) -> list: - result_list = c3p_classifier.classify(smiles_list, self.program_directory, self.chemical_classes, strict=False) + result_list = c3p_classifier.classify(smiles_list, self.program_directory, self.chemical_classes, strict=True) result_reformatted = [dict() for _ in range(len(smiles_list))] for result in result_list: - result_reformatted[smiles_list.index(result.input_smiles)][result.class_id.split(":")[1]] = result.is_match - print(f"C3P predictions for {len(smiles_list)} SMILES strings:") - for i, smiles in enumerate(smiles_list): - print(f"{smiles}: {result_reformatted[i]}") + chebi_id = result.class_id.split(":")[1] + result_reformatted[smiles_list.index(result.input_smiles)][chebi_id] = result.is_match + if result.is_match and self.chebi_graph is not None: + for parent in list(self.chebi_graph.predecessors(int(chebi_id))): + result_reformatted[smiles_list.index(result.input_smiles)][str(parent)] = 1 return result_reformatted \ No newline at end of file diff --git a/chebifier/prediction_models/chebi_lookup.py b/chebifier/prediction_models/chebi_lookup.py index 070e3b1..5a45fab 100644 --- a/chebifier/prediction_models/chebi_lookup.py +++ b/chebifier/prediction_models/chebi_lookup.py @@ -10,6 +10,14 @@ def __init__(self, model_name: str, description: str = None, chebi_version: int super().__init__(model_name, **kwargs) self._description = description or "ChEBI Lookup: If the SMILES is equivalent to a ChEBI entry, retrieve the classification of that entry." self.chebi_version = chebi_version + self.chebi_graph = kwargs.get("chebi_graph", None) + if self.chebi_graph is None: + from chebai.preprocessing.datasets.chebi import ChEBIOver50 + self.chebi_dataset = ChEBIOver50(chebi_version=self.chebi_version) + self.chebi_dataset._download_required_data() + self.chebi_graph = self.chebi_dataset._extract_class_hierarchy( + os.path.join(self.chebi_dataset.raw_dir, "chebi.obo") + ) self.lookup_table = self.get_smiles_lookup() def get_smiles_lookup(self): @@ -26,15 +34,8 @@ def get_smiles_lookup(self): def build_smiles_lookup(self): - # todo test - from chebai.preprocessing.datasets.chebi import ChEBIOver50 - self.chebi_dataset = ChEBIOver50(chebi_version=self.chebi_version) - self.chebi_dataset._download_required_data() - chebi_graph = self.chebi_dataset._extract_class_hierarchy( - os.path.join(self.chebi_dataset.raw_dir, "chebi.obo") - ) smiles_lookup = dict() - for chebi_id, smiles in nx.get_node_attributes(chebi_graph, "smiles").items(): + for chebi_id, smiles in nx.get_node_attributes(self.chebi_graph, "smiles").items(): if smiles is not None: try: mol = Chem.MolFromSmiles(smiles) @@ -45,7 +46,7 @@ def build_smiles_lookup(self): if canonical_smiles not in smiles_lookup: smiles_lookup[canonical_smiles] = [] # if the canonical SMILES is already in the lookup, append "different interpretation of the SMILES" - smiles_lookup[canonical_smiles].append((chebi_id, list(chebi_graph.predecessors(chebi_id)))) + smiles_lookup[canonical_smiles].append((chebi_id, list(self.chebi_graph.predecessors(chebi_id)))) except Exception as e: print(f"Failed to parse SMILES {smiles} for ChEBI ID {chebi_id}: {e}") return smiles_lookup diff --git a/chebifier/prediction_models/chemlog_predictor.py b/chebifier/prediction_models/chemlog_predictor.py index b40410f..2c1ceda 100644 --- a/chebifier/prediction_models/chemlog_predictor.py +++ b/chebifier/prediction_models/chemlog_predictor.py @@ -45,11 +45,24 @@ class ChemlogExtraPredictor(BasePredictor): def __init__(self, model_name: str, **kwargs): super().__init__(model_name, **kwargs) + self.chebi_graph = kwargs.get("chebi_graph", None) self.classifier = self.CHEMLOG_CLASSIFIER() def predict_smiles_list(self, smiles_list: list[str]) -> list: mol_list = [_smiles_to_mol(smiles) for smiles in smiles_list] - return self.classifier.classify(mol_list) + res = self.classifier.classify(mol_list) + if self.chebi_graph is not None: + for sample in res: + sample_additions = dict() + for cls in sample: + if sample[cls] == 1: + successors = list(self.chebi_graph.predecessors(int(cls))) + if successors: + for succ in successors: + sample_additions[str(succ)] = 1 + sample.update(sample_additions) + return res + class ChemlogXMolecularEntityPredictor(ChemlogExtraPredictor): @@ -63,6 +76,7 @@ class ChemlogPeptidesPredictor(BasePredictor): def __init__(self, model_name: str, **kwargs): super().__init__(model_name, **kwargs) self.strategy = "algo" + self.chebi_graph = kwargs.get("chebi_graph", None) self.classifier_instances = { k: v() for k, v in CLASSIFIERS[self.strategy].items() } @@ -81,17 +95,21 @@ def predict_smiles_list(self, smiles_list: list[str]) -> list: if mol is None: results.append(None) else: + pos_labels = [label for label in self.peptide_labels if label in strategy_call( + self.strategy, self.classifier_instances, mol + )["chebi_classes"]] + if self.chebi_graph: + indirect_pos_labels = [str(pr) for label in pos_labels for pr in self.chebi_graph.predecessors(int(label))] + pos_labels = list(set(pos_labels + indirect_pos_labels)) results.append( { label: ( 1 if label - in strategy_call( - self.strategy, self.classifier_instances, mol - )["chebi_classes"] + in pos_labels else 0 ) - for label in self.peptide_labels + for label in self.peptide_labels + pos_labels } ) From a11b0b7949b6d8a062e9256845ca8ef6bd660e09 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 18 Jul 2025 19:25:29 +0200 Subject: [PATCH 28/38] restructure error handling and update cache indexing for gnn --- chebifier/prediction_models/gnn_predictor.py | 2 +- chebifier/prediction_models/nn_predictor.py | 25 ++++++++++---------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/chebifier/prediction_models/gnn_predictor.py b/chebifier/prediction_models/gnn_predictor.py index edddba7..64f21ae 100644 --- a/chebifier/prediction_models/gnn_predictor.py +++ b/chebifier/prediction_models/gnn_predictor.py @@ -60,7 +60,7 @@ def read_smiles(self, smiles): if isinstance(prop.encoder, IndexEncoder): if str(value) in prop.encoder.cache: index = ( - prop.encoder.cache.index(str(value)) + prop.encoder.offset + prop.encoder.cache[str(value)] + prop.encoder.offset ) else: index = 0 diff --git a/chebifier/prediction_models/nn_predictor.py b/chebifier/prediction_models/nn_predictor.py index 3b603b5..d572196 100644 --- a/chebifier/prediction_models/nn_predictor.py +++ b/chebifier/prediction_models/nn_predictor.py @@ -57,25 +57,26 @@ def predict_smiles_list(self, smiles_list) -> list: could_not_parse = [] index_map = dict() for i, smiles in enumerate(smiles_list): + if not smiles: + print(f"Model {self.model_name} received a missing SMILES string at position {i}.") + could_not_parse.append(i) + continue try: - # Try to parse the smiles string - if not smiles: - raise ValueError() d = self.read_smiles(smiles) + # This is just for sanity checks rdmol = Chem.MolFromSmiles(smiles, sanitize=False) - except Exception as e: - # Note if it fails - could_not_parse.append(i) - print(f"Failing to parse {smiles} due to {e}") - else: if rdmol is None: + print(f"Model {self.model_name} received a SMILES string RDKit can't read at position {i}: {smiles}") could_not_parse.append(i) - else: - index_map[i] = len(token_dicts) - token_dicts.append(d) + continue + except Exception as e: + could_not_parse.append(i) + print(f"Model {self.model_name} failed to parse a SMILES string at position {i}: {smiles}") + index_map[i] = len(token_dicts) + token_dicts.append(d) results = [] - if token_dicts: + if len(token_dicts) > 0: for batch in tqdm.tqdm( self.batchify(token_dicts), desc=f"{self.model_name}", From 598ed6b44d9edf8deedb76a11f7f0c31731521c2 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 18 Jul 2025 19:51:55 +0200 Subject: [PATCH 29/38] fix weight calculation if a model does not make predictions for all classes lists in its trust-weights file --- chebifier/ensemble/weighted_majority_ensemble.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/chebifier/ensemble/weighted_majority_ensemble.py b/chebifier/ensemble/weighted_majority_ensemble.py index ed40626..97338f5 100644 --- a/chebifier/ensemble/weighted_majority_ensemble.py +++ b/chebifier/ensemble/weighted_majority_ensemble.py @@ -45,13 +45,14 @@ def calculate_classwise_weights(self, predicted_classes): if model.classwise_weights is None: continue for cls, weights in model.classwise_weights.items(): - if (2 * weights["TP"] + weights["FP"] + weights["FN"]) > 0: - f1 = ( - 2 - * weights["TP"] - / (2 * weights["TP"] + weights["FP"] + weights["FN"]) - ) - weights_by_cls[predicted_classes[cls], j] *= 1 + f1 + if cls in predicted_classes: + if (2 * weights["TP"] + weights["FP"] + weights["FN"]) > 0: + f1 = ( + 2 + * weights["TP"] + / (2 * weights["TP"] + weights["FP"] + weights["FN"]) + ) + weights_by_cls[predicted_classes[cls], j] *= 1 + f1 print("Calculated model weightings. The average weights are:") for i, model in enumerate(self.models): From 4d918d569c3026bbbe9f7fcb6e2ce896029ea4c1 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 18 Jul 2025 19:52:25 +0200 Subject: [PATCH 30/38] Add C3P explanations --- chebifier/prediction_models/c3p_predictor.py | 22 ++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/chebifier/prediction_models/c3p_predictor.py b/chebifier/prediction_models/c3p_predictor.py index bc2f23e..fd233e9 100644 --- a/chebifier/prediction_models/c3p_predictor.py +++ b/chebifier/prediction_models/c3p_predictor.py @@ -18,7 +18,7 @@ def __init__(self, model_name: str, program_directory: Optional[Path]=None, chem self.chebi_graph = kwargs.get("chebi_graph", None) def predict_smiles_list(self, smiles_list: list[str]) -> list: - result_list = c3p_classifier.classify(smiles_list, self.program_directory, self.chemical_classes, strict=True) + result_list = c3p_classifier.classify(smiles_list, self.program_directory, self.chemical_classes, strict=False) result_reformatted = [dict() for _ in range(len(smiles_list))] for result in result_list: chebi_id = result.class_id.split(":")[1] @@ -26,4 +26,22 @@ def predict_smiles_list(self, smiles_list: list[str]) -> list: if result.is_match and self.chebi_graph is not None: for parent in list(self.chebi_graph.predecessors(int(chebi_id))): result_reformatted[smiles_list.index(result.input_smiles)][str(parent)] = 1 - return result_reformatted \ No newline at end of file + return result_reformatted + + def explain_smiles(self, smiles): + """ + C3P provides natural language explanations for each prediction (positive or negative). Since there are more + than 300 classes, only take the positive ones. + """ + highlights = [] + result_list = c3p_classifier.classify([smiles], self.program_directory, self.chemical_classes, strict=False) + for result in result_list: + if result.is_match: + highlights.append( + ("text", f"For class {result.class_name} ({result.class_id}), C3P gave the following explanation: {result.reason}") + ) + highlights = [ + ("text", f"C3P made positive predictions for {len(highlights)} classes. The explanations are as follows:") + ] + highlights + + return {"highlights": highlights} \ No newline at end of file From e4f1c54b265bedef2382704571bf9e56eac8d92d Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 21 Jul 2025 16:43:22 +0200 Subject: [PATCH 31/38] add cache --- chebifier/ensemble/base_ensemble.py | 2 +- chebifier/prediction_models/base_predictor.py | 13 ++++- chebifier/prediction_models/c3p_predictor.py | 6 ++- chebifier/prediction_models/chebi_lookup.py | 53 +++++++++++-------- .../prediction_models/chemlog_predictor.py | 51 ++++++++++-------- chebifier/prediction_models/nn_predictor.py | 5 +- 6 files changed, 79 insertions(+), 51 deletions(-) diff --git a/chebifier/ensemble/base_ensemble.py b/chebifier/ensemble/base_ensemble.py index e320e06..65b2f43 100644 --- a/chebifier/ensemble/base_ensemble.py +++ b/chebifier/ensemble/base_ensemble.py @@ -7,7 +7,7 @@ from chebai.result.analyse_sem import PredictionSmoother, get_chebi_graph from chebifier.prediction_models.base_predictor import BasePredictor - +from functools import lru_cache class BaseEnsemble: diff --git a/chebifier/prediction_models/base_predictor.py b/chebifier/prediction_models/base_predictor.py index 287f097..ba1412d 100644 --- a/chebifier/prediction_models/base_predictor.py +++ b/chebifier/prediction_models/base_predictor.py @@ -1,6 +1,8 @@ import json from abc import ABC +from functools import lru_cache + class BasePredictor(ABC): def __init__( @@ -22,7 +24,16 @@ def __init__( self._description = kwargs.get("description", None) def predict_smiles_list(self, smiles_list: list[str]) -> dict: - raise NotImplementedError + # list is not hashable, so we convert it to a tuple (useful for caching) + return self.predict_smiles_tuple(tuple(smiles_list)) + + @lru_cache(maxsize=100) + def predict_smiles_tuple(self, smiles_tuple: tuple[str]) -> dict: + raise NotImplementedError() + + def predict_smiles(self, smiles: str) -> dict: + # by default, use list-based prediction + return self.predict_smiles_tuple((smiles,))[0] @property def info_text(self): diff --git a/chebifier/prediction_models/c3p_predictor.py b/chebifier/prediction_models/c3p_predictor.py index fd233e9..ccf4eed 100644 --- a/chebifier/prediction_models/c3p_predictor.py +++ b/chebifier/prediction_models/c3p_predictor.py @@ -1,3 +1,4 @@ +from functools import lru_cache from typing import Optional, List from pathlib import Path @@ -17,8 +18,9 @@ def __init__(self, model_name: str, program_directory: Optional[Path]=None, chem self.chemical_classes = chemical_classes self.chebi_graph = kwargs.get("chebi_graph", None) - def predict_smiles_list(self, smiles_list: list[str]) -> list: - result_list = c3p_classifier.classify(smiles_list, self.program_directory, self.chemical_classes, strict=False) + @lru_cache(maxsize=100) + def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list: + result_list = c3p_classifier.classify(list(smiles_list), self.program_directory, self.chemical_classes, strict=False) result_reformatted = [dict() for _ in range(len(smiles_list))] for result in result_list: chebi_id = result.class_id.split(":")[1] diff --git a/chebifier/prediction_models/chebi_lookup.py b/chebifier/prediction_models/chebi_lookup.py index 5a45fab..8b99fad 100644 --- a/chebifier/prediction_models/chebi_lookup.py +++ b/chebifier/prediction_models/chebi_lookup.py @@ -1,3 +1,6 @@ +from functools import lru_cache +from typing import Optional + from chebifier.prediction_models import BasePredictor import os import networkx as nx @@ -51,32 +54,36 @@ def build_smiles_lookup(self): print(f"Failed to parse SMILES {smiles} for ChEBI ID {chebi_id}: {e}") return smiles_lookup + @lru_cache(maxsize=100) + def predict_smiles(self, smiles: str) -> Optional[dict]: + if not smiles: + return None + mol = Chem.MolFromSmiles(smiles) + if mol is None: + return None + canonical_smiles = Chem.MolToSmiles(mol) + if canonical_smiles in self.lookup_table: + parent_candidates = self.lookup_table[canonical_smiles] + preds_i = dict() + if len(parent_candidates) > 1: + print( + f"Multiple matches found in ChEBI for SMILES {smiles}: {', '.join(str(chebi_id) for chebi_id, _ in parent_candidates)}") + for k in list(set(pp for _, p in parent_candidates for pp in p)): + preds_i[str(k)] = 1 + elif len(parent_candidates) == 1: + chebi_id, parents = parent_candidates[0] + for k in parents: + preds_i[str(k)] = 1 + else: + preds_i = None + return preds_i + else: + return None - def predict_smiles_list(self, smiles_list: list[str]) -> list: + def predict_smiles_tuple(self, smiles_list: list[str]) -> list: predictions = [] for smiles in smiles_list: - if not smiles: - predictions.append(None) - continue - mol = Chem.MolFromSmiles(smiles) - if mol is None: - predictions.append(None) - continue - canonical_smiles = Chem.MolToSmiles(mol) - if canonical_smiles in self.lookup_table: - parent_candidates = self.lookup_table[canonical_smiles] - preds_i = dict() - if len(parent_candidates) > 1: - print(f"Multiple matches found in ChEBI for SMILES {smiles}: {', '.join(str(chebi_id) for chebi_id, _ in parent_candidates)}") - for k in list(set(pp for _, p in parent_candidates for pp in p)): - preds_i[str(k)] = 1 - elif len(parent_candidates) == 1: - chebi_id, parents = parent_candidates[0] - for k in parents: - preds_i[str(k)] = 1 - else: - preds_i = None - predictions.append(preds_i) + predictions.append(self.predict_smiles(smiles)) return predictions diff --git a/chebifier/prediction_models/chemlog_predictor.py b/chebifier/prediction_models/chemlog_predictor.py index 2c1ceda..a1343a4 100644 --- a/chebifier/prediction_models/chemlog_predictor.py +++ b/chebifier/prediction_models/chemlog_predictor.py @@ -1,3 +1,5 @@ +from typing import Optional + import tqdm from chemlog.alg_classification.charge_classifier import get_charge_category from chemlog.alg_classification.peptide_size_classifier import get_n_amino_acid_residues @@ -10,6 +12,7 @@ ) from chemlog.cli import CLASSIFIERS, _smiles_to_mol, strategy_call from chemlog_extra.alg_classification.by_element_classification import XMolecularEntityClassifier, OrganoXCompoundClassifier +from functools import lru_cache from .base_predictor import BasePredictor @@ -48,7 +51,7 @@ def __init__(self, model_name: str, **kwargs): self.chebi_graph = kwargs.get("chebi_graph", None) self.classifier = self.CHEMLOG_CLASSIFIER() - def predict_smiles_list(self, smiles_list: list[str]) -> list: + def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list: mol_list = [_smiles_to_mol(smiles) for smiles in smiles_list] res = self.classifier.classify(mol_list) if self.chebi_graph is not None: @@ -88,30 +91,32 @@ def __init__(self, model_name: str, **kwargs): # fmt: on print(f"Initialised ChemLog model {self.model_name}") - def predict_smiles_list(self, smiles_list: list[str]) -> list: + @lru_cache(maxsize=100) + def predict_smiles(self, smiles: str) -> Optional[dict]: + mol = _smiles_to_mol(smiles) + if mol is None: + return None + pos_labels = [label for label in self.peptide_labels if label in strategy_call( + self.strategy, self.classifier_instances, mol + )["chebi_classes"]] + if self.chebi_graph: + indirect_pos_labels = [str(pr) for label in pos_labels for pr in + self.chebi_graph.predecessors(int(label))] + pos_labels = list(set(pos_labels + indirect_pos_labels)) + return { + label: ( + 1 + if label + in pos_labels + else 0 + ) + for label in self.peptide_labels + pos_labels + } + + def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list: results = [] for i, smiles in tqdm.tqdm(enumerate(smiles_list)): - mol = _smiles_to_mol(smiles) - if mol is None: - results.append(None) - else: - pos_labels = [label for label in self.peptide_labels if label in strategy_call( - self.strategy, self.classifier_instances, mol - )["chebi_classes"]] - if self.chebi_graph: - indirect_pos_labels = [str(pr) for label in pos_labels for pr in self.chebi_graph.predecessors(int(label))] - pos_labels = list(set(pos_labels + indirect_pos_labels)) - results.append( - { - label: ( - 1 - if label - in pos_labels - else 0 - ) - for label in self.peptide_labels + pos_labels - } - ) + results.append(self.predict_smiles(smiles)) for classifier in self.classifier_instances.values(): classifier.on_finish() diff --git a/chebifier/prediction_models/nn_predictor.py b/chebifier/prediction_models/nn_predictor.py index d572196..fdc251c 100644 --- a/chebifier/prediction_models/nn_predictor.py +++ b/chebifier/prediction_models/nn_predictor.py @@ -1,3 +1,5 @@ +from functools import lru_cache + import numpy as np import torch import tqdm @@ -50,7 +52,8 @@ def read_smiles(self, smiles): d = reader.to_data(dict(features=smiles, labels=None)) return d - def predict_smiles_list(self, smiles_list) -> list: + @lru_cache(maxsize=100) + def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list: """Returns a list with the length of smiles_list, each element is either None (=failure) or a dictionary Of classes and predicted values.""" token_dicts = [] From 87fb66af070542b8100ee9fce000e030bbc44b99 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 21 Jul 2025 19:53:31 +0200 Subject: [PATCH 32/38] move files from api to chebifier, add files to huggingface --- api/__init__.py | 0 api/__main__.py | 10 --- api/api_registry.yml | 24 ------ api/cli.py | 121 ---------------------------- {api => chebifier}/check_env.py | 0 chebifier/cli.py | 54 ++++++++++--- chebifier/ensemble.yml | 15 ++++ chebifier/ensemble/base_ensemble.py | 50 ++++++++---- {api => chebifier}/hugging_face.py | 11 ++- chebifier/model_registry.yml | 34 ++++++++ 10 files changed, 130 insertions(+), 189 deletions(-) delete mode 100644 api/__init__.py delete mode 100644 api/__main__.py delete mode 100644 api/api_registry.yml delete mode 100644 api/cli.py rename {api => chebifier}/check_env.py (100%) create mode 100644 chebifier/ensemble.yml rename {api => chebifier}/hugging_face.py (89%) create mode 100644 chebifier/model_registry.yml diff --git a/api/__init__.py b/api/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/api/__main__.py b/api/__main__.py deleted file mode 100644 index ec70a17..0000000 --- a/api/__main__.py +++ /dev/null @@ -1,10 +0,0 @@ -from .cli import cli - -if __name__ == "__main__": - """ - Entry point for the CLI application. - - This script calls the `cli` function from the `api.cli` module - when executed as the main program. - """ - cli() diff --git a/api/api_registry.yml b/api/api_registry.yml deleted file mode 100644 index b6e30bd..0000000 --- a/api/api_registry.yml +++ /dev/null @@ -1,24 +0,0 @@ -electra: - hugging_face: - repo_id: aditya0by0/python-chebifier - subfolder: electra - files: - ckpt: electra.ckpt - labels: classes.txt - package_name: chebai - -resgated: - hugging_face: - repo_id: aditya0by0/python-chebifier - subfolder: resgated - files: - ckpt: resgated.ckpt - labels: classes.txt - package_name: chebai-graph - -chemlog: - package_name: chemlog - - -en_mv: - ensemble_of: {electra, chemlog} diff --git a/api/cli.py b/api/cli.py deleted file mode 100644 index e20ed14..0000000 --- a/api/cli.py +++ /dev/null @@ -1,121 +0,0 @@ -from pathlib import Path - -import click -import yaml - -from chebifier.model_registry import ENSEMBLES, MODEL_TYPES - -from .check_env import check_package_installed, get_current_environment -from .hugging_face import download_model_files - -yaml_path = Path("api/api_registry.yml") -if yaml_path.exists(): - with yaml_path.open("r") as f: - api_registry = yaml.safe_load(f) -else: - raise FileNotFoundError(f"{yaml_path} not found.") - - -@click.group() -def cli(): - """Command line interface for Api-Chebifier.""" - pass - - -@cli.command() -@click.option("--smiles", "-s", multiple=True, help="SMILES strings to predict") -@click.option( - "--smiles-file", - "-f", - type=click.Path(exists=True), - help="File containing SMILES strings (one per line)", -) -@click.option( - "--output", - "-o", - type=click.Path(), - help="Output file to save predictions (optional)", -) -@click.option( - "--model-type", - "-m", - type=click.Choice(api_registry.keys()), - default="mv", - help="Type of model to use", -) -def predict(smiles, smiles_file, output, model_type): - """Predict ChEBI classes for SMILES strings using an ensemble model. - - CONFIG_FILE is the path to a YAML configuration file for the ensemble model. - """ - - # Collect SMILES strings from arguments and/or file - smiles_list = list(smiles) - if smiles_file: - with open(smiles_file, "r") as f: - smiles_list.extend([line.strip() for line in f if line.strip()]) - - if not smiles_list: - click.echo("No SMILES strings provided. Use --smiles or --smiles-file options.") - return - - print("Current working environment is:", get_current_environment()) - - def get_individual_model(model_config): - predictor_kwargs = {} - if "hugging_face" in model_config: - predictor_kwargs = download_model_files(model_config["hugging_face"]) - check_package_installed(model_config["package_name"]) - return predictor_kwargs - - if model_type in MODEL_TYPES: - print(f"Predictor for Single/Individual Model: {model_type}") - model_config = api_registry[model_type] - predictor_kwargs = get_individual_model(model_config) - predictor_kwargs["model_name"] = model_type - model_instance = MODEL_TYPES[model_type](**predictor_kwargs) - - elif model_type in ENSEMBLES: - print(f"Predictor for Ensemble Model: {model_type}") - ensemble_config = {} - for i, en_comp in enumerate(api_registry[model_type]["ensemble_of"]): - assert en_comp in MODEL_TYPES - print(f"For ensemble component {en_comp}") - predictor_kwargs = get_individual_model(api_registry[en_comp]) - model_key = f"model_{i + 1}" - ensemble_config[model_key] = { - "type": en_comp, - "model_name": f"{en_comp}_{model_key}", - **predictor_kwargs, - } - model_instance = ENSEMBLES[model_type](ensemble_config) - - else: - raise ValueError("") - - # Make predictions - predictions = model_instance.predict_smiles_list(smiles_list) - - if output: - # save as json - import json - - with open(output, "w") as f: - json.dump( - {smiles: pred for smiles, pred in zip(smiles_list, predictions)}, - f, - indent=2, - ) - - else: - # Print results - for i, (smiles, prediction) in enumerate(zip(smiles_list, predictions)): - click.echo(f"Result for: {smiles}") - if prediction: - click.echo(f" Predicted classes: {', '.join(map(str, prediction))}") - else: - click.echo(" No predictions") - - -if __name__ == "__main__": - cli() diff --git a/api/check_env.py b/chebifier/check_env.py similarity index 100% rename from api/check_env.py rename to chebifier/check_env.py diff --git a/chebifier/cli.py b/chebifier/cli.py index b3e7230..2f99b47 100644 --- a/chebifier/cli.py +++ b/chebifier/cli.py @@ -1,3 +1,4 @@ +import importlib.resources import os import click @@ -14,9 +15,10 @@ def cli(): @cli.command() @click.option( - "--config_file", + "--ensemble-config", + "-e", type=click.Path(exists=True), - default=os.path.join("configs", "huggingface_config.yml"), + default=None, help="Configuration file for ensemble models", ) @click.option("--smiles", "-s", multiple=True, help="SMILES strings to predict") @@ -34,10 +36,10 @@ def cli(): ) @click.option( "--ensemble-type", - "-e", + "-t", type=click.Choice(ENSEMBLES.keys()), - default="mv", - help="Type of ensemble to use (default: Majority Voting)", + default="wmv-f1", + help="Type of ensemble to use (default: Weighted Majority Voting)", ) @click.option( "--chebi-version", @@ -53,25 +55,53 @@ def cli(): default=True, help="Weight predictions based on how 'confident' a model is in its prediction (default: True)", ) +@click.option( + "--resolve-inconsistencies", + "-r", + is_flag=True, + default=True, + help="Resolve inconsistencies in predictions automatically (default: True)", +) def predict( - config_file, + ensemble_config, smiles, smiles_file, output, ensemble_type, chebi_version, use_confidence, + resolve_inconsistencies=True, ): """Predict ChEBI classes for SMILES strings using an ensemble model. - - CONFIG_FILE is the path to a YAML configuration file for the ensemble model. - """ + """ # Load configuration from YAML file - with open(config_file, "r") as f: - config = yaml.safe_load(f) + if not ensemble_config: + print(f"Using default ensemble configuration") + with importlib.resources.files("chebifier").joinpath("ensemble.yml").open("r") as f: + config = yaml.safe_load(f) + else: + print(f"Loading ensemble configuration from {ensemble_config}") + with open(ensemble_config, "r") as f: + config = yaml.safe_load(f) + + with importlib.resources.files("chebifier").joinpath("model_registry.yml").open("r") as f: + model_registry = yaml.safe_load(f) + + new_config = {} + for model_name, entry in config.items(): + if "load_model" in entry: + if entry["load_model"] not in model_registry: + raise ValueError( + f"Model {entry['load_model']} not found in model registry. " + f"Available models are: {','.join(model_registry.keys())}." + ) + new_config[model_name] = {**model_registry[entry["load_model"]], **entry} + else: + new_config[model_name] = entry + config = new_config # Instantiate ensemble model - ensemble = ENSEMBLES[ensemble_type](config, chebi_version=chebi_version) + ensemble = ENSEMBLES[ensemble_type](config, chebi_version=chebi_version, resolve_inconsistencies=resolve_inconsistencies) # Collect SMILES strings from arguments and/or file smiles_list = list(smiles) diff --git a/chebifier/ensemble.yml b/chebifier/ensemble.yml new file mode 100644 index 0000000..1744bad --- /dev/null +++ b/chebifier/ensemble.yml @@ -0,0 +1,15 @@ +electra: + load_model: electra_chebi50_v241 +resgated: + load_model: resgated_chebi50_v241 +chemlog_peptides: + type: chemlog_peptides + model_weight: 100 +chemlog_element: + type: chemlog_element + model_weight: 100 +chemlog_organox: + type: chemlog_organox + model_weight: 100 +c3p: + load_model: c3p_with_weights diff --git a/chebifier/ensemble/base_ensemble.py b/chebifier/ensemble/base_ensemble.py index 65b2f43..8fdb332 100644 --- a/chebifier/ensemble/base_ensemble.py +++ b/chebifier/ensemble/base_ensemble.py @@ -6,32 +6,48 @@ from chebai.preprocessing.datasets.chebi import ChEBIOver50 from chebai.result.analyse_sem import PredictionSmoother, get_chebi_graph +from chebifier.check_env import check_package_installed from chebifier.prediction_models.base_predictor import BasePredictor -from functools import lru_cache + class BaseEnsemble: - def __init__(self, model_configs: dict, chebi_version: int = 241): + def __init__(self, model_configs: dict, chebi_version: int = 241, resolve_inconsistencies: bool = True): # Deferred Import: To avoid circular import error from chebifier.model_registry import MODEL_TYPES self.chebi_dataset = ChEBIOver50(chebi_version=chebi_version) self.chebi_dataset._download_required_data() # download chebi if not already downloaded self.chebi_graph = get_chebi_graph(self.chebi_dataset, None) - self.disjoint_files = [ + local_disjoint_files = [ os.path.join("data", "disjoint_chebi.csv"), os.path.join("data", "disjoint_additional.csv"), ] + self.disjoint_files = [] + for file in local_disjoint_files: + if os.path.isfile(file): + self.disjoint_files.append(file) + else: + print(f"Disjoint axiom file {file} not found. Loading from huggingface instead...") + from chebifier.hugging_face import download_model_files + self.disjoint_files.append(download_model_files({ + "repo_id": "chebai/chebifier", + "repo_type": "dataset", + "files": {"disjoint_file": os.path.basename(file)}, + })["disjoint_file"]) self.models = [] self.positive_prediction_threshold = 0.5 for model_name, model_config in model_configs.items(): model_cls = MODEL_TYPES[model_config["type"]] if "hugging_face" in model_config: - from api.hugging_face import download_model_files + from chebifier.hugging_face import download_model_files hugging_face_kwargs = download_model_files(model_config["hugging_face"]) else: hugging_face_kwargs = {} + if "package_name" in model_config: + check_package_installed(model_config["package_name"]) + model_instance = model_cls( model_name, **model_config, **hugging_face_kwargs, chebi_graph=self.chebi_graph ) @@ -39,12 +55,14 @@ def __init__(self, model_configs: dict, chebi_version: int = 241): self.models.append(model_instance) - - self.smoother = PredictionSmoother( - self.chebi_dataset, - label_names=None, - disjoint_files=self.disjoint_files, - ) + if resolve_inconsistencies: + self.smoother = PredictionSmoother( + self.chebi_dataset, + label_names=None, + disjoint_files=self.disjoint_files, + ) + else: + self.smoother = None def gather_predictions(self, smiles_list): # get predictions from all models for the SMILES list @@ -131,15 +149,15 @@ def consolidate_predictions(self, predictions, classwise_weights, predicted_clas # Smooth predictions start_time = time.perf_counter() class_names = list(predicted_classes.keys()) - self.smoother.set_label_names(class_names) - smooth_net_score = self.smoother(net_score) + if self.smoother is not None: + self.smoother.set_label_names(class_names) + smooth_net_score = self.smoother(net_score) + class_decisions = (smooth_net_score > 0.5) & has_valid_predictions # Shape: (num_smiles, num_classes) + else: + class_decisions = (net_score > 0) & has_valid_predictions # Shape: (num_smiles, num_classes) end_time = time.perf_counter() print(f"Prediction smoothing took {end_time - start_time:.2f} seconds") - class_decisions = ( - smooth_net_score > 0.5 - ) & has_valid_predictions # Shape: (num_smiles, num_classes) - complete_failure = torch.all(~has_valid_predictions, dim=1) return class_decisions, complete_failure diff --git a/api/hugging_face.py b/chebifier/hugging_face.py similarity index 89% rename from api/hugging_face.py rename to chebifier/hugging_face.py index 5569d86..9b9e599 100644 --- a/api/hugging_face.py +++ b/chebifier/hugging_face.py @@ -25,14 +25,15 @@ def download_model_files( model_config (Dict[str, str | Dict[str, str]]): A dictionary containing: - 'repo_id' (str): The Hugging Face repository ID (e.g., 'username/modelname'). - 'subfolder' (str): The subfolder within the repo where the files are located. - - 'files' (Dict[str, str]): A mapping from file type (e.g., 'ckpt', 'labels') to + - 'files' (Dict[str, str]): A mapping from file type (e.g., 'ckpt_path', 'target_labels_path') to actual file names (e.g., 'electra.ckpt', 'classes.txt'). Returns: Dict[str, Path]: A dictionary mapping each file type to the local Path of the downloaded file. """ repo_id = model_config["repo_id"] - subfolder = model_config["subfolder"] + subfolder = model_config.get("subfolder", None) + repo_type = model_config.get("repo_type", "model") filenames = model_config["files"] local_paths: dict[str, Path] = {} @@ -40,12 +41,10 @@ def download_model_files( downloaded_file_path = hf_hub_download( repo_id=repo_id, filename=filename, + repo_type=repo_type, subfolder=subfolder, ) local_paths[file_type] = Path(downloaded_file_path) print(f"\t Using file `{filename}` from: {downloaded_file_path}") - return { - "ckpt_path": local_paths["ckpt"], - "target_labels_path": local_paths["labels"], - } + return local_paths diff --git a/chebifier/model_registry.yml b/chebifier/model_registry.yml new file mode 100644 index 0000000..0cef3af --- /dev/null +++ b/chebifier/model_registry.yml @@ -0,0 +1,34 @@ +electra_chebi50_v241: + type: electra + hugging_face: + repo_id: chebai/electra_chebi50_v241 + files: + ckpt_path: 14ko0zcf_epoch=193.ckpt + target_labels_path: classes.txt + classwise_weights_path: metrics_electra_14ko0zcf_80-10-10_short.json +resgated_chebi50_v241: + type: resgated + hugging_face: + repo_id: chebai/resgated_gcn_chebi50_v241 + files: + ckpt_path: 0ps1g189_epoch=122.ckpt + target_labels_path: classes.txt + classwise_weights_path: metrics_0ps1g189_80-10-10_short.json + molecular_properties: + - chebai_graph.preprocessing.properties.AtomType + - chebai_graph.preprocessing.properties.NumAtomBonds + - chebai_graph.preprocessing.properties.AtomCharge + - chebai_graph.preprocessing.properties.AtomAromaticity + - chebai_graph.preprocessing.properties.AtomHybridization + - chebai_graph.preprocessing.properties.AtomNumHs + - chebai_graph.preprocessing.properties.BondType + - chebai_graph.preprocessing.properties.BondInRing + - chebai_graph.preprocessing.properties.BondAromaticity + - chebai_graph.preprocessing.properties.RDKit2DNormalized +c3p_with_weights: + type: c3p + hugging_face: + repo_id: chebai/chebifier + repo_type: dataset + files: + classwise_weights_path: c3p_trust.json \ No newline at end of file From 7ebbacbdd4e9e278670fc69dea29e2ae48681913 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 21 Jul 2025 19:53:52 +0200 Subject: [PATCH 33/38] fix error handling for nns --- chebifier/prediction_models/nn_predictor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebifier/prediction_models/nn_predictor.py b/chebifier/prediction_models/nn_predictor.py index fdc251c..dd404ba 100644 --- a/chebifier/prediction_models/nn_predictor.py +++ b/chebifier/prediction_models/nn_predictor.py @@ -66,7 +66,6 @@ def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list: continue try: d = self.read_smiles(smiles) - # This is just for sanity checks rdmol = Chem.MolFromSmiles(smiles, sanitize=False) if rdmol is None: @@ -76,6 +75,7 @@ def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list: except Exception as e: could_not_parse.append(i) print(f"Model {self.model_name} failed to parse a SMILES string at position {i}: {smiles}") + continue index_map[i] = len(token_dicts) token_dicts.append(d) results = [] From c30e06d8aa3beb764f3601f65ffa6b6508f323cb Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 21 Jul 2025 20:03:24 +0200 Subject: [PATCH 34/38] reformat with black, add development dependencies --- chebifier/cli.py | 21 +++-- chebifier/ensemble/base_ensemble.py | 64 ++++++++++---- chebifier/model_registry.py | 9 +- chebifier/prediction_models/__init__.py | 11 ++- chebifier/prediction_models/c3p_predictor.py | 41 +++++++-- chebifier/prediction_models/chebi_lookup.py | 83 +++++++++++++------ .../prediction_models/chemlog_predictor.py | 46 ++++++---- chebifier/prediction_models/gnn_predictor.py | 4 +- chebifier/prediction_models/nn_predictor.py | 12 ++- pyproject.toml | 3 + 10 files changed, 213 insertions(+), 81 deletions(-) diff --git a/chebifier/cli.py b/chebifier/cli.py index 2f99b47..924ef5d 100644 --- a/chebifier/cli.py +++ b/chebifier/cli.py @@ -72,19 +72,26 @@ def predict( use_confidence, resolve_inconsistencies=True, ): - """Predict ChEBI classes for SMILES strings using an ensemble model. - """ + """Predict ChEBI classes for SMILES strings using an ensemble model.""" # Load configuration from YAML file if not ensemble_config: print(f"Using default ensemble configuration") - with importlib.resources.files("chebifier").joinpath("ensemble.yml").open("r") as f: + with ( + importlib.resources.files("chebifier") + .joinpath("ensemble.yml") + .open("r") as f + ): config = yaml.safe_load(f) else: print(f"Loading ensemble configuration from {ensemble_config}") with open(ensemble_config, "r") as f: config = yaml.safe_load(f) - with importlib.resources.files("chebifier").joinpath("model_registry.yml").open("r") as f: + with ( + importlib.resources.files("chebifier") + .joinpath("model_registry.yml") + .open("r") as f + ): model_registry = yaml.safe_load(f) new_config = {} @@ -101,7 +108,11 @@ def predict( config = new_config # Instantiate ensemble model - ensemble = ENSEMBLES[ensemble_type](config, chebi_version=chebi_version, resolve_inconsistencies=resolve_inconsistencies) + ensemble = ENSEMBLES[ensemble_type]( + config, + chebi_version=chebi_version, + resolve_inconsistencies=resolve_inconsistencies, + ) # Collect SMILES strings from arguments and/or file smiles_list = list(smiles) diff --git a/chebifier/ensemble/base_ensemble.py b/chebifier/ensemble/base_ensemble.py index 8fdb332..39f643a 100644 --- a/chebifier/ensemble/base_ensemble.py +++ b/chebifier/ensemble/base_ensemble.py @@ -12,7 +12,12 @@ class BaseEnsemble: - def __init__(self, model_configs: dict, chebi_version: int = 241, resolve_inconsistencies: bool = True): + def __init__( + self, + model_configs: dict, + chebi_version: int = 241, + resolve_inconsistencies: bool = True, + ): # Deferred Import: To avoid circular import error from chebifier.model_registry import MODEL_TYPES @@ -28,13 +33,20 @@ def __init__(self, model_configs: dict, chebi_version: int = 241, resolve_incons if os.path.isfile(file): self.disjoint_files.append(file) else: - print(f"Disjoint axiom file {file} not found. Loading from huggingface instead...") + print( + f"Disjoint axiom file {file} not found. Loading from huggingface instead..." + ) from chebifier.hugging_face import download_model_files - self.disjoint_files.append(download_model_files({ - "repo_id": "chebai/chebifier", - "repo_type": "dataset", - "files": {"disjoint_file": os.path.basename(file)}, - })["disjoint_file"]) + + self.disjoint_files.append( + download_model_files( + { + "repo_id": "chebai/chebifier", + "repo_type": "dataset", + "files": {"disjoint_file": os.path.basename(file)}, + } + )["disjoint_file"] + ) self.models = [] self.positive_prediction_threshold = 0.5 @@ -42,6 +54,7 @@ def __init__(self, model_configs: dict, chebi_version: int = 241, resolve_incons model_cls = MODEL_TYPES[model_config["type"]] if "hugging_face" in model_config: from chebifier.hugging_face import download_model_files + hugging_face_kwargs = download_model_files(model_config["hugging_face"]) else: hugging_face_kwargs = {} @@ -49,12 +62,14 @@ def __init__(self, model_configs: dict, chebi_version: int = 241, resolve_incons check_package_installed(model_config["package_name"]) model_instance = model_cls( - model_name, **model_config, **hugging_face_kwargs, chebi_graph=self.chebi_graph + model_name, + **model_config, + **hugging_face_kwargs, + chebi_graph=self.chebi_graph, ) assert isinstance(model_instance, BasePredictor) self.models.append(model_instance) - if resolve_inconsistencies: self.smoother = PredictionSmoother( self.chebi_dataset, @@ -96,7 +111,9 @@ def gather_predictions(self, smiles_list): return ordered_logits, predicted_classes - def consolidate_predictions(self, predictions, classwise_weights, predicted_classes, **kwargs): + def consolidate_predictions( + self, predictions, classwise_weights, predicted_classes, **kwargs + ): """ Aggregates predictions from multiple models using weighted majority voting. Optimized version using tensor operations instead of for loops. @@ -152,9 +169,13 @@ def consolidate_predictions(self, predictions, classwise_weights, predicted_clas if self.smoother is not None: self.smoother.set_label_names(class_names) smooth_net_score = self.smoother(net_score) - class_decisions = (smooth_net_score > 0.5) & has_valid_predictions # Shape: (num_smiles, num_classes) + class_decisions = ( + smooth_net_score > 0.5 + ) & has_valid_predictions # Shape: (num_smiles, num_classes) else: - class_decisions = (net_score > 0) & has_valid_predictions # Shape: (num_smiles, num_classes) + class_decisions = ( + net_score > 0 + ) & has_valid_predictions # Shape: (num_smiles, num_classes) end_time = time.perf_counter() print(f"Prediction smoothing took {end_time - start_time:.2f} seconds") @@ -178,7 +199,9 @@ def predict_smiles_list( smiles_list ) if len(predicted_classes) == 0: - print(f"Warning: No classes have been predicted for the given SMILES list.") + print( + f"Warning: No classes have been predicted for the given SMILES list." + ) # save predictions torch.save(ordered_predictions, preds_file) with open(predicted_classes_file, "w") as f: @@ -203,7 +226,14 @@ def predict_smiles_list( class_names = list(predicted_classes.keys()) class_indices = {predicted_classes[cls]: cls for cls in class_names} result = [ - [class_indices[idx.item()] for idx in torch.nonzero(i, as_tuple=True)[0]] if not failure else None + ( + [ + class_indices[idx.item()] + for idx in torch.nonzero(i, as_tuple=True)[0] + ] + if not failure + else None + ) for i, failure in zip(class_decisions, is_failure) ] @@ -240,7 +270,11 @@ def predict_smiles_list( } ) r = ensemble.predict_smiles_list( - ["[NH3+]CCCC[C@H](NC(=O)[C@@H]([NH3+])CC([O-])=O)C([O-])=O", "C[C@H](N)C(=O)NCC(O)=O#", ""], + [ + "[NH3+]CCCC[C@H](NC(=O)[C@@H]([NH3+])CC([O-])=O)C([O-])=O", + "C[C@H](N)C(=O)NCC(O)=O#", + "", + ], load_preds_if_possible=False, ) print(len(r), r[0]) diff --git a/chebifier/model_registry.py b/chebifier/model_registry.py index 5020f17..76a3cde 100644 --- a/chebifier/model_registry.py +++ b/chebifier/model_registry.py @@ -7,10 +7,13 @@ ChemlogPeptidesPredictor, ElectraPredictor, ResGatedPredictor, - ChEBILookupPredictor + ChEBILookupPredictor, ) from chebifier.prediction_models.c3p_predictor import C3PPredictor -from chebifier.prediction_models.chemlog_predictor import ChemlogXMolecularEntityPredictor, ChemlogOrganoXCompoundPredictor +from chebifier.prediction_models.chemlog_predictor import ( + ChemlogXMolecularEntityPredictor, + ChemlogOrganoXCompoundPredictor, +) ENSEMBLES = { "mv": BaseEnsemble, @@ -26,7 +29,7 @@ "chebi_lookup": ChEBILookupPredictor, "chemlog_element": ChemlogXMolecularEntityPredictor, "chemlog_organox": ChemlogOrganoXCompoundPredictor, - "c3p": C3PPredictor + "c3p": C3PPredictor, } diff --git a/chebifier/prediction_models/__init__.py b/chebifier/prediction_models/__init__.py index ce33cca..51fa6d6 100644 --- a/chebifier/prediction_models/__init__.py +++ b/chebifier/prediction_models/__init__.py @@ -3,5 +3,12 @@ from .electra_predictor import ElectraPredictor from .gnn_predictor import ResGatedPredictor from .chebi_lookup import ChEBILookupPredictor -__all__ = ["BasePredictor", "ChemlogPeptidesPredictor", "ElectraPredictor", "ResGatedPredictor", "ChEBILookupPredictor", - "ChemlogExtraPredictor"] + +__all__ = [ + "BasePredictor", + "ChemlogPeptidesPredictor", + "ElectraPredictor", + "ResGatedPredictor", + "ChEBILookupPredictor", + "ChemlogExtraPredictor", +] diff --git a/chebifier/prediction_models/c3p_predictor.py b/chebifier/prediction_models/c3p_predictor.py index ccf4eed..a8e2e10 100644 --- a/chebifier/prediction_models/c3p_predictor.py +++ b/chebifier/prediction_models/c3p_predictor.py @@ -12,7 +12,13 @@ class C3PPredictor(BasePredictor): Wrapper for C3P (url). """ - def __init__(self, model_name: str, program_directory: Optional[Path]=None, chemical_classes: Optional[List[str]]=None, **kwargs): + def __init__( + self, + model_name: str, + program_directory: Optional[Path] = None, + chemical_classes: Optional[List[str]] = None, + **kwargs, + ): super().__init__(model_name, **kwargs) self.program_directory = program_directory self.chemical_classes = chemical_classes @@ -20,14 +26,23 @@ def __init__(self, model_name: str, program_directory: Optional[Path]=None, chem @lru_cache(maxsize=100) def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list: - result_list = c3p_classifier.classify(list(smiles_list), self.program_directory, self.chemical_classes, strict=False) + result_list = c3p_classifier.classify( + list(smiles_list), + self.program_directory, + self.chemical_classes, + strict=False, + ) result_reformatted = [dict() for _ in range(len(smiles_list))] for result in result_list: chebi_id = result.class_id.split(":")[1] - result_reformatted[smiles_list.index(result.input_smiles)][chebi_id] = result.is_match + result_reformatted[smiles_list.index(result.input_smiles)][ + chebi_id + ] = result.is_match if result.is_match and self.chebi_graph is not None: for parent in list(self.chebi_graph.predecessors(int(chebi_id))): - result_reformatted[smiles_list.index(result.input_smiles)][str(parent)] = 1 + result_reformatted[smiles_list.index(result.input_smiles)][ + str(parent) + ] = 1 return result_reformatted def explain_smiles(self, smiles): @@ -36,14 +51,22 @@ def explain_smiles(self, smiles): than 300 classes, only take the positive ones. """ highlights = [] - result_list = c3p_classifier.classify([smiles], self.program_directory, self.chemical_classes, strict=False) + result_list = c3p_classifier.classify( + [smiles], self.program_directory, self.chemical_classes, strict=False + ) for result in result_list: if result.is_match: highlights.append( - ("text", f"For class {result.class_name} ({result.class_id}), C3P gave the following explanation: {result.reason}") + ( + "text", + f"For class {result.class_name} ({result.class_id}), C3P gave the following explanation: {result.reason}", + ) ) highlights = [ - ("text", f"C3P made positive predictions for {len(highlights)} classes. The explanations are as follows:") - ] + highlights + ( + "text", + f"C3P made positive predictions for {len(highlights)} classes. The explanations are as follows:", + ) + ] + highlights - return {"highlights": highlights} \ No newline at end of file + return {"highlights": highlights} diff --git a/chebifier/prediction_models/chebi_lookup.py b/chebifier/prediction_models/chebi_lookup.py index 8b99fad..f314bf5 100644 --- a/chebifier/prediction_models/chebi_lookup.py +++ b/chebifier/prediction_models/chebi_lookup.py @@ -7,15 +7,26 @@ from rdkit import Chem import json + class ChEBILookupPredictor(BasePredictor): - def __init__(self, model_name: str, description: str = None, chebi_version: int = 241, **kwargs): + def __init__( + self, + model_name: str, + description: str = None, + chebi_version: int = 241, + **kwargs, + ): super().__init__(model_name, **kwargs) - self._description = description or "ChEBI Lookup: If the SMILES is equivalent to a ChEBI entry, retrieve the classification of that entry." + self._description = ( + description + or "ChEBI Lookup: If the SMILES is equivalent to a ChEBI entry, retrieve the classification of that entry." + ) self.chebi_version = chebi_version self.chebi_graph = kwargs.get("chebi_graph", None) if self.chebi_graph is None: from chebai.preprocessing.datasets.chebi import ChEBIOver50 + self.chebi_dataset = ChEBIOver50(chebi_version=self.chebi_version) self.chebi_dataset._download_required_data() self.chebi_graph = self.chebi_dataset._extract_class_hierarchy( @@ -24,7 +35,9 @@ def __init__(self, model_name: str, description: str = None, chebi_version: int self.lookup_table = self.get_smiles_lookup() def get_smiles_lookup(self): - path = os.path.join("data", f"chebi_v{self.chebi_version}", "smiles_lookup.json") + path = os.path.join( + "data", f"chebi_v{self.chebi_version}", "smiles_lookup.json" + ) if not os.path.exists(path): smiles_lookup = self.build_smiles_lookup() with open(path, "w", encoding="utf-8") as f: @@ -35,23 +48,30 @@ def get_smiles_lookup(self): smiles_lookup = json.load(f) return smiles_lookup - def build_smiles_lookup(self): smiles_lookup = dict() - for chebi_id, smiles in nx.get_node_attributes(self.chebi_graph, "smiles").items(): + for chebi_id, smiles in nx.get_node_attributes( + self.chebi_graph, "smiles" + ).items(): if smiles is not None: try: mol = Chem.MolFromSmiles(smiles) if mol is None: - print(f"Failed to parse SMILES {smiles} for ChEBI ID {chebi_id}") + print( + f"Failed to parse SMILES {smiles} for ChEBI ID {chebi_id}" + ) continue canonical_smiles = Chem.MolToSmiles(mol) if canonical_smiles not in smiles_lookup: smiles_lookup[canonical_smiles] = [] # if the canonical SMILES is already in the lookup, append "different interpretation of the SMILES" - smiles_lookup[canonical_smiles].append((chebi_id, list(self.chebi_graph.predecessors(chebi_id)))) + smiles_lookup[canonical_smiles].append( + (chebi_id, list(self.chebi_graph.predecessors(chebi_id))) + ) except Exception as e: - print(f"Failed to parse SMILES {smiles} for ChEBI ID {chebi_id}: {e}") + print( + f"Failed to parse SMILES {smiles} for ChEBI ID {chebi_id}: {e}" + ) return smiles_lookup @lru_cache(maxsize=100) @@ -67,7 +87,8 @@ def predict_smiles(self, smiles: str) -> Optional[dict]: preds_i = dict() if len(parent_candidates) > 1: print( - f"Multiple matches found in ChEBI for SMILES {smiles}: {', '.join(str(chebi_id) for chebi_id, _ in parent_candidates)}") + f"Multiple matches found in ChEBI for SMILES {smiles}: {', '.join(str(chebi_id) for chebi_id, _ in parent_candidates)}" + ) for k in list(set(pp for _, p in parent_candidates for pp in p)): preds_i[str(k)] = 1 elif len(parent_candidates) == 1: @@ -96,28 +117,42 @@ def info_text(self): def explain_smiles(self, smiles: str) -> dict: mol = Chem.MolFromSmiles(smiles) if mol is None: - return {"highlights": [ - ("text", "The input SMILES could not be parsed into a valid molecule.") - ]} + return { + "highlights": [ + ( + "text", + "The input SMILES could not be parsed into a valid molecule.", + ) + ] + } canonical_smiles = Chem.MolToSmiles(mol) if canonical_smiles not in self.lookup_table: - return {"highlights": [ - ("text", "The input SMILES does not match any ChEBI entry.") - ]} + return { + "highlights": [ + ("text", "The input SMILES does not match any ChEBI entry.") + ] + } parent_candidates = self.lookup_table[canonical_smiles] - return {"highlights": [ - ("text", - f"The ChEBI Lookup matches the canonical version of the input SMILES against ChEBI (v{self.chebi_version})." - f" It found {'1 match' if len(parent_candidates) == 1 else f'{len(parent_candidates)} matches'}:" - f" {', '.join(f'CHEBI:{chebi_id}' for chebi_id, _ in parent_candidates)}. The predicted classes are the" - f" parent classes of the matched ChEBI entries.") - ]} + return { + "highlights": [ + ( + "text", + f"The ChEBI Lookup matches the canonical version of the input SMILES against ChEBI (v{self.chebi_version})." + f" It found {'1 match' if len(parent_candidates) == 1 else f'{len(parent_candidates)} matches'}:" + f" {', '.join(f'CHEBI:{chebi_id}' for chebi_id, _ in parent_candidates)}. The predicted classes are the" + f" parent classes of the matched ChEBI entries.", + ) + ] + } if __name__ == "__main__": predictor = ChEBILookupPredictor("ChEBI Lookup") print(predictor.info_text) # Example usage - smiles_list = ["CCO", "C1=CC=CC=C1" '*C(=O)OC[C@H](COP(=O)([O-])OCC[N+](C)(C)C)OC(*)=O'] # SMILES with 251 matches in ChEBI + smiles_list = [ + "CCO", + "C1=CC=CC=C1" "*C(=O)OC[C@H](COP(=O)([O-])OCC[N+](C)(C)C)OC(*)=O", + ] # SMILES with 251 matches in ChEBI predictions = predictor.predict_smiles_list(smiles_list) - print(predictions) \ No newline at end of file + print(predictions) diff --git a/chebifier/prediction_models/chemlog_predictor.py b/chebifier/prediction_models/chemlog_predictor.py index a1343a4..10729d2 100644 --- a/chebifier/prediction_models/chemlog_predictor.py +++ b/chebifier/prediction_models/chemlog_predictor.py @@ -11,7 +11,10 @@ is_emericellamide, ) from chemlog.cli import CLASSIFIERS, _smiles_to_mol, strategy_call -from chemlog_extra.alg_classification.by_element_classification import XMolecularEntityClassifier, OrganoXCompoundClassifier +from chemlog_extra.alg_classification.by_element_classification import ( + XMolecularEntityClassifier, + OrganoXCompoundClassifier, +) from functools import lru_cache from .base_predictor import BasePredictor @@ -42,6 +45,7 @@ "Y": "L-tyrosine", } + class ChemlogExtraPredictor(BasePredictor): CHEMLOG_CLASSIFIER = None @@ -71,10 +75,12 @@ class ChemlogXMolecularEntityPredictor(ChemlogExtraPredictor): CHEMLOG_CLASSIFIER = XMolecularEntityClassifier + class ChemlogOrganoXCompoundPredictor(ChemlogExtraPredictor): CHEMLOG_CLASSIFIER = OrganoXCompoundClassifier + class ChemlogPeptidesPredictor(BasePredictor): def __init__(self, model_name: str, **kwargs): super().__init__(model_name, **kwargs) @@ -96,22 +102,25 @@ def predict_smiles(self, smiles: str) -> Optional[dict]: mol = _smiles_to_mol(smiles) if mol is None: return None - pos_labels = [label for label in self.peptide_labels if label in strategy_call( - self.strategy, self.classifier_instances, mol - )["chebi_classes"]] + pos_labels = [ + label + for label in self.peptide_labels + if label + in strategy_call(self.strategy, self.classifier_instances, mol)[ + "chebi_classes" + ] + ] if self.chebi_graph: - indirect_pos_labels = [str(pr) for label in pos_labels for pr in - self.chebi_graph.predecessors(int(label))] + indirect_pos_labels = [ + str(pr) + for label in pos_labels + for pr in self.chebi_graph.predecessors(int(label)) + ] pos_labels = list(set(pos_labels + indirect_pos_labels)) return { - label: ( - 1 - if label - in pos_labels - else 0 - ) - for label in self.peptide_labels + pos_labels - } + label: (1 if label in pos_labels else 0) + for label in self.peptide_labels + pos_labels + } def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list: results = [] @@ -377,9 +386,12 @@ def build_explain_blocks_proteinogenics(self, proteinogenics, atoms): def explain_smiles(self, smiles) -> dict: info = self.get_chemlog_result_info(smiles) zero_blocks = [ - ("text", "Results for peptides and peptide-related classes (e.g. peptide anion, depsipeptide) have been calculated" - " with a rule-based system. The following shows which parts of the molecule were identified as relevant" - " structures and have influenced the classification.") + ( + "text", + "Results for peptides and peptide-related classes (e.g. peptide anion, depsipeptide) have been calculated" + " with a rule-based system. The following shows which parts of the molecule were identified as relevant" + " structures and have influenced the classification.", + ) ] highlight_blocks = zero_blocks + self.build_explain_blocks_peptides(info) diff --git a/chebifier/prediction_models/gnn_predictor.py b/chebifier/prediction_models/gnn_predictor.py index 64f21ae..3d6fc92 100644 --- a/chebifier/prediction_models/gnn_predictor.py +++ b/chebifier/prediction_models/gnn_predictor.py @@ -59,9 +59,7 @@ def read_smiles(self, smiles): # use default value if we meet an unseen value if isinstance(prop.encoder, IndexEncoder): if str(value) in prop.encoder.cache: - index = ( - prop.encoder.cache[str(value)] + prop.encoder.offset - ) + index = prop.encoder.cache[str(value)] + prop.encoder.offset else: index = 0 print( diff --git a/chebifier/prediction_models/nn_predictor.py b/chebifier/prediction_models/nn_predictor.py index dd404ba..312766d 100644 --- a/chebifier/prediction_models/nn_predictor.py +++ b/chebifier/prediction_models/nn_predictor.py @@ -61,7 +61,9 @@ def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list: index_map = dict() for i, smiles in enumerate(smiles_list): if not smiles: - print(f"Model {self.model_name} received a missing SMILES string at position {i}.") + print( + f"Model {self.model_name} received a missing SMILES string at position {i}." + ) could_not_parse.append(i) continue try: @@ -69,12 +71,16 @@ def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list: # This is just for sanity checks rdmol = Chem.MolFromSmiles(smiles, sanitize=False) if rdmol is None: - print(f"Model {self.model_name} received a SMILES string RDKit can't read at position {i}: {smiles}") + print( + f"Model {self.model_name} received a SMILES string RDKit can't read at position {i}: {smiles}" + ) could_not_parse.append(i) continue except Exception as e: could_not_parse.append(i) - print(f"Model {self.model_name} failed to parse a SMILES string at position {i}: {smiles}") + print( + f"Model {self.model_name} failed to parse a SMILES string at position {i}: {smiles}" + ) continue index_map[i] = len(token_dicts) token_dicts.append(d) diff --git a/pyproject.toml b/pyproject.toml index ff7837d..efa9a3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,3 +30,6 @@ dependencies = [ [tool.setuptools] packages = ["chebifier", "chebifier.ensemble", "chebifier.prediction_models"] + +[project.optional-dependencies] +dev = ["black", "isort", "pre-commit"] From 2115cfd2d8510389255ba6a1a99b28991564ad10 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 21 Jul 2025 20:05:51 +0200 Subject: [PATCH 35/38] refactor with ruff --- chebifier/cli.py | 3 +-- chebifier/ensemble/base_ensemble.py | 2 +- chebifier/prediction_models/nn_predictor.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/chebifier/cli.py b/chebifier/cli.py index 924ef5d..c201187 100644 --- a/chebifier/cli.py +++ b/chebifier/cli.py @@ -1,5 +1,4 @@ import importlib.resources -import os import click import yaml @@ -75,7 +74,7 @@ def predict( """Predict ChEBI classes for SMILES strings using an ensemble model.""" # Load configuration from YAML file if not ensemble_config: - print(f"Using default ensemble configuration") + print("Using default ensemble configuration") with ( importlib.resources.files("chebifier") .joinpath("ensemble.yml") diff --git a/chebifier/ensemble/base_ensemble.py b/chebifier/ensemble/base_ensemble.py index 39f643a..989c691 100644 --- a/chebifier/ensemble/base_ensemble.py +++ b/chebifier/ensemble/base_ensemble.py @@ -200,7 +200,7 @@ def predict_smiles_list( ) if len(predicted_classes) == 0: print( - f"Warning: No classes have been predicted for the given SMILES list." + "Warning: No classes have been predicted for the given SMILES list." ) # save predictions torch.save(ordered_predictions, preds_file) diff --git a/chebifier/prediction_models/nn_predictor.py b/chebifier/prediction_models/nn_predictor.py index 312766d..e7d72c9 100644 --- a/chebifier/prediction_models/nn_predictor.py +++ b/chebifier/prediction_models/nn_predictor.py @@ -76,7 +76,7 @@ def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list: ) could_not_parse.append(i) continue - except Exception as e: + except Exception: could_not_parse.append(i) print( f"Model {self.model_name} failed to parse a SMILES string at position {i}: {smiles}" From 572f0edaec55da30475b3380fcdead4ba32f17a8 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 21 Jul 2025 20:50:35 +0200 Subject: [PATCH 36/38] update readme --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index efa9a3f..6c0f2fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "chebifier" -version = "1.0.0" +version = "1.1.0" description = "An AI ensemble model for predicting chemical classes" readme = "README.md" requires-python = ">=3.9" From 9769e23d59ef5645ed19de2a3143f76961de7fe2 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 21 Jul 2025 20:54:22 +0200 Subject: [PATCH 37/38] update readme --- README.md | 61 +++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 46 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 0559817..29141da 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,17 @@ # python-chebifier -An AI ensemble model for predicting chemical classes in the ChEBI ontology. +An AI ensemble model for predicting chemical classes in the ChEBI ontology. It integrates deep learning models, +rule-based models and generative AI-based models. + +A web application for the ensemble is available at https://chebifier.hastingslab.org/. ## Installation +You can get the package from PyPI: +```bash +pip install chebifier +``` + +or get the latest development version from GitHub: ```bash # Clone the repository git clone https://github.com/yourusername/python-chebifier.git @@ -12,7 +21,7 @@ cd python-chebifier pip install -e . ``` -u`chebai-graph` and its dependencies cannot be installed automatically. If you want to use Graph Neural Networks, follow +`chebai-graph` and its dependencies cannot be installed automatically. If you want to use Graph Neural Networks, follow the instructions in the [chebai-graph repository](https://github.com/ChEB-AI/python-chebai-graph). ## Usage @@ -21,23 +30,24 @@ the instructions in the [chebai-graph repository](https://github.com/ChEB-AI/pyt The package provides a command-line interface (CLI) for making predictions using an ensemble model. -```bash -# Get help -python -m chebifier.cli --help +The ensemble configuration is given by a configuration file (by default, this is `chebifier/ensemble.yml`). If you +want to change which models are included in the ensemble or how they are weighted, you can create your own configuration file. -# Make predictions using a configuration file -python -m chebifier.cli predict configs/example_config.yml --smiles "CC(=O)OC1=CC=CC=C1C(=O)O" "C1=CC=C(C=C1)C(=O)O" +Model weights for deep learning models are downloaded automatically from [Hugging Face](https://huggingface.co/chebai). +However, you can also supply your own model checkpoints (see `configs/example_config.yml` for an example). -# Make predictions using SMILES from a file -python -m chebifier.cli predict configs/example_config.yml --smiles-file smiles.txt -``` +```bash +# Make predictions +python -m chebifier predict --smiles "CC(=O)OC1=CC=CC=C1C(=O)O" --smiles "C1=CC=C(C=C1)C(=O)O" -### Configuration File +# Make predictions using SMILES from a file +python -m chebifier predict --smiles-file smiles.txt -The CLI requires a YAML configuration file that defines the ensemble model. An example can be found in `configs/example_config.yml`. +# Make predictions using a configuration file +python -m chebifier predict --ensemble-config configs/my_config.yml --smiles-file smiles.txt -The models and other required files are trained / generated by our [chebai](https://github.com/ChEB-AI/python-chebai) package. -Examples for models can be found on [kaggle](https://www.kaggle.com/datasets/sfluegel/chebai). +python -m chebifier predict --help +``` ### Python API @@ -67,6 +77,27 @@ for smiles, prediction in zip(smiles_list, predictions): print("No predictions") ``` +### The models +Currently, the following models are supported: + + +| Model | Description | #Classes | Publication | Repository | +|-------|-------------|----------|-----------------------------------------------------------------------|----------------------------------------------------------------------------------------| +| `electra` | A transformer-based deep learning model trained on ChEBI SMILES strings. | 1522 | [Glauer, Martin, et al., 2024: Chebifier: Automating semantic classification in ChEBI to accelerate data-driven discovery, Digital Discovery 3 (2024) 896-907](https://pubs.rsc.org/en/content/articlehtml/2024/dd/d3dd00238a) | [python-chebai](https://github.com/ChEB-AI/python-chebai) | +| `resgated` | A Residual Gated Graph Convolutional Network trained on ChEBI molecules. | 1522 | | [python-chebai-graph](https://github.com/ChEB-AI/python-chebai-graph) | +| `chemlog_peptides` | A rule-based model specialised on peptide classes. | 18 | [Flügel, Simon, et al., 2025: ChemLog: Making MSOL Viable for Ontological Classification and Learning, arXiv](https://arxiv.org/abs/2507.13987) | [chemlog-peptides](https://github.com/sfluegel05/chemlog-peptides) | +| `chemlog_element`, `chemlog_organox` | Extensions of ChemLog for classes that are defined either by the presence of a specific element or by the presence of an organic bond. | 118 + 37 | | [chemlog-extra](https://github.com/ChEB-AI/chemlog-extra) | +| `c3p` | A collection _Chemical Classifier Programs_, generated by LLMs based on the natural language definitions of ChEBI classes. | 338 | [Mungall, Christopher J., et al., 2025: Chemical classification program synthesis using generative artificial intelligence, arXiv](https://arxiv.org/abs/2505.18470) | [c3p](https://github.com/chemkg/c3p) | + +In addition, Chebifier also includes a ChEBI lookup that automatically retrieves the ChEBI superclasses for a class +matched by a SMILES string. This is not activated by default, but can be included by adding +```yaml +chebi_lookup: + type: chebi_lookup + model_weight: 10 # optional +``` +to your configuration file. + ### The ensemble Given a sample (i.e., a SMILES string) and models $m_1, m_2, \ldots, m_n$, the ensemble works as follows: @@ -110,7 +141,7 @@ belongs to the direct and indirect superclasses (e.g., primary alcohol, aromatic - (2) Next, we check for disjointness. This is not specified directly in ChEBI, but in an additional ChEBI module ([chebi-disjoints.owl](https://ftp.ebi.ac.uk/pub/databases/chebi/ontology/)). We have extracted these disjointness axioms into a CSV file and added some more disjointness axioms ourselves (see `data>disjoint_chebi.csv` and `data>disjoint_additional.csv`). If two classes $A$ and $B$ are disjoint and we predict -both, we select one of them randomly and set the other to 0. +both, we select one with the higher class score and set the other to 0. - (3) Since the second step might have introduced new inconsistencies into the hierarchy, we repeat the first step, but with a small change. For a pair of classes $A \subseteq B$ with predictions $1$ and $0$, instead of setting $B$ to $1$, we now set $A$ to $0$. This has the advantage that we cannot introduce new disjointness-inconsistencies and don't have From 0377027c46d97ec4ba183e9ff3401209d22e470b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Fl=C3=BCgel?= <43573433+sfluegel05@users.noreply.github.com> Date: Mon, 21 Jul 2025 21:01:21 +0200 Subject: [PATCH 38/38] Update README.md --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 965ec8f..42031b5 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,7 @@ python -m chebifier predict --smiles-file smiles.txt # Make predictions using a configuration file python -m chebifier predict --ensemble-config configs/my_config.yml --smiles-file smiles.txt +# Get all available options python -m chebifier predict --help ``` @@ -99,6 +100,7 @@ chebi_lookup: to your configuration file. ### The ensemble +ensemble_architecture Given a sample (i.e., a SMILES string) and models $m_1, m_2, \ldots, m_n$, the ensemble works as follows: 1. Get predictions from each model $m_i$ for the sample. @@ -134,7 +136,7 @@ Trust is based on the model's performance on a validation set. After training, w on a validation set for each class. If the `ensemble_type` is set to `wmv-f1`, the trust is calculated as 1 + the F1 score. If the `ensemble_type` is set to `mv` (the default), the trust is set to 1 for all models. -### Inconsistency correction +### Inconsistency resolution After a decision has been made for each class independently, the consistency of the predictions with regard to the ChEBI hierarchy and disjointness axioms is checked. This is done in 3 steps: