Skip to content

Commit

Permalink
Proposal for Neo4j Uploader Improvments (#357)
Browse files Browse the repository at this point in the history
* proposed changes to improve neo4j operations
- use neo4j-rust-ext instead of plain neo4j driver for 10x perf improvement
- always match on a single node label (equivalent to the constraint), never blank matches
- group relationships by type, source- and target-type
- increase batch size
- use vector property procedure to set fp32 instead of p64
- method to select the main label for a node
- TODO: create vector index would need information from the embedder (dimension) and similarity function (from config)
- Set extra labels

* Fixed one missing label.value, added created nodes/rels to log output

* Set default values for username and database

* Modify main label logic, deduplicate data from node and edge

Implement main label getting logic on the Node.
Validate that Node has at least one label to ensure there's always a
main label.
Remove data about nodes from the edge, refer to Node objects directly
in the Edge.

* Add index creation with 'cosine' similarity function

* make creation optional, resolve conflict

* changelog and version

* tidy

---------

Co-authored-by: Filip Knefel <[email protected]>
Co-authored-by: Ahmet Melek <[email protected]>
Co-authored-by: Ahmet Melek <[email protected]>
  • Loading branch information
4 people authored Feb 13, 2025
1 parent 442dbfa commit 4edcd1b
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 48 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
## 0.5.3-dev1
## 0.5.3-dev2

### Enhancements

* **Improvements on Neo4J uploader, and ability to create a vector index**
* **Optimize embedder code** - Move duplicate code to base interface, exit early if no elements have text.

### Fixes
Expand Down
2 changes: 1 addition & 1 deletion requirements/connectors/neo4j.in
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
neo4j
neo4j-rust-ext
cymple
networkx
4 changes: 2 additions & 2 deletions requirements/connectors/neo4j.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
# uv pip compile ./connectors/neo4j.in --output-file ./connectors/neo4j.txt --no-strip-extras --python-version 3.9
cymple==0.12.0
# via -r ./connectors/neo4j.in
neo4j==5.28.1
neo4j-rust-ext==5.27.0.0
# via -r ./connectors/neo4j.in
networkx==3.2.1
# via -r ./connectors/neo4j.in
pytz==2025.1
pytz==2024.2
# via neo4j
2 changes: 1 addition & 1 deletion unstructured_ingest/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.5.3-dev1" # pragma: no cover
__version__ = "0.5.3-dev2" # pragma: no cover
172 changes: 129 additions & 43 deletions unstructured_ingest/v2/processes/connectors/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional
from typing import TYPE_CHECKING, Any, AsyncGenerator, Literal, Optional

from pydantic import BaseModel, ConfigDict, Field, Secret
from pydantic import BaseModel, ConfigDict, Field, Secret, field_validator

from unstructured_ingest.error import DestinationConnectionError
from unstructured_ingest.logger import logger
Expand All @@ -30,6 +30,8 @@
DestinationRegistryEntry,
)

SimilarityFunction = Literal["cosine"]

if TYPE_CHECKING:
from neo4j import AsyncDriver, Auth
from networkx import Graph, MultiDiGraph
Expand All @@ -44,9 +46,9 @@ class Neo4jAccessConfig(AccessConfig):
class Neo4jConnectionConfig(ConnectionConfig):
access_config: Secret[Neo4jAccessConfig]
connector_type: str = Field(default=CONNECTOR_TYPE, init=False)
username: str
username: str = Field(default="neo4j")
uri: str = Field(description="Neo4j Connection URI <scheme>://<host>:<port>")
database: str = Field(description="Name of the target database")
database: str = Field(default="neo4j", description="Name of the target database")

@requires_dependencies(["neo4j"], extras="neo4j")
@asynccontextmanager
Expand Down Expand Up @@ -186,8 +188,8 @@ def from_nx(cls, nx_graph: "MultiDiGraph") -> _GraphData:
nodes = list(nx_graph.nodes())
edges = [
_Edge(
source_id=u.id_,
destination_id=v.id_,
source=u,
destination=v,
relationship=Relationship(data_dict["relationship"]),
)
for u, v, data_dict in nx_graph.edges(data=True)
Expand All @@ -198,19 +200,30 @@ def from_nx(cls, nx_graph: "MultiDiGraph") -> _GraphData:
class _Node(BaseModel):
model_config = ConfigDict()

id_: str = Field(default_factory=lambda: str(uuid.uuid4()))
labels: list[Label] = Field(default_factory=list)
labels: list[Label]
properties: dict = Field(default_factory=dict)
id_: str = Field(default_factory=lambda: str(uuid.uuid4()))

def __hash__(self):
return hash(self.id_)

@property
def main_label(self) -> Label:
return self.labels[0]

@classmethod
@field_validator("labels", mode="after")
def require_at_least_one_label(cls, value: list[Label]) -> list[Label]:
if not value:
raise ValueError("Node must have at least one label.")
return value


class _Edge(BaseModel):
model_config = ConfigDict()

source_id: str
destination_id: str
source: _Node
destination: _Node
relationship: Relationship


Expand All @@ -229,7 +242,14 @@ class Relationship(Enum):

class Neo4jUploaderConfig(UploaderConfig):
batch_size: int = Field(
default=100, description="Maximal number of nodes/relationships created per transaction."
default=1000, description="Maximal number of nodes/relationships created per transaction."
)
similarity_function: SimilarityFunction = Field(
default="cosine",
description="Vector similarity function used to create index on Chunk nodes",
)
create_destination: bool = Field(
default=True, description="Create destination if it does not exist"
)


Expand Down Expand Up @@ -257,6 +277,13 @@ async def run_async(self, path: Path, file_data: FileData, **kwargs) -> None: #
graph_data = _GraphData.model_validate(staged_data)
async with self.connection_config.get_client() as client:
await self._create_uniqueness_constraints(client)
embedding_dimensions = self._get_embedding_dimensions(graph_data)
if embedding_dimensions and self.upload_config.create_destination:
await self._create_vector_index(
client,
dimensions=embedding_dimensions,
similarity_function=self.upload_config.similarity_function,
)
await self._delete_old_data_if_exists(file_data, client=client)
await self._merge_graph(graph_data=graph_data, client=client)

Expand All @@ -274,13 +301,33 @@ async def _create_uniqueness_constraints(self, client: AsyncDriver) -> None:
"""
)

async def _create_vector_index(
self, client: AsyncDriver, dimensions: int, similarity_function: SimilarityFunction
) -> None:
label = Label.CHUNK
logger.info(
f"Creating index on nodes labeled '{label.value}' if it does not already exist."
)
index_name = f"{label.value.lower()}_vector"
await client.execute_query(
f"""
CREATE VECTOR INDEX {index_name} IF NOT EXISTS
FOR (n:{label.value}) ON n.embedding
OPTIONS {{indexConfig: {{
`vector.similarity_function`: '{similarity_function}',
`vector.dimensions`: {dimensions}}}
}}
"""
)

async def _delete_old_data_if_exists(self, file_data: FileData, client: AsyncDriver) -> None:
logger.info(f"Deleting old data for the record '{file_data.identifier}' (if present).")
_, summary, _ = await client.execute_query(
f"""
MATCH (n: {Label.DOCUMENT.value} {{id: $identifier}})
MATCH (n)--(m: {Label.CHUNK.value}|{Label.UNSTRUCTURED_ELEMENT.value})
DETACH DELETE m""",
MATCH (n: `{Label.DOCUMENT.value}` {{id: $identifier}})
MATCH (n)--(m: `{Label.CHUNK.value}`|`{Label.UNSTRUCTURED_ELEMENT.value}`)
DETACH DELETE m
DETACH DELETE n""",
identifier=file_data.identifier,
)
logger.info(
Expand All @@ -289,33 +336,39 @@ async def _delete_old_data_if_exists(self, file_data: FileData, client: AsyncDri
)

async def _merge_graph(self, graph_data: _GraphData, client: AsyncDriver) -> None:
nodes_by_labels: defaultdict[tuple[Label, ...], list[_Node]] = defaultdict(list)
nodes_by_labels: defaultdict[Label, list[_Node]] = defaultdict(list)
for node in graph_data.nodes:
nodes_by_labels[tuple(node.labels)].append(node)

nodes_by_labels[node.main_label].append(node)
logger.info(f"Merging {len(graph_data.nodes)} graph nodes.")
# NOTE: Processed in parallel as there's no overlap between accessed nodes
await self._execute_queries(
[
self._create_nodes_query(nodes_batch, labels)
for labels, nodes in nodes_by_labels.items()
self._create_nodes_query(nodes_batch, label)
for label, nodes in nodes_by_labels.items()
for nodes_batch in batch_generator(nodes, batch_size=self.upload_config.batch_size)
],
client=client,
in_parallel=True,
)
logger.info(f"Finished merging {len(graph_data.nodes)} graph nodes.")

edges_by_relationship: defaultdict[Relationship, list[_Edge]] = defaultdict(list)
edges_by_relationship: defaultdict[tuple[Relationship, Label, Label], list[_Edge]] = (
defaultdict(list)
)
for edge in graph_data.edges:
edges_by_relationship[edge.relationship].append(edge)
key = (edge.relationship, edge.source.main_label, edge.destination.main_label)
edges_by_relationship[key].append(edge)

logger.info(f"Merging {len(graph_data.edges)} graph relationships (edges).")
# NOTE: Processed sequentially to avoid queries locking node access to one another
await self._execute_queries(
[
self._create_edges_query(edges_batch, relationship)
for relationship, edges in edges_by_relationship.items()
self._create_edges_query(edges_batch, relationship, source_label, destination_label)
for (
relationship,
source_label,
destination_label,
), edges in edges_by_relationship.items()
for edges_batch in batch_generator(edges, batch_size=self.upload_config.batch_size)
],
client=client,
Expand All @@ -328,53 +381,86 @@ async def _execute_queries(
client: AsyncDriver,
in_parallel: bool = False,
) -> None:
from neo4j import EagerResult

results: list[EagerResult] = []
logger.info(
f"Executing {len(queries_with_parameters)} "
+ f"{'parallel' if in_parallel else 'sequential'} Cypher statements."
)
if in_parallel:
logger.info(f"Executing {len(queries_with_parameters)} queries in parallel.")
await asyncio.gather(
results = await asyncio.gather(
*[
client.execute_query(query, parameters_=parameters)
for query, parameters in queries_with_parameters
]
)
logger.info("Finished executing parallel queries.")
else:
logger.info(f"Executing {len(queries_with_parameters)} queries sequentially.")
for i, (query, parameters) in enumerate(queries_with_parameters):
logger.info(f"Query #{i} started.")
await client.execute_query(query, parameters_=parameters)
logger.info(f"Query #{i} finished.")
logger.info(
f"Finished executing all ({len(queries_with_parameters)}) sequential queries."
)
logger.info(f"Statement #{i} started.")
results.append(await client.execute_query(query, parameters_=parameters))
logger.info(f"Statement #{i} finished.")
nodeCount = sum([res.summary.counters.nodes_created for res in results])
relCount = sum([res.summary.counters.relationships_created for res in results])
logger.info(
f"Finished executing all ({len(queries_with_parameters)}) "
+ f"{'parallel' if in_parallel else 'sequential'} Cypher statements. "
+ f"Created {nodeCount} nodes, {relCount} relationships."
)

@staticmethod
def _create_nodes_query(nodes: list[_Node], labels: tuple[Label, ...]) -> tuple[str, dict]:
labels_string = ", ".join([label.value for label in labels])
logger.info(f"Preparing MERGE query for {len(nodes)} nodes labeled '{labels_string}'.")
def _create_nodes_query(nodes: list[_Node], label: Label) -> tuple[str, dict]:
logger.info(f"Preparing MERGE query for {len(nodes)} nodes labeled '{label}'.")
query_string = f"""
UNWIND $nodes AS node
MERGE (n: {labels_string} {{id: node.id}})
MERGE (n: `{label.value}` {{id: node.id}})
SET n += node.properties
SET n:$(node.labels)
WITH * WHERE node.vector IS NOT NULL
CALL db.create.setNodeVectorProperty(n, 'embedding', node.vector)
"""
parameters = {"nodes": [{"id": node.id_, "properties": node.properties} for node in nodes]}
parameters = {
"nodes": [
{
"id": node.id_,
"labels": [l.value for l in node.labels if l != label], # noqa: E741
"vector": node.properties.pop("embedding", None),
"properties": node.properties,
}
for node in nodes
]
}
return query_string, parameters

@staticmethod
def _create_edges_query(edges: list[_Edge], relationship: Relationship) -> tuple[str, dict]:
def _create_edges_query(
edges: list[_Edge],
relationship: Relationship,
source_label: Label,
destination_label: Label,
) -> tuple[str, dict]:
logger.info(f"Preparing MERGE query for {len(edges)} {relationship} relationships.")
query_string = f"""
UNWIND $edges AS edge
MATCH (u {{id: edge.source}})
MATCH (v {{id: edge.destination}})
MERGE (u)-[:{relationship.value}]->(v)
MATCH (u: `{source_label.value}` {{id: edge.source}})
MATCH (v: `{destination_label.value}` {{id: edge.destination}})
MERGE (u)-[:`{relationship.value}`]->(v)
"""
parameters = {
"edges": [
{"source": edge.source_id, "destination": edge.destination_id} for edge in edges
{"source": edge.source.id_, "destination": edge.destination.id_} for edge in edges
]
}
return query_string, parameters

def _get_embedding_dimensions(self, graph_data: _GraphData) -> int | None:
"""Embedding dimensions inferred from chunk nodes or None if it can't be determined."""
for node in graph_data.nodes:
if Label.CHUNK in node.labels and "embeddings" in node.properties:
return len(node.properties["embeddings"])

return None


neo4j_destination_entry = DestinationRegistryEntry(
connection_config=Neo4jConnectionConfig,
Expand Down

0 comments on commit 4edcd1b

Please sign in to comment.