Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion lite_bootstrap/bootstrappers/fastapi_bootstrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from lite_bootstrap import import_checker
from lite_bootstrap.bootstrappers.base import BaseBootstrapper
from lite_bootstrap.instruments.cors_instrument import CorsConfig, CorsInstrument
from lite_bootstrap.instruments.healthchecks_instrument import (
HealthChecksConfig,
HealthChecksInstrument,
Expand All @@ -16,6 +17,7 @@

if import_checker.is_fastapi_installed:
import fastapi
from fastapi.middleware.cors import CORSMiddleware

if import_checker.is_opentelemetry_installed:
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
Expand All @@ -26,14 +28,31 @@


@dataclasses.dataclass(kw_only=True, slots=True, frozen=True)
class FastAPIConfig(HealthChecksConfig, LoggingConfig, OpentelemetryConfig, PrometheusConfig, SentryConfig):
class FastAPIConfig(CorsConfig, HealthChecksConfig, LoggingConfig, OpentelemetryConfig, PrometheusConfig, SentryConfig):
application: "fastapi.FastAPI" = dataclasses.field(default_factory=lambda: fastapi.FastAPI())
opentelemetry_excluded_urls: list[str] = dataclasses.field(default_factory=list)
prometheus_instrumentator_params: dict[str, typing.Any] = dataclasses.field(default_factory=dict)
prometheus_instrument_params: dict[str, typing.Any] = dataclasses.field(default_factory=dict)
prometheus_expose_params: dict[str, typing.Any] = dataclasses.field(default_factory=dict)


@dataclasses.dataclass(kw_only=True, slots=True, frozen=True)
class FastApiCorsInstrument(CorsInstrument):
bootstrap_config: FastAPIConfig

def bootstrap(self) -> None:
self.bootstrap_config.application.add_middleware(
CORSMiddleware,
allow_origins=self.bootstrap_config.cors_allowed_origins,
allow_methods=self.bootstrap_config.cors_allowed_methods,
allow_headers=self.bootstrap_config.cors_allowed_headers,
allow_credentials=self.bootstrap_config.cors_allowed_credentials,
allow_origin_regex=self.bootstrap_config.cors_allowed_origin_regex,
expose_headers=self.bootstrap_config.cors_exposed_headers,
max_age=self.bootstrap_config.cors_max_age,
)


@dataclasses.dataclass(kw_only=True, slots=True, frozen=True)
class FastAPIHealthChecksInstrument(HealthChecksInstrument):
bootstrap_config: FastAPIConfig
Expand Down Expand Up @@ -114,6 +133,7 @@ class FastAPIBootstrapper(BaseBootstrapper["fastapi.FastAPI"]):
__slots__ = "bootstrap_config", "instruments"

instruments_types: typing.ClassVar = [
FastApiCorsInstrument,
FastAPIOpenTelemetryInstrument,
FastAPISentryInstrument,
FastAPIHealthChecksInstrument,
Expand Down
21 changes: 20 additions & 1 deletion lite_bootstrap/bootstrappers/litestar_bootstrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from lite_bootstrap import import_checker
from lite_bootstrap.bootstrappers.base import BaseBootstrapper
from lite_bootstrap.instruments.cors_instrument import CorsConfig, CorsInstrument
from lite_bootstrap.instruments.healthchecks_instrument import (
HealthChecksConfig,
HealthChecksInstrument,
Expand All @@ -22,6 +23,7 @@
if import_checker.is_litestar_installed:
import litestar
from litestar.config.app import AppConfig
from litestar.config.cors import CORSConfig
from litestar.contrib.opentelemetry import OpenTelemetryConfig
from litestar.plugins.prometheus import PrometheusConfig, PrometheusController

Expand All @@ -31,13 +33,29 @@

@dataclasses.dataclass(kw_only=True, slots=True, frozen=True)
class LitestarConfig(
HealthChecksConfig, LoggingConfig, OpentelemetryConfig, PrometheusBootstrapperConfig, SentryConfig
CorsConfig, HealthChecksConfig, LoggingConfig, OpentelemetryConfig, PrometheusBootstrapperConfig, SentryConfig
):
application_config: "AppConfig" = dataclasses.field(default_factory=lambda: AppConfig())
opentelemetry_excluded_urls: list[str] = dataclasses.field(default_factory=list)
prometheus_additional_params: dict[str, typing.Any] = dataclasses.field(default_factory=dict)


@dataclasses.dataclass(kw_only=True, slots=True, frozen=True)
class LitestarCorsInstrument(CorsInstrument):
bootstrap_config: LitestarConfig

def bootstrap(self) -> None:
self.bootstrap_config.application_config.cors_config = CORSConfig(
allow_origins=self.bootstrap_config.cors_allowed_origins,
allow_methods=self.bootstrap_config.cors_allowed_methods, # type: ignore[arg-type]
allow_headers=self.bootstrap_config.cors_allowed_headers,
allow_credentials=self.bootstrap_config.cors_allowed_credentials,
allow_origin_regex=self.bootstrap_config.cors_allowed_origin_regex,
expose_headers=self.bootstrap_config.cors_exposed_headers,
max_age=self.bootstrap_config.cors_max_age,
)


@dataclasses.dataclass(kw_only=True, slots=True, frozen=True)
class LitestarHealthChecksInstrument(HealthChecksInstrument):
bootstrap_config: LitestarConfig
Expand Down Expand Up @@ -116,6 +134,7 @@ class LitestarBootstrapper(BaseBootstrapper["litestar.Litestar"]):
__slots__ = "bootstrap_config", "instruments"

instruments_types: typing.ClassVar = [
LitestarCorsInstrument,
LitestarOpenTelemetryInstrument,
LitestarSentryInstrument,
LitestarHealthChecksInstrument,
Expand Down
25 changes: 25 additions & 0 deletions lite_bootstrap/instruments/cors_instrument.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import dataclasses

from lite_bootstrap.instruments.base import BaseConfig, BaseInstrument


@dataclasses.dataclass(kw_only=True, frozen=True)
class CorsConfig(BaseConfig):
cors_allowed_origins: list[str] = dataclasses.field(default_factory=list)
cors_allowed_methods: list[str] = dataclasses.field(default_factory=list)
cors_allowed_headers: list[str] = dataclasses.field(default_factory=list)
cors_exposed_headers: list[str] = dataclasses.field(default_factory=list)
cors_allowed_credentials: bool = False
cors_allowed_origin_regex: str | None = None
cors_max_age: int = 600


@dataclasses.dataclass(kw_only=True, slots=True, frozen=True)
class CorsInstrument(BaseInstrument):
bootstrap_config: CorsConfig
not_ready_message = "cors_allowed_origins or cors_allowed_origin_regex must be provided"

def is_ready(self) -> bool:
return bool(self.bootstrap_config.cors_allowed_origins) or bool(
self.bootstrap_config.cors_allowed_origin_regex,
)
1 change: 1 addition & 0 deletions tests/test_fastapi_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def fastapi_config() -> FastAPIConfig:
service_version="2.0.0",
service_environment="test",
service_debug=False,
cors_allowed_origins=["http://test"],
opentelemetry_endpoint="otl",
opentelemetry_instrumentors=[CustomInstrumentor()],
opentelemetry_span_exporter=ConsoleSpanExporter(),
Expand Down
4 changes: 4 additions & 0 deletions tests/test_litestar_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def litestar_config() -> LitestarConfig:
service_version="2.0.0",
service_environment="test",
service_debug=False,
cors_allowed_origins=["http://test"],
opentelemetry_endpoint="otl",
opentelemetry_instrumentors=[CustomInstrumentor()],
opentelemetry_span_exporter=ConsoleSpanExporter(),
Expand All @@ -35,6 +36,9 @@ def test_litestar_bootstrap(litestar_config: LitestarConfig) -> None:
try:
logger.info("testing logging", key="value")

assert application.cors_config
assert application.cors_config.allow_origins == litestar_config.cors_allowed_origins

with TestClient(app=application) as test_client:
response = test_client.get(litestar_config.health_checks_path)
assert response.status_code == status_codes.HTTP_200_OK
Expand Down