diff --git a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/config.py b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/config.py index 6a58b3a0d..b7fba6a9b 100644 --- a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/config.py +++ b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/config.py @@ -70,13 +70,46 @@ class EmbeddingLinking(Linking): """Choose a device for the linking model to be stored. If None then an appropriate GPU device that is available will be chosen""" context_window_size: int = 14 - """Choose the window size to get context vectors.""" + """Choose the window size to get context vectors. In a trained model + if you increase the context window after training then performance will + degrade significantly.""" use_ner_link_candidates: bool = True """Link candidates are provided by some NER steps. This will flag if - you want to trust them or not.""" + you want to trust them or not. A good guideline is if you've trained + on data from the same distribution then this is probably best set to True. + If you have no training data from the same source distribution then it MIGHT + be better set to false.""" + append_to_ner_link_candidates: bool = False + """If `use_ner_link_candidates` is enabled, generate additional + candidates and append them to existing NER candidates instead of only + generating for entities that have none. This will often result in a slight + increase in recall, and precision.""" + use_pre_inference: bool = True + """Whether to use the pre-inference step to filter candidates before + calculating similarities. This can speed up inference by only calculating + similarities for candidates that are likely to be correct based direct on word + matching.""" learning_rate: float = 1e-4 """Learning rate for training the embedding linker. Only used if the embedding linker is trainable.""" weight_decay: float = 0.01 """Weight decay for training the embedding linker. Only used if the embedding linker is trainable.""" + multiple_predictions_per_detected_entity: bool = False + """Whether to allow multiple predictions per detected entity. If False, only + the highest scoring candidate will be returned for each entity. If True, all + candidates that exceed the similarity thresholds will be returned. This can be + useful if you want to return multiple CUIs for an entity, but can also lead to + more false positives.""" + pre_inference_top_k_sampling: int = 1 + """When using pre-inference to filter candidates, how many names to then add + their related CUIs as potential candidates. Higher numbers will increase recall + but also increase inference time, and reduce precision. This is influenced by + `short_similarity_threshold`, i.e. pass the top k samples over the threshold + for inference.""" + inference_top_k_sampling: int = 1 + """At the inference step, after calculating similarity scores, how many candidates + to keep for each entity. Higher numbers will increase recall but also increase + inference time, and often reduce precision. This is influenced by + `long_similarity_threshold`, i.e. take the top k samples over the threshold. This + will be ignored if `multiple_predictions_per_detected_entity` is set to False.""" diff --git a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/embedding_linker.py b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/embedding_linker.py index 60ac8e6fd..546d624dc 100644 --- a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/embedding_linker.py +++ b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/embedding_linker.py @@ -33,18 +33,21 @@ def __init__( self, cdb: CDB, config: Config, + tokenizer: BaseTokenizer, model_init_kwargs: Optional[dict[str, Any]] = None, ) -> None: """Initializes the embedding linker with a CDB and configuration. Args: cdb (CDB): The concept database to use. config (Config): The base config. + tokenizer (BaseTokenizer): The tokenizer to use. model_init_kwargs (Optional[dict[str, Any]]): Explicit kwargs that override linker defaults. """ super().__init__() self.cdb = cdb self.config = config + self.tokenizer = tokenizer if not isinstance(config.components.linking, EmbeddingLinking): raise TypeError("Linking config must be an EmbeddingLinking instance") self.cnf_l: EmbeddingLinking = config.components.linking @@ -369,7 +372,7 @@ def _set_filters(self) -> None: def _disambiguate_by_cui( self, cui_candidates: list[str], scores: Tensor - ) -> tuple[str, float]: + ) -> list[tuple[str, float]]: """Disambiguate a detected concept by a list of potential cuis Args: cuis (list[str]): Potential cuis @@ -377,19 +380,71 @@ def _disambiguate_by_cui( scores (Tensor): Scores for the detected cui2info concepts similarity cui_keys (list[str]): idx_to_cui inverse Returns: - tuple[str, float]: - The CUI and its similarity + list[tuple[str, float]]: + The selected CUIs and their similarities. """ cui_idxs = [ self._cui_to_idx[cui] for cui in cui_candidates if cui in self._cui_to_idx ] + if not cui_idxs: + return [] + candidate_scores = scores[cui_idxs] + + if self.cnf_l.multiple_predictions_per_detected_entity: + threshold = self.cnf_l.long_similarity_threshold + selected_mask = candidate_scores >= threshold + selected_positions = torch.nonzero(selected_mask, as_tuple=True)[0] + + return [ + ( + self._cui_keys[cui_idxs[pos]], + float(candidate_scores[pos].item()), + ) + for pos in selected_positions.tolist() + ] + candidate_idx = int(torch.argmax(candidate_scores).item()) best_idx = cui_idxs[candidate_idx] - predicted_cui = self._cui_keys[best_idx] - similarity = float(candidate_scores[candidate_idx].item()) - return predicted_cui, similarity + return [ + ( + self._cui_keys[best_idx], + float(candidate_scores[candidate_idx].item()), + ) + ] + + def _get_predictions_from_names( + self, + selected_name_idxs: list[int], + row_scores: Tensor, + cui_scores_row: Tensor, + name_to_cuis: Optional[dict[str, list[str]]] = None, + ) -> list[tuple[str, float]]: + """Retrieve all cuis from the candidate names + + Optional - use name to cuis that has already been generated + with link candidates + """ + cuis_set: set[str] = set() + for name_idx in selected_name_idxs: + selected_name = self._name_keys[name_idx] + if name_to_cuis is None: + cuis_set.update( + self.cdb.name2info[selected_name]["per_cui_status"].keys() + ) + else: + cuis_set.update(name_to_cuis[selected_name]) + cuis = list(cuis_set) + if len(cuis) == 1: + # If there's only one possible cui from the names + # We don't get the similarity for the longest cui score + # Just for speed - this may alter performance if the longest name + # for the cui doesn't meet it's threshold + similarity = max(float(row_scores[name_idx].item()) + for name_idx in selected_name_idxs) + return [(cuis[0], similarity)] + return self._disambiguate_by_cui(cuis, cui_scores_row) def _inference( self, doc: MutableDocument, entities: list[MutableEntity] @@ -409,9 +464,9 @@ def _inference( # score all detected contexts vs all names names_scores = detected_context_vectors @ self.names_context_matrix.T cui_scores = detected_context_vectors @ self.cui_context_matrix.T - sorted_indices = torch.argsort(names_scores, dim=1, descending=True) for i, entity in enumerate(entities): + predictions: list[tuple[str, float]] = [] link_candidates = entity.link_candidates if self.config.components.linking.filter_before_disamb: link_candidates = [ @@ -419,7 +474,13 @@ def _inference( for cui in link_candidates if self.cnf_l.filters.check_filters(cui) ] - if len(link_candidates) == 1: + # TODO: Is this "not" correct? if I skip pre inference I don't care + # about the link candidates? + if ( + len(link_candidates) == 1 and + (self.cnf_l.use_pre_inference or + self.cnf_l.use_ner_link_candidates) + ): best_idx = self._cui_to_idx[link_candidates[0]] predicted_cui = link_candidates[0] if best_idx < 0 or best_idx >= cui_scores.shape[1]: @@ -431,13 +492,19 @@ def _inference( cui_scores.shape[1], ) continue - similarity = cui_scores[i, best_idx].item() - elif len(link_candidates) > 1: + similarity = float(cui_scores[i, best_idx].item()) + predictions = [(predicted_cui, similarity)] + elif ( + len(link_candidates) > 1 and + (self.cnf_l.use_pre_inference or + self.cnf_l.use_ner_link_candidates) + ): + # get all possible names from candidate cuis name_to_cuis = defaultdict(list) for cui in link_candidates: for name in self.cdb.cui2info[cui]["names"]: name_to_cuis[name].append(cui) - + # their position within matricies name_idxs = [ self._name_to_idx[name] for name in name_to_cuis @@ -451,39 +518,83 @@ def _inference( entity.detected_name, ) continue + # get all the scores for the names indexed_scores = names_scores[i, name_idxs] best_local_pos = int(torch.argmax(indexed_scores).item()) best_global_idx = name_idxs[best_local_pos] - similarity = names_scores[i, best_global_idx].item() - best_name = self._name_keys[best_global_idx] - cuis = name_to_cuis[best_name] - if len(cuis) == 1: - predicted_cui = cuis[0] - else: - predicted_cui, _ = self._disambiguate_by_cui(cuis, cui_scores[i, :]) - else: - row_sorted = sorted_indices[i] # sorted candidate indices for entity i + similarity = float(names_scores[i, best_global_idx].item()) + selected_name_idxs = [ + name_idx + for name_idx in name_idxs + if float(names_scores[i, name_idx].item()) >= + self.cnf_l.long_similarity_threshold + ] + # if no names pass the threshold - no cuis will + # skip this detected entity + if not selected_name_idxs: + continue - # Find the first candidate in this row with CUIs - first_true_pos = int( - torch.nonzero(self._valid_names[row_sorted], as_tuple=True)[0][ - 0 - ].item() + predictions = self._get_predictions_from_names( + selected_name_idxs, + names_scores[i], + cui_scores[i, :], + name_to_cuis, + ) + else: + # if there are no link candidates + # or you don't want to use them + row_scores = names_scores[i] + # get all names that pass the threshold + selected_mask = self._valid_names & ( + row_scores >= self.cnf_l.long_similarity_threshold ) + selected_name_idxs = torch.nonzero( + selected_mask, as_tuple=True + )[0].tolist() + # if none pass the threshold + if not selected_name_idxs: + continue - # Get global index + name - top_name_idx = int(row_sorted[first_true_pos].item()) - similarity = names_scores[i, top_name_idx].item() - detected_name = self._name_keys[top_name_idx] - cuis = list(self.cdb.name2info[detected_name]["per_cui_status"].keys()) + # if there are too many, take the top k to reduce processing time + # this is a trade off between compute time and predictive power + # as k increases, processing time increases + if len(selected_name_idxs) > self.cnf_l.inference_top_k_sampling: + selected_scores = row_scores[selected_name_idxs] + topk_positions = torch.topk( + selected_scores, k=self.cnf_l.inference_top_k_sampling + ).indices.tolist() + selected_name_idxs = [ + selected_name_idxs[pos] for pos in topk_positions + ] + + + predictions = self._get_predictions_from_names( + selected_name_idxs, + row_scores, + cui_scores[i, :], + ) - predicted_cui, _ = self._disambiguate_by_cui(cuis, cui_scores[i, :]) - if not self.cnf_l.filters.check_filters(predicted_cui): - continue - if self._check_similarity(similarity): - entity.cui = predicted_cui - entity.context_similarity = similarity - yield entity + for predicted_cui, predicted_similarity in predictions: + # check if the predicted cui passes the filters + if not self.cnf_l.filters.check_filters(predicted_cui): + continue + # This check is useful when there's a single link candidate + # Or only a single prediction that's been disambiguated + if not self._check_similarity(predicted_similarity): + continue + if self.cnf_l.multiple_predictions_per_detected_entity: + # create a barebones entity that has what is requried + ent = self.tokenizer.create_entity( + doc, + entity.base.start_index, + entity.base.end_index, + entity.detected_name, + ) + else: + ent = entity + ent.cui = predicted_cui + ent.context_similarity = predicted_similarity + yield ent def _check_similarity(self, context_similarity: float) -> bool: if self.cnf_l.long_similarity_threshold: @@ -503,7 +614,10 @@ def _build_context_matrices(self) -> None: ) def _generate_link_candidates( - self, doc: MutableDocument, entities: list[MutableEntity] + self, + doc: MutableDocument, + entities: list[MutableEntity], + append_to_existing: bool = False, ) -> None: """Generate link candidates for each detected entity based on context vectors with size 0. Compare to names to get the most @@ -523,22 +637,55 @@ def _generate_link_candidates( # valid names via filtering and contain at least 1 cui valid_mask = self._valid_names[row_sorted] - if self.cnf_l.short_similarity_threshold > 0: - # thresholded selection + valid_positions = torch.nonzero(valid_mask, as_tuple=True)[0] + + if ( + self.cnf_l.short_similarity_threshold > 0 and + self.cnf_l.pre_inference_top_k_sampling > 0 + ): + # Require candidates to satisfy BOTH criteria: + # (a) score above threshold and (b) within top-k valid names. + valid_scores = row_scores[valid_positions] + k = min(self.cnf_l.pre_inference_top_k_sampling, len(valid_positions)) + if k > 0: + topk_rel = torch.topk(valid_scores, k=k).indices + topk_positions = valid_positions[topk_rel] + keep_mask = ( + row_scores[topk_positions] >= + self.cnf_l.short_similarity_threshold + ) + valid_positions = topk_positions[keep_mask] + else: + valid_positions = valid_positions[:0] + elif self.cnf_l.short_similarity_threshold > 0: + # Threshold-only mode. above_thresh_mask = row_scores >= self.cnf_l.short_similarity_threshold selected_mask = valid_mask & above_thresh_mask valid_positions = torch.nonzero(selected_mask, as_tuple=True)[0] + elif self.cnf_l.pre_inference_top_k_sampling > 0: + # Top-k-only mode among valid names. + valid_scores = row_scores[valid_positions] + k = min(self.cnf_l.pre_inference_top_k_sampling, len(valid_positions)) + if k > 0: + topk_rel = torch.topk(valid_scores, k=k).indices + valid_positions = valid_positions[topk_rel] + else: + valid_positions = valid_positions[:0] else: - # just take the single best valid candidate - first_valid = torch.nonzero(valid_mask, as_tuple=True)[0][:1] - valid_positions = first_valid + # If neither criterion is enabled, keep only the best valid candidate. + valid_positions = valid_positions[:1] + # getting cuis from all valid names that pass the threshold and top-k for pos in valid_positions.tolist(): top_name_idx = int(row_sorted[pos].item()) detected_name = self._name_keys[top_name_idx] cuis.update(self.cdb.name2info[detected_name]["per_cui_status"].keys()) - - entity.link_candidates = list(cuis) + + if append_to_existing: + existing = set(entity.link_candidates) + entity.link_candidates = list(existing | cuis) + else: + entity.link_candidates = list(cuis) def _pre_inference( self, doc: MutableDocument @@ -547,9 +694,25 @@ def _pre_inference( avoid full inference step. If we want to calculate similarities, or not use link candidates then just return the entities""" all_ents = doc.ner_ents + # if we don't care to use pre inference just return all entities + # as they are + if not self.cnf_l.use_pre_inference: + return [], all_ents + + append_generated_to_ner = ( + self.cnf_l.use_ner_link_candidates + and self.cnf_l.append_to_ner_link_candidates + ) + if not self.cnf_l.use_ner_link_candidates: + # ignoring link candidates generated by NER + to_generate_link_candidates = all_ents + elif append_generated_to_ner: + # Keep NER candidates and append model-generated candidates. to_generate_link_candidates = all_ents else: + # here we only generate link candidates if they don't exist + # i.e. out of vocabulary to_generate_link_candidates = [ entity for entity in all_ents if not entity.link_candidates ] @@ -558,7 +721,11 @@ def _pre_inference( for entities in self._batch_data( to_generate_link_candidates, self.cnf_l.linking_batch_size ): - self._generate_link_candidates(doc, entities) + self._generate_link_candidates( + doc, + entities, + append_generated_to_ner + ) # filter out entities with no link candidates after thresholding filtered_ents = [ent for ent in all_ents if ent.link_candidates] @@ -569,6 +736,10 @@ def _pre_inference( le: list[MutableEntity] = [] to_infer: list[MutableEntity] = [] for entity in all_ents: + # if no candidates just skip it + if not entity.link_candidates: + continue + # TODO: Check if this is right now with multiple entites being possible if len(entity.link_candidates) == 1: # if the include filter exists and the only cui is in it if self.cnf_l.filters.check_filters(entity.link_candidates[0]): @@ -576,8 +747,6 @@ def _pre_inference( entity.context_similarity = 1 le.append(entity) continue - elif self.cnf_l.use_ner_link_candidates and not entity.link_candidates: - continue # it has to be inferred due to filters or number of link candidates to_infer.append(entity) return le, to_infer @@ -585,12 +754,6 @@ def _pre_inference( def predict_entities( self, doc: MutableDocument, ents: list[MutableEntity] | None = None ) -> list[MutableEntity]: - if self.cnf_l.train and self.name == "embedding_linker": - logger.warning( - "Attemping to train a static embedding linker. " - "This is not possible / required." - "Use the `trainable_embedding_linker` instead." - ) if self.cnf_l.filters.cuis and self.cnf_l.filters.cuis_exclude: logger.warning( "You have both include and exclude filters for CUIs set. " @@ -604,7 +767,7 @@ def predict_entities( for entities in self._batch_data(to_infer, self.cnf_l.linking_batch_size): le.extend(list(self._inference(doc, entities))) - return filter_linked_annotations(doc, le) + return filter_linked_annotations(doc, le, True) @property def names_context_matrix(self): @@ -627,4 +790,4 @@ def create_new_component( vocab: Vocab, model_load_path: Optional[str], ) -> "Linker": - return cls(cdb, cdb.config) + return cls(cdb, cdb.config, tokenizer) diff --git a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/trainable_embedding_linker.py b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/trainable_embedding_linker.py index 4d3202f00..cb052e0ac 100644 --- a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/trainable_embedding_linker.py +++ b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/trainable_embedding_linker.py @@ -2,6 +2,7 @@ from medcat_embedding_linker.config import EmbeddingLinking from torch import Tensor from medcat.cdb import CDB +from medcat.components.types import TrainableComponent from medcat.config.config import Config, ComponentConfig from medcat.components.linking.vector_context_model import PerDocumentTokenCache from medcat.tokenizing.tokenizers import BaseTokenizer @@ -17,7 +18,7 @@ logger = logging.getLogger(__name__) -class Linker(StaticEmbeddingLinker, AbstractManualSerialisable): +class Linker(StaticEmbeddingLinker, AbstractManualSerialisable, TrainableComponent): """Trainable variant of the embedding linker. This class inherits inference and embedding behavior from Linker and provides method hooks for online/offline training. @@ -28,7 +29,10 @@ class Linker(StaticEmbeddingLinker, AbstractManualSerialisable): _MODEL_FOLDER_NAME = "trainable_embedding_model" _MODEL_STATE_FILE_NAME = "model_state.pt" - def __init__(self, cdb: CDB, config: Config) -> None: + def __init__(self, + cdb: CDB, + config: Config, + tokenizer: BaseTokenizer) -> None: if not isinstance(config.components.linking, EmbeddingLinking): raise TypeError("Linking config must be an EmbeddingLinking instance") self.cnf_l: EmbeddingLinking = config.components.linking @@ -41,6 +45,7 @@ def __init__(self, cdb: CDB, config: Config) -> None: super().__init__( cdb, config, + tokenizer, model_init_kwargs=model_init_kwargs, ) self.training_batch: list[tuple] = [] @@ -407,7 +412,7 @@ def create_new_component( vocab: Vocab, model_load_path: Optional[str], ) -> "Linker": - return cls(cdb, cdb.config) + return cls(cdb, cdb.config, tokenizer) def serialise_to(self, folder_path: str) -> None: os.makedirs(folder_path, exist_ok=True) @@ -424,7 +429,8 @@ def deserialise_from( cls, folder_path: str, **init_kwargs ) -> "Linker": cdb = init_kwargs["cdb"] - linker = cls(cdb, cdb.config) + tokenizer = init_kwargs["tokenizer"] + linker = cls(cdb, cdb.config, tokenizer) model_state_path = os.path.join( folder_path, cls._MODEL_FOLDER_NAME, cls._MODEL_STATE_FILE_NAME diff --git a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/transformer_context_model.py b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/transformer_context_model.py index fb08af4a7..fe60f0fa4 100644 --- a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/transformer_context_model.py +++ b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/transformer_context_model.py @@ -3,6 +3,7 @@ from medcat.storage.serialisables import AbstractSerialisable from torch import Tensor, nn from transformers import AutoModel, AutoTokenizer +from medcat_embedding_linker.config import EmbeddingLinking as LinkingConfig from tqdm import tqdm import json import logging @@ -23,14 +24,16 @@ class ModelForEmbeddingLinking(nn.Module): def __init__( self, embedding_model_name: str, + cnf_l: LinkingConfig, use_projection_layer: bool = False, - top_n_layers_to_unfreeze: int = -1, + top_n_layers_to_unfreeze: int = 0, device: Optional[Union[str, torch.device]] = None, ) -> None: super().__init__() self.language_model = AutoModel.from_pretrained(embedding_model_name) self.base_model_name = self.language_model.name_or_path + self.cnf_l = cnf_l self.use_projection_layer = use_projection_layer self.top_n_layers_to_unfreeze = top_n_layers_to_unfreeze @@ -86,6 +89,10 @@ def _freeze_all_parameters(self) -> None: param.requires_grad = True def unfreeze_top_n_lm_layers(self, n: int) -> None: + self.cnf_l.top_n_layers_to_unfreeze = n + self.top_n_layers_to_unfreeze = n + # Re-apply from a known baseline so repeated calls are deterministic. + self._freeze_all_parameters() # train all LM layers - each layer requires more data if n == -1: for param in self.language_model.parameters(): @@ -133,6 +140,7 @@ def save_pretrained(self, save_directory: Union[str, Path]) -> None: def from_pretrained( cls, path_or_model_name: Union[str, Path], + cnf_l: LinkingConfig, device: Optional[Union[str, torch.device]] = None, **kwargs, ) -> "ModelForEmbeddingLinking": @@ -147,7 +155,7 @@ def from_pretrained( config = json.load(f) config.update(kwargs) - model = cls(**config) + model = cls(cnf_l=cnf_l, **config) state_dict = torch.load(weights_path, map_location="cpu") model.load_state_dict(state_dict) model.to(target_device) @@ -156,6 +164,7 @@ def from_pretrained( # Hugging Face model id/path. model = cls( embedding_model_name=str(path_or_model_name), + cnf_l=cnf_l, device=target_device, **kwargs, ) @@ -208,8 +217,19 @@ def _resolve_model_source(path_or_model_name: Union[str, Path]) -> str: return str(path_or_model_name) def _get_model_init_kwargs(self) -> dict[str, Any]: - """Build kwargs passed to ModelForEmbeddingLinking.from_pretrained.""" - return dict(self._model_init_kwargs) + """Build kwargs passed to ModelForEmbeddingLinking.from_pretrained. + + Keep these in sync with runtime linker config so model swaps preserve + trainability settings (i.e. top-n LM layers to unfreeze). + """ + kwargs = dict(self._model_init_kwargs) + if hasattr(self.cnf_l, "use_projection_layer"): + kwargs["use_projection_layer"] = self.cnf_l.use_projection_layer + if hasattr(self.cnf_l, "top_n_layers_to_unfreeze"): + kwargs["top_n_layers_to_unfreeze"] = ( + self.cnf_l.top_n_layers_to_unfreeze + ) + return kwargs def load_transformers(self, embedding_model_name: Union[str, Path]) -> None: """Load tokenizer/model from local path or Hugging Face model id.""" @@ -224,7 +244,9 @@ def load_transformers(self, embedding_model_name: Union[str, Path]) -> None: self.cnf_l.embedding_model_name = str(embedding_model_name) self.tokenizer = AutoTokenizer.from_pretrained(model_source) self.model = ModelForEmbeddingLinking.from_pretrained( - model_source, **model_init_kwargs + model_source, + cnf_l=self.cnf_l, + **model_init_kwargs, ) self.model.eval() self.device = torch.device( diff --git a/medcat-plugins/embedding-linker/tests/test_embedding_linker.py b/medcat-plugins/embedding-linker/tests/test_embedding_linker.py index 1b9591c10..892a81a42 100644 --- a/medcat-plugins/embedding-linker/tests/test_embedding_linker.py +++ b/medcat-plugins/embedding-linker/tests/test_embedding_linker.py @@ -67,7 +67,8 @@ class NonTrainableEmbeddingLinkerTests(unittest.TestCase): cnf = Config() cnf.components.linking = embedding_linker.EmbeddingLinking() cnf.components.linking.comp_name = embedding_linker.Linker.name - linker = embedding_linker.Linker(FakeCDB(cnf), cnf) + vtokenizer = FakeTokenizer() + linker = embedding_linker.Linker(FakeCDB(cnf), cnf, vtokenizer) def test_linker_is_not_trainable(self): self.assertNotIsInstance(self.linker, TrainableComponent) @@ -83,7 +84,8 @@ class TrainableEmbeddingLinkerTests(unittest.TestCase): cnf.components.linking.comp_name = ( trainable_embedding_linker.Linker.name ) - linker = trainable_embedding_linker.Linker(FakeCDB(cnf), cnf) + vtokenizer = FakeTokenizer() + linker = trainable_embedding_linker.Linker(FakeCDB(cnf), cnf, vtokenizer) def test_linker_is_trainable(self): self.assertIsInstance(self.linker, TrainableComponent) diff --git a/medcat-plugins/rawstring-tokenizer/README.md b/medcat-plugins/rawstring-tokenizer/README.md new file mode 100644 index 000000000..59142596c --- /dev/null +++ b/medcat-plugins/rawstring-tokenizer/README.md @@ -0,0 +1,80 @@ +# MedCAT Embedding Linker + +A MedCAT plugin that provides an a Rawstring tokenizer, essentially splitting on whitespace characters (" ", "\n", "\t") only. + +## Overview + +This plugin replaces MedCAT's default tokenizing components with with rawstring, that are not limited by requiring SpaCy representations that perform linking. + +## Requirements + +- **MedCAT**: 2.0+ ([PyPI](https://pypi.org/project/medcat/) | [GitHub](https://github.com/CogStack/MedCAT)) +- Python 3.10+ + +## Installation + +```bash +pip install medcat-rawstring-tokenizer +``` + +## Quick Start + +### Replacing current tokenizer with a rawstring_tokenizer + +```python +from medcat.cat import CAT +from medcat_rawstring_tokenizer.tokenizer import RawstringTokenizer +from medcat.tokenizing.tokenizers import register_tokenizer + +MODEL_PACK_PATH = ".." +TARGET_FOLDER = ".." +TARGET_PACK_NAME = ".." +TOKENIZER_NAME = "rawstring_tokenizer" + +# The custom tokenizer must be registered before we rebuild the pipeline. +register_tokenizer(TOKENIZER_NAME, RawstringTokenizer) + +cat = CAT.load_model_pack(MODEL_PACK_PATH) +print("Tokenizer provider before:", cat.config.general.nlp.provider) + +# Switch tokenizer provider in config, then recreate pipeline to apply it. +cat.config.general.nlp.provider = TOKENIZER_NAME + +cat.config.components.addons.clear() +cat._recreate_pipe() + +print("Tokenizer provider after:", cat.config.general.nlp.provider) + +cat.save_model_pack( + target_folder=TARGET_FOLDER, + pack_name=TARGET_PACK_NAME, + add_hash_to_pack_name=False, + make_archive=False, +) +print("Saved model pack to:", f"{TARGET_FOLDER.rstrip('/')}/{TARGET_PACK_NAME}") +``` + +## How It Works + +### Component Registration + +Register the tokenizer by name before trying to add the tokenizer to the pipeline. If loading a model with a rawstring tokenizer register it beforehand. + +### Embedding Generation + +## Limitations + +- Can NOT be used with the default `context_based_linker` as, that uses spacy tokens and spacy embeddings for linking. Which are not used with this tokenizer. + +## Citation + +If you use this plugin, please cite MedCAT: + +```bibtex +@article{medcat2021, + title={Medical Concept Annotation Tool (MedCAT)}, + author={Kraljevic, Zeljko and et al.}, + journal={arXiv preprint arXiv:2010.01165}, + year={2021} +} +``` diff --git a/medcat-plugins/rawstring-tokenizer/pyproject.toml b/medcat-plugins/rawstring-tokenizer/pyproject.toml new file mode 100644 index 000000000..dd531a802 --- /dev/null +++ b/medcat-plugins/rawstring-tokenizer/pyproject.toml @@ -0,0 +1,113 @@ +[project] +name = "medcat-rawstring-tokenzier" + +dynamic = ["version"] + +description = "Rawstring tokenizer for MedCAT" + +readme = "README.md" + +requires-python = ">=3.10" + +license = {text = "Apache-2.0"} + +keywords = ["ML", "NLP", "NER+L"] + +authors = [ + {name = "A. Sutton"}, + {name = "M. Ratas"}, +] + +# This should be your name or the names of the organization who currently +# maintains the project, and a valid email address corresponding to the name +# listed. +maintainers = [ + {name = "CogStack", email = "contact@cogstack.org" } +] + +classifiers = [ + # How mature is this project? Common values are + # 3 - Alpha + # 4 - Beta + # 5 - Production/Stable + "Development Status :: 5 - Production/Stable", + + "Intended Audience :: Healthcare Industry", + # "Topic :: Natural Language Processing :: Named Entity Recognition and Linking", + + # Specify the Python versions you support here. In particular, ensure + # that you indicate you support Python 3. These classifiers are *not* + # checked by "pip install". See instead "python_requires" below. + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3 :: Only", + "Operating System :: OS Independent", +] + +# This field lists other packages that your project depends on to run. +# Any package you put here will be installed by pip when your project is +# installed, so they must be valid existing projects. +# +# For an analysis of this field vs pip's requirements files see: +# https://packaging.python.org/discussions/install-requires-vs-requirements/ +dependencies = [ + "medcat[spacy]>=2.7", +] + +# List additional groups of dependencies here (e.g. development +# dependencies). Users will be able to install these using the "extras" +# syntax, for example: +# +# $ pip install sampleproject[dev] +# +# Similar to `dependencies` above, these must be valid existing +# projects. +[project.optional-dependencies] # Optional +dev = [ + "ruff~=0.1.7", + "mypy", + "types-tqdm", + "types-setuptools", + "types-PyYAML", +] + +# entry-points to add onto medcat +[project.entry-points."medcat.plugins"] +medcat_rawstring_tokenizer = "medcat_rawstring_tokenizer" + +[project.urls] +"Homepage" = "https://cogstack.org/" +"Bug Reports" = "https://discourse.cogstack.org/" +"Source" = "https://github.com/CogStack/cogstack-nlp/tree/main/medcat-plugins/rawstring-tokenizer" + +[build-system] +# These are the assumed default build requirements from pip: +# https://pip.pypa.io/en/stable/reference/pip/#pep-517-and-518-support +requires = ["setuptools>=43.0.0", "setuptools_scm>=8", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +package-dir = {"" = "src"} + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.setuptools.package-data] +"medcat_rawstring_tokenizer" = ["py.typed"] + +[tool.setuptools_scm] +# look for .git folder in root of repo +root = "../.." +version_scheme = "post-release" +local_scheme = "no-local-version" +tag_regex = "^medcat-rawstring-tokenizer/v(?P\\d+(?:\\.\\d+)*)(?:[ab]\\d+|rc\\d+)?$" +git_describe_command = "git describe --dirty --tags --long --match 'medcat-rawstring-tokenizer/v*'" + +[tool.ruff.lint] +# 1. Enable some extra checks for ruff +select = ["E", "F"] +# ignore unused local variables +ignore = ["F841"] diff --git a/medcat-plugins/rawstring-tokenizer/src/medcat_rawstring_tokenizer/__init__.py b/medcat-plugins/rawstring-tokenizer/src/medcat_rawstring_tokenizer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/medcat-plugins/rawstring-tokenizer/src/medcat_rawstring_tokenizer/tokenizer.py b/medcat-plugins/rawstring-tokenizer/src/medcat_rawstring_tokenizer/tokenizer.py new file mode 100644 index 000000000..34f15d398 --- /dev/null +++ b/medcat-plugins/rawstring-tokenizer/src/medcat_rawstring_tokenizer/tokenizer.py @@ -0,0 +1,132 @@ +from medcat.tokenizing.tokenizers import MutableDocument, MutableEntity, MutableToken +from medcat.config.config import Config +from medcat_rawstring_tokenizer.tokens import Entity, Document +from typing import Optional, Type + +class RawstringTokenizer: + """The base tokenizer protocol.""" + + def __init__(self, config: Config): + self.config = config + + def create_entity(self, doc: MutableDocument, + token_start_index: int, token_end_index: int, + label: str) -> MutableEntity: + """Create an entity from a document. + + Args: + doc (MutableDocument): The document to use. + token_start_index (int): The token start index. + token_end_index (int): The token end index. + label (str): The detected name for the entity. + + Returns: + MutableEntity: The resulting entity. + """ + # Get tokens to determine character span and text + tokens = doc[token_start_index:token_end_index] + if not tokens: + raise ValueError("No tokens in the specified range") + # Construct entity text and determine character span + text = " ".join(tkn.text for tkn in tokens) + start_char = tokens[0].char_index + end_char = tokens[-1].end_char_index + # TODO: Check this is the correct length + # maybe + 1 + text = doc.text[start_char:end_char] + # Create entity with both token and character spans + # The end index needs to be pushed forward by one + # i.e. index 9:10 means token 9 is included + # we address this in the Entity by setting end_index to be end_token.index - 1 + entity = Entity(text, + token_start_index, + token_end_index+1, + start_char, + end_char, + label) + return entity + + def entity_from_tokens(self, tokens: list[MutableToken]) -> MutableEntity: + """Get an entity from the list of tokens. + + This will create a new instance instead of looking for existing entity. + This method should be used only if/when there was no existing entity + within the specified document for the given span of tokens. + + Args: + tokens (list[MutableToken]): List of tokens. + + Returns: + MutableEntity: The resulting entity. + """ + if not tokens: + raise ValueError("Need at least one token for an entity") + text = " ".join(tkn.text for tkn in tokens) + start_index = tokens[0].index + # The end index needs to be pushed forward by one + # i.e. index 9:10 means token 9 is included + # we address this in the Entity by setting end_index to be end_token.index - 1 + end_index = tokens[-1].index + 1 + start_char = tokens[0].char_index + end_char = tokens[-1].end_char_index + # Entity uses [start, end] char semantics, so end must stay exclusive. + return Entity(text, start_index, end_index, start_char, end_char, text) + + + def _get_existing_entity(self, tokens: list[MutableToken], + doc: MutableDocument) -> Optional[MutableEntity]: + if not tokens: + return None + for ent in doc.ner_ents + doc.linked_ents: + # The end index is exclusive + if (ent.start_index == tokens[0].base.index and + ent.end_index - 1 == tokens[-1].base.index): + return ent + return None + + def entity_from_tokens_in_doc(self, tokens: list[MutableToken], + doc: MutableDocument) -> MutableEntity: + """Get an entity from the list of tokens in the specified document. + + This method is designed to reuse entities where possible. + I don't think the document is required for this implementation. + + Args: + tokens (list[MutableToken]): List of tokens. + doc (MutableDocument): The document for these tokens. + + Returns: + MutableEntity: The resulting entity. + """ + existing_ent = self._get_existing_entity(tokens, doc) + if existing_ent: + print("Existing entity found: ", existing_ent) + return existing_ent + return self.entity_from_tokens(tokens) + + def __call__(self, text: str) -> MutableDocument: + doc = Document(text) + return doc + + @classmethod + def create_new_tokenizer(cls, config: Config) -> 'RawstringTokenizer': + return cls(config) + + def get_doc_class(self) -> Type[MutableDocument]: + """Get the document implementation class used by the tokenizer. + + This can be used (e.g) to register addon paths. + + Returns: + Type[MutableDocument]: The document class. + """ + return Document + + def get_entity_class(self) -> Type[MutableEntity]: + """Get the entity implementation class used by the tokenizer. + + Returns: + Type[MutableEntity]: The entity class. + """ + return Entity + diff --git a/medcat-plugins/rawstring-tokenizer/src/medcat_rawstring_tokenizer/tokens.py b/medcat-plugins/rawstring-tokenizer/src/medcat_rawstring_tokenizer/tokens.py new file mode 100644 index 000000000..28e3a66be --- /dev/null +++ b/medcat-plugins/rawstring-tokenizer/src/medcat_rawstring_tokenizer/tokens.py @@ -0,0 +1,253 @@ +from typing import Any, Iterator, Optional, Union, cast, overload +from bisect import bisect_right +from medcat.tokenizing.tokens import (BaseToken, MutableToken, + BaseEntity, MutableEntity, + BaseDocument, + UnregisteredDataPathException) + +import unicodedata +import re + + +# keep both hyphens and slashes within words +_WORD_RE = re.compile(r"[^\W_]+(?:[^\W_]+)*", re.UNICODE) +# _WORD_RE = re.compile(r"[^\W_]+(?:[-/][^\W_]+)*", re.UNICODE) + + +def _iter_word_spans( + text: str, + base_char_index: int = 0 + ) -> Iterator[tuple[str, int, int]]: + for match in _WORD_RE.finditer(text): + yield (match.group(0), + base_char_index + match.start(), + base_char_index + match.end()) + +class Token: + def __init__(self, + text: str, + index: int, + char_index: int, + end_char_index: int) -> None: + # --- BaseToken fields --- + self._text = text + self._index = index + self._char_index = char_index + self._end_char_index = end_char_index + # --- MutableToken fields --- + self._norm: str = text.lower() + self._to_skip: bool = False + self._is_punctuation: bool = ( + text != "" and unicodedata.category(text[0]).startswith("P") + ) + + # --- BaseToken --- + @property + def text(self) -> str: return self._text + @property + def lower(self) -> str: return self._text.lower() + @property + def text_versions(self) -> list[str]: return [self._norm, self.lower] + @property + def is_upper(self) -> bool: return self._text.isupper() + @property + def is_stop(self) -> bool: return False # handled by transformers + @property + def is_digit(self) -> bool: return self._text.isdigit() + @property + def char_index(self) -> int: return self._char_index + @property + def index(self) -> int: return self._index + @property + def end_char_index(self) -> int: return self._end_char_index + @property + def text_with_ws(self) -> str: return self._text + + # --- MutableToken --- + @property + def base(self) -> BaseToken: return cast(BaseToken, self) + @property + def is_punctuation(self) -> bool: return self._is_punctuation + @is_punctuation.setter + def is_punctuation(self, val: bool) -> None: self._is_punctuation = val + @property + def to_skip(self) -> bool: return self._to_skip + @to_skip.setter + def to_skip(self, val: bool) -> None: self._to_skip = val + @property + def lemma(self) -> str: return self._text # no lemmatization, return text as lemma + @property + def tag(self) -> Optional[str]: return None + @property + def norm(self) -> str: return self._norm + @norm.setter + def norm(self, val: str) -> None: self._norm = val + +class Entity: + _addon_extension_paths: set[str] = set() + + def __init__(self, text: str, start_index: int, end_index: int, + start_char: int, end_char: int, label: str = "") -> None: + # --- BaseEntity fields --- + # Token span is [start_index, end_index]: end is exclusive. + # Character span is [start_char, end_char]: end is exclusive. + self._text = text + self._start_index = start_index + self._end_index = end_index + self._start_char = start_char + self._end_char = end_char + self._label = label + self._addon_data: dict[str, Any] = {} + # --- MutableEntity fields --- + self.cui: str = '' + self.detected_name: str = label + self.link_candidates: list[str] = [] + self.context_similarity: float = 0.0 + self.confidence: float = 0.0 + self.id: int = -1 + + # --- BaseEntity --- + @property + def base(self) -> BaseEntity: return cast(BaseEntity, self) + @property + def text(self) -> str: return self._text + @property + def label(self) -> str: return self._label + @property + def start_index(self) -> int: return self._start_index + # This requires -1 for compatibility + @property + def end_index(self) -> int: return self._end_index - 1 + @property + def start_char_index(self) -> int: return self._start_char + @property + def end_char_index(self) -> int: return self._end_char # exclusive end index + + def __iter__(self) -> Iterator[MutableToken]: + for i, (text, char_index, end_char_index) in enumerate( + _iter_word_spans(self._text, self._start_char)): + yield Token(text, self._start_index + i, char_index, end_char_index) + + def __len__(self) -> int: return max(0, self._end_index - self._start_index) + + # --- addon data --- + def set_addon_data(self, path: str, val: Any) -> None: + if path not in self._addon_extension_paths: + raise UnregisteredDataPathException(self.__class__, path) + self._addon_data[path] = val + + def has_addon_data(self, path: str) -> bool: + return bool(self._addon_data.get(path)) + + def get_addon_data(self, path: str) -> Any: + if path not in self._addon_extension_paths: + raise UnregisteredDataPathException(self.__class__, path) + return self._addon_data.get(path) + + def get_available_addon_paths(self) -> list[str]: + return [p for p in self._addon_extension_paths if self.has_addon_data(p)] + + @classmethod + def register_addon_path(cls, path: str, def_val: Any = None, + force: bool = True) -> None: + cls._addon_extension_paths.add(path) + + +class Document: + _addon_extension_paths: set[str] = set() + + def __init__(self, text: str) -> None: + self._text = text + self._addon_data: dict[str, Any] = {} + self.ner_ents: list[MutableEntity] = [] + self.linked_ents: list[MutableEntity] = [] + self._char_indices: Optional[list[int]] = None + self._tokens: list[Token] = [ + Token(token_text, token_index, char_index, end_char_index) + for token_index, (token_text, char_index, end_char_index) in + enumerate(_iter_word_spans(text)) + ] + + @property + def base(self) -> BaseDocument: return cast(BaseDocument, self) + + @property + def text(self) -> str: return self._text + + @overload + def __getitem__(self, index: int) -> MutableToken: + pass + + @overload + def __getitem__(self, index: slice) -> list[MutableToken]: + pass + + def __getitem__(self, index: Union[int, slice] + ) -> Union[MutableToken, list[MutableToken]]: + if isinstance(index, int): + if index < 0: + index += len(self._tokens) + if index < 0 or index >= len(self._tokens): + raise IndexError("Document index out of range") + return self._tokens[index] + + start, stop, step = index.indices(len(self._tokens)) + if step != 1: + raise ValueError("Token slices must use step=1") + return self._tokens[start:stop] + + def __iter__(self) -> Iterator[MutableToken]: + return iter(self._tokens) + + def __len__(self) -> int: + return len(self._tokens) + + def isupper(self) -> bool: + return self._text.isupper() + + def _ensure_char_indices(self) -> list[int]: + if self._char_indices is None: + self._char_indices = [tkn.char_index for tkn in self._tokens] + return self._char_indices + + def get_tokens(self, start_index: int, end_index: int + ) -> list[MutableToken]: + # Keep MedCAT compatibility (inclusive end index), then resolve to + # full tokens by overlap so partial subword offsets map to words. + span_start = max(0, start_index) + span_end_exclusive = max(span_start, end_index) + 1 + + token_char_indices = self._ensure_char_indices() + lo = max(0, bisect_right(token_char_indices, span_start) - 1) + hi = min( + len(self._tokens), + bisect_right(token_char_indices, span_end_exclusive - 1) + 1 + ) + + return [ + token for token in self._tokens[lo:hi] + if token.end_char_index > span_start and + token.char_index < span_end_exclusive + ] + + + def set_addon_data(self, path: str, val: Any) -> None: + if path not in self._addon_extension_paths: + raise UnregisteredDataPathException(self.__class__, path) + self._addon_data[path] = val + + def has_addon_data(self, path: str) -> bool: + return bool(self._addon_data.get(path)) + + def get_addon_data(self, path: str) -> Any: + if path not in self._addon_extension_paths: + raise UnregisteredDataPathException(self.__class__, path) + return self._addon_data.get(path) + + def get_available_addon_paths(self) -> list[str]: + return [p for p in self._addon_extension_paths if self.has_addon_data(p)] + + @classmethod + def register_addon_path(cls, path: str, def_val: Any = None, + force: bool = True) -> None: + cls._addon_extension_paths.add(path) \ No newline at end of file diff --git a/medcat-plugins/rawstring-tokenizer/tests/__init__.py b/medcat-plugins/rawstring-tokenizer/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/medcat-plugins/rawstring-tokenizer/tests/test_rawstring_tokenizer.py b/medcat-plugins/rawstring-tokenizer/tests/test_rawstring_tokenizer.py new file mode 100644 index 000000000..bdfb16bfe --- /dev/null +++ b/medcat-plugins/rawstring-tokenizer/tests/test_rawstring_tokenizer.py @@ -0,0 +1,78 @@ +from typing import runtime_checkable +from medcat.tokenizing import tokenizers +from medcat_rawstring_tokenizer.tokenizer import RawstringTokenizer +from medcat.config import Config +from medcat.tokenizing.tokens import MutableDocument, MutableEntity, MutableToken +from medcat.utils.registry import Registry +from medcat.tokenizing.tokenizers import register_tokenizer + +import unittest + + +class RawstringTokenizerInitTests(unittest.TestCase): + default_provider = 'rawstring_tokenizer' + default_cls = RawstringTokenizer + default_creator = RawstringTokenizer.create_new_tokenizer + # spacy, regex, and now this + exp_num_def_tokenizers = 3 + + @classmethod + def setUpClass(cls): + register_tokenizer('rawstring_tokenizer', RawstringTokenizer.create_new_tokenizer) + cls.cnf = Config() + + def def_creator_name(self) -> str: + return Registry.translate_name(self.default_creator) + + def test_has_default(self): + avail_tokenizers = tokenizers.list_available_tokenizers() + self.assertEqual(len(avail_tokenizers), self.exp_num_def_tokenizers) + name, cls_name = [(t_name, t_cls) for t_name, t_cls in avail_tokenizers + if t_name == self.default_provider][0] + self.assertEqual(name, self.default_provider) + self.assertEqual(cls_name, self.def_creator_name()) + + def test_can_create_def_tokenizer(self): + tokenizer = tokenizers.create_tokenizer( + self.default_provider, self.cnf) + self.assertIsInstance(tokenizer, + runtime_checkable(tokenizers.BaseTokenizer)) + self.assertIsInstance(tokenizer, self.default_cls) + + +class TokenizerTests(unittest.TestCase): + default_provider = 'rawstring_tokenizer' + text = "Some text to tokenize" + + @classmethod + def setUpClass(cls): + cls.cnf = Config() + + def setUp(self) -> None: + self.tokenizer = tokenizers.create_tokenizer( + self.default_provider, self.cnf) + self.doc = self.tokenizer(self.text) + self.doc.ner_ents = self._create_ner_ents(self.doc) + self.doc.linked_ents = self.doc.ner_ents.copy() + + def _create_ner_ents( + self, doc: MutableDocument, + targets: list[str] = ["text",]) -> list[MutableEntity]: + token_start = 1 + token_end = 2 + return [ + self.tokenizer.create_entity( + doc=doc, + token_start_index=token_start, + token_end_index=token_end, + label=target) + for target in targets + ] + + def test_getting_entity_based_on_tokens_gets_same_instance(self): + for ent in self.doc.ner_ents: + with self.subTest(f"Ent: {ent} in doc {self.doc}"): + tokens = list(ent) + got_ent = self.tokenizer.entity_from_tokens_in_doc(tokens, self.doc) + self.assertIs(got_ent, ent) + self.assertIn(got_ent, self.doc.ner_ents) diff --git a/medcat-plugins/transformer-ner/README.md b/medcat-plugins/transformer-ner/README.md new file mode 100644 index 000000000..378a3a458 --- /dev/null +++ b/medcat-plugins/transformer-ner/README.md @@ -0,0 +1,100 @@ +# MedCAT Embedding Linker + +A MedCAT plugin that provides an transformer based NER component using transformer models from HuggingFace. + +## Overview + +This plugin replaces MedCAT's default NER component with a transformer-based approach that uses BIOES token classifcation to identify spans of text that contain medical entities. + +**Key features:** +- BIOES token format for accurate labeling of longer / shorter spans. +- CRF head to ensure consistent label generation. +- Trainable and configurable for all potential transformer huggingface language models. + +## Requirements + +- **MedCAT**: 2.0+ ([PyPI](https://pypi.org/project/medcat/) | [GitHub](https://github.com/CogStack/MedCAT)) +- Python 3.10+ +- PyTorch +- Transformers + +## Installation + +```bash +pip install medcat-transformer-ner +``` + +## Quick Start + +### Replacing current NER with transformer NER + +```python +from medcat.cat import CAT +from medcat_transformer_ner.transformer_ner import NER +from medcat_transformer_ner.config import TransformerNER + +cat = CAT.load_model_pack("..") + +cat.config.components.ner = TransformerNER() +cat.config.components.ner.comp_name = NER.name + +cat.config.components.addons.clear() +cat._recreate_pipe() + +cat.save_model_pack(target_folder="/data/adam/models/trainable/", + pack_name="kch_gstt_v2_NER_BioLinkBERT", + add_hash_to_pack_name=False, + make_archive=False +) +``` + +## How It Works + +### Component Registration + +The transformer NER has a default untrained transformers model from huggingface it downloads. Using `.load_transformers()` with another huggingface model will use that model instead. + +### Inference Process + +Pass a document through the transformer BIOES model tagging if each entity is: + +1. **Beginning Of Entity Span** +2. **Intermediate Of Entity Span** +3. **Outside Of Entity Span** +4. **End Of Entity Span** +5. **Single Span Of Entity** - Meaning it is a single entity within its own token + +## Configuration + +### Key Parameters +```python +from medcat_transformer_ner.config import TransformerNER +from medcat.cat import CAT +cat = CAT.load_model_pack("..your transformer ner model..") + +ner_component = cat._pipeline.get_component(CoreComponentType.ner) + +# Do you want to only pass forward detected entities where there is a +# perfect match in the name vocabulary? +ner_component.cnf_ner.require_link_candidates = True +# What pretrained transformers model would you like to use? +ner_component.load_transformers("michiyasunaga/BioLinkBERT-large") + +``` + +### Suggested Models + +Any HuggingFace model will work. However smaller models will be unable to model the task appropriately leading to significantly reduced performances. We strongly reccomend `BioLinkBERT-large`, as this is one of the smaller models that can appropriately detect entities. All models will be worth testing. + +## Citation + +If you use this plugin, please cite MedCAT: + +```bibtex +@article{medcat2021, + title={Medical Concept Annotation Tool (MedCAT)}, + author={Kraljevic, Zeljko and et al.}, + journal={arXiv preprint arXiv:2010.01165}, + year={2021} +} +``` diff --git a/medcat-plugins/transformer-ner/pyproject.toml b/medcat-plugins/transformer-ner/pyproject.toml new file mode 100644 index 000000000..9c0ab6e1f --- /dev/null +++ b/medcat-plugins/transformer-ner/pyproject.toml @@ -0,0 +1,117 @@ +[project] +name = "medcat-transformer-ner" + +dynamic = ["version"] + +description = "Transformer based NER for MedCAT" + +readme = "README.md" + +requires-python = ">=3.10" + +license = {text = "Apache-2.0"} + +keywords = ["ML", "NLP", "NER+L"] + +authors = [ + {name = "A. Sutton"}, + {name = "T. Searle"}, + {name = "M. Ratas"}, +] + +# This should be your name or the names of the organization who currently +# maintains the project, and a valid email address corresponding to the name +# listed. +maintainers = [ + {name = "CogStack", email = "contact@cogstack.org" } +] + +classifiers = [ + # How mature is this project? Common values are + # 3 - Alpha + # 4 - Beta + # 5 - Production/Stable + "Development Status :: 3 - Alpha", + + "Intended Audience :: Healthcare Industry", + # "Topic :: Natural Language Processing :: Named Entity Recognition and Linking", + + # Specify the Python versions you support here. In particular, ensure + # that you indicate you support Python 3. These classifiers are *not* + # checked by "pip install". See instead "python_requires" below. + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3 :: Only", + "Operating System :: OS Independent", +] + +# This field lists other packages that your project depends on to run. +# Any package you put here will be installed by pip when your project is +# installed, so they must be valid existing projects. +# +# For an analysis of this field vs pip's requirements files see: +# https://packaging.python.org/discussions/install-requires-vs-requirements/ +dependencies = [ + "medcat[spacy]>=2.7", + "transformers>=4.41.0,<5.0", # avoid major bump + "torch>=2.4.0,<3.0", + "tqdm", +] + +# List additional groups of dependencies here (e.g. development +# dependencies). Users will be able to install these using the "extras" +# syntax, for example: +# +# $ pip install sampleproject[dev] +# +# Similar to `dependencies` above, these must be valid existing +# projects. +[project.optional-dependencies] # Optional +dev = [ + "ruff~=0.1.7", + "mypy", + "types-tqdm", + "types-setuptools", + "types-PyYAML", +] + +# entry-points to add onto medcat +[project.entry-points."medcat.plugins"] +medcat_transformer_ner = "medcat_transformer_ner" + +[project.urls] +"Homepage" = "https://cogstack.org/" +"Bug Reports" = "https://discourse.cogstack.org/" +"Source" = "https://github.com/CogStack/cogstack-nlp/tree/main/medcat-plugins/transformer-ner" + +[build-system] +# These are the assumed default build requirements from pip: +# https://pip.pypa.io/en/stable/reference/pip/#pep-517-and-518-support +requires = ["setuptools>=43.0.0", "setuptools_scm>=8", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +package-dir = {"" = "src"} + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.setuptools.package-data] +"medcat_ner_transformer" = ["py.typed"] + +[tool.setuptools_scm] +# look for .git folder in root of repo +root = "../.." +version_scheme = "post-release" +local_scheme = "no-local-version" +tag_regex = "^medcat-transformer-ner/v(?P\\d+(?:\\.\\d+)*)(?:[ab]\\d+|rc\\d+)?$" +git_describe_command = "git describe --dirty --tags --long --match 'medcat-transformer-ner/v*'" + +[tool.ruff.lint] +# 1. Enable some extra checks for ruff +select = ["E", "F"] +# ignore unused local variables +ignore = ["F841"] diff --git a/medcat-plugins/transformer-ner/src/medcat_transformer_ner/__init__.py b/medcat-plugins/transformer-ner/src/medcat_transformer_ner/__init__.py new file mode 100644 index 000000000..1f1d2b174 --- /dev/null +++ b/medcat-plugins/transformer-ner/src/medcat_transformer_ner/__init__.py @@ -0,0 +1,3 @@ +from .registration import do_registration as __register + +__register() diff --git a/medcat-plugins/transformer-ner/src/medcat_transformer_ner/config.py b/medcat-plugins/transformer-ner/src/medcat_transformer_ner/config.py new file mode 100644 index 000000000..d9c7cf2a0 --- /dev/null +++ b/medcat-plugins/transformer-ner/src/medcat_transformer_ner/config.py @@ -0,0 +1,34 @@ +from typing import Optional, Any +from medcat.config import Ner + +class TransformerNER(Ner): + """The config exclusively used for the transformer NER""" + language_model_name: str = "michiyasunaga/BioLinkBERT-large" + """Name/path of the language model. It must be downloadable from + huggingface or linked from an appropriate file directory. NOTE: + use ner_component.load_transformers to load the model, changing this + does nothing.""" + training_batch_size: int = 32 + """The size of the batch to be used for training.""" + max_token_length: int = 512 + """Max number of tokens to be passed to the language model. + Longer sequences will be chunked""" + overlap_chunking: float = 0.1 + """How much each chunk should overlap with the previous one. + This is important to avoid missing entities that are on the border of two chunks.""" + gpu_device: Optional[Any] = None + """Choose a device for the model to be stored / computed on. If None + then an appropriate GPU device that is available will be chosen""" + require_link_candidates: bool = True + """Generate ent.link_candidates based on detected names. This requires + checking the CDB.name2info, and is required for vocab based linking. + Enabling this will lower recall, and most likely increase precision, + it will also decrease computation time.""" + use_prefix_token: bool = False + """Given a detected span, include one token previous to improve recall. + This helps with low signal words not being detected by the model. This will + increase computation time, and could reduce precision.""" + learning_rate: float = 1e-5 + """The learning rate to be used for training the model""" + weight_decay: float = 0.001 + """The weight decay to be used for training the model""" \ No newline at end of file diff --git a/medcat-plugins/transformer-ner/src/medcat_transformer_ner/registration.py b/medcat-plugins/transformer-ner/src/medcat_transformer_ner/registration.py new file mode 100644 index 000000000..71b42af43 --- /dev/null +++ b/medcat-plugins/transformer-ner/src/medcat_transformer_ner/registration.py @@ -0,0 +1,16 @@ +import logging + +from medcat.components.types import CoreComponentType +from medcat.components.types import lazy_register_core_component + + +logger = logging.getLogger(__name__) + + +def do_registration(): + lazy_register_core_component( + CoreComponentType.ner, + "transformer_ner", + "medcat_transformer_ner.transformer_ner", + "NER.create_new_component", + ) diff --git a/medcat-plugins/transformer-ner/src/medcat_transformer_ner/transformer_ner.py b/medcat-plugins/transformer-ner/src/medcat_transformer_ner/transformer_ner.py new file mode 100644 index 000000000..535d0fae6 --- /dev/null +++ b/medcat-plugins/transformer-ner/src/medcat_transformer_ner/transformer_ner.py @@ -0,0 +1,574 @@ +from pathlib import Path +from typing import Any, Optional, Union +from medcat.tokenizing.tokens import MutableDocument, MutableEntity +from medcat.components.types import CoreComponentType, TrainableComponent +from medcat.components.types import AbstractEntityProvidingComponent +from medcat.components.ner.vocab_based_annotator import annotate_name +from medcat.tokenizing.tokenizers import BaseTokenizer +from medcat.vocab import Vocab +from medcat.cdb import CDB +from medcat.config.config import ComponentConfig +from medcat.storage.serialisables import AbstractManualSerialisable +from transformers import AutoTokenizer, get_constant_schedule_with_warmup +from medcat_transformer_ner.transformer_ner_model import ModelForBinaryNER +from medcat_transformer_ner.config import TransformerNER +from torch import Tensor +import logging +import os +import torch + +logger = logging.getLogger(__name__) + + +class NER(AbstractEntityProvidingComponent, + TrainableComponent, + AbstractManualSerialisable): + name = 'transformer_ner' + + comp_name = "transformer_ner" + _MODEL_FOLDER_NAME = "transformer_ner_model" + + def __init__(self, tokenizer: BaseTokenizer, + cdb: CDB) -> None: + super().__init__() + self.tokenizer = tokenizer + self.cdb = cdb + self.config = self.cdb.config + + # NER model stuff! + self.cnf_ner: TransformerNER = self.config.components.ner + self.label2id = { + "O": 0, + "B-ENT": 1, + "I-ENT": 2, + "E-ENT": 3, + "S-ENT": 4 + } + self.id2label = {v: k for k, v in self.label2id.items()} + self._model_init_kwargs: dict[str, Any] = dict() + self.load_transformers(self.cnf_ner.language_model_name) + self.max_token_length = self.cnf_ner.max_token_length + self.overlap_chunking = self.cnf_ner.overlap_chunking + + @staticmethod + def _resolve_model_source(path_or_model_name: Union[str, Path]) -> str: + """Return local absolute path if it exists, otherwise keep HF model id.""" + candidate = Path(path_or_model_name).expanduser() + if candidate.exists(): + return str(candidate.resolve()) + return str(path_or_model_name) + + def _get_model_init_kwargs(self) -> dict[str, Any]: + """Build kwargs passed to ModelForEmbeddingLinking.from_pretrained.""" + return dict(self._model_init_kwargs) + + def load_transformers(self, language_model_name: Union[str, Path]) -> None: + """Load tokenizer/model from local path or Hugging Face model id.""" + model_source = self._resolve_model_source(language_model_name) + model_init_kwargs = self._get_model_init_kwargs() + + if ( + not hasattr(self, "model") + or not hasattr(self, "transformer_tokenizer") + or model_source != self._loaded_model_source + or model_init_kwargs != self._loaded_model_init_kwargs + ): + self.cnf_ner.language_model_name = str(language_model_name) + + self.transformer_tokenizer = AutoTokenizer.from_pretrained( + model_source, + clean_up_tokenization_spaces=False + ) + self.model = ModelForBinaryNER( + embedding_model_name=model_source, + id2label=self.id2label, + **model_init_kwargs + ) + + self.model.eval() + self.device = torch.device( + self.cnf_ner.gpu_device + or ("cuda" if torch.cuda.is_available() else "cpu") + ) + self.model.to(self.device) + self._loaded_model_source: str = model_source + self._loaded_model_init_kwargs: dict[str, Any] = model_init_kwargs + self.optimizer = torch.optim.AdamW(self.model.parameters(), + lr=1e-5, + weight_decay=0.001) + self.scheduler = get_constant_schedule_with_warmup( + self.optimizer, + num_warmup_steps=20, + ) + logger.debug( + "Loaded embedding model: %s (resolved source: %s) with kwargs=%s " \ + "on device: %s", + language_model_name, + model_source, + model_init_kwargs, + self.device, + ) + + def get_type(self) -> CoreComponentType: + return CoreComponentType.ner + + def _chunk_and_encode(self, + text: str, + entities: Optional[list[MutableEntity]] = None + ) -> tuple[Tensor, Tensor, list[Any], list[Any], + Optional[Tensor]]: + labels_enabled = entities is not None + # First pass: tokenize full text to get offsets for chunking and label alignment + base_encoding = self.transformer_tokenizer( + text, + return_offsets_mapping=True, + add_special_tokens=False + ) + + offsets = base_encoding["offset_mapping"] + + stride = (self.max_token_length - + int(self.max_token_length * self.overlap_chunking)) + + n_tokens = len(base_encoding["input_ids"]) + start_idx = 0 + + all_input_ids = [] + all_attention_masks = [] + all_labels: list[Tensor] = [] + offset_mappings = [] + chunk_char_starts = [] + while start_idx < n_tokens: + end_idx = min(start_idx + self.max_token_length, n_tokens) + + chunk_offsets = offsets[start_idx:end_idx] + + char_start = chunk_offsets[0][0] + char_end = chunk_offsets[-1][1] + chunk_text = text[char_start:char_end] + + # Rebase entities to chunk + # iff this is a training example + if entities is not None: + chunk_entities = [] + for ent in entities: + ent_start = ent.base.start_char_index + ent_end = ent.base.end_char_index # make end exclusive + + if ent_end > char_start and ent_start < char_end: + chunk_entities.append({ + "start": ent_start - char_start, + "end": ent_end - char_start + }) + + # Tokenize chunk + encoding = self.transformer_tokenizer( + chunk_text, + return_offsets_mapping=True, + truncation=True, + padding="max_length", + max_length=self.max_token_length + ) + + offsets_chunk = encoding["offset_mapping"] + + # Label alignment to relevant chunks + if labels_enabled: + chunk_labels = [ + -100 if (start == end) else self.label2id["O"] + for start, end in offsets_chunk + ] + + + for ent in chunk_entities: + ent_token_indices = [] + for i, (token_start, token_end) in enumerate(offsets_chunk): + if token_start == token_end: + continue + if token_start < ent["end"] and token_end > ent["start"]: + ent_token_indices.append(i) + + if not ent_token_indices: + continue + + if len(ent_token_indices) == 1: + chunk_labels[ent_token_indices[0]] = self.label2id["S-ENT"] + continue + + chunk_labels[ent_token_indices[0]] = self.label2id["B-ENT"] + chunk_labels[ent_token_indices[-1]] = self.label2id["E-ENT"] + for i in ent_token_indices[1:-1]: + chunk_labels[i] = self.label2id["I-ENT"] + + all_labels.append(torch.tensor(chunk_labels, dtype=torch.long)) + + all_input_ids.append(torch.tensor(encoding["input_ids"], + dtype=torch.long)) + all_attention_masks.append(torch.tensor(encoding["attention_mask"], + dtype=torch.long)) + offset_mappings.append(offsets_chunk) + chunk_char_starts.append(char_start) + + if end_idx == n_tokens: + break + + start_idx += stride + input_ids = torch.stack(all_input_ids).to(self.device) + attention_masks = torch.stack(all_attention_masks).to(self.device) + labels = None + if labels_enabled: + labels = torch.stack(all_labels).to(self.device) + return input_ids, attention_masks, offset_mappings, chunk_char_starts, labels + + def train(self, cui: str, + entity: MutableEntity, + doc: MutableDocument, + negative: bool = False, + names: Union[list[str], dict] = []) -> None: + """Train the NER component on a given document. This is used in the + supervised training loop of the MedCAT trainer. + """ + # if this is the last entity, we'll train + # kind of a hacky work around, but it's minimal impact on the CAT trainer + if entity is doc.ner_ents[-1]: + text = doc.base.text + entities = doc.ner_ents + input_ids, attention_masks, _, _, labels = ( + self._chunk_and_encode(text, entities) + ) + self.optimizer.zero_grad() + self.model.train() + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_masks, + labels=labels + ) + loss = outputs.loss + + loss.backward() + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), + 1.0 + ) + self.optimizer.step() + self.scheduler.step() + logger.debug("NER training step - loss: ", + loss.item()) + + def _decode_chunk(self, preds, offsets_chunk, chunk_char_start): + """For inference only. Decode a single chunk of predictions into entity + spans, then merge them across chunks.""" + spans = [] + current = None + for pred_id, (tok_start, tok_end) in zip(preds, offsets_chunk): + + # skip padding / special tokens + if (tok_start, tok_end) == (0, 0): + continue + + label = self.id2label[pred_id] + + # If label is "O", close any open entity span and move on. + if label == "O": + if current is not None: + spans.append(current) + current = None + continue + + # This is a bit too general for a binary ENT/ Non Ent + # But it's extendable... maybe! + prefix, ent_type = label.split("-", 1) + + abs_start = chunk_char_start + tok_start + abs_end = chunk_char_start + tok_end + + # B starts a new span + if prefix == "B": + if current is not None: + spans.append(current) + current = { + "start": abs_start, + "end": abs_end, + "label": ent_type + } + + # I continues + elif prefix == "I": + if current is not None and current["label"] == ent_type: + current["end"] = abs_end + else: + # Broken sequence -> treat as a new span. + current = { + "start": abs_start, + "end": abs_end, + "label": ent_type + } + + # E closes + elif prefix == "E": + if current is not None and current["label"] == ent_type: + current["end"] = abs_end + spans.append(current) + current = None + else: + # Broken sequence -> treat standalone E as a single-token span. + spans.append({ + "start": abs_start, + "end": abs_end, + "label": ent_type + }) + + # S is a single token span + elif prefix == "S": + if current is not None: + spans.append(current) + current = None + spans.append({ + "start": abs_start, + "end": abs_end, + "label": ent_type + }) + + if current is not None: + spans.append(current) + + return spans + + def _merge_spans(self, spans, text: str) -> list[dict]: + """Merge spans across chunk boundaries. This is required before creating + entities in the doc, otherwise we might have duplicates for the same + entity that got split across chunks. Used in inference only.""" + if not spans: + return [] + + spans = sorted(spans, key=lambda x: (x["start"], x["end"])) + merged = [spans[0]] + + for span in spans[1:]: + last = merged[-1] + gap_text = text[last["end"]:span["start"]] + gap_is_soft_separator = not ( + gap_text.strip() or gap_text.strip() in {"/", "-"} + ) + + if span["label"] == last["label"] and ( + span["start"] <= last["end"] or gap_is_soft_separator + ): + last["end"] = max(last["end"], span["end"]) + else: + merged.append(span) + + return merged + + + # Build segments in two modes: + # 1) keep half separators inside tokens, 2) split on half separators. + def _build_segments(self, + split_chars: set[str], + detected_string: str, + detected_start: int, + detected_end: int) -> list[tuple[int, int]]: + segs = [] + seg_start = None + for idx, ch in enumerate(detected_string): + if ch in split_chars: + if seg_start is not None: + segs.append((detected_start + seg_start, detected_start + idx)) + seg_start = None + elif seg_start is None: + seg_start = idx + if seg_start is not None: + segs.append((detected_start + seg_start, detected_end)) + return segs + + def _char_span_to_token_span( + self, + doc: MutableDocument, + start_char: int, + end_char: int, + ) -> Optional[tuple[int, int]]: + token_start = None + token_end = None + + for token in doc: + if token.char_index + len(token.text) <= start_char: + continue + if token.char_index >= end_char: + break + + if token_start is None: + token_start = token.index + token_end = token.index + 1 + + if token_start is None or token_end is None: + return None + + return token_start, token_end + + def _span_inference(self, spans: list[dict], + doc: MutableDocument, + text: str) -> list[MutableEntity]: + ner_ents: list[MutableEntity] = [] + seen_token_spans = set() + logger.debug("Num detected spans: %s", len(spans)) + for span in spans: + detected_start = span["start"] + detected_end = span["end"] + detected_string = text[detected_start:detected_end] + if not detected_string: + continue + logger.debug( + "Detected span: [%s, %s] %r", + detected_start, + detected_end, + detected_string, + ) + + token_span = self._char_span_to_token_span(doc, + detected_start, + detected_end) + if token_span is None: + continue + + token_start, token_end = token_span + if self.cnf_ner.use_prefix_token: + token_start = token_start - 1 if token_start > 0 else token_start + # Loop through all contiguous token subspans [i:j] + for i in range(token_start, token_end): + for j in range(i + 1, token_end + 1): + span_key = (i, j) + if span_key in seen_token_spans: + continue + + sub_tokens = list(doc[i:j]) + # there might be more cleaning required here + detected_name = self.config.general.separator.join( + token.text.lower() for token in sub_tokens + ) + ent = None + if detected_name in self.cdb.name2info: + ent = annotate_name( + self.tokenizer, + detected_name, + sub_tokens, + doc, + self.cdb, + len(ner_ents), + detected_name + ) + elif not self.cnf_ner.require_link_candidates: + ent = self.tokenizer.create_entity( + doc, + i, + j, + detected_name, + ) + + if ent: + logger.debug( + "Created entity: %r tokens [%s, %s]", + ent.text, + i, + j, + ent.base.start_char_index, + ent.base.end_char_index, + ) + ner_ents.append(ent) + seen_token_spans.add(span_key) + + return ner_ents + + def predict_entities(self, doc: MutableDocument, + ents: list[MutableEntity] | None = None + ) -> list[MutableEntity]: + """Detect candidates for concepts - linker will then be able + to do the rest. It adds `entities` to the doc.ner_ents and each + entity can have the entity.link_candidates - that the linker + will resolve. + + Args: + doc (MutableDocument): + Spacy document to be annotated with named entities. + ents (list[MutableEntity] | None): + The entities given. This should be None. + + Returns: + list[MutableEntity]: + The NER'ed entities. + """ + # Keep offset generation in the same coordinate space as spaCy char_span. + text = doc.text + input_ids, attention_masks, offset_mappings, chunk_char_starts, _ = ( + self._chunk_and_encode(text) + ) + + self.model.eval() + with torch.no_grad(): + input_ids = input_ids.to(self.device) + attention_masks = attention_masks.to(self.device) + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_masks + ) + predictions = outputs.predictions.cpu().tolist() + + all_spans = [] + for preds, offsets_chunk, char_start in zip( + predictions, + offset_mappings, + chunk_char_starts + ): + spans = self._decode_chunk(preds, offsets_chunk, char_start) + all_spans.extend(spans) + final_spans = self._merge_spans(all_spans, text) + + return self._span_inference(final_spans, doc, text) + + @classmethod + def create_new_component( + cls, cnf: ComponentConfig, tokenizer: BaseTokenizer, + cdb: CDB, vocab: Vocab, model_load_path: Optional[str]) -> 'TransformerNER': + return cls(tokenizer, cdb) + + def serialise_to(self, folder_path: str) -> None: + os.makedirs(folder_path, exist_ok=True) + model_folder = os.path.join(folder_path, self._MODEL_FOLDER_NAME) + os.makedirs(model_folder, exist_ok=True) + + # Save in HuggingFace format for forward compatibility. + self.model.save_pretrained(model_folder) + + @classmethod + def deserialise_from( + cls, folder_path: str, **init_kwargs + ) -> "NER": + cdb = init_kwargs["cdb"] + tokenizer = init_kwargs["tokenizer"] + ner = cls(tokenizer, cdb) + model_folder = os.path.join( + folder_path, cls._MODEL_FOLDER_NAME + ) + config_path = os.path.join(model_folder, "config.json") + weights_path = os.path.join(model_folder, "pytorch_model.bin") + if not os.path.exists(config_path) or not os.path.exists(weights_path): + raise FileNotFoundError( + "Could not find transformer-ner checkpoint files in " + f"{model_folder}. Expected both config.json and pytorch_model.bin." + ) + + # ner.model = AutoModelForTokenClassification.from_pretrained(model_folder) + ner.model = ModelForBinaryNER.from_pretrained( + model_folder, + device=ner.device, + ) + ner.optimizer = torch.optim.AdamW(ner.model.parameters(), + lr=1e-5, + weight_decay=0.001) + ner.scheduler = get_constant_schedule_with_warmup(ner.optimizer, + num_warmup_steps=20) + ner.model.to(ner.device) + ner.model.eval() + + return ner \ No newline at end of file diff --git a/medcat-plugins/transformer-ner/src/medcat_transformer_ner/transformer_ner_model.py b/medcat-plugins/transformer-ner/src/medcat_transformer_ner/transformer_ner_model.py new file mode 100644 index 000000000..86e71b96b --- /dev/null +++ b/medcat-plugins/transformer-ner/src/medcat_transformer_ner/transformer_ner_model.py @@ -0,0 +1,285 @@ +from pathlib import Path +from types import SimpleNamespace +from typing import Any, Optional, Union +from torch import Tensor, nn +from torchcrf import CRF +from transformers import AutoModelForTokenClassification +import json +import logging +import torch + +logger = logging.getLogger(__name__) + +class ModelForBinaryNER(nn.Module): + """Wrapper around a Hugging Face transformer for transformer-based NER. + + The architecture is: transformer backbone -> linear classifier -> CRF. + """ + # for mypy checking + label_is_start: Tensor + label_is_end: Tensor + + def __init__( + self, + embedding_model_name: str, + id2label: dict[int, str], + num_labels: int = 5, + top_n_layers_to_unfreeze: int = -1, + device: Optional[Union[str, torch.device]] = None, + aux_loss_weight: float = 0.5, + ) -> None: + super().__init__() + self.num_labels = num_labels + self.aux_loss_weight = aux_loss_weight + self.id2label = id2label + self.language_model = AutoModelForTokenClassification.from_pretrained( + embedding_model_name, + num_labels=self.num_labels, + ) + # Make sure hidden states are available for the auxiliary heads. + self.language_model.config.output_hidden_states = True + self.base_model_name = self.language_model.config.name_or_path + + # For the auxiliary start/end position heads, we use the hidden states + # from the last layer of the transformer. + hidden_size = self.language_model.config.hidden_size + self.start_head = nn.Linear(hidden_size, 1) + self.end_head = nn.Linear(hidden_size, 1) + + # For state transitions + self.crf = CRF(num_tags=self.num_labels, batch_first=True) + + # Build boundary lookup tables from BIOES label names. + # This is future proof in the sense that "B" and "E" would still be the same + # for multiple different types of entities. + start_flags = [] + end_flags = [] + for i in range(self.num_labels): + label = self.id2label[i] + prefix = label.split("-", 1)[0] + + start_flags.append(1.0 if prefix == "B" else 0.0) + end_flags.append(1.0 if prefix == "E" else 0.0) + + self.register_buffer("label_is_start", + torch.tensor(start_flags, dtype=torch.float)) + self.register_buffer("label_is_end", + torch.tensor(end_flags, dtype=torch.float)) + + target_device = self._resolve_device(device) + self.to(target_device) + + @staticmethod + def _resolve_device(device: Optional[Union[str, torch.device]]) -> torch.device: + if device is None: + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + return torch.device(device) + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + def forward(self, **inputs) -> Any: + labels: Optional[Tensor] = inputs.pop("labels", None) + attention_mask: Tensor = inputs["attention_mask"] + + outputs = self.language_model(**inputs, + return_dict=True, + output_hidden_states=True) + emissions = outputs.logits + # the last layer's hidden states for the start/end heads + hidden_states = outputs.hidden_states[-1] + + # Linear classifiers for boundary heads + start_logits = self.start_head(hidden_states).squeeze(-1) # [B, T] + end_logits = self.end_head(hidden_states).squeeze(-1) # [B, T] + + # CRF can't handle -100 labels so this handles it + loss = None + crf_loss = None + start_loss = None + end_loss = None + + valid_mask = attention_mask.bool() + + if labels is not None: + labels = labels.long() + + # CRF can't handle -100 labels, so we mask them out. + crf_mask = valid_mask & (labels != -100) + safe_labels = labels.clone() + safe_labels[safe_labels == -100] = 0 + + # CRF requires the first timestep to be valid for every sequence. + crf_mask[:, 0] = True + + # CRF also requires each sequence to have at least one valid timestep. + no_valid_tokens = ~crf_mask.any(dim=1) + if no_valid_tokens.any(): + crf_mask[no_valid_tokens, 0] = True + + safe_labels[~crf_mask] = 0 + crf_loss = -self.crf( + emissions, + safe_labels, + mask=crf_mask, + reduction="token_mean", + ) + + # Auxiliary start/end targets + start_targets, end_targets = self._build_boundary_targets(labels) + + # Use the standard attention mask, but exclude -100 positions + aux_mask = valid_mask & (labels != -100) + + start_loss = self._masked_bce_loss(start_logits, start_targets, aux_mask) + end_loss = self._masked_bce_loss(end_logits, end_targets, aux_mask) + + loss = crf_loss + self.aux_loss_weight * (start_loss + end_loss) + + decoded_sequences = self.crf.decode(emissions, mask=valid_mask) + decoded_tensor = torch.zeros( + emissions.shape[:2], + dtype=torch.long, + device=emissions.device, + ) + for row_idx, seq in enumerate(decoded_sequences): + if seq: + decoded_tensor[row_idx, : len(seq)] = torch.tensor( + seq, + dtype=torch.long, + device=emissions.device, + ) + + return SimpleNamespace( + loss=loss, + crf_loss=crf_loss, + start_loss=start_loss, + end_loss=end_loss, + logits=emissions, + start_logits=start_logits, + end_logits=end_logits, + predictions=decoded_tensor, + decoded_sequences=decoded_sequences, + ) + + def _masked_bce_loss(self, logits: Tensor, targets: Tensor, mask: Tensor) -> Tensor: + """ + Normal BCE doesn't handle masking (i.e. handling [-100]), + so this implements a masked version. + """ + loss_fn = nn.BCEWithLogitsLoss(reduction="none") + loss = loss_fn(logits, targets.float()) + loss = loss * mask.float() + + denom = mask.float().sum().clamp_min(1.0) + return loss.sum() / denom + + def _build_boundary_targets(self, labels: Tensor) -> tuple[Tensor, Tensor]: + """ + Convert BIOES token labels into binary start/end targets. + - start = 1 for B-*, (and maybe S-*) + - end = 1 for E-*, (and maybe S-*) + """ + safe_labels = labels.clone() + safe_labels[safe_labels == -100] = 0 + + start_targets = self.label_is_start[safe_labels].to(labels.device) + end_targets = self.label_is_end[safe_labels].to(labels.device) + + start_targets = start_targets.masked_fill(labels == -100, 0.0) + end_targets = end_targets.masked_fill(labels == -100, 0.0) + + return start_targets, end_targets + + def _freeze_all_parameters(self) -> None: + for param in self.language_model.parameters(): + param.requires_grad = False + + # The classification head always needs to be trainable + for param in self.language_model.classifier.parameters(): + param.requires_grad = True + + # Same for the CRF + for param in self.crf.parameters(): + param.requires_grad = True + + def unfreeze_top_n_lm_layers(self, n: int) -> None: + # train all LM layers - each layer requires more data + if n == -1: + for param in self.language_model.parameters(): + param.requires_grad = True + return + + # keep LM fully frozen - better with less data + if n == 0: + return + + base = self.language_model.base_model + # BERT-likes + if hasattr(base, "encoder") and hasattr(base.encoder, "layer"): + layers = base.encoder.layer + # DistilBERT-likes + elif hasattr(base, "transformer") and hasattr(base.transformer, "layer"): + layers = base.transformer.layer + else: + raise ValueError("Unsupported LM architecture for layer unfreezing.") + + total_layers = len(layers) + n = min(n, total_layers) + for layer in layers[-n:]: + for param in layer.parameters(): + param.requires_grad = True + + def save_pretrained(self, save_directory: Union[str, Path]) -> None: + save_path = Path(save_directory) + save_path.mkdir(parents=True, exist_ok=True) + + torch.save(self.state_dict(), save_path / "pytorch_model.bin") + + config = { + "embedding_model_name": self.base_model_name, + "num_labels": self.num_labels, + "id2label": self.id2label, + "aux_loss_weight": self.aux_loss_weight, + } + with open(save_path / "config.json", "w", encoding="utf-8") as f: + json.dump(config, f, indent=2) + + @classmethod + def from_pretrained( + cls, + path_or_model_name: Union[str, Path], + device: Optional[Union[str, torch.device]] = None, + **kwargs, + ) -> "ModelForBinaryNER": + path = Path(path_or_model_name) + config_path = path / "config.json" + weights_path = path / "pytorch_model.bin" + target_device = cls._resolve_device(device) + + # Local saved wrapper model. + if config_path.exists() and weights_path.exists(): + with open(config_path, encoding="utf-8") as f: + config = json.load(f) + + # because loading in turns int keys into strings + if "id2label" in config: + id2label: dict[int, str] = {} + for key, value in config["id2label"].items(): + id2label[int(key)] = value + config["id2label"] = id2label + config.update(kwargs) + model = cls(device=target_device, **config) + state_dict = torch.load(weights_path, map_location="cpu") + model.load_state_dict(state_dict) + model.to(target_device) + return model + + # Hugging Face model id/path. + model = cls( + embedding_model_name=str(path_or_model_name), + device=target_device, + **kwargs, + ) + return model \ No newline at end of file diff --git a/medcat-plugins/transformer-ner/tests/__init__.py b/medcat-plugins/transformer-ner/tests/__init__.py new file mode 100644 index 000000000..b40364e1c --- /dev/null +++ b/medcat-plugins/transformer-ner/tests/__init__.py @@ -0,0 +1,26 @@ +# NOTE: mostly copied from medcat tests +import atexit +import os +import shutil + + +RESOURCES_PATH = os.path.join(os.path.dirname(__file__), "resources") +EXAMPLE_MODEL_PACK_ZIP = os.path.join(RESOURCES_PATH, "mct2_model_pack.zip") +UNPACKED_EXAMPLE_MODEL_PACK_PATH = os.path.join( + RESOURCES_PATH, "mct2_model_pack") + + +# unpack model pack at start so we can access stuff like Vocab +print("Unpacking included test model pack") +shutil.unpack_archive(EXAMPLE_MODEL_PACK_ZIP, UNPACKED_EXAMPLE_MODEL_PACK_PATH) + + +def _del_unpacked_model(): + print( + "Cleaning up! Removing unpacked exmaple model pack:", + UNPACKED_EXAMPLE_MODEL_PACK_PATH, + ) + shutil.rmtree(UNPACKED_EXAMPLE_MODEL_PACK_PATH) + + +atexit.register(_del_unpacked_model) diff --git a/medcat-plugins/transformer-ner/tests/helper.py b/medcat-plugins/transformer-ner/tests/helper.py new file mode 100644 index 000000000..4a70e259d --- /dev/null +++ b/medcat-plugins/transformer-ner/tests/helper.py @@ -0,0 +1,84 @@ +from typing import runtime_checkable, Type, Callable + +from medcat.components import types +from medcat.config.config import Config, ComponentConfig + + +class FakeCDB: + def __init__(self, cnf: Config): + self.config = cnf + self.token_counts = {} + self.cui2info = {} + self.name2info = {} + + def weighted_average_function(self, v: int) -> float: + return v * 0.5 + + +class FVocab: + pass + + +class FTokenizer: + pass + + +class ComponentInitTests: + expected_def_components = 1 + default = "default" + # these need to be specified when overriding + comp_type: types.CoreComponentType + default_cls: Type[types.BaseComponent] + default_creator: Callable[..., types.BaseComponent] + + @classmethod + def setUpClass(cls): + cls.cnf = Config() + cls.fcdb = FakeCDB(cls.cnf) + cls.fvocab = FVocab() + cls.vtokenizer = FTokenizer() + cls.comp_cnf: ComponentConfig = getattr(cls.cnf.components, cls.comp_type.name) + if isinstance(cls.default_creator, Type): + cls._def_creator_name_opts = (cls.default_creator.__name__,) + else: + # classmethod + cls._def_creator_name_opts = ( + ".".join( + ( + # etiher class.method_name + cls.default_creator.__self__.__name__, + cls.default_creator.__name__, + ) + ), + # or just method_name + cls.default_creator.__name__, + ) + + def test_has_default(self): + avail_components = types.get_registered_components(self.comp_type) + self.assertEqual(len(avail_components), self.expected_def_components) + name, cls_name = avail_components[0] + # 1 name / cls name + eq_name = [name == self.default for name, _ in avail_components] + eq_cls = [ + cls_name in self._def_creator_name_opts for _, cls_name in avail_components + ] + self.assertEqual(sum(eq_name), 1) + # NOTE: for NER both the default as well as the Dict based NER + # have the came class name, so may be more than 1 + self.assertGreaterEqual(sum(eq_cls), 1) + # needs to have the same class where name is equal + self.assertTrue(eq_cls[eq_name.index(True)]) + + def test_can_create_def_component(self): + component = types.create_core_component( + self.comp_type, + self.default, + self.cnf, + self.vtokenizer, + self.fcdb, + self.fvocab, + None, + ) + self.assertIsInstance(component, runtime_checkable(types.BaseComponent)) + self.assertIsInstance(component, self.default_cls) diff --git a/medcat-plugins/transformer-ner/tests/resources/mct2_model_pack.zip b/medcat-plugins/transformer-ner/tests/resources/mct2_model_pack.zip new file mode 100644 index 000000000..b6bc74e49 Binary files /dev/null and b/medcat-plugins/transformer-ner/tests/resources/mct2_model_pack.zip differ diff --git a/medcat-plugins/transformer-ner/tests/test_transformer_ner.py b/medcat-plugins/transformer-ner/tests/test_transformer_ner.py new file mode 100644 index 000000000..62f187df5 --- /dev/null +++ b/medcat-plugins/transformer-ner/tests/test_transformer_ner.py @@ -0,0 +1,49 @@ +from medcat_transformer_ner import transformer_ner +from medcat.components import types +from medcat.config import Config +from medcat.vocab import Vocab +from medcat.components.types import _DEFAULT_NER as DEFAULT_NER +import unittest + +from .helper import ComponentInitTests + +class FakeDocument: + + def __init__(self, text): + self.text = text + + +class FakeTokenizer: + + def __call__(selt, text: str) -> FakeDocument: + return FakeDocument(text) + + +class FakeCDB: + + def __init__(self, config: Config): + self.config = config + + +class NerInitTests(ComponentInitTests, unittest.TestCase): + expected_def_components = len(DEFAULT_NER) + comp_type = types.CoreComponentType.ner + default = "transformer_ner" + default_cls = transformer_ner.NER + default_creator = transformer_ner.NER.create_new_component + module = transformer_ner + + @classmethod + def setUpClass(cls): + cls.cnf = Config() + cls.cnf.components.ner = transformer_ner.TransformerNER() + cls.cnf.components.linking.comp_name = transformer_ner.NER.name + cls.fcdb = FakeCDB(cls.cnf) + cls.fvocab = Vocab() + cls.vtokenizer = FakeTokenizer() + cls.comp_cnf = getattr(cls.cnf.components, cls.comp_type.name) + + def test_has_default(self): + avail_components = types.get_registered_components(self.comp_type) + registered_names = [name for name, _ in avail_components] + self.assertIn("transformer_ner", registered_names) \ No newline at end of file