From c46ee65a014de65360507c95e6954103bd2afbb4 Mon Sep 17 00:00:00 2001 From: Amarnath Mullick Date: Mon, 16 Dec 2024 23:17:51 +0000 Subject: [PATCH] Few minor changes to the SpannerGraphNodeVectorRetriever --- src/langchain_google_spanner/graph_retriever.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/langchain_google_spanner/graph_retriever.py b/src/langchain_google_spanner/graph_retriever.py index 120805a..1e8a492 100644 --- a/src/langchain_google_spanner/graph_retriever.py +++ b/src/langchain_google_spanner/graph_retriever.py @@ -234,10 +234,12 @@ class SpannerGraphNodeVectorRetriever(BaseRetriever): embeddings_column: str = "embedding" """The name of the column that stores embedding""" query_parameters: QueryParameters = QueryParameters() - k: int = 10 - """Number of top results to return""" + top_k: int = 3 + """Number of vector similarity matches to return""" graph_expansion_query: str = "" """GQL query to expand the returned context""" + k: int = 10 + """Number of graph results to return""" @classmethod def from_params( @@ -278,9 +280,10 @@ def _get_relevant_documents( VECTOR_QUERY = """ GRAPH {graph_name} MATCH ({node_var}:{label_expr}) + WHERE {node_var}.{embeddings_column} IS NOT NULL ORDER BY {distance_fn}({node_var}.{embeddings_column}, ARRAY[{query_embeddings}]) - LIMIT {k} + LIMIT {top_k} """ gql_query = VECTOR_QUERY.format( graph_name=graph_name, @@ -289,7 +292,7 @@ def _get_relevant_documents( embeddings_column=self.embeddings_column, distance_fn=distance_fn, query_embeddings=",".join(map(str, query_embeddings)), - k=self.k, + top_k=self.top_k, ) if self.return_properties_list: