From c10261d05491e221728c5056eeed50b05180e9b1 Mon Sep 17 00:00:00 2001 From: Nathan Evans Date: Tue, 4 Nov 2025 10:45:30 -0800 Subject: [PATCH 1/7] Fix pipeline recursion --- packages/graphrag/graphrag/config/defaults.py | 8 +- .../config/models/graph_rag_config.py | 10 +- .../graphrag/config/models/storage_config.py | 2 +- .../graphrag/storage/blob_pipeline_storage.py | 100 +++++------------- .../storage/cosmosdb_pipeline_storage.py | 23 ++-- .../graphrag/storage/file_pipeline_storage.py | 47 ++++---- 6 files changed, 73 insertions(+), 117 deletions(-) diff --git a/packages/graphrag/graphrag/config/defaults.py b/packages/graphrag/graphrag/config/defaults.py index 8a04851682..60dd4d162f 100644 --- a/packages/graphrag/graphrag/config/defaults.py +++ b/packages/graphrag/graphrag/config/defaults.py @@ -229,7 +229,7 @@ class StorageDefaults: """Default values for storage.""" type: ClassVar[StorageType] = StorageType.file - base_dir: str = DEFAULT_OUTPUT_BASE_DIR + base_dir: str | None = None connection_string: None = None container_name: None = None storage_account_blob_url: None = None @@ -240,7 +240,7 @@ class StorageDefaults: class InputStorageDefaults(StorageDefaults): """Default values for input storage.""" - base_dir: str = "input" + base_dir: str | None = "input" @dataclass @@ -310,7 +310,7 @@ class LocalSearchDefaults: class OutputDefaults(StorageDefaults): """Default values for output.""" - base_dir: str = DEFAULT_OUTPUT_BASE_DIR + base_dir: str | None = DEFAULT_OUTPUT_BASE_DIR @dataclass @@ -362,7 +362,7 @@ class SummarizeDescriptionsDefaults: class UpdateIndexOutputDefaults(StorageDefaults): """Default values for update index output.""" - base_dir: str = "update_output" + base_dir: str | None = "update_output" @dataclass diff --git a/packages/graphrag/graphrag/config/models/graph_rag_config.py b/packages/graphrag/graphrag/config/models/graph_rag_config.py index 6a6a98e973..71509f0176 100644 --- a/packages/graphrag/graphrag/config/models/graph_rag_config.py +++ b/packages/graphrag/graphrag/config/models/graph_rag_config.py @@ -152,7 +152,7 @@ def _validate_input_pattern(self) -> None: def _validate_input_base_dir(self) -> None: """Validate the input base directory.""" if self.input.storage.type == defs.StorageType.file: - if self.input.storage.base_dir.strip() == "": + if not self.input.storage.base_dir: msg = "input storage base directory is required for file input storage. Please rerun `graphrag init` and set the input storage configuration." raise ValueError(msg) self.input.storage.base_dir = str( @@ -167,14 +167,16 @@ def _validate_input_base_dir(self) -> None: output: StorageConfig = Field( description="The output configuration.", - default=StorageConfig(), + default=StorageConfig( + base_dir=graphrag_config_defaults.output.base_dir, + ), ) """The output configuration.""" def _validate_output_base_dir(self) -> None: """Validate the output base directory.""" if self.output.type == defs.StorageType.file: - if self.output.base_dir.strip() == "": + if not self.output.base_dir: msg = "output base directory is required for file output. Please rerun `graphrag init` and set the output configuration." raise ValueError(msg) self.output.base_dir = str( @@ -192,7 +194,7 @@ def _validate_output_base_dir(self) -> None: def _validate_update_index_output_base_dir(self) -> None: """Validate the update index output base directory.""" if self.update_index_output.type == defs.StorageType.file: - if self.update_index_output.base_dir.strip() == "": + if not self.update_index_output.base_dir: msg = "update_index_output base directory is required for file output. Please rerun `graphrag init` and set the update_index_output configuration." raise ValueError(msg) self.update_index_output.base_dir = str( diff --git a/packages/graphrag/graphrag/config/models/storage_config.py b/packages/graphrag/graphrag/config/models/storage_config.py index 3f01448c66..7491454c0a 100644 --- a/packages/graphrag/graphrag/config/models/storage_config.py +++ b/packages/graphrag/graphrag/config/models/storage_config.py @@ -18,7 +18,7 @@ class StorageConfig(BaseModel): description="The storage type to use.", default=graphrag_config_defaults.storage.type, ) - base_dir: str = Field( + base_dir: str | None = Field( description="The base directory for the output.", default=graphrag_config_defaults.storage.base_dir, ) diff --git a/packages/graphrag/graphrag/storage/blob_pipeline_storage.py b/packages/graphrag/graphrag/storage/blob_pipeline_storage.py index 5a00af85b5..0fa11d34b0 100644 --- a/packages/graphrag/graphrag/storage/blob_pipeline_storage.py +++ b/packages/graphrag/graphrag/storage/blob_pipeline_storage.py @@ -25,7 +25,7 @@ class BlobPipelineStorage(PipelineStorage): _connection_string: str | None _container_name: str - _path_prefix: str + _base_dir: str | None _encoding: str _storage_account_blob_url: str | None @@ -33,7 +33,7 @@ def __init__(self, **kwargs: Any) -> None: """Create a new BlobStorage instance.""" connection_string = kwargs.get("connection_string") storage_account_blob_url = kwargs.get("storage_account_blob_url") - path_prefix = kwargs.get("base_dir") + base_dir = kwargs.get("base_dir") container_name = kwargs["container_name"] if container_name is None: msg = "No container name provided for blob storage." @@ -42,7 +42,9 @@ def __init__(self, **kwargs: Any) -> None: msg = "No storage account blob url provided for blob storage." raise ValueError(msg) - logger.info("Creating blob storage at %s", container_name) + logger.info( + "Creating blob storage at [%s] and base_dir [%s]", container_name, base_dir + ) if connection_string: self._blob_service_client = BlobServiceClient.from_connection_string( connection_string @@ -59,18 +61,13 @@ def __init__(self, **kwargs: Any) -> None: self._encoding = kwargs.get("encoding", "utf-8") self._container_name = container_name self._connection_string = connection_string - self._path_prefix = path_prefix or "" + self._base_dir = base_dir self._storage_account_blob_url = storage_account_blob_url self._storage_account_name = ( storage_account_blob_url.split("//")[1].split(".")[0] if storage_account_blob_url else None ) - logger.debug( - "creating blob storage at container=%s, path=%s", - self._container_name, - self._path_prefix, - ) self._create_container() def _create_container(self) -> None: @@ -82,6 +79,7 @@ def _create_container(self) -> None: for container in self._blob_service_client.list_containers() ] if container_name not in container_names: + logger.debug("Creating new container [%s]", container_name) self._blob_service_client.create_container(container_name) def _delete_container(self) -> None: @@ -114,17 +112,18 @@ def find( ------- An iterator of blob names and their corresponding regex matches. """ - base_dir = base_dir or "" + base_dir = base_dir or self._base_dir logger.info( - "search container %s for files matching %s", + "Search container [%s] in base_dir [%s] for files matching [%s]", self._container_name, + base_dir, file_pattern.pattern, ) def _blobname(blob_name: str) -> str: - if blob_name.startswith(self._path_prefix): - blob_name = blob_name.replace(self._path_prefix, "", 1) + if base_dir and blob_name.startswith(base_dir): + blob_name = blob_name.replace(base_dir, "", 1) if blob_name.startswith("/"): blob_name = blob_name[1:] return blob_name @@ -133,26 +132,26 @@ def _blobname(blob_name: str) -> str: container_client = self._blob_service_client.get_container_client( self._container_name ) - all_blobs = list(container_client.list_blobs()) - + all_blobs = list(container_client.list_blobs(base_dir)) + logger.debug("All blobs: %s", [blob.name for blob in all_blobs]) num_loaded = 0 num_total = len(list(all_blobs)) num_filtered = 0 for blob in all_blobs: match = file_pattern.search(blob.name) - if match and blob.name.startswith(base_dir): + if match: yield _blobname(blob.name) num_loaded += 1 if max_count > 0 and num_loaded >= max_count: break else: num_filtered += 1 - logger.debug( - "Blobs loaded: %d, filtered: %d, total: %d", - num_loaded, - num_filtered, - num_total, - ) + logger.debug( + "Blobs loaded: %d, filtered: %d, total: %d", + num_loaded, + num_filtered, + num_total, + ) except Exception: # noqa: BLE001 logger.warning( "Error finding blobs: base_dir=%s, file_pattern=%s", @@ -163,7 +162,7 @@ def _blobname(blob_name: str) -> str: async def get( self, key: str, as_bytes: bool | None = False, encoding: str | None = None ) -> Any: - """Get a value from the cache.""" + """Get a value from the blob.""" try: key = self._keyname(key) container_client = self._blob_service_client.get_container_client( @@ -181,7 +180,7 @@ async def get( return blob_data async def set(self, key: str, value: Any, encoding: str | None = None) -> None: - """Set a value in the cache.""" + """Set a value in the blob.""" try: key = self._keyname(key) container_client = self._blob_service_client.get_container_client( @@ -196,46 +195,8 @@ async def set(self, key: str, value: Any, encoding: str | None = None) -> None: except Exception: logger.exception("Error setting key %s: %s", key) - def _set_df_json(self, key: str, dataframe: Any) -> None: - """Set a json dataframe.""" - if self._connection_string is None and self._storage_account_name: - dataframe.to_json( - self._abfs_url(key), - storage_options={ - "account_name": self._storage_account_name, - "credential": DefaultAzureCredential(), - }, - orient="records", - lines=True, - force_ascii=False, - ) - else: - dataframe.to_json( - self._abfs_url(key), - storage_options={"connection_string": self._connection_string}, - orient="records", - lines=True, - force_ascii=False, - ) - - def _set_df_parquet(self, key: str, dataframe: Any) -> None: - """Set a parquet dataframe.""" - if self._connection_string is None and self._storage_account_name: - dataframe.to_parquet( - self._abfs_url(key), - storage_options={ - "account_name": self._storage_account_name, - "credential": DefaultAzureCredential(), - }, - ) - else: - dataframe.to_parquet( - self._abfs_url(key), - storage_options={"connection_string": self._connection_string}, - ) - async def has(self, key: str) -> bool: - """Check if a key exists in the cache.""" + """Check if a key exists in the blob.""" key = self._keyname(key) container_client = self._blob_service_client.get_container_client( self._container_name @@ -244,7 +205,7 @@ async def has(self, key: str) -> bool: return blob_client.exists() async def delete(self, key: str) -> None: - """Delete a key from the cache.""" + """Delete a key from the blob.""" key = self._keyname(key) container_client = self._blob_service_client.get_container_client( self._container_name @@ -259,7 +220,7 @@ def child(self, name: str | None) -> "PipelineStorage": """Create a child storage instance.""" if name is None: return self - path = str(Path(self._path_prefix) / name) + path = str(Path(self._base_dir) / name) if self._base_dir else name return BlobPipelineStorage( connection_string=self._connection_string, container_name=self._container_name, @@ -275,15 +236,10 @@ def keys(self) -> list[str]: def _keyname(self, key: str) -> str: """Get the key name.""" - return str(Path(self._path_prefix) / key) - - def _abfs_url(self, key: str) -> str: - """Get the ABFS URL.""" - path = str(Path(self._container_name) / self._path_prefix / key) - return f"abfs://{path}" + return str(Path(self._base_dir) / key) if self._base_dir else key async def get_creation_date(self, key: str) -> str: - """Get a value from the cache.""" + """Get creation date for the blob.""" try: key = self._keyname(key) container_client = self._blob_service_client.get_container_client( diff --git a/packages/graphrag/graphrag/storage/cosmosdb_pipeline_storage.py b/packages/graphrag/graphrag/storage/cosmosdb_pipeline_storage.py index 8d2673e89b..fbaa254274 100644 --- a/packages/graphrag/graphrag/storage/cosmosdb_pipeline_storage.py +++ b/packages/graphrag/graphrag/storage/cosmosdb_pipeline_storage.py @@ -77,7 +77,7 @@ def __init__(self, **kwargs: Any) -> None: ) self._no_id_prefixes = [] logger.debug( - "creating cosmosdb storage with account: %s and database: %s and container: %s", + "Creating cosmosdb storage with account [%s] and database [%s] and container [%s]", self._cosmosdb_account_name, self._database_name, self._container_name, @@ -136,7 +136,7 @@ def find( """ base_dir = base_dir or "" logger.info( - "search container %s for documents matching %s", + "Search container [%s] for documents matching [%s]", self._container_name, file_pattern.pattern, ) @@ -156,6 +156,7 @@ def find( enable_cross_partition_query=True, ) ) + logger.debug("All items: %s", [item["id"] for item in items]) num_loaded = 0 num_total = len(items) if num_total == 0: @@ -171,15 +172,15 @@ def find( else: num_filtered += 1 - progress_status = _create_progress_status( - num_loaded, num_filtered, num_total - ) - logger.debug( - "Progress: %s (%d/%d completed)", - progress_status.description, - progress_status.completed_items, - progress_status.total_items, - ) + progress_status = _create_progress_status( + num_loaded, num_filtered, num_total + ) + logger.debug( + "Progress: %s (%d/%d completed)", + progress_status.description, + progress_status.completed_items, + progress_status.total_items, + ) except Exception: # noqa: BLE001 logger.warning( "An error occurred while searching for documents in Cosmos DB." diff --git a/packages/graphrag/graphrag/storage/file_pipeline_storage.py b/packages/graphrag/graphrag/storage/file_pipeline_storage.py index 15445b0d3a..bd1052d4b1 100644 --- a/packages/graphrag/graphrag/storage/file_pipeline_storage.py +++ b/packages/graphrag/graphrag/storage/file_pipeline_storage.py @@ -27,15 +27,15 @@ class FilePipelineStorage(PipelineStorage): """File storage class definition.""" - _root_dir: str + _base_dir: str _encoding: str def __init__(self, **kwargs: Any) -> None: """Create a file based storage.""" - self._root_dir = kwargs.get("base_dir", "") + self._base_dir = kwargs.get("base_dir", "") self._encoding = kwargs.get("encoding", "utf-8") - logger.info("Creating file storage at %s", self._root_dir) - Path(self._root_dir).mkdir(parents=True, exist_ok=True) + logger.info("Creating file storage at [%s]", self._base_dir) + Path(self._base_dir).mkdir(parents=True, exist_ok=True) def find( self, @@ -44,18 +44,19 @@ def find( max_count=-1, ) -> Iterator[str]: """Find files in the storage using a file pattern.""" - search_path = Path(self._root_dir) / (base_dir or "") + search_path = Path(self._base_dir) / (base_dir or "") logger.info( - "search %s for files matching %s", search_path, file_pattern.pattern + "Search [%s] for files matching [%s]", search_path, file_pattern.pattern ) all_files = list(search_path.rglob("**/*")) + logger.debug("All files and folders: %s", [file.name for file in all_files]) num_loaded = 0 num_total = len(all_files) num_filtered = 0 for file in all_files: match = file_pattern.search(f"{file}") if match: - filename = f"{file}".replace(self._root_dir, "") + filename = f"{file}".replace(self._base_dir, "", 1) if filename.startswith(os.sep): filename = filename[1:] yield filename @@ -64,25 +65,21 @@ def find( break else: num_filtered += 1 - logger.debug( - "Files loaded: %d, filtered: %d, total: %d", - num_loaded, - num_filtered, - num_total, - ) + logger.debug( + "Files loaded: %d, filtered: %d, total: %d", + num_loaded, + num_filtered, + num_total, + ) async def get( self, key: str, as_bytes: bool | None = False, encoding: str | None = None ) -> Any: """Get method definition.""" - file_path = join_path(self._root_dir, key) + file_path = join_path(self._base_dir, key) if await self.has(key): return await self._read_file(file_path, as_bytes, encoding) - if await exists(key): - # Lookup for key, as it is pressumably a new file loaded from inputs - # and not yet written to storage - return await self._read_file(key, as_bytes, encoding) return None @@ -109,7 +106,7 @@ async def set(self, key: str, value: Any, encoding: str | None = None) -> None: write_type = "wb" if is_bytes else "w" encoding = None if is_bytes else encoding or self._encoding async with aiofiles.open( - join_path(self._root_dir, key), + join_path(self._base_dir, key), cast("Any", write_type), encoding=encoding, ) as f: @@ -117,16 +114,16 @@ async def set(self, key: str, value: Any, encoding: str | None = None) -> None: async def has(self, key: str) -> bool: """Has method definition.""" - return await exists(join_path(self._root_dir, key)) + return await exists(join_path(self._base_dir, key)) async def delete(self, key: str) -> None: """Delete method definition.""" if await self.has(key): - await remove(join_path(self._root_dir, key)) + await remove(join_path(self._base_dir, key)) async def clear(self) -> None: """Clear method definition.""" - for file in Path(self._root_dir).glob("*"): + for file in Path(self._base_dir).glob("*"): if file.is_dir(): shutil.rmtree(file) else: @@ -136,16 +133,16 @@ def child(self, name: str | None) -> "PipelineStorage": """Create a child storage instance.""" if name is None: return self - child_path = str(Path(self._root_dir) / Path(name)) + child_path = str(Path(self._base_dir) / Path(name)) return FilePipelineStorage(base_dir=child_path, encoding=self._encoding) def keys(self) -> list[str]: """Return the keys in the storage.""" - return [item.name for item in Path(self._root_dir).iterdir() if item.is_file()] + return [item.name for item in Path(self._base_dir).iterdir() if item.is_file()] async def get_creation_date(self, key: str) -> str: """Get the creation date of a file.""" - file_path = Path(join_path(self._root_dir, key)) + file_path = Path(join_path(self._base_dir, key)) creation_timestamp = file_path.stat().st_ctime creation_time_utc = datetime.fromtimestamp(creation_timestamp, tz=timezone.utc) From 3eb1e688eda57246cb5f9f755758a7396fc1e30e Mon Sep 17 00:00:00 2001 From: Nathan Evans Date: Tue, 4 Nov 2025 11:33:52 -0800 Subject: [PATCH 2/7] Remove base_dir from storage.find --- .../graphrag/storage/blob_pipeline_storage.py | 14 +++++--------- .../graphrag/storage/cosmosdb_pipeline_storage.py | 3 --- .../graphrag/storage/file_pipeline_storage.py | 6 ++---- .../graphrag/graphrag/storage/pipeline_storage.py | 1 - .../storage/test_blob_pipeline_storage.py | 8 ++------ tests/integration/storage/test_factory.py | 1 - .../storage/test_file_pipeline_storage.py | 8 +++----- 7 files changed, 12 insertions(+), 29 deletions(-) diff --git a/packages/graphrag/graphrag/storage/blob_pipeline_storage.py b/packages/graphrag/graphrag/storage/blob_pipeline_storage.py index 0fa11d34b0..a31e0c9595 100644 --- a/packages/graphrag/graphrag/storage/blob_pipeline_storage.py +++ b/packages/graphrag/graphrag/storage/blob_pipeline_storage.py @@ -98,13 +98,11 @@ def _container_exists(self) -> bool: def find( self, file_pattern: re.Pattern[str], - base_dir: str | None = None, max_count=-1, ) -> Iterator[str]: """Find blobs in a container using a file pattern. Params: - base_dir: The name of the base container. file_pattern: The file pattern to use. max_count: The maximum number of blobs to return. If -1, all blobs are returned. @@ -112,18 +110,16 @@ def find( ------- An iterator of blob names and their corresponding regex matches. """ - base_dir = base_dir or self._base_dir - logger.info( "Search container [%s] in base_dir [%s] for files matching [%s]", self._container_name, - base_dir, + self._base_dir, file_pattern.pattern, ) def _blobname(blob_name: str) -> str: - if base_dir and blob_name.startswith(base_dir): - blob_name = blob_name.replace(base_dir, "", 1) + if self._base_dir and blob_name.startswith(self._base_dir): + blob_name = blob_name.replace(self._base_dir, "", 1) if blob_name.startswith("/"): blob_name = blob_name[1:] return blob_name @@ -132,7 +128,7 @@ def _blobname(blob_name: str) -> str: container_client = self._blob_service_client.get_container_client( self._container_name ) - all_blobs = list(container_client.list_blobs(base_dir)) + all_blobs = list(container_client.list_blobs(self._base_dir)) logger.debug("All blobs: %s", [blob.name for blob in all_blobs]) num_loaded = 0 num_total = len(list(all_blobs)) @@ -155,7 +151,7 @@ def _blobname(blob_name: str) -> str: except Exception: # noqa: BLE001 logger.warning( "Error finding blobs: base_dir=%s, file_pattern=%s", - base_dir, + self._base_dir, file_pattern, ) diff --git a/packages/graphrag/graphrag/storage/cosmosdb_pipeline_storage.py b/packages/graphrag/graphrag/storage/cosmosdb_pipeline_storage.py index fbaa254274..407310013b 100644 --- a/packages/graphrag/graphrag/storage/cosmosdb_pipeline_storage.py +++ b/packages/graphrag/graphrag/storage/cosmosdb_pipeline_storage.py @@ -120,13 +120,11 @@ def _delete_container(self) -> None: def find( self, file_pattern: re.Pattern[str], - base_dir: str | None = None, max_count=-1, ) -> Iterator[str]: """Find documents in a Cosmos DB container using a file pattern regex. Params: - base_dir: The name of the base directory (not used in Cosmos DB context). file_pattern: The file pattern to use. max_count: The maximum number of documents to return. If -1, all documents are returned. @@ -134,7 +132,6 @@ def find( ------- An iterator of document IDs and their corresponding regex matches. """ - base_dir = base_dir or "" logger.info( "Search container [%s] for documents matching [%s]", self._container_name, diff --git a/packages/graphrag/graphrag/storage/file_pipeline_storage.py b/packages/graphrag/graphrag/storage/file_pipeline_storage.py index bd1052d4b1..661fc1d2b0 100644 --- a/packages/graphrag/graphrag/storage/file_pipeline_storage.py +++ b/packages/graphrag/graphrag/storage/file_pipeline_storage.py @@ -40,15 +40,13 @@ def __init__(self, **kwargs: Any) -> None: def find( self, file_pattern: re.Pattern[str], - base_dir: str | None = None, max_count=-1, ) -> Iterator[str]: """Find files in the storage using a file pattern.""" - search_path = Path(self._base_dir) / (base_dir or "") logger.info( - "Search [%s] for files matching [%s]", search_path, file_pattern.pattern + "Search [%s] for files matching [%s]", self._base_dir, file_pattern.pattern ) - all_files = list(search_path.rglob("**/*")) + all_files = list(Path(self._base_dir).rglob("**/*")) logger.debug("All files and folders: %s", [file.name for file in all_files]) num_loaded = 0 num_total = len(all_files) diff --git a/packages/graphrag/graphrag/storage/pipeline_storage.py b/packages/graphrag/graphrag/storage/pipeline_storage.py index ba3ab86e97..b48bbccd7e 100644 --- a/packages/graphrag/graphrag/storage/pipeline_storage.py +++ b/packages/graphrag/graphrag/storage/pipeline_storage.py @@ -17,7 +17,6 @@ class PipelineStorage(metaclass=ABCMeta): def find( self, file_pattern: re.Pattern[str], - base_dir: str | None = None, max_count=-1, ) -> Iterator[str]: """Find files in the storage using a file pattern.""" diff --git a/tests/integration/storage/test_blob_pipeline_storage.py b/tests/integration/storage/test_blob_pipeline_storage.py index f99a74ff74..818b588bd6 100644 --- a/tests/integration/storage/test_blob_pipeline_storage.py +++ b/tests/integration/storage/test_blob_pipeline_storage.py @@ -18,17 +18,13 @@ async def test_find(): ) try: try: - items = list( - storage.find(base_dir="input", file_pattern=re.compile(r".*\.txt$")) - ) + items = list(storage.find(file_pattern=re.compile(r".*\.txt$"))) assert items == [] await storage.set( "input/christmas.txt", "Merry Christmas!", encoding="utf-8" ) - items = list( - storage.find(base_dir="input", file_pattern=re.compile(r".*\.txt$")) - ) + items = list(storage.find(file_pattern=re.compile(r".*\.txt$"))) assert items == ["input/christmas.txt"] await storage.set("test.txt", "Hello, World!", encoding="utf-8") diff --git a/tests/integration/storage/test_factory.py b/tests/integration/storage/test_factory.py index 5cab7187b8..a4a1fdfbc1 100644 --- a/tests/integration/storage/test_factory.py +++ b/tests/integration/storage/test_factory.py @@ -113,7 +113,6 @@ def __init__(self, **kwargs): def find( self, file_pattern: re.Pattern[str], - base_dir: str | None = None, max_count=-1, ) -> Iterator[str]: return iter([]) diff --git a/tests/integration/storage/test_file_pipeline_storage.py b/tests/integration/storage/test_file_pipeline_storage.py index cc5b3f7c83..0585003dc0 100644 --- a/tests/integration/storage/test_file_pipeline_storage.py +++ b/tests/integration/storage/test_file_pipeline_storage.py @@ -15,12 +15,10 @@ async def test_find(): - storage = FilePipelineStorage() - items = list( - storage.find( - base_dir="tests/fixtures/text/input", file_pattern=re.compile(r".*\.txt$") - ) + storage = FilePipelineStorage( + base_dir="tests/fixtures/text/input", ) + items = list(storage.find(file_pattern=re.compile(r".*\.txt$"))) assert items == [str(Path("tests/fixtures/text/input/dulce.txt"))] output = await storage.get("tests/fixtures/text/input/dulce.txt") assert len(output) > 0 From a4e555e2a3efe13e2fa85e435818c333d3b2664f Mon Sep 17 00:00:00 2001 From: Nathan Evans Date: Tue, 4 Nov 2025 11:36:38 -0800 Subject: [PATCH 3/7] Remove max_count from storage.find --- packages/graphrag/graphrag/storage/blob_pipeline_storage.py | 4 ---- .../graphrag/graphrag/storage/cosmosdb_pipeline_storage.py | 4 ---- packages/graphrag/graphrag/storage/file_pipeline_storage.py | 3 --- packages/graphrag/graphrag/storage/pipeline_storage.py | 1 - tests/integration/storage/test_factory.py | 1 - 5 files changed, 13 deletions(-) diff --git a/packages/graphrag/graphrag/storage/blob_pipeline_storage.py b/packages/graphrag/graphrag/storage/blob_pipeline_storage.py index a31e0c9595..1435cb387d 100644 --- a/packages/graphrag/graphrag/storage/blob_pipeline_storage.py +++ b/packages/graphrag/graphrag/storage/blob_pipeline_storage.py @@ -98,13 +98,11 @@ def _container_exists(self) -> bool: def find( self, file_pattern: re.Pattern[str], - max_count=-1, ) -> Iterator[str]: """Find blobs in a container using a file pattern. Params: file_pattern: The file pattern to use. - max_count: The maximum number of blobs to return. If -1, all blobs are returned. Returns ------- @@ -138,8 +136,6 @@ def _blobname(blob_name: str) -> str: if match: yield _blobname(blob.name) num_loaded += 1 - if max_count > 0 and num_loaded >= max_count: - break else: num_filtered += 1 logger.debug( diff --git a/packages/graphrag/graphrag/storage/cosmosdb_pipeline_storage.py b/packages/graphrag/graphrag/storage/cosmosdb_pipeline_storage.py index 407310013b..a12da0ee5f 100644 --- a/packages/graphrag/graphrag/storage/cosmosdb_pipeline_storage.py +++ b/packages/graphrag/graphrag/storage/cosmosdb_pipeline_storage.py @@ -120,13 +120,11 @@ def _delete_container(self) -> None: def find( self, file_pattern: re.Pattern[str], - max_count=-1, ) -> Iterator[str]: """Find documents in a Cosmos DB container using a file pattern regex. Params: file_pattern: The file pattern to use. - max_count: The maximum number of documents to return. If -1, all documents are returned. Returns ------- @@ -164,8 +162,6 @@ def find( if match: yield item["id"] num_loaded += 1 - if max_count > 0 and num_loaded >= max_count: - break else: num_filtered += 1 diff --git a/packages/graphrag/graphrag/storage/file_pipeline_storage.py b/packages/graphrag/graphrag/storage/file_pipeline_storage.py index 661fc1d2b0..98289d06e7 100644 --- a/packages/graphrag/graphrag/storage/file_pipeline_storage.py +++ b/packages/graphrag/graphrag/storage/file_pipeline_storage.py @@ -40,7 +40,6 @@ def __init__(self, **kwargs: Any) -> None: def find( self, file_pattern: re.Pattern[str], - max_count=-1, ) -> Iterator[str]: """Find files in the storage using a file pattern.""" logger.info( @@ -59,8 +58,6 @@ def find( filename = filename[1:] yield filename num_loaded += 1 - if max_count > 0 and num_loaded >= max_count: - break else: num_filtered += 1 logger.debug( diff --git a/packages/graphrag/graphrag/storage/pipeline_storage.py b/packages/graphrag/graphrag/storage/pipeline_storage.py index b48bbccd7e..5c79921736 100644 --- a/packages/graphrag/graphrag/storage/pipeline_storage.py +++ b/packages/graphrag/graphrag/storage/pipeline_storage.py @@ -17,7 +17,6 @@ class PipelineStorage(metaclass=ABCMeta): def find( self, file_pattern: re.Pattern[str], - max_count=-1, ) -> Iterator[str]: """Find files in the storage using a file pattern.""" diff --git a/tests/integration/storage/test_factory.py b/tests/integration/storage/test_factory.py index a4a1fdfbc1..87a2960dbc 100644 --- a/tests/integration/storage/test_factory.py +++ b/tests/integration/storage/test_factory.py @@ -113,7 +113,6 @@ def __init__(self, **kwargs): def find( self, file_pattern: re.Pattern[str], - max_count=-1, ) -> Iterator[str]: return iter([]) From 2d8558ef8cca47fd2c44042e11f1310992a1ac1e Mon Sep 17 00:00:00 2001 From: Nathan Evans Date: Tue, 4 Nov 2025 12:35:54 -0800 Subject: [PATCH 4/7] Remove prefix on storage integ test --- tests/integration/storage/test_file_pipeline_storage.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/integration/storage/test_file_pipeline_storage.py b/tests/integration/storage/test_file_pipeline_storage.py index 0585003dc0..5c17cc1324 100644 --- a/tests/integration/storage/test_file_pipeline_storage.py +++ b/tests/integration/storage/test_file_pipeline_storage.py @@ -19,8 +19,8 @@ async def test_find(): base_dir="tests/fixtures/text/input", ) items = list(storage.find(file_pattern=re.compile(r".*\.txt$"))) - assert items == [str(Path("tests/fixtures/text/input/dulce.txt"))] - output = await storage.get("tests/fixtures/text/input/dulce.txt") + assert items == [str(Path("dulce.txt"))] + output = await storage.get("dulce.txt") assert len(output) > 0 await storage.set("test.txt", "Hello, World!", encoding="utf-8") @@ -35,7 +35,7 @@ async def test_get_creation_date(): storage = FilePipelineStorage() creation_date = await storage.get_creation_date( - "tests/fixtures/text/input/dulce.txt" + "dulce.txt" ) datetime_format = "%Y-%m-%d %H:%M:%S %z" From fcb46e6527d3bd299a285b01ab69db5f2fd3c17c Mon Sep 17 00:00:00 2001 From: Nathan Evans Date: Tue, 4 Nov 2025 12:50:39 -0800 Subject: [PATCH 5/7] Add base_dir in creation_date test --- tests/integration/storage/test_file_pipeline_storage.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/integration/storage/test_file_pipeline_storage.py b/tests/integration/storage/test_file_pipeline_storage.py index 5c17cc1324..9f76f0d9bb 100644 --- a/tests/integration/storage/test_file_pipeline_storage.py +++ b/tests/integration/storage/test_file_pipeline_storage.py @@ -32,7 +32,9 @@ async def test_find(): async def test_get_creation_date(): - storage = FilePipelineStorage() + storage = FilePipelineStorage( + base_dir="tests/fixtures/text/input", + ) creation_date = await storage.get_creation_date( "dulce.txt" From 8646bd505f5184b717a6d65a595ceee66d87141f Mon Sep 17 00:00:00 2001 From: Nathan Evans Date: Tue, 4 Nov 2025 13:30:55 -0800 Subject: [PATCH 6/7] Wrap base_dir in Path --- packages/graphrag/graphrag/storage/file_pipeline_storage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/graphrag/graphrag/storage/file_pipeline_storage.py b/packages/graphrag/graphrag/storage/file_pipeline_storage.py index 98289d06e7..52402c8bd6 100644 --- a/packages/graphrag/graphrag/storage/file_pipeline_storage.py +++ b/packages/graphrag/graphrag/storage/file_pipeline_storage.py @@ -53,7 +53,7 @@ def find( for file in all_files: match = file_pattern.search(f"{file}") if match: - filename = f"{file}".replace(self._base_dir, "", 1) + filename = f"{file}".replace(str(Path(self._base_dir)), "", 1) if filename.startswith(os.sep): filename = filename[1:] yield filename From 797e2bdc8631cddfac27d14ef92be3078ea30400 Mon Sep 17 00:00:00 2001 From: Nathan Evans Date: Tue, 4 Nov 2025 14:39:55 -0800 Subject: [PATCH 7/7] Use constants for input/update directories --- packages/graphrag/graphrag/config/defaults.py | 6 ++++-- tests/integration/storage/test_file_pipeline_storage.py | 4 +--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/packages/graphrag/graphrag/config/defaults.py b/packages/graphrag/graphrag/config/defaults.py index 60dd4d162f..f57bf610f9 100644 --- a/packages/graphrag/graphrag/config/defaults.py +++ b/packages/graphrag/graphrag/config/defaults.py @@ -24,7 +24,9 @@ EN_STOP_WORDS, ) +DEFAULT_INPUT_BASE_DIR = "input" DEFAULT_OUTPUT_BASE_DIR = "output" +DEFAULT_UPDATE_OUTPUT_BASE_DIR = "update_output" DEFAULT_CHAT_MODEL_ID = "default_chat_model" DEFAULT_CHAT_MODEL_TYPE = ModelType.Chat DEFAULT_CHAT_MODEL_AUTH_TYPE = AuthType.APIKey @@ -240,7 +242,7 @@ class StorageDefaults: class InputStorageDefaults(StorageDefaults): """Default values for input storage.""" - base_dir: str | None = "input" + base_dir: str | None = DEFAULT_INPUT_BASE_DIR @dataclass @@ -362,7 +364,7 @@ class SummarizeDescriptionsDefaults: class UpdateIndexOutputDefaults(StorageDefaults): """Default values for update index output.""" - base_dir: str | None = "update_output" + base_dir: str | None = DEFAULT_UPDATE_OUTPUT_BASE_DIR @dataclass diff --git a/tests/integration/storage/test_file_pipeline_storage.py b/tests/integration/storage/test_file_pipeline_storage.py index 9f76f0d9bb..95e329b6bf 100644 --- a/tests/integration/storage/test_file_pipeline_storage.py +++ b/tests/integration/storage/test_file_pipeline_storage.py @@ -36,9 +36,7 @@ async def test_get_creation_date(): base_dir="tests/fixtures/text/input", ) - creation_date = await storage.get_creation_date( - "dulce.txt" - ) + creation_date = await storage.get_creation_date("dulce.txt") datetime_format = "%Y-%m-%d %H:%M:%S %z" parsed_datetime = datetime.strptime(creation_date, datetime_format).astimezone()