-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 58b4abd
Showing
22 changed files
with
477 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
.vscode/ | ||
__pycache__/ | ||
venv/ | ||
|
||
assets/ | ||
|
||
docs/ | ||
models/ | ||
temp/ | ||
|
||
config.env | ||
err.log |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
### Setup | ||
- Install Dependencies with `pip install -r requirements.txt`. | ||
- Rename `config.env.sample` to `config.env` and fill the vars. | ||
- There's an unexpected issue with mongodb (check [this](https://www.mongodb.com/community/forums/t/error-connecting-to-search-index-management-service/270272)) that wouldn't let us create index programmatically, so we need to create a vector search index manually through atlas console. Follow [this guide](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/) to create the index for `vector_store` collection with this config below. | ||
|
||
``` | ||
{ | ||
"type": "vector", | ||
"path": "embedding", | ||
"numDimensions": 768, | ||
"similarity": "cosine" | ||
} | ||
``` | ||
### Usage | ||
- Run `python ./main.py` | ||
- Head over to `http://localhost:8080/docs` | ||
### Models Used | ||
- sentence-transformers/all-mpnet-base-v2 | ||
It is used for embedding vectors and this will run locally. Initally it will download the model into `models` directory. | ||
- TinyLlama/TinyLlama-1.1B-Chat-v1.0 | ||
It is the main LLM and being used through the HuggingFace Hub Inference API. | ||
> Note: Since it is using Inference API and a model locally, the setup would be too slow 😓. On my side, it took more than 2 mins exactly to add a document of 45+ pages to vector store and almost 1 min to process the messages with LLM. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .utils import ensure_envs | ||
|
||
ensure_envs(["MONGO_URL", "HUGGINGFACEHUB_API_TOKEN"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from langchain.chains.combine_documents import create_stuff_documents_chain | ||
from langchain.chains.history_aware_retriever import create_history_aware_retriever | ||
from langchain.chains.retrieval import create_retrieval_chain | ||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | ||
from langchain_core.runnables import RunnableWithMessageHistory | ||
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint | ||
|
||
from .database import get_retriever, get_session_history | ||
from .defaults import LLM_NAME | ||
|
||
system_prompt = ( | ||
"You are an assistant for question-answering tasks. " | ||
"Only use the following pieces of retrieved context to answer " | ||
"the question. If you don't know the answer, say that you " | ||
"don't know. Use three sentences maximum and keep the " | ||
"answer concise." | ||
"\n\n" | ||
"{context}" | ||
) | ||
|
||
prompt = ChatPromptTemplate.from_messages( | ||
[ | ||
("system", system_prompt), | ||
MessagesPlaceholder("chat_history"), | ||
("human", "{input}"), | ||
] | ||
) | ||
|
||
contextualize_q_system_prompt = ( | ||
"Given a chat history and the latest user question " | ||
"which might reference context in the chat history, " | ||
"formulate a standalone question which can be understood " | ||
"without the chat history. Do NOT answer the question, " | ||
"just reformulate it if needed and otherwise return it as is." | ||
) | ||
|
||
contextualize_q_prompt = ChatPromptTemplate.from_messages( | ||
[ | ||
("system", contextualize_q_system_prompt), | ||
MessagesPlaceholder("chat_history"), | ||
("human", "{input}"), | ||
] | ||
) | ||
|
||
llm = ChatHuggingFace( | ||
llm=HuggingFaceEndpoint( | ||
repo_id=LLM_NAME, | ||
temperature=0.1, | ||
) # type: ignore | ||
) | ||
|
||
|
||
history_aware_retriever = create_history_aware_retriever( | ||
llm, get_retriever(), contextualize_q_prompt | ||
) | ||
|
||
question_answer_chain = create_stuff_documents_chain(llm, prompt) | ||
|
||
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) | ||
|
||
rag_chain_with_history = RunnableWithMessageHistory( | ||
rag_chain, | ||
get_session_history, | ||
input_messages_key="input", | ||
history_messages_key="chat_history", | ||
output_messages_key="answer", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .chat_history import get_session_history | ||
from .vector_store import get_retriever, vector_store | ||
|
||
__all__ = ("vector_store", "get_retriever", "get_session_history") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import os | ||
|
||
from langchain_mongodb import MongoDBChatMessageHistory | ||
|
||
from ..defaults import ( | ||
CHAT_HISTORY_COLLECTION_NAME, | ||
CHAT_HISTORY_HISTORY_KEY, | ||
CHAT_HISTORY_SESSION_KEY, | ||
DB_NAME, | ||
) | ||
|
||
|
||
def get_session_history(session_id: str): | ||
history = MongoDBChatMessageHistory( | ||
os.environ["MONGO_URL"], | ||
session_id, | ||
DB_NAME, | ||
CHAT_HISTORY_COLLECTION_NAME, | ||
session_id_key=CHAT_HISTORY_SESSION_KEY, | ||
history_key=CHAT_HISTORY_HISTORY_KEY, | ||
) | ||
return history |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import os | ||
|
||
from pymongo import MongoClient | ||
from pymongo.collection import Collection | ||
|
||
from ..defaults import DB_NAME | ||
|
||
client = MongoClient(os.environ["MONGO_URL"]) | ||
db = client[DB_NAME] | ||
|
||
|
||
def get_collection(name: str) -> Collection: | ||
return db[name] | ||
|
||
|
||
# useless now since mongodb have issues with index management service | ||
# https://www.mongodb.com/community/forums/t/error-connecting-to-search-index-management-service/270272 | ||
def ensure_index(col: Collection, model: dict): | ||
if col.count_documents({}) == 0: | ||
return False | ||
|
||
index_name = model.get("name") | ||
if index_name is not None: | ||
if len(tuple(col.list_search_indexes(index_name))) == 0: | ||
col.create_search_index(model) | ||
|
||
return True | ||
|
||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from langchain_huggingface import HuggingFaceEmbeddings | ||
from langchain_mongodb import MongoDBAtlasVectorSearch | ||
|
||
from ..defaults import ( | ||
MODEL_DIR, | ||
VECTORSTORE_COLLECTION_NAME, | ||
VECTORSTORE_EMBEDDING_KEY, | ||
VECTORSTORE_INDEX_NAME, | ||
VECTORSTORE_SEARCH_FUNC, | ||
) | ||
from ..logging import logger | ||
from .mongo import get_collection | ||
|
||
log = logger(__name__) | ||
|
||
embeddings = HuggingFaceEmbeddings(cache_folder=MODEL_DIR) | ||
|
||
vector_store = MongoDBAtlasVectorSearch( | ||
collection=get_collection(VECTORSTORE_COLLECTION_NAME), | ||
embedding=embeddings, | ||
index_name=VECTORSTORE_INDEX_NAME, | ||
embedding_key=VECTORSTORE_EMBEDDING_KEY, | ||
relevance_score_fn=VECTORSTORE_SEARCH_FUNC, | ||
) | ||
|
||
|
||
def get_retriever(): | ||
return vector_store.as_retriever(k=3) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
CONFIG_FILE = "config.env" | ||
LOGGING_FILE = "err.log" | ||
MODEL_DIR = "models/" | ||
LLM_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | ||
|
||
DB_NAME = "chatbot" | ||
|
||
VECTORSTORE_COLLECTION_NAME = "vector_store" | ||
VECTORSTORE_INDEX_NAME = "vector_index" | ||
VECTORSTORE_EMBEDDING_KEY = "embedding" | ||
VECTORSTORE_SEARCH_FUNC = "cosine" | ||
|
||
CHAT_HISTORY_COLLECTION_NAME = "chat_history" | ||
CHAT_HISTORY_SESSION_KEY = "sid" | ||
CHAT_HISTORY_HISTORY_KEY = "history" | ||
|
||
SPLITTER_CHUNK_SIZE = 1000 | ||
SPLITTER_CHUNK_OVERLAP = 25 | ||
MAX_READ_LINES_FOR_TEXT_FILE = 40 | ||
|
||
TEMP_DIR = "temp" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .loader import DocumentLoader | ||
|
||
__all__ = ("DocumentLoader",) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import os | ||
|
||
from langchain_community.document_loaders import CSVLoader as RawCSVLoader | ||
|
||
from .utils import get_splitter | ||
|
||
|
||
class CSVLoader: | ||
def __init__(self) -> None: | ||
self.splitter = get_splitter() | ||
|
||
def load(self, path: str): | ||
csv = RawCSVLoader(path) | ||
for row in csv.lazy_load(): | ||
for idx, chunk in enumerate(self.splitter.transform_documents([row])): | ||
chunk.metadata = { | ||
"source": ( | ||
f"{os.path.basename(chunk.metadata['source'])}" | ||
f":{chunk.metadata['row']}" | ||
f":{idx}" | ||
) | ||
} | ||
yield chunk |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from ..exceptions import LoaderNotSupported | ||
from ..utils import chunk_iter | ||
from .csv import CSVLoader | ||
from .pdf import PDFLoader | ||
from .text import TextLoader | ||
|
||
LOADERS = { | ||
"application/pdf": PDFLoader, | ||
"text/csv": CSVLoader, | ||
"text/plain": TextLoader, | ||
} | ||
|
||
|
||
class DocumentLoader: | ||
|
||
def __init__(self, loader) -> None: | ||
self._loader = loader | ||
|
||
@classmethod | ||
def by_type(cls, mime_type: str): | ||
if mime_type is None or mime_type not in LOADERS: | ||
raise LoaderNotSupported(f"Files with mime type {mime_type} are not supported...") | ||
|
||
return cls(LOADERS[mime_type]()) | ||
|
||
def load(self, path: str, num_docs=20): | ||
for docs in chunk_iter(self._loader.load(path), num_docs): | ||
yield list(filter(lambda doc: len(doc.page_content.strip()) > 0, docs)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import os | ||
|
||
from langchain_community.document_loaders import PyPDFLoader | ||
|
||
from .utils import get_splitter | ||
|
||
|
||
class PDFLoader: | ||
|
||
def __init__(self) -> None: | ||
self.splitter = get_splitter() | ||
|
||
def load(self, path: str): | ||
pdf = PyPDFLoader(path) | ||
for page in pdf.lazy_load(): | ||
for idx, chunk in enumerate(self.splitter.transform_documents([page])): | ||
chunk.metadata = { | ||
"source": ( | ||
f"{os.path.basename(chunk.metadata['source'])}" | ||
f":{chunk.metadata['page']}" | ||
f":{idx}" | ||
) | ||
} | ||
yield chunk |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import os | ||
|
||
from .utils import get_splitter | ||
|
||
|
||
class TextLoader: | ||
def __init__(self) -> None: | ||
self.splitter = get_splitter() | ||
|
||
def load(self, path: str): | ||
with open(path) as f: | ||
src_file = os.path.basename(path) | ||
count = 0 | ||
while data := f.read(self.splitter._chunk_size * 2): | ||
splitted_text = self.splitter.split_text(data) | ||
for idx, chunk in enumerate(self.splitter.create_documents(splitted_text)): | ||
if len(chunk.page_content.strip()) > 0: | ||
chunk.metadata = {"source": f"{src_file}:{2 * count + idx}"} | ||
yield chunk | ||
count += 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from langchain_text_splitters import RecursiveCharacterTextSplitter | ||
|
||
from ..defaults import SPLITTER_CHUNK_OVERLAP, SPLITTER_CHUNK_SIZE | ||
|
||
|
||
def get_splitter(): | ||
return RecursiveCharacterTextSplitter( | ||
chunk_size=SPLITTER_CHUNK_SIZE, | ||
chunk_overlap=SPLITTER_CHUNK_OVERLAP, | ||
length_function=len, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
class ChatBotException(Exception): | ||
pass | ||
|
||
|
||
class LoaderNotSupported(ChatBotException): | ||
pass | ||
|
||
|
||
class MimeTypeInvalid(ChatBotException): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import mimetypes | ||
from typing import Optional | ||
|
||
from .chain import rag_chain_with_history | ||
from .database import vector_store | ||
from .document_loader import DocumentLoader | ||
from .exceptions import MimeTypeInvalid | ||
from .logging import logger | ||
|
||
log = logger(__name__) | ||
|
||
|
||
def add_document(path: str, mime: Optional[str] = None): | ||
if mime is None: | ||
mime = mimetypes.guess_type(path)[0] | ||
if mime is None: | ||
raise MimeTypeInvalid("Invalid Mimetype...") | ||
|
||
loader = DocumentLoader.by_type(mime) | ||
for docs in loader.load(path): | ||
vector_store.add_documents(docs) | ||
log.info(f"Added {len(docs)} docs from path {path}.") | ||
|
||
log.info(f"Added all docs for path {path}") | ||
|
||
|
||
def ask_question(session_id: str, question: str): | ||
res = rag_chain_with_history.invoke( | ||
{"input": question}, config={"configurable": {"session_id": session_id}} | ||
) | ||
return res["answer"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import logging | ||
|
||
from .defaults import LOGGING_FILE | ||
|
||
logging.basicConfig( | ||
level="INFO", | ||
filename=LOGGING_FILE, | ||
style="{", | ||
format="{asctime} - {levelname}({levelno}) : {filename}(Line {lineno}) : {message}", | ||
) | ||
|
||
|
||
def logger(name=None): | ||
return logging.getLogger(name) |
Oops, something went wrong.