diff --git a/lite_bootstrap/bootstrappers/fastapi_bootstrapper.py b/lite_bootstrap/bootstrappers/fastapi_bootstrapper.py index 516edc5..d8cd8d4 100644 --- a/lite_bootstrap/bootstrappers/fastapi_bootstrapper.py +++ b/lite_bootstrap/bootstrappers/fastapi_bootstrapper.py @@ -3,6 +3,7 @@ from lite_bootstrap import import_checker from lite_bootstrap.bootstrappers.base import BaseBootstrapper +from lite_bootstrap.instruments.cors_instrument import CorsConfig, CorsInstrument from lite_bootstrap.instruments.healthchecks_instrument import ( HealthChecksConfig, HealthChecksInstrument, @@ -16,6 +17,7 @@ if import_checker.is_fastapi_installed: import fastapi + from fastapi.middleware.cors import CORSMiddleware if import_checker.is_opentelemetry_installed: from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor @@ -26,7 +28,7 @@ @dataclasses.dataclass(kw_only=True, slots=True, frozen=True) -class FastAPIConfig(HealthChecksConfig, LoggingConfig, OpentelemetryConfig, PrometheusConfig, SentryConfig): +class FastAPIConfig(CorsConfig, HealthChecksConfig, LoggingConfig, OpentelemetryConfig, PrometheusConfig, SentryConfig): application: "fastapi.FastAPI" = dataclasses.field(default_factory=lambda: fastapi.FastAPI()) opentelemetry_excluded_urls: list[str] = dataclasses.field(default_factory=list) prometheus_instrumentator_params: dict[str, typing.Any] = dataclasses.field(default_factory=dict) @@ -34,6 +36,23 @@ class FastAPIConfig(HealthChecksConfig, LoggingConfig, OpentelemetryConfig, Prom prometheus_expose_params: dict[str, typing.Any] = dataclasses.field(default_factory=dict) +@dataclasses.dataclass(kw_only=True, slots=True, frozen=True) +class FastApiCorsInstrument(CorsInstrument): + bootstrap_config: FastAPIConfig + + def bootstrap(self) -> None: + self.bootstrap_config.application.add_middleware( + CORSMiddleware, + allow_origins=self.bootstrap_config.cors_allowed_origins, + allow_methods=self.bootstrap_config.cors_allowed_methods, + allow_headers=self.bootstrap_config.cors_allowed_headers, + allow_credentials=self.bootstrap_config.cors_allowed_credentials, + allow_origin_regex=self.bootstrap_config.cors_allowed_origin_regex, + expose_headers=self.bootstrap_config.cors_exposed_headers, + max_age=self.bootstrap_config.cors_max_age, + ) + + @dataclasses.dataclass(kw_only=True, slots=True, frozen=True) class FastAPIHealthChecksInstrument(HealthChecksInstrument): bootstrap_config: FastAPIConfig @@ -114,6 +133,7 @@ class FastAPIBootstrapper(BaseBootstrapper["fastapi.FastAPI"]): __slots__ = "bootstrap_config", "instruments" instruments_types: typing.ClassVar = [ + FastApiCorsInstrument, FastAPIOpenTelemetryInstrument, FastAPISentryInstrument, FastAPIHealthChecksInstrument, diff --git a/lite_bootstrap/bootstrappers/litestar_bootstrapper.py b/lite_bootstrap/bootstrappers/litestar_bootstrapper.py index d59de78..c7e4aaa 100644 --- a/lite_bootstrap/bootstrappers/litestar_bootstrapper.py +++ b/lite_bootstrap/bootstrappers/litestar_bootstrapper.py @@ -3,6 +3,7 @@ from lite_bootstrap import import_checker from lite_bootstrap.bootstrappers.base import BaseBootstrapper +from lite_bootstrap.instruments.cors_instrument import CorsConfig, CorsInstrument from lite_bootstrap.instruments.healthchecks_instrument import ( HealthChecksConfig, HealthChecksInstrument, @@ -22,6 +23,7 @@ if import_checker.is_litestar_installed: import litestar from litestar.config.app import AppConfig + from litestar.config.cors import CORSConfig from litestar.contrib.opentelemetry import OpenTelemetryConfig from litestar.plugins.prometheus import PrometheusConfig, PrometheusController @@ -31,13 +33,29 @@ @dataclasses.dataclass(kw_only=True, slots=True, frozen=True) class LitestarConfig( - HealthChecksConfig, LoggingConfig, OpentelemetryConfig, PrometheusBootstrapperConfig, SentryConfig + CorsConfig, HealthChecksConfig, LoggingConfig, OpentelemetryConfig, PrometheusBootstrapperConfig, SentryConfig ): application_config: "AppConfig" = dataclasses.field(default_factory=lambda: AppConfig()) opentelemetry_excluded_urls: list[str] = dataclasses.field(default_factory=list) prometheus_additional_params: dict[str, typing.Any] = dataclasses.field(default_factory=dict) +@dataclasses.dataclass(kw_only=True, slots=True, frozen=True) +class LitestarCorsInstrument(CorsInstrument): + bootstrap_config: LitestarConfig + + def bootstrap(self) -> None: + self.bootstrap_config.application_config.cors_config = CORSConfig( + allow_origins=self.bootstrap_config.cors_allowed_origins, + allow_methods=self.bootstrap_config.cors_allowed_methods, # type: ignore[arg-type] + allow_headers=self.bootstrap_config.cors_allowed_headers, + allow_credentials=self.bootstrap_config.cors_allowed_credentials, + allow_origin_regex=self.bootstrap_config.cors_allowed_origin_regex, + expose_headers=self.bootstrap_config.cors_exposed_headers, + max_age=self.bootstrap_config.cors_max_age, + ) + + @dataclasses.dataclass(kw_only=True, slots=True, frozen=True) class LitestarHealthChecksInstrument(HealthChecksInstrument): bootstrap_config: LitestarConfig @@ -116,6 +134,7 @@ class LitestarBootstrapper(BaseBootstrapper["litestar.Litestar"]): __slots__ = "bootstrap_config", "instruments" instruments_types: typing.ClassVar = [ + LitestarCorsInstrument, LitestarOpenTelemetryInstrument, LitestarSentryInstrument, LitestarHealthChecksInstrument, diff --git a/lite_bootstrap/instruments/cors_instrument.py b/lite_bootstrap/instruments/cors_instrument.py new file mode 100644 index 0000000..d00805c --- /dev/null +++ b/lite_bootstrap/instruments/cors_instrument.py @@ -0,0 +1,25 @@ +import dataclasses + +from lite_bootstrap.instruments.base import BaseConfig, BaseInstrument + + +@dataclasses.dataclass(kw_only=True, frozen=True) +class CorsConfig(BaseConfig): + cors_allowed_origins: list[str] = dataclasses.field(default_factory=list) + cors_allowed_methods: list[str] = dataclasses.field(default_factory=list) + cors_allowed_headers: list[str] = dataclasses.field(default_factory=list) + cors_exposed_headers: list[str] = dataclasses.field(default_factory=list) + cors_allowed_credentials: bool = False + cors_allowed_origin_regex: str | None = None + cors_max_age: int = 600 + + +@dataclasses.dataclass(kw_only=True, slots=True, frozen=True) +class CorsInstrument(BaseInstrument): + bootstrap_config: CorsConfig + not_ready_message = "cors_allowed_origins or cors_allowed_origin_regex must be provided" + + def is_ready(self) -> bool: + return bool(self.bootstrap_config.cors_allowed_origins) or bool( + self.bootstrap_config.cors_allowed_origin_regex, + ) diff --git a/tests/test_fastapi_bootstrap.py b/tests/test_fastapi_bootstrap.py index 6001af3..0de53b6 100644 --- a/tests/test_fastapi_bootstrap.py +++ b/tests/test_fastapi_bootstrap.py @@ -18,6 +18,7 @@ def fastapi_config() -> FastAPIConfig: service_version="2.0.0", service_environment="test", service_debug=False, + cors_allowed_origins=["http://test"], opentelemetry_endpoint="otl", opentelemetry_instrumentors=[CustomInstrumentor()], opentelemetry_span_exporter=ConsoleSpanExporter(), diff --git a/tests/test_litestar_bootstrap.py b/tests/test_litestar_bootstrap.py index f824acb..144ee1e 100644 --- a/tests/test_litestar_bootstrap.py +++ b/tests/test_litestar_bootstrap.py @@ -18,6 +18,7 @@ def litestar_config() -> LitestarConfig: service_version="2.0.0", service_environment="test", service_debug=False, + cors_allowed_origins=["http://test"], opentelemetry_endpoint="otl", opentelemetry_instrumentors=[CustomInstrumentor()], opentelemetry_span_exporter=ConsoleSpanExporter(), @@ -35,6 +36,9 @@ def test_litestar_bootstrap(litestar_config: LitestarConfig) -> None: try: logger.info("testing logging", key="value") + assert application.cors_config + assert application.cors_config.allow_origins == litestar_config.cors_allowed_origins + with TestClient(app=application) as test_client: response = test_client.get(litestar_config.health_checks_path) assert response.status_code == status_codes.HTTP_200_OK