diff --git a/README.md b/README.md index 42031b5..f16f8f2 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,9 @@ A web application for the ensemble is available at https://chebifier.hastingslab ## Installation +Note: `chebai-graph` and its dependencies cannot be installed automatically. To install it, follow +the instructions in the [chebai-graph repository](https://github.com/ChEB-AI/python-chebai-graph). Other dependencies are installed automatically. + You can get the package from PyPI: ```bash pip install chebifier @@ -21,9 +24,6 @@ cd python-chebifier pip install -e . ``` -`chebai-graph` and its dependencies cannot be installed automatically. If you want to use Graph Neural Networks, follow -the instructions in the [chebai-graph repository](https://github.com/ChEB-AI/python-chebai-graph). - ## Usage ### Command Line Interface diff --git a/chebifier/ensemble/base_ensemble.py b/chebifier/ensemble/base_ensemble.py index 989c691..7a3aef2 100644 --- a/chebifier/ensemble/base_ensemble.py +++ b/chebifier/ensemble/base_ensemble.py @@ -3,8 +3,8 @@ import torch import tqdm -from chebai.preprocessing.datasets.chebi import ChEBIOver50 -from chebai.result.analyse_sem import PredictionSmoother, get_chebi_graph +from chebifier.inconsistency_resolution import PredictionSmoother +from chebifier.utils import load_chebi_graph, get_disjoint_files from chebifier.check_env import check_package_installed from chebifier.prediction_models.base_predictor import BasePredictor @@ -21,32 +21,8 @@ def __init__( # Deferred Import: To avoid circular import error from chebifier.model_registry import MODEL_TYPES - self.chebi_dataset = ChEBIOver50(chebi_version=chebi_version) - self.chebi_dataset._download_required_data() # download chebi if not already downloaded - self.chebi_graph = get_chebi_graph(self.chebi_dataset, None) - local_disjoint_files = [ - os.path.join("data", "disjoint_chebi.csv"), - os.path.join("data", "disjoint_additional.csv"), - ] - self.disjoint_files = [] - for file in local_disjoint_files: - if os.path.isfile(file): - self.disjoint_files.append(file) - else: - print( - f"Disjoint axiom file {file} not found. Loading from huggingface instead..." - ) - from chebifier.hugging_face import download_model_files - - self.disjoint_files.append( - download_model_files( - { - "repo_id": "chebai/chebifier", - "repo_type": "dataset", - "files": {"disjoint_file": os.path.basename(file)}, - } - )["disjoint_file"] - ) + self.chebi_graph = load_chebi_graph() + self.disjoint_files = get_disjoint_files() self.models = [] self.positive_prediction_threshold = 0.5 @@ -72,7 +48,7 @@ def __init__( if resolve_inconsistencies: self.smoother = PredictionSmoother( - self.chebi_dataset, + self.chebi_graph, label_names=None, disjoint_files=self.disjoint_files, ) @@ -203,10 +179,11 @@ def predict_smiles_list( "Warning: No classes have been predicted for the given SMILES list." ) # save predictions - torch.save(ordered_predictions, preds_file) - with open(predicted_classes_file, "w") as f: - for cls in predicted_classes: - f.write(f"{cls}\n") + if load_preds_if_possible: + torch.save(ordered_predictions, preds_file) + with open(predicted_classes_file, "w") as f: + for cls in predicted_classes: + f.write(f"{cls}\n") predicted_classes = {cls: i for i, cls in enumerate(predicted_classes)} else: print( diff --git a/chebifier/inconsistency_resolution.py b/chebifier/inconsistency_resolution.py new file mode 100644 index 0000000..f6640c2 --- /dev/null +++ b/chebifier/inconsistency_resolution.py @@ -0,0 +1,124 @@ +import csv +import os +import torch +from pathlib import Path + + +def get_disjoint_groups(disjoint_files): + if disjoint_files is None: + disjoint_files = os.path.join("data", "chebi-disjoints.owl") + disjoint_pairs, disjoint_groups = [], [] + for file in disjoint_files: + if isinstance(file, Path): + file = str(file) + if file.endswith(".csv"): + with open(file, "r") as f: + reader = csv.reader(f) + disjoint_pairs += [line for line in reader] + elif file.endswith(".owl"): + with open(file, "r") as f: + plaintext = f.read() + segments = plaintext.split("<") + disjoint_pairs = [] + left = None + for seg in segments: + if seg.startswith("rdf:Description ") or seg.startswith( + "owl:Class" + ): + left = int(seg.split('rdf:about="&obo;CHEBI_')[1].split('"')[0]) + elif seg.startswith("owl:disjointWith"): + right = int( + seg.split('rdf:resource="&obo;CHEBI_')[1].split('"')[0] + ) + disjoint_pairs.append([left, right]) + + disjoint_groups = [] + for seg in plaintext.split(""): + if "owl;AllDisjointClasses" in seg: + classes = seg.split('rdf:about="&obo;CHEBI_')[1:] + classes = [int(c.split('"')[0]) for c in classes] + disjoint_groups.append(classes) + else: + raise NotImplementedError( + "Unsupported disjoint file format: " + file.split(".")[-1] + ) + + disjoint_all = disjoint_pairs + disjoint_groups + # one disjointness is commented out in the owl-file + # (the correct way would be to parse the owl file and notice the comment symbols, but for this case, it should work) + if [22729, 51880] in disjoint_all: + disjoint_all.remove([22729, 51880]) + # print(f"Found {len(disjoint_all)} disjoint groups") + return disjoint_all + + +class PredictionSmoother: + """Removes implication and disjointness violations from predictions""" + + def __init__(self, chebi_graph, label_names=None, disjoint_files=None): + self.chebi_graph = chebi_graph + self.set_label_names(label_names) + self.disjoint_groups = get_disjoint_groups(disjoint_files) + + def set_label_names(self, label_names): + if label_names is not None: + self.label_names = label_names + chebi_subgraph = self.chebi_graph.subgraph(self.label_names) + self.label_successors = torch.zeros( + (len(self.label_names), len(self.label_names)), dtype=torch.bool + ) + for i, label in enumerate(self.label_names): + self.label_successors[i, i] = 1 + for p in chebi_subgraph.successors(label): + if p in self.label_names: + self.label_successors[i, self.label_names.index(p)] = 1 + self.label_successors = self.label_successors.unsqueeze(0) + + def __call__(self, preds): + if preds.shape[1] == 0: + # no labels predicted + return preds + # preds shape: (n_samples, n_labels) + preds_sum_orig = torch.sum(preds) + # step 1: apply implications: for each class, set prediction to max of itself and all successors + preds = preds.unsqueeze(1) + preds_masked_succ = torch.where(self.label_successors, preds, 0) + # preds_masked_succ shape: (n_samples, n_labels, n_labels) + + preds = preds_masked_succ.max(dim=2).values + if torch.sum(preds) != preds_sum_orig: + print(f"Preds change (step 1): {torch.sum(preds) - preds_sum_orig}") + preds_sum_orig = torch.sum(preds) + # step 2: eliminate disjointness violations: for group of disjoint classes, set all except max to 0.49 (if it is not already lower) + preds_bounded = torch.min(preds, torch.ones_like(preds) * 0.49) + for disj_group in self.disjoint_groups: + disj_group = [ + self.label_names.index(g) for g in disj_group if g in self.label_names + ] + if len(disj_group) > 1: + old_preds = preds[:, disj_group] + disj_max = torch.max(preds[:, disj_group], dim=1) + for i, row in enumerate(preds): + for l_ in range(len(preds[i])): + if l_ in disj_group and l_ != disj_group[disj_max.indices[i]]: + preds[i, l_] = preds_bounded[i, l_] + samples_changed = 0 + for i, row in enumerate(preds[:, disj_group]): + if any(r != o for r, o in zip(row, old_preds[i])): + samples_changed += 1 + if samples_changed != 0: + print( + f"disjointness group {[self.label_names[d] for d in disj_group]} changed {samples_changed} samples" + ) + if torch.sum(preds) != preds_sum_orig: + print(f"Preds change (step 2): {torch.sum(preds) - preds_sum_orig}") + preds_sum_orig = torch.sum(preds) + # step 3: disjointness violation removal may have caused new implication inconsistencies -> set each prediction to min of predecessors + preds = preds.unsqueeze(1) + preds_masked_predec = torch.where( + torch.transpose(self.label_successors, 1, 2), preds, 1 + ) + preds = preds_masked_predec.min(dim=2).values + if torch.sum(preds) != preds_sum_orig: + print(f"Preds change (step 3): {torch.sum(preds) - preds_sum_orig}") + return preds diff --git a/chebifier/prediction_models/c3p_predictor.py b/chebifier/prediction_models/c3p_predictor.py index a8e2e10..00c71f7 100644 --- a/chebifier/prediction_models/c3p_predictor.py +++ b/chebifier/prediction_models/c3p_predictor.py @@ -39,7 +39,7 @@ def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list: chebi_id ] = result.is_match if result.is_match and self.chebi_graph is not None: - for parent in list(self.chebi_graph.predecessors(int(chebi_id))): + for parent in list(self.chebi_graph.predecessors(chebi_id)): result_reformatted[smiles_list.index(result.input_smiles)][ str(parent) ] = 1 diff --git a/chebifier/prediction_models/chemlog_predictor.py b/chebifier/prediction_models/chemlog_predictor.py index 10729d2..8232641 100644 --- a/chebifier/prediction_models/chemlog_predictor.py +++ b/chebifier/prediction_models/chemlog_predictor.py @@ -63,7 +63,7 @@ def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list: sample_additions = dict() for cls in sample: if sample[cls] == 1: - successors = list(self.chebi_graph.predecessors(int(cls))) + successors = list(self.chebi_graph.predecessors(cls)) if successors: for succ in successors: sample_additions[str(succ)] = 1 @@ -114,7 +114,7 @@ def predict_smiles(self, smiles: str) -> Optional[dict]: indirect_pos_labels = [ str(pr) for label in pos_labels - for pr in self.chebi_graph.predecessors(int(label)) + for pr in self.chebi_graph.predecessors(label) ] pos_labels = list(set(pos_labels + indirect_pos_labels)) return { diff --git a/chebifier/utils.py b/chebifier/utils.py new file mode 100644 index 0000000..e6fefae --- /dev/null +++ b/chebifier/utils.py @@ -0,0 +1,131 @@ +import os + +import networkx as nx +import requests +import fastobo +from chebifier.hugging_face import download_model_files +import pickle + + +def load_chebi_graph(filename=None): + """Load ChEBI graph from Hugging Face (if filename is None) or local file""" + if filename is None: + print("Loading ChEBI graph from Hugging Face...") + file = download_model_files( + { + "repo_id": "chebai/chebifier", + "repo_type": "dataset", + "files": {"f": "chebi_graph.pkl"}, + } + )["f"] + else: + print(f"Loading ChEBI graph from local {filename}...") + file = filename + return pickle.load(open(file, "rb")) + + +def term_callback(doc): + """Similar to the chebai function, but reduced to the necessary fields. Also, ChEBI IDs are strings""" + parents = [] + name = None + smiles = None + for clause in doc: + if isinstance(clause, fastobo.term.PropertyValueClause): + t = clause.property_value + if str(t.relation) == "http://purl.obolibrary.org/obo/chebi/smiles": + assert smiles is None + smiles = t.value + # in older chebi versions, smiles strings are synonyms + # e.g. synonym: "[F-].[Na+]" RELATED SMILES [ChEBI] + elif isinstance(clause, fastobo.term.SynonymClause): + if "SMILES" in clause.raw_value(): + assert smiles is None + smiles = clause.raw_value().split('"')[1] + elif isinstance(clause, fastobo.term.IsAClause): + chebi_id = str(clause.term) + chebi_id = chebi_id[chebi_id.index(":") + 1 :] + parents.append(chebi_id) + elif isinstance(clause, fastobo.term.NameClause): + name = str(clause.name) + + if isinstance(clause, fastobo.term.IsObsoleteClause): + if clause.obsolete: + # if the term document contains clause as obsolete as true, skips this document. + return False + chebi_id = str(doc.id) + chebi_id = chebi_id[chebi_id.index(":") + 1 :] + return { + "id": chebi_id, + "parents": parents, + "name": name, + "smiles": smiles, + } + + +def build_chebi_graph(chebi_version=241): + """Creates a networkx graph for the ChEBI hierarchy. Usually, you don't want to call this function directly, but rather use the `load_chebi_graph` function.""" + chebi_path = os.path.join("data", f"chebi_v{chebi_version}", "chebi.obo") + os.makedirs(os.path.join("data", f"chebi_v{chebi_version}"), exist_ok=True) + if not os.path.exists(chebi_path): + url = f"http://purl.obolibrary.org/obo/chebi/{chebi_version}/chebi.obo" + r = requests.get(url, allow_redirects=True) + open(chebi_path, "wb").write(r.content) + with open(chebi_path, encoding="utf-8") as chebi: + chebi = "\n".join(line for line in chebi if not line.startswith("xref:")) + + elements = [] + for term_doc in fastobo.loads(chebi): + if ( + term_doc + and isinstance(term_doc.id, fastobo.id.PrefixedIdent) + and term_doc.id.prefix == "CHEBI" + ): + term_dict = term_callback(term_doc) + if term_dict: + elements.append(term_dict) + + g = nx.DiGraph() + for n in elements: + g.add_node(n["id"], **n) + + # Only take the edges which connect the existing nodes, to avoid internal creation of obsolete nodes + # https://github.com/ChEB-AI/python-chebai/pull/55#issuecomment-2386654142 + g.add_edges_from( + [(p, q["id"]) for q in elements for p in q["parents"] if g.has_node(p)] + ) + return nx.transitive_closure_dag(g) + + +def get_disjoint_files(): + """Gets local disjointness files if they are present in the right location, otherwise downloads them from Hugging Face.""" + local_disjoint_files = [ + os.path.join("data", "disjoint_chebi.csv"), + os.path.join("data", "disjoint_additional.csv"), + ] + disjoint_files = [] + for file in local_disjoint_files: + if os.path.isfile(file): + disjoint_files.append(file) + else: + print( + f"Disjoint axiom file {file} not found. Loading from huggingface instead..." + ) + + disjoint_files.append( + download_model_files( + { + "repo_id": "chebai/chebifier", + "repo_type": "dataset", + "files": {"disjoint_file": os.path.basename(file)}, + } + )["disjoint_file"] + ) + return disjoint_files + + +if __name__ == "__main__": + # chebi_graph = build_chebi_graph(chebi_version=241) + # save the graph to a file + # pickle.dump(chebi_graph, open("chebi_graph.pkl", "wb")) + chebi_graph = load_chebi_graph() + print(chebi_graph) diff --git a/pyproject.toml b/pyproject.toml index 6c0f2fb..325d3d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "chebifier" -version = "1.1.0" +version = "1.1.1" description = "An AI ensemble model for predicting chemical classes" readme = "README.md" requires-python = ">=3.9" @@ -24,12 +24,18 @@ dependencies = [ "tqdm", "rdkit", "chebai>=1.0.1", - "chemlog>=1.0.4" + "chemlog>=1.0.4", + "chemlog_extra @ git+https://github.com/ChEB-AI/chemlog-extra.git", + # forked version of c3p is windows-compatible + "c3p @ git+https://github.com/sfluegel05/c3p.git" ] [tool.setuptools] packages = ["chebifier", "chebifier.ensemble", "chebifier.prediction_models"] +[tool.setuptools.package-data] +chebifier = ["*.yml"] + [project.optional-dependencies] dev = ["black", "isort", "pre-commit"]