From 58b4abdffe3791473a6c07231c0a633093241b41 Mon Sep 17 00:00:00 2001 From: Harsh <65716674+Harsh-br0@users.noreply.github.com> Date: Mon, 26 Aug 2024 11:22:33 +0530 Subject: [PATCH] Initial Commit --- .gitignore | 12 +++++ README.md | 29 ++++++++++++ chatbot/__init__.py | 3 ++ chatbot/chain.py | 67 ++++++++++++++++++++++++++ chatbot/database/__init__.py | 4 ++ chatbot/database/chat_history.py | 22 +++++++++ chatbot/database/mongo.py | 29 ++++++++++++ chatbot/database/vector_store.py | 28 +++++++++++ chatbot/defaults.py | 21 +++++++++ chatbot/document_loader/__init__.py | 3 ++ chatbot/document_loader/csv.py | 23 +++++++++ chatbot/document_loader/loader.py | 28 +++++++++++ chatbot/document_loader/pdf.py | 24 ++++++++++ chatbot/document_loader/text.py | 20 ++++++++ chatbot/document_loader/utils.py | 11 +++++ chatbot/exceptions.py | 10 ++++ chatbot/functions.py | 31 ++++++++++++ chatbot/logging.py | 14 ++++++ chatbot/utils.py | 26 +++++++++++ config.env.sample | 2 + main.py | 70 ++++++++++++++++++++++++++++ requirements.txt | Bin 0 -> 3152 bytes 22 files changed, 477 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 chatbot/__init__.py create mode 100644 chatbot/chain.py create mode 100644 chatbot/database/__init__.py create mode 100644 chatbot/database/chat_history.py create mode 100644 chatbot/database/mongo.py create mode 100644 chatbot/database/vector_store.py create mode 100644 chatbot/defaults.py create mode 100644 chatbot/document_loader/__init__.py create mode 100644 chatbot/document_loader/csv.py create mode 100644 chatbot/document_loader/loader.py create mode 100644 chatbot/document_loader/pdf.py create mode 100644 chatbot/document_loader/text.py create mode 100644 chatbot/document_loader/utils.py create mode 100644 chatbot/exceptions.py create mode 100644 chatbot/functions.py create mode 100644 chatbot/logging.py create mode 100644 chatbot/utils.py create mode 100644 config.env.sample create mode 100644 main.py create mode 100644 requirements.txt diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1796d21 --- /dev/null +++ b/.gitignore @@ -0,0 +1,12 @@ +.vscode/ +__pycache__/ +venv/ + +assets/ + +docs/ +models/ +temp/ + +config.env +err.log \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..e984381 --- /dev/null +++ b/README.md @@ -0,0 +1,29 @@ +### Setup +- Install Dependencies with `pip install -r requirements.txt`. +- Rename `config.env.sample` to `config.env` and fill the vars. +- There's an unexpected issue with mongodb (check [this](https://www.mongodb.com/community/forums/t/error-connecting-to-search-index-management-service/270272)) that wouldn't let us create index programmatically, so we need to create a vector search index manually through atlas console. Follow [this guide](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/) to create the index for `vector_store` collection with this config below. + + ``` + { + "type": "vector", + "path": "embedding", + "numDimensions": 768, + "similarity": "cosine" + } + ``` + +### Usage +- Run `python ./main.py` +- Head over to `http://localhost:8080/docs` + +### Models Used +- sentence-transformers/all-mpnet-base-v2 + + It is used for embedding vectors and this will run locally. Initally it will download the model into `models` directory. + +- TinyLlama/TinyLlama-1.1B-Chat-v1.0 + + It is the main LLM and being used through the HuggingFace Hub Inference API. + + +> Note: Since it is using Inference API and a model locally, the setup would be too slow 😓. On my side, it took more than 2 mins exactly to add a document of 45+ pages to vector store and almost 1 min to process the messages with LLM. \ No newline at end of file diff --git a/chatbot/__init__.py b/chatbot/__init__.py new file mode 100644 index 0000000..8c76228 --- /dev/null +++ b/chatbot/__init__.py @@ -0,0 +1,3 @@ +from .utils import ensure_envs + +ensure_envs(["MONGO_URL", "HUGGINGFACEHUB_API_TOKEN"]) diff --git a/chatbot/chain.py b/chatbot/chain.py new file mode 100644 index 0000000..27daffa --- /dev/null +++ b/chatbot/chain.py @@ -0,0 +1,67 @@ +from langchain.chains.combine_documents import create_stuff_documents_chain +from langchain.chains.history_aware_retriever import create_history_aware_retriever +from langchain.chains.retrieval import create_retrieval_chain +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.runnables import RunnableWithMessageHistory +from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint + +from .database import get_retriever, get_session_history +from .defaults import LLM_NAME + +system_prompt = ( + "You are an assistant for question-answering tasks. " + "Only use the following pieces of retrieved context to answer " + "the question. If you don't know the answer, say that you " + "don't know. Use three sentences maximum and keep the " + "answer concise." + "\n\n" + "{context}" +) + +prompt = ChatPromptTemplate.from_messages( + [ + ("system", system_prompt), + MessagesPlaceholder("chat_history"), + ("human", "{input}"), + ] +) + +contextualize_q_system_prompt = ( + "Given a chat history and the latest user question " + "which might reference context in the chat history, " + "formulate a standalone question which can be understood " + "without the chat history. Do NOT answer the question, " + "just reformulate it if needed and otherwise return it as is." +) + +contextualize_q_prompt = ChatPromptTemplate.from_messages( + [ + ("system", contextualize_q_system_prompt), + MessagesPlaceholder("chat_history"), + ("human", "{input}"), + ] +) + +llm = ChatHuggingFace( + llm=HuggingFaceEndpoint( + repo_id=LLM_NAME, + temperature=0.1, + ) # type: ignore +) + + +history_aware_retriever = create_history_aware_retriever( + llm, get_retriever(), contextualize_q_prompt +) + +question_answer_chain = create_stuff_documents_chain(llm, prompt) + +rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) + +rag_chain_with_history = RunnableWithMessageHistory( + rag_chain, + get_session_history, + input_messages_key="input", + history_messages_key="chat_history", + output_messages_key="answer", +) diff --git a/chatbot/database/__init__.py b/chatbot/database/__init__.py new file mode 100644 index 0000000..722eaac --- /dev/null +++ b/chatbot/database/__init__.py @@ -0,0 +1,4 @@ +from .chat_history import get_session_history +from .vector_store import get_retriever, vector_store + +__all__ = ("vector_store", "get_retriever", "get_session_history") diff --git a/chatbot/database/chat_history.py b/chatbot/database/chat_history.py new file mode 100644 index 0000000..87d1839 --- /dev/null +++ b/chatbot/database/chat_history.py @@ -0,0 +1,22 @@ +import os + +from langchain_mongodb import MongoDBChatMessageHistory + +from ..defaults import ( + CHAT_HISTORY_COLLECTION_NAME, + CHAT_HISTORY_HISTORY_KEY, + CHAT_HISTORY_SESSION_KEY, + DB_NAME, +) + + +def get_session_history(session_id: str): + history = MongoDBChatMessageHistory( + os.environ["MONGO_URL"], + session_id, + DB_NAME, + CHAT_HISTORY_COLLECTION_NAME, + session_id_key=CHAT_HISTORY_SESSION_KEY, + history_key=CHAT_HISTORY_HISTORY_KEY, + ) + return history diff --git a/chatbot/database/mongo.py b/chatbot/database/mongo.py new file mode 100644 index 0000000..a0638bb --- /dev/null +++ b/chatbot/database/mongo.py @@ -0,0 +1,29 @@ +import os + +from pymongo import MongoClient +from pymongo.collection import Collection + +from ..defaults import DB_NAME + +client = MongoClient(os.environ["MONGO_URL"]) +db = client[DB_NAME] + + +def get_collection(name: str) -> Collection: + return db[name] + + +# useless now since mongodb have issues with index management service +# https://www.mongodb.com/community/forums/t/error-connecting-to-search-index-management-service/270272 +def ensure_index(col: Collection, model: dict): + if col.count_documents({}) == 0: + return False + + index_name = model.get("name") + if index_name is not None: + if len(tuple(col.list_search_indexes(index_name))) == 0: + col.create_search_index(model) + + return True + + return False diff --git a/chatbot/database/vector_store.py b/chatbot/database/vector_store.py new file mode 100644 index 0000000..3452de6 --- /dev/null +++ b/chatbot/database/vector_store.py @@ -0,0 +1,28 @@ +from langchain_huggingface import HuggingFaceEmbeddings +from langchain_mongodb import MongoDBAtlasVectorSearch + +from ..defaults import ( + MODEL_DIR, + VECTORSTORE_COLLECTION_NAME, + VECTORSTORE_EMBEDDING_KEY, + VECTORSTORE_INDEX_NAME, + VECTORSTORE_SEARCH_FUNC, +) +from ..logging import logger +from .mongo import get_collection + +log = logger(__name__) + +embeddings = HuggingFaceEmbeddings(cache_folder=MODEL_DIR) + +vector_store = MongoDBAtlasVectorSearch( + collection=get_collection(VECTORSTORE_COLLECTION_NAME), + embedding=embeddings, + index_name=VECTORSTORE_INDEX_NAME, + embedding_key=VECTORSTORE_EMBEDDING_KEY, + relevance_score_fn=VECTORSTORE_SEARCH_FUNC, +) + + +def get_retriever(): + return vector_store.as_retriever(k=3) diff --git a/chatbot/defaults.py b/chatbot/defaults.py new file mode 100644 index 0000000..c02c7c0 --- /dev/null +++ b/chatbot/defaults.py @@ -0,0 +1,21 @@ +CONFIG_FILE = "config.env" +LOGGING_FILE = "err.log" +MODEL_DIR = "models/" +LLM_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + +DB_NAME = "chatbot" + +VECTORSTORE_COLLECTION_NAME = "vector_store" +VECTORSTORE_INDEX_NAME = "vector_index" +VECTORSTORE_EMBEDDING_KEY = "embedding" +VECTORSTORE_SEARCH_FUNC = "cosine" + +CHAT_HISTORY_COLLECTION_NAME = "chat_history" +CHAT_HISTORY_SESSION_KEY = "sid" +CHAT_HISTORY_HISTORY_KEY = "history" + +SPLITTER_CHUNK_SIZE = 1000 +SPLITTER_CHUNK_OVERLAP = 25 +MAX_READ_LINES_FOR_TEXT_FILE = 40 + +TEMP_DIR = "temp" diff --git a/chatbot/document_loader/__init__.py b/chatbot/document_loader/__init__.py new file mode 100644 index 0000000..7acaddc --- /dev/null +++ b/chatbot/document_loader/__init__.py @@ -0,0 +1,3 @@ +from .loader import DocumentLoader + +__all__ = ("DocumentLoader",) diff --git a/chatbot/document_loader/csv.py b/chatbot/document_loader/csv.py new file mode 100644 index 0000000..f590d13 --- /dev/null +++ b/chatbot/document_loader/csv.py @@ -0,0 +1,23 @@ +import os + +from langchain_community.document_loaders import CSVLoader as RawCSVLoader + +from .utils import get_splitter + + +class CSVLoader: + def __init__(self) -> None: + self.splitter = get_splitter() + + def load(self, path: str): + csv = RawCSVLoader(path) + for row in csv.lazy_load(): + for idx, chunk in enumerate(self.splitter.transform_documents([row])): + chunk.metadata = { + "source": ( + f"{os.path.basename(chunk.metadata['source'])}" + f":{chunk.metadata['row']}" + f":{idx}" + ) + } + yield chunk diff --git a/chatbot/document_loader/loader.py b/chatbot/document_loader/loader.py new file mode 100644 index 0000000..679d59a --- /dev/null +++ b/chatbot/document_loader/loader.py @@ -0,0 +1,28 @@ +from ..exceptions import LoaderNotSupported +from ..utils import chunk_iter +from .csv import CSVLoader +from .pdf import PDFLoader +from .text import TextLoader + +LOADERS = { + "application/pdf": PDFLoader, + "text/csv": CSVLoader, + "text/plain": TextLoader, +} + + +class DocumentLoader: + + def __init__(self, loader) -> None: + self._loader = loader + + @classmethod + def by_type(cls, mime_type: str): + if mime_type is None or mime_type not in LOADERS: + raise LoaderNotSupported(f"Files with mime type {mime_type} are not supported...") + + return cls(LOADERS[mime_type]()) + + def load(self, path: str, num_docs=20): + for docs in chunk_iter(self._loader.load(path), num_docs): + yield list(filter(lambda doc: len(doc.page_content.strip()) > 0, docs)) diff --git a/chatbot/document_loader/pdf.py b/chatbot/document_loader/pdf.py new file mode 100644 index 0000000..d5b4df9 --- /dev/null +++ b/chatbot/document_loader/pdf.py @@ -0,0 +1,24 @@ +import os + +from langchain_community.document_loaders import PyPDFLoader + +from .utils import get_splitter + + +class PDFLoader: + + def __init__(self) -> None: + self.splitter = get_splitter() + + def load(self, path: str): + pdf = PyPDFLoader(path) + for page in pdf.lazy_load(): + for idx, chunk in enumerate(self.splitter.transform_documents([page])): + chunk.metadata = { + "source": ( + f"{os.path.basename(chunk.metadata['source'])}" + f":{chunk.metadata['page']}" + f":{idx}" + ) + } + yield chunk diff --git a/chatbot/document_loader/text.py b/chatbot/document_loader/text.py new file mode 100644 index 0000000..edc09e3 --- /dev/null +++ b/chatbot/document_loader/text.py @@ -0,0 +1,20 @@ +import os + +from .utils import get_splitter + + +class TextLoader: + def __init__(self) -> None: + self.splitter = get_splitter() + + def load(self, path: str): + with open(path) as f: + src_file = os.path.basename(path) + count = 0 + while data := f.read(self.splitter._chunk_size * 2): + splitted_text = self.splitter.split_text(data) + for idx, chunk in enumerate(self.splitter.create_documents(splitted_text)): + if len(chunk.page_content.strip()) > 0: + chunk.metadata = {"source": f"{src_file}:{2 * count + idx}"} + yield chunk + count += 1 diff --git a/chatbot/document_loader/utils.py b/chatbot/document_loader/utils.py new file mode 100644 index 0000000..35f2fcd --- /dev/null +++ b/chatbot/document_loader/utils.py @@ -0,0 +1,11 @@ +from langchain_text_splitters import RecursiveCharacterTextSplitter + +from ..defaults import SPLITTER_CHUNK_OVERLAP, SPLITTER_CHUNK_SIZE + + +def get_splitter(): + return RecursiveCharacterTextSplitter( + chunk_size=SPLITTER_CHUNK_SIZE, + chunk_overlap=SPLITTER_CHUNK_OVERLAP, + length_function=len, + ) diff --git a/chatbot/exceptions.py b/chatbot/exceptions.py new file mode 100644 index 0000000..532f063 --- /dev/null +++ b/chatbot/exceptions.py @@ -0,0 +1,10 @@ +class ChatBotException(Exception): + pass + + +class LoaderNotSupported(ChatBotException): + pass + + +class MimeTypeInvalid(ChatBotException): + pass diff --git a/chatbot/functions.py b/chatbot/functions.py new file mode 100644 index 0000000..4f2e750 --- /dev/null +++ b/chatbot/functions.py @@ -0,0 +1,31 @@ +import mimetypes +from typing import Optional + +from .chain import rag_chain_with_history +from .database import vector_store +from .document_loader import DocumentLoader +from .exceptions import MimeTypeInvalid +from .logging import logger + +log = logger(__name__) + + +def add_document(path: str, mime: Optional[str] = None): + if mime is None: + mime = mimetypes.guess_type(path)[0] + if mime is None: + raise MimeTypeInvalid("Invalid Mimetype...") + + loader = DocumentLoader.by_type(mime) + for docs in loader.load(path): + vector_store.add_documents(docs) + log.info(f"Added {len(docs)} docs from path {path}.") + + log.info(f"Added all docs for path {path}") + + +def ask_question(session_id: str, question: str): + res = rag_chain_with_history.invoke( + {"input": question}, config={"configurable": {"session_id": session_id}} + ) + return res["answer"] diff --git a/chatbot/logging.py b/chatbot/logging.py new file mode 100644 index 0000000..0e83738 --- /dev/null +++ b/chatbot/logging.py @@ -0,0 +1,14 @@ +import logging + +from .defaults import LOGGING_FILE + +logging.basicConfig( + level="INFO", + filename=LOGGING_FILE, + style="{", + format="{asctime} - {levelname}({levelno}) : {filename}(Line {lineno}) : {message}", +) + + +def logger(name=None): + return logging.getLogger(name) diff --git a/chatbot/utils.py b/chatbot/utils.py new file mode 100644 index 0000000..4bc66c0 --- /dev/null +++ b/chatbot/utils.py @@ -0,0 +1,26 @@ +import os +from typing import Sequence + +import dotenv + +from .defaults import CONFIG_FILE + + +def ensure_envs(envs: Sequence[str]): + if not dotenv.load_dotenv(CONFIG_FILE): + raise RuntimeError("Failed to load the config file...") + + for env in envs: + if os.getenv(env) is None: + raise RuntimeError(f"env {env} not found...") + + +def chunk_iter(iterator, size=20): + temp = [] + for item in iterator: + temp.append(item) + if len(temp) == size: + yield temp + temp.clear() + if temp: + yield temp diff --git a/config.env.sample b/config.env.sample new file mode 100644 index 0000000..365a04c --- /dev/null +++ b/config.env.sample @@ -0,0 +1,2 @@ +MONGO_URL= +HUGGINGFACEHUB_API_TOKEN= \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..7730817 --- /dev/null +++ b/main.py @@ -0,0 +1,70 @@ +import os +from typing import Optional + +from fastapi import ( + APIRouter, + BackgroundTasks, + FastAPI, + HTTPException, + Request, + UploadFile, + status, +) + +from chatbot.defaults import TEMP_DIR +from chatbot.document_loader import DocumentLoader +from chatbot.exceptions import ChatBotException +from chatbot.functions import add_document, ask_question + +router = APIRouter() + + +def process_file(file: UploadFile): + if not os.path.isdir(TEMP_DIR): + os.makedirs(TEMP_DIR) + new_path = os.path.join(TEMP_DIR, file.filename or "Document.txt") + with open(new_path, "wb") as f: + while data := file.file.read(30): + f.write(data) + + return new_path + + +def post_process_file(path: str, mime: Optional[str]): + add_document(path, mime) + os.remove(path) + + +@router.post("/addDocument") +def add_doc_route(file: UploadFile, tasks: BackgroundTasks): + # document check + DocumentLoader.by_type(file.content_type) # type: ignore + path = process_file(file) + tasks.add_task(post_process_file, path, file.content_type) + return {"msg": "Document is being processed..."} + + +@router.get("/ask") +def ask_question_route(session_id: str, question: str): + session_id, question = map(str.strip, [session_id, question]) + if session_id == "" or question == "": + raise HTTPException(status.HTTP_400_BAD_REQUEST, "Query params are empty.") + answer = ask_question(session_id, question) + return {"output": answer} + + +app = FastAPI( + title="Simple RAG Server", +) +app.include_router(router, prefix="/api") + + +@app.exception_handler(ChatBotException) +async def unicorn_exception_handler(request: Request, exc: ChatBotException): + raise HTTPException(status.HTTP_406_NOT_ACCEPTABLE, str(exc)) + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="localhost", port=8080) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..a168f29d4873098b367a96114afdf5a286fe51e8 GIT binary patch literal 3152 zcmaKuOK;m&5QXo$Kz|B@N=lq`(M7gJ1GLDht3cMnl4FsID7#jFeB1NQ;qcOlQV_Ic z&HK!mx%c0HmStbMvMbAS)N?7DcyIOpgPy+pqGw)SmDhTIDnFJVB(cT0oMnN0A=ymQ zSLMw_qnF-YX~Pp*Gs&%KFRf&qd@S{D%DOcAyvXBbh>E|rdNXP(DUtC-Hq(rH$=b3C zANz8UkFGq*U+Us(G2}+xr?E1;P@ki$WO(tZ z1+#tmJA7bEwY>Ae3hAnR4!`g8JBYj;WSwj~#a@OVtjU?|ZbD-nI4x!Wpl^M|?iKr9 z-#TTG7B{2MwSMiRr)TAxI@F5JEWB>YvpjC3Ma}5kDF?7~XUPx?SFnx9bnG~wLyj{& z=-4B;J;)=~WUcP{8|fIoZNvblR(gA7=}r-iiqjKS!iQB^7hh8M=wcoqQ zFVUHWfA@MPeHgq_rg%Yz46rqiQ{y>g{63#BawnAT^y~vQumQZ%| zjhB{2=X;jfZFjU{bxr?zmopk3yjQN-9+V~Zb?SR%1l@TT(|65={yeK9W6E^EQhH`2 zGv}8TvWZURNq;l5I<*<47hZTXNAN=?KkI>KAJmJJb`z8TQQxfC5BA-C(-8lQsY$yN8$U87%USALNVvG7*o9TjmKI3V-%rR`A_cdD7~ z3fs>5@1P3sWK7{l^_6ExBcGhwbmJ+svkQsCMN*$PjXY+bL6DPm>fg-!U6uS<-iv|1 zg)~`UhEqRW+oyyZZx=59@TN@d7=QW;aG8{QZJ z&KPbmaH&@wMCZnnIe1qmD!sJW%ACu-JK?=@;bZwjxW89+Y+?4ka{v`81{sj^jf1Vk z`G1OPyqU;2w}cVAu48YN@7&p5PUI}|ay(bd)YbitEtr{=u#A2#)vcpy;|2wn&0#8> z>UQb__~reqZqu1IV50T)6nKg=BaqHBmWu0ipT(-RssRgry@}o8J@$XSGFSNPPWuLI zjyKBHH#pB|a-M7x0q)Dz{-o+SH>mODmb7>Bbsdk