Skip to content

Commit

Permalink
db fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
vyokky committed Jan 15, 2025
1 parent d95bf4a commit 26d888c
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 5 deletions.
4 changes: 3 additions & 1 deletion learner/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def create_indexer(app: str, docs: str, format: str, incremental: bool, save_pat
if incremental:
if app in records:
print_with_color("Merging with previous indexer...", "yellow")
prev_db = FAISS.load_local(records[app], embeddings)
prev_db = FAISS.load_local(
records[app], embeddings, allow_dangerous_deserialization=True
)
db.merge_from(prev_db)

db_file_path = os.path.join(save_path, app)
Expand Down
6 changes: 5 additions & 1 deletion record_processor/summarizer/summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,11 @@ def create_or_update_vector_db(summaries: list, db_path: str):

# Check if the db exists, if not, create a new one.
if os.path.exists(db_path):
prev_db = FAISS.load_local(db_path, get_hugginface_embedding())
prev_db = FAISS.load_local(
db_path,
get_hugginface_embedding(),
allow_dangerous_deserialization=True,
)
db.merge_from(prev_db)

db.save_local(db_path)
Expand Down
12 changes: 12 additions & 0 deletions ufo/module/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,15 @@ def to_dict(self) -> Dict[str, Any]:
:return: The dictionary of the context.
"""
return self._context

def from_dict(self, context_dict: Dict[str, Any]) -> None:
"""
Load the context from a dictionary.
:param context_dict: The dictionary of the context.
"""
for key in ContextNames:
if key.name in context_dict:
self._context[key.name] = context_dict.get(key.name)

# Sync the current round step and cost
self._sync_round_values()
16 changes: 13 additions & 3 deletions ufo/rag/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ def get_indexer(self, path: str):
return None

try:
db = FAISS.load_local(path, get_hugginface_embedding())
db = FAISS.load_local(
path, get_hugginface_embedding(), allow_dangerous_deserialization=True
)
return db
except:
# print_with_color(
Expand Down Expand Up @@ -142,7 +144,11 @@ def get_indexer(self, db_path: str):
"""

try:
db = FAISS.load_local(db_path, get_hugginface_embedding())
db = FAISS.load_local(
db_path,
get_hugginface_embedding(),
allow_dangerous_deserialization=True,
)
return db
except:
# print_with_color(
Expand Down Expand Up @@ -209,7 +215,11 @@ def get_indexer(self, db_path: str):
"""

try:
db = FAISS.load_local(db_path, get_hugginface_embedding())
db = FAISS.load_local(
db_path,
get_hugginface_embedding(),
allow_dangerous_deserialization=True,
)
return db
except:
# print_with_color(
Expand Down

0 comments on commit 26d888c

Please sign in to comment.