diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 230f8cc362..e204b1778a 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -179,6 +179,18 @@ def get_relationship_to( elif origin is list: use_annotation = get_args(annotation)[0] + # If a dict or Mapping, get the value type (second type argument) + elif origin is dict or origin is Mapping: + args = get_args(annotation) + if len(args) >= 2: + # For Dict[K, V] or Mapping[K, V], we want the value type (V) + use_annotation = args[1] + else: + raise ValueError( + f"Dict/Mapping relationship field '{name}' must have both key " + "and value type arguments (e.g., Dict[str, Model])" + ) + return get_relationship_to( name=name, rel_info=rel_info, annotation=use_annotation ) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 7c916f79af..529e20388e 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -54,7 +54,14 @@ from sqlalchemy.orm.instrumentation import is_instrumented from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.sqltypes import LargeBinary, Time, Uuid -from typing_extensions import Literal, TypeAlias, deprecated, get_origin +from typing_extensions import ( + Annotated, + Literal, + TypeAlias, + deprecated, + get_args, + get_origin, +) from ._compat import ( # type: ignore[attr-defined] IS_PYDANTIC_V2, @@ -562,7 +569,8 @@ def get_config(name: str) -> Any: # If it was passed by kwargs, ensure it's also set in config set_config_value(model=new_cls, parameter="table", value=config_table) for k, v in get_model_fields(new_cls).items(): - col = get_column_from_field(v) + original_annotation = new_cls.__annotations__.get(k) + col = get_column_from_field(v, original_annotation) setattr(new_cls, k, col) # Set a config flag to tell FastAPI that this should be read with a field # in orm_mode instead of preemptively converting it to a dict. @@ -646,12 +654,44 @@ def __init__( ModelMetaclass.__init__(cls, classname, bases, dict_, **kw) -def get_sqlalchemy_type(field: Any) -> Any: +def _get_sqlmodel_field_info_from_annotation(annotation: Any) -> Optional[FieldInfo]: + """Extract SQLModel FieldInfo from an Annotated type's metadata. + + When using Annotated[type, Field(...), Validator(...)], Pydantic V2 may create + a new pydantic.fields.FieldInfo that doesn't preserve SQLModel-specific attributes + like sa_column and sa_type. This function looks through the Annotated metadata + to find the original SQLModel FieldInfo. + """ + if get_origin(annotation) is not Annotated: + return None + for arg in get_args(annotation)[1:]: # Skip the first arg (the actual type) + if isinstance(arg, FieldInfo): + return arg + return None + + +def get_sqlalchemy_type(field: Any, original_annotation: Any = None) -> Any: if IS_PYDANTIC_V2: field_info = field else: field_info = field.field_info sa_type = getattr(field_info, "sa_type", Undefined) # noqa: B009 + # If sa_type not found on field_info, check if it's in the Annotated metadata + # This handles the case where Pydantic V2 creates a new FieldInfo losing SQLModel attrs + if sa_type is Undefined and IS_PYDANTIC_V2: + # First try field_info.annotation (may be unpacked by Pydantic) + annotation = getattr(field_info, "annotation", None) + if annotation is not None: + sqlmodel_field_info = _get_sqlmodel_field_info_from_annotation(annotation) + if sqlmodel_field_info is not None: + sa_type = getattr(sqlmodel_field_info, "sa_type", Undefined) + # If still not found, try the original annotation from the class + if sa_type is Undefined and original_annotation is not None: + sqlmodel_field_info = _get_sqlmodel_field_info_from_annotation( + original_annotation + ) + if sqlmodel_field_info is not None: + sa_type = getattr(sqlmodel_field_info, "sa_type", Undefined) if sa_type is not Undefined: return sa_type @@ -703,15 +743,31 @@ def get_sqlalchemy_type(field: Any) -> Any: raise ValueError(f"{type_} has no matching SQLAlchemy type") -def get_column_from_field(field: Any) -> Column: # type: ignore +def get_column_from_field(field: Any, original_annotation: Any = None) -> Column: # type: ignore if IS_PYDANTIC_V2: field_info = field else: field_info = field.field_info sa_column = getattr(field_info, "sa_column", Undefined) + # If sa_column not found on field_info, check if it's in the Annotated metadata + # This handles the case where Pydantic V2 creates a new FieldInfo losing SQLModel attrs + if sa_column is Undefined and IS_PYDANTIC_V2: + # First try field_info.annotation (may be unpacked by Pydantic) + annotation = getattr(field_info, "annotation", None) + if annotation is not None: + sqlmodel_field_info = _get_sqlmodel_field_info_from_annotation(annotation) + if sqlmodel_field_info is not None: + sa_column = getattr(sqlmodel_field_info, "sa_column", Undefined) + # If still not found, try the original annotation from the class + if sa_column is Undefined and original_annotation is not None: + sqlmodel_field_info = _get_sqlmodel_field_info_from_annotation( + original_annotation + ) + if sqlmodel_field_info is not None: + sa_column = getattr(sqlmodel_field_info, "sa_column", Undefined) if isinstance(sa_column, Column): return sa_column - sa_type = get_sqlalchemy_type(field) + sa_type = get_sqlalchemy_type(field, original_annotation) primary_key = getattr(field_info, "primary_key", Undefined) if primary_key is Undefined: primary_key = False diff --git a/tests/test_annotated_sa_column.py b/tests/test_annotated_sa_column.py new file mode 100644 index 0000000000..9ec01c41e5 --- /dev/null +++ b/tests/test_annotated_sa_column.py @@ -0,0 +1,92 @@ +"""Tests for Annotated fields with sa_column and Pydantic validators. + +When using Annotated[type, Field(sa_column=...), Validator(...)], Pydantic V2 may +create a new FieldInfo that doesn't preserve SQLModel-specific attributes like +sa_column. These tests ensure the sa_column is properly extracted from the +Annotated metadata. +""" + +from datetime import datetime +from typing import Annotated, Optional + +from pydantic import AfterValidator, BeforeValidator +from sqlalchemy import Column, DateTime, String +from sqlmodel import Field, SQLModel + + +def test_annotated_sa_column_with_validators() -> None: + """Test that sa_column is preserved when using Annotated with validators.""" + + def before_validate(v: datetime) -> datetime: + return v + + def after_validate(v: datetime) -> datetime: + return v + + class Position(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + timestamp: Annotated[ + datetime, + Field( + sa_column=Column(DateTime(timezone=True), nullable=False, index=True) + ), + BeforeValidator(before_validate), + AfterValidator(after_validate), + ] + + # Verify the column type has timezone=True + assert Position.__table__.c.timestamp.type.timezone is True + assert Position.__table__.c.timestamp.nullable is False + assert Position.__table__.c.timestamp.index is True + + +def test_annotated_sa_column_with_single_validator() -> None: + """Test sa_column with just one validator.""" + + def validate_name(v: str) -> str: + return v.strip() + + class Item(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: Annotated[ + str, + Field(sa_column=Column(String(100), nullable=False, unique=True)), + AfterValidator(validate_name), + ] + + assert isinstance(Item.__table__.c.name.type, String) + assert Item.__table__.c.name.type.length == 100 + assert Item.__table__.c.name.nullable is False + assert Item.__table__.c.name.unique is True + + +def test_annotated_sa_column_without_validators() -> None: + """Test that sa_column still works with Annotated but no validators.""" + + class Record(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + created_at: Annotated[ + datetime, + Field(sa_column=Column(DateTime(timezone=True), nullable=False)), + ] + + assert Record.__table__.c.created_at.type.timezone is True + assert Record.__table__.c.created_at.nullable is False + + +def test_annotated_sa_type_with_validators() -> None: + """Test that sa_type is preserved when using Annotated with validators.""" + + def validate_timestamp(v: datetime) -> datetime: + return v + + class Event(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + occurred_at: Annotated[ + datetime, + Field(sa_type=DateTime(timezone=True)), + AfterValidator(validate_timestamp), + ] + + # Verify the column type has timezone=True + assert Event.__table__.c.occurred_at.type.timezone is True diff --git a/tests/test_dict_relationship_recursion.py b/tests/test_dict_relationship_recursion.py new file mode 100644 index 0000000000..9993b6843d --- /dev/null +++ b/tests/test_dict_relationship_recursion.py @@ -0,0 +1,60 @@ +"""Test for Dict relationship recursion bug fix.""" + +from typing import Dict + +from sqlalchemy.orm.collections import attribute_mapped_collection +from sqlmodel import Field, Relationship, SQLModel + + +def test_dict_relationship_pattern(): + """Test that Dict relationships with attribute_mapped_collection work.""" + + # Create a minimal reproduction of the pattern + # This should not raise a RecursionError + + class TestChild(SQLModel, table=True): + __tablename__ = "test_child" + id: int = Field(primary_key=True) + key: str = Field(nullable=False) + parent_id: int = Field(foreign_key="test_parent.id") + parent: "TestParent" = Relationship(back_populates="children") + + class TestParent(SQLModel, table=True): + __tablename__ = "test_parent" + id: int = Field(primary_key=True) + children: Dict[str, "TestChild"] = Relationship( + back_populates="parent", + sa_relationship_kwargs={ + "collection_class": attribute_mapped_collection("key") + }, + ) + + # If we got here without RecursionError, the bug is fixed + assert TestParent.__tablename__ == "test_parent" + assert TestChild.__tablename__ == "test_child" + + +def test_dict_relationship_with_optional(): + """Test that Optional[Dict[...]] relationships also work.""" + from typing import Optional + + class Child(SQLModel, table=True): + __tablename__ = "child" + id: int = Field(primary_key=True) + key: str = Field(nullable=False) + parent_id: int = Field(foreign_key="parent.id") + parent: Optional["Parent"] = Relationship(back_populates="children") + + class Parent(SQLModel, table=True): + __tablename__ = "parent" + id: int = Field(primary_key=True) + children: Optional[Dict[str, "Child"]] = Relationship( + back_populates="parent", + sa_relationship_kwargs={ + "collection_class": attribute_mapped_collection("key") + }, + ) + + # If we got here without RecursionError, the bug is fixed + assert Parent.__tablename__ == "parent" + assert Child.__tablename__ == "child"