Skip to content

Commit

Permalink
Initial Commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Harsh-br0 committed Aug 26, 2024
0 parents commit 58b4abd
Show file tree
Hide file tree
Showing 22 changed files with 477 additions and 0 deletions.
12 changes: 12 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
.vscode/
__pycache__/
venv/

assets/

docs/
models/
temp/

config.env
err.log
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
### Setup
- Install Dependencies with `pip install -r requirements.txt`.
- Rename `config.env.sample` to `config.env` and fill the vars.
- There's an unexpected issue with mongodb (check [this](https://www.mongodb.com/community/forums/t/error-connecting-to-search-index-management-service/270272)) that wouldn't let us create index programmatically, so we need to create a vector search index manually through atlas console. Follow [this guide](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/) to create the index for `vector_store` collection with this config below.

```
{
"type": "vector",
"path": "embedding",
"numDimensions": 768,
"similarity": "cosine"
}
```
### Usage
- Run `python ./main.py`
- Head over to `http://localhost:8080/docs`
### Models Used
- sentence-transformers/all-mpnet-base-v2
It is used for embedding vectors and this will run locally. Initally it will download the model into `models` directory.
- TinyLlama/TinyLlama-1.1B-Chat-v1.0
It is the main LLM and being used through the HuggingFace Hub Inference API.
> Note: Since it is using Inference API and a model locally, the setup would be too slow 😓. On my side, it took more than 2 mins exactly to add a document of 45+ pages to vector store and almost 1 min to process the messages with LLM.
3 changes: 3 additions & 0 deletions chatbot/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .utils import ensure_envs

ensure_envs(["MONGO_URL", "HUGGINGFACEHUB_API_TOKEN"])
67 changes: 67 additions & 0 deletions chatbot/chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains.history_aware_retriever import create_history_aware_retriever
from langchain.chains.retrieval import create_retrieval_chain
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableWithMessageHistory
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint

from .database import get_retriever, get_session_history
from .defaults import LLM_NAME

system_prompt = (
"You are an assistant for question-answering tasks. "
"Only use the following pieces of retrieved context to answer "
"the question. If you don't know the answer, say that you "
"don't know. Use three sentences maximum and keep the "
"answer concise."
"\n\n"
"{context}"
)

prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)

contextualize_q_system_prompt = (
"Given a chat history and the latest user question "
"which might reference context in the chat history, "
"formulate a standalone question which can be understood "
"without the chat history. Do NOT answer the question, "
"just reformulate it if needed and otherwise return it as is."
)

contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
("system", contextualize_q_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)

llm = ChatHuggingFace(
llm=HuggingFaceEndpoint(
repo_id=LLM_NAME,
temperature=0.1,
) # type: ignore
)


history_aware_retriever = create_history_aware_retriever(
llm, get_retriever(), contextualize_q_prompt
)

question_answer_chain = create_stuff_documents_chain(llm, prompt)

rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)

rag_chain_with_history = RunnableWithMessageHistory(
rag_chain,
get_session_history,
input_messages_key="input",
history_messages_key="chat_history",
output_messages_key="answer",
)
4 changes: 4 additions & 0 deletions chatbot/database/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .chat_history import get_session_history
from .vector_store import get_retriever, vector_store

__all__ = ("vector_store", "get_retriever", "get_session_history")
22 changes: 22 additions & 0 deletions chatbot/database/chat_history.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import os

from langchain_mongodb import MongoDBChatMessageHistory

from ..defaults import (
CHAT_HISTORY_COLLECTION_NAME,
CHAT_HISTORY_HISTORY_KEY,
CHAT_HISTORY_SESSION_KEY,
DB_NAME,
)


def get_session_history(session_id: str):
history = MongoDBChatMessageHistory(
os.environ["MONGO_URL"],
session_id,
DB_NAME,
CHAT_HISTORY_COLLECTION_NAME,
session_id_key=CHAT_HISTORY_SESSION_KEY,
history_key=CHAT_HISTORY_HISTORY_KEY,
)
return history
29 changes: 29 additions & 0 deletions chatbot/database/mongo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os

from pymongo import MongoClient
from pymongo.collection import Collection

from ..defaults import DB_NAME

client = MongoClient(os.environ["MONGO_URL"])
db = client[DB_NAME]


def get_collection(name: str) -> Collection:
return db[name]


# useless now since mongodb have issues with index management service
# https://www.mongodb.com/community/forums/t/error-connecting-to-search-index-management-service/270272
def ensure_index(col: Collection, model: dict):
if col.count_documents({}) == 0:
return False

index_name = model.get("name")
if index_name is not None:
if len(tuple(col.list_search_indexes(index_name))) == 0:
col.create_search_index(model)

return True

return False
28 changes: 28 additions & 0 deletions chatbot/database/vector_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_mongodb import MongoDBAtlasVectorSearch

from ..defaults import (
MODEL_DIR,
VECTORSTORE_COLLECTION_NAME,
VECTORSTORE_EMBEDDING_KEY,
VECTORSTORE_INDEX_NAME,
VECTORSTORE_SEARCH_FUNC,
)
from ..logging import logger
from .mongo import get_collection

log = logger(__name__)

embeddings = HuggingFaceEmbeddings(cache_folder=MODEL_DIR)

vector_store = MongoDBAtlasVectorSearch(
collection=get_collection(VECTORSTORE_COLLECTION_NAME),
embedding=embeddings,
index_name=VECTORSTORE_INDEX_NAME,
embedding_key=VECTORSTORE_EMBEDDING_KEY,
relevance_score_fn=VECTORSTORE_SEARCH_FUNC,
)


def get_retriever():
return vector_store.as_retriever(k=3)
21 changes: 21 additions & 0 deletions chatbot/defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
CONFIG_FILE = "config.env"
LOGGING_FILE = "err.log"
MODEL_DIR = "models/"
LLM_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

DB_NAME = "chatbot"

VECTORSTORE_COLLECTION_NAME = "vector_store"
VECTORSTORE_INDEX_NAME = "vector_index"
VECTORSTORE_EMBEDDING_KEY = "embedding"
VECTORSTORE_SEARCH_FUNC = "cosine"

CHAT_HISTORY_COLLECTION_NAME = "chat_history"
CHAT_HISTORY_SESSION_KEY = "sid"
CHAT_HISTORY_HISTORY_KEY = "history"

SPLITTER_CHUNK_SIZE = 1000
SPLITTER_CHUNK_OVERLAP = 25
MAX_READ_LINES_FOR_TEXT_FILE = 40

TEMP_DIR = "temp"
3 changes: 3 additions & 0 deletions chatbot/document_loader/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .loader import DocumentLoader

__all__ = ("DocumentLoader",)
23 changes: 23 additions & 0 deletions chatbot/document_loader/csv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import os

from langchain_community.document_loaders import CSVLoader as RawCSVLoader

from .utils import get_splitter


class CSVLoader:
def __init__(self) -> None:
self.splitter = get_splitter()

def load(self, path: str):
csv = RawCSVLoader(path)
for row in csv.lazy_load():
for idx, chunk in enumerate(self.splitter.transform_documents([row])):
chunk.metadata = {
"source": (
f"{os.path.basename(chunk.metadata['source'])}"
f":{chunk.metadata['row']}"
f":{idx}"
)
}
yield chunk
28 changes: 28 additions & 0 deletions chatbot/document_loader/loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from ..exceptions import LoaderNotSupported
from ..utils import chunk_iter
from .csv import CSVLoader
from .pdf import PDFLoader
from .text import TextLoader

LOADERS = {
"application/pdf": PDFLoader,
"text/csv": CSVLoader,
"text/plain": TextLoader,
}


class DocumentLoader:

def __init__(self, loader) -> None:
self._loader = loader

@classmethod
def by_type(cls, mime_type: str):
if mime_type is None or mime_type not in LOADERS:
raise LoaderNotSupported(f"Files with mime type {mime_type} are not supported...")

return cls(LOADERS[mime_type]())

def load(self, path: str, num_docs=20):
for docs in chunk_iter(self._loader.load(path), num_docs):
yield list(filter(lambda doc: len(doc.page_content.strip()) > 0, docs))
24 changes: 24 additions & 0 deletions chatbot/document_loader/pdf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import os

from langchain_community.document_loaders import PyPDFLoader

from .utils import get_splitter


class PDFLoader:

def __init__(self) -> None:
self.splitter = get_splitter()

def load(self, path: str):
pdf = PyPDFLoader(path)
for page in pdf.lazy_load():
for idx, chunk in enumerate(self.splitter.transform_documents([page])):
chunk.metadata = {
"source": (
f"{os.path.basename(chunk.metadata['source'])}"
f":{chunk.metadata['page']}"
f":{idx}"
)
}
yield chunk
20 changes: 20 additions & 0 deletions chatbot/document_loader/text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import os

from .utils import get_splitter


class TextLoader:
def __init__(self) -> None:
self.splitter = get_splitter()

def load(self, path: str):
with open(path) as f:
src_file = os.path.basename(path)
count = 0
while data := f.read(self.splitter._chunk_size * 2):
splitted_text = self.splitter.split_text(data)
for idx, chunk in enumerate(self.splitter.create_documents(splitted_text)):
if len(chunk.page_content.strip()) > 0:
chunk.metadata = {"source": f"{src_file}:{2 * count + idx}"}
yield chunk
count += 1
11 changes: 11 additions & 0 deletions chatbot/document_loader/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from langchain_text_splitters import RecursiveCharacterTextSplitter

from ..defaults import SPLITTER_CHUNK_OVERLAP, SPLITTER_CHUNK_SIZE


def get_splitter():
return RecursiveCharacterTextSplitter(
chunk_size=SPLITTER_CHUNK_SIZE,
chunk_overlap=SPLITTER_CHUNK_OVERLAP,
length_function=len,
)
10 changes: 10 additions & 0 deletions chatbot/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
class ChatBotException(Exception):
pass


class LoaderNotSupported(ChatBotException):
pass


class MimeTypeInvalid(ChatBotException):
pass
31 changes: 31 additions & 0 deletions chatbot/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import mimetypes
from typing import Optional

from .chain import rag_chain_with_history
from .database import vector_store
from .document_loader import DocumentLoader
from .exceptions import MimeTypeInvalid
from .logging import logger

log = logger(__name__)


def add_document(path: str, mime: Optional[str] = None):
if mime is None:
mime = mimetypes.guess_type(path)[0]
if mime is None:
raise MimeTypeInvalid("Invalid Mimetype...")

loader = DocumentLoader.by_type(mime)
for docs in loader.load(path):
vector_store.add_documents(docs)
log.info(f"Added {len(docs)} docs from path {path}.")

log.info(f"Added all docs for path {path}")


def ask_question(session_id: str, question: str):
res = rag_chain_with_history.invoke(
{"input": question}, config={"configurable": {"session_id": session_id}}
)
return res["answer"]
14 changes: 14 additions & 0 deletions chatbot/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import logging

from .defaults import LOGGING_FILE

logging.basicConfig(
level="INFO",
filename=LOGGING_FILE,
style="{",
format="{asctime} - {levelname}({levelno}) : {filename}(Line {lineno}) : {message}",
)


def logger(name=None):
return logging.getLogger(name)
Loading

0 comments on commit 58b4abd

Please sign in to comment.