diff --git a/sqlmodel/__init__.py b/sqlmodel/__init__.py index f71b5265b1..d0d4558272 100644 --- a/sqlmodel/__init__.py +++ b/sqlmodel/__init__.py @@ -1,5 +1,9 @@ __version__ = "0.0.35" +# Re-export from Pydantic +from pydantic import Discriminator as Discriminator +from pydantic import Tag as Tag + # Re-export from SQLAlchemy from sqlalchemy.engine import create_engine as create_engine from sqlalchemy.engine import create_mock_engine as create_mock_engine diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 300031de8b..b359a188ce 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -23,7 +23,7 @@ overload, ) -from pydantic import BaseModel, EmailStr +from pydantic import BaseModel, Discriminator, EmailStr from pydantic.fields import FieldInfo as PydanticFieldInfo from sqlalchemy import ( Boolean, @@ -261,7 +261,7 @@ def Field( max_length: int | None = None, allow_mutation: bool = True, regex: str | None = None, - discriminator: str | None = None, + discriminator: str | Discriminator | None = None, repr: bool = True, primary_key: bool | UndefinedType = Undefined, foreign_key: Any = Undefined, @@ -304,7 +304,7 @@ def Field( max_length: int | None = None, allow_mutation: bool = True, regex: str | None = None, - discriminator: str | None = None, + discriminator: str | Discriminator | None = None, repr: bool = True, primary_key: bool | UndefinedType = Undefined, foreign_key: str, @@ -356,7 +356,7 @@ def Field( max_length: int | None = None, allow_mutation: bool = True, regex: str | None = None, - discriminator: str | None = None, + discriminator: str | Discriminator | None = None, repr: bool = True, sa_column: Column[Any] | UndefinedType = Undefined, schema_extra: dict[str, Any] | None = None, @@ -389,7 +389,7 @@ def Field( max_length: int | None = None, allow_mutation: bool = True, regex: str | None = None, - discriminator: str | None = None, + discriminator: str | Discriminator | None = None, repr: bool = True, primary_key: bool | UndefinedType = Undefined, foreign_key: Any = Undefined, diff --git a/tests/test_pydantic/test_field.py b/tests/test_pydantic/test_field.py index 11f4150d98..1956f55bf2 100644 --- a/tests/test_pydantic/test_field.py +++ b/tests/test_pydantic/test_field.py @@ -1,9 +1,9 @@ from decimal import Decimal -from typing import Literal +from typing import Annotated, Any, Literal import pytest from pydantic import ValidationError -from sqlmodel import Field, SQLModel +from sqlmodel import Discriminator, Field, SQLModel, Tag def test_decimal(): @@ -47,6 +47,38 @@ class Model(SQLModel): Model(pet={"pet_type": "dog"}, n=1) # type: ignore[arg-type] +def test_discriminator_callable(): + # Example adapted from + # [Pydantic docs](https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator): + + class Pie(SQLModel): + pass + + class ApplePie(Pie): + fruit: Literal["apple"] = "apple" + + class PumpkinPie(Pie): + filling: Literal["pumpkin"] = "pumpkin" + + def get_discriminator_value(v: Any) -> str: + if isinstance(v, dict): + return v.get("fruit", v.get("filling")) + return getattr(v, "fruit", getattr(v, "filling", None)) + + class ThanksgivingDinner(SQLModel): + dessert: ( + Annotated[ApplePie, Tag("apple")] | Annotated[PumpkinPie, Tag("pumpkin")] + ) = Field( + discriminator=Discriminator(get_discriminator_value), + ) + + apple_pie = ThanksgivingDinner.model_validate({"dessert": {"fruit": "apple"}}) + assert isinstance(apple_pie.dessert, ApplePie) + + pumpkin_pie = ThanksgivingDinner.model_validate({"dessert": {"filling": "pumpkin"}}) + assert isinstance(pumpkin_pie.dessert, PumpkinPie) + + def test_repr(): class Model(SQLModel): id: int | None = Field(primary_key=True)