From 26eea887615a325e43f1974b8377ec48eee6511a Mon Sep 17 00:00:00 2001 From: Mehmet Can Ay Date: Tue, 13 Aug 2024 11:42:29 +0200 Subject: [PATCH 1/8] add: sentence embedder property to repository schema --- datastew/embedding.py | 27 ++++++--- datastew/process/ols.py | 6 +- datastew/repository/__init__.py | 3 +- datastew/repository/base.py | 8 ++- datastew/repository/model.py | 18 +++++- datastew/repository/sqllite.py | 6 +- datastew/repository/weaviate.py | 76 +++++++++++++++++++++++--- datastew/repository/weaviate_schema.py | 15 +++++ tests/test_sql_repository.py | 24 ++++++-- tests/test_system.py | 9 +-- tests/test_weaviate_repository.py | 25 +++++---- 11 files changed, 171 insertions(+), 46 deletions(-) diff --git a/datastew/embedding.py b/datastew/embedding.py index 8a77ac6..d1767e9 100644 --- a/datastew/embedding.py +++ b/datastew/embedding.py @@ -12,17 +12,21 @@ def get_embedding(self, text: str) -> [float]: def get_embeddings(self, messages: [str]) -> [[float]]: pass + def get_model_name(self) -> str: + pass + def sanitize(self, message: str) -> str: return message.strip().lower() class GPT4Adapter(EmbeddingModel): - def __init__(self, api_key: str): + def __init__(self, api_key: str, model: str = "text-embedding-ada-002"): self.api_key = api_key openai.api_key = api_key + self.model = model logging.getLogger().setLevel(logging.INFO) - def get_embedding(self, text: str, model="text-embedding-ada-002"): + def get_embedding(self, text: str): logging.info(f"Getting embedding for {text}") try: if text is None or text == "" or text is np.nan: @@ -31,12 +35,12 @@ def get_embedding(self, text: str, model="text-embedding-ada-002"): if isinstance(text, str): text = text.replace("\n", " ") text = self.sanitize(text) - return openai.Embedding.create(input=[text], model=model)["data"][0]["embedding"] + return openai.Embedding.create(input=[text], model=self.model)["data"][0]["embedding"] except Exception as e: logging.error(f"Error getting embedding for {text}: {e}") return None - def get_embeddings(self, messages: [str], model="text-embedding-ada-002", max_length=2048): + def get_embeddings(self, messages: [str], max_length=2048): sanitized_messages = [self.sanitize(message) for message in messages] embeddings = [] total_chunks = (len(sanitized_messages) + max_length - 1) // max_length @@ -44,16 +48,20 @@ def get_embeddings(self, messages: [str], model="text-embedding-ada-002", max_le for i in range(0, len(sanitized_messages), max_length): current_chunk += 1 chunk = sanitized_messages[i:i + max_length] - response = openai.Embedding.create(input=chunk, model=model) + response = openai.Embedding.create(input=chunk, model=self.model) embeddings.extend([item["embedding"] for item in response["data"]]) logging.info("Processed chunk %d/%d", current_chunk, total_chunks) return embeddings + + def get_model_name(self) -> str: + return self.model class MPNetAdapter(EmbeddingModel): def __init__(self, model="sentence-transformers/all-mpnet-base-v2"): logging.getLogger().setLevel(logging.INFO) - self.mpnet_model = SentenceTransformer(model) + self.model = SentenceTransformer(model) + self.model_name = model # For Weaviate def get_embedding(self, text: str): logging.info(f"Getting embedding for {text}") @@ -64,7 +72,7 @@ def get_embedding(self, text: str): if isinstance(text, str): text = text.replace("\n", " ") text = self.sanitize(text) - return self.mpnet_model.encode(text) + return self.model.encode(text) except Exception as e: logging.error(f"Error getting embedding for {text}: {e}") return None @@ -72,11 +80,14 @@ def get_embedding(self, text: str): def get_embeddings(self, messages: [str]) -> [[float]]: sanitized_messages = [self.sanitize(message) for message in messages] try: - embeddings = self.mpnet_model.encode(sanitized_messages) + embeddings = self.model.encode(sanitized_messages) except Exception as e: logging.error(f"Failed for messages {sanitized_messages}") flattened_embeddings = [[float(element) for element in row] for row in embeddings] return flattened_embeddings + + def get_model_name(self) -> str: + return self.model_name class TextEmbedding: diff --git a/datastew/process/ols.py b/datastew/process/ols.py index 100f81c..ba058df 100644 --- a/datastew/process/ols.py +++ b/datastew/process/ols.py @@ -2,7 +2,7 @@ import requests -from datastew.repository.model import Terminology, Concept, Mapping +from datastew.repository.model import SentenceEmbedder, Terminology, Concept, Mapping from datastew.embedding import EmbeddingModel from datastew.repository.base import BaseRepository @@ -52,9 +52,11 @@ def process(self): else: descriptions.append(term["label"]) embeddings = self.embedding_model.get_embeddings(descriptions) + model_name = self.embedding_model.get_model_name() + sentence_embedder = SentenceEmbedder(model_name) for identifier, label, description, embedding in zip(identifiers, labels, descriptions, embeddings): concept = Concept(self.terminology, label, identifier) - mapping = Mapping(concept, description, embedding) + mapping = Mapping(concept, description, embedding, sentence_embedder) self.repository.store(concept) self.repository.store(mapping) except Exception as e: diff --git a/datastew/repository/__init__.py b/datastew/repository/__init__.py index 9e273ee..4678eb6 100644 --- a/datastew/repository/__init__.py +++ b/datastew/repository/__init__.py @@ -1,10 +1,11 @@ -from .model import Terminology, Concept, Mapping +from .model import Terminology, Concept, Mapping, SentenceEmbedder from .sqllite import SQLLiteRepository from .weaviate import WeaviateRepository __all__ = [ "Terminology", "Concept", + "SentenceEmbedder", "Mapping", "SQLLiteRepository", "WeaviateRepository" diff --git a/datastew/repository/base.py b/datastew/repository/base.py index 4aa6a85..681bf1c 100644 --- a/datastew/repository/base.py +++ b/datastew/repository/base.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import List -from datastew.repository.model import Mapping, Concept, Terminology +from datastew.repository.model import Mapping, Concept, Terminology, SentenceEmbedder class BaseRepository(ABC): @@ -27,10 +27,14 @@ def get_all_terminologies(self) -> List[Terminology]: pass @abstractmethod - def get_all_mappings(self, limit=1000) -> [Mapping]: + def get_all_mappings(self, limit=1000) -> List[Mapping]: """Get all embeddings up to a limit""" pass + @abstractmethod + def get_all_sentence_embedders(self) -> List[SentenceEmbedder]: + pass + @abstractmethod def get_closest_mappings(self, embedding, limit=5): """Get the closest mappings based on embedding.""" diff --git a/datastew/repository/model.py b/datastew/repository/model.py index a0cbec6..9c7c9ca 100644 --- a/datastew/repository/model.py +++ b/datastew/repository/model.py @@ -13,11 +13,20 @@ class Terminology(Base): id = Column(String, primary_key=True) name = Column(String) - def __init__(self, name: str, id: str) -> object: + def __init__(self, name: str, id: str) -> None: self.name = name self.id = id +class SentenceEmbedder(Base): + __tablename__ = 'sentence_embedder' + id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(String) + + def __init__(self, name: str) -> None: + self.name = name + + class Concept(Base): __tablename__ = 'concept' concept_identifier = Column(String, primary_key=True) @@ -26,7 +35,7 @@ class Concept(Base): terminology = relationship("Terminology") uuid = Column(String) - def __init__(self, terminology: Terminology, pref_label: str, concept_identifier: str, id: str = None) -> object: + def __init__(self, terminology: Terminology, pref_label: str, concept_identifier: str, id: str | None = None) -> None: self.terminology = terminology self.pref_label = pref_label # should be unique @@ -42,13 +51,16 @@ class Mapping(Base): concept = relationship("Concept") text = Column(Text) embedding_json = Column(Text) + sentence_embedder_id = Column(String, ForeignKey('sentence_embedder.id')) + sentence_embedder = relationship("SentenceEmbedder") - def __init__(self, concept: Concept, text: str, embedding: list) -> object: + def __init__(self, concept: Concept, text: str, embedding: list, sentence_embedder: SentenceEmbedder) -> None: self.concept = concept self.text = text if isinstance(embedding, np.ndarray): embedding = embedding.tolist() self.embedding_json = json.dumps(embedding) # Store embedding as JSON + self.sentence_embedder = sentence_embedder def __str__(self): return f"{self.concept.terminology.name} > {self.concept.concept_identifier} : {self.concept.pref_label} | {self.text}" diff --git a/datastew/repository/sqllite.py b/datastew/repository/sqllite.py index 270e5e7..5c6f317 100644 --- a/datastew/repository/sqllite.py +++ b/datastew/repository/sqllite.py @@ -1,5 +1,4 @@ import random -import sqlite3 import numpy as np @@ -7,7 +6,7 @@ from sqlalchemy import create_engine, func from sqlalchemy.orm import sessionmaker -from datastew.repository.model import Base, Terminology, Concept, Mapping +from datastew.repository.model import Base, SentenceEmbedder, Terminology, Concept, Mapping from datastew.repository.base import BaseRepository @@ -51,6 +50,9 @@ def get_all_mappings(self, limit=1000): # Query for mappings corresponding to the random indices mappings = self.session.query(Mapping).filter(Mapping.id.in_(random_indices)).all() return mappings + + def get_all_sentence_embedders(self) -> List[SentenceEmbedder]: + return self.session.query(SentenceEmbedder).all() def get_closest_mappings(self, embedding: List[float], limit=5): mappings = self.session.query(Mapping).all() diff --git a/datastew/repository/weaviate.py b/datastew/repository/weaviate.py index 22c4c87..eb746a0 100644 --- a/datastew/repository/weaviate.py +++ b/datastew/repository/weaviate.py @@ -5,9 +5,9 @@ import weaviate from weaviate.embedded import EmbeddedOptions -from datastew.repository import Mapping, Terminology, Concept +from datastew.repository import Mapping, Terminology, Concept, SentenceEmbedder from datastew.repository.base import BaseRepository -from datastew.repository.weaviate_schema import terminology_schema, concept_schema, mapping_schema +from datastew.repository.weaviate_schema import terminology_schema, concept_schema, mapping_schema, sentence_embedder_schema class WeaviateRepository(BaseRepository): @@ -38,6 +38,7 @@ def __init__(self, mode="memory", path=None): raise ConnectionError(f"Failed to initialize Weaviate client: {e}") try: + self._create_schema_if_not_exists(sentence_embedder_schema) self._create_schema_if_not_exists(terminology_schema) self._create_schema_if_not_exists(concept_schema) self._create_schema_if_not_exists(mapping_schema) @@ -58,6 +59,17 @@ def store_all(self, model_object_instances): for instance in model_object_instances: self.store(instance) + def get_all_sentence_embedders(self) -> List[SentenceEmbedder]: + sentence_embedders = [] + try: + result = self.client.query.get("SentenceEmbedders", ["name"]).do() + for item in result['data']['Get']['SentenceEmbedders']: + sentence_embedder = SentenceEmbedder(item["name"]) + sentence_embedders.append(sentence_embedder) + except Exception as e: + raise RuntimeError(f"Failed to fetch terminologies: {e}") + return sentence_embedder + def get_all_concepts(self) -> List[Concept]: concepts = [] try: @@ -101,10 +113,12 @@ def get_all_mappings(self, limit=1000) -> List[Mapping]: result = self.client.query.get( "Mapping", ["text", - "hasConcept { ... on Concept { _additional { id } conceptID prefLabel hasTerminology { ... on Terminology { _additional { id } name } } } }"] + "hasConcept { ... on Concept { _additional { id } conceptID prefLabel hasTerminology { ... on Terminology { _additional { id } name } } } }", + "hasSentenceEmbedder { ... on SentenceEmbedder { name } }"] ).with_additional("vector").with_limit(limit).do() for item in result['data']['Get']['Mapping']: embedding_vector = item["_additional"]["vector"] + sentence_embedder_data = item["hasSentenceEmbedder"][0] concept_data = item["hasConcept"][0] # Assuming it has only one concept terminology_data = concept_data["hasTerminology"][0] # Assuming it has only one terminology terminology = Terminology( @@ -117,10 +131,14 @@ def get_all_mappings(self, limit=1000) -> List[Mapping]: terminology=terminology, id=concept_data["_additional"]["id"] ) + sentence_embedder = SentenceEmbedder( + name=sentence_embedder_data["name"] + ) mapping = Mapping( text=item["text"], concept=concept, - embedding=embedding_vector + embedding=embedding_vector, + sentence_embedder=sentence_embedder ) mappings.append(mapping) except Exception as e: @@ -133,10 +151,12 @@ def get_closest_mappings(self, embedding, limit=5) -> List[Mapping]: result = self.client.query.get( "Mapping", ["text", "_additional { distance }", - "hasConcept { ... on Concept { _additional { id } conceptID prefLabel hasTerminology { ... on Terminology { _additional { id } name } } } }"] + "hasConcept { ... on Concept { _additional { id } conceptID prefLabel hasTerminology { ... on Terminology { _additional { id } name } } } }", + "hasSentenceEmbedder { ... on SentenceEmbedder { name } }"] ).with_additional("vector").with_near_vector({"vector": embedding}).with_limit(limit).do() for item in result['data']['Get']['Mapping']: embedding_vector = item["_additional"]["vector"] + sentence_embedder_data = item["hasSentenceEmbedder"][0] concept_data = item["hasConcept"][0] # Assuming it has only one concept terminology_data = concept_data["hasTerminology"][0] # Assuming it has only one terminology terminology = Terminology( @@ -149,10 +169,14 @@ def get_closest_mappings(self, embedding, limit=5) -> List[Mapping]: terminology=terminology, id=concept_data["_additional"]["id"] ) + sentence_embedder = SentenceEmbedder( + name=sentence_embedder_data["name"] + ) mapping = Mapping( text=item["text"], concept=concept, - embedding=embedding_vector + embedding=embedding_vector, + sentence_embedder=sentence_embedder ) mappings.append(mapping) except Exception as e: @@ -165,11 +189,13 @@ def get_closest_mappings_with_similarities(self, embedding, limit=5) -> List[Tup result = self.client.query.get( "Mapping", ["text", "_additional { distance }", - "hasConcept { ... on Concept { _additional { id } conceptID prefLabel hasTerminology { ... on Terminology { _additional { id } name } } } }"] + "hasConcept { ... on Concept { _additional { id } conceptID prefLabel hasTerminology { ... on Terminology { _additional { id } name } } } }", + "hasSentenceEmbedder { ... on SentenceEmbedder { name } }"] ).with_additional("vector").with_near_vector({"vector": embedding}).with_limit(limit).do() for item in result['data']['Get']['Mapping']: similarity = 1 - item["_additional"]["distance"] embedding_vector = item["_additional"]["vector"] + sentence_embedder_data = item["hasSentenceEmbedder"][0] concept_data = item["hasConcept"][0] # Assuming it has only one concept terminology_data = concept_data["hasTerminology"][0] # Assuming it has only one terminology terminology = Terminology( @@ -182,10 +208,14 @@ def get_closest_mappings_with_similarities(self, embedding, limit=5) -> List[Tup terminology=terminology, id=concept_data["_additional"]["id"] ) + sentence_embedder = SentenceEmbedder( + name=sentence_embedder_data["name"] + ) mapping = Mapping( text=item["text"], concept=concept, - embedding=embedding_vector + embedding=embedding_vector, + sentence_embedder=sentence_embedder ) mappings_with_similarities.append((mapping, similarity)) except Exception as e: @@ -196,10 +226,20 @@ def shut_down(self): if self.mode == "memory": shutil.rmtree("db") - def store(self, model_object_instance: Union[Terminology, Concept, Mapping]): + def store(self, model_object_instance: Union[Terminology, Concept, Mapping, SentenceEmbedder]): random_uuid = uuid.uuid4() model_object_instance.id = random_uuid try: + if isinstance(model_object_instance, SentenceEmbedder): + if not self._sentence_embedder_exists(model_object_instance.name): + properties = { + "name": model_object_instance.name + } + self.client.data_object.create( + class_name="SentenceEmbedder", + data_object=properties, + uuid=random_uuid + ) if isinstance(model_object_instance, Terminology): if not self._terminology_exists(model_object_instance.name): properties = { @@ -252,6 +292,13 @@ def store(self, model_object_instance: Union[Terminology, Concept, Mapping]): to_class_name="Concept", to_uuid=model_object_instance.concept.uuid, ) + self.client.data_object.reference.add( + from_class_name="Mapping", + from_uuid=random_uuid, + from_property_name="hasSentenceEmbedder", + to_class_name="SentenceEmbedder", + to_uuid=model_object_instance.sentence_embedder.uuid, + ) else: self.logger.info(f'Mapping with same embedding already exists. Skipping.') else: @@ -259,6 +306,17 @@ def store(self, model_object_instance: Union[Terminology, Concept, Mapping]): except Exception as e: raise RuntimeError(f"Failed to store object in Weaviate: {e}") + + def _sentence_embedder_exists(self, name: str) -> bool: + try: + result = self.client.query.get("SentenceEmbedder", ["name"]).with_where({ + "path": ["name"], + "operator": "Equal", + "valueText": name + }).do() + return len(result['data']['Get']['SentenceEmbedder']) > 0 + except Exception as e: + raise RuntimeError(f"Failed to check if sentence embedder exists: {e}") def _terminology_exists(self, name: str) -> bool: try: diff --git a/datastew/repository/weaviate_schema.py b/datastew/repository/weaviate_schema.py index 1f28992..1e7551e 100644 --- a/datastew/repository/weaviate_schema.py +++ b/datastew/repository/weaviate_schema.py @@ -1,3 +1,14 @@ +sentence_embedder_schema = { + "class": "SentenceEmbedder", + "description": "A sentence embedder model entry", + "properties": [ + { + "name": "name", + "dataType": ["string"] + } + ] +} + terminology_schema = { "class": "Terminology", "description": "A terminology entry", @@ -43,6 +54,10 @@ { "name": "hasConcept", "dataType": ["Concept"] + }, + { + "name": "hasSentenceEmbedder", + "dataType": ["SentenceEmbedder"] } ] } \ No newline at end of file diff --git a/tests/test_sql_repository.py b/tests/test_sql_repository.py index f9c1a67..1f6ea78 100644 --- a/tests/test_sql_repository.py +++ b/tests/test_sql_repository.py @@ -1,6 +1,6 @@ import unittest -from datastew.repository.model import Terminology, Concept, Mapping +from datastew.repository.model import Terminology, Concept, Mapping, SentenceEmbedder from datastew.repository.sqllite import SQLLiteRepository @@ -14,12 +14,26 @@ def tearDown(self): def test_get_closest_mappings(self): terminology = Terminology(name="Terminology 1", id="1") + sentence_embedder = SentenceEmbedder(name="sentence-transformers/all-mpnet-base-v2") concept = Concept(terminology=terminology, pref_label="Concept 1", concept_identifier="1") - mapping_1 = Mapping(concept=concept, text="Text 1", embedding=[0.1, 0.2, 0.3]) - mapping_2 = Mapping(concept=concept, text="Text 2", embedding=[0.2, 0.3, 0.4]) - mapping_3 = Mapping(concept=concept, text="Text 3", embedding=[1.2, 2.3, 3.4]) - self.repository.store_all([terminology, concept, mapping_1, mapping_2, mapping_3]) + mapping_1 = Mapping(concept=concept, text="Text 1", embedding=[0.1, 0.2, 0.3], sentence_embedder=sentence_embedder) + mapping_2 = Mapping(concept=concept, text="Text 2", embedding=[0.2, 0.3, 0.4], sentence_embedder=sentence_embedder) + mapping_3 = Mapping(concept=concept, text="Text 3", embedding=[1.2, 2.3, 3.4], sentence_embedder=sentence_embedder) + self.repository.store_all([terminology, concept, mapping_1, mapping_2, mapping_3, sentence_embedder]) sample_embedding = [0.2, 0.4, 0.35] closest_mappings, distances = self.repository.get_closest_mappings(sample_embedding, limit=3) self.assertEqual(len(closest_mappings), 3) self.assertEqual(mapping_2.text, closest_mappings[0].text) + + def test_get_all_sentence_embedders(self): + terminology = Terminology(name="Terminology 1", id="1") + sentence_embedder_1 = SentenceEmbedder(name="sentence-transformers/all-mpnet-base-v2") + sentence_embedder_2 = SentenceEmbedder(name="text-embedding-ada-002") + concept = Concept(terminology=terminology, pref_label="Concept 1", concept_identifier="1") + mapping_1 = Mapping(concept=concept, text="Text 1", embedding=[0.1, 0.2, 0.3], sentence_embedder=sentence_embedder_1) + mapping_2 = Mapping(concept=concept, text="Text 1", embedding=[0.1, 0.2, 0.3], sentence_embedder=sentence_embedder_2) + self.repository.store_all([terminology, concept, mapping_1, mapping_2, sentence_embedder_1, sentence_embedder_2]) + sentence_embedders = self.repository.get_all_sentence_embedders() + self.assertEqual(len(sentence_embedders), 2) + self.assertEqual(sentence_embedders[0].name, "sentence-transformers/all-mpnet-base-v2") + self.assertEqual(sentence_embedders[1].name, "text-embedding-ada-002") diff --git a/tests/test_system.py b/tests/test_system.py index 3855d83..f82b631 100644 --- a/tests/test_system.py +++ b/tests/test_system.py @@ -1,6 +1,6 @@ import unittest -from datastew.repository.model import Terminology, Concept, Mapping +from datastew.repository.model import Terminology, Concept, Mapping, SentenceEmbedder from datastew.embedding import MPNetAdapter from datastew.repository.sqllite import SQLLiteRepository @@ -19,13 +19,14 @@ def test_mapping_storage_and_closest_retrieval(self): terminology = Terminology("test", "test") concept1 = Concept(terminology, "cat", "TEST:1") concept1_description = "The cat is sitting on the mat." - mapping1 = Mapping(concept1, concept1_description, self.embedding_model.get_embedding(concept1_description)) + sentence_embedder = SentenceEmbedder("test") + mapping1 = Mapping(concept1, concept1_description, self.embedding_model.get_embedding(concept1_description), sentence_embedder=sentence_embedder) concept2 = Concept(terminology, "sunrise", "TEST:2") concept2_description = "The sun rises in the east." - mapping2 = Mapping(concept2, concept2_description, self.embedding_model.get_embedding(concept2_description)) + mapping2 = Mapping(concept2, concept2_description, self.embedding_model.get_embedding(concept2_description), sentence_embedder=sentence_embedder) concept3 = Concept(terminology, "dog", "TEST:3") concept3_description = "A loyal companion to humans." - mapping3 = Mapping(concept3, concept3_description, self.embedding_model.get_embedding(concept3_description)) + mapping3 = Mapping(concept3, concept3_description, self.embedding_model.get_embedding(concept3_description), sentence_embedder=sentence_embedder) self.repository.store_all([terminology, concept1, mapping1, concept2, mapping2, concept3, mapping3]) # test new mappings text1 = "A furry feline rests on the rug." diff --git a/tests/test_weaviate_repository.py b/tests/test_weaviate_repository.py index f4d5466..0be9b09 100644 --- a/tests/test_weaviate_repository.py +++ b/tests/test_weaviate_repository.py @@ -2,7 +2,7 @@ from unittest import TestCase from datastew import MPNetAdapter -from datastew.repository import Terminology, Concept, Mapping +from datastew.repository import Terminology, Concept, Mapping, SentenceEmbedder from datastew.repository.weaviate import WeaviateRepository @@ -15,42 +15,43 @@ def test_repository(self): embedding_model = MPNetAdapter() terminology = Terminology("snomed CT", "SNOMED") + sentence_embedder = SentenceEmbedder("sentence-transformers/all-mpnet-base-v2") text1 = "Diabetes mellitus (disorder)" concept1 = Concept(terminology, text1, "Concept ID: 11893007") - mapping1 = Mapping(concept1, text1, embedding_model.get_embedding(text1)) + mapping1 = Mapping(concept1, text1, embedding_model.get_embedding(text1), sentence_embedder) text2 = "Hypertension (disorder)" concept2 = Concept(terminology, text2, "Concept ID: 73211009") - mapping2 = Mapping(concept2, text2, embedding_model.get_embedding(text2)) + mapping2 = Mapping(concept2, text2, embedding_model.get_embedding(text2), sentence_embedder) text3 = "Asthma" concept3 = Concept(terminology, text3, "Concept ID: 195967001") - mapping3 = Mapping(concept3, text3, embedding_model.get_embedding(text3)) + mapping3 = Mapping(concept3, text3, embedding_model.get_embedding(text3), sentence_embedder) text4 = "Heart attack" concept4 = Concept(terminology, text4, "Concept ID: 22298006") - mapping4 = Mapping(concept4, text4, embedding_model.get_embedding(text4)) + mapping4 = Mapping(concept4, text4, embedding_model.get_embedding(text4), sentence_embedder) text5 = "Common cold" concept5 = Concept(terminology, text5, "Concept ID: 13260007") - mapping5 = Mapping(concept5, text5, embedding_model.get_embedding(text5)) + mapping5 = Mapping(concept5, text5, embedding_model.get_embedding(text5), sentence_embedder) text6 = "Stroke" concept6 = Concept(terminology, text6, "Concept ID: 422504002") - mapping6 = Mapping(concept6, text6, embedding_model.get_embedding(text6)) + mapping6 = Mapping(concept6, text6, embedding_model.get_embedding(text6), sentence_embedder) text7 = "Migraine" concept7 = Concept(terminology, text7, "Concept ID: 386098009") - mapping7 = Mapping(concept7, text7, embedding_model.get_embedding(text7)) + mapping7 = Mapping(concept7, text7, embedding_model.get_embedding(text7), sentence_embedder) text8 = "Influenza" concept8 = Concept(terminology, text8, "Concept ID: 57386000") - mapping8 = Mapping(concept8, text8, embedding_model.get_embedding(text8)) + mapping8 = Mapping(concept8, text8, embedding_model.get_embedding(text8), sentence_embedder) text9 = "Osteoarthritis" concept9 = Concept(terminology, text9, "Concept ID: 399206004") - mapping9 = Mapping(concept9, text9, embedding_model.get_embedding(text9)) + mapping9 = Mapping(concept9, text9, embedding_model.get_embedding(text9), sentence_embedder) text10 = "The flu" @@ -69,6 +70,10 @@ def test_repository(self): terminologies = repository.get_all_terminologies() self.assertEqual(len(terminologies), 1) + sentence_embedders = repository.get_all_sentence_embedders() + self.assertEqual(len(sentence_embedders), 1) + self.assertEqual(sentence_embedders[0].name, "sentence-transformers/all-mpnet-base-v2") + test_embedding = embedding_model.get_embedding(text10) closest_mappings = repository.get_closest_mappings(test_embedding) From 2bc976e2da86f6000cffe2e37a34c8d65f47ed3e Mon Sep 17 00:00:00 2001 From: Mehmet Can Ay Date: Tue, 13 Aug 2024 11:51:24 +0200 Subject: [PATCH 2/8] refactor: minor changes --- datastew/repository/model.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/datastew/repository/model.py b/datastew/repository/model.py index 9c7c9ca..f0dc1a9 100644 --- a/datastew/repository/model.py +++ b/datastew/repository/model.py @@ -2,8 +2,7 @@ import numpy as np from sqlalchemy import Column, ForeignKey, Integer, String, Text -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import relationship +from sqlalchemy.orm import relationship, declarative_base Base = declarative_base() @@ -13,7 +12,7 @@ class Terminology(Base): id = Column(String, primary_key=True) name = Column(String) - def __init__(self, name: str, id: str) -> None: + def __init__(self, name: str, id: str) -> object: self.name = name self.id = id @@ -23,7 +22,7 @@ class SentenceEmbedder(Base): id = Column(Integer, primary_key=True, autoincrement=True) name = Column(String) - def __init__(self, name: str) -> None: + def __init__(self, name: str) -> object: self.name = name @@ -35,7 +34,7 @@ class Concept(Base): terminology = relationship("Terminology") uuid = Column(String) - def __init__(self, terminology: Terminology, pref_label: str, concept_identifier: str, id: str | None = None) -> None: + def __init__(self, terminology: Terminology, pref_label: str, concept_identifier: str, id: str = None) -> object: self.terminology = terminology self.pref_label = pref_label # should be unique @@ -54,7 +53,7 @@ class Mapping(Base): sentence_embedder_id = Column(String, ForeignKey('sentence_embedder.id')) sentence_embedder = relationship("SentenceEmbedder") - def __init__(self, concept: Concept, text: str, embedding: list, sentence_embedder: SentenceEmbedder) -> None: + def __init__(self, concept: Concept, text: str, embedding: list, sentence_embedder: SentenceEmbedder) -> object: self.concept = concept self.text = text if isinstance(embedding, np.ndarray): From 1214e180cd9d5a269ce1aa5c67cd0e8857bc3f40 Mon Sep 17 00:00:00 2001 From: Mehmet Can Ay Date: Wed, 14 Aug 2024 10:18:29 +0200 Subject: [PATCH 3/8] refactor: minor changes --- datastew/repository/weaviate.py | 10 +++++----- tests/test_weaviate_repository.py | 7 ++++--- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/datastew/repository/weaviate.py b/datastew/repository/weaviate.py index eb746a0..c15d953 100644 --- a/datastew/repository/weaviate.py +++ b/datastew/repository/weaviate.py @@ -62,13 +62,13 @@ def store_all(self, model_object_instances): def get_all_sentence_embedders(self) -> List[SentenceEmbedder]: sentence_embedders = [] try: - result = self.client.query.get("SentenceEmbedders", ["name"]).do() - for item in result['data']['Get']['SentenceEmbedders']: + result = self.client.query.get("SentenceEmbedder", ["name"]).do() + for item in result['data']['Get']['SentenceEmbedder']: sentence_embedder = SentenceEmbedder(item["name"]) sentence_embedders.append(sentence_embedder) except Exception as e: raise RuntimeError(f"Failed to fetch terminologies: {e}") - return sentence_embedder + return sentence_embedders def get_all_concepts(self) -> List[Concept]: concepts = [] @@ -240,7 +240,7 @@ def store(self, model_object_instance: Union[Terminology, Concept, Mapping, Sent data_object=properties, uuid=random_uuid ) - if isinstance(model_object_instance, Terminology): + elif isinstance(model_object_instance, Terminology): if not self._terminology_exists(model_object_instance.name): properties = { "name": model_object_instance.name @@ -297,7 +297,7 @@ def store(self, model_object_instance: Union[Terminology, Concept, Mapping, Sent from_uuid=random_uuid, from_property_name="hasSentenceEmbedder", to_class_name="SentenceEmbedder", - to_uuid=model_object_instance.sentence_embedder.uuid, + to_uuid=model_object_instance.sentence_embedder.id, ) else: self.logger.info(f'Mapping with same embedding already exists. Skipping.') diff --git a/tests/test_weaviate_repository.py b/tests/test_weaviate_repository.py index 0be9b09..0024943 100644 --- a/tests/test_weaviate_repository.py +++ b/tests/test_weaviate_repository.py @@ -13,9 +13,10 @@ def test_repository(self): repository = WeaviateRepository(mode="disk", path="db") embedding_model = MPNetAdapter() + model_name = embedding_model.get_model_name() + sentence_embedder = SentenceEmbedder(model_name) terminology = Terminology("snomed CT", "SNOMED") - sentence_embedder = SentenceEmbedder("sentence-transformers/all-mpnet-base-v2") text1 = "Diabetes mellitus (disorder)" concept1 = Concept(terminology, text1, "Concept ID: 11893007") @@ -56,7 +57,7 @@ def test_repository(self): text10 = "The flu" repository.store_all([ - terminology, concept1, mapping1, concept2, mapping2, concept3, mapping3, + sentence_embedder, terminology, concept1, mapping1, concept2, mapping2, concept3, mapping3, concept4, mapping4, concept5, mapping5, concept6, mapping6, concept7, mapping7, concept8, mapping8, concept9, mapping9 ]) @@ -90,7 +91,7 @@ def test_repository(self): # try to store all again (should not create new entries since they already exist) repository.store_all([ - terminology, concept1, mapping1, concept2, mapping2, concept3, mapping3, + sentence_embedder, terminology, concept1, mapping1, concept2, mapping2, concept3, mapping3, concept4, mapping4, concept5, mapping5, concept6, mapping6, concept7, mapping7, concept8, mapping8, concept9, mapping9 ]) From 8c6f8fbd039b90aeaca2ac3542721a222bd09697 Mon Sep 17 00:00:00 2001 From: Mehmet Can Ay Date: Thu, 15 Aug 2024 12:38:10 +0200 Subject: [PATCH 4/8] feat: enable filter based on terminology and model --- datastew/repository/weaviate.py | 65 +++++++++++++++++++++-- tests/test_weaviate_repository.py | 88 +++++++++++++++++++------------ 2 files changed, 113 insertions(+), 40 deletions(-) diff --git a/datastew/repository/weaviate.py b/datastew/repository/weaviate.py index c15d953..1038325 100644 --- a/datastew/repository/weaviate.py +++ b/datastew/repository/weaviate.py @@ -1,13 +1,14 @@ import logging import shutil -from typing import List, Union, Tuple import uuid as uuid +from typing import List, Tuple, Union + import weaviate from weaviate.embedded import EmbeddedOptions -from datastew.repository import Mapping, Terminology, Concept, SentenceEmbedder +from datastew.repository import Concept, Mapping, SentenceEmbedder, Terminology from datastew.repository.base import BaseRepository -from datastew.repository.weaviate_schema import terminology_schema, concept_schema, mapping_schema, sentence_embedder_schema +from datastew.repository.weaviate_schema import concept_schema, mapping_schema, sentence_embedder_schema, terminology_schema class WeaviateRepository(BaseRepository): @@ -156,7 +157,7 @@ def get_closest_mappings(self, embedding, limit=5) -> List[Mapping]: ).with_additional("vector").with_near_vector({"vector": embedding}).with_limit(limit).do() for item in result['data']['Get']['Mapping']: embedding_vector = item["_additional"]["vector"] - sentence_embedder_data = item["hasSentenceEmbedder"][0] + sentence_embedder_data = item["hasSentenceEmbedder"][0] # Assuming it has only one sentence embedder concept_data = item["hasConcept"][0] # Assuming it has only one concept terminology_data = concept_data["hasTerminology"][0] # Assuming it has only one terminology terminology = Terminology( @@ -195,7 +196,7 @@ def get_closest_mappings_with_similarities(self, embedding, limit=5) -> List[Tup for item in result['data']['Get']['Mapping']: similarity = 1 - item["_additional"]["distance"] embedding_vector = item["_additional"]["vector"] - sentence_embedder_data = item["hasSentenceEmbedder"][0] + sentence_embedder_data = item["hasSentenceEmbedder"][0] # Assuming it has only one sentence embedder concept_data = item["hasConcept"][0] # Assuming it has only one concept terminology_data = concept_data["hasTerminology"][0] # Assuming it has only one terminology terminology = Terminology( @@ -221,6 +222,60 @@ def get_closest_mappings_with_similarities(self, embedding, limit=5) -> List[Tup except Exception as e: raise RuntimeError(f"Failed to fetch closest mappings with similarities: {e}") return mappings_with_similarities + + def get_terminology_and_model_specific_closest_mappings(self, embedding, terminology_name: str, sentence_embedder_name: str, limit: int = 5) -> List[Tuple[Mapping, float]]: + mappings_with_similarities = [] + try: + result = self.client.query.get( + "Mapping", + ["text", + "_additional { distance }", + "hasConcept { ... on Concept { _additional { id } conceptID prefLabel hasTerminology { ... on Terminology { _additional { id } name } } } }", + "hasSentenceEmbedder { ... on SentenceEmbedder { name } }"] + ).with_where({ + "operator": "And", + "operands": [ + { + "path": ["hasSentenceEmbedder", "SentenceEmbedder", "name"], + "operator": "Equal", + "valueText": sentence_embedder_name + }, + { + "path": ["hasConcept", "Concept", "hasTerminology", "Terminology", "name"], + "operator": "Equal", + "valueText": terminology_name + } + ] + }).with_additional("vector").with_near_vector({"vector": embedding}).with_limit(limit).do() + for item in result['data']['Get']['Mapping']: + similarity = 1 - item["_additional"]["distance"] + embedding_vector = item["_additional"]["vector"] + sentence_embedder_data = item["hasSentenceEmbedder"][0] + concept_data = item["hasConcept"][0] # Assuming it has only one concept + terminology_data = concept_data["hasTerminology"][0] + terminology = Terminology( + name=terminology_data["name"], + id=terminology_data["_additional"]["id"] + ) + concept = Concept( + concept_identifier=concept_data["conceptID"], + pref_label=concept_data["prefLabel"], + terminology=terminology, + id=concept_data["_additional"]["id"] + ) + sentence_embedder = SentenceEmbedder( + name=sentence_embedder_data["name"] + ) + mapping = Mapping( + text=item["text"], + concept=concept, + embedding=embedding_vector, + sentence_embedder=sentence_embedder + ) + mappings_with_similarities.append((mapping, similarity)) + except Exception as e: + raise RuntimeError(f"Failed to fetch the closest mappings for terminology {terminology_name} and model {sentence_embedder_name}: {e}") + return mappings_with_similarities def shut_down(self): if self.mode == "memory": diff --git a/tests/test_weaviate_repository.py b/tests/test_weaviate_repository.py index 0024943..905cf90 100644 --- a/tests/test_weaviate_repository.py +++ b/tests/test_weaviate_repository.py @@ -12,54 +12,58 @@ def test_repository(self): repository = WeaviateRepository(mode="disk", path="db") - embedding_model = MPNetAdapter() - model_name = embedding_model.get_model_name() - sentence_embedder = SentenceEmbedder(model_name) + embedding_model1 = MPNetAdapter() + embedding_model2 = MPNetAdapter("FremyCompany/BioLORD-2023") + model_name1 = embedding_model1.get_model_name() + model_name2 = embedding_model2.get_model_name() + sentence_embedder1 = SentenceEmbedder(model_name1) + sentence_embedder2 = SentenceEmbedder(model_name2) - terminology = Terminology("snomed CT", "SNOMED") + terminology1 = Terminology("snomed CT", "SNOMED") + terminology2 = Terminology("NCI Thesaurus OBO Edition", "NCIT") text1 = "Diabetes mellitus (disorder)" - concept1 = Concept(terminology, text1, "Concept ID: 11893007") - mapping1 = Mapping(concept1, text1, embedding_model.get_embedding(text1), sentence_embedder) + concept1 = Concept(terminology1, text1, "Concept ID: 11893007") + mapping1 = Mapping(concept1, text1, embedding_model1.get_embedding(text1), sentence_embedder1) text2 = "Hypertension (disorder)" - concept2 = Concept(terminology, text2, "Concept ID: 73211009") - mapping2 = Mapping(concept2, text2, embedding_model.get_embedding(text2), sentence_embedder) + concept2 = Concept(terminology1, text2, "Concept ID: 73211009") + mapping2 = Mapping(concept2, text2, embedding_model2.get_embedding(text2), sentence_embedder2) text3 = "Asthma" - concept3 = Concept(terminology, text3, "Concept ID: 195967001") - mapping3 = Mapping(concept3, text3, embedding_model.get_embedding(text3), sentence_embedder) + concept3 = Concept(terminology1, text3, "Concept ID: 195967001") + mapping3 = Mapping(concept3, text3, embedding_model1.get_embedding(text3), sentence_embedder1) text4 = "Heart attack" - concept4 = Concept(terminology, text4, "Concept ID: 22298006") - mapping4 = Mapping(concept4, text4, embedding_model.get_embedding(text4), sentence_embedder) + concept4 = Concept(terminology1, text4, "Concept ID: 22298006") + mapping4 = Mapping(concept4, text4, embedding_model2.get_embedding(text4), sentence_embedder2) text5 = "Common cold" - concept5 = Concept(terminology, text5, "Concept ID: 13260007") - mapping5 = Mapping(concept5, text5, embedding_model.get_embedding(text5), sentence_embedder) + concept5 = Concept(terminology2, text5, "Concept ID: 13260007") + mapping5 = Mapping(concept5, text5, embedding_model1.get_embedding(text5), sentence_embedder1) text6 = "Stroke" - concept6 = Concept(terminology, text6, "Concept ID: 422504002") - mapping6 = Mapping(concept6, text6, embedding_model.get_embedding(text6), sentence_embedder) + concept6 = Concept(terminology2, text6, "Concept ID: 422504002") + mapping6 = Mapping(concept6, text6, embedding_model2.get_embedding(text6), sentence_embedder2) text7 = "Migraine" - concept7 = Concept(terminology, text7, "Concept ID: 386098009") - mapping7 = Mapping(concept7, text7, embedding_model.get_embedding(text7), sentence_embedder) + concept7 = Concept(terminology2, text7, "Concept ID: 386098009") + mapping7 = Mapping(concept7, text7, embedding_model1.get_embedding(text7), sentence_embedder1) text8 = "Influenza" - concept8 = Concept(terminology, text8, "Concept ID: 57386000") - mapping8 = Mapping(concept8, text8, embedding_model.get_embedding(text8), sentence_embedder) + concept8 = Concept(terminology2, text8, "Concept ID: 57386000") + mapping8 = Mapping(concept8, text8, embedding_model2.get_embedding(text8), sentence_embedder2) text9 = "Osteoarthritis" - concept9 = Concept(terminology, text9, "Concept ID: 399206004") - mapping9 = Mapping(concept9, text9, embedding_model.get_embedding(text9), sentence_embedder) + concept9 = Concept(terminology2, text9, "Concept ID: 399206004") + mapping9 = Mapping(concept9, text9, embedding_model1.get_embedding(text9), sentence_embedder1) text10 = "The flu" repository.store_all([ - sentence_embedder, terminology, concept1, mapping1, concept2, mapping2, concept3, mapping3, - concept4, mapping4, concept5, mapping5, concept6, mapping6, concept7, mapping7, - concept8, mapping8, concept9, mapping9 + sentence_embedder1, sentence_embedder2, terminology1, terminology2, concept1, mapping1, + concept2, mapping2, concept3, mapping3, concept4, mapping4, concept5, mapping5, concept6, + mapping6, concept7, mapping7, concept8, mapping8, concept9, mapping9 ]) mappings = repository.get_all_mappings(limit=5) @@ -69,30 +73,44 @@ def test_repository(self): self.assertEqual(len(concepts), 9) terminologies = repository.get_all_terminologies() - self.assertEqual(len(terminologies), 1) + terminology_names = [embedding.name for embedding in terminologies] + self.assertEqual(len(terminologies), 2) + self.assertIn("NCI Thesaurus OBO Edition", terminology_names) + self.assertIn("snomed CT", terminology_names) sentence_embedders = repository.get_all_sentence_embedders() - self.assertEqual(len(sentence_embedders), 1) - self.assertEqual(sentence_embedders[0].name, "sentence-transformers/all-mpnet-base-v2") + embedder_names = [embedder.name for embedder in sentence_embedders] + self.assertEqual(len(sentence_embedders), 2) + self.assertIn(model_name1, embedder_names) + self.assertIn(model_name2, embedder_names) - test_embedding = embedding_model.get_embedding(text10) + test_embedding = embedding_model1.get_embedding(text10) closest_mappings = repository.get_closest_mappings(test_embedding) self.assertEqual(len(closest_mappings), 5) - self.assertEqual(closest_mappings[0].text, "Influenza") + self.assertEqual(closest_mappings[0].text, "Common cold") + self.assertEqual(closest_mappings[0].sentence_embedder.name, model_name1) closest_mappings_with_similarities = repository.get_closest_mappings_with_similarities(test_embedding) self.assertEqual(len(closest_mappings_with_similarities), 5) - self.assertEqual(closest_mappings_with_similarities[0][0].text, "Influenza") - self.assertEqual(closest_mappings_with_similarities[0][1], 0.86187172) + self.assertEqual(closest_mappings_with_similarities[0][0].text, "Common cold") + self.assertEqual(closest_mappings_with_similarities[0][0].sentence_embedder.name, model_name1) + self.assertEqual(closest_mappings_with_similarities[0][1], 0.6747197) + + terminology_and_model_specific_closest_mappings = repository.get_terminology_and_model_specific_closest_mappings(test_embedding, "snomed CT", model_name1) + self.assertEqual(len(terminology_and_model_specific_closest_mappings), 2) + self.assertEqual(closest_mappings_with_similarities[0][0].text, "Common cold") + self.assertEqual(terminology_and_model_specific_closest_mappings[0][0].concept.terminology.name, "snomed CT") + self.assertEqual(terminology_and_model_specific_closest_mappings[0][0].sentence_embedder.name, model_name1) + self.assertEqual(closest_mappings_with_similarities[0][1], 0.6747197) # check if it crashed (due to schema re-creation) after restart repository = WeaviateRepository(mode="disk", path="db") # try to store all again (should not create new entries since they already exist) repository.store_all([ - sentence_embedder, terminology, concept1, mapping1, concept2, mapping2, concept3, mapping3, - concept4, mapping4, concept5, mapping5, concept6, mapping6, concept7, mapping7, - concept8, mapping8, concept9, mapping9 + sentence_embedder1, sentence_embedder2, terminology1, terminology2, concept1, + mapping1, concept2, mapping2, concept3, mapping3, concept4, mapping4, concept5, + mapping5, concept6, mapping6, concept7, mapping7, concept8, mapping8, concept9, mapping9 ]) From f360f1827a6e5031e963823e6438051aff9d0fdb Mon Sep 17 00:00:00 2001 From: Mehmet Can Ay Date: Tue, 20 Aug 2024 12:20:34 +0200 Subject: [PATCH 5/8] refactor: rename model attribute of GPT adapter --- datastew/embedding.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/datastew/embedding.py b/datastew/embedding.py index d1767e9..0e74e82 100644 --- a/datastew/embedding.py +++ b/datastew/embedding.py @@ -20,10 +20,10 @@ def sanitize(self, message: str) -> str: class GPT4Adapter(EmbeddingModel): - def __init__(self, api_key: str, model: str = "text-embedding-ada-002"): + def __init__(self, api_key: str, model_name: str = "text-embedding-ada-002"): self.api_key = api_key openai.api_key = api_key - self.model = model + self.model_name = model_name logging.getLogger().setLevel(logging.INFO) def get_embedding(self, text: str): @@ -35,7 +35,7 @@ def get_embedding(self, text: str): if isinstance(text, str): text = text.replace("\n", " ") text = self.sanitize(text) - return openai.Embedding.create(input=[text], model=self.model)["data"][0]["embedding"] + return openai.Embedding.create(input=[text], model=self.model_name)["data"][0]["embedding"] except Exception as e: logging.error(f"Error getting embedding for {text}: {e}") return None @@ -48,20 +48,20 @@ def get_embeddings(self, messages: [str], max_length=2048): for i in range(0, len(sanitized_messages), max_length): current_chunk += 1 chunk = sanitized_messages[i:i + max_length] - response = openai.Embedding.create(input=chunk, model=self.model) + response = openai.Embedding.create(input=chunk, model=self.model_name) embeddings.extend([item["embedding"] for item in response["data"]]) logging.info("Processed chunk %d/%d", current_chunk, total_chunks) return embeddings def get_model_name(self) -> str: - return self.model + return self.model_name class MPNetAdapter(EmbeddingModel): - def __init__(self, model="sentence-transformers/all-mpnet-base-v2"): + def __init__(self, model_name="sentence-transformers/all-mpnet-base-v2"): logging.getLogger().setLevel(logging.INFO) - self.model = SentenceTransformer(model) - self.model_name = model # For Weaviate + self.model = SentenceTransformer(model_name) + self.model_name = model_name # For Weaviate def get_embedding(self, text: str): logging.info(f"Getting embedding for {text}") From 6aa7b4ece09a158d9db76b18e5502618ac339af4 Mon Sep 17 00:00:00 2001 From: Mehmet Can Ay Date: Tue, 20 Aug 2024 12:21:37 +0200 Subject: [PATCH 6/8] refactor: remove SentenceEmbedder class and make it an attribute of Mapping class --- datastew/process/ols.py | 5 +- datastew/repository/__init__.py | 3 +- datastew/repository/base.py | 4 +- datastew/repository/model.py | 14 +---- datastew/repository/sqllite.py | 6 +- datastew/repository/weaviate.py | 87 +++++++------------------- datastew/repository/weaviate_schema.py | 19 ++---- tests/test_embedding.py | 2 +- tests/test_sql_repository.py | 26 ++++---- tests/test_system.py | 4 +- tests/test_weaviate_repository.py | 51 +++++++-------- 11 files changed, 75 insertions(+), 146 deletions(-) diff --git a/datastew/process/ols.py b/datastew/process/ols.py index ba058df..5724fb0 100644 --- a/datastew/process/ols.py +++ b/datastew/process/ols.py @@ -2,7 +2,7 @@ import requests -from datastew.repository.model import SentenceEmbedder, Terminology, Concept, Mapping +from datastew.repository.model import Terminology, Concept, Mapping from datastew.embedding import EmbeddingModel from datastew.repository.base import BaseRepository @@ -53,10 +53,9 @@ def process(self): descriptions.append(term["label"]) embeddings = self.embedding_model.get_embeddings(descriptions) model_name = self.embedding_model.get_model_name() - sentence_embedder = SentenceEmbedder(model_name) for identifier, label, description, embedding in zip(identifiers, labels, descriptions, embeddings): concept = Concept(self.terminology, label, identifier) - mapping = Mapping(concept, description, embedding, sentence_embedder) + mapping = Mapping(concept, description, embedding, model_name) self.repository.store(concept) self.repository.store(mapping) except Exception as e: diff --git a/datastew/repository/__init__.py b/datastew/repository/__init__.py index 4678eb6..9e273ee 100644 --- a/datastew/repository/__init__.py +++ b/datastew/repository/__init__.py @@ -1,11 +1,10 @@ -from .model import Terminology, Concept, Mapping, SentenceEmbedder +from .model import Terminology, Concept, Mapping from .sqllite import SQLLiteRepository from .weaviate import WeaviateRepository __all__ = [ "Terminology", "Concept", - "SentenceEmbedder", "Mapping", "SQLLiteRepository", "WeaviateRepository" diff --git a/datastew/repository/base.py b/datastew/repository/base.py index 681bf1c..5bb345e 100644 --- a/datastew/repository/base.py +++ b/datastew/repository/base.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import List -from datastew.repository.model import Mapping, Concept, Terminology, SentenceEmbedder +from datastew.repository.model import Mapping, Concept, Terminology class BaseRepository(ABC): @@ -32,7 +32,7 @@ def get_all_mappings(self, limit=1000) -> List[Mapping]: pass @abstractmethod - def get_all_sentence_embedders(self) -> List[SentenceEmbedder]: + def get_all_sentence_embedders(self) -> List[str]: pass @abstractmethod diff --git a/datastew/repository/model.py b/datastew/repository/model.py index f0dc1a9..4afc4ef 100644 --- a/datastew/repository/model.py +++ b/datastew/repository/model.py @@ -17,15 +17,6 @@ def __init__(self, name: str, id: str) -> object: self.id = id -class SentenceEmbedder(Base): - __tablename__ = 'sentence_embedder' - id = Column(Integer, primary_key=True, autoincrement=True) - name = Column(String) - - def __init__(self, name: str) -> object: - self.name = name - - class Concept(Base): __tablename__ = 'concept' concept_identifier = Column(String, primary_key=True) @@ -50,10 +41,9 @@ class Mapping(Base): concept = relationship("Concept") text = Column(Text) embedding_json = Column(Text) - sentence_embedder_id = Column(String, ForeignKey('sentence_embedder.id')) - sentence_embedder = relationship("SentenceEmbedder") + sentence_embedder = Column(Text) - def __init__(self, concept: Concept, text: str, embedding: list, sentence_embedder: SentenceEmbedder) -> object: + def __init__(self, concept: Concept, text: str, embedding: list, sentence_embedder: str) -> object: self.concept = concept self.text = text if isinstance(embedding, np.ndarray): diff --git a/datastew/repository/sqllite.py b/datastew/repository/sqllite.py index 5c6f317..b3a4ced 100644 --- a/datastew/repository/sqllite.py +++ b/datastew/repository/sqllite.py @@ -6,7 +6,7 @@ from sqlalchemy import create_engine, func from sqlalchemy.orm import sessionmaker -from datastew.repository.model import Base, SentenceEmbedder, Terminology, Concept, Mapping +from datastew.repository.model import Base, Terminology, Concept, Mapping from datastew.repository.base import BaseRepository @@ -51,8 +51,8 @@ def get_all_mappings(self, limit=1000): mappings = self.session.query(Mapping).filter(Mapping.id.in_(random_indices)).all() return mappings - def get_all_sentence_embedders(self) -> List[SentenceEmbedder]: - return self.session.query(SentenceEmbedder).all() + def get_all_sentence_embedders(self) -> List[str]: + return [embedder for embedder, in self.session.query(Mapping.sentence_embedder).distinct().all()] def get_closest_mappings(self, embedding: List[float], limit=5): mappings = self.session.query(Mapping).all() diff --git a/datastew/repository/weaviate.py b/datastew/repository/weaviate.py index 1038325..d537795 100644 --- a/datastew/repository/weaviate.py +++ b/datastew/repository/weaviate.py @@ -6,9 +6,9 @@ import weaviate from weaviate.embedded import EmbeddedOptions -from datastew.repository import Concept, Mapping, SentenceEmbedder, Terminology +from datastew.repository import Concept, Mapping, Terminology from datastew.repository.base import BaseRepository -from datastew.repository.weaviate_schema import concept_schema, mapping_schema, sentence_embedder_schema, terminology_schema +from datastew.repository.weaviate_schema import concept_schema, mapping_schema, terminology_schema class WeaviateRepository(BaseRepository): @@ -39,7 +39,6 @@ def __init__(self, mode="memory", path=None): raise ConnectionError(f"Failed to initialize Weaviate client: {e}") try: - self._create_schema_if_not_exists(sentence_embedder_schema) self._create_schema_if_not_exists(terminology_schema) self._create_schema_if_not_exists(concept_schema) self._create_schema_if_not_exists(mapping_schema) @@ -60,16 +59,15 @@ def store_all(self, model_object_instances): for instance in model_object_instances: self.store(instance) - def get_all_sentence_embedders(self) -> List[SentenceEmbedder]: - sentence_embedders = [] + def get_all_sentence_embedders(self) -> List[str]: + sentence_embedders = set() try: - result = self.client.query.get("SentenceEmbedder", ["name"]).do() - for item in result['data']['Get']['SentenceEmbedder']: - sentence_embedder = SentenceEmbedder(item["name"]) - sentence_embedders.append(sentence_embedder) + result = self.client.query.get("Mapping", ["hasSentenceEmbedder"]).do() + for item in result['data']['Get']['Mapping']: + sentence_embedders.add(item["hasSentenceEmbedder"]) except Exception as e: raise RuntimeError(f"Failed to fetch terminologies: {e}") - return sentence_embedders + return list(sentence_embedders) def get_all_concepts(self) -> List[Concept]: concepts = [] @@ -114,12 +112,11 @@ def get_all_mappings(self, limit=1000) -> List[Mapping]: result = self.client.query.get( "Mapping", ["text", - "hasConcept { ... on Concept { _additional { id } conceptID prefLabel hasTerminology { ... on Terminology { _additional { id } name } } } }", - "hasSentenceEmbedder { ... on SentenceEmbedder { name } }"] + "hasSentenceEmbedder", + "hasConcept { ... on Concept { _additional { id } conceptID prefLabel hasTerminology { ... on Terminology { _additional { id } name } } } }"] ).with_additional("vector").with_limit(limit).do() for item in result['data']['Get']['Mapping']: embedding_vector = item["_additional"]["vector"] - sentence_embedder_data = item["hasSentenceEmbedder"][0] concept_data = item["hasConcept"][0] # Assuming it has only one concept terminology_data = concept_data["hasTerminology"][0] # Assuming it has only one terminology terminology = Terminology( @@ -132,14 +129,11 @@ def get_all_mappings(self, limit=1000) -> List[Mapping]: terminology=terminology, id=concept_data["_additional"]["id"] ) - sentence_embedder = SentenceEmbedder( - name=sentence_embedder_data["name"] - ) mapping = Mapping( text=item["text"], concept=concept, embedding=embedding_vector, - sentence_embedder=sentence_embedder + sentence_embedder=item["hasSentenceEmbedder"] ) mappings.append(mapping) except Exception as e: @@ -153,11 +147,10 @@ def get_closest_mappings(self, embedding, limit=5) -> List[Mapping]: "Mapping", ["text", "_additional { distance }", "hasConcept { ... on Concept { _additional { id } conceptID prefLabel hasTerminology { ... on Terminology { _additional { id } name } } } }", - "hasSentenceEmbedder { ... on SentenceEmbedder { name } }"] + "hasSentenceEmbedder"] ).with_additional("vector").with_near_vector({"vector": embedding}).with_limit(limit).do() for item in result['data']['Get']['Mapping']: embedding_vector = item["_additional"]["vector"] - sentence_embedder_data = item["hasSentenceEmbedder"][0] # Assuming it has only one sentence embedder concept_data = item["hasConcept"][0] # Assuming it has only one concept terminology_data = concept_data["hasTerminology"][0] # Assuming it has only one terminology terminology = Terminology( @@ -170,14 +163,11 @@ def get_closest_mappings(self, embedding, limit=5) -> List[Mapping]: terminology=terminology, id=concept_data["_additional"]["id"] ) - sentence_embedder = SentenceEmbedder( - name=sentence_embedder_data["name"] - ) mapping = Mapping( text=item["text"], concept=concept, embedding=embedding_vector, - sentence_embedder=sentence_embedder + sentence_embedder=item["hasSentenceEmbedder"] ) mappings.append(mapping) except Exception as e: @@ -191,12 +181,11 @@ def get_closest_mappings_with_similarities(self, embedding, limit=5) -> List[Tup "Mapping", ["text", "_additional { distance }", "hasConcept { ... on Concept { _additional { id } conceptID prefLabel hasTerminology { ... on Terminology { _additional { id } name } } } }", - "hasSentenceEmbedder { ... on SentenceEmbedder { name } }"] + "hasSentenceEmbedder"] ).with_additional("vector").with_near_vector({"vector": embedding}).with_limit(limit).do() for item in result['data']['Get']['Mapping']: similarity = 1 - item["_additional"]["distance"] embedding_vector = item["_additional"]["vector"] - sentence_embedder_data = item["hasSentenceEmbedder"][0] # Assuming it has only one sentence embedder concept_data = item["hasConcept"][0] # Assuming it has only one concept terminology_data = concept_data["hasTerminology"][0] # Assuming it has only one terminology terminology = Terminology( @@ -209,14 +198,11 @@ def get_closest_mappings_with_similarities(self, embedding, limit=5) -> List[Tup terminology=terminology, id=concept_data["_additional"]["id"] ) - sentence_embedder = SentenceEmbedder( - name=sentence_embedder_data["name"] - ) mapping = Mapping( text=item["text"], concept=concept, embedding=embedding_vector, - sentence_embedder=sentence_embedder + sentence_embedder=item["hasSentenceEmbedder"] ) mappings_with_similarities.append((mapping, similarity)) except Exception as e: @@ -231,12 +217,12 @@ def get_terminology_and_model_specific_closest_mappings(self, embedding, termino ["text", "_additional { distance }", "hasConcept { ... on Concept { _additional { id } conceptID prefLabel hasTerminology { ... on Terminology { _additional { id } name } } } }", - "hasSentenceEmbedder { ... on SentenceEmbedder { name } }"] + "hasSentenceEmbedder"] ).with_where({ "operator": "And", "operands": [ { - "path": ["hasSentenceEmbedder", "SentenceEmbedder", "name"], + "path": ["hasSentenceEmbedder"], "operator": "Equal", "valueText": sentence_embedder_name }, @@ -250,7 +236,6 @@ def get_terminology_and_model_specific_closest_mappings(self, embedding, termino for item in result['data']['Get']['Mapping']: similarity = 1 - item["_additional"]["distance"] embedding_vector = item["_additional"]["vector"] - sentence_embedder_data = item["hasSentenceEmbedder"][0] concept_data = item["hasConcept"][0] # Assuming it has only one concept terminology_data = concept_data["hasTerminology"][0] terminology = Terminology( @@ -263,14 +248,11 @@ def get_terminology_and_model_specific_closest_mappings(self, embedding, termino terminology=terminology, id=concept_data["_additional"]["id"] ) - sentence_embedder = SentenceEmbedder( - name=sentence_embedder_data["name"] - ) mapping = Mapping( text=item["text"], concept=concept, embedding=embedding_vector, - sentence_embedder=sentence_embedder + sentence_embedder=item["hasSentenceEmbedder"] ) mappings_with_similarities.append((mapping, similarity)) except Exception as e: @@ -281,21 +263,11 @@ def shut_down(self): if self.mode == "memory": shutil.rmtree("db") - def store(self, model_object_instance: Union[Terminology, Concept, Mapping, SentenceEmbedder]): + def store(self, model_object_instance: Union[Terminology, Concept, Mapping]): random_uuid = uuid.uuid4() model_object_instance.id = random_uuid try: - if isinstance(model_object_instance, SentenceEmbedder): - if not self._sentence_embedder_exists(model_object_instance.name): - properties = { - "name": model_object_instance.name - } - self.client.data_object.create( - class_name="SentenceEmbedder", - data_object=properties, - uuid=random_uuid - ) - elif isinstance(model_object_instance, Terminology): + if isinstance(model_object_instance, Terminology): if not self._terminology_exists(model_object_instance.name): properties = { "name": model_object_instance.name @@ -333,6 +305,7 @@ def store(self, model_object_instance: Union[Terminology, Concept, Mapping, Sent if not self._mapping_exists(model_object_instance.embedding): properties = { "text": model_object_instance.text, + "hasSentenceEmbedder": model_object_instance.sentence_embedder } self.client.data_object.create( class_name="Mapping", @@ -347,13 +320,6 @@ def store(self, model_object_instance: Union[Terminology, Concept, Mapping, Sent to_class_name="Concept", to_uuid=model_object_instance.concept.uuid, ) - self.client.data_object.reference.add( - from_class_name="Mapping", - from_uuid=random_uuid, - from_property_name="hasSentenceEmbedder", - to_class_name="SentenceEmbedder", - to_uuid=model_object_instance.sentence_embedder.id, - ) else: self.logger.info(f'Mapping with same embedding already exists. Skipping.') else: @@ -361,17 +327,6 @@ def store(self, model_object_instance: Union[Terminology, Concept, Mapping, Sent except Exception as e: raise RuntimeError(f"Failed to store object in Weaviate: {e}") - - def _sentence_embedder_exists(self, name: str) -> bool: - try: - result = self.client.query.get("SentenceEmbedder", ["name"]).with_where({ - "path": ["name"], - "operator": "Equal", - "valueText": name - }).do() - return len(result['data']['Get']['SentenceEmbedder']) > 0 - except Exception as e: - raise RuntimeError(f"Failed to check if sentence embedder exists: {e}") def _terminology_exists(self, name: str) -> bool: try: diff --git a/datastew/repository/weaviate_schema.py b/datastew/repository/weaviate_schema.py index 1e7551e..b97f57a 100644 --- a/datastew/repository/weaviate_schema.py +++ b/datastew/repository/weaviate_schema.py @@ -1,14 +1,3 @@ -sentence_embedder_schema = { - "class": "SentenceEmbedder", - "description": "A sentence embedder model entry", - "properties": [ - { - "name": "name", - "dataType": ["string"] - } - ] -} - terminology_schema = { "class": "Terminology", "description": "A terminology entry", @@ -47,6 +36,10 @@ "name": "text", "dataType": ["string"] }, + { + "name": "hasSentenceEmbedder", + "dataType": ["string"] + }, { "name": "vector", "dataType": ["number[]"] @@ -54,10 +47,6 @@ { "name": "hasConcept", "dataType": ["Concept"] - }, - { - "name": "hasSentenceEmbedder", - "dataType": ["SentenceEmbedder"] } ] } \ No newline at end of file diff --git a/tests/test_embedding.py b/tests/test_embedding.py index 976db2a..9e171c7 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -5,7 +5,7 @@ class TestEmbedding(unittest.TestCase): def setUp(self): - self.mpnet_adapter = MPNetAdapter(model="sentence-transformers/all-mpnet-base-v2") + self.mpnet_adapter = MPNetAdapter(model_name="sentence-transformers/all-mpnet-base-v2") def test_mpnet_adapter_get_embedding(self): text = "This is a test sentence." diff --git a/tests/test_sql_repository.py b/tests/test_sql_repository.py index 1f6ea78..0e466eb 100644 --- a/tests/test_sql_repository.py +++ b/tests/test_sql_repository.py @@ -1,6 +1,6 @@ import unittest -from datastew.repository.model import Terminology, Concept, Mapping, SentenceEmbedder +from datastew.repository.model import Terminology, Concept, Mapping from datastew.repository.sqllite import SQLLiteRepository @@ -14,12 +14,12 @@ def tearDown(self): def test_get_closest_mappings(self): terminology = Terminology(name="Terminology 1", id="1") - sentence_embedder = SentenceEmbedder(name="sentence-transformers/all-mpnet-base-v2") + model_name = "sentence-transformers/all-mpnet-base-v2" concept = Concept(terminology=terminology, pref_label="Concept 1", concept_identifier="1") - mapping_1 = Mapping(concept=concept, text="Text 1", embedding=[0.1, 0.2, 0.3], sentence_embedder=sentence_embedder) - mapping_2 = Mapping(concept=concept, text="Text 2", embedding=[0.2, 0.3, 0.4], sentence_embedder=sentence_embedder) - mapping_3 = Mapping(concept=concept, text="Text 3", embedding=[1.2, 2.3, 3.4], sentence_embedder=sentence_embedder) - self.repository.store_all([terminology, concept, mapping_1, mapping_2, mapping_3, sentence_embedder]) + mapping_1 = Mapping(concept=concept, text="Text 1", embedding=[0.1, 0.2, 0.3], sentence_embedder=model_name) + mapping_2 = Mapping(concept=concept, text="Text 2", embedding=[0.2, 0.3, 0.4], sentence_embedder=model_name) + mapping_3 = Mapping(concept=concept, text="Text 3", embedding=[1.2, 2.3, 3.4], sentence_embedder=model_name) + self.repository.store_all([terminology, concept, mapping_1, mapping_2, mapping_3]) sample_embedding = [0.2, 0.4, 0.35] closest_mappings, distances = self.repository.get_closest_mappings(sample_embedding, limit=3) self.assertEqual(len(closest_mappings), 3) @@ -27,13 +27,13 @@ def test_get_closest_mappings(self): def test_get_all_sentence_embedders(self): terminology = Terminology(name="Terminology 1", id="1") - sentence_embedder_1 = SentenceEmbedder(name="sentence-transformers/all-mpnet-base-v2") - sentence_embedder_2 = SentenceEmbedder(name="text-embedding-ada-002") + model_name_1 = "sentence-transformers/all-mpnet-base-v2" + model_name_2 = "text-embedding-ada-002" concept = Concept(terminology=terminology, pref_label="Concept 1", concept_identifier="1") - mapping_1 = Mapping(concept=concept, text="Text 1", embedding=[0.1, 0.2, 0.3], sentence_embedder=sentence_embedder_1) - mapping_2 = Mapping(concept=concept, text="Text 1", embedding=[0.1, 0.2, 0.3], sentence_embedder=sentence_embedder_2) - self.repository.store_all([terminology, concept, mapping_1, mapping_2, sentence_embedder_1, sentence_embedder_2]) + mapping_1 = Mapping(concept=concept, text="Text 1", embedding=[0.1, 0.2, 0.3], sentence_embedder=model_name_1) + mapping_2 = Mapping(concept=concept, text="Text 1", embedding=[0.1, 0.2, 0.3], sentence_embedder=model_name_2) + self.repository.store_all([terminology, concept, mapping_1, mapping_2]) sentence_embedders = self.repository.get_all_sentence_embedders() self.assertEqual(len(sentence_embedders), 2) - self.assertEqual(sentence_embedders[0].name, "sentence-transformers/all-mpnet-base-v2") - self.assertEqual(sentence_embedders[1].name, "text-embedding-ada-002") + self.assertEqual(sentence_embedders[0], "sentence-transformers/all-mpnet-base-v2") + self.assertEqual(sentence_embedders[1], "text-embedding-ada-002") diff --git a/tests/test_system.py b/tests/test_system.py index f82b631..00172a5 100644 --- a/tests/test_system.py +++ b/tests/test_system.py @@ -1,6 +1,6 @@ import unittest -from datastew.repository.model import Terminology, Concept, Mapping, SentenceEmbedder +from datastew.repository.model import Terminology, Concept, Mapping from datastew.embedding import MPNetAdapter from datastew.repository.sqllite import SQLLiteRepository @@ -19,7 +19,7 @@ def test_mapping_storage_and_closest_retrieval(self): terminology = Terminology("test", "test") concept1 = Concept(terminology, "cat", "TEST:1") concept1_description = "The cat is sitting on the mat." - sentence_embedder = SentenceEmbedder("test") + sentence_embedder = "test" mapping1 = Mapping(concept1, concept1_description, self.embedding_model.get_embedding(concept1_description), sentence_embedder=sentence_embedder) concept2 = Concept(terminology, "sunrise", "TEST:2") concept2_description = "The sun rises in the east." diff --git a/tests/test_weaviate_repository.py b/tests/test_weaviate_repository.py index 905cf90..d63f8ab 100644 --- a/tests/test_weaviate_repository.py +++ b/tests/test_weaviate_repository.py @@ -2,68 +2,66 @@ from unittest import TestCase from datastew import MPNetAdapter -from datastew.repository import Terminology, Concept, Mapping, SentenceEmbedder +from datastew.repository import Terminology, Concept, Mapping from datastew.repository.weaviate import WeaviateRepository class Test(TestCase): - @unittest.skip("currently broken on github workflows") + #@unittest.skip("currently broken on github workflows") def test_repository(self): - repository = WeaviateRepository(mode="disk", path="db") + repository = WeaviateRepository(mode="remote", path="http://localhost:8080") embedding_model1 = MPNetAdapter() embedding_model2 = MPNetAdapter("FremyCompany/BioLORD-2023") model_name1 = embedding_model1.get_model_name() model_name2 = embedding_model2.get_model_name() - sentence_embedder1 = SentenceEmbedder(model_name1) - sentence_embedder2 = SentenceEmbedder(model_name2) terminology1 = Terminology("snomed CT", "SNOMED") terminology2 = Terminology("NCI Thesaurus OBO Edition", "NCIT") text1 = "Diabetes mellitus (disorder)" concept1 = Concept(terminology1, text1, "Concept ID: 11893007") - mapping1 = Mapping(concept1, text1, embedding_model1.get_embedding(text1), sentence_embedder1) + mapping1 = Mapping(concept1, text1, embedding_model1.get_embedding(text1), model_name1) text2 = "Hypertension (disorder)" concept2 = Concept(terminology1, text2, "Concept ID: 73211009") - mapping2 = Mapping(concept2, text2, embedding_model2.get_embedding(text2), sentence_embedder2) + mapping2 = Mapping(concept2, text2, embedding_model2.get_embedding(text2), model_name2) text3 = "Asthma" concept3 = Concept(terminology1, text3, "Concept ID: 195967001") - mapping3 = Mapping(concept3, text3, embedding_model1.get_embedding(text3), sentence_embedder1) + mapping3 = Mapping(concept3, text3, embedding_model1.get_embedding(text3), model_name1) text4 = "Heart attack" concept4 = Concept(terminology1, text4, "Concept ID: 22298006") - mapping4 = Mapping(concept4, text4, embedding_model2.get_embedding(text4), sentence_embedder2) + mapping4 = Mapping(concept4, text4, embedding_model2.get_embedding(text4), model_name2) text5 = "Common cold" concept5 = Concept(terminology2, text5, "Concept ID: 13260007") - mapping5 = Mapping(concept5, text5, embedding_model1.get_embedding(text5), sentence_embedder1) + mapping5 = Mapping(concept5, text5, embedding_model1.get_embedding(text5), model_name1) text6 = "Stroke" concept6 = Concept(terminology2, text6, "Concept ID: 422504002") - mapping6 = Mapping(concept6, text6, embedding_model2.get_embedding(text6), sentence_embedder2) + mapping6 = Mapping(concept6, text6, embedding_model2.get_embedding(text6), model_name2) text7 = "Migraine" concept7 = Concept(terminology2, text7, "Concept ID: 386098009") - mapping7 = Mapping(concept7, text7, embedding_model1.get_embedding(text7), sentence_embedder1) + mapping7 = Mapping(concept7, text7, embedding_model1.get_embedding(text7), model_name1) text8 = "Influenza" concept8 = Concept(terminology2, text8, "Concept ID: 57386000") - mapping8 = Mapping(concept8, text8, embedding_model2.get_embedding(text8), sentence_embedder2) + mapping8 = Mapping(concept8, text8, embedding_model2.get_embedding(text8), model_name2) text9 = "Osteoarthritis" concept9 = Concept(terminology2, text9, "Concept ID: 399206004") - mapping9 = Mapping(concept9, text9, embedding_model1.get_embedding(text9), sentence_embedder1) + mapping9 = Mapping(concept9, text9, embedding_model1.get_embedding(text9), model_name1) text10 = "The flu" repository.store_all([ - sentence_embedder1, sentence_embedder2, terminology1, terminology2, concept1, mapping1, - concept2, mapping2, concept3, mapping3, concept4, mapping4, concept5, mapping5, concept6, - mapping6, concept7, mapping7, concept8, mapping8, concept9, mapping9 + terminology1, terminology2, concept1, mapping1, concept2, mapping2, concept3, mapping3, concept4, + mapping4, concept5, mapping5, concept6, mapping6, concept7, mapping7, concept8, mapping8, concept9, + mapping9 ]) mappings = repository.get_all_mappings(limit=5) @@ -79,38 +77,37 @@ def test_repository(self): self.assertIn("snomed CT", terminology_names) sentence_embedders = repository.get_all_sentence_embedders() - embedder_names = [embedder.name for embedder in sentence_embedders] self.assertEqual(len(sentence_embedders), 2) - self.assertIn(model_name1, embedder_names) - self.assertIn(model_name2, embedder_names) + self.assertIn(model_name1, sentence_embedders) + self.assertIn(model_name2, sentence_embedders) test_embedding = embedding_model1.get_embedding(text10) closest_mappings = repository.get_closest_mappings(test_embedding) self.assertEqual(len(closest_mappings), 5) self.assertEqual(closest_mappings[0].text, "Common cold") - self.assertEqual(closest_mappings[0].sentence_embedder.name, model_name1) + self.assertEqual(closest_mappings[0].sentence_embedder, model_name1) closest_mappings_with_similarities = repository.get_closest_mappings_with_similarities(test_embedding) self.assertEqual(len(closest_mappings_with_similarities), 5) self.assertEqual(closest_mappings_with_similarities[0][0].text, "Common cold") - self.assertEqual(closest_mappings_with_similarities[0][0].sentence_embedder.name, model_name1) + self.assertEqual(closest_mappings_with_similarities[0][0].sentence_embedder, model_name1) self.assertEqual(closest_mappings_with_similarities[0][1], 0.6747197) terminology_and_model_specific_closest_mappings = repository.get_terminology_and_model_specific_closest_mappings(test_embedding, "snomed CT", model_name1) self.assertEqual(len(terminology_and_model_specific_closest_mappings), 2) self.assertEqual(closest_mappings_with_similarities[0][0].text, "Common cold") self.assertEqual(terminology_and_model_specific_closest_mappings[0][0].concept.terminology.name, "snomed CT") - self.assertEqual(terminology_and_model_specific_closest_mappings[0][0].sentence_embedder.name, model_name1) + self.assertEqual(terminology_and_model_specific_closest_mappings[0][0].sentence_embedder, model_name1) self.assertEqual(closest_mappings_with_similarities[0][1], 0.6747197) # check if it crashed (due to schema re-creation) after restart - repository = WeaviateRepository(mode="disk", path="db") + repository = WeaviateRepository(mode="remote", path="http://localhost:8080") # try to store all again (should not create new entries since they already exist) repository.store_all([ - sentence_embedder1, sentence_embedder2, terminology1, terminology2, concept1, - mapping1, concept2, mapping2, concept3, mapping3, concept4, mapping4, concept5, - mapping5, concept6, mapping6, concept7, mapping7, concept8, mapping8, concept9, mapping9 + terminology1, terminology2, concept1, mapping1, concept2, mapping2, concept3, mapping3, + concept4, mapping4, concept5, mapping5, concept6, mapping6, concept7, mapping7, concept8, + mapping8, concept9, mapping9 ]) From 29855f56a033d29fff191cb80454422b19d58578 Mon Sep 17 00:00:00 2001 From: Mehmet Can Ay Date: Tue, 20 Aug 2024 12:26:13 +0200 Subject: [PATCH 7/8] refactor: minor changes --- tests/test_weaviate_repository.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_weaviate_repository.py b/tests/test_weaviate_repository.py index d63f8ab..f5c0d80 100644 --- a/tests/test_weaviate_repository.py +++ b/tests/test_weaviate_repository.py @@ -7,10 +7,10 @@ class Test(TestCase): - #@unittest.skip("currently broken on github workflows") + @unittest.skip("currently broken on github workflows") def test_repository(self): - repository = WeaviateRepository(mode="remote", path="http://localhost:8080") + repository = WeaviateRepository(mode="disk", path="db") embedding_model1 = MPNetAdapter() embedding_model2 = MPNetAdapter("FremyCompany/BioLORD-2023") @@ -102,7 +102,7 @@ def test_repository(self): self.assertEqual(closest_mappings_with_similarities[0][1], 0.6747197) # check if it crashed (due to schema re-creation) after restart - repository = WeaviateRepository(mode="remote", path="http://localhost:8080") + repository = WeaviateRepository(mode="disk", path="db") # try to store all again (should not create new entries since they already exist) repository.store_all([ From 104557b8f15bc814c0d22e3667e32ccd9ac357bc Mon Sep 17 00:00:00 2001 From: Mehmet Can Ay Date: Wed, 21 Aug 2024 13:38:43 +0200 Subject: [PATCH 8/8] refactor: fix typo --- datastew/repository/weaviate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datastew/repository/weaviate.py b/datastew/repository/weaviate.py index d537795..0842e7f 100644 --- a/datastew/repository/weaviate.py +++ b/datastew/repository/weaviate.py @@ -66,7 +66,7 @@ def get_all_sentence_embedders(self) -> List[str]: for item in result['data']['Get']['Mapping']: sentence_embedders.add(item["hasSentenceEmbedder"]) except Exception as e: - raise RuntimeError(f"Failed to fetch terminologies: {e}") + raise RuntimeError(f"Failed to fetch sentence embedders: {e}") return list(sentence_embedders) def get_all_concepts(self) -> List[Concept]: