Skip to content

Commit

Permalink
Merge pull request #15 from SCAI-BIO/extend-weaviate-schema
Browse files Browse the repository at this point in the history
Extend repository schema to include sentence embedder
  • Loading branch information
tiadams authored Aug 21, 2024
2 parents a1db41e + 104557b commit fa768b9
Show file tree
Hide file tree
Showing 11 changed files with 188 additions and 61 deletions.
29 changes: 20 additions & 9 deletions datastew/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,21 @@ def get_embedding(self, text: str) -> [float]:
def get_embeddings(self, messages: [str]) -> [[float]]:
pass

def get_model_name(self) -> str:
pass

def sanitize(self, message: str) -> str:
return message.strip().lower()


class GPT4Adapter(EmbeddingModel):
def __init__(self, api_key: str):
def __init__(self, api_key: str, model_name: str = "text-embedding-ada-002"):
self.api_key = api_key
openai.api_key = api_key
self.model_name = model_name
logging.getLogger().setLevel(logging.INFO)

def get_embedding(self, text: str, model="text-embedding-ada-002"):
def get_embedding(self, text: str):
logging.info(f"Getting embedding for {text}")
try:
if text is None or text == "" or text is np.nan:
Expand All @@ -31,29 +35,33 @@ def get_embedding(self, text: str, model="text-embedding-ada-002"):
if isinstance(text, str):
text = text.replace("\n", " ")
text = self.sanitize(text)
return openai.Embedding.create(input=[text], model=model)["data"][0]["embedding"]
return openai.Embedding.create(input=[text], model=self.model_name)["data"][0]["embedding"]
except Exception as e:
logging.error(f"Error getting embedding for {text}: {e}")
return None

def get_embeddings(self, messages: [str], model="text-embedding-ada-002", max_length=2048):
def get_embeddings(self, messages: [str], max_length=2048):
sanitized_messages = [self.sanitize(message) for message in messages]
embeddings = []
total_chunks = (len(sanitized_messages) + max_length - 1) // max_length
current_chunk = 0
for i in range(0, len(sanitized_messages), max_length):
current_chunk += 1
chunk = sanitized_messages[i:i + max_length]
response = openai.Embedding.create(input=chunk, model=model)
response = openai.Embedding.create(input=chunk, model=self.model_name)
embeddings.extend([item["embedding"] for item in response["data"]])
logging.info("Processed chunk %d/%d", current_chunk, total_chunks)
return embeddings

def get_model_name(self) -> str:
return self.model_name


class MPNetAdapter(EmbeddingModel):
def __init__(self, model="sentence-transformers/all-mpnet-base-v2"):
def __init__(self, model_name="sentence-transformers/all-mpnet-base-v2"):
logging.getLogger().setLevel(logging.INFO)
self.mpnet_model = SentenceTransformer(model)
self.model = SentenceTransformer(model_name)
self.model_name = model_name # For Weaviate

def get_embedding(self, text: str):
logging.info(f"Getting embedding for {text}")
Expand All @@ -64,19 +72,22 @@ def get_embedding(self, text: str):
if isinstance(text, str):
text = text.replace("\n", " ")
text = self.sanitize(text)
return self.mpnet_model.encode(text)
return self.model.encode(text)
except Exception as e:
logging.error(f"Error getting embedding for {text}: {e}")
return None

def get_embeddings(self, messages: [str]) -> [[float]]:
sanitized_messages = [self.sanitize(message) for message in messages]
try:
embeddings = self.mpnet_model.encode(sanitized_messages)
embeddings = self.model.encode(sanitized_messages)
except Exception as e:
logging.error(f"Failed for messages {sanitized_messages}")
flattened_embeddings = [[float(element) for element in row] for row in embeddings]
return flattened_embeddings

def get_model_name(self) -> str:
return self.model_name


class TextEmbedding:
Expand Down
3 changes: 2 additions & 1 deletion datastew/process/ols.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@ def process(self):
else:
descriptions.append(term["label"])
embeddings = self.embedding_model.get_embeddings(descriptions)
model_name = self.embedding_model.get_model_name()
for identifier, label, description, embedding in zip(identifiers, labels, descriptions, embeddings):
concept = Concept(self.terminology, label, identifier)
mapping = Mapping(concept, description, embedding)
mapping = Mapping(concept, description, embedding, model_name)
self.repository.store(concept)
self.repository.store(mapping)
except Exception as e:
Expand Down
6 changes: 5 additions & 1 deletion datastew/repository/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,14 @@ def get_all_terminologies(self) -> List[Terminology]:
pass

@abstractmethod
def get_all_mappings(self, limit=1000) -> [Mapping]:
def get_all_mappings(self, limit=1000) -> List[Mapping]:
"""Get all embeddings up to a limit"""
pass

@abstractmethod
def get_all_sentence_embedders(self) -> List[str]:
pass

@abstractmethod
def get_closest_mappings(self, embedding, limit=5):
"""Get the closest mappings based on embedding."""
Expand Down
7 changes: 4 additions & 3 deletions datastew/repository/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

import numpy as np
from sqlalchemy import Column, ForeignKey, Integer, String, Text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship
from sqlalchemy.orm import relationship, declarative_base

Base = declarative_base()

Expand Down Expand Up @@ -42,13 +41,15 @@ class Mapping(Base):
concept = relationship("Concept")
text = Column(Text)
embedding_json = Column(Text)
sentence_embedder = Column(Text)

def __init__(self, concept: Concept, text: str, embedding: list) -> object:
def __init__(self, concept: Concept, text: str, embedding: list, sentence_embedder: str) -> object:
self.concept = concept
self.text = text
if isinstance(embedding, np.ndarray):
embedding = embedding.tolist()
self.embedding_json = json.dumps(embedding) # Store embedding as JSON
self.sentence_embedder = sentence_embedder

def __str__(self):
return f"{self.concept.terminology.name} > {self.concept.concept_identifier} : {self.concept.pref_label} | {self.text}"
Expand Down
4 changes: 3 additions & 1 deletion datastew/repository/sqllite.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import random
import sqlite3

import numpy as np

Expand Down Expand Up @@ -51,6 +50,9 @@ def get_all_mappings(self, limit=1000):
# Query for mappings corresponding to the random indices
mappings = self.session.query(Mapping).filter(Mapping.id.in_(random_indices)).all()
return mappings

def get_all_sentence_embedders(self) -> List[str]:
return [embedder for embedder, in self.session.query(Mapping.sentence_embedder).distinct().all()]

def get_closest_mappings(self, embedding: List[float], limit=5):
mappings = self.session.query(Mapping).all()
Expand Down
84 changes: 76 additions & 8 deletions datastew/repository/weaviate.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import logging
import shutil
from typing import List, Union, Tuple
import uuid as uuid
from typing import List, Tuple, Union

import weaviate
from weaviate.embedded import EmbeddedOptions

from datastew.repository import Mapping, Terminology, Concept
from datastew.repository import Concept, Mapping, Terminology
from datastew.repository.base import BaseRepository
from datastew.repository.weaviate_schema import terminology_schema, concept_schema, mapping_schema
from datastew.repository.weaviate_schema import concept_schema, mapping_schema, terminology_schema


class WeaviateRepository(BaseRepository):
Expand Down Expand Up @@ -58,6 +59,16 @@ def store_all(self, model_object_instances):
for instance in model_object_instances:
self.store(instance)

def get_all_sentence_embedders(self) -> List[str]:
sentence_embedders = set()
try:
result = self.client.query.get("Mapping", ["hasSentenceEmbedder"]).do()
for item in result['data']['Get']['Mapping']:
sentence_embedders.add(item["hasSentenceEmbedder"])
except Exception as e:
raise RuntimeError(f"Failed to fetch sentence embedders: {e}")
return list(sentence_embedders)

def get_all_concepts(self) -> List[Concept]:
concepts = []
try:
Expand Down Expand Up @@ -101,6 +112,7 @@ def get_all_mappings(self, limit=1000) -> List[Mapping]:
result = self.client.query.get(
"Mapping",
["text",
"hasSentenceEmbedder",
"hasConcept { ... on Concept { _additional { id } conceptID prefLabel hasTerminology { ... on Terminology { _additional { id } name } } } }"]
).with_additional("vector").with_limit(limit).do()
for item in result['data']['Get']['Mapping']:
Expand All @@ -120,7 +132,8 @@ def get_all_mappings(self, limit=1000) -> List[Mapping]:
mapping = Mapping(
text=item["text"],
concept=concept,
embedding=embedding_vector
embedding=embedding_vector,
sentence_embedder=item["hasSentenceEmbedder"]
)
mappings.append(mapping)
except Exception as e:
Expand All @@ -133,7 +146,8 @@ def get_closest_mappings(self, embedding, limit=5) -> List[Mapping]:
result = self.client.query.get(
"Mapping",
["text", "_additional { distance }",
"hasConcept { ... on Concept { _additional { id } conceptID prefLabel hasTerminology { ... on Terminology { _additional { id } name } } } }"]
"hasConcept { ... on Concept { _additional { id } conceptID prefLabel hasTerminology { ... on Terminology { _additional { id } name } } } }",
"hasSentenceEmbedder"]
).with_additional("vector").with_near_vector({"vector": embedding}).with_limit(limit).do()
for item in result['data']['Get']['Mapping']:
embedding_vector = item["_additional"]["vector"]
Expand All @@ -152,7 +166,8 @@ def get_closest_mappings(self, embedding, limit=5) -> List[Mapping]:
mapping = Mapping(
text=item["text"],
concept=concept,
embedding=embedding_vector
embedding=embedding_vector,
sentence_embedder=item["hasSentenceEmbedder"]
)
mappings.append(mapping)
except Exception as e:
Expand All @@ -165,7 +180,8 @@ def get_closest_mappings_with_similarities(self, embedding, limit=5) -> List[Tup
result = self.client.query.get(
"Mapping",
["text", "_additional { distance }",
"hasConcept { ... on Concept { _additional { id } conceptID prefLabel hasTerminology { ... on Terminology { _additional { id } name } } } }"]
"hasConcept { ... on Concept { _additional { id } conceptID prefLabel hasTerminology { ... on Terminology { _additional { id } name } } } }",
"hasSentenceEmbedder"]
).with_additional("vector").with_near_vector({"vector": embedding}).with_limit(limit).do()
for item in result['data']['Get']['Mapping']:
similarity = 1 - item["_additional"]["distance"]
Expand All @@ -185,12 +201,63 @@ def get_closest_mappings_with_similarities(self, embedding, limit=5) -> List[Tup
mapping = Mapping(
text=item["text"],
concept=concept,
embedding=embedding_vector
embedding=embedding_vector,
sentence_embedder=item["hasSentenceEmbedder"]
)
mappings_with_similarities.append((mapping, similarity))
except Exception as e:
raise RuntimeError(f"Failed to fetch closest mappings with similarities: {e}")
return mappings_with_similarities

def get_terminology_and_model_specific_closest_mappings(self, embedding, terminology_name: str, sentence_embedder_name: str, limit: int = 5) -> List[Tuple[Mapping, float]]:
mappings_with_similarities = []
try:
result = self.client.query.get(
"Mapping",
["text",
"_additional { distance }",
"hasConcept { ... on Concept { _additional { id } conceptID prefLabel hasTerminology { ... on Terminology { _additional { id } name } } } }",
"hasSentenceEmbedder"]
).with_where({
"operator": "And",
"operands": [
{
"path": ["hasSentenceEmbedder"],
"operator": "Equal",
"valueText": sentence_embedder_name
},
{
"path": ["hasConcept", "Concept", "hasTerminology", "Terminology", "name"],
"operator": "Equal",
"valueText": terminology_name
}
]
}).with_additional("vector").with_near_vector({"vector": embedding}).with_limit(limit).do()
for item in result['data']['Get']['Mapping']:
similarity = 1 - item["_additional"]["distance"]
embedding_vector = item["_additional"]["vector"]
concept_data = item["hasConcept"][0] # Assuming it has only one concept
terminology_data = concept_data["hasTerminology"][0]
terminology = Terminology(
name=terminology_data["name"],
id=terminology_data["_additional"]["id"]
)
concept = Concept(
concept_identifier=concept_data["conceptID"],
pref_label=concept_data["prefLabel"],
terminology=terminology,
id=concept_data["_additional"]["id"]
)
mapping = Mapping(
text=item["text"],
concept=concept,
embedding=embedding_vector,
sentence_embedder=item["hasSentenceEmbedder"]
)
mappings_with_similarities.append((mapping, similarity))
except Exception as e:
raise RuntimeError(f"Failed to fetch the closest mappings for terminology {terminology_name} and model {sentence_embedder_name}: {e}")
return mappings_with_similarities

def shut_down(self):
if self.mode == "memory":
Expand Down Expand Up @@ -238,6 +305,7 @@ def store(self, model_object_instance: Union[Terminology, Concept, Mapping]):
if not self._mapping_exists(model_object_instance.embedding):
properties = {
"text": model_object_instance.text,
"hasSentenceEmbedder": model_object_instance.sentence_embedder
}
self.client.data_object.create(
class_name="Mapping",
Expand Down
4 changes: 4 additions & 0 deletions datastew/repository/weaviate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
"name": "text",
"dataType": ["string"]
},
{
"name": "hasSentenceEmbedder",
"dataType": ["string"]
},
{
"name": "vector",
"dataType": ["number[]"]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
class TestEmbedding(unittest.TestCase):

def setUp(self):
self.mpnet_adapter = MPNetAdapter(model="sentence-transformers/all-mpnet-base-v2")
self.mpnet_adapter = MPNetAdapter(model_name="sentence-transformers/all-mpnet-base-v2")

def test_mpnet_adapter_get_embedding(self):
text = "This is a test sentence."
Expand Down
20 changes: 17 additions & 3 deletions tests/test_sql_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,26 @@ def tearDown(self):

def test_get_closest_mappings(self):
terminology = Terminology(name="Terminology 1", id="1")
model_name = "sentence-transformers/all-mpnet-base-v2"
concept = Concept(terminology=terminology, pref_label="Concept 1", concept_identifier="1")
mapping_1 = Mapping(concept=concept, text="Text 1", embedding=[0.1, 0.2, 0.3])
mapping_2 = Mapping(concept=concept, text="Text 2", embedding=[0.2, 0.3, 0.4])
mapping_3 = Mapping(concept=concept, text="Text 3", embedding=[1.2, 2.3, 3.4])
mapping_1 = Mapping(concept=concept, text="Text 1", embedding=[0.1, 0.2, 0.3], sentence_embedder=model_name)
mapping_2 = Mapping(concept=concept, text="Text 2", embedding=[0.2, 0.3, 0.4], sentence_embedder=model_name)
mapping_3 = Mapping(concept=concept, text="Text 3", embedding=[1.2, 2.3, 3.4], sentence_embedder=model_name)
self.repository.store_all([terminology, concept, mapping_1, mapping_2, mapping_3])
sample_embedding = [0.2, 0.4, 0.35]
closest_mappings, distances = self.repository.get_closest_mappings(sample_embedding, limit=3)
self.assertEqual(len(closest_mappings), 3)
self.assertEqual(mapping_2.text, closest_mappings[0].text)

def test_get_all_sentence_embedders(self):
terminology = Terminology(name="Terminology 1", id="1")
model_name_1 = "sentence-transformers/all-mpnet-base-v2"
model_name_2 = "text-embedding-ada-002"
concept = Concept(terminology=terminology, pref_label="Concept 1", concept_identifier="1")
mapping_1 = Mapping(concept=concept, text="Text 1", embedding=[0.1, 0.2, 0.3], sentence_embedder=model_name_1)
mapping_2 = Mapping(concept=concept, text="Text 1", embedding=[0.1, 0.2, 0.3], sentence_embedder=model_name_2)
self.repository.store_all([terminology, concept, mapping_1, mapping_2])
sentence_embedders = self.repository.get_all_sentence_embedders()
self.assertEqual(len(sentence_embedders), 2)
self.assertEqual(sentence_embedders[0], "sentence-transformers/all-mpnet-base-v2")
self.assertEqual(sentence_embedders[1], "text-embedding-ada-002")
7 changes: 4 additions & 3 deletions tests/test_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@ def test_mapping_storage_and_closest_retrieval(self):
terminology = Terminology("test", "test")
concept1 = Concept(terminology, "cat", "TEST:1")
concept1_description = "The cat is sitting on the mat."
mapping1 = Mapping(concept1, concept1_description, self.embedding_model.get_embedding(concept1_description))
sentence_embedder = "test"
mapping1 = Mapping(concept1, concept1_description, self.embedding_model.get_embedding(concept1_description), sentence_embedder=sentence_embedder)
concept2 = Concept(terminology, "sunrise", "TEST:2")
concept2_description = "The sun rises in the east."
mapping2 = Mapping(concept2, concept2_description, self.embedding_model.get_embedding(concept2_description))
mapping2 = Mapping(concept2, concept2_description, self.embedding_model.get_embedding(concept2_description), sentence_embedder=sentence_embedder)
concept3 = Concept(terminology, "dog", "TEST:3")
concept3_description = "A loyal companion to humans."
mapping3 = Mapping(concept3, concept3_description, self.embedding_model.get_embedding(concept3_description))
mapping3 = Mapping(concept3, concept3_description, self.embedding_model.get_embedding(concept3_description), sentence_embedder=sentence_embedder)
self.repository.store_all([terminology, concept1, mapping1, concept2, mapping2, concept3, mapping3])
# test new mappings
text1 = "A furry feline rests on the rug."
Expand Down
Loading

0 comments on commit fa768b9

Please sign in to comment.