From 41525a487b66dba1a2d007b666f2223fc88a3a28 Mon Sep 17 00:00:00 2001 From: Chandrasekharan M <117059509+chandrasekharan-zipstack@users.noreply.github.com> Date: Fri, 30 Aug 2024 13:39:49 +0530 Subject: [PATCH] fix: Introduced new index_document() to fix chunking related issue (#96) * Introduced new index_document() to fix chunking related issue * Minor deprecated version correction --- src/unstract/sdk/__init__.py | 2 +- src/unstract/sdk/index.py | 21 +++++++-------------- src/unstract/sdk/vector_db.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 36 insertions(+), 15 deletions(-) diff --git a/src/unstract/sdk/__init__.py b/src/unstract/sdk/__init__.py index 14cf8252..7f22c00f 100644 --- a/src/unstract/sdk/__init__.py +++ b/src/unstract/sdk/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.47.0" +__version__ = "0.48.0" def get_sdk_version(): diff --git a/src/unstract/sdk/index.py b/src/unstract/sdk/index.py index c82029b1..bbd88eb4 100644 --- a/src/unstract/sdk/index.py +++ b/src/unstract/sdk/index.py @@ -4,7 +4,7 @@ from deprecated import deprecated from llama_index.core import Document -from llama_index.core.node_parser import SimpleNodeParser +from llama_index.core.node_parser import SentenceSplitter from llama_index.core.vector_stores import ( FilterOperator, MetadataFilter, @@ -199,7 +199,8 @@ def index( self.tool.stream_log(f"No nodes found for {doc_id}") except Exception as e: self.tool.stream_log( - f"Error querying {vector_db_instance_id}: {e}", level=LogLevel.ERROR + f"Error querying {vector_db_instance_id}: {e}, proceeding to index", + level=LogLevel.ERROR, ) if doc_id_found and reindex: @@ -288,7 +289,7 @@ def index( try: if chunk_size == 0: - parser = SimpleNodeParser.from_defaults( + parser = SentenceSplitter.from_defaults( chunk_size=len(documents[0].text) + 10, chunk_overlap=0, callback_manager=embedding.get_callback_manager(), @@ -301,12 +302,6 @@ def index( vector_db.add(doc_id, nodes=[node]) self.tool.stream_log("Added node to vector db") else: - storage_context = vector_db.get_storage_context() - parser = SimpleNodeParser.from_defaults( - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - callback_manager=embedding.get_callback_manager(), - ) self.tool.stream_log("Adding nodes to vector db...") # TODO: Phase 2: # Post insertion to VDB, use query using doc_id and @@ -318,13 +313,11 @@ def index( # Once this is in place, the overridden implementation # of prefixing ids with doc_id before adding to VDB # can be removed - vector_db.get_vector_store_index_from_storage_context( + vector_db.index_document( documents, - storage_context=storage_context, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, show_progress=True, - embed_model=embedding, - node_parser=parser, - callback_manager=embedding.get_callback_manager(), ) except Exception as e: self.tool.stream_log( diff --git a/src/unstract/sdk/vector_db.py b/src/unstract/sdk/vector_db.py index 4189236f..67098bcb 100644 --- a/src/unstract/sdk/vector_db.py +++ b/src/unstract/sdk/vector_db.py @@ -5,6 +5,7 @@ from deprecated import deprecated from llama_index.core import StorageContext, VectorStoreIndex from llama_index.core.indices.base import IndexType +from llama_index.core.node_parser import SentenceSplitter from llama_index.core.schema import BaseNode, Document from llama_index.core.vector_stores.types import ( BasePydanticVectorStore, @@ -119,6 +120,33 @@ def _get_vector_db(self) -> Union[BasePydanticVectorStore, VectorStore]: ) raise VectorDBError(f"Error getting vectorDB instance: {e}") from e + def index_document( + self, + documents: Sequence[Document], + chunk_size: int = 1024, + chunk_overlap: int = 128, + show_progress: bool = False, + **index_kwargs, + ) -> IndexType: + if not self._embedding_instance: + raise VectorDBError(self.EMBEDDING_INSTANCE_ERROR) + storage_context = self.get_storage_context() + parser = SentenceSplitter.from_defaults( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + callback_manager=self._embedding_instance.callback_manager, + ) + return VectorStoreIndex.from_documents( + documents, + storage_context=storage_context, + show_progress=show_progress, + embed_model=self._embedding_instance, + transformations=[parser], + callback_manager=self._embedding_instance.callback_manager, + **index_kwargs, + ) + + @deprecated(version="0.47.0", reason="Use index_document() instead") def get_vector_store_index_from_storage_context( self, documents: Sequence[Document],