Skip to content

Commit

Permalink
Chore/lib updates (#1477)
Browse files Browse the repository at this point in the history
* Update dependencies and fix issues

* Format

* Semver

* Fix Pyright

* Pyright

* More Pyright

* Pyright
  • Loading branch information
AlonsoGuevara authored Dec 6, 2024
1 parent b1f2ca7 commit 1c3b0f3
Show file tree
Hide file tree
Showing 71 changed files with 554 additions and 537 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20241206190229362643.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Dependency updates"
}
Original file line number Diff line number Diff line change
Expand Up @@ -209,15 +209,16 @@
"source": [
"# create constraints, idempotent operation\n",
"\n",
"statements = \"\"\"\n",
"create constraint chunk_id if not exists for (c:__Chunk__) require c.id is unique;\n",
"create constraint document_id if not exists for (d:__Document__) require d.id is unique;\n",
"create constraint entity_id if not exists for (c:__Community__) require c.community is unique;\n",
"create constraint entity_id if not exists for (e:__Entity__) require e.id is unique;\n",
"create constraint entity_title if not exists for (e:__Entity__) require e.name is unique;\n",
"create constraint entity_title if not exists for (e:__Covariate__) require e.title is unique;\n",
"create constraint related_id if not exists for ()-[rel:RELATED]->() require rel.id is unique;\n",
"\"\"\".split(\";\")\n",
"statements = [\n",
" \"\\ncreate constraint chunk_id if not exists for (c:__Chunk__) require c.id is unique\",\n",
" \"\\ncreate constraint document_id if not exists for (d:__Document__) require d.id is unique\",\n",
" \"\\ncreate constraint entity_id if not exists for (c:__Community__) require c.community is unique\",\n",
" \"\\ncreate constraint entity_id if not exists for (e:__Entity__) require e.id is unique\",\n",
" \"\\ncreate constraint entity_title if not exists for (e:__Entity__) require e.name is unique\",\n",
" \"\\ncreate constraint entity_title if not exists for (e:__Covariate__) require e.title is unique\",\n",
" \"\\ncreate constraint related_id if not exists for ()-[rel:RELATED]->() require rel.id is unique\",\n",
" \"\\n\",\n",
"]\n",
"\n",
"for statement in statements:\n",
" if len((statement or \"\").strip()) > 0:\n",
Expand Down
52 changes: 26 additions & 26 deletions graphrag/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
read_indexer_reports,
read_indexer_text_units,
)
from graphrag.query.structured_search.base import SearchResult # noqa: TCH001
from graphrag.query.structured_search.base import SearchResult # noqa: TC001
from graphrag.utils.cli import redact
from graphrag.utils.embeddings import create_collection_name
from graphrag.vector_stores.base import BaseVectorStore
Expand Down Expand Up @@ -90,14 +90,14 @@ async def global_search(
------
TODO: Document any exceptions to expect.
"""
_communities = read_indexer_communities(communities, nodes, community_reports)
communities_ = read_indexer_communities(communities, nodes, community_reports)
reports = read_indexer_reports(
community_reports,
nodes,
community_level=community_level,
dynamic_community_selection=dynamic_community_selection,
)
_entities = read_indexer_entities(nodes, entities, community_level=community_level)
entities_ = read_indexer_entities(nodes, entities, community_level=community_level)
map_prompt = _load_search_prompt(config.root_dir, config.global_search.map_prompt)
reduce_prompt = _load_search_prompt(
config.root_dir, config.global_search.reduce_prompt
Expand All @@ -109,8 +109,8 @@ async def global_search(
search_engine = get_global_search_engine(
config,
reports=reports,
entities=_entities,
communities=_communities,
entities=entities_,
communities=communities_,
response_type=response_type,
dynamic_community_selection=dynamic_community_selection,
map_system_prompt=map_prompt,
Expand Down Expand Up @@ -159,14 +159,14 @@ async def global_search_streaming(
------
TODO: Document any exceptions to expect.
"""
_communities = read_indexer_communities(communities, nodes, community_reports)
communities_ = read_indexer_communities(communities, nodes, community_reports)
reports = read_indexer_reports(
community_reports,
nodes,
community_level=community_level,
dynamic_community_selection=dynamic_community_selection,
)
_entities = read_indexer_entities(nodes, entities, community_level=community_level)
entities_ = read_indexer_entities(nodes, entities, community_level=community_level)
map_prompt = _load_search_prompt(config.root_dir, config.global_search.map_prompt)
reduce_prompt = _load_search_prompt(
config.root_dir, config.global_search.reduce_prompt
Expand All @@ -178,8 +178,8 @@ async def global_search_streaming(
search_engine = get_global_search_engine(
config,
reports=reports,
entities=_entities,
communities=_communities,
entities=entities_,
communities=communities_,
response_type=response_type,
dynamic_community_selection=dynamic_community_selection,
map_system_prompt=map_prompt,
Expand Down Expand Up @@ -258,17 +258,17 @@ async def local_search(
embedding_name=entity_description_embedding,
)

_entities = read_indexer_entities(nodes, entities, community_level)
_covariates = read_indexer_covariates(covariates) if covariates is not None else []
entities_ = read_indexer_entities(nodes, entities, community_level)
covariates_ = read_indexer_covariates(covariates) if covariates is not None else []
prompt = _load_search_prompt(config.root_dir, config.local_search.prompt)

search_engine = get_local_search_engine(
config=config,
reports=read_indexer_reports(community_reports, nodes, community_level),
text_units=read_indexer_text_units(text_units),
entities=_entities,
entities=entities_,
relationships=read_indexer_relationships(relationships),
covariates={"claims": _covariates},
covariates={"claims": covariates_},
description_embedding_store=description_embedding_store, # type: ignore
response_type=response_type,
system_prompt=prompt,
Expand Down Expand Up @@ -334,17 +334,17 @@ async def local_search_streaming(
embedding_name=entity_description_embedding,
)

_entities = read_indexer_entities(nodes, entities, community_level)
_covariates = read_indexer_covariates(covariates) if covariates is not None else []
entities_ = read_indexer_entities(nodes, entities, community_level)
covariates_ = read_indexer_covariates(covariates) if covariates is not None else []
prompt = _load_search_prompt(config.root_dir, config.local_search.prompt)

search_engine = get_local_search_engine(
config=config,
reports=read_indexer_reports(community_reports, nodes, community_level),
text_units=read_indexer_text_units(text_units),
entities=_entities,
entities=entities_,
relationships=read_indexer_relationships(relationships),
covariates={"claims": _covariates},
covariates={"claims": covariates_},
description_embedding_store=description_embedding_store, # type: ignore
response_type=response_type,
system_prompt=prompt,
Expand Down Expand Up @@ -424,15 +424,15 @@ async def drift_search(
embedding_name=community_full_content_embedding,
)

_entities = read_indexer_entities(nodes, entities, community_level)
_reports = read_indexer_reports(community_reports, nodes, community_level)
read_indexer_report_embeddings(_reports, full_content_embedding_store)
entities_ = read_indexer_entities(nodes, entities, community_level)
reports = read_indexer_reports(community_reports, nodes, community_level)
read_indexer_report_embeddings(reports, full_content_embedding_store)
prompt = _load_search_prompt(config.root_dir, config.drift_search.prompt)
search_engine = get_drift_search_engine(
config=config,
reports=_reports,
reports=reports,
text_units=read_indexer_text_units(text_units),
entities=_entities,
entities=entities_,
relationships=read_indexer_relationships(relationships),
description_embedding_store=description_embedding_store, # type: ignore
local_system_prompt=prompt,
Expand Down Expand Up @@ -492,9 +492,9 @@ def _patch_vector_store(
db_uri=config.embeddings.vector_store["db_uri"]
)
# dump embeddings from the entities list to the description_embedding_store
_entities = read_indexer_entities(nodes, entities, community_level)
entities_ = read_indexer_entities(nodes, entities, community_level)
store_entity_semantic_embeddings(
entities=_entities, vectorstore=description_embedding_store
entities=entities_, vectorstore=description_embedding_store
)

if with_reports is not None:
Expand All @@ -506,7 +506,7 @@ def _patch_vector_store(
community_reports = with_reports
container_name = config.embeddings.vector_store["container_name"]
# Store report embeddings
_reports = read_indexer_reports(
reports = read_indexer_reports(
community_reports,
nodes,
community_level,
Expand All @@ -526,7 +526,7 @@ def _patch_vector_store(
)
# dump embeddings from the reports list to the full_content_embedding_store
store_reports_semantic_embeddings(
reports=_reports, vectorstore=full_content_embedding_store
reports=reports, vectorstore=full_content_embedding_store
)

return config
Expand Down
10 changes: 4 additions & 6 deletions graphrag/cache/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,15 @@
from typing import TYPE_CHECKING, cast

from graphrag.config.enums import CacheType
from graphrag.index.config.cache import (
PipelineBlobCacheConfig,
PipelineFileCacheConfig,
)
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
from graphrag.storage.file_pipeline_storage import FilePipelineStorage

if TYPE_CHECKING:
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.index.config.cache import (
PipelineBlobCacheConfig,
PipelineCacheConfig,
PipelineFileCacheConfig,
)

from graphrag.cache.json_pipeline_cache import JsonPipelineCache
Expand All @@ -39,11 +37,11 @@ def create_cache(
case CacheType.memory:
return InMemoryCache()
case CacheType.file:
config = cast(PipelineFileCacheConfig, config)
config = cast("PipelineFileCacheConfig", config)
storage = FilePipelineStorage(root_dir).child(config.base_dir)
return JsonPipelineCache(storage)
case CacheType.blob:
config = cast(PipelineBlobCacheConfig, config)
config = cast("PipelineBlobCacheConfig", config)
storage = BlobPipelineStorage(
config.connection_string,
config.container_name,
Expand Down
4 changes: 2 additions & 2 deletions graphrag/callbacks/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ def create_pipeline_reporter(

match config.type:
case ReportingType.file:
config = cast(PipelineFileReportingConfig, config)
config = cast("PipelineFileReportingConfig", config)
return FileWorkflowCallbacks(
str(Path(root_dir or "") / (config.base_dir or ""))
)
case ReportingType.console:
return ConsoleWorkflowCallbacks()
case ReportingType.blob:
config = cast(PipelineBlobReportingConfig, config)
config = cast("PipelineBlobReportingConfig", config)
return BlobWorkflowCallbacks(
config.connection_string,
config.container_name,
Expand Down
19 changes: 15 additions & 4 deletions graphrag/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,26 +46,37 @@ def wildcard_match(string: str, pattern: str) -> bool:
regex = re.escape(pattern).replace(r"\?", ".").replace(r"\*", ".*")
return re.fullmatch(regex, string) is not None

from pathlib import Path

def completer(incomplete: str) -> list[str]:
items = os.listdir()
# List items in the current directory as Path objects
items = Path().iterdir()
completions = []

for item in items:
if not file_okay and Path(item).is_file():
# Filter based on file/directory properties
if not file_okay and item.is_file():
continue
if not dir_okay and Path(item).is_dir():
if not dir_okay and item.is_dir():
continue
if readable and not os.access(item, os.R_OK):
continue
if writable and not os.access(item, os.W_OK):
continue
completions.append(item)

# Append the name of the matching item
completions.append(item.name)

# Apply wildcard matching if required
if match_wildcard:
completions = filter(
lambda i: wildcard_match(i, match_wildcard)
if match_wildcard
else False,
completions,
)

# Return completions that start with the given incomplete string
return [i for i in completions if i.startswith(incomplete)]

return completer
Expand Down
2 changes: 1 addition & 1 deletion graphrag/config/create_graphrag_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def create_graphrag_config(
values = values or {}
root_dir = root_dir or str(Path.cwd())
env = _make_env(root_dir)
_token_replace(cast(dict, values))
_token_replace(cast("dict", values))
InputModelValidator.validate_python(values, strict=True)

reader = EnvironmentReader(env)
Expand Down
7 changes: 7 additions & 0 deletions graphrag/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@
overwrite: true\
"""

VECTOR_STORE_DICT = {
"type": VectorStoreType.LanceDB.value,
"db_uri": str(Path(STORAGE_BASE_DIR) / "lancedb"),
"container_name": "default",
"overwrite": True,
}

# Local Search
LOCAL_SEARCH_TEXT_UNIT_PROP = 0.5
LOCAL_SEARCH_COMMUNITY_PROP = 0.1
Expand Down
2 changes: 1 addition & 1 deletion graphrag/config/models/graph_rag_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __str__(self):
return self.model_dump_json(indent=4)

root_dir: str = Field(
description="The root directory for the configuration.", default=None
description="The root directory for the configuration.", default="."
)

reporting: ReportingConfig = Field(
Expand Down
2 changes: 1 addition & 1 deletion graphrag/config/models/text_embedding_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class TextEmbeddingConfig(LLMConfig):
)
skip: list[str] = Field(description="The specific embeddings to skip.", default=[])
vector_store: dict | None = Field(
description="The vector storage configuration", default=defs.VECTOR_STORE
description="The vector storage configuration", default=defs.VECTOR_STORE_DICT
)
strategy: dict | None = Field(
description="The override strategy to use.", default=None
Expand Down
15 changes: 6 additions & 9 deletions graphrag/index/config/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@

from typing import Generic, Literal, TypeVar

from pydantic import BaseModel
from pydantic import Field as pydantic_Field
from pydantic import BaseModel, Field

from graphrag.config.enums import CacheType

Expand All @@ -27,7 +26,7 @@ class PipelineFileCacheConfig(PipelineCacheConfig[Literal[CacheType.file]]):
type: Literal[CacheType.file] = CacheType.file
"""The type of cache."""

base_dir: str | None = pydantic_Field(
base_dir: str | None = Field(
description="The base directory for the cache.", default=None
)
"""The base directory for the cache."""
Expand All @@ -53,22 +52,20 @@ class PipelineBlobCacheConfig(PipelineCacheConfig[Literal[CacheType.blob]]):
type: Literal[CacheType.blob] = CacheType.blob
"""The type of cache."""

base_dir: str | None = pydantic_Field(
base_dir: str | None = Field(
description="The base directory for the cache.", default=None
)
"""The base directory for the cache."""

connection_string: str | None = pydantic_Field(
connection_string: str | None = Field(
description="The blob cache connection string for the cache.", default=None
)
"""The blob cache connection string for the cache."""

container_name: str = pydantic_Field(
description="The container name for cache", default=None
)
container_name: str = Field(description="The container name for cache", default="")
"""The container name for cache"""

storage_account_blob_url: str | None = pydantic_Field(
storage_account_blob_url: str | None = Field(
description="The storage account blob url for cache", default=None
)
"""The storage account blob url for cache"""
Expand Down
Loading

0 comments on commit 1c3b0f3

Please sign in to comment.