Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 18 additions & 20 deletions django/contrib/gis/db/backends/base/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,36 +115,34 @@ def get_distance(self, f, value, lookup_type):
"Distance operations not available on this spatial backend."
)

def get_geom_placeholder(self, f, value, compiler):
@staticmethod
def _must_transform_value(value, field):
return value is not None and value.srid != field.srid

def get_geom_placeholder_sql(self, f, value, compiler):
"""
Return the placeholder for the given geometry field with the given
value. Depending on the spatial backend, the placeholder may contain a
stored procedure call to the transformation function of the spatial
backend.
"""

def transform_value(value, field):
return value is not None and value.srid != field.srid

if hasattr(value, "as_sql"):
return (
"%s(%%s, %s)" % (self.spatial_function_name("Transform"), f.srid)
if transform_value(value.output_field, f)
else "%s"
)
if transform_value(value, f):
# Add Transform() to the SQL placeholder.
return "%s(%s(%%s,%s), %s)" % (
self.spatial_function_name("Transform"),
self.from_text,
value.srid,
f.srid,
)
sql, params = compiler.compile(value)
if self._must_transform_value(value.output_field, f):
transform_func = self.spatial_function_name("Transform")
sql = f"{transform_func}({sql}, %s)"
params = (*params, f.srid)
return sql, params
elif self._must_transform_value(value, f):
transform_func = self.spatial_function_name("Transform")
sql = f"{transform_func}({self.from_text}(%s, %s), %s)"
params = (value, value.srid, f.srid)
return sql, params
elif self.connection.features.has_spatialrefsys_table:
return "%s(%%s,%s)" % (self.from_text, f.srid)
return f"{self.from_text}(%s, %s)", (value, f.srid)
else:
# For backwards compatibility on MySQL (#27464).
return "%s(%%s)" % self.from_text
return f"{self.from_text}(%s)", (value,)

def check_expression_support(self, expression):
if isinstance(expression, self.disallowed_aggregates):
Expand Down
6 changes: 3 additions & 3 deletions django/contrib/gis/db/backends/oracle/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,10 @@ def get_distance(self, f, value, lookup_type):

return [dist_param]

def get_geom_placeholder(self, f, value, compiler):
def get_geom_placeholder_sql(self, f, value, compiler):
if value is None:
return "NULL"
return super().get_geom_placeholder(f, value, compiler)
return "NULL", ()
return super().get_geom_placeholder_sql(f, value, compiler)

def spatial_aggregate_name(self, agg_name):
"""
Expand Down
18 changes: 8 additions & 10 deletions django/contrib/gis/db/backends/postgis/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,19 +304,19 @@ def get_distance(self, f, dist_val, lookup_type):

return [dist_param]

def get_geom_placeholder(self, f, value, compiler):
def get_geom_placeholder_sql(self, f, value, compiler):
"""
Provide a proper substitution value for Geometries or rasters that are
not in the SRID of the field. Specifically, this routine will
substitute in the ST_Transform() function call.
"""
transform_func = self.spatial_function_name("Transform")
if hasattr(value, "as_sql"):
if value.field.srid == f.srid:
placeholder = "%s"
else:
placeholder = "%s(%%s, %s)" % (transform_func, f.srid)
return placeholder
sql, params = compiler.compile(value)
if value.field.srid != f.srid:
sql = f"{transform_func}({sql}, %s)"
params = (*params, f.srid)
return sql, params

# Get the srid for this object
if value is None:
Expand All @@ -327,11 +327,9 @@ def get_geom_placeholder(self, f, value, compiler):
# Adding Transform() to the SQL placeholder if the value srid
# is not equal to the field srid.
if value_srid is None or value_srid == f.srid:
placeholder = "%s"
return "%s", (value,)
else:
placeholder = "%s(%%s, %s)" % (transform_func, f.srid)

return placeholder
return f"{transform_func}(%s, %s)", (value, f.srid)

def _get_postgis_func(self, func):
"""
Expand Down
4 changes: 2 additions & 2 deletions django/contrib/gis/db/models/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,12 @@ def geodetic(self, connection):
"""
return get_srid_info(self.srid, connection).geodetic

def get_placeholder(self, value, compiler, connection):
def get_placeholder_sql(self, value, compiler, connection):
"""
Return the placeholder for the spatial column for the
given value.
"""
return connection.ops.get_geom_placeholder(self, value, compiler)
return connection.ops.get_geom_placeholder_sql(self, value, compiler)

def get_srid(self, obj):
"""
Expand Down
10 changes: 5 additions & 5 deletions django/contrib/gis/db/models/lookups.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ def process_rhs(self, compiler, connection):
# If rhs is some Query, don't touch it.
return super().process_rhs(compiler, connection)
if isinstance(self.rhs, Expression):
self.rhs = self.rhs.resolve_expression(compiler.query)
rhs, rhs_params = super().process_rhs(compiler, connection)
placeholder = connection.ops.get_geom_placeholder(
self.lhs.output_field, self.rhs, compiler
rhs = self.rhs.resolve_expression(compiler.query)
else:
rhs = connection.ops.Adapter(self.rhs)
return connection.ops.get_geom_placeholder_sql(
self.lhs.output_field, rhs, compiler
)
return placeholder % rhs, rhs_params

def get_rhs_op(self, connection, rhs):
# Unlike BuiltinLookup, the GIS get_rhs_op() implementation should
Expand Down
8 changes: 6 additions & 2 deletions django/contrib/postgres/fields/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,12 @@ def db_parameters(self, connection):
db_params["collation"] = self.db_collation
return db_params

def get_placeholder(self, value, compiler, connection):
return "%s::{}".format(self.db_type(connection))
def get_placeholder_sql(self, value, compiler, connection):
db_type = self.db_type(connection)
if hasattr(value, "as_sql"):
sql, params = compiler.compile(value)
return f"{sql}::{db_type}", params
return f"%s::{db_type}", (value,)

def get_db_prep_value(self, value, connection, prepared=False):
if isinstance(value, (list, tuple)):
Expand Down
8 changes: 6 additions & 2 deletions django/contrib/postgres/fields/ranges.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,12 @@ def model(self, model):
def _choices_is_value(cls, value):
return isinstance(value, (list, tuple)) or super()._choices_is_value(value)

def get_placeholder(self, value, compiler, connection):
return "%s::{}".format(self.db_type(connection))
def get_placeholder_sql(self, value, compiler, connection):
db_type = self.db_type(connection)
if hasattr(value, "as_sql"):
sql, params = compiler.compile(value)
return f"{sql}::{db_type}", params
return f"%s::{db_type}", (value,)

def get_prep_value(self, value):
if value is None:
Expand Down
6 changes: 4 additions & 2 deletions django/db/backends/base/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,12 +734,14 @@ def combine_expression(self, connector, sub_expressions):
def combine_duration_expression(self, connector, sub_expressions):
return self.combine_expression(connector, sub_expressions)

def binary_placeholder_sql(self, value):
def binary_placeholder_sql(self, value, compiler):
"""
Some backends require special syntax to insert binary content (MySQL
for example uses '_binary %s').
"""
return "%s"
if hasattr(value, "as_sql"):
return compiler.compile(value)
return "%s", (value,)

def modify_insert_params(self, placeholder, params):
"""
Expand Down
10 changes: 6 additions & 4 deletions django/db/backends/mysql/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,10 +306,12 @@ def convert_uuidfield_value(self, value, expression, connection):
value = uuid.UUID(value)
return value

def binary_placeholder_sql(self, value):
return (
"_binary %s" if value is not None and not hasattr(value, "as_sql") else "%s"
)
def binary_placeholder_sql(self, value, compiler):
if value is None:
return "%s", (None,)
elif hasattr(value, "as_sql"):
return compiler.compile(value)
return "_binary %s", (value,)

def subtract_temporals(self, internal_type, lhs, rhs):
lhs_sql, lhs_params = lhs
Expand Down
4 changes: 2 additions & 2 deletions django/db/backends/postgresql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ def assemble_as_sql(self, fields, value_rows):
# Lack of fields denote the usage of the DEFAULT keyword
# for the insertion of empty rows.
or any(field is None for field in fields)
# Field.get_placeholder takes value as an argument, so the
# Field.get_placeholder_sql takes value as an argument, so the
# resulting placeholder might be dependent on the value.
# in UNNEST requires a single placeholder to "fit all values" in
# the array.
or any(hasattr(field, "get_placeholder") for field in fields)
or any(hasattr(field, "get_placeholder_sql") for field in fields)
# Fields that don't use standard internal types might not be
# unnest'able (e.g. array and geometry types are known to be
# problematic).
Expand Down
20 changes: 14 additions & 6 deletions django/db/models/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,10 +705,14 @@ def register_combinable_fields(lhs, connector, rhs, result):
_connector_combinators[connector].append((lhs, rhs, result))


for d in _connector_combinations:
for connector, field_types in d.items():
for lhs, rhs, result in field_types:
register_combinable_fields(lhs, connector, rhs, result)
def _register_combinable_fields():
for d in _connector_combinations:
for connector, field_types in d.items():
for lhs, rhs, result in field_types:
register_combinable_fields(lhs, connector, rhs, result)


_register_combinable_fields()


@functools.lru_cache(maxsize=128)
Expand Down Expand Up @@ -1173,8 +1177,12 @@ def as_sql(self, compiler, connection):
val = output_field.get_db_prep_save(val, connection=connection)
else:
val = output_field.get_db_prep_value(val, connection=connection)
if hasattr(output_field, "get_placeholder"):
return output_field.get_placeholder(val, compiler, connection), [val]
try:
get_placeholder_sql = output_field.get_placeholder_sql
except AttributeError:
pass
else:
return get_placeholder_sql(val, compiler, connection)
if val is None:
# oracledb does not always convert None to the appropriate
# NULL type (like in case expressions using numbers), so we
Expand Down
31 changes: 29 additions & 2 deletions django/db/models/fields/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
parse_duration,
parse_time,
)
from django.utils.deprecation import RemovedInDjango70Warning, django_file_prefixes
from django.utils.duration import duration_string
from django.utils.functional import Promise, cached_property
from django.utils.ipv6 import MAX_IPV6_ADDRESS_LENGTH, clean_ipv6_address
Expand Down Expand Up @@ -181,6 +182,32 @@ def _description(self):

description = property(_description)

def __init_subclass__(cls, **kwargs):
# RemovedInDjango70Warning: When the deprecation ends, remove
# completely.
# Allow for both `get_placeholder` and `get_placeholder_sql` to
# be declared to ease the deprecation process for third-party apps.
if (
get_placeholder := cls.__dict__.get("get_placeholder")
) is not None and "get_placeholder_sql" not in cls.__dict__:
warnings.warn(
"Field.get_placeholder is deprecated in favor of get_placeholder_sql. "
f"Define {cls.__module__}.{cls.__qualname__}.get_placeholder_sql "
"to return both SQL and parameters instead.",
category=RemovedInDjango70Warning,
skip_file_prefixes=django_file_prefixes(),
)

def get_placeholder_sql(self, value, compiler, connection):
placeholder = get_placeholder(self, value, compiler, connection)
if hasattr(value, "as_sql"):
sql, params = compiler.compile(value)
return placeholder % sql, params
return placeholder, (value,)

setattr(cls, "get_placeholder_sql", get_placeholder_sql)
return super().__init_subclass__(**kwargs)

def __init__(
self,
verbose_name=None,
Expand Down Expand Up @@ -2735,8 +2762,8 @@ def deconstruct(self):
def get_internal_type(self):
return "BinaryField"

def get_placeholder(self, value, compiler, connection):
return connection.ops.binary_placeholder_sql(value)
def get_placeholder_sql(self, value, compiler, connection):
return connection.ops.binary_placeholder_sql(value, compiler)

def get_default(self):
if self.has_default() and not callable(self.default):
Expand Down
46 changes: 25 additions & 21 deletions django/db/models/sql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1690,25 +1690,26 @@ class SQLInsertCompiler(SQLCompiler):
returning_fields = None
returning_params = ()

def field_as_sql(self, field, get_placeholder, val):
def field_as_sql(self, field, get_placeholder_sql, val):
"""
Take a field and a value intended to be saved on that field, and
return placeholder SQL and accompanying params. Check for raw values,
expressions, and fields with get_placeholder() defined in that order.
fields with get_placeholder_sql(), and compilable defined in that
order.
When field is None, consider the value raw and use it as the
placeholder, with no corresponding parameters returned.
"""
if field is None:
# A field value of None means the value is raw.
sql, params = val, []
elif get_placeholder_sql is not None:
# Some fields (e.g. geo fields) need special munging before
# they can be inserted.
sql, params = get_placeholder_sql(val, self, self.connection)
elif hasattr(val, "as_sql"):
# This is an expression, let's compile it.
sql, params = self.compile(val)
elif get_placeholder is not None:
# Some fields (e.g. geo fields) need special munging before
# they can be inserted.
sql, params = get_placeholder(val, self, self.connection), [val]
else:
# Return the common case for the placeholder
sql, params = "%s", [val]
Expand Down Expand Up @@ -1777,11 +1778,15 @@ def assemble_as_sql(self, fields, value_rows):

# list of (sql, [params]) tuples for each object to be saved
# Shape: [n_objs][n_fields][2]
get_placeholders = [getattr(field, "get_placeholder", None) for field in fields]
get_placeholder_sqls = [
getattr(field, "get_placeholder_sql", None) for field in fields
]
rows_of_fields_as_sql = (
(
self.field_as_sql(field, get_placeholder, value)
for field, get_placeholder, value in zip(fields, get_placeholders, row)
self.field_as_sql(field, get_placeholder_sql, value)
for field, get_placeholder_sql, value in zip(
fields, get_placeholder_sqls, row
)
)
for row in value_rows
)
Expand Down Expand Up @@ -2078,21 +2083,20 @@ def as_sql(self):
)
val = field.get_db_prep_save(val, connection=self.connection)

# Getting the placeholder for the field.
if hasattr(field, "get_placeholder"):
placeholder = field.get_placeholder(val, self, self.connection)
else:
placeholder = "%s"
name = field.column
if hasattr(val, "as_sql"):
quoted_name = qn(field.column)
if (
get_placeholder_sql := getattr(field, "get_placeholder_sql", None)
) is not None:
sql, params = get_placeholder_sql(val, self, self.connection)
values.append(f"{quoted_name} = {sql}")
update_params.extend(params)
elif hasattr(val, "as_sql"):
sql, params = self.compile(val)
values.append("%s = %s" % (qn(name), placeholder % sql))
values.append(f"{quoted_name} = {sql}")
update_params.extend(params)
elif val is not None:
values.append("%s = %s" % (qn(name), placeholder))
update_params.append(val)
else:
values.append("%s = NULL" % qn(name))
values.append(f"{quoted_name} = %s")
update_params.append(val)
table = self.query.base_table
result = [
"UPDATE %s SET" % qn(table),
Expand Down
Loading
Loading