From a1473fa9e62ffdee8d410fb7359b1c824cc4cf1f Mon Sep 17 00:00:00 2001 From: avisnakovs Date: Thu, 4 Jun 2026 15:01:40 +0300 Subject: [PATCH] Add tag filtering with include_tags and exclude_tags --- .changeset/tag_based_filtering.md | 7 + README.md | 16 ++ .../test_tag_filtering.py | 147 ++++++++++++++++++ openapi_python_client/cli.py | 56 ++++++- openapi_python_client/config.py | 6 + openapi_python_client/parser/_pruning.py | 60 +++++++ openapi_python_client/parser/openapi.py | 25 ++- tests/test_cli.py | 76 +++++++++ tests/test_config.py | 22 +++ tests/test_parser/test_openapi.py | 103 +++++++++++- tests/test_parser/test_pruning.py | 142 +++++++++++++++++ 11 files changed, 651 insertions(+), 9 deletions(-) create mode 100644 .changeset/tag_based_filtering.md create mode 100644 end_to_end_tests/functional_tests/generated_code_execution/test_tag_filtering.py create mode 100644 openapi_python_client/parser/_pruning.py create mode 100644 tests/test_parser/test_pruning.py diff --git a/.changeset/tag_based_filtering.md b/.changeset/tag_based_filtering.md new file mode 100644 index 000000000..54709c8f4 --- /dev/null +++ b/.changeset/tag_based_filtering.md @@ -0,0 +1,7 @@ +--- +default: minor +--- + +# Add `include_tags` / `exclude_tags` to filter generated endpoints by tag + +Endpoints can now be limited to (or excluded by) OpenAPI tags via the `include_tags` / `exclude_tags` config keys or the `--include-tags` / `--exclude-tags` CLI options. Schemas that become unused after filtering are pruned automatically. diff --git a/README.md b/README.md index 4a886d299..cf3710979 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,7 @@ _Be forewarned, this is a beta-level feature in the sense that the API exposed i for calling the functions in the `api` module. 2. An `api` module which will contain one module for each tag in your OpenAPI spec, as well as a `default` module for endpoints without a tag. Each of these modules in turn contains one function for calling each endpoint. + You can limit which tags are generated with [`include_tags` / `exclude_tags`](#include_tags-and-exclude_tags). 3. A `models` module which has all the classes defined by the various schemas in your OpenAPI spec 4. A `setup.py` file _if_ you use `--meta=setup` (default is `--meta=poetry`) @@ -128,6 +129,21 @@ listed, you can enable this option: generate_all_tags: true ``` +### include_tags and exclude_tags + +Generate only part of a large API, by tag. Unused schemas are pruned automatically. + +```yaml +include_tags: [billing, users] # keep only these +# or +exclude_tags: [admin] # drop these +``` + +On the CLI, pass them comma-separated: `--include-tags billing,users`. A CLI flag overrides the matching config key. + +- `include_tags` and `exclude_tags` are mutually exclusive — use one, not both. +- Tags match the spec exactly (case-sensitive). Untagged endpoints count as `default`. + ### project_name_override and package_name_override Used to change the name of generated client library project/package. If the project name is changed but an override for the package name diff --git a/end_to_end_tests/functional_tests/generated_code_execution/test_tag_filtering.py b/end_to_end_tests/functional_tests/generated_code_execution/test_tag_filtering.py new file mode 100644 index 000000000..e4bb4dd7f --- /dev/null +++ b/end_to_end_tests/functional_tests/generated_code_execution/test_tag_filtering.py @@ -0,0 +1,147 @@ +import re + +from end_to_end_tests.functional_tests.helpers import ( + inline_spec_should_fail, + with_generated_client_fixture, +) + +MULTI_TAG_SPEC = """ +paths: + "/billing": + post: + operationId: createInvoice + tags: ["billing"] + requestBody: + content: + application/json: + schema: {"$ref": "#/components/schemas/BillingModel"} + responses: + "200": + description: OK + content: + application/json: + schema: {"$ref": "#/components/schemas/SharedModel"} + "/users/me": + get: + operationId: getCurrentUser + tags: ["users"] + responses: + "200": + description: OK + content: + application/json: + schema: {"$ref": "#/components/schemas/SharedModel"} + "/admin/settings": + get: + operationId: getAdminSettings + tags: ["admin"] + responses: + "200": + description: OK + content: + application/json: + schema: {"$ref": "#/components/schemas/AdminModel"} + "/health": + get: + operationId: getHealth + responses: + "200": + description: OK +components: + schemas: + SharedModel: + type: object + properties: + id: {type: string} + status: {"$ref": "#/components/schemas/OrderStatus"} + address: {"$ref": "#/components/schemas/Address"} + Address: + type: object + properties: + city: {type: string} + BillingModel: + type: object + properties: + amount: {type: number} + AdminModel: + type: object + properties: + secret: {type: string} + OrderStatus: + type: string + enum: ["active", "inactive"] +""" + + +def _generated_package(generated_client): + return generated_client.output_path / generated_client.base_module + + +def _api_tag_dirs(generated_client) -> set[str]: + api_dir = _generated_package(generated_client) / "api" + return {child.name for child in api_dir.iterdir() if child.is_dir() and child.name != "__pycache__"} + + +def _model_modules(generated_client) -> set[str]: + models_dir = _generated_package(generated_client) / "models" + return {path.stem for path in models_dir.glob("*.py") if path.stem != "__init__"} + + +def _dangling_model_imports(generated_client) -> list[str]: + package = _generated_package(generated_client) + existing = {path.stem for path in (package / "models").glob("*.py")} + dangling: list[str] = [] + for path in package.rglob("*.py"): + for match in re.finditer(r"from \.+models\.(\w+) import", path.read_text()): + if match.group(1) not in existing: + dangling.append(f"{path.relative_to(package)} -> models.{match.group(1)}") + return sorted(dangling) + + +@with_generated_client_fixture(MULTI_TAG_SPEC, extra_args=["--include-tags", "billing"]) +class TestIncludeTagsViaCli: + def test_only_included_tag_api_module_is_generated(self, generated_client): + assert _api_tag_dirs(generated_client) == {"billing"} + + def test_unused_models_are_pruned(self, generated_client): + assert _model_modules(generated_client) == {"billing_model", "shared_model", "order_status", "address"} + + def test_pruned_client_has_no_dangling_imports(self, generated_client): + generated_client.import_module(".models") + assert _dangling_model_imports(generated_client) == [] + + +@with_generated_client_fixture(MULTI_TAG_SPEC, config="include_tags: [billing]") +class TestIncludeTagsViaConfigFile: + def test_only_included_tag_api_module_is_generated(self, generated_client): + assert _api_tag_dirs(generated_client) == {"billing"} + + def test_unused_models_are_pruned(self, generated_client): + assert _model_modules(generated_client) == {"billing_model", "shared_model", "order_status", "address"} + + +@with_generated_client_fixture(MULTI_TAG_SPEC, extra_args=["--exclude-tags", "admin"]) +class TestExcludeTagsViaCli: + def test_excluded_tag_api_module_is_dropped(self, generated_client): + assert _api_tag_dirs(generated_client) == {"billing", "users", "default"} + + def test_only_admin_models_are_pruned(self, generated_client): + assert _model_modules(generated_client) == {"billing_model", "shared_model", "order_status", "address"} + + +@with_generated_client_fixture(MULTI_TAG_SPEC, config="exclude_tags: [admin]") +class TestExcludeTagsViaConfigFile: + def test_excluded_tag_api_module_is_dropped(self, generated_client): + assert _api_tag_dirs(generated_client) == {"billing", "users", "default"} + + def test_only_admin_models_are_pruned(self, generated_client): + assert _model_modules(generated_client) == {"billing_model", "shared_model", "order_status", "address"} + + +class TestMutuallyExclusiveTagFlags: + def test_both_flags_exits_nonzero(self): + result = inline_spec_should_fail( + MULTI_TAG_SPEC, + extra_args=["--include-tags", "billing", "--exclude-tags", "admin"], + ) + assert "Provide either include_tags or exclude_tags, not both" in result.output diff --git a/openapi_python_client/cli.py b/openapi_python_client/cli.py index 3972703ae..7cd3b390f 100644 --- a/openapi_python_client/cli.py +++ b/openapi_python_client/cli.py @@ -18,6 +18,27 @@ def _version_callback(value: bool) -> None: raise typer.Exit() +def _split_comma_separated(value: str | None) -> list[str]: + if not value: + return [] + return [item.strip() for item in value.split(",") if item.strip()] + + +def _load_config_file( + *, + config_path: Path | None, +) -> ConfigFile: + if not config_path: + config_file = ConfigFile() + else: + try: + config_file = ConfigFile.load_from_path(path=config_path) + except Exception as err: + raise typer.BadParameter("Unable to parse config") from err + + return config_file + + def _process_config( *, url: str | None, @@ -27,6 +48,8 @@ def _process_config( file_encoding: str, overwrite: bool, output_path: Path | None, + include_tags: str | None = None, + exclude_tags: str | None = None, ) -> Config: source: Path | str if url and not path: @@ -46,13 +69,16 @@ def _process_config( typer.secho(f"Unknown encoding : {file_encoding}", fg=typer.colors.RED) raise typer.Exit(code=1) from err - if not config_path: - config_file = ConfigFile() - else: - try: - config_file = ConfigFile.load_from_path(path=config_path) - except Exception as err: - raise typer.BadParameter("Unable to parse config") from err + config_file = _load_config_file(config_path=config_path) + + if include_tags is not None: + config_file.include_tags = _split_comma_separated(include_tags) + if exclude_tags is not None: + config_file.exclude_tags = _split_comma_separated(exclude_tags) + + if config_file.include_tags and config_file.exclude_tags: + typer.secho("Provide either include_tags or exclude_tags, not both", fg=typer.colors.RED) + raise typer.Exit(code=1) return Config.from_sources(config_file, meta_type, source, file_encoding, overwrite, output_path=output_path) @@ -148,6 +174,20 @@ def generate( "Defaults to the OpenAPI document title converted to kebab or snake case (depending on meta type). " "Can also be overridden with `project_name_override` or `package_name_override` in config.", ), + include_tags: str | None = typer.Option( + None, + "--include-tags", + help="Comma-separated tags to generate. " + "Keeps matching endpoints, drops the rest, prunes unused schemas. " + "Case-sensitive. Overrides config. Can't combine with --exclude-tags.", + ), + exclude_tags: str | None = typer.Option( + None, + "--exclude-tags", + help="Comma-separated tags to skip. " + "Drops matching endpoints, keeps the rest, prunes unused schemas. " + "Case-sensitive. Overrides config. Can't combine with --include-tags.", + ), ) -> None: """Generate a new OpenAPI Client library""" from . import generate # noqa: PLC0415 @@ -160,6 +200,8 @@ def generate( file_encoding=file_encoding, overwrite=overwrite, output_path=output_path, + include_tags=include_tags, + exclude_tags=exclude_tags, ) errors = generate( custom_template_path=custom_template_path, diff --git a/openapi_python_client/config.py b/openapi_python_client/config.py index a30d2ddbf..6d4c37728 100644 --- a/openapi_python_client/config.py +++ b/openapi_python_client/config.py @@ -46,6 +46,8 @@ class ConfigFile(BaseModel): generate_all_tags: bool = False http_timeout: int = 5 literal_enums: bool = False + include_tags: list[str] | None = None + exclude_tags: list[str] | None = None @staticmethod def load_from_path(path: Path) -> "ConfigFile": @@ -76,6 +78,8 @@ class Config: generate_all_tags: bool http_timeout: int literal_enums: bool + include_tags: list[str] + exclude_tags: list[str] document_source: Path | str file_encoding: str content_type_overrides: dict[str, str] @@ -118,6 +122,8 @@ def from_sources( generate_all_tags=config_file.generate_all_tags, http_timeout=config_file.http_timeout, literal_enums=config_file.literal_enums, + include_tags=(config_file.include_tags or []), + exclude_tags=(config_file.exclude_tags or []), document_source=document_source, file_encoding=file_encoding, overwrite=overwrite, diff --git a/openapi_python_client/parser/_pruning.py b/openapi_python_client/parser/_pruning.py new file mode 100644 index 000000000..961d68adc --- /dev/null +++ b/openapi_python_client/parser/_pruning.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +__all__ = ["get_reachable_classes"] + +from collections.abc import Iterable, Mapping +from typing import TYPE_CHECKING + +from ..utils import ClassName +from .properties import ( + EnumProperty, + ListProperty, + LiteralEnumProperty, + ModelProperty, + Property, + UnionProperty, +) +from .properties.protocol import PropertyProtocol + +if TYPE_CHECKING: # pragma: no cover + from .openapi import Endpoint + + +def get_reachable_classes( + *, + endpoints: Iterable[Endpoint], + classes_by_name: Mapping[ClassName, Property], +) -> set[ClassName]: + """Class names reachable from the given endpoints. Anything else is safe to prune. + + Walks each endpoint's properties transitively, collecting models and enums by name. + Re-fetches each model from ``classes_by_name`` on descent, since the copy an endpoint + holds may have empty properties. + """ + reachable: set[ClassName] = set() + stack: list[PropertyProtocol] = [] + for endpoint in endpoints: + stack.extend(endpoint.list_all_parameters()) + stack.extend(response.prop for response in endpoint.responses) + + while stack: + prop = stack.pop() + if isinstance(prop, ModelProperty): + name = prop.class_info.name + if name in reachable: + continue + reachable.add(name) + canonical = classes_by_name.get(name, prop) + if isinstance(canonical, ModelProperty): + stack.extend(canonical.required_properties or []) + stack.extend(canonical.optional_properties or []) + if canonical.additional_properties is not None: + stack.append(canonical.additional_properties) + elif isinstance(prop, EnumProperty | LiteralEnumProperty): + reachable.add(prop.class_info.name) + elif isinstance(prop, ListProperty): + stack.append(prop.inner_property) + elif isinstance(prop, UnionProperty): + stack.extend(prop.inner_properties) + + return reachable diff --git a/openapi_python_client/parser/openapi.py b/openapi_python_client/parser/openapi.py index 4f83ae93e..48cf695b3 100644 --- a/openapi_python_client/parser/openapi.py +++ b/openapi_python_client/parser/openapi.py @@ -10,6 +10,7 @@ from .. import utils from ..config import Config from ..utils import PythonIdentifier +from ._pruning import get_reachable_classes from .bodies import Body, body_from_data from .errors import GeneratorError, ParseError, PropertyError from .properties import ( @@ -35,6 +36,14 @@ def import_string_from_class(class_: Class, prefix: str = "") -> str: return f"from {prefix}.{class_.module_name} import {class_.name}" +def _filter_tags(tags: list[str], config: Config) -> list[str]: + if config.include_tags: + return [tag for tag in tags if tag in config.include_tags] + if config.exclude_tags: + return [tag for tag in tags if tag not in config.exclude_tags] + return tags + + @dataclass class EndpointCollection: """A bunch of endpoints grouped under a tag that will become a module""" @@ -64,7 +73,11 @@ def from_data( if operation is None: continue - tags = [utils.PythonIdentifier(value=tag, prefix="tag") for tag in operation.tags or ["default"]] + filtered_tags = _filter_tags(operation.tags or ["default"], config) + if not filtered_tags: + continue + + tags = [utils.PythonIdentifier(value=tag, prefix="tag") for tag in filtered_tags] if not config.generate_all_tags: tags = tags[:1] @@ -544,6 +557,16 @@ def from_dict(data: dict[str, Any], *, config: Config) -> "GeneratorData | Gener ] models = [prop for prop in schemas.classes_by_name.values() if isinstance(prop, ModelProperty)] + if config.include_tags or config.exclude_tags: + reachable = get_reachable_classes( + endpoints=( + endpoint for collection in endpoint_collections_by_tag.values() for endpoint in collection.endpoints + ), + classes_by_name=schemas.classes_by_name, + ) + models = [model for model in models if model.class_info.name in reachable] + enums = [enum for enum in enums if enum.class_info.name in reachable] + return GeneratorData( title=openapi.info.title, description=openapi.info.description, diff --git a/tests/test_cli.py b/tests/test_cli.py index 9712b2a04..9cee936ad 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,3 +1,5 @@ +import json + from typer.testing import CliRunner from openapi_python_client.cli import app @@ -41,3 +43,77 @@ def test_generate_encoding_errors(self) -> None: assert result.exit_code == 1 assert result.output == f"Unknown encoding : {file_encoding}\n" + + +class TestTagFilterOptions: + def _config_from_invoke(self, mocker, args): + generate = mocker.patch("openapi_python_client.generate", return_value=[]) + result = runner.invoke(app, ["generate", "--path=openapi.json", *args]) + assert result.exit_code == 0, result.output + return generate.call_args.kwargs["config"] + + def test_include_tags(self, mocker) -> None: + config = self._config_from_invoke(mocker, ["--include-tags=billing,users"]) + assert config.include_tags == ["billing", "users"] + assert config.exclude_tags == [] + + def test_exclude_tags(self, mocker) -> None: + config = self._config_from_invoke(mocker, ["--exclude-tags=admin"]) + assert config.exclude_tags == ["admin"] + assert config.include_tags == [] + + def test_include_tags_whitespace_trimmed(self, mocker) -> None: + config = self._config_from_invoke(mocker, ["--include-tags=a, b"]) + assert config.include_tags == ["a", "b"] + assert config.exclude_tags == [] + + def test_empty_include_tags_yields_no_filter(self, mocker) -> None: + config = self._config_from_invoke(mocker, ["--include-tags="]) + assert config.include_tags == [] + assert config.exclude_tags == [] + + def test_both_flags_exits_with_error(self, mocker) -> None: + generate = mocker.patch("openapi_python_client.generate", return_value=[]) + result = runner.invoke( + app, ["generate", "--path=openapi.json", "--include-tags=billing", "--exclude-tags=admin"] + ) + assert result.exit_code == 1 + assert result.output == "Provide either include_tags or exclude_tags, not both\n" + generate.assert_not_called() + + def test_cli_overrides_config_file_value(self, mocker, tmp_path) -> None: + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps({"include_tags": ["from_config"]})) + config = self._config_from_invoke(mocker, [f"--config={config_path}", "--include-tags=from_cli"]) + assert config.include_tags == ["from_cli"] + + def test_cli_include_with_opposite_config_exclude_errors(self, mocker, tmp_path) -> None: + generate = mocker.patch("openapi_python_client.generate", return_value=[]) + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps({"exclude_tags": ["admin"]})) + result = runner.invoke( + app, ["generate", "--path=openapi.json", f"--config={config_path}", "--include-tags=billing"] + ) + assert result.exit_code == 1 + assert result.output == "Provide either include_tags or exclude_tags, not both\n" + generate.assert_not_called() + + def test_cli_exclude_with_opposite_config_include_errors(self, mocker, tmp_path) -> None: + generate = mocker.patch("openapi_python_client.generate", return_value=[]) + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps({"include_tags": ["billing"]})) + result = runner.invoke( + app, ["generate", "--path=openapi.json", f"--config={config_path}", "--exclude-tags=admin"] + ) + assert result.exit_code == 1 + assert result.output == "Provide either include_tags or exclude_tags, not both\n" + generate.assert_not_called() + + def test_config_with_both_tags_and_no_cli_flags_errors(self, mocker, tmp_path) -> None: + generate = mocker.patch("openapi_python_client.generate", return_value=[]) + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps({"include_tags": ["billing"], "exclude_tags": ["admin"]})) + result = runner.invoke(app, ["generate", "--path=openapi.json", f"--config={config_path}"]) + assert result.exit_code == 1 + assert result.output == "Provide either include_tags or exclude_tags, not both\n" + generate.assert_not_called() diff --git a/tests/test_config.py b/tests/test_config.py index be2e8bf59..780ed441b 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -7,6 +7,7 @@ import pytest from ruamel.yaml import YAML as _YAML +from openapi_python_client import Config, MetaType from openapi_python_client.config import ConfigFile @@ -59,3 +60,24 @@ def test_load_from_path(tmp_path: Path, filename, dump, relative) -> None: assert config.project_name_override == "project-name" assert config.package_name_override == "package_name" assert config.package_version_override == "package_version" + + +def _config_from_sources(config_file: ConfigFile) -> Config: + return Config.from_sources( + config_file, + MetaType.NONE, + document_source=Path("openapi.yaml"), + file_encoding="utf-8", + overwrite=False, + output_path=None, + ) + + +def test_config_from_sources_normalizes_tag_filters() -> None: + defaults = _config_from_sources(ConfigFile()) + assert defaults.include_tags == [] + assert defaults.exclude_tags == [] + + both = _config_from_sources(ConfigFile(include_tags=["billing"], exclude_tags=["admin"])) + assert both.include_tags == ["billing"] + assert both.exclude_tags == ["admin"] diff --git a/tests/test_parser/test_openapi.py b/tests/test_parser/test_openapi.py index 3bd743804..bede31cf3 100644 --- a/tests/test_parser/test_openapi.py +++ b/tests/test_parser/test_openapi.py @@ -2,11 +2,17 @@ import pydantic import pytest +from attr import evolve import openapi_python_client.schema as oai from openapi_python_client.parser.errors import ParseError from openapi_python_client.parser.openapi import Endpoint, EndpointCollection, import_string_from_class -from openapi_python_client.parser.properties import Class, IntProperty, Parameters, Schemas +from openapi_python_client.parser.properties import ( + Class, + IntProperty, + Parameters, + Schemas, +) from openapi_python_client.schema import DataType MODULE_NAME = "openapi_python_client.parser.openapi" @@ -661,3 +667,98 @@ def test_from_data_overrides_path_item_params_with_operation_params(self, config ) collection: EndpointCollection = collections["default"] assert isinstance(collection.endpoints[0].query_parameters[0], IntProperty) + + @staticmethod + def _tagged_data() -> dict: + def _op(tags: list[str] | None) -> oai.Operation: + return oai.Operation.model_construct( + tags=tags, + responses={"200": oai.Response.model_construct(description="ok")}, + ) + + return { + "/billing": oai.PathItem.model_construct(get=_op(["billing"])), + "/admin": oai.PathItem.model_construct(get=_op(["admin"])), + "/users": oai.PathItem.model_construct(get=_op(["users"])), + "/untagged": oai.PathItem.model_construct(get=_op(None)), + } + + def _collect(self, data: dict, config) -> dict: + collections, _schemas, _parameters = EndpointCollection.from_data( + data=data, + schemas=Schemas(), + parameters=Parameters(), + config=config, + request_bodies={}, + responses={}, + ) + return collections + + def test_from_data_no_filter_keeps_every_tag(self, config): + collections = self._collect(self._tagged_data(), config) + assert set(collections.keys()) == {"billing", "admin", "users", "default"} + + def test_from_data_include_tags_keeps_only_included(self, config): + collections = self._collect(self._tagged_data(), evolve(config, include_tags=["billing"])) + assert set(collections.keys()) == {"billing"} + + def test_from_data_exclude_tags_drops_excluded(self, config): + collections = self._collect(self._tagged_data(), evolve(config, exclude_tags=["admin"])) + assert "admin" not in collections + assert {"billing", "users", "default"} == set(collections.keys()) + + def test_from_data_untagged_excluded_when_default_not_included(self, config): + collections = self._collect(self._tagged_data(), evolve(config, include_tags=["billing"])) + assert "default" not in collections + + def test_from_data_include_default_keeps_untagged(self, config): + collections = self._collect(self._tagged_data(), evolve(config, include_tags=["default"])) + assert set(collections.keys()) == {"default"} + + def test_from_data_filters_before_first_tag_truncation(self, config): + data = { + "/multi": oai.PathItem.model_construct( + get=oai.Operation.model_construct( + tags=["admin", "billing"], + responses={"200": oai.Response.model_construct(description="ok")}, + ), + ), + } + collections = self._collect(data, evolve(config, include_tags=["billing"])) + assert set(collections.keys()) == {"billing"} + + def test_from_data_generate_all_tags_keeps_only_surviving_tags(self, config): + data = { + "/multi": oai.PathItem.model_construct( + get=oai.Operation.model_construct( + tags=["admin", "billing", "users"], + responses={"200": oai.Response.model_construct(description="ok")}, + ), + ), + } + collections = self._collect(data, evolve(config, include_tags=["billing", "users"], generate_all_tags=True)) + assert set(collections.keys()) == {"billing", "users"} + + def test_from_data_exclude_tags_with_generate_all_tags(self, config): + data = { + "/multi": oai.PathItem.model_construct( + get=oai.Operation.model_construct( + tags=["admin", "billing", "users"], + responses={"200": oai.Response.model_construct(description="ok")}, + ), + ), + } + collections = self._collect(data, evolve(config, exclude_tags=["admin"], generate_all_tags=True)) + assert set(collections.keys()) == {"billing", "users"} + + def test_from_data_tag_matching_is_case_sensitive(self, config): + data = { + "/billing": oai.PathItem.model_construct( + get=oai.Operation.model_construct( + tags=["Billing"], + responses={"200": oai.Response.model_construct(description="ok")}, + ), + ), + } + collections = self._collect(data, evolve(config, include_tags=["billing"])) + assert collections == {} diff --git a/tests/test_parser/test_pruning.py b/tests/test_parser/test_pruning.py new file mode 100644 index 000000000..cb55a230b --- /dev/null +++ b/tests/test_parser/test_pruning.py @@ -0,0 +1,142 @@ +import openapi_python_client.schema as oai +from openapi_python_client.parser._pruning import get_reachable_classes +from openapi_python_client.parser.bodies import Body, BodyType +from openapi_python_client.parser.openapi import Endpoint +from openapi_python_client.parser.properties import Class +from openapi_python_client.parser.responses import NONE_SOURCE, HTTPStatusPattern, Response, Responses +from openapi_python_client.utils import ClassName, PythonIdentifier + + +def _class(name: str) -> Class: + return Class(name=ClassName(name, ""), module_name=PythonIdentifier(name, "")) + + +def _endpoint(**kwargs) -> Endpoint: + return Endpoint( + path="/x", + method="get", + description=None, + name="x", + requires_security=False, + tags=[], + **kwargs, + ) + + +def _response(prop) -> Response: + return Response( + status_code=HTTPStatusPattern(pattern="200", code_range=(200, 200)), + prop=prop, + source=NONE_SOURCE, + data=oai.Response.model_construct(description="ok"), + ) + + +class TestGetReachableClasses: + def test_no_endpoints_reaches_nothing(self): + assert get_reachable_classes(endpoints=[], classes_by_name={}) == set() + + def test_ignores_scalar_properties(self, string_property_factory): + result = get_reachable_classes( + endpoints=[_endpoint(query_parameters=[string_property_factory()])], + classes_by_name={}, + ) + assert result == set() + + def test_collects_model_from_parameter(self, model_property_factory): + thing = model_property_factory(class_info=_class("Thing"), required_properties=[], optional_properties=[]) + classes = {ClassName("Thing", ""): thing} + result = get_reachable_classes(endpoints=[_endpoint(query_parameters=[thing])], classes_by_name=classes) + assert result == {"Thing"} + + def test_collects_required_optional_and_additional_properties(self, model_property_factory, enum_property_factory): + status = enum_property_factory(name="status", class_info=_class("Status")) + shared = model_property_factory(class_info=_class("Shared"), required_properties=[], optional_properties=[]) + extra = model_property_factory(class_info=_class("Extra"), required_properties=[], optional_properties=[]) + thing = model_property_factory( + class_info=_class("Thing"), + required_properties=[status], + optional_properties=[shared], + additional_properties=extra, + ) + classes = { + ClassName(name, ""): prop + for name, prop in (("Thing", thing), ("Status", status), ("Shared", shared), ("Extra", extra)) + } + result = get_reachable_classes(endpoints=[_endpoint(query_parameters=[thing])], classes_by_name=classes) + assert result == {"Thing", "Status", "Shared", "Extra"} + + def test_descends_through_list_inner_property(self, model_property_factory, list_property_factory): + item = model_property_factory(class_info=_class("Item"), required_properties=[], optional_properties=[]) + listed = list_property_factory(inner_property=item) + classes = {ClassName("Item", ""): item} + result = get_reachable_classes(endpoints=[_endpoint(query_parameters=[listed])], classes_by_name=classes) + assert result == {"Item"} + + def test_descends_through_union_inner_properties( + self, model_property_factory, literal_enum_property_factory, union_property_factory + ): + member = model_property_factory(class_info=_class("Member"), required_properties=[], optional_properties=[]) + kind = literal_enum_property_factory(name="kind", class_info=_class("Kind")) + union = union_property_factory(inner_properties=[member, kind]) + classes = {ClassName("Member", ""): member, ClassName("Kind", ""): kind} + result = get_reachable_classes(endpoints=[_endpoint(query_parameters=[union])], classes_by_name=classes) + assert result == {"Member", "Kind"} + + def test_descends_into_canonical_definition_not_stale_copy(self, model_property_factory, enum_property_factory): + currency = enum_property_factory(name="currency", class_info=_class("Currency")) + canonical_detail = model_property_factory( + class_info=_class("Detail"), required_properties=[currency], optional_properties=[] + ) + stale_detail = model_property_factory( + class_info=_class("Detail"), required_properties=None, optional_properties=None + ) + thing = model_property_factory( + class_info=_class("Thing"), required_properties=[stale_detail], optional_properties=[] + ) + classes = { + ClassName("Thing", ""): thing, + ClassName("Detail", ""): canonical_detail, + ClassName("Currency", ""): currency, + } + result = get_reachable_classes(endpoints=[_endpoint(query_parameters=[thing])], classes_by_name=classes) + assert result == {"Thing", "Detail", "Currency"} + + def test_handles_self_referential_model(self, model_property_factory): + node = model_property_factory(class_info=_class("Node"), required_properties=[], optional_properties=[]) + object.__setattr__(node, "optional_properties", [node]) + classes = {ClassName("Node", ""): node} + result = get_reachable_classes(endpoints=[_endpoint(query_parameters=[node])], classes_by_name=classes) + assert result == {"Node"} + + def test_handles_mutual_cycle(self, model_property_factory): + a = model_property_factory(class_info=_class("A"), required_properties=[], optional_properties=[]) + b = model_property_factory(class_info=_class("B"), required_properties=[], optional_properties=[]) + object.__setattr__(a, "optional_properties", [b]) + object.__setattr__(b, "optional_properties", [a]) + classes = {ClassName("A", ""): a, ClassName("B", ""): b} + result = get_reachable_classes(endpoints=[_endpoint(query_parameters=[a])], classes_by_name=classes) + assert result == {"A", "B"} + + def test_seeds_from_responses(self, model_property_factory): + thing = model_property_factory(class_info=_class("RespModel"), required_properties=[], optional_properties=[]) + classes = {ClassName("RespModel", ""): thing} + endpoint = _endpoint(responses=Responses(patterns=[_response(thing)], default=None)) + result = get_reachable_classes(endpoints=[endpoint], classes_by_name=classes) + assert result == {"RespModel"} + + def test_seeds_from_default_response(self, model_property_factory): + thing = model_property_factory( + class_info=_class("DefaultModel"), required_properties=[], optional_properties=[] + ) + classes = {ClassName("DefaultModel", ""): thing} + endpoint = _endpoint(responses=Responses(patterns=[], default=_response(thing))) + result = get_reachable_classes(endpoints=[endpoint], classes_by_name=classes) + assert result == {"DefaultModel"} + + def test_seeds_from_bodies(self, model_property_factory): + thing = model_property_factory(class_info=_class("BodyModel"), required_properties=[], optional_properties=[]) + classes = {ClassName("BodyModel", ""): thing} + endpoint = _endpoint(bodies=[Body(content_type="application/json", prop=thing, body_type=BodyType.JSON)]) + result = get_reachable_classes(endpoints=[endpoint], classes_by_name=classes) + assert result == {"BodyModel"}