diff --git a/README.rst b/README.rst index e9319eb9..ef6d00ba 100644 --- a/README.rst +++ b/README.rst @@ -133,6 +133,15 @@ The Cachier wrapper adds a ``clear_cache()`` function to each wrapped function. foo.clear_cache() +To clear only the cache entry for a specific call, pass the same arguments to ``clear_cache()`` that you would pass to the wrapped function: + +.. code-block:: python + + foo.clear_cache(arg1, arg2) + foo.clear_cache(arg1, arg2=arg2) + +The asynchronous ``aclear_cache()`` helper supports the same argument-specific form. + General Configuration ---------------------- diff --git a/src/cachier/core.py b/src/cachier/core.py index 5a4bc4aa..ea79c6f0 100644 --- a/src/cachier/core.py +++ b/src/cachier/core.py @@ -48,9 +48,9 @@ class _CachierWrappedFunc(Protocol[_P, _R_co]): def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R_co: ... # pragma: no cover - clear_cache: Callable[[], Any] + clear_cache: Callable[..., Any] clear_being_calculated: Callable[[], Any] - aclear_cache: Callable[[], Any] + aclear_cache: Callable[..., Any] aclear_being_calculated: Callable[[], Any] cache_dpath: Callable[[], Optional[str]] precache_value: Callable[..., Any] @@ -219,6 +219,13 @@ def _is_async_redis_client(client: Any) -> bool: return all(inspect.iscoroutinefunction(getattr(client, name, None)) for name in method_names) +def _convert_public_cache_args(func, _is_method: bool, args: tuple, kwds: dict) -> dict: + """Convert cache-management arguments to canonical cache-key kwargs.""" + if _is_method: + args = (None, *args) + return _convert_args_kwargs(func, _is_method=_is_method, args=args, kwds=kwds) + + def cachier( hash_func: Optional[HashFunc] = None, hash_params: Optional[HashFunc] = None, @@ -733,9 +740,14 @@ async def func_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: def func_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: return _call(*args, **kwargs) # type: ignore[arg-type] - def _clear_cache(): - """Clear the cache.""" - core.clear_cache() + def _clear_cache(*args, **kwds): + """Clear the cache, or only the entry matching the provided arguments.""" + if args or kwds: + kwargs = _convert_public_cache_args(func, core.func_is_method, args, kwds) + key = core.get_key((), kwargs) + core.clear_cache_entry(key) + else: + core.clear_cache() if is_coroutine: return _ImmediateAwaitable() return None @@ -747,9 +759,14 @@ def _clear_being_calculated(): return _ImmediateAwaitable() return None - async def _aclear_cache(): - """Clear the cache asynchronously.""" - await core.aclear_cache() + async def _aclear_cache(*args, **kwds): + """Clear the cache asynchronously, or only the entry matching the provided arguments.""" + if args or kwds: + kwargs = _convert_public_cache_args(func, core.func_is_method, args, kwds) + key = core.get_key((), kwargs) + await core.aclear_cache_entry(key) + else: + await core.aclear_cache() async def _aclear_being_calculated(): """Mark all entries in this cache as not being calculated asynchronously.""" diff --git a/src/cachier/cores/base.py b/src/cachier/cores/base.py index a62a0d28..450fd4f6 100644 --- a/src/cachier/cores/base.py +++ b/src/cachier/cores/base.py @@ -274,6 +274,18 @@ async def aclear_cache(self) -> None: """ await asyncio.to_thread(self.clear_cache) + @abc.abstractmethod + def clear_cache_entry(self, key: str) -> None: + """Clear the cache entry mapped by the given key.""" + + async def aclear_cache_entry(self, key: str) -> None: + """Async-compatible variant of :meth:`clear_cache_entry`. + + By default this runs in a thread to avoid blocking the event loop. + + """ + await asyncio.to_thread(self.clear_cache_entry, key) + @abc.abstractmethod def clear_being_calculated(self) -> None: """Mark all entries in this cache as not being calculated.""" diff --git a/src/cachier/cores/memory.py b/src/cachier/cores/memory.py index 36fddf41..7b957fab 100644 --- a/src/cachier/cores/memory.py +++ b/src/cachier/cores/memory.py @@ -126,6 +126,11 @@ def clear_cache(self) -> None: # Update size metrics after clearing self._update_size_metrics() + def clear_cache_entry(self, key: str) -> None: + with self.lock: + self.cache.pop(self._hash_func_key(key), None) + self._update_size_metrics() + def clear_being_calculated(self) -> None: with self.lock: for entry in self.cache.values(): diff --git a/src/cachier/cores/mongo.py b/src/cachier/cores/mongo.py index 2ace3aaa..a45d3d42 100644 --- a/src/cachier/cores/mongo.py +++ b/src/cachier/cores/mongo.py @@ -249,6 +249,14 @@ async def aclear_cache(self) -> None: mongo_collection = await self._ensure_collection_async() await mongo_collection.delete_many(filter={"func": self._func_str}) + def clear_cache_entry(self, key: str) -> None: + mongo_collection = self._ensure_collection() + mongo_collection.delete_one(filter={"func": self._func_str, "key": key}) + + async def aclear_cache_entry(self, key: str) -> None: + mongo_collection = await self._ensure_collection_async() + await mongo_collection.delete_one(filter={"func": self._func_str, "key": key}) + def clear_being_calculated(self) -> None: mongo_collection = self._ensure_collection() mongo_collection.update_many( diff --git a/src/cachier/cores/pickle.py b/src/cachier/cores/pickle.py index fda3b308..3f21ee71 100644 --- a/src/cachier/cores/pickle.py +++ b/src/cachier/cores/pickle.py @@ -172,19 +172,22 @@ def _clear_all_cache_files(self) -> None: path, name = os.path.split(self.cache_fpath) for subpath in os.listdir(path): if subpath.startswith(f"{name}_"): - fpath = os.path.join(path, subpath) - # Retry loop to handle Windows mandatory file-locking (WinError 32): - # portalocker holds an exclusive lock while a thread is computing, - # so os.remove() may fail transiently until the lock is released. - for attempt in range(3): # pragma: no branch - try: - os.remove(fpath) - break - except PermissionError: - if attempt < 2: - time.sleep(0.1 * (attempt + 1)) - else: - raise + self._remove_cache_file_with_retries(os.path.join(path, subpath)) + + @staticmethod + def _remove_cache_file_with_retries(fpath: str) -> None: + # Retry loop to handle Windows mandatory file-locking (WinError 32): + # portalocker holds an exclusive lock while a thread is computing, + # so os.remove() may fail transiently until the lock is released. + for attempt in range(3): # pragma: no branch + try: + os.remove(fpath) + break + except PermissionError: + if attempt < 2: + time.sleep(0.1 * (attempt + 1)) + else: + raise def _clear_being_calculated_all_cache_files(self) -> None: path, name = os.path.split(self.cache_fpath) @@ -272,10 +275,14 @@ async def _aset_entry(self, key: str, func_res: Any) -> bool: return self._set_entry(key, func_res) def mark_entry_being_calculated_separate_files(self, key: str) -> None: - self._save_cache( - CacheEntry(value=None, time=datetime.now(), stale=False, _processing=True), - separate_file_key=key, + entry = self._load_cache_by_key(key) or CacheEntry( + value=None, + time=datetime.now(), + stale=False, + _processing=False, ) + entry._processing = True + self._save_cache(entry, separate_file_key=key) def _mark_entry_not_calculated_separate_files(self, key: str) -> None: _, entry = self.get_entry_by_key(key) @@ -416,6 +423,17 @@ def clear_cache(self) -> None: else: self._save_cache({}) + def clear_cache_entry(self, key: str) -> None: + if self.separate_files: + with suppress(FileNotFoundError): + self._remove_cache_file_with_retries(f"{self.cache_fpath}_{key}") + return + + with self.lock: + cache = self.get_cache_dict() + cache.pop(key, None) + self._save_cache(cache) + def clear_being_calculated(self) -> None: if self.separate_files: self._clear_being_calculated_all_cache_files() diff --git a/src/cachier/cores/redis.py b/src/cachier/cores/redis.py index df6eea00..829d9a3a 100644 --- a/src/cachier/cores/redis.py +++ b/src/cachier/cores/redis.py @@ -353,6 +353,16 @@ def clear_cache(self) -> None: except Exception as e: warnings.warn(f"Redis clear_cache failed: {e}", stacklevel=2) + def clear_cache_entry(self, key: str) -> None: + """Clear the cache entry mapped by the given key.""" + redis_client = self._resolve_redis_client() + redis_key = self._get_redis_key(key) + + try: + redis_client.delete(redis_key) + except Exception as e: + warnings.warn(f"Redis clear_cache_entry failed: {e}", stacklevel=2) + async def aclear_cache(self) -> None: """Clear the cache of this core asynchronously.""" redis_client = await self._resolve_redis_client_async() @@ -365,6 +375,16 @@ async def aclear_cache(self) -> None: except Exception as e: warnings.warn(f"Redis clear_cache failed: {e}", stacklevel=2) + async def aclear_cache_entry(self, key: str) -> None: + """Clear the cache entry mapped by the given key asynchronously.""" + redis_client = await self._resolve_redis_client_async() + redis_key = self._get_redis_key(key) + + try: + await redis_client.delete(redis_key) + except Exception as e: + warnings.warn(f"Redis clear_cache_entry failed: {e}", stacklevel=2) + def clear_being_calculated(self) -> None: """Mark all entries in this cache as not being calculated.""" redis_client = self._resolve_redis_client() diff --git a/src/cachier/cores/s3.py b/src/cachier/cores/s3.py index 3dda917e..5939e6fa 100644 --- a/src/cachier/cores/s3.py +++ b/src/cachier/cores/s3.py @@ -333,6 +333,15 @@ def clear_cache(self) -> None: except Exception as exc: _safe_warn(f"S3 clear_cache failed: {exc}") + def clear_cache_entry(self, key: str) -> None: + """Delete the cache entry mapped by the given key from S3.""" + client = self._get_s3_client() + s3_key = self._get_s3_key(key) + try: + client.delete_object(Bucket=self.s3_bucket, Key=s3_key) + except Exception as exc: + _safe_warn(f"S3 clear_cache_entry failed: {exc}") + def clear_being_calculated(self) -> None: """Reset the ``_processing`` flag on all entries for this function in S3.""" client = self._get_s3_client() diff --git a/src/cachier/cores/sql.py b/src/cachier/cores/sql.py index 28e5d721..178212e9 100644 --- a/src/cachier/cores/sql.py +++ b/src/cachier/cores/sql.py @@ -434,12 +434,28 @@ def clear_cache(self) -> None: session.execute(delete(CacheTable).where(CacheTable.function_id == self._func_str)) session.commit() + def clear_cache_entry(self, key: str) -> None: + session_factory = self._get_sync_session() + with self._lock, session_factory() as session: + session.execute( + delete(CacheTable).where(and_(CacheTable.function_id == self._func_str, CacheTable.key == key)) + ) + session.commit() + async def aclear_cache(self) -> None: session_factory = await self._get_async_session() async with session_factory() as session: await session.execute(delete(CacheTable).where(CacheTable.function_id == self._func_str)) await session.commit() + async def aclear_cache_entry(self, key: str) -> None: + session_factory = await self._get_async_session() + async with session_factory() as session: + await session.execute( + delete(CacheTable).where(and_(CacheTable.function_id == self._func_str, CacheTable.key == key)) + ) + await session.commit() + def clear_being_calculated(self) -> None: session_factory = self._get_sync_session() with self._lock, session_factory() as session: diff --git a/tests/mongo_tests/clients.py b/tests/mongo_tests/clients.py index 289a188a..fb216750 100644 --- a/tests/mongo_tests/clients.py +++ b/tests/mongo_tests/clients.py @@ -69,3 +69,11 @@ async def delete_many(self, query=None, **kwargs): del self._docs[key] deleted += 1 return {"deleted_count": deleted} + + async def delete_one(self, query=None, **kwargs): + if query is None: + query = kwargs.get("filter", {}) + key = (query.get("func"), query.get("key")) + existed = key in self._docs + self._docs.pop(key, None) + return {"deleted_count": int(existed)} diff --git a/tests/mongo_tests/test_async_mongo_core.py b/tests/mongo_tests/test_async_mongo_core.py index d92226bd..c35172d8 100644 --- a/tests/mongo_tests/test_async_mongo_core.py +++ b/tests/mongo_tests/test_async_mongo_core.py @@ -152,6 +152,10 @@ async def test_async_mongo_core_mark_clear_and_stale_paths(): await core.amark_entry_being_calculated("fresh") await core.amark_entry_not_calculated("fresh") + await core.aclear_cache_entry("fresh") + assert (core._func_str, "fresh") not in collection._docs + + await core.aset_entry("fresh", 2) await core.aclear_being_calculated() await core.adelete_stale_entries(timedelta(hours=1)) diff --git a/tests/mongo_tests/test_mongo_core.py b/tests/mongo_tests/test_mongo_core.py index f7cde7af..5fab18ce 100644 --- a/tests/mongo_tests/test_mongo_core.py +++ b/tests/mongo_tests/test_mongo_core.py @@ -139,6 +139,28 @@ def _test_mongo_caching(arg_1, arg_2): assert val6 == val5 +@pytest.mark.mongo +def test_mongo_clear_cache_for_specific_arguments(): + """clear_cache can remove one Mongo cache entry by function arguments.""" + + @cachier(mongetter=_test_mongetter) + def _test_mongo_caching(arg_1, arg_2): + """Some function.""" + return random() + arg_1 + arg_2 + + _test_mongo_caching.clear_cache() + val1 = _test_mongo_caching(1, arg_2=2) + val2 = _test_mongo_caching(3, arg_2=4) + assert _test_mongo_caching(1, arg_2=2) == val1 + assert _test_mongo_caching(3, arg_2=4) == val2 + + _test_mongo_caching.clear_cache(1, arg_2=2) + + assert _test_mongo_caching(1, arg_2=2) != val1 + assert _test_mongo_caching(3, arg_2=4) == val2 + _test_mongo_caching.clear_cache() + + @pytest.mark.mongo def test_mongo_stale_after(): """Testing MongoDB core stale_after functionality.""" diff --git a/tests/pickle_tests/test_pickle_core.py b/tests/pickle_tests/test_pickle_core.py index 451363da..fa34de1a 100644 --- a/tests/pickle_tests/test_pickle_core.py +++ b/tests/pickle_tests/test_pickle_core.py @@ -287,7 +287,7 @@ def test_being_calc_next_time(separate_files): separate_files=separate_files, ) _being_calc_next_time_decorated.clear_cache() - _being_calc_next_time(0.13, 0.02) + _being_calc_next_time_decorated(0.13, 0.02) sleep(1.1) res_queue = queue.Queue() thread1 = threading.Thread( @@ -1445,6 +1445,49 @@ def flaky_remove(path): assert not os.path.exists(dummy_file) +@pytest.mark.pickle +def test_clear_cache_entry_retries_on_permission_error(tmp_path): + """Test clear_cache_entry retries on PermissionError then succeeds.""" + core = _PickleCore( + hash_func=None, + cache_dir=tmp_path, + pickle_reload=False, + wait_for_calc_timeout=10, + separate_files=True, + ) + + def mock_func(): + pass + + core.set_func(mock_func) + + cache_fpath = core.cache_fpath + dummy_file = cache_fpath + "_dummykey" + with open(dummy_file, "wb") as f: + f.write(b"") + + real_remove = os.remove + call_count = 0 + + def flaky_remove(path): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise PermissionError("locked") + real_remove(path) + + with ( + patch("cachier.cores.pickle.os.remove", side_effect=flaky_remove), + patch("cachier.cores.pickle.time.sleep") as mock_sleep, + ): + core.clear_cache_entry("dummykey") + assert mock_sleep.call_count == 2 + mock_sleep.assert_any_call(0.1) + mock_sleep.assert_any_call(0.2) + + assert not os.path.exists(dummy_file) + + @pytest.mark.pickle def test_clear_all_cache_files_raises_on_persistent_permission_error(tmp_path): """Test _clear_all_cache_files re-raises PermissionError after all retries.""" diff --git a/tests/redis_tests/test_async_redis_core.py b/tests/redis_tests/test_async_redis_core.py index e5dec8f7..e3aa1333 100644 --- a/tests/redis_tests/test_async_redis_core.py +++ b/tests/redis_tests/test_async_redis_core.py @@ -224,8 +224,14 @@ async def test_async_redis_core_mark_and_clear_paths(): assert clean1._processing is False assert clean2._processing is False + await core.aclear_cache_entry("cleanup-1") + _, cleared_one = await core.aget_entry_by_key("cleanup-1") + _, kept = await core.aget_entry_by_key("cleanup-2") + assert cleared_one is None + assert kept is not None + await core.aclear_cache() - _, cleared = await core.aget_entry_by_key("cleanup-1") + _, cleared = await core.aget_entry_by_key("cleanup-2") assert cleared is None await core.aclear_cache() @@ -252,6 +258,11 @@ async def test_async_redis_core_mark_and_clear_exceptions(): with pytest.warns(UserWarning, match="Redis clear_cache failed"): await core.aclear_cache() + client.fail_keys = False + client.fail_delete = True + with pytest.warns(UserWarning, match="Redis clear_cache_entry failed"): + await core.aclear_cache_entry("k") + @pytest.mark.redis @pytest.mark.asyncio diff --git a/tests/redis_tests/test_redis_core.py b/tests/redis_tests/test_redis_core.py index 6b2d9f39..6e4ea429 100644 --- a/tests/redis_tests/test_redis_core.py +++ b/tests/redis_tests/test_redis_core.py @@ -130,6 +130,28 @@ def _test_redis_caching(arg_1, arg_2): assert val6 == val5 +@pytest.mark.redis +def test_redis_clear_cache_for_specific_arguments(): + """clear_cache can remove one Redis cache entry by function arguments.""" + + @cachier(backend="redis", redis_client=_test_redis_getter) + def _test_redis_caching(arg_1, arg_2): + """Some function.""" + return random() + arg_1 + arg_2 + + _test_redis_caching.clear_cache() + val1 = _test_redis_caching(1, arg_2=2) + val2 = _test_redis_caching(3, arg_2=4) + assert _test_redis_caching(1, arg_2=2) == val1 + assert _test_redis_caching(3, arg_2=4) == val2 + + _test_redis_caching.clear_cache(1, arg_2=2) + + assert _test_redis_caching(1, arg_2=2) != val1 + assert _test_redis_caching(3, arg_2=4) == val2 + _test_redis_caching.clear_cache() + + @pytest.mark.redis def test_redis_stale_after(): """Testing Redis core stale_after functionality.""" diff --git a/tests/redis_tests/test_redis_core_exceptions.py b/tests/redis_tests/test_redis_core_exceptions.py index 7d51eb11..0988747e 100644 --- a/tests/redis_tests/test_redis_core_exceptions.py +++ b/tests/redis_tests/test_redis_core_exceptions.py @@ -96,6 +96,12 @@ def test_clear_cache_exceptions(self, core, mock_redis): with pytest.warns(UserWarning, match="Redis clear_cache failed"): core.clear_cache() + def test_clear_cache_entry_exceptions(self, core, mock_redis): + """Test clear_cache_entry Redis delete exception handling.""" + mock_redis.delete.side_effect = Exception("Redis error") + with pytest.warns(UserWarning, match="Redis clear_cache_entry failed"): + core.clear_cache_entry("key") + def test_clear_being_calculated_exceptions(self, core, mock_redis): """Test clear_being_calculated Redis keys exception handling.""" mock_redis.keys.side_effect = Exception("Redis error") diff --git a/tests/s3_tests/test_s3_core.py b/tests/s3_tests/test_s3_core.py index de9e7132..4bb03a1e 100644 --- a/tests/s3_tests/test_s3_core.py +++ b/tests/s3_tests/test_s3_core.py @@ -104,6 +104,27 @@ def _cached(x): assert val1 != val2 +@pytest.mark.s3 +def test_s3_clear_cache_for_specific_arguments(s3_bucket): + """clear_cache can remove one S3 cache entry by function arguments.""" + + @cachier(backend="s3", s3_bucket=s3_bucket, s3_region=TEST_REGION) + def _cached(x, y=1): + return random() + x + y + + _cached.clear_cache() + val1 = _cached(1, y=2) + val2 = _cached(3, y=4) + assert _cached(1, y=2) == val1 + assert _cached(3, y=4) == val2 + + _cached.clear_cache(1, y=2) + + assert _cached(1, y=2) != val1 + assert _cached(3, y=4) == val2 + _cached.clear_cache() + + @pytest.mark.s3 def test_s3_core_skip_cache(s3_bucket): """cachier__skip_cache bypasses the cache.""" @@ -678,6 +699,16 @@ def _dummy(x): core.clear_cache() assert client.delete_objects.call_count == 2 + client.delete_object = Mock(return_value=None) + core.clear_cache_entry("k") + assert client.delete_object.called + + client.delete_object = Mock(side_effect=RuntimeError("delete failed")) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + core.clear_cache_entry("k") + assert any("clear_cache_entry failed" in str(w.message) for w in caught) + client.get_paginator = Mock(side_effect=RuntimeError("paginate failed")) with warnings.catch_warnings(record=True) as caught: warnings.simplefilter("always") diff --git a/tests/sql_tests/test_async_sql_core.py b/tests/sql_tests/test_async_sql_core.py index 555d5708..e96c4169 100644 --- a/tests/sql_tests/test_async_sql_core.py +++ b/tests/sql_tests/test_async_sql_core.py @@ -138,6 +138,45 @@ async def test_sqlcore_async_session_requires_async_engine(): await core._get_async_session() +@pytest.mark.sql +@pytest.mark.asyncio +async def test_async_sql_clear_cache_entry_executes_delete(monkeypatch): + core = _SQLCore.__new__(_SQLCore) + core._func_str = "test_func" + + class FakeSession: + def __init__(self): + self.executed = [] + self.committed = False + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + async def execute(self, stmt): + self.executed.append(stmt) + + async def commit(self): + self.committed = True + + session = FakeSession() + + def session_factory(): + return session + + async def fake_get_async_session(): + return session_factory + + monkeypatch.setattr(core, "_get_async_session", fake_get_async_session) + + await core.aclear_cache_entry("one-key") + + assert session.executed + assert session.committed is True + + @pytest.mark.sql @pytest.mark.asyncio async def test_sqlcore_async_session_creates_tables_once(async_sql_engine): diff --git a/tests/sql_tests/test_sql_core.py b/tests/sql_tests/test_sql_core.py index 4601f7b2..e879a52f 100644 --- a/tests/sql_tests/test_sql_core.py +++ b/tests/sql_tests/test_sql_core.py @@ -75,6 +75,27 @@ def f(x, y): f.clear_cache() +@pytest.mark.sql +def test_sql_clear_cache_for_specific_arguments(): + """clear_cache can remove one SQL cache entry by function arguments.""" + + @cachier(backend="sql", sql_engine=SQL_CONN_STR) + def f(x, y): + return random() + x + y + + f.clear_cache() + v1 = f(1, y=2) + v2 = f(3, y=4) + assert f(1, y=2) == v1 + assert f(3, y=4) == v2 + + f.clear_cache(1, y=2) + + assert f(1, y=2) != v1 + assert f(3, y=4) == v2 + f.clear_cache() + + @pytest.mark.sql def test_sql_stale_after(): @cachier( diff --git a/tests/test_async_core.py b/tests/test_async_core.py index 770c0ee4..c815495a 100644 --- a/tests/test_async_core.py +++ b/tests/test_async_core.py @@ -41,6 +41,31 @@ async def async_func(x): async_func.clear_cache() + @pytest.mark.memory + @pytest.mark.asyncio + async def test_clear_cache_for_specific_arguments(self): + """Test async wrappers can clear one cached argument set.""" + call_count = 0 + + @cachier(backend="memory") + async def async_func(x, y=1): + nonlocal call_count + call_count += 1 + return f"{x}:{y}:{call_count}" + + async_func.clear_cache() + + first = await async_func(1, y=2) + second = await async_func(3, y=4) + assert await async_func(1, y=2) == first + assert await async_func(3, y=4) == second + + await async_func.aclear_cache(1, y=2) + + assert await async_func(1, y=2) != first + assert await async_func(3, y=4) == second + await async_func.aclear_cache() + @pytest.mark.pickle @pytest.mark.asyncio async def test_pickle(self): diff --git a/tests/test_base_core.py b/tests/test_base_core.py index 589eac66..2ebf069f 100644 --- a/tests/test_base_core.py +++ b/tests/test_base_core.py @@ -19,6 +19,7 @@ def __init__(self, hash_func, wait_for_calc_timeout, entry_size_limit=None): self.last_mark_not_calc = None self.last_wait_key = None self.clear_cache_called = False + self.last_cleared_key = None self.clear_being_calculated_called = False self.last_deleted_stale_after = None @@ -48,6 +49,10 @@ def clear_cache(self): """Clear the cache.""" self.clear_cache_called = True + def clear_cache_entry(self, key): + """Clear one cache entry.""" + self.last_cleared_key = key + def clear_being_calculated(self): """Clear entries that are being calculated.""" self.clear_being_calculated_called = True @@ -112,6 +117,9 @@ async def fake_aset_entry(key, value): await core.aclear_cache() assert core.clear_cache_called is True + await core.aclear_cache_entry("one-key") + assert core.last_cleared_key == "one-key" + await core.aclear_being_calculated() assert core.clear_being_calculated_called is True diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 1f38afd3..890c59ed 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -98,6 +98,66 @@ def func(): func.clear_cache() +@pytest.mark.smoke +@pytest.mark.parametrize("backend", ["memory", "pickle"]) +@pytest.mark.parametrize("separate_files", [False, True]) +def test_clear_cache_for_specific_arguments(tmp_path, backend, separate_files): + """Test that clear_cache can reset one cached argument set.""" + if backend == "memory" and separate_files: + pytest.skip("separate_files only applies to pickle") + + call_count = 0 + + @cachier_decorator( + backend=backend, + cache_dir=tmp_path, + separate_files=separate_files, + ) + def func(x, y=1): + nonlocal call_count + call_count += 1 + return f"{x}:{y}:{call_count}" + + func.clear_cache() + + first = func(1, y=2) + second = func(3, y=4) + assert func(1, y=2) == first + assert func(3, y=4) == second + + func.clear_cache(1, y=2) + + assert func(1, y=2) != first + assert func(3, y=4) == second + func.clear_cache() + + +@pytest.mark.smoke +def test_clear_cache_for_specific_method_arguments(): + """Test that method cache entries can be cleared by method arguments.""" + call_count = 0 + + class Foo: + @cachier_decorator(backend="memory", allow_non_static_methods=True) + def method(self, x): + nonlocal call_count + call_count += 1 + return f"{x}:{call_count}" + + obj = Foo() + obj.method.clear_cache() + first = obj.method(1) + second = obj.method(2) + assert obj.method(1) == first + assert obj.method(2) == second + + obj.method.clear_cache(1) + + assert obj.method(1) != first + assert obj.method(2) == second + obj.method.clear_cache() + + @pytest.mark.smoke def test_pickle_backend_stale_after(tmp_path): """Test that stale_after=timedelta(0) always recalculates."""