diff --git a/src/langchain_google_spanner/graph_store.py b/src/langchain_google_spanner/graph_store.py index d3e03a0..0c6859c 100644 --- a/src/langchain_google_spanner/graph_store.py +++ b/src/langchain_google_spanner/graph_store.py @@ -18,9 +18,11 @@ import string from abc import abstractmethod from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union +import json from google.cloud import spanner from google.cloud.spanner_v1 import param_types +from google.cloud.spanner_v1 import JsonObject from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship from langchain_community.graphs.graph_store import GraphStore from requests.structures import CaseInsensitiveDict @@ -57,6 +59,7 @@ def __hash__(self): ( self.edge.source.id, self.edge.target.id, + self.edge.type, ) ) @@ -65,6 +68,7 @@ def __eq__(self, other: Any): return ( self.edge.source.id == other.edge.source.id and self.edge.target.id == other.edge.target.id + and self.edge.type == other.edge.type ) return False @@ -122,7 +126,7 @@ def to_identifiers(s: List[str]) -> Iterable[str]: @staticmethod def fixup_identifier(s: str) -> str: - return re.sub("[{}]".format(string.whitespace), "_", s) + return re.sub("[{}]".format(string.whitespace + string.punctuation), "_", s) @staticmethod def fixup_graph_documents(graph_documents: List[GraphDocument]) -> None: @@ -164,8 +168,12 @@ class ElementSchema(object): NODE_KEY_COLUMN_NAME: str = "id" TARGET_NODE_KEY_COLUMN_NAME: str = "target_id" + DYNAMIC_PROPERTY_COLUMN_NAME: str = "properties" + DYNAMIC_LABEL_COLUMN_NAME: str = "label" + UUID_COLUMN_NAME: str = "default_uuid" name: str + original_name: str kind: str key_columns: List[str] base_table_name: str @@ -175,12 +183,21 @@ class ElementSchema(object): source: NodeReference target: NodeReference + def is_dynamic_schema(self) -> bool: + return ( + self.types.get(ElementSchema.DYNAMIC_PROPERTY_COLUMN_NAME, None) + == param_types.JSON + ) + @staticmethod - def from_nodes(name: str, nodes: List[Node]) -> ElementSchema: + def from_static_nodes( + name: str, graph_name: str, nodes: List[Node] + ) -> ElementSchema: """Builds ElementSchema from a list of nodes. Parameters: - name: name of the schema; + - graph_name: name of the graph; - nodes: a non-empty list of nodes. Returns: @@ -189,39 +206,91 @@ def from_nodes(name: str, nodes: List[Node]) -> ElementSchema: if len(nodes) == 0: raise ValueError("The list of nodes should not be empty") - props = set((key.casefold() for n in nodes for key in n.properties.keys())) - if ElementSchema.NODE_KEY_COLUMN_NAME in props: + node = ElementSchema() + node.types = CaseInsensitiveDict( + { + k: TypeUtility.value_to_param_type(v) + for n in nodes + for k, v in n.properties.items() + } + ) + if ElementSchema.NODE_KEY_COLUMN_NAME in node.types: raise ValueError( "Node properties should not contain property with name: `%s`" % ElementSchema.NODE_KEY_COLUMN_NAME ) - props.add(ElementSchema.NODE_KEY_COLUMN_NAME) + node.types[ElementSchema.NODE_KEY_COLUMN_NAME] = ( + TypeUtility.value_to_param_type(nodes[0].id) + ) - node = ElementSchema() - node.name = name + node.properties = CaseInsensitiveDict({prop: prop for prop in node.types}) + node.labels = [name] + node.base_table_name = "%s_%s" % (graph_name, name) + node.original_name = name + node.name = node.base_table_name node.kind = "NODE" node.key_columns = [ElementSchema.NODE_KEY_COLUMN_NAME] - node.base_table_name = name - node.labels = [name] - node.properties = CaseInsensitiveDict({prop: prop for prop in props}) + return node + + @staticmethod + def from_dynamic_nodes( + name: str, nodes: List[Node], graph_schema: SpannerGraphSchema + ) -> ElementSchema: + """Builds ElementSchema from a list of nodes. + + Parameters: + - name: name of the schema; + - graph_name: name of the graph; + - nodes: a non-empty list of nodes. + + Returns: + - ElementSchema: schema representation of the nodes. + """ + if len(nodes) == 0: + raise ValueError("The list of nodes should not be empty") + + node = ElementSchema() node.types = CaseInsensitiveDict( { - k: TypeUtility.value_to_param_type(v) - for n in nodes - for k, v in n.properties.items() + ElementSchema.DYNAMIC_PROPERTY_COLUMN_NAME: param_types.JSON, + ElementSchema.DYNAMIC_LABEL_COLUMN_NAME: param_types.STRING, + ElementSchema.NODE_KEY_COLUMN_NAME: TypeUtility.value_to_param_type( + nodes[0].id + ), } ) - node.types[ElementSchema.NODE_KEY_COLUMN_NAME] = ( - TypeUtility.value_to_param_type(nodes[0].id) + node.types.update( + CaseInsensitiveDict( + { + k: TypeUtility.value_to_param_type(v) + for n in nodes + for k, v in n.properties.items() + if k in graph_schema.static_node_properties + } + ) ) + node.properties = CaseInsensitiveDict({prop: prop for prop in node.types}) + + node.labels = ["Node"] + node.base_table_name = "%s_Node" % graph_schema.graph_name + node.original_name = name + node.name = node.base_table_name + node.kind = "NODE" + node.key_columns = [ElementSchema.NODE_KEY_COLUMN_NAME] return node @staticmethod - def from_edges(name: str, edges: List[Relationship]) -> ElementSchema: + def from_static_edges( + name: str, + graph_name: str, + edges: List[Relationship], + graph_schema: SpannerGraphSchema, + ) -> ElementSchema: """Builds ElementSchema from a list of edges. Parameters: - name: name of the schema; + - graph_name: name of the graph; - nodes: a non-empty list of edges. Returns: @@ -230,33 +299,7 @@ def from_edges(name: str, edges: List[Relationship]) -> ElementSchema: if len(edges) == 0: raise ValueError("The list of edges should not be empty") - props = set((key.casefold() for e in edges for key in e.properties.keys())) - if ElementSchema.NODE_KEY_COLUMN_NAME in props: - raise ValueError( - "Edge properties should not contain property with name: `%s`" - % ElementSchema.NODE_KEY_COLUMN_NAME - ) - if ElementSchema.TARGET_NODE_KEY_COLUMN_NAME in props: - raise ValueError( - "Edge properties should not contain property with name: `%s`" - % ElementSchema.TARGET_NODE_KEY_COLUMN_NAME - ) - props.add(ElementSchema.NODE_KEY_COLUMN_NAME) - props.add(ElementSchema.TARGET_NODE_KEY_COLUMN_NAME) - edge = ElementSchema() - edge.name = name - edge.kind = "EDGE" - edge.key_columns = [ - ElementSchema.NODE_KEY_COLUMN_NAME, - ElementSchema.TARGET_NODE_KEY_COLUMN_NAME, - ] - edge.base_table_name = name - - # Uses the type as label because the label can be shared by multiple base - # tables. - edge.labels = [edges[0].type] - edge.properties = CaseInsensitiveDict({prop: prop for prop in props}) edge.types = CaseInsensitiveDict( { k: TypeUtility.value_to_param_type(v) @@ -264,25 +307,218 @@ def from_edges(name: str, edges: List[Relationship]) -> ElementSchema: for k, v in e.properties.items() } ) + + for col_name in [ + ElementSchema.NODE_KEY_COLUMN_NAME, + ElementSchema.TARGET_NODE_KEY_COLUMN_NAME, + ElementSchema.UUID_COLUMN_NAME, + ]: + if col_name in edge.types: + raise ValueError( + "Edge properties should not contain property with name: `%s`" + % col_name + ) edge.types[ElementSchema.NODE_KEY_COLUMN_NAME] = ( TypeUtility.value_to_param_type(edges[0].source.id) ) edge.types[ElementSchema.TARGET_NODE_KEY_COLUMN_NAME] = ( TypeUtility.value_to_param_type(edges[0].target.id) ) + edge.types[ElementSchema.UUID_COLUMN_NAME] = param_types.STRING + + edge.properties = CaseInsensitiveDict({prop: prop for prop in edge.types}) + + edge.labels = [edges[0].type] + edge.base_table_name = "%s_%s" % (graph_name, name) + edge.key_columns = [ + ElementSchema.NODE_KEY_COLUMN_NAME, + ElementSchema.TARGET_NODE_KEY_COLUMN_NAME, + ElementSchema.UUID_COLUMN_NAME, + ] + + edge.original_name = name + edge.name = edge.base_table_name + edge.kind = "EDGE" + + source_node_schema = graph_schema.get_node_schema(edges[0].source.type) + if source_node_schema is None: + raise ValueError("No source node schema `%s` found" % edges[0].source.type) + + target_node_schema = graph_schema.get_node_schema(edges[0].target.type) + if target_node_schema is None: + raise ValueError("No target node schema `%s` found" % edges[0].target.type) + + edge.source = NodeReference( + source_node_schema.name, + [ElementSchema.NODE_KEY_COLUMN_NAME], + [ElementSchema.NODE_KEY_COLUMN_NAME], + ) + edge.target = NodeReference( + target_node_schema.name, + [ElementSchema.NODE_KEY_COLUMN_NAME], + [ElementSchema.TARGET_NODE_KEY_COLUMN_NAME], + ) + return edge + + @staticmethod + def from_dynamic_edges( + name: str, + edges: List[Relationship], + graph_schema: SpannerGraphSchema, + ) -> ElementSchema: + """Builds ElementSchema from a list of edges. + + Parameters: + - name: name of the schema; + - graph_name: name of the graph; + - nodes: a non-empty list of edges. + + Returns: + - ElementSchema: schema representation of the edges. + """ + if len(edges) == 0: + raise ValueError("The list of edges should not be empty") + + edge = ElementSchema() + edge.types = CaseInsensitiveDict( + { + ElementSchema.DYNAMIC_PROPERTY_COLUMN_NAME: param_types.JSON, + ElementSchema.DYNAMIC_LABEL_COLUMN_NAME: param_types.STRING, + ElementSchema.UUID_COLUMN_NAME: param_types.STRING, + ElementSchema.NODE_KEY_COLUMN_NAME: TypeUtility.value_to_param_type( + edges[0].source.id + ), + ElementSchema.TARGET_NODE_KEY_COLUMN_NAME: TypeUtility.value_to_param_type( + edges[0].target.id + ), + } + ) + edge.types.update( + CaseInsensitiveDict( + { + k: TypeUtility.value_to_param_type(v) + for e in edges + for k, v in e.properties.items() + if k in graph_schema.static_edge_properties + } + ) + ) + edge.properties = CaseInsensitiveDict({prop: prop for prop in edge.types}) + + edge.labels = ["Edge"] + edge.base_table_name = "%s_Edge" % graph_schema.graph_name + edge.key_columns = [ + ElementSchema.NODE_KEY_COLUMN_NAME, + ElementSchema.TARGET_NODE_KEY_COLUMN_NAME, + ElementSchema.DYNAMIC_LABEL_COLUMN_NAME, + ElementSchema.UUID_COLUMN_NAME, + ] + + edge.original_name = name + edge.name = edge.base_table_name + edge.kind = "EDGE" + + source_node_schema = graph_schema.get_node_schema(edges[0].source.type) + if source_node_schema is None: + raise ValueError("No source node schema `%s` found" % edges[0].source.type) + + target_node_schema = graph_schema.get_node_schema(edges[0].target.type) + if target_node_schema is None: + raise ValueError("No target node schema `%s` found" % edges[0].target.type) edge.source = NodeReference( - edges[0].source.type, + source_node_schema.name, [ElementSchema.NODE_KEY_COLUMN_NAME], [ElementSchema.NODE_KEY_COLUMN_NAME], ) edge.target = NodeReference( - edges[0].target.type, + target_node_schema.name, [ElementSchema.NODE_KEY_COLUMN_NAME], [ElementSchema.TARGET_NODE_KEY_COLUMN_NAME], ) return edge + def add_nodes( + self, name: str, nodes: List[Node] + ) -> Tuple[str, List[str], List[List[Any]]]: + """Builds the data required to add a list of nodes to Spanner. + + Parameters: + - name: type of name; + - nodes: a list of Nodes. + + Returns: + - str: a table name; + - List[str]: a list of column names; + - List[List[Any]]: a list of rows. + """ + if len(nodes) == 0: + raise ValueError("Empty list of nodes") + + columns = [k for k in self.types.keys()] + rows = [] + for node in nodes: + properties = node.properties.copy() + properties[ElementSchema.NODE_KEY_COLUMN_NAME] = node.id + + if self.is_dynamic_schema(): + dynamic_properties = { + k: TypeUtility.value_for_json(v) + for k, v in node.properties.items() + if k not in self.types + } + # Json loads and dumps handles some invalid characters + # that the JsonDecoder doesn't accept. + properties[ElementSchema.DYNAMIC_PROPERTY_COLUMN_NAME] = JsonObject( + json.loads(json.dumps(dynamic_properties)) + ) + properties[ElementSchema.DYNAMIC_LABEL_COLUMN_NAME] = node.type + + row = [properties.get(k, None) for k in columns] + rows.append(row) + return self.base_table_name, columns, rows + + def add_edges( + self, name: str, edges: List[Relationship] + ) -> Tuple[str, List[str], List[List[Any]]]: + """Builds the data required to add a list of edges to Spanner. + + Parameters: + - name: type of edge; + - edges: a list of Relationships. + + Returns: + - str: a table name; + - List[str]: a list of column names; + - List[List[Any]]: a list of rows. + """ + if len(edges) == 0: + raise ValueError("Empty list of edges") + + columns = [k for k in self.types.keys() if k != ElementSchema.UUID_COLUMN_NAME] + rows = [] + for edge in edges: + properties = edge.properties.copy() + properties[ElementSchema.NODE_KEY_COLUMN_NAME] = edge.source.id + properties[ElementSchema.TARGET_NODE_KEY_COLUMN_NAME] = edge.target.id + + if self.is_dynamic_schema(): + dynamic_properties = { + k: TypeUtility.value_for_json(v) + for k, v in edge.properties.items() + if k not in self.types + } + # Json loads and dumps handles some invalid characters + # that the JsonDecoder doesn't accept. + properties[ElementSchema.DYNAMIC_PROPERTY_COLUMN_NAME] = JsonObject( + json.loads(json.dumps(dynamic_properties)) + ) + properties[ElementSchema.DYNAMIC_LABEL_COLUMN_NAME] = edge.type + + row = [properties.get(k, None) for k in columns] + rows.append(row) + return self.base_table_name, columns, rows + @staticmethod def from_info_schema( element_schema: Dict[str, Any], @@ -300,6 +536,7 @@ def from_info_schema( """ element = ElementSchema() element.name = element_schema["name"] + element.original_name = element.name element.kind = element_schema["kind"] element.key_columns = element_schema["keyColumns"] element.base_table_name = element_schema["baseTableName"] @@ -314,6 +551,7 @@ def from_info_schema( { decl["name"]: TypeUtility.schema_str_to_spanner_type(decl["type"]) for decl in property_decls + if decl["name"] in element.properties } ) @@ -330,15 +568,23 @@ def from_info_schema( ) return element - def to_ddl(self) -> str: + def to_ddl(self, graph_schema: SpannerGraphSchema) -> str: """Returns a CREATE TABLE ddl that represents the element schema. Returns: - str: a string of CREATE TABLE ddl statement. + - graph_schema: Spanner Graph schema. """ to_identifier = GraphDocumentUtility.to_identifier to_identifiers = GraphDocumentUtility.to_identifiers + + def get_reference_node_table(name: str) -> str: + node_schema = graph_schema.node_tables.get(name, None) + if node_schema is None: + raise ValueError("No node schema `%s` found" % name) + return node_schema.base_table_name + return """CREATE TABLE {} ( {}{} ) PRIMARY KEY ({}){} @@ -346,11 +592,16 @@ def to_ddl(self) -> str: to_identifier(self.base_table_name), ",\n ".join( ( - "{} {}".format( + "{} {} {}".format( to_identifier(n), TypeUtility.spanner_type_to_schema_str( t, include_type_annotations=True ), + ( + "DEFAULT (GENERATE_UUID())" + if n == ElementSchema.UUID_COLUMN_NAME + else "" + ), ) for n, t in self.types.items() ) @@ -358,7 +609,7 @@ def to_ddl(self) -> str: ( ",\n FOREIGN KEY ({}) REFERENCES {}({})".format( ", ".join(to_identifiers(self.target.edge_keys)), - to_identifier(self.target.node_name), + to_identifier(get_reference_node_table(self.target.node_name)), ", ".join(to_identifiers(self.target.node_keys)), ) if self.kind == "EDGE" @@ -366,7 +617,9 @@ def to_ddl(self) -> str: ), ",".join(to_identifiers(self.key_columns)), ( - ", INTERLEAVE IN PARENT {}".format(to_identifier(self.source.node_name)) + ", INTERLEAVE IN PARENT {}".format( + to_identifier(get_reference_node_table(self.source.node_name)) + ) if self.kind == "EDGE" else "" ), @@ -381,12 +634,6 @@ def evolve(self, new_schema: ElementSchema) -> List[str]: Returns: - List[str]: a list of DDL statements. """ - if self.name.casefold() != new_schema.name.casefold(): - raise ValueError( - "Schema should have the same kind, got {}, expected {}".format( - new_schema.name, self.name - ) - ) if self.kind != new_schema.kind: raise ValueError( "Schema with name `{}` should have the same kind, got {}, expected {}".format( @@ -466,17 +713,34 @@ class SpannerGraphSchema(object): WHERE property_graph_name = '{}' """ - def __init__(self, graph_name: str): + def __init__( + self, + graph_name: str, + use_flexible_schema: bool, + static_node_properties: List[str] = [], + static_edge_properties: List[str] = [], + ): """Initializes the graph schema. Parameters: - - graph_name: the name of the graph. + - graph_name: the name of the graph; + - use_flexible_schema: whether to use the flexible schema which uses a + JSON blob to store node and edge properties; + - static_node_properties: in flexible schema, treat these node + properties as static; + - static_edge_properties: in flexible schema, treat these edge + properties as static. """ self.graph_name: str = graph_name self.nodes: CaseInsensitiveDict[ElementSchema] = CaseInsensitiveDict({}) self.edges: CaseInsensitiveDict[ElementSchema] = CaseInsensitiveDict({}) + self.node_tables: CaseInsensitiveDict[ElementSchema] = CaseInsensitiveDict({}) + self.edge_tables: CaseInsensitiveDict[ElementSchema] = CaseInsensitiveDict({}) self.labels: CaseInsensitiveDict[Label] = CaseInsensitiveDict({}) self.properties: CaseInsensitiveDict[param_types.Type] = CaseInsensitiveDict({}) + self.use_flexible_schema = use_flexible_schema + self.static_node_properties = set(static_node_properties) + self.static_edge_properties = set(static_edge_properties) def evolve(self, graph_documents: List[GraphDocument]) -> List[str]: """Evolves current schema into a schema representing the input documents. @@ -491,12 +755,20 @@ def evolve(self, graph_documents: List[GraphDocument]) -> List[str]: ddls = [] for k, ns in nodes.items(): - node_schema = ElementSchema.from_nodes(k, ns) + node_schema = ( + ElementSchema.from_static_nodes(k, self.graph_name, ns) + if not self.use_flexible_schema + else ElementSchema.from_dynamic_nodes(k, ns, self) + ) ddls.extend(self._update_node_schema(node_schema)) self._update_labels_and_properties(node_schema) for k, es in edges.items(): - edge_schema = ElementSchema.from_edges(k, es) + edge_schema = ( + ElementSchema.from_static_edges(k, self.graph_name, es, self) + if not self.use_flexible_schema + else ElementSchema.from_dynamic_edges(k, es, self) + ) ddls.extend(self._update_edge_schema(edge_schema)) self._update_labels_and_properties(edge_schema) @@ -511,32 +783,15 @@ def from_information_schema(self, info_schema: Dict[str, Any]) -> None: - info_schema: the information schema represenation of a graph; """ property_decls = info_schema.get("propertyDeclarations", []) - self.nodes = CaseInsensitiveDict( - { - node["name"]: ElementSchema.from_info_schema(node, property_decls) - for node in info_schema["nodeTables"] - } - ) - self.edges = CaseInsensitiveDict( - { - edge["name"]: ElementSchema.from_info_schema(edge, property_decls) - for edge in info_schema.get("edgeTables", []) - } - ) - self.labels = CaseInsensitiveDict( - { - label["name"]: Label( - label["name"], set(label.get("propertyDeclarationNames", [])) - ) - for label in info_schema["labels"] - } - ) - self.properties = CaseInsensitiveDict( - { - decl["name"]: TypeUtility.schema_str_to_spanner_type(decl["type"]) - for decl in property_decls - } - ) + for node in info_schema["nodeTables"]: + node_schema = ElementSchema.from_info_schema(node, property_decls) + self._update_node_schema(node_schema) + self._update_labels_and_properties(node_schema) + + for edge in info_schema.get("edgeTables", []): + edge_schema = ElementSchema.from_info_schema(edge, property_decls) + self._update_edge_schema(edge_schema) + self._update_labels_and_properties(edge_schema) def get_node_schema(self, name: str) -> Optional[ElementSchema]: """Gets the node schema by name. @@ -600,8 +855,6 @@ def __repr__(self) -> str: k: TypeUtility.spanner_type_to_schema_str(v) for k, v in self.properties.items() } - import json - return json.dumps( { "Name of graph": self.graph_name, @@ -693,7 +946,7 @@ def construct_node_reference( construct_columns(endpoint.node_keys), ) - def constuct_element_table( + def construct_element_table( element: ElementSchema, labels: CaseInsensitiveDict[Label] ) -> str: definition = [ @@ -718,15 +971,18 @@ def constuct_element_table( ) ddl += "\nNODE TABLES(\n " ddl += ",\n ".join( - (constuct_element_table(node, self.labels) for node in self.nodes.values()) + ( + construct_element_table(node, self.labels) + for node in self.node_tables.values() + ) ) ddl += "\n)" if len(self.edges) > 0: ddl += "\nEDGE TABLES(\n " ddl += ",\n ".join( ( - constuct_element_table(edge, self.labels) - for edge in self.edges.values() + construct_element_table(edge, self.labels) + for edge in self.edge_tables.values() ) ) ddl += "\n)" @@ -741,11 +997,15 @@ def _update_node_schema(self, node_schema: ElementSchema) -> List[str]: Returns: - List[str]: a list of DDL statements that requires to evolve the schema. """ - if node_schema.name not in self.nodes: - ddls = [node_schema.to_ddl()] - self.nodes[node_schema.name] = node_schema - return ddls - ddls = self.nodes[node_schema.name].evolve(node_schema) + + old_schema = self.node_tables.get(node_schema.name, None) + if old_schema is None: + ddls = [node_schema.to_ddl(self)] + self.node_tables[node_schema.name] = node_schema + else: + ddls = old_schema.evolve(node_schema) + + self.nodes[node_schema.original_name] = old_schema or node_schema return ddls def _update_edge_schema(self, edge_schema: ElementSchema) -> List[str]: @@ -757,11 +1017,15 @@ def _update_edge_schema(self, edge_schema: ElementSchema) -> List[str]: Returns: - List[str]: a list of DDL statements that requires to evolve the schema. """ - if edge_schema.name not in self.edges: - ddls = [edge_schema.to_ddl()] - self.edges[edge_schema.name] = edge_schema - return ddls - ddls = self.edges[edge_schema.name].evolve(edge_schema) + if edge_schema.base_table_name not in self.edge_tables: + ddls = [edge_schema.to_ddl(self)] + self.edge_tables[edge_schema.base_table_name] = edge_schema + else: + ddls = self.edge_tables[edge_schema.base_table_name].evolve(edge_schema) + + self.edges[edge_schema.original_name] = self.edge_tables[ + edge_schema.base_table_name + ] return ddls def _update_labels_and_properties(self, element_schema: ElementSchema) -> None: @@ -778,6 +1042,44 @@ def _update_labels_and_properties(self, element_schema: ElementSchema) -> None: self.properties.update(element_schema.types) + def add_nodes( + self, name: str, nodes: List[Node] + ) -> Tuple[str, List[str], List[List[Any]]]: + """Builds the data required to add a list of nodes to Spanner. + + Parameters: + - name: type of name; + - nodes: a list of Nodes. + + Returns: + - str: a table name; + - List[str]: a list of column names; + - List[List[Any]]: a list of rows. + """ + node_schema = self.get_node_schema(name) + if node_schema is None: + raise ValueError("Unknown node schema: `%s`" % name) + return node_schema.add_nodes(name, nodes) + + def add_edges( + self, name: str, edges: List[Relationship] + ) -> Tuple[str, List[str], List[List[Any]]]: + """Builds the data required to add a list of edges to Spanner. + + Parameters: + - name: type of edge; + - edges: a list of Relationships. + + Returns: + - str: a table name; + - List[str]: a list of column names; + - List[List[Any]]: a list of rows. + """ + edge_schema = self.get_edge_schema(name) + if edge_schema is None: + raise ValueError("Unknown edge schema `%s`" % name) + return edge_schema.add_edges(name, edges) + class SpannerImpl(object): """Wrapper of Spanner APIs.""" @@ -863,6 +1165,9 @@ def __init__( database_id: str, graph_name: str, client: Optional[spanner.Client] = None, + use_flexible_schema: bool = False, + static_node_properties: List[str] = [], + static_edge_properties: List[str] = [], ): """Parameters: @@ -870,9 +1175,20 @@ def __init__( - database_id: Google Cloud Spanner database id; - graph_name: Graph name; - client: an optional instance of Spanner client. + - use_flexible_schema: whether to use the flexible schema which uses a + JSON blob to store node and edge properties; + - static_node_properties: in flexible schema, treat these node + properties as static; + - static_edge_properties: in flexible schema, treat these edge + properties as static. """ self.impl = SpannerImpl(instance_id, database_id, client) - self.schema = SpannerGraphSchema(graph_name) + self.schema = SpannerGraphSchema( + graph_name, + use_flexible_schema, + static_node_properties, + static_edge_properties, + ) self.refresh_schema() @@ -907,7 +1223,7 @@ def add_graph_documents( for name, elements in nodes.items(): if len(elements) == 0: continue - table, columns, rows = self._add_nodes(name, elements) + table, columns, rows = self.schema.add_nodes(name, elements) print("Insert nodes of type `{}`...".format(name)) self.impl.insert_or_update(table, columns, rows) @@ -915,7 +1231,7 @@ def add_graph_documents( for name, elements in edges.items(): if len(elements) == 0: continue - table, columns, rows = self._add_edges(name, elements) + table, columns, rows = self.schema.add_edges(name, elements) print("Insert edges of type `{}`...".format(name)) self.impl.insert_or_update(table, columns, rows) @@ -939,12 +1255,7 @@ def get_schema(self) -> str: @property def get_structured_schema(self) -> Dict[str, Any]: - return { - "nodes": self.schema.nodes, - "edges": self.schema.edges, - "labels": self.schema.labels, - "properties": self.schema.properties, - } + return json.loads(repr(self.schema)) def get_ddl(self) -> str: return self.schema.to_ddl() @@ -968,64 +1279,6 @@ def refresh_schema(self) -> None: self.schema.from_information_schema(results[0]["property_graph_metadata_json"]) - def _add_nodes( - self, name: str, nodes: List[Node] - ) -> Tuple[str, List[str], List[List[Any]]]: - """Builds the data required to add a list of nodes to Spanner. - - Parameters: - - name: type of name; - - nodes: a list of Nodes. - - Returns: - - str: a table name; - - List[str]: a list of column names; - - List[List[Any]]: a list of rows. - """ - if len(nodes) == 0: - raise ValueError("Empty list of nodes") - - columns = list(set({k for node in nodes for k, v in node.properties.items()})) - - rows = [] - for node in nodes: - row = [node.properties.get(k, None) for k in columns] - row.append(node.id) - rows.append(row) - - columns.append(ElementSchema.NODE_KEY_COLUMN_NAME) - return name, columns, rows - - def _add_edges( - self, name: str, edges: List[Relationship] - ) -> Tuple[str, List[str], List[List[Any]]]: - """Builds the data required to add a list of edges to Spanner. - - Parameters: - - name: type of edge; - - edges: a list of Relationships. - - Returns: - - str: a table name; - - List[str]: a list of column names; - - List[List[Any]]: a list of rows. - """ - if len(edges) == 0: - raise ValueError("Empty list of edges") - - columns = list(set({k for edge in edges for k, v in edge.properties.items()})) - - rows = [] - for edge in edges: - row = [edge.properties.get(k, None) for k in columns] - row.append(edge.source.id) - row.append(edge.target.id) - rows.append(row) - - columns.append(ElementSchema.NODE_KEY_COLUMN_NAME) - columns.append(ElementSchema.TARGET_NODE_KEY_COLUMN_NAME) - return name, columns, rows - def cleanup(self): """Removes all data from your Spanner Graph. @@ -1053,4 +1306,6 @@ def cleanup(self): for node in self.schema.nodes.values() ] ) - self.schema = SpannerGraphSchema(self.schema.graph_name) + self.schema = SpannerGraphSchema( + self.schema.graph_name, self.schema.use_flexible_schema + ) diff --git a/src/langchain_google_spanner/type_utils.py b/src/langchain_google_spanner/type_utils.py index f03aacc..8397464 100644 --- a/src/langchain_google_spanner/type_utils.py +++ b/src/langchain_google_spanner/type_utils.py @@ -14,10 +14,12 @@ from __future__ import annotations +import base64 import datetime from typing import Any from google.cloud.spanner_v1 import param_types +from google.cloud.spanner_v1 import JsonObject class TypeUtility(object): @@ -59,6 +61,8 @@ def spanner_type_to_schema_str( return "FLOAT64" if t.code == param_types.TypeCode.TIMESTAMP: return "TIMESTAMP" + if t.code == param_types.TypeCode.JSON: + return "JSON" raise ValueError("Unsupported type: %s" % t) @staticmethod @@ -85,6 +89,8 @@ def schema_str_to_spanner_type(s: str) -> param_types.Type: return param_types.FLOAT32 if s == "TIMESTAMP": return param_types.TIMESTAMP + if s == "JSON": + return param_types.JSON if s.startswith("ARRAY<") and s.endswith(">"): return param_types.Array( TypeUtility.schema_str_to_spanner_type(s[len("ARRAY<") : -len(">")]) @@ -113,8 +119,40 @@ def value_to_param_type(v: Any) -> param_types.Type: return param_types.FLOAT64 if isinstance(v, datetime.datetime): return param_types.TIMESTAMP + if isinstance(v, JsonObject): + return param_types.JSON if isinstance(v, list): if len(v) == 0: raise ValueError("Unknown element type of empty array") return param_types.Array(TypeUtility.value_to_param_type(v[0])) raise ValueError("Unsupported type of param: {}".format(v)) + + @staticmethod + def value_for_json(v: Any) -> Any: + """Returns a value for JSON. + + Parameters: + - v: a python value. + """ + if isinstance(v, bool): + return v + if isinstance(v, int): + return v + if isinstance(v, str): + return v + if isinstance(v, float): + return str(v) + if isinstance(v, bytes): + return base64.b64encode(v).decode("utf-8") + if isinstance(v, datetime.datetime): + return str(v) + if isinstance(v, JsonObject): + return v + if isinstance(v, list): + return [TypeUtility.value_for_json(e) for e in v] + if isinstance(v, dict): + return { + TypeUtility.value_for_json(k): TypeUtility.value_for_json(v) + for k, v in v.items() + } + raise ValueError("Unsupported type of param: {}".format(v)) diff --git a/tests/integration/test_spanner_graph_store.py b/tests/integration/test_spanner_graph_store.py index 4fe6292..bb6447d 100644 --- a/tests/integration/test_spanner_graph_store.py +++ b/tests/integration/test_spanner_graph_store.py @@ -17,8 +17,10 @@ import os import random import string +import pytest from google.cloud.spanner import Client # type: ignore +from google.cloud.spanner_v1 import JsonObject from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship from langchain_core.documents import Document @@ -72,6 +74,10 @@ def random_none(): return None +def random_json(): + return JsonObject({random_string(exclude_whitespaces=True): random_int()}) + + def random_primitive_generators(): return [ random_int, @@ -87,7 +93,7 @@ def random_generators(): return ( random_primitive_generators() + [lambda: random_array(g) for g in random_primitive_generators()] - + [random_none] + + [random_none, random_json] ) @@ -97,6 +103,10 @@ def random_generators(): ] +def random_property_names(k): + return [k for k, _ in random.choices(properties, k=k)] + + def random_property(): k, vg = random.choice(properties) return k, vg() @@ -148,7 +158,8 @@ def random_graph_doc(suffix): class TestSpannerGraphStore: - def test_spanner_graph_random_doc(self): + @pytest.mark.parametrize("use_flexible_schema", [False, True]) + def test_spanner_graph_random_doc(self, use_flexible_schema): suffix = random_string(num_char=5, exclude_whitespaces=True) graph_name = "test_graph{}".format(suffix) graph = SpannerGraphStore( @@ -156,6 +167,13 @@ def test_spanner_graph_random_doc(self): google_database, graph_name, client=Client(project=project_id), + use_flexible_schema=use_flexible_schema, + static_node_properties=random_property_names( + random_int(l=0, u=len(properties)) + ), + static_edge_properties=random_property_names( + random_int(l=0, u=len(properties)) + ), ) graph.refresh_schema() @@ -212,7 +230,8 @@ def test_spanner_graph_random_doc(self): print(graph.get_ddl()) graph.cleanup() - def test_spanner_graph_doc_with_duplicate_elements(self): + @pytest.mark.parametrize("use_flexible_schema", [False, True]) + def test_spanner_graph_doc_with_duplicate_elements(self, use_flexible_schema): suffix = random_string(num_char=5, exclude_whitespaces=True) graph_name = "test_graph{}".format(suffix) graph = SpannerGraphStore( @@ -220,6 +239,13 @@ def test_spanner_graph_doc_with_duplicate_elements(self): google_database, graph_name, client=Client(project=project_id), + use_flexible_schema=use_flexible_schema, + static_node_properties=random_property_names( + random_int(l=0, u=len(properties)) + ), + static_edge_properties=random_property_names( + random_int(l=0, u=len(properties)) + ), ) graph.refresh_schema() @@ -239,12 +265,16 @@ def test_spanner_graph_doc_with_duplicate_elements(self): ) graph.add_graph_documents([doc]) + # In the case of flexible schema, `properties` is a nested json + # field. results = graph.query( """ GRAPH {} MATCH -[e]-> - RETURN TO_JSON(e)['properties'] AS properties + LET properties = TO_JSON(e)['properties'] + RETURN COALESCE(properties.properties, JSON "{{}}") AS dynamic_properties, + properties AS static_properties """.format( graph_name ), @@ -255,8 +285,13 @@ def test_spanner_graph_doc_with_duplicate_elements(self): edge_properties = edge0.properties edge_properties.update(edge1.properties) missing_properties = set(edge_properties.keys()).difference( - set(results[0]["properties"].keys()) + set(results[0]["dynamic_properties"].keys()).union( + set(results[0]["static_properties"].keys()) + ) ) + print(edge0.properties) + print(edge1.properties) + print(results) assert ( len(missing_properties) == 0 ), "Missing properties of edge: {}".format(missing_properties)