diff --git a/packages/django-google-spanner/django_spanner/__init__.py b/packages/django-google-spanner/django_spanner/__init__.py index 23ea56b5db15..30240ae98ca6 100644 --- a/packages/django-google-spanner/django_spanner/__init__.py +++ b/packages/django-google-spanner/django_spanner/__init__.py @@ -84,19 +84,15 @@ def autofield_init(self, *args, **kwargs): == "true" ): self.default = gen_rand_int64 + self.db_returning = False + self.validators = [] break AutoField.__init__ = autofield_init -AutoField.db_returning = False -AutoField.validators = [] SmallAutoField.__init__ = autofield_init BigAutoField.__init__ = autofield_init -SmallAutoField.db_returning = False -BigAutoField.db_returning = False -SmallAutoField.validators = [] -BigAutoField.validators = [] def get_prep_value(self, value): diff --git a/packages/django-google-spanner/tests/unit/django_spanner/test_schema.py b/packages/django-google-spanner/tests/unit/django_spanner/test_schema.py index 5ef6f74395af..5e4fca90adbe 100644 --- a/packages/django-google-spanner/tests/unit/django_spanner/test_schema.py +++ b/packages/django-google-spanner/tests/unit/django_spanner/test_schema.py @@ -408,6 +408,8 @@ def test_autofield_no_default(self): """Spanner, default is not provided.""" field = AutoField(name="field_name") assert gen_rand_int64 == field.default + # db_returning must be explicitly False because Spanner is handling ID generation client-side + assert getattr(field, "db_returning", True) is False def test_autofield_default(self): """Spanner, default provided.""" @@ -415,12 +417,18 @@ def test_autofield_default(self): field = AutoField(name="field_name", default=mock_func) assert gen_rand_int64 != field.default assert mock_func == field.default + # A default was already provided, so Spanner does not generate random IDs client-side. + # Therefore, db_returning does not need to be overridden to False. + assert field.db_returning is True def test_autofield_not_spanner(self): """Not Spanner, default not provided.""" connection.settings_dict["ENGINE"] = "another_db" field = AutoField(name="field_name") assert gen_rand_int64 != field.default + # db_returning should remain untouched (implicitly True) for non-Spanner databases + # so that Django retrieves the auto-increment ID correctly. + assert field.db_returning is True connection.settings_dict["ENGINE"] = "django_spanner" def test_autofield_not_spanner_w_default(self): @@ -430,6 +438,8 @@ def test_autofield_not_spanner_w_default(self): field = AutoField(name="field_name", default=mock_func) assert gen_rand_int64 != field.default assert mock_func == field.default + # Because it's not a Spanner database, the behavior shouldn't be altered in any way. + assert field.db_returning is True connection.settings_dict["ENGINE"] = "django_spanner" def test_autofield_spanner_as_non_default_db_random_generation_enabled( @@ -441,6 +451,9 @@ def test_autofield_spanner_as_non_default_db_random_generation_enabled( connections.settings["secondary"]["RANDOM_ID_GENERATION_ENABLED"] = "true" field = AutoField(name="field_name") assert gen_rand_int64 == field.default + # Since this specific connection explicitly enables client-side random generation, + # we must tell Django not to attempt retrieving the DB's returned ID. + assert getattr(field, "db_returning", True) is False connections.settings["default"]["ENGINE"] = "django_spanner" connections.settings["secondary"]["ENGINE"] = "django_spanner" del connections.settings["secondary"]["RANDOM_ID_GENERATION_ENABLED"] @@ -450,4 +463,7 @@ def test_autofield_random_generation_disabled(self): connections.settings["default"]["RANDOM_ID_GENERATION_ENABLED"] = "false" field = AutoField(name="field_name") assert gen_rand_int64 != field.default + # Because we're delegating ID generation back to the database backend, + # Django needs to be able to retrieve the assigned ID. + assert field.db_returning is True del connections.settings["default"]["RANDOM_ID_GENERATION_ENABLED"]