diff --git a/packages/graphrag/graphrag/config/defaults.py b/packages/graphrag/graphrag/config/defaults.py index 8a04851682..f57bf610f9 100644 --- a/packages/graphrag/graphrag/config/defaults.py +++ b/packages/graphrag/graphrag/config/defaults.py @@ -24,7 +24,9 @@ EN_STOP_WORDS, ) +DEFAULT_INPUT_BASE_DIR = "input" DEFAULT_OUTPUT_BASE_DIR = "output" +DEFAULT_UPDATE_OUTPUT_BASE_DIR = "update_output" DEFAULT_CHAT_MODEL_ID = "default_chat_model" DEFAULT_CHAT_MODEL_TYPE = ModelType.Chat DEFAULT_CHAT_MODEL_AUTH_TYPE = AuthType.APIKey @@ -229,7 +231,7 @@ class StorageDefaults: """Default values for storage.""" type: ClassVar[StorageType] = StorageType.file - base_dir: str = DEFAULT_OUTPUT_BASE_DIR + base_dir: str | None = None connection_string: None = None container_name: None = None storage_account_blob_url: None = None @@ -240,7 +242,7 @@ class StorageDefaults: class InputStorageDefaults(StorageDefaults): """Default values for input storage.""" - base_dir: str = "input" + base_dir: str | None = DEFAULT_INPUT_BASE_DIR @dataclass @@ -310,7 +312,7 @@ class LocalSearchDefaults: class OutputDefaults(StorageDefaults): """Default values for output.""" - base_dir: str = DEFAULT_OUTPUT_BASE_DIR + base_dir: str | None = DEFAULT_OUTPUT_BASE_DIR @dataclass @@ -362,7 +364,7 @@ class SummarizeDescriptionsDefaults: class UpdateIndexOutputDefaults(StorageDefaults): """Default values for update index output.""" - base_dir: str = "update_output" + base_dir: str | None = DEFAULT_UPDATE_OUTPUT_BASE_DIR @dataclass diff --git a/packages/graphrag/graphrag/config/models/graph_rag_config.py b/packages/graphrag/graphrag/config/models/graph_rag_config.py index 6a6a98e973..71509f0176 100644 --- a/packages/graphrag/graphrag/config/models/graph_rag_config.py +++ b/packages/graphrag/graphrag/config/models/graph_rag_config.py @@ -152,7 +152,7 @@ def _validate_input_pattern(self) -> None: def _validate_input_base_dir(self) -> None: """Validate the input base directory.""" if self.input.storage.type == defs.StorageType.file: - if self.input.storage.base_dir.strip() == "": + if not self.input.storage.base_dir: msg = "input storage base directory is required for file input storage. Please rerun `graphrag init` and set the input storage configuration." raise ValueError(msg) self.input.storage.base_dir = str( @@ -167,14 +167,16 @@ def _validate_input_base_dir(self) -> None: output: StorageConfig = Field( description="The output configuration.", - default=StorageConfig(), + default=StorageConfig( + base_dir=graphrag_config_defaults.output.base_dir, + ), ) """The output configuration.""" def _validate_output_base_dir(self) -> None: """Validate the output base directory.""" if self.output.type == defs.StorageType.file: - if self.output.base_dir.strip() == "": + if not self.output.base_dir: msg = "output base directory is required for file output. Please rerun `graphrag init` and set the output configuration." raise ValueError(msg) self.output.base_dir = str( @@ -192,7 +194,7 @@ def _validate_output_base_dir(self) -> None: def _validate_update_index_output_base_dir(self) -> None: """Validate the update index output base directory.""" if self.update_index_output.type == defs.StorageType.file: - if self.update_index_output.base_dir.strip() == "": + if not self.update_index_output.base_dir: msg = "update_index_output base directory is required for file output. Please rerun `graphrag init` and set the update_index_output configuration." raise ValueError(msg) self.update_index_output.base_dir = str( diff --git a/packages/graphrag/graphrag/config/models/storage_config.py b/packages/graphrag/graphrag/config/models/storage_config.py index 3f01448c66..7491454c0a 100644 --- a/packages/graphrag/graphrag/config/models/storage_config.py +++ b/packages/graphrag/graphrag/config/models/storage_config.py @@ -18,7 +18,7 @@ class StorageConfig(BaseModel): description="The storage type to use.", default=graphrag_config_defaults.storage.type, ) - base_dir: str = Field( + base_dir: str | None = Field( description="The base directory for the output.", default=graphrag_config_defaults.storage.base_dir, ) diff --git a/packages/graphrag/graphrag/storage/blob_pipeline_storage.py b/packages/graphrag/graphrag/storage/blob_pipeline_storage.py index 5a00af85b5..1435cb387d 100644 --- a/packages/graphrag/graphrag/storage/blob_pipeline_storage.py +++ b/packages/graphrag/graphrag/storage/blob_pipeline_storage.py @@ -25,7 +25,7 @@ class BlobPipelineStorage(PipelineStorage): _connection_string: str | None _container_name: str - _path_prefix: str + _base_dir: str | None _encoding: str _storage_account_blob_url: str | None @@ -33,7 +33,7 @@ def __init__(self, **kwargs: Any) -> None: """Create a new BlobStorage instance.""" connection_string = kwargs.get("connection_string") storage_account_blob_url = kwargs.get("storage_account_blob_url") - path_prefix = kwargs.get("base_dir") + base_dir = kwargs.get("base_dir") container_name = kwargs["container_name"] if container_name is None: msg = "No container name provided for blob storage." @@ -42,7 +42,9 @@ def __init__(self, **kwargs: Any) -> None: msg = "No storage account blob url provided for blob storage." raise ValueError(msg) - logger.info("Creating blob storage at %s", container_name) + logger.info( + "Creating blob storage at [%s] and base_dir [%s]", container_name, base_dir + ) if connection_string: self._blob_service_client = BlobServiceClient.from_connection_string( connection_string @@ -59,18 +61,13 @@ def __init__(self, **kwargs: Any) -> None: self._encoding = kwargs.get("encoding", "utf-8") self._container_name = container_name self._connection_string = connection_string - self._path_prefix = path_prefix or "" + self._base_dir = base_dir self._storage_account_blob_url = storage_account_blob_url self._storage_account_name = ( storage_account_blob_url.split("//")[1].split(".")[0] if storage_account_blob_url else None ) - logger.debug( - "creating blob storage at container=%s, path=%s", - self._container_name, - self._path_prefix, - ) self._create_container() def _create_container(self) -> None: @@ -82,6 +79,7 @@ def _create_container(self) -> None: for container in self._blob_service_client.list_containers() ] if container_name not in container_names: + logger.debug("Creating new container [%s]", container_name) self._blob_service_client.create_container(container_name) def _delete_container(self) -> None: @@ -100,31 +98,26 @@ def _container_exists(self) -> bool: def find( self, file_pattern: re.Pattern[str], - base_dir: str | None = None, - max_count=-1, ) -> Iterator[str]: """Find blobs in a container using a file pattern. Params: - base_dir: The name of the base container. file_pattern: The file pattern to use. - max_count: The maximum number of blobs to return. If -1, all blobs are returned. Returns ------- An iterator of blob names and their corresponding regex matches. """ - base_dir = base_dir or "" - logger.info( - "search container %s for files matching %s", + "Search container [%s] in base_dir [%s] for files matching [%s]", self._container_name, + self._base_dir, file_pattern.pattern, ) def _blobname(blob_name: str) -> str: - if blob_name.startswith(self._path_prefix): - blob_name = blob_name.replace(self._path_prefix, "", 1) + if self._base_dir and blob_name.startswith(self._base_dir): + blob_name = blob_name.replace(self._base_dir, "", 1) if blob_name.startswith("/"): blob_name = blob_name[1:] return blob_name @@ -133,37 +126,35 @@ def _blobname(blob_name: str) -> str: container_client = self._blob_service_client.get_container_client( self._container_name ) - all_blobs = list(container_client.list_blobs()) - + all_blobs = list(container_client.list_blobs(self._base_dir)) + logger.debug("All blobs: %s", [blob.name for blob in all_blobs]) num_loaded = 0 num_total = len(list(all_blobs)) num_filtered = 0 for blob in all_blobs: match = file_pattern.search(blob.name) - if match and blob.name.startswith(base_dir): + if match: yield _blobname(blob.name) num_loaded += 1 - if max_count > 0 and num_loaded >= max_count: - break else: num_filtered += 1 - logger.debug( - "Blobs loaded: %d, filtered: %d, total: %d", - num_loaded, - num_filtered, - num_total, - ) + logger.debug( + "Blobs loaded: %d, filtered: %d, total: %d", + num_loaded, + num_filtered, + num_total, + ) except Exception: # noqa: BLE001 logger.warning( "Error finding blobs: base_dir=%s, file_pattern=%s", - base_dir, + self._base_dir, file_pattern, ) async def get( self, key: str, as_bytes: bool | None = False, encoding: str | None = None ) -> Any: - """Get a value from the cache.""" + """Get a value from the blob.""" try: key = self._keyname(key) container_client = self._blob_service_client.get_container_client( @@ -181,7 +172,7 @@ async def get( return blob_data async def set(self, key: str, value: Any, encoding: str | None = None) -> None: - """Set a value in the cache.""" + """Set a value in the blob.""" try: key = self._keyname(key) container_client = self._blob_service_client.get_container_client( @@ -196,46 +187,8 @@ async def set(self, key: str, value: Any, encoding: str | None = None) -> None: except Exception: logger.exception("Error setting key %s: %s", key) - def _set_df_json(self, key: str, dataframe: Any) -> None: - """Set a json dataframe.""" - if self._connection_string is None and self._storage_account_name: - dataframe.to_json( - self._abfs_url(key), - storage_options={ - "account_name": self._storage_account_name, - "credential": DefaultAzureCredential(), - }, - orient="records", - lines=True, - force_ascii=False, - ) - else: - dataframe.to_json( - self._abfs_url(key), - storage_options={"connection_string": self._connection_string}, - orient="records", - lines=True, - force_ascii=False, - ) - - def _set_df_parquet(self, key: str, dataframe: Any) -> None: - """Set a parquet dataframe.""" - if self._connection_string is None and self._storage_account_name: - dataframe.to_parquet( - self._abfs_url(key), - storage_options={ - "account_name": self._storage_account_name, - "credential": DefaultAzureCredential(), - }, - ) - else: - dataframe.to_parquet( - self._abfs_url(key), - storage_options={"connection_string": self._connection_string}, - ) - async def has(self, key: str) -> bool: - """Check if a key exists in the cache.""" + """Check if a key exists in the blob.""" key = self._keyname(key) container_client = self._blob_service_client.get_container_client( self._container_name @@ -244,7 +197,7 @@ async def has(self, key: str) -> bool: return blob_client.exists() async def delete(self, key: str) -> None: - """Delete a key from the cache.""" + """Delete a key from the blob.""" key = self._keyname(key) container_client = self._blob_service_client.get_container_client( self._container_name @@ -259,7 +212,7 @@ def child(self, name: str | None) -> "PipelineStorage": """Create a child storage instance.""" if name is None: return self - path = str(Path(self._path_prefix) / name) + path = str(Path(self._base_dir) / name) if self._base_dir else name return BlobPipelineStorage( connection_string=self._connection_string, container_name=self._container_name, @@ -275,15 +228,10 @@ def keys(self) -> list[str]: def _keyname(self, key: str) -> str: """Get the key name.""" - return str(Path(self._path_prefix) / key) - - def _abfs_url(self, key: str) -> str: - """Get the ABFS URL.""" - path = str(Path(self._container_name) / self._path_prefix / key) - return f"abfs://{path}" + return str(Path(self._base_dir) / key) if self._base_dir else key async def get_creation_date(self, key: str) -> str: - """Get a value from the cache.""" + """Get creation date for the blob.""" try: key = self._keyname(key) container_client = self._blob_service_client.get_container_client( diff --git a/packages/graphrag/graphrag/storage/cosmosdb_pipeline_storage.py b/packages/graphrag/graphrag/storage/cosmosdb_pipeline_storage.py index 8d2673e89b..a12da0ee5f 100644 --- a/packages/graphrag/graphrag/storage/cosmosdb_pipeline_storage.py +++ b/packages/graphrag/graphrag/storage/cosmosdb_pipeline_storage.py @@ -77,7 +77,7 @@ def __init__(self, **kwargs: Any) -> None: ) self._no_id_prefixes = [] logger.debug( - "creating cosmosdb storage with account: %s and database: %s and container: %s", + "Creating cosmosdb storage with account [%s] and database [%s] and container [%s]", self._cosmosdb_account_name, self._database_name, self._container_name, @@ -120,23 +120,18 @@ def _delete_container(self) -> None: def find( self, file_pattern: re.Pattern[str], - base_dir: str | None = None, - max_count=-1, ) -> Iterator[str]: """Find documents in a Cosmos DB container using a file pattern regex. Params: - base_dir: The name of the base directory (not used in Cosmos DB context). file_pattern: The file pattern to use. - max_count: The maximum number of documents to return. If -1, all documents are returned. Returns ------- An iterator of document IDs and their corresponding regex matches. """ - base_dir = base_dir or "" logger.info( - "search container %s for documents matching %s", + "Search container [%s] for documents matching [%s]", self._container_name, file_pattern.pattern, ) @@ -156,6 +151,7 @@ def find( enable_cross_partition_query=True, ) ) + logger.debug("All items: %s", [item["id"] for item in items]) num_loaded = 0 num_total = len(items) if num_total == 0: @@ -166,20 +162,18 @@ def find( if match: yield item["id"] num_loaded += 1 - if max_count > 0 and num_loaded >= max_count: - break else: num_filtered += 1 - progress_status = _create_progress_status( - num_loaded, num_filtered, num_total - ) - logger.debug( - "Progress: %s (%d/%d completed)", - progress_status.description, - progress_status.completed_items, - progress_status.total_items, - ) + progress_status = _create_progress_status( + num_loaded, num_filtered, num_total + ) + logger.debug( + "Progress: %s (%d/%d completed)", + progress_status.description, + progress_status.completed_items, + progress_status.total_items, + ) except Exception: # noqa: BLE001 logger.warning( "An error occurred while searching for documents in Cosmos DB." diff --git a/packages/graphrag/graphrag/storage/file_pipeline_storage.py b/packages/graphrag/graphrag/storage/file_pipeline_storage.py index 15445b0d3a..52402c8bd6 100644 --- a/packages/graphrag/graphrag/storage/file_pipeline_storage.py +++ b/packages/graphrag/graphrag/storage/file_pipeline_storage.py @@ -27,62 +27,54 @@ class FilePipelineStorage(PipelineStorage): """File storage class definition.""" - _root_dir: str + _base_dir: str _encoding: str def __init__(self, **kwargs: Any) -> None: """Create a file based storage.""" - self._root_dir = kwargs.get("base_dir", "") + self._base_dir = kwargs.get("base_dir", "") self._encoding = kwargs.get("encoding", "utf-8") - logger.info("Creating file storage at %s", self._root_dir) - Path(self._root_dir).mkdir(parents=True, exist_ok=True) + logger.info("Creating file storage at [%s]", self._base_dir) + Path(self._base_dir).mkdir(parents=True, exist_ok=True) def find( self, file_pattern: re.Pattern[str], - base_dir: str | None = None, - max_count=-1, ) -> Iterator[str]: """Find files in the storage using a file pattern.""" - search_path = Path(self._root_dir) / (base_dir or "") logger.info( - "search %s for files matching %s", search_path, file_pattern.pattern + "Search [%s] for files matching [%s]", self._base_dir, file_pattern.pattern ) - all_files = list(search_path.rglob("**/*")) + all_files = list(Path(self._base_dir).rglob("**/*")) + logger.debug("All files and folders: %s", [file.name for file in all_files]) num_loaded = 0 num_total = len(all_files) num_filtered = 0 for file in all_files: match = file_pattern.search(f"{file}") if match: - filename = f"{file}".replace(self._root_dir, "") + filename = f"{file}".replace(str(Path(self._base_dir)), "", 1) if filename.startswith(os.sep): filename = filename[1:] yield filename num_loaded += 1 - if max_count > 0 and num_loaded >= max_count: - break else: num_filtered += 1 - logger.debug( - "Files loaded: %d, filtered: %d, total: %d", - num_loaded, - num_filtered, - num_total, - ) + logger.debug( + "Files loaded: %d, filtered: %d, total: %d", + num_loaded, + num_filtered, + num_total, + ) async def get( self, key: str, as_bytes: bool | None = False, encoding: str | None = None ) -> Any: """Get method definition.""" - file_path = join_path(self._root_dir, key) + file_path = join_path(self._base_dir, key) if await self.has(key): return await self._read_file(file_path, as_bytes, encoding) - if await exists(key): - # Lookup for key, as it is pressumably a new file loaded from inputs - # and not yet written to storage - return await self._read_file(key, as_bytes, encoding) return None @@ -109,7 +101,7 @@ async def set(self, key: str, value: Any, encoding: str | None = None) -> None: write_type = "wb" if is_bytes else "w" encoding = None if is_bytes else encoding or self._encoding async with aiofiles.open( - join_path(self._root_dir, key), + join_path(self._base_dir, key), cast("Any", write_type), encoding=encoding, ) as f: @@ -117,16 +109,16 @@ async def set(self, key: str, value: Any, encoding: str | None = None) -> None: async def has(self, key: str) -> bool: """Has method definition.""" - return await exists(join_path(self._root_dir, key)) + return await exists(join_path(self._base_dir, key)) async def delete(self, key: str) -> None: """Delete method definition.""" if await self.has(key): - await remove(join_path(self._root_dir, key)) + await remove(join_path(self._base_dir, key)) async def clear(self) -> None: """Clear method definition.""" - for file in Path(self._root_dir).glob("*"): + for file in Path(self._base_dir).glob("*"): if file.is_dir(): shutil.rmtree(file) else: @@ -136,16 +128,16 @@ def child(self, name: str | None) -> "PipelineStorage": """Create a child storage instance.""" if name is None: return self - child_path = str(Path(self._root_dir) / Path(name)) + child_path = str(Path(self._base_dir) / Path(name)) return FilePipelineStorage(base_dir=child_path, encoding=self._encoding) def keys(self) -> list[str]: """Return the keys in the storage.""" - return [item.name for item in Path(self._root_dir).iterdir() if item.is_file()] + return [item.name for item in Path(self._base_dir).iterdir() if item.is_file()] async def get_creation_date(self, key: str) -> str: """Get the creation date of a file.""" - file_path = Path(join_path(self._root_dir, key)) + file_path = Path(join_path(self._base_dir, key)) creation_timestamp = file_path.stat().st_ctime creation_time_utc = datetime.fromtimestamp(creation_timestamp, tz=timezone.utc) diff --git a/packages/graphrag/graphrag/storage/pipeline_storage.py b/packages/graphrag/graphrag/storage/pipeline_storage.py index ba3ab86e97..5c79921736 100644 --- a/packages/graphrag/graphrag/storage/pipeline_storage.py +++ b/packages/graphrag/graphrag/storage/pipeline_storage.py @@ -17,8 +17,6 @@ class PipelineStorage(metaclass=ABCMeta): def find( self, file_pattern: re.Pattern[str], - base_dir: str | None = None, - max_count=-1, ) -> Iterator[str]: """Find files in the storage using a file pattern.""" diff --git a/tests/integration/storage/test_blob_pipeline_storage.py b/tests/integration/storage/test_blob_pipeline_storage.py index f99a74ff74..818b588bd6 100644 --- a/tests/integration/storage/test_blob_pipeline_storage.py +++ b/tests/integration/storage/test_blob_pipeline_storage.py @@ -18,17 +18,13 @@ async def test_find(): ) try: try: - items = list( - storage.find(base_dir="input", file_pattern=re.compile(r".*\.txt$")) - ) + items = list(storage.find(file_pattern=re.compile(r".*\.txt$"))) assert items == [] await storage.set( "input/christmas.txt", "Merry Christmas!", encoding="utf-8" ) - items = list( - storage.find(base_dir="input", file_pattern=re.compile(r".*\.txt$")) - ) + items = list(storage.find(file_pattern=re.compile(r".*\.txt$"))) assert items == ["input/christmas.txt"] await storage.set("test.txt", "Hello, World!", encoding="utf-8") diff --git a/tests/integration/storage/test_factory.py b/tests/integration/storage/test_factory.py index 5cab7187b8..87a2960dbc 100644 --- a/tests/integration/storage/test_factory.py +++ b/tests/integration/storage/test_factory.py @@ -113,8 +113,6 @@ def __init__(self, **kwargs): def find( self, file_pattern: re.Pattern[str], - base_dir: str | None = None, - max_count=-1, ) -> Iterator[str]: return iter([]) diff --git a/tests/integration/storage/test_file_pipeline_storage.py b/tests/integration/storage/test_file_pipeline_storage.py index cc5b3f7c83..95e329b6bf 100644 --- a/tests/integration/storage/test_file_pipeline_storage.py +++ b/tests/integration/storage/test_file_pipeline_storage.py @@ -15,14 +15,12 @@ async def test_find(): - storage = FilePipelineStorage() - items = list( - storage.find( - base_dir="tests/fixtures/text/input", file_pattern=re.compile(r".*\.txt$") - ) + storage = FilePipelineStorage( + base_dir="tests/fixtures/text/input", ) - assert items == [str(Path("tests/fixtures/text/input/dulce.txt"))] - output = await storage.get("tests/fixtures/text/input/dulce.txt") + items = list(storage.find(file_pattern=re.compile(r".*\.txt$"))) + assert items == [str(Path("dulce.txt"))] + output = await storage.get("dulce.txt") assert len(output) > 0 await storage.set("test.txt", "Hello, World!", encoding="utf-8") @@ -34,12 +32,12 @@ async def test_find(): async def test_get_creation_date(): - storage = FilePipelineStorage() - - creation_date = await storage.get_creation_date( - "tests/fixtures/text/input/dulce.txt" + storage = FilePipelineStorage( + base_dir="tests/fixtures/text/input", ) + creation_date = await storage.get_creation_date("dulce.txt") + datetime_format = "%Y-%m-%d %H:%M:%S %z" parsed_datetime = datetime.strptime(creation_date, datetime_format).astimezone()