Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions medcat-v2/medcat/components/linking/embedding_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,8 @@ def _generate_link_candidates(

entity.link_candidates = list(cuis)

def _pre_inference(self, doc: MutableDocument) -> tuple[list, list]:
def _pre_inference(self, doc: MutableDocument
) -> tuple[list[MutableEntity], list[MutableEntity]]:
"""Checking all entities for entites with only a single link candidate and to
avoid full inference step. If we want to calculate similarities, or not use
link candidates then just return the entities"""
Expand All @@ -643,8 +644,8 @@ def _pre_inference(self, doc: MutableDocument) -> tuple[list, list]:
if self.cnf_l.always_calculate_similarity:
return [], filtered_ents

le = []
to_infer = []
le: list[MutableEntity] = []
to_infer: list[MutableEntity] = []
for entity in all_ents:
if len(entity.link_candidates) == 1:
# if the include filter exists and the only cui is in it
Expand All @@ -653,7 +654,7 @@ def _pre_inference(self, doc: MutableDocument) -> tuple[list, list]:
entity.context_similarity = 1
le.append(entity)
continue
elif self.cnf_l.use_ner_link_candidates:
elif self.cnf_l.use_ner_link_candidates and not entity.link_candidates:
continue
# it has to be inferred due to filters or number of link candidates
to_infer.append(entity)
Expand Down
38 changes: 37 additions & 1 deletion medcat-v2/tests/components/linking/test_embedding_linker.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from medcat.components.linking import embedding_linker
from medcat.components import types
from medcat.config import Config
from medcat.data.entities import Entity
from medcat.vocab import Vocab
from medcat.cat import CAT
from medcat.cdb.concepts import CUIInfo, NameInfo
from medcat.components.types import TrainableComponent
from medcat.components.types import _DEFAULT_LINKING as DEF_LINKING
import unittest
from ..helper import ComponentInitTests

from ... import UNPACKED_EXAMPLE_MODEL_PACK_PATH

class FakeDocument:
linked_ents = []
ner_ents = []
Expand Down Expand Up @@ -64,4 +68,36 @@ def test_linker_is_not_trainable(self):

def test_linker_processes_document(self):
doc = FakeDocument("Test Document")
self.linker(doc)
self.linker(doc)


class EmbeddingModelDisambiguationTests(unittest.TestCase):
PLACEHOLDER = "{SOME_PLACEHOLDER}"
TEXT = f"""The issue has a lot to do with the {PLACEHOLDER}"""

@classmethod
def setUpClass(cls) -> None:
cls.model = CAT.load_model_pack(UNPACKED_EXAMPLE_MODEL_PACK_PATH)
cls.model.config.components.linking = embedding_linker.EmbeddingLinking()
cls.model._recreate_pipe()
linker: embedding_linker.Linker = cls.model.pipe.get_component(
types.CoreComponentType.linking)
linker.create_embeddings()

def assert_has_name(self, out_ents: dict[int, Entity], name: str):
self.assertTrue(
any(ent["source_value"] == name for ent in out_ents.values())
)

def test_does_disambiguation(self):
used_names = 0
for name, info in self.model.cdb.name2info.items():
if len(info['per_cui_status']) <= 1:
continue
used_names += 1
with self.subTest(name):
cur_text = self.TEXT.replace(self.PLACEHOLDER, name)
out_ents = self.model.get_entities(cur_text)["entities"]
self.assert_has_name(out_ents, name)
self.assertGreater(used_names, 0)

Loading