diff --git a/medcat-v2/medcat/components/linking/embedding_linker.py b/medcat-v2/medcat/components/linking/embedding_linker.py index 80d495dd8..c0e8b594c 100644 --- a/medcat-v2/medcat/components/linking/embedding_linker.py +++ b/medcat-v2/medcat/components/linking/embedding_linker.py @@ -619,7 +619,8 @@ def _generate_link_candidates( entity.link_candidates = list(cuis) - def _pre_inference(self, doc: MutableDocument) -> tuple[list, list]: + def _pre_inference(self, doc: MutableDocument + ) -> tuple[list[MutableEntity], list[MutableEntity]]: """Checking all entities for entites with only a single link candidate and to avoid full inference step. If we want to calculate similarities, or not use link candidates then just return the entities""" @@ -643,8 +644,8 @@ def _pre_inference(self, doc: MutableDocument) -> tuple[list, list]: if self.cnf_l.always_calculate_similarity: return [], filtered_ents - le = [] - to_infer = [] + le: list[MutableEntity] = [] + to_infer: list[MutableEntity] = [] for entity in all_ents: if len(entity.link_candidates) == 1: # if the include filter exists and the only cui is in it @@ -653,7 +654,7 @@ def _pre_inference(self, doc: MutableDocument) -> tuple[list, list]: entity.context_similarity = 1 le.append(entity) continue - elif self.cnf_l.use_ner_link_candidates: + elif self.cnf_l.use_ner_link_candidates and not entity.link_candidates: continue # it has to be inferred due to filters or number of link candidates to_infer.append(entity) diff --git a/medcat-v2/tests/components/linking/test_embedding_linker.py b/medcat-v2/tests/components/linking/test_embedding_linker.py index 187bc2189..658ecc52f 100644 --- a/medcat-v2/tests/components/linking/test_embedding_linker.py +++ b/medcat-v2/tests/components/linking/test_embedding_linker.py @@ -1,13 +1,17 @@ from medcat.components.linking import embedding_linker from medcat.components import types from medcat.config import Config +from medcat.data.entities import Entity from medcat.vocab import Vocab +from medcat.cat import CAT from medcat.cdb.concepts import CUIInfo, NameInfo from medcat.components.types import TrainableComponent from medcat.components.types import _DEFAULT_LINKING as DEF_LINKING import unittest from ..helper import ComponentInitTests +from ... import UNPACKED_EXAMPLE_MODEL_PACK_PATH + class FakeDocument: linked_ents = [] ner_ents = [] @@ -64,4 +68,36 @@ def test_linker_is_not_trainable(self): def test_linker_processes_document(self): doc = FakeDocument("Test Document") - self.linker(doc) \ No newline at end of file + self.linker(doc) + + +class EmbeddingModelDisambiguationTests(unittest.TestCase): + PLACEHOLDER = "{SOME_PLACEHOLDER}" + TEXT = f"""The issue has a lot to do with the {PLACEHOLDER}""" + + @classmethod + def setUpClass(cls) -> None: + cls.model = CAT.load_model_pack(UNPACKED_EXAMPLE_MODEL_PACK_PATH) + cls.model.config.components.linking = embedding_linker.EmbeddingLinking() + cls.model._recreate_pipe() + linker: embedding_linker.Linker = cls.model.pipe.get_component( + types.CoreComponentType.linking) + linker.create_embeddings() + + def assert_has_name(self, out_ents: dict[int, Entity], name: str): + self.assertTrue( + any(ent["source_value"] == name for ent in out_ents.values()) + ) + + def test_does_disambiguation(self): + used_names = 0 + for name, info in self.model.cdb.name2info.items(): + if len(info['per_cui_status']) <= 1: + continue + used_names += 1 + with self.subTest(name): + cur_text = self.TEXT.replace(self.PLACEHOLDER, name) + out_ents = self.model.get_entities(cur_text)["entities"] + self.assert_has_name(out_ents, name) + self.assertGreater(used_names, 0) +