Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .changeset/tag_based_filtering.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
default: minor
---

# Add `include_tags` / `exclude_tags` to filter generated endpoints by tag

Endpoints can now be limited to (or excluded by) OpenAPI tags via the `include_tags` / `exclude_tags` config keys or the `--include-tags` / `--exclude-tags` CLI options. Schemas that become unused after filtering are pruned automatically.
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ _Be forewarned, this is a beta-level feature in the sense that the API exposed i
for calling the functions in the `api` module.
2. An `api` module which will contain one module for each tag in your OpenAPI spec, as well as a `default` module
for endpoints without a tag. Each of these modules in turn contains one function for calling each endpoint.
You can limit which tags are generated with [`include_tags` / `exclude_tags`](#include_tags-and-exclude_tags).
3. A `models` module which has all the classes defined by the various schemas in your OpenAPI spec
4. A `setup.py` file _if_ you use `--meta=setup` (default is `--meta=poetry`)

Expand Down Expand Up @@ -128,6 +129,21 @@ listed, you can enable this option:
generate_all_tags: true
```

### include_tags and exclude_tags

Generate only part of a large API, by tag. Unused schemas are pruned automatically.

```yaml
include_tags: [billing, users] # keep only these
# or
exclude_tags: [admin] # drop these
```

On the CLI, pass them comma-separated: `--include-tags billing,users`. A CLI flag overrides the matching config key.

- `include_tags` and `exclude_tags` are mutually exclusive — use one, not both.
- Tags match the spec exactly (case-sensitive). Untagged endpoints count as `default`.

### project_name_override and package_name_override

Used to change the name of generated client library project/package. If the project name is changed but an override for the package name
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import re

from end_to_end_tests.functional_tests.helpers import (
inline_spec_should_fail,
with_generated_client_fixture,
)

MULTI_TAG_SPEC = """
paths:
"/billing":
post:
operationId: createInvoice
tags: ["billing"]
requestBody:
content:
application/json:
schema: {"$ref": "#/components/schemas/BillingModel"}
responses:
"200":
description: OK
content:
application/json:
schema: {"$ref": "#/components/schemas/SharedModel"}
"/users/me":
get:
operationId: getCurrentUser
tags: ["users"]
responses:
"200":
description: OK
content:
application/json:
schema: {"$ref": "#/components/schemas/SharedModel"}
"/admin/settings":
get:
operationId: getAdminSettings
tags: ["admin"]
responses:
"200":
description: OK
content:
application/json:
schema: {"$ref": "#/components/schemas/AdminModel"}
"/health":
get:
operationId: getHealth
responses:
"200":
description: OK
components:
schemas:
SharedModel:
type: object
properties:
id: {type: string}
status: {"$ref": "#/components/schemas/OrderStatus"}
address: {"$ref": "#/components/schemas/Address"}
Address:
type: object
properties:
city: {type: string}
BillingModel:
type: object
properties:
amount: {type: number}
AdminModel:
type: object
properties:
secret: {type: string}
OrderStatus:
type: string
enum: ["active", "inactive"]
"""


def _generated_package(generated_client):
return generated_client.output_path / generated_client.base_module


def _api_tag_dirs(generated_client) -> set[str]:
api_dir = _generated_package(generated_client) / "api"
return {child.name for child in api_dir.iterdir() if child.is_dir() and child.name != "__pycache__"}


def _model_modules(generated_client) -> set[str]:
models_dir = _generated_package(generated_client) / "models"
return {path.stem for path in models_dir.glob("*.py") if path.stem != "__init__"}


def _dangling_model_imports(generated_client) -> list[str]:
package = _generated_package(generated_client)
existing = {path.stem for path in (package / "models").glob("*.py")}
dangling: list[str] = []
for path in package.rglob("*.py"):
for match in re.finditer(r"from \.+models\.(\w+) import", path.read_text()):
if match.group(1) not in existing:
dangling.append(f"{path.relative_to(package)} -> models.{match.group(1)}")
return sorted(dangling)


@with_generated_client_fixture(MULTI_TAG_SPEC, extra_args=["--include-tags", "billing"])
class TestIncludeTagsViaCli:
def test_only_included_tag_api_module_is_generated(self, generated_client):
assert _api_tag_dirs(generated_client) == {"billing"}

def test_unused_models_are_pruned(self, generated_client):
assert _model_modules(generated_client) == {"billing_model", "shared_model", "order_status", "address"}

def test_pruned_client_has_no_dangling_imports(self, generated_client):
generated_client.import_module(".models")
assert _dangling_model_imports(generated_client) == []


@with_generated_client_fixture(MULTI_TAG_SPEC, config="include_tags: [billing]")
class TestIncludeTagsViaConfigFile:
def test_only_included_tag_api_module_is_generated(self, generated_client):
assert _api_tag_dirs(generated_client) == {"billing"}

def test_unused_models_are_pruned(self, generated_client):
assert _model_modules(generated_client) == {"billing_model", "shared_model", "order_status", "address"}


@with_generated_client_fixture(MULTI_TAG_SPEC, extra_args=["--exclude-tags", "admin"])
class TestExcludeTagsViaCli:
def test_excluded_tag_api_module_is_dropped(self, generated_client):
assert _api_tag_dirs(generated_client) == {"billing", "users", "default"}

def test_only_admin_models_are_pruned(self, generated_client):
assert _model_modules(generated_client) == {"billing_model", "shared_model", "order_status", "address"}


@with_generated_client_fixture(MULTI_TAG_SPEC, config="exclude_tags: [admin]")
class TestExcludeTagsViaConfigFile:
def test_excluded_tag_api_module_is_dropped(self, generated_client):
assert _api_tag_dirs(generated_client) == {"billing", "users", "default"}

def test_only_admin_models_are_pruned(self, generated_client):
assert _model_modules(generated_client) == {"billing_model", "shared_model", "order_status", "address"}


class TestMutuallyExclusiveTagFlags:
def test_both_flags_exits_nonzero(self):
result = inline_spec_should_fail(
MULTI_TAG_SPEC,
extra_args=["--include-tags", "billing", "--exclude-tags", "admin"],
)
assert "Provide either include_tags or exclude_tags, not both" in result.output
56 changes: 49 additions & 7 deletions openapi_python_client/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,27 @@ def _version_callback(value: bool) -> None:
raise typer.Exit()


def _split_comma_separated(value: str | None) -> list[str]:
if not value:
return []
return [item.strip() for item in value.split(",") if item.strip()]


def _load_config_file(
*,
config_path: Path | None,
) -> ConfigFile:
if not config_path:
config_file = ConfigFile()
else:
try:
config_file = ConfigFile.load_from_path(path=config_path)
except Exception as err:
raise typer.BadParameter("Unable to parse config") from err

return config_file


def _process_config(
*,
url: str | None,
Expand All @@ -27,6 +48,8 @@ def _process_config(
file_encoding: str,
overwrite: bool,
output_path: Path | None,
include_tags: str | None = None,
exclude_tags: str | None = None,
) -> Config:
source: Path | str
if url and not path:
Expand All @@ -46,13 +69,16 @@ def _process_config(
typer.secho(f"Unknown encoding : {file_encoding}", fg=typer.colors.RED)
raise typer.Exit(code=1) from err

if not config_path:
config_file = ConfigFile()
else:
try:
config_file = ConfigFile.load_from_path(path=config_path)
except Exception as err:
raise typer.BadParameter("Unable to parse config") from err
config_file = _load_config_file(config_path=config_path)

if include_tags is not None:
config_file.include_tags = _split_comma_separated(include_tags)
if exclude_tags is not None:
config_file.exclude_tags = _split_comma_separated(exclude_tags)

if config_file.include_tags and config_file.exclude_tags:
typer.secho("Provide either include_tags or exclude_tags, not both", fg=typer.colors.RED)
raise typer.Exit(code=1)

return Config.from_sources(config_file, meta_type, source, file_encoding, overwrite, output_path=output_path)

Expand Down Expand Up @@ -148,6 +174,20 @@ def generate(
"Defaults to the OpenAPI document title converted to kebab or snake case (depending on meta type). "
"Can also be overridden with `project_name_override` or `package_name_override` in config.",
),
include_tags: str | None = typer.Option(
None,
"--include-tags",
help="Comma-separated tags to generate. "
"Keeps matching endpoints, drops the rest, prunes unused schemas. "
"Case-sensitive. Overrides config. Can't combine with --exclude-tags.",
),
exclude_tags: str | None = typer.Option(
None,
"--exclude-tags",
help="Comma-separated tags to skip. "
"Drops matching endpoints, keeps the rest, prunes unused schemas. "
"Case-sensitive. Overrides config. Can't combine with --include-tags.",
),
) -> None:
"""Generate a new OpenAPI Client library"""
from . import generate # noqa: PLC0415
Expand All @@ -160,6 +200,8 @@ def generate(
file_encoding=file_encoding,
overwrite=overwrite,
output_path=output_path,
include_tags=include_tags,
exclude_tags=exclude_tags,
)
errors = generate(
custom_template_path=custom_template_path,
Expand Down
6 changes: 6 additions & 0 deletions openapi_python_client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class ConfigFile(BaseModel):
generate_all_tags: bool = False
http_timeout: int = 5
literal_enums: bool = False
include_tags: list[str] | None = None
exclude_tags: list[str] | None = None

@staticmethod
def load_from_path(path: Path) -> "ConfigFile":
Expand Down Expand Up @@ -76,6 +78,8 @@ class Config:
generate_all_tags: bool
http_timeout: int
literal_enums: bool
include_tags: list[str]
exclude_tags: list[str]
document_source: Path | str
file_encoding: str
content_type_overrides: dict[str, str]
Expand Down Expand Up @@ -118,6 +122,8 @@ def from_sources(
generate_all_tags=config_file.generate_all_tags,
http_timeout=config_file.http_timeout,
literal_enums=config_file.literal_enums,
include_tags=(config_file.include_tags or []),
exclude_tags=(config_file.exclude_tags or []),
document_source=document_source,
file_encoding=file_encoding,
overwrite=overwrite,
Expand Down
60 changes: 60 additions & 0 deletions openapi_python_client/parser/_pruning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from __future__ import annotations

__all__ = ["get_reachable_classes"]

from collections.abc import Iterable, Mapping
from typing import TYPE_CHECKING

from ..utils import ClassName
from .properties import (
EnumProperty,
ListProperty,
LiteralEnumProperty,
ModelProperty,
Property,
UnionProperty,
)
from .properties.protocol import PropertyProtocol

if TYPE_CHECKING: # pragma: no cover
from .openapi import Endpoint


def get_reachable_classes(
*,
endpoints: Iterable[Endpoint],
classes_by_name: Mapping[ClassName, Property],
) -> set[ClassName]:
"""Class names reachable from the given endpoints. Anything else is safe to prune.

Walks each endpoint's properties transitively, collecting models and enums by name.
Re-fetches each model from ``classes_by_name`` on descent, since the copy an endpoint
holds may have empty properties.
"""
reachable: set[ClassName] = set()
stack: list[PropertyProtocol] = []
for endpoint in endpoints:
stack.extend(endpoint.list_all_parameters())
stack.extend(response.prop for response in endpoint.responses)

while stack:
prop = stack.pop()
if isinstance(prop, ModelProperty):
name = prop.class_info.name
if name in reachable:
continue
reachable.add(name)
canonical = classes_by_name.get(name, prop)
if isinstance(canonical, ModelProperty):
stack.extend(canonical.required_properties or [])
stack.extend(canonical.optional_properties or [])
if canonical.additional_properties is not None:
stack.append(canonical.additional_properties)
elif isinstance(prop, EnumProperty | LiteralEnumProperty):
reachable.add(prop.class_info.name)
elif isinstance(prop, ListProperty):
stack.append(prop.inner_property)
elif isinstance(prop, UnionProperty):
stack.extend(prop.inner_properties)

return reachable
Loading