Skip to content

Commit

Permalink
Wire up ANN vector index creation
Browse files Browse the repository at this point in the history
  • Loading branch information
odeke-em committed Dec 26, 2024
1 parent aeab06b commit dfeab0a
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 17 deletions.
142 changes: 125 additions & 17 deletions src/langchain_google_spanner/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ class SecondaryIndex:
index_name: str
columns: list[str]
storing_columns: Optional[list[str]] = None
num_leaves: Optional[int] = None # Only necessary for ANN
num_branches: Optional[int] = None # Only necessary for ANN
tree_depth: Optional[int] = None # Only necessary for ANN
index_type: Optional[DistanceStrategy] = None # Only necessary for ANN

def __post_init__(self):
# Check if column_name is None after initialization
Expand All @@ -109,6 +113,16 @@ class DistanceStrategy(Enum):
APPROX_COSINE = 5
APPROX_EUCLIDEAN = 6

def __str__(self):
return DISTANCE_STRATEGY_STRING[self]


DISTANCE_STRATEGY_STRING = {
COSINE: "COSINE",
EUCLIDEIAN: "EUCLIDEIAN",
DOT_PRODUCT: "DOT_PRODUCT",
}


class DialectSemantics(ABC):
"""
Expand Down Expand Up @@ -152,6 +166,12 @@ def getDeleteDocumentsValueParameters(self, columns, values) -> Dict[str, Any]:
DistanceStrategy.EUCLIDEIAN: "EUCLIDEAN_DISTANCE",
}

_GOOGLE_ALGO_INDEX_NAME = {
DistanceStrategy.COSINE: "COSINE",
DistanceStrategy.DOT_PRODUCT: "DOT_PRODUCT",
DistanceStrategy.EUCLIDEIAN: "EUCLIDEAN",
}


class GoogleSqlSemnatics(DialectSemantics):
"""
Expand All @@ -173,6 +193,12 @@ def getDeleteDocumentsParameters(self, columns) -> Tuple[str, Any]:
def getDeleteDocumentsValueParameters(self, columns, values) -> Dict[str, Any]:
return dict(zip(columns, values))

def getIndexDistanceType(self, distance_strategy) -> str:
value = _GOOGLE_ALGO_INDEX_NAME.get(distance_strategy, None)
if value is None:
raise Exception(f"{distance_strategy} is unsupported for distance_type")
return value


_PG_DISTANCE_ALGO_NAMES = {
DistanceStrategy.COSINE: "spanner.cosine_distance",
Expand Down Expand Up @@ -276,6 +302,15 @@ def __init__(
self.staleness = {key: value}


DEFAULT_ANN_TREE_DEPTH = 2
ANN_ACCEPTABLE_TREE_DEPTHS = (2, 3)


class AlgoKind(Enum):
KNN = 0
ANN = 1


class SpannerVectorStore(VectorStore):
GSQL_TYPES = {
CONTENT_COLUMN_NAME: ["STRING"],
Expand Down Expand Up @@ -306,6 +341,7 @@ def init_vector_store_table(
primary_key: Optional[str] = None,
vector_size: Optional[int] = None,
secondary_indexes: Optional[List[SecondaryIndex]] = None,
kind: AlgoKind = None,
) -> bool:
"""
Initialize the vector store new table in Google Cloud Spanner.
Expand Down Expand Up @@ -344,6 +380,7 @@ def init_vector_store_table(
metadata_columns,
primary_key,
secondary_indexes,
kind=kind,
)

operation = database.update_ddl(ddl)
Expand All @@ -363,6 +400,7 @@ def _generate_sql(
column_configs,
primary_key,
secondary_indexes: Optional[List[SecondaryIndex]] = None,
kind: Optional[AlgoKind] = AlgoKind.KNN,
):
"""
Generate SQL for creating the vector store table.
Expand All @@ -378,6 +416,40 @@ def _generate_sql(
Returns:
- str: The generated SQL.
"""

ddl_statements = [
SpannerVectorStore._generate_create_table_sql(
table_name,
id_column,
content_column,
embedding_column,
column_configs,
primary_key,
dialect,
)
]

if kind == AlgoKind.ANN:
ddl_statements += SpannerVectorStore._generate_secondary_indices_ddl_ANN(
table_name, dialect, secondary_indexes
)
else:
ddl_statements += SpannerVectorStore._generate_secondary_indices_ddl_KNN(
table_name, embedding_column, dialect, secondary_indexes
)

return ddl_statements

@staticmethod
def _generate_create_table_sql(
table_name,
id_column,
content_column,
embedding_column,
column_configs,
primary_key,
dialect=DatabaseDialect.GOOGLE_STANDARD_SQL,
):
create_table_statement = f"CREATE TABLE {table_name} (\n"

if not isinstance(id_column, TableColumn):
Expand Down Expand Up @@ -438,30 +510,66 @@ def _generate_sql(
+ ")"
)

return create_table_statement

@staticmethod
def _generate_secondary_indices_ddl_KNN(
table_name, embedding_column, dialect, secondary_indexes=None
):
if not secondary_indexes:
return []

secondary_index_ddl_statements = []
for secondary_index in secondary_indexes:
statement = f"CREATE INDEX {secondary_index.index_name} ON {table_name}("
statement = statement + ",".join(secondary_index.columns) + ") "

if secondary_indexes is not None:
for secondary_index in secondary_indexes:
statement = (
f"CREATE INDEX {secondary_index.index_name} ON {table_name}("
)
statement = statement + ",".join(secondary_index.columns) + ") "
if dialect == DatabaseDialect.POSTGRESQL:
statement = statement + "INCLUDE ("
else:
statement = statement + "STORING ("

if secondary_index.storing_columns is None:
secondary_index.storing_columns = [embedding_column.name]
elif embedding_column not in secondary_index.storing_columns:
secondary_index.storing_columns.append(embedding_column.name)

statement = statement + ",".join(secondary_index.storing_columns) + ")"
secondary_index_ddl_statements.append(statement)
return secondary_index_ddl_statements

@staticmethod
def _generate_secondary_indices_ddl_ANN(
table_name, dialect=DatabaseDialect.GOOGLE_STANDARD_SQL, secondary_indexes=[]
):
if dialect != DatabaseDialect.GOOGLE_STANDARD_SQL:
raise Exception(
f"ANN is only supported for the GoogleSQL dialect not {dialect}"
)

secondary_index_ddl_statements = []

if dialect == DatabaseDialect.POSTGRESQL:
statement = statement + "INCLUDE ("
else:
statement = statement + "STORING ("
for secondary_index in secondary_indexes:
statement = f"CREATE VECTOR INDEX {secondary_index.index_name}\n\tON {table_name}({secondary_index.columns[0]})"
options_segments = [f"distance_type='{secondary_index.index_type}'"]
if secondary_index.tree_depth > 0:
tree_depth = secondary_index.tree_depth
if tree_depth not in ANN_ACCEPTABLE_TREE_DEPTHS:
raise Exception(
f"tree_depth: {tree_depth} is not in the acceptable values: {ANN_ACCEPTABLE_TREE_DEPTHS}"
)
options_segments.append(f"tree_depth={secondary_index.tree_depth}")

if secondary_index.storing_columns is None:
secondary_index.storing_columns = [embedding_column.name]
elif embedding_column not in secondary_index.storing_columns:
secondary_index.storing_columns.append(embedding_column.name)
if secondary_index.num_branches > 0:
options_segments.append(f"num_branches={secondary_index.num_branches}")

statement = statement + ",".join(secondary_index.storing_columns) + ")"
if secondary_index.num_leaves > 0:
options_segments.append(f"num_leaves={secondary_index.num_leaves}")

secondary_index_ddl_statements.append(statement)
statement += "\n\tOPTIONS(" + ", ".join(options_segments) + ")"
secondary_index_ddl_statements.append(statement.strip())

return [create_table_statement] + secondary_index_ddl_statements
return secondary_index_ddl_statements

def __init__(
self,
Expand Down
36 changes: 36 additions & 0 deletions tests/unit/test_vectore_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
DistanceStrategy,
GoogleSqlSemnatics,
PGSqlSemnatics,
SecondaryIndex,
SpannerVectorStore,
)


Expand Down Expand Up @@ -69,3 +71,37 @@ def test_distance_function_raises_exception_if_unknown(self):
for strategy in strategies:
with self.assertRaises(Exception):
sem.getDistanceFunction(strategy)


class TestSpannerVectorStore_KNN(unittest.TestCase):
def test_generate_create_table_sql(self):
got = SpannerVectorStore._generate_create_table_sql(
"users",
"id",
"essays",
"science_scores",
[],
"id",
)
want = "CREATE TABLE users (\n id STRING(36),\n essays STRING(MAX),\n science_scores ARRAY<FLOAT64>\n) PRIMARY KEY(id)"
assert got == want

def test_generate_secondary_indices_ddl_ANN(self):
got = SpannerVectorStore._generate_secondary_indices_ddl_ANN(
"Documents",
secondary_indexes=[
SecondaryIndex(
index_name="DocEmbeddingIndex",
columns=["DocEmbedding"],
num_branches=1000,
tree_depth=3,
index_type=DistanceStrategy.COSINE,
num_leaves=100000,
)
],
)
want = [
"CREATE VECTOR INDEX DocEmbeddingIndex\n\tON Documents(DocEmbedding)\n\tOPTIONS (distance_type='COSINE', tree_depth=3, num_branches=1000, num_leaves=1000000)"
]

assert got == want

0 comments on commit dfeab0a

Please sign in to comment.