Skip to content

Commit

Permalink
增加一个配置,默认不启用ES服务
Browse files Browse the repository at this point in the history
  • Loading branch information
deadwalks committed Sep 9, 2024
1 parent 3510f82 commit 310ce62
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 13 deletions.
18 changes: 12 additions & 6 deletions app/rag/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# ES需要导入的库
from typing import List
import logging
import settings

logger = LoggerManager().logger

Expand All @@ -32,14 +33,19 @@ def create_retriever(self):
logging.basicConfig()
logging.getLogger("langchain.retrievers.multi_query").setLevel(logging.INFO)

# 创建一个 ES 的 Retriever
es_retriever = ElasticsearchRetriever()
if settings.ELASTIC_ENABLE_USE == True:
# 创建一个 ES 的 Retriever
es_retriever = ElasticsearchRetriever()

# 将集合在一起
ensemble_retriever = EnsembleRetriever(
retrievers=[es_retriever, mq_retriever], weights=[0.5, 0.5])
# 将集合在一起
ensemble_retriever = EnsembleRetriever(
retrievers=[es_retriever, mq_retriever], weights=[0.5, 0.5])

return ensemble_retriever
logger.info(f'使用的检索器类: {ensemble_retriever.__class__.__name__}')
return ensemble_retriever
else:
logger.info(f'使用的检索器类: {mq_retriever.__class__.__name__}')
return mq_retriever


class ElasticsearchRetrieverWrapper(RetrieverBase):
Expand Down
5 changes: 4 additions & 1 deletion app/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,11 @@
"""
ES数据库相关的配置
"""
# 默认不使用ES服务
ELASTIC_ENABLE_USE = False
ELASTIC_PASSWORD = os.getenv("ELASTIC_PASSWORD", "123abc")
ELASTIC_HOST = os.getenv("ELASTIC_HOST", "175.27.143.233")
ELASTIC_PORT = os.getenv("ELASTIC_PORT", 9200)
ELASTIC_SCHEMA = "https"
ELASTIC_INDEX_NAME = "smart_bot_index"
# ELASTIC_INDEX_NAME = "smart_bot_index"
ELASTIC_INDEX_NAME = "smart_test_index"
11 changes: 5 additions & 6 deletions app/test_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def test_agent():

# 测试RAG主流程
def test_rag():
from rag.retrievers import MultiQueryRetrieverWrapper
from rag.rag import RagManager
from rag.vector_db import ChromaDB
from rag.retrievers import SimpleRetriever
llm, chat, embed = settings.LLM, settings.CHAT, settings.EMBED

# Chroma的配置
Expand All @@ -39,7 +39,7 @@ def test_rag():

# 多查询检索器
rag_manager = RagManager(vector_db_class=ChromaDB, db_config=db_config, llm=llm, embed=embed,
etriever_cls=MultiQueryRetrieverWrapper)
etriever_cls=SimpleRetriever)


example_query = "湖南长远锂科股份有限公司"
Expand Down Expand Up @@ -319,7 +319,6 @@ def splitFiles(docs):

def test_es_search():
from rag.retrievers import ElasticsearchRetriever
from rag.retrievers import MultiQueryRetrieverWrapper
from rag.rag import RagManager
from rag.vector_db import ChromaDB
llm, chat, embed = settings.LLM, settings.CHAT, settings.EMBED
Expand Down Expand Up @@ -356,12 +355,12 @@ def test_es_search():
# test_import_vector_db()
# test_import_elasticsearch()
# test_agent()
# test_rag()
test_rag()
# test_financebot()
test_financebot_ex()
# test_financebot_ex()
# test_llm_api()
# test_answer_question()
# test_clean_test_result()
# test_es_search()
test_es_search()


0 comments on commit 310ce62

Please sign in to comment.