Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,15 @@ The Cachier wrapper adds a ``clear_cache()`` function to each wrapped function.

foo.clear_cache()

To clear only the cache entry for a specific call, pass the same arguments to ``clear_cache()`` that you would pass to the wrapped function:

.. code-block:: python

foo.clear_cache(arg1, arg2)
foo.clear_cache(arg1, arg2=arg2)

The asynchronous ``aclear_cache()`` helper supports the same argument-specific form.

General Configuration
----------------------

Expand Down
33 changes: 25 additions & 8 deletions src/cachier/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ class _CachierWrappedFunc(Protocol[_P, _R_co]):

def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R_co: ... # pragma: no cover

clear_cache: Callable[[], Any]
clear_cache: Callable[..., Any]
clear_being_calculated: Callable[[], Any]
aclear_cache: Callable[[], Any]
aclear_cache: Callable[..., Any]
aclear_being_calculated: Callable[[], Any]
cache_dpath: Callable[[], Optional[str]]
precache_value: Callable[..., Any]
Expand Down Expand Up @@ -219,6 +219,13 @@ def _is_async_redis_client(client: Any) -> bool:
return all(inspect.iscoroutinefunction(getattr(client, name, None)) for name in method_names)


def _convert_public_cache_args(func, _is_method: bool, args: tuple, kwds: dict) -> dict:
"""Convert cache-management arguments to canonical cache-key kwargs."""
if _is_method:
args = (None, *args)
return _convert_args_kwargs(func, _is_method=_is_method, args=args, kwds=kwds)


def cachier(
hash_func: Optional[HashFunc] = None,
hash_params: Optional[HashFunc] = None,
Expand Down Expand Up @@ -733,9 +740,14 @@ async def func_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
def func_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
return _call(*args, **kwargs) # type: ignore[arg-type]

def _clear_cache():
"""Clear the cache."""
core.clear_cache()
def _clear_cache(*args, **kwds):
"""Clear the cache, or only the entry matching the provided arguments."""
if args or kwds:
kwargs = _convert_public_cache_args(func, core.func_is_method, args, kwds)
key = core.get_key((), kwargs)
core.clear_cache_entry(key)
else:
core.clear_cache()
if is_coroutine:
return _ImmediateAwaitable()
return None
Expand All @@ -747,9 +759,14 @@ def _clear_being_calculated():
return _ImmediateAwaitable()
return None

async def _aclear_cache():
"""Clear the cache asynchronously."""
await core.aclear_cache()
async def _aclear_cache(*args, **kwds):
"""Clear the cache asynchronously, or only the entry matching the provided arguments."""
if args or kwds:
kwargs = _convert_public_cache_args(func, core.func_is_method, args, kwds)
key = core.get_key((), kwargs)
await core.aclear_cache_entry(key)
else:
await core.aclear_cache()

async def _aclear_being_calculated():
"""Mark all entries in this cache as not being calculated asynchronously."""
Expand Down
12 changes: 12 additions & 0 deletions src/cachier/cores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,18 @@ async def aclear_cache(self) -> None:
"""
await asyncio.to_thread(self.clear_cache)

@abc.abstractmethod
def clear_cache_entry(self, key: str) -> None:
"""Clear the cache entry mapped by the given key."""

async def aclear_cache_entry(self, key: str) -> None:
"""Async-compatible variant of :meth:`clear_cache_entry`.

By default this runs in a thread to avoid blocking the event loop.

"""
await asyncio.to_thread(self.clear_cache_entry, key)

@abc.abstractmethod
def clear_being_calculated(self) -> None:
"""Mark all entries in this cache as not being calculated."""
Expand Down
5 changes: 5 additions & 0 deletions src/cachier/cores/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ def clear_cache(self) -> None:
# Update size metrics after clearing
self._update_size_metrics()

def clear_cache_entry(self, key: str) -> None:
with self.lock:
self.cache.pop(self._hash_func_key(key), None)
self._update_size_metrics()

def clear_being_calculated(self) -> None:
with self.lock:
for entry in self.cache.values():
Expand Down
8 changes: 8 additions & 0 deletions src/cachier/cores/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,14 @@ async def aclear_cache(self) -> None:
mongo_collection = await self._ensure_collection_async()
await mongo_collection.delete_many(filter={"func": self._func_str})

def clear_cache_entry(self, key: str) -> None:
mongo_collection = self._ensure_collection()
mongo_collection.delete_one(filter={"func": self._func_str, "key": key})

async def aclear_cache_entry(self, key: str) -> None:
mongo_collection = await self._ensure_collection_async()
await mongo_collection.delete_one(filter={"func": self._func_str, "key": key})

def clear_being_calculated(self) -> None:
mongo_collection = self._ensure_collection()
mongo_collection.update_many(
Expand Down
50 changes: 34 additions & 16 deletions src/cachier/cores/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,19 +172,22 @@ def _clear_all_cache_files(self) -> None:
path, name = os.path.split(self.cache_fpath)
for subpath in os.listdir(path):
if subpath.startswith(f"{name}_"):
fpath = os.path.join(path, subpath)
# Retry loop to handle Windows mandatory file-locking (WinError 32):
# portalocker holds an exclusive lock while a thread is computing,
# so os.remove() may fail transiently until the lock is released.
for attempt in range(3): # pragma: no branch
try:
os.remove(fpath)
break
except PermissionError:
if attempt < 2:
time.sleep(0.1 * (attempt + 1))
else:
raise
self._remove_cache_file_with_retries(os.path.join(path, subpath))

@staticmethod
def _remove_cache_file_with_retries(fpath: str) -> None:
# Retry loop to handle Windows mandatory file-locking (WinError 32):
# portalocker holds an exclusive lock while a thread is computing,
# so os.remove() may fail transiently until the lock is released.
for attempt in range(3): # pragma: no branch
try:
os.remove(fpath)
break
except PermissionError:
if attempt < 2:
time.sleep(0.1 * (attempt + 1))
else:
raise

def _clear_being_calculated_all_cache_files(self) -> None:
path, name = os.path.split(self.cache_fpath)
Expand Down Expand Up @@ -272,10 +275,14 @@ async def _aset_entry(self, key: str, func_res: Any) -> bool:
return self._set_entry(key, func_res)

def mark_entry_being_calculated_separate_files(self, key: str) -> None:
self._save_cache(
CacheEntry(value=None, time=datetime.now(), stale=False, _processing=True),
separate_file_key=key,
entry = self._load_cache_by_key(key) or CacheEntry(
value=None,
time=datetime.now(),
stale=False,
_processing=False,
)
entry._processing = True
self._save_cache(entry, separate_file_key=key)

def _mark_entry_not_calculated_separate_files(self, key: str) -> None:
_, entry = self.get_entry_by_key(key)
Expand Down Expand Up @@ -416,6 +423,17 @@ def clear_cache(self) -> None:
else:
self._save_cache({})

def clear_cache_entry(self, key: str) -> None:
if self.separate_files:
with suppress(FileNotFoundError):
self._remove_cache_file_with_retries(f"{self.cache_fpath}_{key}")
return

with self.lock:
cache = self.get_cache_dict()
cache.pop(key, None)
self._save_cache(cache)

def clear_being_calculated(self) -> None:
if self.separate_files:
self._clear_being_calculated_all_cache_files()
Expand Down
20 changes: 20 additions & 0 deletions src/cachier/cores/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,16 @@ def clear_cache(self) -> None:
except Exception as e:
warnings.warn(f"Redis clear_cache failed: {e}", stacklevel=2)

def clear_cache_entry(self, key: str) -> None:
"""Clear the cache entry mapped by the given key."""
redis_client = self._resolve_redis_client()
redis_key = self._get_redis_key(key)

try:
redis_client.delete(redis_key)
except Exception as e:
warnings.warn(f"Redis clear_cache_entry failed: {e}", stacklevel=2)

async def aclear_cache(self) -> None:
"""Clear the cache of this core asynchronously."""
redis_client = await self._resolve_redis_client_async()
Expand All @@ -365,6 +375,16 @@ async def aclear_cache(self) -> None:
except Exception as e:
warnings.warn(f"Redis clear_cache failed: {e}", stacklevel=2)

async def aclear_cache_entry(self, key: str) -> None:
"""Clear the cache entry mapped by the given key asynchronously."""
redis_client = await self._resolve_redis_client_async()
redis_key = self._get_redis_key(key)

try:
await redis_client.delete(redis_key)
except Exception as e:
warnings.warn(f"Redis clear_cache_entry failed: {e}", stacklevel=2)

def clear_being_calculated(self) -> None:
"""Mark all entries in this cache as not being calculated."""
redis_client = self._resolve_redis_client()
Expand Down
9 changes: 9 additions & 0 deletions src/cachier/cores/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,15 @@ def clear_cache(self) -> None:
except Exception as exc:
_safe_warn(f"S3 clear_cache failed: {exc}")

def clear_cache_entry(self, key: str) -> None:
"""Delete the cache entry mapped by the given key from S3."""
client = self._get_s3_client()
s3_key = self._get_s3_key(key)
try:
client.delete_object(Bucket=self.s3_bucket, Key=s3_key)
except Exception as exc:
_safe_warn(f"S3 clear_cache_entry failed: {exc}")

def clear_being_calculated(self) -> None:
"""Reset the ``_processing`` flag on all entries for this function in S3."""
client = self._get_s3_client()
Expand Down
16 changes: 16 additions & 0 deletions src/cachier/cores/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,12 +434,28 @@ def clear_cache(self) -> None:
session.execute(delete(CacheTable).where(CacheTable.function_id == self._func_str))
session.commit()

def clear_cache_entry(self, key: str) -> None:
session_factory = self._get_sync_session()
with self._lock, session_factory() as session:
session.execute(
delete(CacheTable).where(and_(CacheTable.function_id == self._func_str, CacheTable.key == key))
)
session.commit()

async def aclear_cache(self) -> None:
session_factory = await self._get_async_session()
async with session_factory() as session:
await session.execute(delete(CacheTable).where(CacheTable.function_id == self._func_str))
await session.commit()

async def aclear_cache_entry(self, key: str) -> None:
session_factory = await self._get_async_session()
async with session_factory() as session:
await session.execute(
delete(CacheTable).where(and_(CacheTable.function_id == self._func_str, CacheTable.key == key))
)
await session.commit()

def clear_being_calculated(self) -> None:
session_factory = self._get_sync_session()
with self._lock, session_factory() as session:
Expand Down
8 changes: 8 additions & 0 deletions tests/mongo_tests/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,11 @@ async def delete_many(self, query=None, **kwargs):
del self._docs[key]
deleted += 1
return {"deleted_count": deleted}

async def delete_one(self, query=None, **kwargs):
if query is None:
query = kwargs.get("filter", {})
key = (query.get("func"), query.get("key"))
existed = key in self._docs
self._docs.pop(key, None)
return {"deleted_count": int(existed)}
4 changes: 4 additions & 0 deletions tests/mongo_tests/test_async_mongo_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ async def test_async_mongo_core_mark_clear_and_stale_paths():

await core.amark_entry_being_calculated("fresh")
await core.amark_entry_not_calculated("fresh")
await core.aclear_cache_entry("fresh")
assert (core._func_str, "fresh") not in collection._docs

await core.aset_entry("fresh", 2)
await core.aclear_being_calculated()
await core.adelete_stale_entries(timedelta(hours=1))

Expand Down
22 changes: 22 additions & 0 deletions tests/mongo_tests/test_mongo_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,28 @@ def _test_mongo_caching(arg_1, arg_2):
assert val6 == val5


@pytest.mark.mongo
def test_mongo_clear_cache_for_specific_arguments():
"""clear_cache can remove one Mongo cache entry by function arguments."""

@cachier(mongetter=_test_mongetter)
def _test_mongo_caching(arg_1, arg_2):
"""Some function."""
return random() + arg_1 + arg_2

_test_mongo_caching.clear_cache()
val1 = _test_mongo_caching(1, arg_2=2)
val2 = _test_mongo_caching(3, arg_2=4)
assert _test_mongo_caching(1, arg_2=2) == val1
assert _test_mongo_caching(3, arg_2=4) == val2

_test_mongo_caching.clear_cache(1, arg_2=2)

assert _test_mongo_caching(1, arg_2=2) != val1
assert _test_mongo_caching(3, arg_2=4) == val2
_test_mongo_caching.clear_cache()


@pytest.mark.mongo
def test_mongo_stale_after():
"""Testing MongoDB core stale_after functionality."""
Expand Down
45 changes: 44 additions & 1 deletion tests/pickle_tests/test_pickle_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def test_being_calc_next_time(separate_files):
separate_files=separate_files,
)
_being_calc_next_time_decorated.clear_cache()
_being_calc_next_time(0.13, 0.02)
_being_calc_next_time_decorated(0.13, 0.02)
sleep(1.1)
res_queue = queue.Queue()
thread1 = threading.Thread(
Expand Down Expand Up @@ -1445,6 +1445,49 @@ def flaky_remove(path):
assert not os.path.exists(dummy_file)


@pytest.mark.pickle
def test_clear_cache_entry_retries_on_permission_error(tmp_path):
"""Test clear_cache_entry retries on PermissionError then succeeds."""
core = _PickleCore(
hash_func=None,
cache_dir=tmp_path,
pickle_reload=False,
wait_for_calc_timeout=10,
separate_files=True,
)

def mock_func():
pass

core.set_func(mock_func)

cache_fpath = core.cache_fpath
dummy_file = cache_fpath + "_dummykey"
with open(dummy_file, "wb") as f:
f.write(b"")

real_remove = os.remove
call_count = 0

def flaky_remove(path):
nonlocal call_count
call_count += 1
if call_count < 3:
raise PermissionError("locked")
real_remove(path)

with (
patch("cachier.cores.pickle.os.remove", side_effect=flaky_remove),
patch("cachier.cores.pickle.time.sleep") as mock_sleep,
):
core.clear_cache_entry("dummykey")
assert mock_sleep.call_count == 2
mock_sleep.assert_any_call(0.1)
mock_sleep.assert_any_call(0.2)

assert not os.path.exists(dummy_file)


@pytest.mark.pickle
def test_clear_all_cache_files_raises_on_persistent_permission_error(tmp_path):
"""Test _clear_all_cache_files re-raises PermissionError after all retries."""
Expand Down
Loading
Loading