Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions django/contrib/contenttypes/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,11 +651,12 @@ def _apply_rel_filters(self, queryset):
Filter the queryset for the instance this manager is bound to.
"""
db = self._db or router.db_for_read(self.model, instance=self.instance)
return (
queryset.using(db)
.fetch_mode(self.instance._state.fetch_mode)
.filter(**self.core_filters)
)
with queryset._avoid_cloning():
return (
queryset.using(db)
.fetch_mode(self.instance._state.fetch_mode)
.filter(**self.core_filters)
)

def _remove_prefetched_objects(self):
try:
Expand All @@ -671,12 +672,17 @@ def get_queryset(self):
return self._apply_rel_filters(queryset)

def get_prefetch_querysets(self, instances, querysets=None):
if querysets and len(querysets) != 1:
raise ValueError(
"querysets argument of get_prefetch_querysets() should have a "
"length of 1."
)
queryset = querysets[0] if querysets else super().get_queryset()
_cloning_disabled = False
if querysets:
if len(querysets) != 1:
raise ValueError(
"querysets argument of get_prefetch_querysets() should have a "
"length of 1."
)
queryset = querysets[0]
else:
_cloning_disabled = True
queryset = super().get_queryset()._disable_cloning()
queryset._add_hints(instance=instances[0])
queryset = queryset.using(queryset._db or self._db)
# Group instances by content types.
Expand All @@ -697,8 +703,12 @@ def get_prefetch_querysets(self, instances, querysets=None):
# instances' PK in order to match up instances:
object_id_converter = instances[0]._meta.pk.to_python
content_type_id_field_name = "%s_id" % self.content_type_field_name
queryset = queryset.filter(query)
# Restore subsequent cloning operations.
if _cloning_disabled:
queryset._enable_cloning()
return (
queryset.filter(query),
queryset,
lambda relobj: (
object_id_converter(getattr(relobj, self.object_id_field_name)),
getattr(relobj, content_type_id_field_name),
Expand Down
113 changes: 73 additions & 40 deletions django/db/models/fields/related_descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,14 +172,17 @@ def get_queryset(self, *, instance):
).fetch_mode(instance._state.fetch_mode)

def get_prefetch_querysets(self, instances, querysets=None):
if querysets and len(querysets) != 1:
raise ValueError(
"querysets argument of get_prefetch_querysets() should have a length "
"of 1."
)
queryset = (
querysets[0] if querysets else self.get_queryset(instance=instances[0])
)
_cloning_disabled = False
if querysets:
if len(querysets) != 1:
raise ValueError(
"querysets argument of get_prefetch_querysets() should have a "
"length of 1."
)
queryset = querysets[0]
else:
_cloning_disabled = True
queryset = self.get_queryset(instance=instances[0])._disable_cloning()

rel_obj_attr = self.field.get_foreign_related_value
instance_attr = self.field.get_local_related_value
Expand All @@ -203,6 +206,9 @@ def get_prefetch_querysets(self, instances, querysets=None):
# There can be only one object prefetched for each instance so clear
# ordering if the query allows it without side effects.
queryset.query.clear_ordering()
# Restore subsequent cloning operations.
if _cloning_disabled:
queryset._enable_cloning()

# Since we're going to assign directly in the cache,
# we must manage the reverse relation cache manually.
Expand Down Expand Up @@ -466,14 +472,17 @@ def get_queryset(self, *, instance):
).fetch_mode(instance._state.fetch_mode)

def get_prefetch_querysets(self, instances, querysets=None):
if querysets and len(querysets) != 1:
raise ValueError(
"querysets argument of get_prefetch_querysets() should have a length "
"of 1."
)
queryset = (
querysets[0] if querysets else self.get_queryset(instance=instances[0])
)
_cloning_disabled = False
if querysets:
if len(querysets) != 1:
raise ValueError(
"querysets argument of get_prefetch_querysets() should have a "
"length of 1."
)
queryset = querysets[0]
else:
_cloning_disabled = True
queryset = self.get_queryset(instance=instances[0])._disable_cloning()

rel_obj_attr = self.related.field.get_local_related_value
instance_attr = self.related.field.get_foreign_related_value
Expand All @@ -483,6 +492,9 @@ def get_prefetch_querysets(self, instances, querysets=None):
# There can be only one object prefetched for each instance so clear
# ordering if the query allows it without side effects.
queryset.query.clear_ordering()
# Restore subsequent cloning operations.
if _cloning_disabled:
queryset._enable_cloning()

# Since we're going to assign directly in the cache,
# we must manage the reverse relation cache manually.
Expand Down Expand Up @@ -738,12 +750,13 @@ def _apply_rel_filters(self, queryset):
empty_strings_as_null = connections[
db
].features.interprets_empty_strings_as_nulls
queryset._add_hints(instance=self.instance)
if self._db:
queryset = queryset.using(self._db)
queryset._fetch_mode = self.instance._state.fetch_mode
queryset._defer_next_filter = True
queryset = queryset.filter(**self.core_filters)
with queryset._avoid_cloning():
queryset._add_hints(instance=self.instance)
if self._db:
queryset = queryset.using(self._db)
queryset._fetch_mode = self.instance._state.fetch_mode
queryset._defer_next_filter = True
queryset = queryset.filter(**self.core_filters)
for field in self.field.foreign_related_fields:
val = getattr(self.instance, field.attname)
if val is None or (val == "" and empty_strings_as_null):
Expand Down Expand Up @@ -797,19 +810,28 @@ def get_queryset(self):
return self._apply_rel_filters(queryset)

def get_prefetch_querysets(self, instances, querysets=None):
if querysets and len(querysets) != 1:
raise ValueError(
"querysets argument of get_prefetch_querysets() should have a "
"length of 1."
)
queryset = querysets[0] if querysets else super().get_queryset()
_cloning_disabled = False
if querysets:
if len(querysets) != 1:
raise ValueError(
"querysets argument of get_prefetch_querysets() should have a "
"length of 1."
)
queryset = querysets[0]
else:
_cloning_disabled = True
queryset = super().get_queryset()._disable_cloning()

queryset._add_hints(instance=instances[0])
queryset = queryset.using(queryset._db or self._db)

rel_obj_attr = self.field.get_local_related_value
instance_attr = self.field.get_foreign_related_value
instances_dict = {instance_attr(inst): inst for inst in instances}
queryset = _filter_prefetch_queryset(queryset, self.field.name, instances)
# Restore subsequent cloning operations.
if _cloning_disabled:
queryset._enable_cloning()

# Since we just bypassed this class' get_queryset(), we must manage
# the reverse relation manually.
Expand Down Expand Up @@ -1140,12 +1162,13 @@ def _apply_rel_filters(self, queryset):
"""
Filter the queryset for the instance this manager is bound to.
"""
queryset._add_hints(instance=self.instance)
if self._db:
queryset = queryset.using(self._db)
queryset._fetch_mode = self.instance._state.fetch_mode
queryset._defer_next_filter = True
return queryset._next_is_sticky().filter(**self.core_filters)
with queryset._avoid_cloning():
queryset._add_hints(instance=self.instance)
if self._db:
queryset = queryset.using(self._db)
queryset._fetch_mode = self.instance._state.fetch_mode
queryset._defer_next_filter = True
return queryset._next_is_sticky().filter(**self.core_filters)

def get_prefetch_cache(self):
# Walk up the ancestor-chain (if cached) to try and find a prefetch
Expand Down Expand Up @@ -1174,12 +1197,18 @@ def get_queryset(self):
return self._apply_rel_filters(queryset)

def get_prefetch_querysets(self, instances, querysets=None):
if querysets and len(querysets) != 1:
raise ValueError(
"querysets argument of get_prefetch_querysets() should have a "
"length of 1."
)
queryset = querysets[0] if querysets else super().get_queryset()
_cloning_disabled = False
if querysets:
if len(querysets) != 1:
raise ValueError(
"querysets argument of get_prefetch_querysets() should have a "
"length of 1."
)
queryset = querysets[0]
else:
_cloning_disabled = True
queryset = super().get_queryset()._disable_cloning()

queryset._add_hints(instance=instances[0])
queryset = queryset.using(queryset._db or self._db)
queryset = _filter_prefetch_queryset(
Expand All @@ -1205,6 +1234,10 @@ def get_prefetch_querysets(self, instances, querysets=None):
for f in fk.local_related_fields
}
)
# Restore subsequent cloning operations.
if _cloning_disabled:
queryset._enable_cloning()

return (
queryset,
lambda result: tuple(
Expand Down
65 changes: 63 additions & 2 deletions django/db/models/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,28 @@ def __iter__(self):
yield row[0]


class PreventQuerySetCloning:
"""
Temporarily prevent the given QuerySet from creating new QuerySet instances
on each mutating operation (e.g: filter(), exclude() etc), instead
modifying the QuerySet in-place.

@contextlib.contextmanager is intentionally not used for performance
reasons.
"""

__slots__ = ("queryset",)

def __init__(self, queryset):
self.queryset = queryset

def __enter__(self):
return self.queryset._disable_cloning()

def __exit__(self, exc_type, exc_value, traceback):
self.queryset._enable_cloning()


class QuerySet(AltersData):
"""Represent a lazy database lookup for a set of objects."""

Expand All @@ -319,6 +341,7 @@ def __init__(self, model=None, query=None, using=None, hints=None):
self._fields = None
self._defer_next_filter = False
self._deferred_filter = None
self._cloning_enabled = True

@property
def query(self):
Expand Down Expand Up @@ -2134,12 +2157,50 @@ def _batched_insert(
)
return inserted_rows

def _disable_cloning(self):
"""
Prevent calls to _chain() from creating a new QuerySet via _clone().
All subsequent QuerySet mutations will occur on this instance until
_enable_cloning() is used.
"""
self._cloning_enabled = False
return self

def _enable_cloning(self):
"""
Allow calls to _chain() to create a new QuerySet via _clone(). Restores
the default behavior where any QuerySet mutation will return a new
QuerySet instance. Necessary only when there has been a
_disable_cloning() call previously.
"""
self._cloning_enabled = True
return self

def _avoid_cloning(self):
"""
Temporarily prevent QuerySet _clone() operations, restoring the default
behavior on exit. For the duration of the context managed statement,
all operations (e.g. filter(), exclude(), etc.) will mutate the same
QuerySet instance.

@contextlib.contextmanager is intentionally not used for performance
reasons.
"""
return PreventQuerySetCloning(self)

def _chain(self):
"""
Return a copy of the current QuerySet that's ready for another
operation.

If the QuerySet has opted in to in-place mutations via
_disable_cloning() temporarily, the copy doesn't occur and instead the
same QuerySet instance will be modified.
"""
obj = self._clone()
if not self._cloning_enabled:
obj = self
else:
obj = self._clone()
if obj._sticky_filter:
obj.query.filter_is_sticky = True
obj._sticky_filter = False
Expand Down Expand Up @@ -2845,7 +2906,7 @@ def prefetch_one_level(instances, prefetcher, lookup, level):
else:
manager = getattr(obj, to_attr)
if leaf and lookup.queryset is not None:
qs = manager._apply_rel_filters(lookup.queryset)
qs = manager._apply_rel_filters(lookup.queryset._chain())
else:
qs = manager.get_queryset()
qs._result_cache = vals
Expand Down
46 changes: 46 additions & 0 deletions tests/queries/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4629,3 +4629,49 @@ def test_ticket_23622(self):
set(Ticket23605A.objects.filter(qy).values_list("pk", flat=True)),
)
self.assertSequenceEqual(Ticket23605A.objects.filter(qx), [a2])


class QuerySetCloningTests(TestCase):
@classmethod
def setUpTestData(cls):
SimpleCategory.objects.bulk_create(
[
SimpleCategory(name="first"),
SimpleCategory(name="second"),
SimpleCategory(name="third"),
SimpleCategory(name="fourth"),
]
)

def test_context_manager(self):
"""
_avoid_cloning() makes modifications apply to the original QuerySet.
"""
qs = SimpleCategory.objects.all()
with qs._avoid_cloning():
qs2 = qs.filter(name__in={"first", "second"}).exclude(name="second")
self.assertIs(qs2, qs)
qs3 = qs2.exclude(name__in={"third", "fourth"})
# qs3 is not a mutation of qs2 (which is actually also qs) but a new
# instance entirely.
self.assertIsNot(qs3, qs)
self.assertIsNot(qs3, qs2)

def test_explicit_toggling(self):
qs = SimpleCategory.objects.filter(name__in={"first", "second"})
qs2 = qs._disable_cloning()
# The _disable_cloning() method doesn't return a new QuerySet, but
# toggles the value on the current instance. qs2 can be ignored.
self.assertIs(qs2, qs)
qs3 = qs.filter(name__in={"first", "second"})
qs3 = qs3.exclude(name="second")
qs3._enable_cloning()
# These are still both references to the same QuerySet, despite
# re-binding as if they were normal chained operations providing new
# QuerySet instances.
self.assertIs(qs3, qs)
qs3 = qs3.filter(name="second")
# Cloning has been re-enabled so subsequent operations yield a new
# QuerySet. qs3 is now all of the filters applied to qs + an additional
# filter.
self.assertIsNot(qs3, qs)
Loading