diff --git a/medcat-trainer/webapp/api/api/model_cache.py b/medcat-trainer/webapp/api/api/model_cache.py index ac438ef09..2073e58d9 100644 --- a/medcat-trainer/webapp/api/api/model_cache.py +++ b/medcat-trainer/webapp/api/api/model_cache.py @@ -32,6 +32,24 @@ logger.warning("MAX_MEDCAT_MODELS is not an integer, using default value of 1") + +def _apply_addon_filter(cat: CAT, + addons: Optional[list[str]] = None) -> CAT: + """Return *cat* with pipeline addons filtered; full set is kept on the cache.""" + full_addons = cat.pipe._addons + if addons is None: + cat._pipeline._addons = list(full_addons) + else: + allowed_addons = set(addons) + cat._pipeline._addons = [ + addon for addon in full_addons if addon.addon_type in allowed_addons + ] + cat.config.components.addons = [ + addon.config for addon in cat._pipeline._addons + ] + return cat + + def _clear_models(cdb_map: Dict[str, CDB]=CDB_MAP, vocab_map: Dict[str, Vocab]=VOCAB_MAP, cat_map: Dict[str, CAT]=CAT_MAP): @@ -186,19 +204,40 @@ def get_medcat_from_model_pack_id(modelpack_id: int, cat_map: Dict[str, CAT]=CAT @tracer.start_as_current_span("get_medcat") def get_medcat(project, + addons: Optional[list[str]] = None, cdb_map: Dict[str, CDB]=CDB_MAP, vocab_map: Dict[str, Vocab]=VOCAB_MAP, cat_map: Dict[str, CAT]=CAT_MAP): + """Load and cache a MedCAT model for a trainer project. + + Args: + project: ``ProjectAnnotateEntities`` to load the model for. + addons: Addon types to enable on the returned model, e.g. + ``['meta_cat']`` or ``['rel_cat']``. Pass an empty list for NER and + linking only. Defaults to ``None`` (all addons enabled). + cdb_map: Module-level CDB cache. Defaults to ``CDB_MAP``. + vocab_map: Module-level vocab cache. Defaults to ``VOCAB_MAP``. + cat_map: Module-level CAT cache. Defaults to ``CAT_MAP``. + + Returns: + CAT: A cached MedCAT instance for the project. + + Raises: + Exception: If the project ConceptDB, vocab, or model pack is missing + or misconfigured. + """ cat = get_cached_medcat(project, cat_map) if cat is not None: trace.get_current_span().add_event("Loaded medcat from cache") - return cat + # NOTE: addon filtering needs to be handled on the core lib side in the future + return _apply_addon_filter(cat, addons) try: if project.model_pack is None: cat = get_medcat_from_cdb_vocab(project, cdb_map, vocab_map, cat_map) else: cat = get_medcat_from_model_pack(project, cat_map) - return cat + # NOTE: addon filtering needs to be handled on the core lib side in the future + return _apply_addon_filter(cat, addons) except AttributeError as err: raise Exception('Failure loading Project ConceptDB, Vocab or Model Pack. Are these set correctly?') from err diff --git a/medcat-trainer/webapp/api/api/utils.py b/medcat-trainer/webapp/api/api/utils.py index aff47c20f..20c236276 100644 --- a/medcat-trainer/webapp/api/api/utils.py +++ b/medcat-trainer/webapp/api/api/utils.py @@ -454,7 +454,7 @@ def prep_docs(project_id: List[int], doc_ids: List[int], user_id: int): else: # Use local medcat model logger.info('Loading CAT object in bg process for project: %s', project.id) - cat = get_medcat(project=project) + cat = get_medcat(project=project, addons=["meta_cat"]) # Set CAT filters cat.config.components.linking.filters.cuis = cuis diff --git a/medcat-trainer/webapp/api/api/views.py b/medcat-trainer/webapp/api/api/views.py index 825397b72..388f4a954 100644 --- a/medcat-trainer/webapp/api/api/views.py +++ b/medcat-trainer/webapp/api/api/views.py @@ -45,7 +45,6 @@ logger = logging.getLogger(__name__) - # Get the basic version of MedCAT cat = None @@ -329,7 +328,7 @@ def prepare_documents(request): existing_annotations=anns) else: # Use local medcat model - cat = get_medcat(project=project) + cat = get_medcat(project=project, addons=["meta_cat"]) logger.info('loaded medcat model for project: %s', project.id) # Set CAT filters