Skip to content

Commit

Permalink
Merge pull request #17 from SCAI-BIO/add-weaviate-functions
Browse files Browse the repository at this point in the history
refactor: add functions to retrieve concept and terminology
  • Loading branch information
tiadams authored Aug 26, 2024
2 parents fa768b9 + 92bf3ce commit 9d3ba42
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 3 deletions.
52 changes: 51 additions & 1 deletion datastew/repository/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,43 @@ def get_all_sentence_embedders(self) -> List[str]:
except Exception as e:
raise RuntimeError(f"Failed to fetch sentence embedders: {e}")
return list(sentence_embedders)

def get_concept(self, concept_id: str) -> Concept:
try:
if not self._concept_exists(concept_id):
raise RuntimeError(f"Concept {concept_id} does not exists")
result = self.client.query.get(
"Concept",
["conceptID",
"prefLabel",
"_additional { id }",
"hasTerminology { ... on Terminology { _additional { id } name } }"]
).with_where({
"path": "conceptID",
"operator": "Equal",
"valueText": concept_id
}).do()
concept_data = result["data"]["Get"]["Concept"][0]
terminology_data = concept_data["hasTerminology"][0]
terminology_name = terminology_data["name"]
terminology_id = terminology_data["_additional"]["id"]
terminology = Terminology(terminology_name, terminology_id)
id = concept_data["_additional"]["id"]
concept_name = result["data"]["Get"]["Concept"][0]["prefLabel"]
concept = Concept(terminology, concept_name, concept_id, id)
except Exception as e:
raise RuntimeError(f"Failed to fetch concept {concept_id}: {e}")
return concept

def get_all_concepts(self) -> List[Concept]:
concepts = []
try:
result = self.client.query.get(
"Concept",
["conceptID", "prefLabel", "hasTerminology { ... on Terminology { _additional { id } name } }"]
["conceptID",
"prefLabel",
"_additional { id }",
"hasTerminology { ... on Terminology { _additional { id } name } }"]
).with_additional("vector").do()
for item in result['data']['Get']['Concept']:
terminology_data = item["hasTerminology"][0] # Assuming it has only one terminology
Expand All @@ -86,11 +116,31 @@ def get_all_concepts(self) -> List[Concept]:
concept_identifier=item["conceptID"],
pref_label=item["prefLabel"],
terminology=terminology,
id=item["_additional"]["id"]
)
concepts.append(concept)
except Exception as e:
raise RuntimeError(f"Failed to fetch concepts: {e}")
return concepts

def get_terminology(self, terminology_name: str) -> Terminology:
try:
if not self._terminology_exists(terminology_name):
raise RuntimeError(f"Terminology {terminology_name} does not exists")
result = self.client.query.get(
"Terminology",
["name", "_additional { id }"]
).with_where({
"path": "name",
"operator": "Equal",
"valueText": terminology_name
}).do()
terminology_data = result["data"]["Get"]["Terminology"][0]
terminology_id = terminology_data["_additional"]["id"]
terminology = Terminology(terminology_name, terminology_id)
except Exception as e:
raise RuntimeError(f"Failed to fetch terminology {terminology_name}: {e}")
return terminology

def get_all_terminologies(self) -> List[Terminology]:
terminologies = []
Expand Down
10 changes: 8 additions & 2 deletions tests/test_weaviate_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,16 @@ def test_repository(self):
self.assertEqual(len(mappings), 5)

concepts = repository.get_all_concepts()
concept = repository.get_concept("Concept ID: 11893007")
self.assertEqual(concept.concept_identifier, "Concept ID: 11893007")
self.assertEqual(concept.pref_label, text1)
self.assertEqual(concept.terminology.name, terminology1.name)
self.assertEqual(len(concepts), 9)

terminology = repository.get_terminology("snomed CT")
terminologies = repository.get_all_terminologies()
terminology_names = [embedding.name for embedding in terminologies]
self.assertEqual(terminology.name, "snomed CT")
self.assertEqual(len(terminologies), 2)
self.assertIn("NCI Thesaurus OBO Edition", terminology_names)
self.assertIn("snomed CT", terminology_names)
Expand All @@ -92,14 +98,14 @@ def test_repository(self):
self.assertEqual(len(closest_mappings_with_similarities), 5)
self.assertEqual(closest_mappings_with_similarities[0][0].text, "Common cold")
self.assertEqual(closest_mappings_with_similarities[0][0].sentence_embedder, model_name1)
self.assertEqual(closest_mappings_with_similarities[0][1], 0.6747197)
self.assertAlmostEqual(closest_mappings_with_similarities[0][1], 0.6747197, 3)

terminology_and_model_specific_closest_mappings = repository.get_terminology_and_model_specific_closest_mappings(test_embedding, "snomed CT", model_name1)
self.assertEqual(len(terminology_and_model_specific_closest_mappings), 2)
self.assertEqual(closest_mappings_with_similarities[0][0].text, "Common cold")
self.assertEqual(terminology_and_model_specific_closest_mappings[0][0].concept.terminology.name, "snomed CT")
self.assertEqual(terminology_and_model_specific_closest_mappings[0][0].sentence_embedder, model_name1)
self.assertEqual(closest_mappings_with_similarities[0][1], 0.6747197)
self.assertAlmostEqual(closest_mappings_with_similarities[0][1], 0.6747197, 3)

# check if it crashed (due to schema re-creation) after restart
repository = WeaviateRepository(mode="disk", path="db")
Expand Down

0 comments on commit 9d3ba42

Please sign in to comment.