Skip to content

Commit

Permalink
Adding support for multiple files
Browse files Browse the repository at this point in the history
  • Loading branch information
Harsh-br0 committed Aug 26, 2024
1 parent 03b0ff9 commit a5667b7
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 12 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
| Setting CHUNK_SIZE to 400 with k = 4 | `chatbot/defaults.py`:`L13` | Added a constant for k = 4. |
| | `chatbot/defaults.py`:`L19` | Updated the constant from 1000 to 400 |
| | `chatbot/database/vector_store.py`:`L29` | Imported and used the `K` param constant introduced previously. |
| Adding support for multiple files | `main.py`:`L35-69` | Replaced logic to handle multiple files instead of single one. |
4 changes: 3 additions & 1 deletion chatbot/document_loader/loader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from ..exceptions import LoaderNotSupported
from ..utils import chunk_iter
from .csv import CSVLoader
Expand All @@ -17,7 +19,7 @@ def __init__(self, loader) -> None:
self._loader = loader

@classmethod
def by_type(cls, mime_type: str):
def by_type(cls, mime_type: Optional[str]):
if mime_type is None or mime_type not in LOADERS:
raise LoaderNotSupported(f"Files with mime type {mime_type} are not supported...")

Expand Down
2 changes: 1 addition & 1 deletion chatbot/document_loader/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def __init__(self) -> None:
self.splitter = get_splitter()

def load(self, path: str):
with open(path) as f:
with open(path, encoding="utf8") as f:
src_file = os.path.basename(path)
count = 0
while data := f.read(self.splitter._chunk_size * 2):
Expand Down
45 changes: 35 additions & 10 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@

from chatbot.defaults import TEMP_DIR
from chatbot.document_loader import DocumentLoader
from chatbot.exceptions import ChatBotException
from chatbot.exceptions import ChatBotException, LoaderNotSupported
from chatbot.functions import add_document, ask_question
from chatbot.logging import logger

router = APIRouter()
log = logger(__name__)


def process_file(file: UploadFile):
Expand All @@ -30,18 +32,41 @@ def process_file(file: UploadFile):
return new_path


def post_process_file(path: str, mime: Optional[str]):
add_document(path, mime)
os.remove(path)
def process_files(files: list[UploadFile]):
paths, mimes = [], []
errs_with_exts = set()
for file in files:
try:
DocumentLoader.by_type(file.content_type)
paths.append(process_file(file))
mimes.append(file.content_type)
except Exception as e:
log.exception(e)

if isinstance(e, LoaderNotSupported):
errs_with_exts.add(os.path.basename(file.filename or ".unknown"))

return paths, mimes, errs_with_exts


def post_process_files(paths: list[str], mimes: list[Optional[str]]):
for path, mime in zip(paths, mimes):
try:
add_document(path, mime)
except Exception as e:
log.exception(e)
finally:
os.remove(path)


@router.post("/addDocument")
def add_doc_route(file: UploadFile, tasks: BackgroundTasks):
# document check
DocumentLoader.by_type(file.content_type) # type: ignore
path = process_file(file)
tasks.add_task(post_process_file, path, file.content_type)
return {"msg": "Document is being processed..."}
def add_doc_route(files: list[UploadFile], tasks: BackgroundTasks):
paths, mimes, errs = process_files(files)
tasks.add_task(post_process_files, paths, mimes)
res = {"msg": "Documents are being processed..."}
if errs:
res["error"] = "Following extensions are not supported: " + ", ".join(errs)
return res


@router.get("/ask")
Expand Down

0 comments on commit a5667b7

Please sign in to comment.