Skip to content

Commit

Permalink
fix embeddings module
Browse files Browse the repository at this point in the history
  • Loading branch information
azliu0 committed Apr 15, 2024
1 parent b9a99f7 commit 9f1e02d
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 92 deletions.
174 changes: 83 additions & 91 deletions server/nlp/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""Embeddings.
This module provides functions for embedding documents and querying them in Redis.
"""

import os
import time

Expand Down Expand Up @@ -31,17 +36,13 @@


def load_corpus(corpus: list[RedisDocument]):
"""Loads given corpus into redis
"""Loads given corpus into redis.
PARAMETERS
----------
corpus : :obj:`list` of :obj:`RedisDocument`
list of documents, each represented by dictionary
Args:
corpus: list of documents, each represented by dictionary
Raises:
------
Exception
if failed to load corpus into redis
exception: if failed to load corpus into redis
"""
print("loading corpus...")

Expand All @@ -57,6 +58,14 @@ def load_corpus(corpus: list[RedisDocument]):


def compute_openai_embeddings(texts):
"""Compute embeddings from texts using OpenAI.
Args:
texts: list of texts to embed
Returns:
list of embeddings
"""
embeddings = []
for i in range(len(texts)):
embeddings.append(
Expand All @@ -68,7 +77,7 @@ def compute_openai_embeddings(texts):


def compute_embeddings():
"""Compute embeddings from redis documents"""
"""Compute embeddings from redis documents."""
print("computing embeddings...")

# get keys, questions, content
Expand Down Expand Up @@ -97,17 +106,14 @@ def compute_embeddings():


def load_embeddings(embeddings: list[list[float]]):
"""Load embeddings into redis
"""Load embeddings into redis.
PARAMETERS
----------
embeddings : :obj:`list` of :obj:`list` of :obj:`float`
list of embeddings
Args:
embeddings:
list of embeddings
Raises:
------
Exception
if failed to load embeddings into redis
exception: if failed to load embeddings into redis
"""
print("loading embeddings into redis...")

Expand All @@ -125,18 +131,16 @@ def load_embeddings(embeddings: list[list[float]]):


def create_index(corpus_len: int):
"""Create search index in redis
assumes that documents and embeddings have already been loaded into redis
"""Create search index in redis.
Assumes that documents and embeddings have already been loaded into redis
PARAMETERS
----------
corpus_len : :obj:`int`
number of documents in corpus
Args:
corpus_len:
number of documents in corpus
Raises:
------
Exception
if failed to create index
exception: if failed to create index
"""
print("creating index...")

Expand Down Expand Up @@ -178,12 +182,13 @@ def create_index(corpus_len: int):


def create_query(k: int):
"""Create k-NN redis query
"""Create k-NN redis query.
PARAMETERS
----------
k : :obj:`int`
number of nearest neighbors to return
Args:
k: number of nearest neighbors to return
Returns:
redis query object
"""
return (
Query(f"(*)=>[KNN {k} @vector $query_vector AS vector_score]")
Expand All @@ -194,18 +199,13 @@ def create_query(k: int):


def queries(query, queries: list[str]) -> list[dict]:
"""Run queries against redis
"""Run queries against redis.
PARAMETERS
----------
query : :obj:`Query`
redis query object
queries : :obj:`list` of :obj:`str`
list of question queries
Args:
query: redis query object
queries: list of question queries
Returns:
-------
:obj:`list` of :obj:`dict`
list of dictionaries containing query and result
"""
print("running queries...")
Expand Down Expand Up @@ -243,36 +243,27 @@ def queries(query, queries: list[str]) -> list[dict]:


def query_all(k: int, questions: list[str]):
"""Return k most similar documents for each query
"""Return k most similar documents for each query.
PARAMETERS
----------
k : :obj:`int`
number of nearest neighbors to return
questions : :obj:`list` of :obj:`str`
list of question queries
Args:
k: number of nearest neighbors to return
questions: list of question queries
Returns:
-------
:obj:`list` of :obj:`dict`
list of dictionaries containing query and result
"""
redis_query = create_query(k)
return queries(redis_query, questions)


def embed_corpus(corpus: list[RedisDocument]):
"""Load corpus, compute embeddings, load embeddings into redis
"""Load corpus, compute embeddings, load embeddings into redis.
PARAMETERS
----------
corpus : :obj:`list` of :obj:`dict`
list of documents, each represented by dictionary
Args:
corpus: list of documents, each represented by dictionary
Raises:
------
Exception
if failed to load corpus
exception: if failed to load corpus
"""
# flush database
print("cleaning database...")
Expand All @@ -288,37 +279,38 @@ def embed_corpus(corpus: list[RedisDocument]):
create_index(len(corpus))


def test():
try:
embed_corpus()
except Exception as err:
print(f"Unexpected {err=}, {type(err)=}")
raise

questions = [
"What is the deadline to apply for the hackathon?",
"When is HackMIT?",
"What are the challenges?",
"How does judging work?",
"What building should I go to during the event?",
"What prizes are available?",
"How many people are allowed on a team?",
"What is HackMIT?",
"Can I attend HackMIT if I am an MIT grad student?",
"Can I attend HackMIT if I am a sophomore in high school?",
"I'm a high school student, but I'm really advanced. Can I attend HackMIT?",
"Do I need to bring money to the event?",
"Will we be able to sleep at the event?",
"Will we be able to stay overnight at the event?",
"What should I do if I am a beginner at the event?",
]
results = query_all(3, questions)

for result in results:
print(result["query"])
for doc in result["result"]:
print(f"Score: {doc['score']}")
print(f"Source: {doc['source']}")
print(f"Q: {doc['question']}")
print(f"A: {doc['content']}")
print()
# TODO(azliu): turn this into a test case
# def test():
# try:
# embed_corpus()
# except Exception as err:
# print(f"Unexpected {err=}, {type(err)=}")
# raise

# questions = [
# "What is the deadline to apply for the hackathon?",
# "When is HackMIT?",
# "What are the challenges?",
# "How does judging work?",
# "What building should I go to during the event?",
# "What prizes are available?",
# "How many people are allowed on a team?",
# "What is HackMIT?",
# "Can I attend HackMIT if I am an MIT grad student?",
# "Can I attend HackMIT if I am a sophomore in high school?",
# "I'm a high school student, but I'm really advanced. Can I attend HackMIT?",
# "Do I need to bring money to the event?",
# "Will we be able to sleep at the event?",
# "Will we be able to stay overnight at the event?",
# "What should I do if I am a beginner at the event?",
# ]
# results = query_all(3, questions)

# for result in results:
# print(result["query"])
# for doc in result["result"]:
# print(f"Score: {doc['score']}")
# print(f"Source: {doc['source']}")
# print(f"Q: {doc['question']}")
# print(f"A: {doc['content']}")
# print()
2 changes: 1 addition & 1 deletion server/nlp/responses.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Provides responses module.
"""Responses.
This module is used to generate responses to incoming emails using OpenAI.
"""
Expand Down

0 comments on commit 9f1e02d

Please sign in to comment.