Skip to content

Commit

Permalink
给集成检索器增加上下文重排和上下文压缩
Browse files Browse the repository at this point in the history
  • Loading branch information
deadwalks committed Sep 13, 2024
1 parent c43a167 commit 9a88fa3
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 28 deletions.
141 changes: 114 additions & 27 deletions app/rag/retrievers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from utils.logger_config import LoggerManager
from langchain_core.retrievers import BaseRetriever
from langchain.retrievers import EnsembleRetriever
from langchain_core.documents import Document
from langchain.retrievers import EnsembleRetriever
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor
from rag.elasticsearch_db import ElasticsearchDB
# ES需要导入的库
from typing import List
Expand All @@ -11,12 +14,11 @@
from langchain_community.document_transformers import (
LongContextReorder,
)

from utils.util import get_rerank_model

logger = LoggerManager().logger



class SimpleRetrieverWrapper():
"""自定义检索器实现"""

Expand All @@ -26,27 +28,53 @@ def __init__(self, store, llm, **kwargs):
logger.info(f'检索器所使用的Chat模型:{self.llm}')

def create_retriever(self):
logger.info(f'初始化SimpleRetriever')
logger.info(f'初始化自定义的Retriever')

# 初始化一个空的检索器列表
retrievers = []
weights = []

# 创建一个 MultiQueryRetriever
# Step1:创建一个 多路召回检索器 MultiQueryRetriever
chromadb_retriever = self.store.as_retriever()
mq_retriever = MultiQueryRetriever.from_llm(retriever=chromadb_retriever, llm=self.llm)
logging.basicConfig()
logging.getLogger("langchain.retrievers.multi_query").setLevel(logging.INFO)
mq_retriever = MultiQueryRetrieverWrapper.from_llm(retriever=chromadb_retriever, llm=self.llm)

# Step2:创建一个 上下文压缩检索器ContextualCompressionRetriever
if settings.COMPRESSOR_ENABLE is True:
compressor = LLMChainExtractor.from_llm(llm=self.llm)
compression_retriever = ContextualCompressionRetrieverWrapper(
base_compressor=compressor, base_retriever=mq_retriever
)
# 开启开关就使用压缩检索器
retrievers.append(compression_retriever)
weights.append(0.5)
logger.info(f'已启用 ContextualCompressionRetriever')
else:
# 关闭开关就使用多路召回检索器
retrievers.append(mq_retriever)
weights.append(0.5)
logger.info(f'已启用 MultiQueryRetriever')

if settings.ELASTIC_ENABLE_USE is True:
# 创建一个 ES 的 Retriever
# Step3:创建一个 ES 检索器
if settings.ELASTIC_ENABLE_ES is True:
es_retriever = ElasticsearchRetriever()

# 将集合在一起
ensemble_retriever = EnsembleRetriever(
retrievers=[es_retriever, mq_retriever], weights=[0.5, 0.5])

logger.info(f'使用的检索器类: {ensemble_retriever.__class__.__name__}')
return ensemble_retriever
else:
logger.info(f'使用的检索器类: {mq_retriever.__class__.__name__}')
return mq_retriever
if settings.COMPRESSOR_ENABLE is True:
# 如果开启了压缩检索器,就将ES检索器的权重调整为0.3
es_compressor = LLMChainExtractor.from_llm(llm=self.llm)
es_compression_retriever = ContextualCompressionRetrieverWrapper(
base_compressor=compressor, base_retriever=es_retriever
)
# 开启开关就使用压缩检索器
retrievers.append(es_compression_retriever)
weights.append(0.5)
logger.info(f'已启用 ES的ContextualCompressionRetriever')
else:
retrievers.append(es_retriever)
weights.append(0.5)
logger.info(f'已启用 ElasticsearchRetriever')

# 使用集成检索器,将所有启用的检索器集合在一起
ensemble_retriever = EnsembleRetriever(retrievers=retrievers, weights=weights)
return ensemble_retriever

class ElasticsearchRetriever(BaseRetriever):
def _get_relevant_documents(self, query: str, ) -> List[Document]:
Expand All @@ -57,16 +85,15 @@ def _get_relevant_documents(self, query: str, ) -> List[Document]:
# 增加长上下文重排序
reordering = LongContextReorder()
reordered_docs = reordering.transform_documents(query_result)
logger.info(f"检索到的原始文档:")

for poriginal in query_result:
logger.info(f"{poriginal}")
# logger.info(f"ElasticSearch检索到的原始文档:")
# for poriginal in query_result:
# logger.info(f"{poriginal}")

logger.info(f"重新排序后的文档:")
logger.info(f"ElasticSearch检索重排后的文档:")
for preordered in reordered_docs:
logger.info(f"{preordered}")

logger.info(f"检索到资料文件个数{len(query_result)}")
logger.info(f"ElasticSearch检索到资料文件个数{len(query_result)}")

if reordered_docs:
return [Document(page_content=doc) for doc in reordered_docs]
Expand All @@ -79,3 +106,63 @@ async def _aget_relevant_documents(self, query: str) -> List[Document]:
if query_result:
return [Document(page_content=doc) for doc in query_result]
return []



class MultiQueryRetrieverWrapper(MultiQueryRetriever):
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> List[Document]:
"""
对MultiQueryRetriever进行重写,增加日志打印
"""
queries = self.generate_queries(query, run_manager)
if self.include_original:
queries.append(query)
documents = self.retrieve_documents(queries, run_manager)

# 增加长上下文重排序
reordering = LongContextReorder()
reordered_docs = reordering.transform_documents(documents)

logger.info(f'MultiQuery生成的检索语句:')
for q in queries:
logger.info(f"{q}")
logger.info(f'MultiQuery检索到的资料文件:')
for doc in documents:
logger.info(f"{doc}")
logger.info(f"MultiQuery检索到资料文件个数:{len(documents)}")

return self.unique_union(reordered_docs)
# return self.unique_union(documents)


class ContextualCompressionRetrieverWrapper(ContextualCompressionRetriever):
from typing import Any, List
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
"""
对ContextualCompressionRetriever进行重写,增加日志打印
"""

docs = self.base_retriever.invoke(
query, config={"callbacks": run_manager.get_child()}, **kwargs
)
if docs:
compressed_docs = self.base_compressor.compress_documents(
docs, query, callbacks=run_manager.get_child()
)
logger.info(f'压缩后的文档长度:{len(compressed_docs)}')
logger.info(f'压缩后的文档:{compressed_docs}')
return list(compressed_docs)
else:
return []

8 changes: 7 additions & 1 deletion app/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,15 @@
ES数据库相关的配置
"""
# ES服务开关:True表示开启ES服务,False表示关闭ES服务
ELASTIC_ENABLE_USE = True
ELASTIC_ENABLE_ES = True
ELASTIC_PASSWORD = os.getenv("ELASTIC_PASSWORD", "123abc")
ELASTIC_HOST = os.getenv("ELASTIC_HOST", "175.27.143.233")
ELASTIC_PORT = os.getenv("ELASTIC_PORT", 9200)
ELASTIC_SCHEMA = "https"
ELASTIC_INDEX_NAME = "smart_test_index"


"""
COMPRESSOR 检索器相关的配置
"""
COMPRESSOR_ENABLE = True
6 changes: 6 additions & 0 deletions app/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,9 @@ def get_zhipu_chat_model():
api_key="xxxx", )

return chat

def get_rerank_model():
from qianfan.resources import Reranker
r = Reranker(model="bce-reranker-base_v1")

return r

0 comments on commit 9a88fa3

Please sign in to comment.