Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions packages/graphrag/graphrag/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
EN_STOP_WORDS,
)

DEFAULT_INPUT_BASE_DIR = "input"
DEFAULT_OUTPUT_BASE_DIR = "output"
DEFAULT_UPDATE_OUTPUT_BASE_DIR = "update_output"
DEFAULT_CHAT_MODEL_ID = "default_chat_model"
DEFAULT_CHAT_MODEL_TYPE = ModelType.Chat
DEFAULT_CHAT_MODEL_AUTH_TYPE = AuthType.APIKey
Expand Down Expand Up @@ -229,7 +231,7 @@ class StorageDefaults:
"""Default values for storage."""

type: ClassVar[StorageType] = StorageType.file
base_dir: str = DEFAULT_OUTPUT_BASE_DIR
base_dir: str | None = None
connection_string: None = None
container_name: None = None
storage_account_blob_url: None = None
Expand All @@ -240,7 +242,7 @@ class StorageDefaults:
class InputStorageDefaults(StorageDefaults):
"""Default values for input storage."""

base_dir: str = "input"
base_dir: str | None = DEFAULT_INPUT_BASE_DIR


@dataclass
Expand Down Expand Up @@ -310,7 +312,7 @@ class LocalSearchDefaults:
class OutputDefaults(StorageDefaults):
"""Default values for output."""

base_dir: str = DEFAULT_OUTPUT_BASE_DIR
base_dir: str | None = DEFAULT_OUTPUT_BASE_DIR


@dataclass
Expand Down Expand Up @@ -362,7 +364,7 @@ class SummarizeDescriptionsDefaults:
class UpdateIndexOutputDefaults(StorageDefaults):
"""Default values for update index output."""

base_dir: str = "update_output"
base_dir: str | None = DEFAULT_UPDATE_OUTPUT_BASE_DIR


@dataclass
Expand Down
10 changes: 6 additions & 4 deletions packages/graphrag/graphrag/config/models/graph_rag_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def _validate_input_pattern(self) -> None:
def _validate_input_base_dir(self) -> None:
"""Validate the input base directory."""
if self.input.storage.type == defs.StorageType.file:
if self.input.storage.base_dir.strip() == "":
if not self.input.storage.base_dir:
msg = "input storage base directory is required for file input storage. Please rerun `graphrag init` and set the input storage configuration."
raise ValueError(msg)
self.input.storage.base_dir = str(
Expand All @@ -167,14 +167,16 @@ def _validate_input_base_dir(self) -> None:

output: StorageConfig = Field(
description="The output configuration.",
default=StorageConfig(),
default=StorageConfig(
base_dir=graphrag_config_defaults.output.base_dir,
),
)
"""The output configuration."""

def _validate_output_base_dir(self) -> None:
"""Validate the output base directory."""
if self.output.type == defs.StorageType.file:
if self.output.base_dir.strip() == "":
if not self.output.base_dir:
msg = "output base directory is required for file output. Please rerun `graphrag init` and set the output configuration."
raise ValueError(msg)
self.output.base_dir = str(
Expand All @@ -192,7 +194,7 @@ def _validate_output_base_dir(self) -> None:
def _validate_update_index_output_base_dir(self) -> None:
"""Validate the update index output base directory."""
if self.update_index_output.type == defs.StorageType.file:
if self.update_index_output.base_dir.strip() == "":
if not self.update_index_output.base_dir:
msg = "update_index_output base directory is required for file output. Please rerun `graphrag init` and set the update_index_output configuration."
raise ValueError(msg)
self.update_index_output.base_dir = str(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class StorageConfig(BaseModel):
description="The storage type to use.",
default=graphrag_config_defaults.storage.type,
)
base_dir: str = Field(
base_dir: str | None = Field(
description="The base directory for the output.",
default=graphrag_config_defaults.storage.base_dir,
)
Expand Down
108 changes: 28 additions & 80 deletions packages/graphrag/graphrag/storage/blob_pipeline_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ class BlobPipelineStorage(PipelineStorage):

_connection_string: str | None
_container_name: str
_path_prefix: str
_base_dir: str | None
_encoding: str
_storage_account_blob_url: str | None

def __init__(self, **kwargs: Any) -> None:
"""Create a new BlobStorage instance."""
connection_string = kwargs.get("connection_string")
storage_account_blob_url = kwargs.get("storage_account_blob_url")
path_prefix = kwargs.get("base_dir")
base_dir = kwargs.get("base_dir")
container_name = kwargs["container_name"]
if container_name is None:
msg = "No container name provided for blob storage."
Expand All @@ -42,7 +42,9 @@ def __init__(self, **kwargs: Any) -> None:
msg = "No storage account blob url provided for blob storage."
raise ValueError(msg)

logger.info("Creating blob storage at %s", container_name)
logger.info(
"Creating blob storage at [%s] and base_dir [%s]", container_name, base_dir
)
if connection_string:
self._blob_service_client = BlobServiceClient.from_connection_string(
connection_string
Expand All @@ -59,18 +61,13 @@ def __init__(self, **kwargs: Any) -> None:
self._encoding = kwargs.get("encoding", "utf-8")
self._container_name = container_name
self._connection_string = connection_string
self._path_prefix = path_prefix or ""
self._base_dir = base_dir
self._storage_account_blob_url = storage_account_blob_url
self._storage_account_name = (
storage_account_blob_url.split("//")[1].split(".")[0]
if storage_account_blob_url
else None
)
logger.debug(
"creating blob storage at container=%s, path=%s",
self._container_name,
self._path_prefix,
)
self._create_container()

def _create_container(self) -> None:
Expand All @@ -82,6 +79,7 @@ def _create_container(self) -> None:
for container in self._blob_service_client.list_containers()
]
if container_name not in container_names:
logger.debug("Creating new container [%s]", container_name)
self._blob_service_client.create_container(container_name)

def _delete_container(self) -> None:
Expand All @@ -100,31 +98,26 @@ def _container_exists(self) -> bool:
def find(
self,
file_pattern: re.Pattern[str],
base_dir: str | None = None,
max_count=-1,
) -> Iterator[str]:
"""Find blobs in a container using a file pattern.

Params:
base_dir: The name of the base container.
file_pattern: The file pattern to use.
max_count: The maximum number of blobs to return. If -1, all blobs are returned.

Returns
-------
An iterator of blob names and their corresponding regex matches.
"""
base_dir = base_dir or ""

logger.info(
"search container %s for files matching %s",
"Search container [%s] in base_dir [%s] for files matching [%s]",
self._container_name,
self._base_dir,
file_pattern.pattern,
)

def _blobname(blob_name: str) -> str:
if blob_name.startswith(self._path_prefix):
blob_name = blob_name.replace(self._path_prefix, "", 1)
if self._base_dir and blob_name.startswith(self._base_dir):
blob_name = blob_name.replace(self._base_dir, "", 1)
if blob_name.startswith("/"):
blob_name = blob_name[1:]
return blob_name
Expand All @@ -133,37 +126,35 @@ def _blobname(blob_name: str) -> str:
container_client = self._blob_service_client.get_container_client(
self._container_name
)
all_blobs = list(container_client.list_blobs())

all_blobs = list(container_client.list_blobs(self._base_dir))
logger.debug("All blobs: %s", [blob.name for blob in all_blobs])
num_loaded = 0
num_total = len(list(all_blobs))
num_filtered = 0
for blob in all_blobs:
match = file_pattern.search(blob.name)
if match and blob.name.startswith(base_dir):
if match:
yield _blobname(blob.name)
num_loaded += 1
if max_count > 0 and num_loaded >= max_count:
break
else:
num_filtered += 1
logger.debug(
"Blobs loaded: %d, filtered: %d, total: %d",
num_loaded,
num_filtered,
num_total,
)
logger.debug(
"Blobs loaded: %d, filtered: %d, total: %d",
num_loaded,
num_filtered,
num_total,
)
except Exception: # noqa: BLE001
logger.warning(
"Error finding blobs: base_dir=%s, file_pattern=%s",
base_dir,
self._base_dir,
file_pattern,
)

async def get(
self, key: str, as_bytes: bool | None = False, encoding: str | None = None
) -> Any:
"""Get a value from the cache."""
"""Get a value from the blob."""
try:
key = self._keyname(key)
container_client = self._blob_service_client.get_container_client(
Expand All @@ -181,7 +172,7 @@ async def get(
return blob_data

async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
"""Set a value in the cache."""
"""Set a value in the blob."""
try:
key = self._keyname(key)
container_client = self._blob_service_client.get_container_client(
Expand All @@ -196,46 +187,8 @@ async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
except Exception:
logger.exception("Error setting key %s: %s", key)

def _set_df_json(self, key: str, dataframe: Any) -> None:
"""Set a json dataframe."""
if self._connection_string is None and self._storage_account_name:
dataframe.to_json(
self._abfs_url(key),
storage_options={
"account_name": self._storage_account_name,
"credential": DefaultAzureCredential(),
},
orient="records",
lines=True,
force_ascii=False,
)
else:
dataframe.to_json(
self._abfs_url(key),
storage_options={"connection_string": self._connection_string},
orient="records",
lines=True,
force_ascii=False,
)

def _set_df_parquet(self, key: str, dataframe: Any) -> None:
"""Set a parquet dataframe."""
if self._connection_string is None and self._storage_account_name:
dataframe.to_parquet(
self._abfs_url(key),
storage_options={
"account_name": self._storage_account_name,
"credential": DefaultAzureCredential(),
},
)
else:
dataframe.to_parquet(
self._abfs_url(key),
storage_options={"connection_string": self._connection_string},
)

async def has(self, key: str) -> bool:
"""Check if a key exists in the cache."""
"""Check if a key exists in the blob."""
key = self._keyname(key)
container_client = self._blob_service_client.get_container_client(
self._container_name
Expand All @@ -244,7 +197,7 @@ async def has(self, key: str) -> bool:
return blob_client.exists()

async def delete(self, key: str) -> None:
"""Delete a key from the cache."""
"""Delete a key from the blob."""
key = self._keyname(key)
container_client = self._blob_service_client.get_container_client(
self._container_name
Expand All @@ -259,7 +212,7 @@ def child(self, name: str | None) -> "PipelineStorage":
"""Create a child storage instance."""
if name is None:
return self
path = str(Path(self._path_prefix) / name)
path = str(Path(self._base_dir) / name) if self._base_dir else name
return BlobPipelineStorage(
connection_string=self._connection_string,
container_name=self._container_name,
Expand All @@ -275,15 +228,10 @@ def keys(self) -> list[str]:

def _keyname(self, key: str) -> str:
"""Get the key name."""
return str(Path(self._path_prefix) / key)

def _abfs_url(self, key: str) -> str:
"""Get the ABFS URL."""
path = str(Path(self._container_name) / self._path_prefix / key)
return f"abfs://{path}"
return str(Path(self._base_dir) / key) if self._base_dir else key

async def get_creation_date(self, key: str) -> str:
"""Get a value from the cache."""
"""Get creation date for the blob."""
try:
key = self._keyname(key)
container_client = self._blob_service_client.get_container_client(
Expand Down
30 changes: 12 additions & 18 deletions packages/graphrag/graphrag/storage/cosmosdb_pipeline_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(self, **kwargs: Any) -> None:
)
self._no_id_prefixes = []
logger.debug(
"creating cosmosdb storage with account: %s and database: %s and container: %s",
"Creating cosmosdb storage with account [%s] and database [%s] and container [%s]",
self._cosmosdb_account_name,
self._database_name,
self._container_name,
Expand Down Expand Up @@ -120,23 +120,18 @@ def _delete_container(self) -> None:
def find(
self,
file_pattern: re.Pattern[str],
base_dir: str | None = None,
max_count=-1,
) -> Iterator[str]:
"""Find documents in a Cosmos DB container using a file pattern regex.

Params:
base_dir: The name of the base directory (not used in Cosmos DB context).
file_pattern: The file pattern to use.
max_count: The maximum number of documents to return. If -1, all documents are returned.

Returns
-------
An iterator of document IDs and their corresponding regex matches.
"""
base_dir = base_dir or ""
logger.info(
"search container %s for documents matching %s",
"Search container [%s] for documents matching [%s]",
self._container_name,
file_pattern.pattern,
)
Expand All @@ -156,6 +151,7 @@ def find(
enable_cross_partition_query=True,
)
)
logger.debug("All items: %s", [item["id"] for item in items])
num_loaded = 0
num_total = len(items)
if num_total == 0:
Expand All @@ -166,20 +162,18 @@ def find(
if match:
yield item["id"]
num_loaded += 1
if max_count > 0 and num_loaded >= max_count:
break
else:
num_filtered += 1

progress_status = _create_progress_status(
num_loaded, num_filtered, num_total
)
logger.debug(
"Progress: %s (%d/%d completed)",
progress_status.description,
progress_status.completed_items,
progress_status.total_items,
)
progress_status = _create_progress_status(
num_loaded, num_filtered, num_total
)
logger.debug(
"Progress: %s (%d/%d completed)",
progress_status.description,
progress_status.completed_items,
progress_status.total_items,
)
except Exception: # noqa: BLE001
logger.warning(
"An error occurred while searching for documents in Cosmos DB."
Expand Down
Loading
Loading