From 65657c21c11e0e0f4ec2c2a141229c570ee3c83b Mon Sep 17 00:00:00 2001 From: Kim Gustyr Date: Tue, 2 Jun 2026 19:47:11 +0100 Subject: [PATCH] implement http oauth resource server Pass-through OAuth 2.0 protected resource for HTTP transport: serves RFC 9728 metadata, 401 + WWW-Authenticate on missing credential, any-scheme Authorization forwarded to the API. Tests migrated to respx. beep boop --- mcp/pyproject.toml | 18 +++--- mcp/src/flagsmith_mcp/config.py | 5 ++ mcp/src/flagsmith_mcp/constants.py | 1 + mcp/src/flagsmith_mcp/oauth.py | 65 +++++++++++++++++++ mcp/src/flagsmith_mcp/server.py | 12 ++++ mcp/tests/integration/conftest.py | 25 ++++---- mcp/tests/integration/test_oauth.py | 98 +++++++++++++++++++++++++++++ mcp/tests/unit/test_server.py | 16 +++-- mcp/uv.lock | 29 +++++---- 9 files changed, 225 insertions(+), 44 deletions(-) create mode 100644 mcp/src/flagsmith_mcp/oauth.py create mode 100644 mcp/tests/integration/test_oauth.py diff --git a/mcp/pyproject.toml b/mcp/pyproject.toml index 8ed97e4ddfcd..0072afdbd6fa 100644 --- a/mcp/pyproject.toml +++ b/mcp/pyproject.toml @@ -6,8 +6,8 @@ authors = [{ name = "Flagsmith", email = "support@flagsmith.com" }] readme = "README.md" requires-python = ">=3.10" dependencies = [ - "fastmcp>=3.3.1,<4.0.0", # Base MCP functionality - "pydantic-settings>=2.0.0,<3.0.0", # Environment-driven configuration + "fastmcp>=3.3.1,<4.0.0", # Base MCP functionality + "pydantic-settings>=2.0.0,<3.0.0", # Environment-driven configuration ] [project.scripts] @@ -15,13 +15,13 @@ flagsmith-mcp = "flagsmith_mcp.server:run" [dependency-groups] dev = [ - "mypy>=2.1.0,<3.0.0", # Static type checking - "openapi-pydantic>=0.5.0,<1.0.0", # Build OpenAPI specs as fixtures - "pytest>=9.0.3,<10.0.0", # Run tests - "pytest-asyncio>=1.3.0,<2.0.0", # Run asynchronous tests - "pytest-cov>=7.0.0,<8.0.0", # Measure test coverage - "pytest-httpx>=0.35.0,<1.0.0", # Mock HTTP interactions - "ruff>=0.15.12,<0.16.0", # Lint and format + "mypy>=2.1.0,<3.0.0", # Static type checking + "openapi-pydantic>=0.5.0,<1.0.0", # Build OpenAPI specs as fixtures + "pytest>=9.0.3,<10.0.0", # Run tests + "pytest-asyncio>=1.3.0,<2.0.0", # Run asynchronous tests + "pytest-cov>=7.0.0,<8.0.0", # Measure test coverage + "respx>=0.22,<1.0", # Mock HTTP interactions + "ruff>=0.15.12,<0.16.0", # Lint and format ] [build-system] diff --git a/mcp/src/flagsmith_mcp/config.py b/mcp/src/flagsmith_mcp/config.py index 7bc5a58e5ab9..c1b0bb3227f2 100644 --- a/mcp/src/flagsmith_mcp/config.py +++ b/mcp/src/flagsmith_mcp/config.py @@ -21,6 +21,11 @@ class Settings(BaseSettings): default="http", ) """MCP transport to use.""" + mcp_server_url: str = Field( + default="http://127.0.0.1:8000", + ) + """Public base URL of this MCP server, advertised in OAuth protected-resource + metadata. Override for HTTP deployments behind a proxy/public hostname.""" @model_validator(mode="after") def validate_stdio_token(self) -> "Settings": diff --git a/mcp/src/flagsmith_mcp/constants.py b/mcp/src/flagsmith_mcp/constants.py index 46c45ed6f519..6ec4fc869c5a 100644 --- a/mcp/src/flagsmith_mcp/constants.py +++ b/mcp/src/flagsmith_mcp/constants.py @@ -1,2 +1,3 @@ # TODO: consume a version-controlled schema — https://github.com/Flagsmith/flagsmith/issues/7669 OPENAPI_SPEC_URL = "https://api.flagsmith.com/api/v1/swagger.json" +OAUTH_SCOPES = ["mcp"] diff --git a/mcp/src/flagsmith_mcp/oauth.py b/mcp/src/flagsmith_mcp/oauth.py new file mode 100644 index 000000000000..f34b561c3317 --- /dev/null +++ b/mcp/src/flagsmith_mcp/oauth.py @@ -0,0 +1,65 @@ +from fastmcp.server.auth.auth import AccessToken, RemoteAuthProvider, TokenVerifier +from mcp.server.auth.middleware.auth_context import AuthContextMiddleware +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser +from pydantic import AnyHttpUrl +from starlette.authentication import AuthCredentials, AuthenticationBackend +from starlette.middleware import Middleware +from starlette.middleware.authentication import AuthenticationMiddleware +from starlette.requests import HTTPConnection + +from flagsmith_mcp.constants import OAUTH_SCOPES + + +class _AnySchemeBackend(AuthenticationBackend): + """Authenticates a request on the mere presence of an `Authorization` + header, regardless of scheme (`Bearer` OAuth token or `Api-Key`). The + credential is forwarded upstream verbatim and validated by the API.""" + + async def authenticate( + self, conn: HTTPConnection + ) -> tuple[AuthCredentials, AuthenticatedUser] | None: + if not ( + header := next( + ( + conn.headers.get(k) + for k in conn.headers + if k.lower() == "authorization" + ), + None, + ) + ): + return None + + token = header.split(" ", 1)[-1] + access = AccessToken( + token=token, client_id="flagsmith-mcp", scopes=OAUTH_SCOPES + ) + return AuthCredentials(OAUTH_SCOPES), AuthenticatedUser(access) + + +class FlagsmithResourceAuth(RemoteAuthProvider): + """OAuth 2.0 protected resource for HTTP transport. + + Serves Protected Resource Metadata (RFC 9728) pointing at the Flagsmith + authorization server and returns 401 + `WWW-Authenticate` when a request + carries no credential, so MCP clients can discover and complete the OAuth + flow. Any `Authorization` header is accepted and passed through — the API + validates it (no introspection here). + """ + + def __init__(self, *, resource_url: str, authorization_server: str) -> None: + token_verifier = TokenVerifier( + required_scopes=[] + ) # never consulted — introspection done by Core API. + super().__init__( + token_verifier=token_verifier, + authorization_servers=[AnyHttpUrl(authorization_server)], + base_url=resource_url, + scopes_supported=OAUTH_SCOPES, + ) + + def get_middleware(self) -> list[Middleware]: + return [ + Middleware(AuthenticationMiddleware, backend=_AnySchemeBackend()), + Middleware(AuthContextMiddleware), + ] diff --git a/mcp/src/flagsmith_mcp/server.py b/mcp/src/flagsmith_mcp/server.py index f1e04e5e489d..e755808c3da0 100644 --- a/mcp/src/flagsmith_mcp/server.py +++ b/mcp/src/flagsmith_mcp/server.py @@ -9,6 +9,7 @@ from flagsmith_mcp import config, constants from flagsmith_mcp.auth import FlagsmithAuth +from flagsmith_mcp.oauth import FlagsmithResourceAuth ROUTE_MAPS = [ RouteMap(tags={"mcp"}, mcp_type=MCPType.TOOL), @@ -40,6 +41,16 @@ def _fetch_spec() -> dict[str, Any]: def create_server(settings: config.Settings) -> FastMCP[None]: + # OAuth discovery is the credential fallback for HTTP transport: only when + # the server holds no static token does it advertise the AS and gate on a + # missing Authorization header. Otherwise (stdio, static token, or a + # forwarded --header) it's pure pass-through. + auth = None + if settings.transport == "http" and settings.flagsmith_api_token is None: + auth = FlagsmithResourceAuth( + resource_url=settings.mcp_server_url, + authorization_server=settings.flagsmith_api_url, + ) return FastMCP.from_openapi( openapi_spec=_fetch_spec(), client=httpx.AsyncClient( @@ -50,6 +61,7 @@ def create_server(settings: config.Settings) -> FastMCP[None]: route_maps=ROUTE_MAPS, mcp_component_fn=_customise, validate_output=False, # TODO https://github.com/Flagsmith/flagsmith/issues/7679 + auth=auth, ) diff --git a/mcp/tests/integration/conftest.py b/mcp/tests/integration/conftest.py index 4546eea28690..d9ea1edc5e14 100644 --- a/mcp/tests/integration/conftest.py +++ b/mcp/tests/integration/conftest.py @@ -1,19 +1,19 @@ from collections.abc import AsyncIterator -from typing import Any import openapi_pydantic as openapi import pytest from fastmcp import Client, FastMCP from fastmcp.client.transports import FastMCPTransport +from respx import MockRouter -from flagsmith_mcp import config +from flagsmith_mcp import config, constants from flagsmith_mcp import server as server_module @pytest.fixture -def openapi_spec() -> dict[str, Any]: +def openapi_spec() -> openapi.OpenAPI: ok = openapi.Response(description="OK") - spec = openapi.OpenAPI( + return openapi.OpenAPI( info=openapi.Info(title="Flagsmith API", version="1.0.0"), paths={ "/environments/": openapi.PathItem( @@ -35,16 +35,19 @@ def openapi_spec() -> dict[str, Any]: ), }, ) - return spec.model_dump(by_alias=True, exclude_none=True, mode="json") + + +@pytest.fixture(autouse=True) +def openapi_spec_mock(respx_mock: MockRouter, openapi_spec: openapi.OpenAPI) -> None: + # create_server fetches the OpenAPI spec over HTTP; mock that call (respx + # leaves the in-memory ASGI transport used by the tests untouched). + respx_mock.get(constants.OPENAPI_SPEC_URL).respond( + json=openapi_spec.model_dump(by_alias=True, exclude_none=True, mode="json") + ) @pytest.fixture -def server( - monkeypatch: pytest.MonkeyPatch, - openapi_spec: dict[str, Any], -) -> FastMCP: - monkeypatch.setenv("FLAGSMITH_API_URL", "https://flagsmith.example.com") - monkeypatch.setattr(server_module, "_fetch_spec", lambda: openapi_spec) +def server() -> FastMCP: return server_module.create_server(config.Settings()) diff --git a/mcp/tests/integration/test_oauth.py b/mcp/tests/integration/test_oauth.py new file mode 100644 index 000000000000..13a646ffd154 --- /dev/null +++ b/mcp/tests/integration/test_oauth.py @@ -0,0 +1,98 @@ +from collections.abc import AsyncIterator +from typing import Callable + +import httpx +import pytest +from fastmcp import FastMCP + +from flagsmith_mcp import config +from flagsmith_mcp import server as server_module + +HTTPClientFactoryFixture = Callable[[FastMCP], AsyncIterator[httpx.AsyncClient]] + + +PRM_PATH = "/.well-known/oauth-protected-resource/mcp" + + +@pytest.fixture +def http_client_factory() -> HTTPClientFactoryFixture: + async def factory(server: FastMCP) -> AsyncIterator[httpx.AsyncClient]: + transport = httpx.ASGITransport(app=server.http_app()) + async with httpx.AsyncClient( + transport=transport, base_url="http://testserver" + ) as connected: + yield connected + + return factory + + +@pytest.fixture +async def http_client( + server: FastMCP, + http_client_factory: HTTPClientFactoryFixture, +) -> AsyncIterator[httpx.AsyncClient]: + async for client in http_client_factory(server): + yield client + + +@pytest.fixture +def server_with_flagsmith_api_token() -> FastMCP: + return server_module.create_server( + config.Settings(flagsmith_api_token="secret.token") + ) + + +@pytest.fixture +async def http_client_with_flagsmith_api_token( + server_with_flagsmith_api_token: FastMCP, + http_client_factory: HTTPClientFactoryFixture, +) -> AsyncIterator[httpx.AsyncClient]: + async for client in http_client_factory(server_with_flagsmith_api_token): + yield client + + +async def test_http_no_token__serves_protected_resource_metadata( + http_client: httpx.AsyncClient, +) -> None: + # Given OAuth discovery is active (server fixture: http, no static token) + response = await http_client.get(PRM_PATH) + + # Then it advertises the Flagsmith AS and the mcp scope (RFC 9728) + assert response.status_code == 200 + body = response.json() + assert body["authorization_servers"] == ["https://api.flagsmith.com/"] + assert body["scopes_supported"] == ["mcp"] + + +async def test_http_no_token__missing_authorization__401_points_at_prm( + http_client: httpx.AsyncClient, +) -> None: + # When a request reaches the MCP endpoint with no credential + response = await http_client.get("/mcp") + + # Then it 401s (no API round-trip) and points the client at the PRM + assert response.status_code == 401 + assert PRM_PATH in response.headers["www-authenticate"] + + +async def test_http_no_token__non_bearer_credential__accepted_by_gate( + http_client: httpx.AsyncClient, +) -> None: + # When a request carries a non-Bearer (Api-Key) credential + response = await http_client.get( + PRM_PATH, headers={"Authorization": "Api-Key ser.secret"} + ) + + # Then the gate authenticates it (scheme-agnostic); end-to-end pass-through + # to the API is exercised against live SaaS. + assert response.status_code == 200 + + +async def test_http_static_token__no_oauth_resource( + http_client_with_flagsmith_api_token: httpx.AsyncClient, +) -> None: + # Given a static token (pure pass-through, OAuth disabled) + response = await http_client_with_flagsmith_api_token.get(PRM_PATH) + + # Then no protected-resource metadata is served + assert response.status_code == 404 diff --git a/mcp/tests/unit/test_server.py b/mcp/tests/unit/test_server.py index 4ab8125775e2..d83eb0bed850 100644 --- a/mcp/tests/unit/test_server.py +++ b/mcp/tests/unit/test_server.py @@ -4,7 +4,7 @@ import pytest from fastmcp import Client from mcp.types import ToolAnnotations -from pytest_httpx import HTTPXMock +from respx import MockRouter from flagsmith_mcp import config, constants, server @@ -63,7 +63,7 @@ ], ) async def test_create_server__mcp_route__annotates_tool_per_method( - httpx_mock: HTTPXMock, + respx_mock: MockRouter, method: str, expected: ToolAnnotations, ) -> None: @@ -77,9 +77,8 @@ async def test_create_server__mcp_route__annotates_tool_per_method( info=openapi.Info(title="Flagsmith API", version="1.0.0"), paths={"/things/": openapi.PathItem.model_validate({method: operation})}, ) - httpx_mock.add_response( - url=constants.OPENAPI_SPEC_URL, - json=spec.model_dump(by_alias=True, exclude_none=True, mode="json"), + respx_mock.get(constants.OPENAPI_SPEC_URL).respond( + json=spec.model_dump(by_alias=True, exclude_none=True, mode="json") ) # When @@ -91,7 +90,7 @@ async def test_create_server__mcp_route__annotates_tool_per_method( async def test_create_server__untagged_route__excluded_from_tools( - httpx_mock: HTTPXMock, + respx_mock: MockRouter, ) -> None: # Given a spec with one mcp-tagged route and one untagged route spec = openapi.OpenAPI( @@ -112,9 +111,8 @@ async def test_create_server__untagged_route__excluded_from_tools( ), }, ) - httpx_mock.add_response( - url=constants.OPENAPI_SPEC_URL, - json=spec.model_dump(by_alias=True, exclude_none=True, mode="json"), + respx_mock.get(constants.OPENAPI_SPEC_URL).respond( + json=spec.model_dump(by_alias=True, exclude_none=True, mode="json") ) # When diff --git a/mcp/uv.lock b/mcp/uv.lock index bf5df4b6b768..60d09059ea7d 100644 --- a/mcp/uv.lock +++ b/mcp/uv.lock @@ -614,7 +614,7 @@ dev = [ { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-cov" }, - { name = "pytest-httpx" }, + { name = "respx" }, { name = "ruff" }, ] @@ -631,7 +631,7 @@ dev = [ { name = "pytest", specifier = ">=9.0.3,<10.0.0" }, { name = "pytest-asyncio", specifier = ">=1.3.0,<2.0.0" }, { name = "pytest-cov", specifier = ">=7.0.0,<8.0.0" }, - { name = "pytest-httpx", specifier = ">=0.35.0,<1.0.0" }, + { name = "respx", specifier = ">=0.22,<1.0" }, { name = "ruff", specifier = ">=0.15.12,<0.16.0" }, ] @@ -1390,19 +1390,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9d/7a/d968e294073affff457b041c2be9868a40c1c71f4a35fcc1e45e5493067b/pytest_cov-7.1.0-py3-none-any.whl", hash = "sha256:a0461110b7865f9a271aa1b51e516c9a95de9d696734a2f71e3e78f46e1d4678", size = 22876, upload-time = "2026-03-21T20:11:14.438Z" }, ] -[[package]] -name = "pytest-httpx" -version = "0.36.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "httpx" }, - { name = "pytest" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/4e/42/f53c58570e80d503ade9dd42ce57f2915d14bcbe25f6308138143950d1d6/pytest_httpx-0.36.2.tar.gz", hash = "sha256:05a56527484f7f4e8c856419ea379b8dc359c36801c4992fdb330f294c690356", size = 57683, upload-time = "2026-04-09T13:57:19.837Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/1e/55/1fa65f8e4fceb19dd6daa867c162ad845d547f6058cd92b4b02384a44777/pytest_httpx-0.36.2-py3-none-any.whl", hash = "sha256:d42ebd5679442dc7bfb0c48e0767b6562e9bc4534d805127b0084171886a5e22", size = 20315, upload-time = "2026-04-09T13:57:18.587Z" }, -] - [[package]] name = "python-dotenv" version = "1.2.2" @@ -1531,6 +1518,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2c/58/ca301544e1fa93ed4f80d724bf5b194f6e4b945841c5bfd555878eea9fcb/referencing-0.37.0-py3-none-any.whl", hash = "sha256:381329a9f99628c9069361716891d34ad94af76e461dcb0335825aecc7692231", size = 26766, upload-time = "2025-10-13T15:30:47.625Z" }, ] +[[package]] +name = "respx" +version = "0.23.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "httpx" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/43/98/4e55c9c486404ec12373708d015ebce157966965a5ebe7f28ff2c784d41b/respx-0.23.1.tar.gz", hash = "sha256:242dcc6ce6b5b9bf621f5870c82a63997e8e82bc7c947f9ffe272b8f3dd5a780", size = 29243, upload-time = "2026-04-08T14:37:16.008Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1d/4a/221da6ca167db45693d8d26c7dc79ccfc978a440251bf6721c9aaf251ac0/respx-0.23.1-py2.py3-none-any.whl", hash = "sha256:b18004b029935384bccfa6d7d9d74b4ec9af73a081cc28600fffc0447f4b8c1a", size = 25557, upload-time = "2026-04-08T14:37:14.613Z" }, +] + [[package]] name = "rich" version = "15.0.0"