Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new basic pipeline runner #1565

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20241227225850465466.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Fix gleanings loop check"
}
48 changes: 30 additions & 18 deletions graphrag/api/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.create_pipeline_config import create_pipeline_config
from graphrag.index.run import run_pipeline_with_config
from graphrag.index.run.run_workflows import run_workflows
from graphrag.index.typing import PipelineRunResult
from graphrag.logger.base import ProgressLogger

Expand All @@ -27,6 +28,7 @@ async def build_index(
memory_profile: bool = False,
callbacks: list[WorkflowCallbacks] | None = None,
progress_logger: ProgressLogger | None = None,
use_new_pipeline: bool = False,
natoverse marked this conversation as resolved.
Show resolved Hide resolved
) -> list[PipelineRunResult]:
"""Run the pipeline with the given configuration.

Expand Down Expand Up @@ -56,7 +58,6 @@ async def build_index(
msg = "Cannot resume and update a run at the same time."
raise ValueError(msg)

pipeline_config = create_pipeline_config(config)
pipeline_cache = (
NoopPipelineCache() if config.cache.type == CacheType.none is None else None
)
Expand All @@ -65,21 +66,32 @@ async def build_index(
callbacks = callbacks or []
callbacks.append(create_pipeline_reporter(config.reporting, None)) # type: ignore
outputs: list[PipelineRunResult] = []
async for output in run_pipeline_with_config(
pipeline_config,
run_id=run_id,
memory_profile=memory_profile,
cache=pipeline_cache,
callbacks=callbacks,
logger=progress_logger,
is_resume_run=is_resume_run,
is_update_run=is_update_run,
):
outputs.append(output)
if progress_logger:
if output.errors and len(output.errors) > 0:
progress_logger.error(output.workflow)
else:
progress_logger.success(output.workflow)
progress_logger.info(str(output.result))

if use_new_pipeline:
await run_workflows(
config,
cache=pipeline_cache,
logger=progress_logger,
run_id=run_id,
)
else:
pipeline_config = create_pipeline_config(config)
async for output in run_pipeline_with_config(
pipeline_config,
run_id=run_id,
memory_profile=memory_profile,
cache=pipeline_cache,
callbacks=callbacks,
logger=progress_logger,
is_resume_run=is_resume_run,
is_update_run=is_update_run,
):
outputs.append(output)
if progress_logger:
if output.errors and len(output.errors) > 0:
progress_logger.error(output.workflow)
else:
progress_logger.success(output.workflow)
progress_logger.info(str(output.result))

return outputs
1 change: 1 addition & 0 deletions graphrag/cli/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def _run_index(
is_resume_run=bool(resume),
memory_profile=memprofile,
progress_logger=progress_logger,
use_new_pipeline=True,
)
)
encountered_errors = any(
Expand Down
42 changes: 42 additions & 0 deletions graphrag/index/config/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@

"""A module containing embeddings values."""

from graphrag.config.enums import TextEmbeddingTarget
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.config.models.text_embedding_config import TextEmbeddingConfig

entity_title_embedding = "entity.title"
entity_description_embedding = "entity.description"
relationship_description_embedding = "relationship.description"
Expand All @@ -27,3 +31,41 @@
community_full_content_embedding,
text_unit_text_embedding,
}


def get_embedded_fields(settings: GraphRagConfig) -> set[str]:
"""Get the fields to embed based on the enum or specifically skipped embeddings."""
match settings.embeddings.target:
case TextEmbeddingTarget.all:
return all_embeddings.difference(settings.embeddings.skip)
case TextEmbeddingTarget.required:
return required_embeddings
case TextEmbeddingTarget.none:
return set()
case _:
msg = f"Unknown embeddings target: {settings.embeddings.target}"
raise ValueError(msg)


def get_embedding_settings(
settings: TextEmbeddingConfig,
vector_store_params: dict | None = None,
) -> dict:
"""Transform GraphRAG config into settings for workflows."""
# TEMP
vector_store_settings = settings.vector_store
if vector_store_settings is None:
return {"strategy": settings.resolved_strategy()}
#
# If we get to this point, settings.vector_store is defined, and there's a specific setting for this embedding.
# settings.vector_store.base contains connection information, or may be undefined
# settings.vector_store.<vector_name> contains the specific settings for this embedding
#
strategy = settings.resolved_strategy() # get the default strategy
strategy.update({
"vector_store": {**(vector_store_params or {}), **vector_store_settings}
}) # update the default strategy with the vector store settings
# This ensures the vector store config is part of the strategy and not the global config
return {
"strategy": strategy,
}
46 changes: 3 additions & 43 deletions graphrag/index/create_pipeline_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@
InputFileType,
ReportingType,
StorageType,
TextEmbeddingTarget,
)
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.config.models.storage_config import StorageConfig
from graphrag.config.models.text_embedding_config import TextEmbeddingConfig
from graphrag.index.config.cache import (
PipelineBlobCacheConfig,
PipelineCacheConfigTypes,
Expand All @@ -25,10 +23,7 @@
PipelineMemoryCacheConfig,
PipelineNoneCacheConfig,
)
from graphrag.index.config.embeddings import (
all_embeddings,
required_embeddings,
)
from graphrag.index.config.embeddings import get_embedded_fields, get_embedding_settings
from graphrag.index.config.input import (
PipelineCSVInputConfig,
PipelineInputConfigTypes,
Expand Down Expand Up @@ -92,7 +87,7 @@ def create_pipeline_config(settings: GraphRagConfig, verbose=False) -> PipelineC
_log_llm_settings(settings)

skip_workflows = settings.skip_workflows
embedded_fields = _get_embedded_fields(settings)
embedded_fields = get_embedded_fields(settings)
covariates_enabled = (
settings.claim_extraction.enabled
and create_final_covariates not in skip_workflows
Expand Down Expand Up @@ -123,19 +118,6 @@ def create_pipeline_config(settings: GraphRagConfig, verbose=False) -> PipelineC
return result


def _get_embedded_fields(settings: GraphRagConfig) -> set[str]:
match settings.embeddings.target:
case TextEmbeddingTarget.all:
return all_embeddings.difference(settings.embeddings.skip)
case TextEmbeddingTarget.required:
return required_embeddings
case TextEmbeddingTarget.none:
return set()
case _:
msg = f"Unknown embeddings target: {settings.embeddings.target}"
raise ValueError(msg)


def _log_llm_settings(settings: GraphRagConfig) -> None:
log.info(
"Using LLM Config %s",
Expand Down Expand Up @@ -189,28 +171,6 @@ def _text_unit_workflows(
]


def _get_embedding_settings(
settings: TextEmbeddingConfig,
vector_store_params: dict | None = None,
) -> dict:
vector_store_settings = settings.vector_store
if vector_store_settings is None:
return {"strategy": settings.resolved_strategy()}
#
# If we get to this point, settings.vector_store is defined, and there's a specific setting for this embedding.
# settings.vector_store.base contains connection information, or may be undefined
# settings.vector_store.<vector_name> contains the specific settings for this embedding
#
strategy = settings.resolved_strategy() # get the default strategy
strategy.update({
"vector_store": {**(vector_store_params or {}), **vector_store_settings}
}) # update the default strategy with the vector store settings
# This ensures the vector store config is part of the strategy and not the global config
return {
"strategy": strategy,
}


def _graph_workflows(settings: GraphRagConfig) -> list[PipelineWorkflowReference]:
return [
PipelineWorkflowReference(
Expand Down Expand Up @@ -307,7 +267,7 @@ def _embeddings_workflows(
name=generate_text_embeddings,
config={
"snapshot_embeddings": settings.snapshots.embeddings,
"text_embed": _get_embedding_settings(settings.embeddings),
"text_embed": get_embedding_settings(settings.embeddings),
"embedded_fields": embedded_fields,
},
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ async def _process_document(
history=response.history,
model_parameters=self._loop_args,
)
if response.output != "YES":

if response.output.content != "YES":
break

return results
Expand Down
103 changes: 103 additions & 0 deletions graphrag/index/run/run_workflows.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Different methods to run the pipeline."""

import logging
import time

from datashaper import NoopVerbCallbacks
from datashaper.progress.types import Progress

from graphrag.cache.factory import CacheFactory
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.input.factory import create_input
from graphrag.index.run.profiling import _dump_stats
from graphrag.index.run.utils import create_run_context
from graphrag.index.workflows.default_workflows import basic_workflows
from graphrag.logger.base import ProgressLogger
from graphrag.logger.null_progress import NullProgressLogger
from graphrag.storage.factory import StorageFactory

log = logging.getLogger(__name__)


default_workflows = [
"create_base_text_units",
"create_final_documents",
"extract_graph",
"create_final_covariates",
"compute_communities",
"create_final_entities",
"create_final_relationships",
"create_final_nodes",
"create_final_communities",
"create_final_text_units",
"create_final_community_reports",
"generate_text_embeddings",
]


async def run_workflows(
config: GraphRagConfig,
cache: PipelineCache | None = None,
logger: ProgressLogger | None = None,
run_id: str | None = None,
):
"""Run all workflows using a simplified pipeline."""
log.info("RUNNING NEW WORKFLOWS WITHOUT DATASHAPER")
start_time = time.time()

run_id = run_id or time.strftime("%Y%m%d-%H%M%S")
root_dir = config.root_dir or ""
progress_logger = logger or NullProgressLogger()
storage_config = config.storage.model_dump() # type: ignore
storage = StorageFactory().create_storage(
storage_type=storage_config["type"], # type: ignore
kwargs=storage_config,
)
cache_config = config.cache.model_dump() # type: ignore
cache = cache or CacheFactory().create_cache(
cache_type=cache_config["type"], # type: ignore
root_dir=root_dir,
kwargs=cache_config,
)

context = create_run_context(storage=storage, cache=cache, stats=None)

dataset = await create_input(config.input, progress_logger, root_dir)
log.info("Final # of rows loaded: %s", len(dataset))
context.stats.num_documents = len(dataset)

await context.runtime_storage.set("input", dataset)

for workflow in default_workflows:
run_workflow = basic_workflows[workflow]
progress = progress_logger.child(workflow, transient=False)
verb_callbacks = DelegatingCallbacks(progress)
work_time = time.time()
await run_workflow(
config,
context,
verb_callbacks,
)
progress(Progress(percent=1))
context.stats.workflows[workflow] = {"overall": time.time() - work_time}

context.stats.total_runtime = time.time() - start_time
await _dump_stats(context.stats, context.storage)


class DelegatingCallbacks(NoopVerbCallbacks):
"""TEMP: this is to wrap into DataShaper callbacks that the flows expect, until we create our own callback system."""

_progress: ProgressLogger

def __init__(self, progress: ProgressLogger):
self._progress = progress
self._progress(Progress(percent=0))

def progress(self, progress: Progress) -> None:
"""Handle when progress occurs."""
self._progress(progress)
Loading
Loading