diff --git a/python/pathway/xpacks/llm/question_answering.py b/python/pathway/xpacks/llm/question_answering.py index b52ba371..2ca08901 100644 --- a/python/pathway/xpacks/llm/question_answering.py +++ b/python/pathway/xpacks/llm/question_answering.py @@ -307,6 +307,7 @@ class BaseRAGQuestionAnswerer(SummaryQuestionAnswerer): A pw.udf function is expected. Defaults to ``pathway.xpacks.llm.prompts.prompt_qa``. summarize_template: Template for text summarization. Defaults to ``pathway.xpacks.llm.prompts.prompt_summarize``. search_topk: Top k parameter for the retrieval. Adjusts number of chunks in the context. + query_rewrite_method: Method for query transformation. Accepts values: 'hyde', 'default', or None. Defaults to None. Example: @@ -357,6 +358,7 @@ def __init__( long_prompt_template: pw.UDF = prompts.prompt_qa, summarize_template: pw.UDF = prompts.prompt_summarize, search_topk: int = 6, + query_rewrite_method: str | None = None, ) -> None: self.llm = llm @@ -372,6 +374,7 @@ def __init__( self.summarize_template = summarize_template self.search_topk = search_topk + self.query_rewrite_method = query_rewrite_method self.server: None | QASummaryRestServer = None self._pending_endpoints: list[tuple] = [] @@ -402,6 +405,15 @@ def answer_query(self, pw_ai_queries: pw.Table) -> pw.Table: """Main function for RAG applications that answer questions based on available information.""" + if self.query_rewrite_method == "hyde": + pw_ai_queries += pw_ai_queries.select( + prompt=prompts.prompt_query_rewrite_hyde(pw.this.prompt) + ) + elif self.query_rewrite_method == "default": + pw_ai_queries += pw_ai_queries.select( + prompt=prompts.prompt_query_rewrite(pw.this.prompt) + ) + pw_ai_results = pw_ai_queries + self.indexer.retrieve_query( pw_ai_queries.select( metadata_filter=pw.this.filters, @@ -653,6 +665,7 @@ def __init__( factor: int = 2, max_iterations: int = 4, strict_prompt: bool = False, + query_rewrite_method: str | None = None, ) -> None: super().__init__( llm, @@ -661,6 +674,7 @@ def __init__( short_prompt_template=short_prompt_template, long_prompt_template=long_prompt_template, summarize_template=summarize_template, + query_rewrite_method=query_rewrite_method, ) self.n_starting_documents = n_starting_documents self.factor = factor @@ -677,6 +691,15 @@ def answer_query(self, pw_ai_queries: pw.Table) -> pw.Table: else: data_column_name = "text" + if self.query_rewrite_method == "hyde": + pw_ai_queries += pw_ai_queries.select( + prompt=prompts.prompt_query_rewrite_hyde(pw.this.prompt) + ) + elif self.query_rewrite_method == "default": + pw_ai_queries += pw_ai_queries.select( + prompt=prompts.prompt_query_rewrite(pw.this.prompt) + ) + result = pw_ai_queries.select( *pw.this, result=answer_with_geometric_rag_strategy_from_index( diff --git a/python/pathway/xpacks/llm/tests/test_rag.py b/python/pathway/xpacks/llm/tests/test_rag.py index d8593b24..8a1ac753 100644 --- a/python/pathway/xpacks/llm/tests/test_rag.py +++ b/python/pathway/xpacks/llm/tests/test_rag.py @@ -95,3 +95,121 @@ def test_base_rag(): """ ), ) + + +def test_base_rag_with_query_rewrite(): + schema = pw.schema_from_types(data=bytes, _metadata=dict) + input = pw.debug.table_from_rows( + schema=schema, rows=[("foo", {}), ("bar", {}), ("baz", {})] + ) + + vector_server = VectorStoreServer( + input, + embedder=fake_embeddings_model, + ) + + rag = BaseRAGQuestionAnswerer( + IdentityMockChat(), + vector_server, + short_prompt_template=_short_template, + long_prompt_template=_long_template, + summarize_template=_summarize_template, + search_topk=2, + query_rewrite_method="default", + ) + + answer_queries = pw.debug.table_from_rows( + schema=rag.AnswerQuerySchema, + rows=[ + ("foo", None, "gpt3.5", "short"), + ("baz", None, "gpt4", "long"), + ], + ) + + answer_output = rag.answer_query(answer_queries) + assert_table_equality( + answer_output.select(result=pw.this.result), + pw.debug.table_from_markdown( + """ + result + gpt3.5,short,foo,foo,bar + gpt4,long,baz,baz,bar + """ + ), + ) + + summarize_query = pw.debug.table_from_rows( + schema=rag.SummarizeQuerySchema, + rows=[(["foo", "bar"], "gpt2")], + ) + + summarize_outputs = rag.summarize_query(summarize_query) + + assert_table_equality( + summarize_outputs.select(result=pw.this.result), + pw.debug.table_from_markdown( + """ + result + gpt2,summarize,foo,bar + """ + ), + ) + + +def test_base_rag_with_hyde_query_rewrite(): + schema = pw.schema_from_types(data=bytes, _metadata=dict) + input = pw.debug.table_from_rows( + schema=schema, rows=[("foo", {}), ("bar", {}), ("baz", {})] + ) + + vector_server = VectorStoreServer( + input, + embedder=fake_embeddings_model, + ) + + rag = BaseRAGQuestionAnswerer( + IdentityMockChat(), + vector_server, + short_prompt_template=_short_template, + long_prompt_template=_long_template, + summarize_template=_summarize_template, + search_topk=2, + query_rewrite_method="hyde", + ) + + answer_queries = pw.debug.table_from_rows( + schema=rag.AnswerQuerySchema, + rows=[ + ("foo", None, "gpt3.5", "short"), + ("baz", None, "gpt4", "long"), + ], + ) + + answer_output = rag.answer_query(answer_queries) + assert_table_equality( + answer_output.select(result=pw.this.result), + pw.debug.table_from_markdown( + """ + result + gpt3.5,short,foo,foo,bar + gpt4,long,baz,baz,bar + """ + ), + ) + + summarize_query = pw.debug.table_from_rows( + schema=rag.SummarizeQuerySchema, + rows=[(["foo", "bar"], "gpt2")], + ) + + summarize_outputs = rag.summarize_query(summarize_query) + + assert_table_equality( + summarize_outputs.select(result=pw.this.result), + pw.debug.table_from_markdown( + """ + result + gpt2,summarize,foo,bar + """ + ), + )