diff --git a/datastew/repository/weaviate.py b/datastew/repository/weaviate.py index 0842e7f..4b1fdeb 100644 --- a/datastew/repository/weaviate.py +++ b/datastew/repository/weaviate.py @@ -68,13 +68,43 @@ def get_all_sentence_embedders(self) -> List[str]: except Exception as e: raise RuntimeError(f"Failed to fetch sentence embedders: {e}") return list(sentence_embedders) + + def get_concept(self, concept_id: str) -> Concept: + try: + if not self._concept_exists(concept_id): + raise RuntimeError(f"Concept {concept_id} does not exists") + result = self.client.query.get( + "Concept", + ["conceptID", + "prefLabel", + "_additional { id }", + "hasTerminology { ... on Terminology { _additional { id } name } }"] + ).with_where({ + "path": "conceptID", + "operator": "Equal", + "valueText": concept_id + }).do() + concept_data = result["data"]["Get"]["Concept"][0] + terminology_data = concept_data["hasTerminology"][0] + terminology_name = terminology_data["name"] + terminology_id = terminology_data["_additional"]["id"] + terminology = Terminology(terminology_name, terminology_id) + id = concept_data["_additional"]["id"] + concept_name = result["data"]["Get"]["Concept"][0]["prefLabel"] + concept = Concept(terminology, concept_name, concept_id, id) + except Exception as e: + raise RuntimeError(f"Failed to fetch concept {concept_id}: {e}") + return concept def get_all_concepts(self) -> List[Concept]: concepts = [] try: result = self.client.query.get( "Concept", - ["conceptID", "prefLabel", "hasTerminology { ... on Terminology { _additional { id } name } }"] + ["conceptID", + "prefLabel", + "_additional { id }", + "hasTerminology { ... on Terminology { _additional { id } name } }"] ).with_additional("vector").do() for item in result['data']['Get']['Concept']: terminology_data = item["hasTerminology"][0] # Assuming it has only one terminology @@ -86,11 +116,31 @@ def get_all_concepts(self) -> List[Concept]: concept_identifier=item["conceptID"], pref_label=item["prefLabel"], terminology=terminology, + id=item["_additional"]["id"] ) concepts.append(concept) except Exception as e: raise RuntimeError(f"Failed to fetch concepts: {e}") return concepts + + def get_terminology(self, terminology_name: str) -> Terminology: + try: + if not self._terminology_exists(terminology_name): + raise RuntimeError(f"Terminology {terminology_name} does not exists") + result = self.client.query.get( + "Terminology", + ["name", "_additional { id }"] + ).with_where({ + "path": "name", + "operator": "Equal", + "valueText": terminology_name + }).do() + terminology_data = result["data"]["Get"]["Terminology"][0] + terminology_id = terminology_data["_additional"]["id"] + terminology = Terminology(terminology_name, terminology_id) + except Exception as e: + raise RuntimeError(f"Failed to fetch terminology {terminology_name}: {e}") + return terminology def get_all_terminologies(self) -> List[Terminology]: terminologies = [] diff --git a/tests/test_weaviate_repository.py b/tests/test_weaviate_repository.py index f5c0d80..55d0eff 100644 --- a/tests/test_weaviate_repository.py +++ b/tests/test_weaviate_repository.py @@ -68,10 +68,16 @@ def test_repository(self): self.assertEqual(len(mappings), 5) concepts = repository.get_all_concepts() + concept = repository.get_concept("Concept ID: 11893007") + self.assertEqual(concept.concept_identifier, "Concept ID: 11893007") + self.assertEqual(concept.pref_label, text1) + self.assertEqual(concept.terminology.name, terminology1.name) self.assertEqual(len(concepts), 9) + terminology = repository.get_terminology("snomed CT") terminologies = repository.get_all_terminologies() terminology_names = [embedding.name for embedding in terminologies] + self.assertEqual(terminology.name, "snomed CT") self.assertEqual(len(terminologies), 2) self.assertIn("NCI Thesaurus OBO Edition", terminology_names) self.assertIn("snomed CT", terminology_names) @@ -92,14 +98,14 @@ def test_repository(self): self.assertEqual(len(closest_mappings_with_similarities), 5) self.assertEqual(closest_mappings_with_similarities[0][0].text, "Common cold") self.assertEqual(closest_mappings_with_similarities[0][0].sentence_embedder, model_name1) - self.assertEqual(closest_mappings_with_similarities[0][1], 0.6747197) + self.assertAlmostEqual(closest_mappings_with_similarities[0][1], 0.6747197, 3) terminology_and_model_specific_closest_mappings = repository.get_terminology_and_model_specific_closest_mappings(test_embedding, "snomed CT", model_name1) self.assertEqual(len(terminology_and_model_specific_closest_mappings), 2) self.assertEqual(closest_mappings_with_similarities[0][0].text, "Common cold") self.assertEqual(terminology_and_model_specific_closest_mappings[0][0].concept.terminology.name, "snomed CT") self.assertEqual(terminology_and_model_specific_closest_mappings[0][0].sentence_embedder, model_name1) - self.assertEqual(closest_mappings_with_similarities[0][1], 0.6747197) + self.assertAlmostEqual(closest_mappings_with_similarities[0][1], 0.6747197, 3) # check if it crashed (due to schema re-creation) after restart repository = WeaviateRepository(mode="disk", path="db")