From 5e428fd848ede87fbc65b4886d1a0e28d25bdabd Mon Sep 17 00:00:00 2001 From: Tao Peng Date: Fri, 29 Aug 2025 21:10:22 -0700 Subject: [PATCH 01/10] first try --- tests/unit/test_persistent_cache.py | 330 ++++++++++++++++++++++++++-- 1 file changed, 310 insertions(+), 20 deletions(-) diff --git a/tests/unit/test_persistent_cache.py b/tests/unit/test_persistent_cache.py index 32fef006..c21a7fde 100644 --- a/tests/unit/test_persistent_cache.py +++ b/tests/unit/test_persistent_cache.py @@ -1,7 +1,11 @@ +import concurrent.futures import dbm +import multiprocessing import os +import tempfile import threading import time +from pathlib import Path import pytest @@ -9,7 +13,8 @@ # DBM backends to test with -DBM_BACKENDS = ["dbm.sqlite3", "dbm.gnu", "dbm.ndbm", "dbm.dumb"] +DBM_BACKENDS = ["dbm.sqlite3"] +# , "dbm.gnu", "dbm.ndbm", "dbm.dumb"] @pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) @@ -223,21 +228,23 @@ def test_corrupted_data(tmpdir, dbm_backend): @pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) def test_concurrency(tmpdir, dbm_backend): - """Test concurrent access to the cache.""" + """Test concurrent access to the cache - fixed version.""" cache_file = os.path.join(tmpdir, f"cache_concurrency_{dbm_backend}") cache = PersistentCache(cache_file) - num_threads = 10 - num_operations = 50 + num_threads = 20 + num_operations = 10 results = [] # Store assertion failures for pytest to check after threads complete def worker(thread_id): + # Fixed: Don't overwrite thread_id parameter for i in range(num_operations): key = f"key_{thread_id}_{i}" value = f"value_{thread_id}_{i}" - cache.set(key, value) + if cache.get(key) is None: + cache.set(key, value) # Occasionally read a previously written value - if i > 0 and i % 5 == 0: + if i > 0 and i % 2 == 0: prev_key = f"key_{thread_id}_{i - 1}" prev_value = cache.get(prev_key) if prev_value != f"value_{thread_id}_{i - 1}": @@ -245,24 +252,307 @@ def worker(thread_id): f"Expected {prev_key} to be value_{thread_id}_{i - 1}, got {prev_value}" ) - threads = [] - for i in range(num_threads): - t = threading.Thread(target=worker, args=(i,)) - threads.append(t) - t.start() - - for t in threads: - t.join() + with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor: + list(executor.map(worker, range(num_threads))) # Check for any failures in threads assert not results, f"Thread assertions failed: {results}" - # Verify all values were written correctly - for i in range(num_threads): - for j in range(num_operations): - key = f"key_{i}_{j}" - expected_value = f"value_{i}_{j}" - assert cache.get(key) == expected_value + +@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) +def test_aggressive_concurrency_database_lock(tmpdir, dbm_backend): + """Test aggressive concurrent access that might trigger database lock issues.""" + cache_file = os.path.join(tmpdir, f"cache_aggressive_{dbm_backend}") + + # Use a higher number of threads and operations to stress test + num_threads = 50 + num_operations = 50 + errors = [] + + def aggressive_worker(thread_id): + """Worker that performs rapid database operations.""" + try: + # Create a new cache instance per thread to simulate real-world usage + thread_cache = PersistentCache(cache_file) + + for i in range(num_operations): + key = f"thread_{thread_id}_op_{i}" + value = f"value_{thread_id}_{i}" + + # Rapid set/get operations + thread_cache.set(key, value, expires_in=1) + retrieved = thread_cache.get(key) + + if retrieved != value: + errors.append( + f"Thread {thread_id}: Expected {value}, got {retrieved}" + ) + + # Perform some operations that might cause contention + if i % 5 == 0: + thread_cache.clear_expired() + + # Try to access keys from other threads + if i % 3 == 0 and thread_id > 0: + other_key = f"thread_{thread_id - 1}_op_{i}" + thread_cache.get(other_key) + + except Exception as e: + errors.append(f"Thread {thread_id} error: {str(e)}") + + # Run with high concurrency + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(aggressive_worker, i) for i in range(num_threads)] + concurrent.futures.wait(futures) + + # Check for database lock errors or other issues + database_lock_errors = [e for e in errors if "database is locked" in str(e).lower()] + if database_lock_errors: + pytest.fail(f"Database lock errors detected: {database_lock_errors}") + + if errors: + pytest.fail( + f"Concurrency errors detected: {errors[:10]}" + ) # Show first 10 errors + + +# Global function for multiprocessing (needed for pickling) +def _multiprocess_worker(process_id, cache_file, num_ops): + """Worker function for multiprocessing test.""" + import sys + import traceback + + try: + # Each process creates its own cache instance + from mapillary_tools.history import PersistentCache + + cache = PersistentCache(cache_file) + + for i in range(num_ops): + key = f"proc_{process_id}_op_{i}" + value = f"value_{process_id}_{i}" + + # Rapid operations that might cause database locking + cache.set(key, value, expires_in=2) + retrieved = cache.get(key) + + if retrieved != value: + return f"Process {process_id}: Expected {value}, got {retrieved}" + + # Operations that might cause contention + if i % 3 == 0: + cache.clear_expired() + + # Try to read from other processes + if i % 5 == 0 and process_id > 0: + other_key = f"proc_{process_id - 1}_op_{i}" + cache.get(other_key) + + except Exception as e: + return f"Process {process_id} error: {str(e)} - {traceback.format_exc()}" + + return None + + +def test_multiprocess_database_lock(tmpdir): + """Test multiprocess access that might trigger database lock issues.""" + cache_file = os.path.join(tmpdir, "cache_multiprocess") + + # Use multiprocessing to create real process contention + num_processes = 8 + num_operations = 20 + + with multiprocessing.Pool(processes=num_processes) as pool: + results = pool.starmap( + _multiprocess_worker, + [(i, cache_file, num_operations) for i in range(num_processes)], + ) + + # Check for errors + errors = [r for r in results if r is not None] + database_lock_errors = [e for e in errors if "database is locked" in str(e).lower()] + + if database_lock_errors: + pytest.fail( + f"Database lock errors in multiprocess test: {database_lock_errors}" + ) + + if errors: + pytest.fail(f"Multiprocess errors: {errors}") + + +@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) +def test_rapid_file_creation_database_lock(tmpdir, dbm_backend): + """Test rapid database file creation that might trigger lock issues.""" + base_path = os.path.join(tmpdir, f"rapid_creation_{dbm_backend}") + + def rapid_creator(thread_id): + """Create and use cache files rapidly.""" + errors = [] + try: + for i in range(10): + # Create a new cache file for each operation + cache_file = f"{base_path}_{thread_id}_{i}" + cache = PersistentCache(cache_file) + + # Perform operations immediately after creation + cache.set("test_key", f"test_value_{thread_id}_{i}") + result = cache.get("test_key") + + if result != f"test_value_{thread_id}_{i}": + errors.append( + f"Thread {thread_id}, iteration {i}: Expected test_value_{thread_id}_{i}, got {result}" + ) + + except Exception as e: + errors.append(f"Thread {thread_id} error: {str(e)}") + + return errors + + # Run multiple threads creating cache files rapidly + num_threads = 20 + all_errors = [] + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(rapid_creator, i) for i in range(num_threads)] + for future in concurrent.futures.as_completed(futures): + errors = future.result() + all_errors.extend(errors) + + # Check for database lock errors + database_lock_errors = [ + e for e in all_errors if "database is locked" in str(e).lower() + ] + if database_lock_errors: + pytest.fail( + f"Database lock errors in rapid creation test: {database_lock_errors}" + ) + + if all_errors: + pytest.fail(f"Rapid creation errors: {all_errors[:5]}") # Show first 5 errors + + +@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) +def test_simultaneous_database_operations(tmpdir, dbm_backend): + """Test simultaneous database operations that might cause locking.""" + cache_file = os.path.join(tmpdir, f"cache_simultaneous_{dbm_backend}") + + # Barrier to synchronize thread start + barrier = threading.Barrier(10) + errors = [] + + def synchronized_worker(thread_id): + """Worker that starts operations simultaneously.""" + try: + cache = PersistentCache(cache_file) + + # Wait for all threads to be ready + barrier.wait() + + # All threads perform operations at the same time + for i in range(20): + key = f"sync_{thread_id}_{i}" + value = f"value_{thread_id}_{i}" + + # Simultaneous write operations + cache.set(key, value) + + # Immediate read back + result = cache.get(key) + if result != value: + errors.append(f"Thread {thread_id}: Expected {value}, got {result}") + + # Mixed operations + if i % 2 == 0: + cache.clear_expired() + + except Exception as e: + errors.append(f"Thread {thread_id} error: {str(e)}") + + # Start all threads simultaneously + threads = [] + for i in range(10): + thread = threading.Thread(target=synchronized_worker, args=(i,)) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Check for database lock errors + database_lock_errors = [e for e in errors if "database is locked" in str(e).lower()] + if database_lock_errors: + pytest.fail( + f"Database lock errors in simultaneous operations test: {database_lock_errors}" + ) + + if errors: + pytest.fail(f"Simultaneous operation errors: {errors[:10]}") + + +@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) +def test_stress_database_with_exceptions(tmpdir, dbm_backend): + """Stress test that might trigger database lock issues with exception handling.""" + cache_file = os.path.join(tmpdir, f"cache_stress_{dbm_backend}") + + def stress_worker(thread_id): + """Worker that performs operations and handles exceptions.""" + database_lock_count = 0 + other_errors = [] + + cache = PersistentCache(cache_file) + + for i in range(100): # More operations to increase chance of lock + try: + key = f"stress_{thread_id}_{i}" + value = f"value_{thread_id}_{i}" + + # Rapid operations + cache.set(key, value, expires_in=1) + cache.get(key) + + # Operations that might cause contention + if i % 10 == 0: + cache.clear_expired() + + # Try to access the database file directly (might cause issues) + if i % 15 == 0: + try: + with dbm.open(cache_file, flag="r") as db: + list(db.keys()) + except Exception: + pass # Ignore direct access errors + + except Exception as e: + error_msg = str(e).lower() + if "database is locked" in error_msg: + database_lock_count += 1 + else: + other_errors.append(f"Thread {thread_id}, op {i}: {str(e)}") + + return database_lock_count, other_errors + + # Run stress test + num_threads = 15 + total_lock_errors = 0 + all_other_errors = [] + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(stress_worker, i) for i in range(num_threads)] + for future in concurrent.futures.as_completed(futures): + lock_count, other_errors = future.result() + total_lock_errors += lock_count + all_other_errors.extend(other_errors) + + # Report results + if total_lock_errors > 0: + pytest.fail( + f"Database lock errors detected: {total_lock_errors} total lock errors" + ) + + if all_other_errors: + pytest.fail(f"Other stress test errors: {all_other_errors[:5]}") @pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) From b9f9f7ff39fa9800f2f476ce0c4c6f1ae3ee2f6b Mon Sep 17 00:00:00 2001 From: Tao Peng Date: Sat, 30 Aug 2025 15:25:06 -0700 Subject: [PATCH 02/10] fix --- tests/unit/test_persistent_cache.py | 204 ++++++++++++++++++++++------ 1 file changed, 160 insertions(+), 44 deletions(-) diff --git a/tests/unit/test_persistent_cache.py b/tests/unit/test_persistent_cache.py index c21a7fde..9e608899 100644 --- a/tests/unit/test_persistent_cache.py +++ b/tests/unit/test_persistent_cache.py @@ -1,11 +1,7 @@ import concurrent.futures -import dbm -import multiprocessing import os -import tempfile -import threading import time -from pathlib import Path +import traceback import pytest @@ -210,19 +206,18 @@ def test_corrupted_data(tmpdir, dbm_backend): # Set valid entry cache.set("key1", "value1") - # Corrupt the data by directly writing invalid JSON - with dbm.open(cache_file, "c") as db: - db["corrupted"] = b"not valid json" - db["corrupted_dict"] = b'"not a dict"' + # Test the _decode method directly with corrupted data to simulate corruption + # This tests the error handling without directly manipulating the database + corrupted_result = cache._decode(b"not valid json") + assert corrupted_result == {} - # Check that corrupted entries are handled gracefully - assert cache.get("corrupted") is None - assert cache.get("corrupted_dict") is None + corrupted_dict_result = cache._decode(b'"not a dict"') + assert corrupted_dict_result == {} # Valid entries should still work assert cache.get("key1") == "value1" - # Clear expired should not crash on corrupted entries + # Clear expired should not crash cache.clear_expired() @@ -319,13 +314,8 @@ def aggressive_worker(thread_id): # Global function for multiprocessing (needed for pickling) def _multiprocess_worker(process_id, cache_file, num_ops): """Worker function for multiprocessing test.""" - import sys - import traceback - try: # Each process creates its own cache instance - from mapillary_tools.history import PersistentCache - cache = PersistentCache(cache_file) for i in range(num_ops): @@ -362,11 +352,14 @@ def test_multiprocess_database_lock(tmpdir): num_processes = 8 num_operations = 20 - with multiprocessing.Pool(processes=num_processes) as pool: - results = pool.starmap( - _multiprocess_worker, - [(i, cache_file, num_operations) for i in range(num_processes)], - ) + with concurrent.futures.ProcessPoolExecutor(max_workers=num_processes) as executor: + futures = [ + executor.submit(_multiprocess_worker, i, cache_file, num_operations) + for i in range(num_processes) + ] + results = [ + future.result() for future in concurrent.futures.as_completed(futures) + ] # Check for errors errors = [r for r in results if r is not None] @@ -437,8 +430,6 @@ def test_simultaneous_database_operations(tmpdir, dbm_backend): """Test simultaneous database operations that might cause locking.""" cache_file = os.path.join(tmpdir, f"cache_simultaneous_{dbm_backend}") - # Barrier to synchronize thread start - barrier = threading.Barrier(10) errors = [] def synchronized_worker(thread_id): @@ -446,9 +437,6 @@ def synchronized_worker(thread_id): try: cache = PersistentCache(cache_file) - # Wait for all threads to be ready - barrier.wait() - # All threads perform operations at the same time for i in range(20): key = f"sync_{thread_id}_{i}" @@ -469,16 +457,10 @@ def synchronized_worker(thread_id): except Exception as e: errors.append(f"Thread {thread_id} error: {str(e)}") - # Start all threads simultaneously - threads = [] - for i in range(10): - thread = threading.Thread(target=synchronized_worker, args=(i,)) - threads.append(thread) - thread.start() - - # Wait for all threads to complete - for thread in threads: - thread.join() + # Start all threads simultaneously using ThreadPoolExecutor + with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(synchronized_worker, i) for i in range(10)] + concurrent.futures.wait(futures) # Check for database lock errors database_lock_errors = [e for e in errors if "database is locked" in str(e).lower()] @@ -495,14 +477,13 @@ def synchronized_worker(thread_id): def test_stress_database_with_exceptions(tmpdir, dbm_backend): """Stress test that might trigger database lock issues with exception handling.""" cache_file = os.path.join(tmpdir, f"cache_stress_{dbm_backend}") + cache = PersistentCache(cache_file) def stress_worker(thread_id): """Worker that performs operations and handles exceptions.""" database_lock_count = 0 other_errors = [] - cache = PersistentCache(cache_file) - for i in range(100): # More operations to increase chance of lock try: key = f"stress_{thread_id}_{i}" @@ -516,15 +497,16 @@ def stress_worker(thread_id): if i % 10 == 0: cache.clear_expired() - # Try to access the database file directly (might cause issues) + # Additional operations that might cause contention if i % 15 == 0: + # Use cache.keys() instead of direct dbm access try: - with dbm.open(cache_file, flag="r") as db: - list(db.keys()) + list(cache.keys()) except Exception: - pass # Ignore direct access errors + pass # Ignore access errors except Exception as e: + # raise e error_msg = str(e).lower() if "database is locked" in error_msg: database_lock_count += 1 @@ -555,6 +537,140 @@ def stress_worker(thread_id): pytest.fail(f"Other stress test errors: {all_other_errors[:5]}") +@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) +def test_shared_cache_instance_database_lock(tmpdir, dbm_backend): + """Test shared cache instance across threads - reproduces real uploader.py usage pattern.""" + cache_file = os.path.join(tmpdir, f"cache_shared_{dbm_backend}") + + # Create a single shared cache instance (like in uploader.py) + shared_cache = PersistentCache(cache_file) + shared_cache.clear_expired() + + # Use higher numbers to increase chance of database lock + num_threads = 30 + num_operations = 100 + errors = [] + + def shared_cache_worker(thread_id): + """Worker that uses the shared cache instance (like CachedImageUploader.upload).""" + try: + for i in range(num_operations): + key = f"shared_thread_{thread_id}_op_{i}" + value = f"shared_value_{thread_id}_{i}" + + # Simulate the pattern from CachedImageUploader: + # 1. Check cache first (_get_cached_file_handle) + cached_value = shared_cache.get(key) + + if cached_value is None: + # 2. Set new value (_set_file_handle_cache) + shared_cache.set(key, value, expires_in=2) + retrieved = shared_cache.get(key) + + if retrieved != value: + errors.append( + f"Thread {thread_id}: Expected {value}, got {retrieved}" + ) + else: + # 3. Update cache with existing value (_set_file_handle_cache) + shared_cache.set(key, cached_value, expires_in=2) + + # Occasional cleanup operations + if i % 20 == 0: + shared_cache.clear_expired() + + # Cross-thread access pattern + if i % 7 == 0 and thread_id > 0: + other_key = f"shared_thread_{thread_id - 1}_op_{i}" + shared_cache.get(other_key) + + except Exception as e: + errors.append(f"Thread {thread_id} error: {str(e)}") + + # Run all threads using the same shared cache instance + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(shared_cache_worker, i) for i in range(num_threads)] + concurrent.futures.wait(futures) + + # Check for database lock errors + database_lock_errors = [e for e in errors if "database is locked" in str(e).lower()] + if database_lock_errors: + pytest.fail( + f"Database lock errors with shared cache instance: {database_lock_errors}" + ) + + # Check for data consistency errors (values not persisting correctly) + data_consistency_errors = [e for e in errors if "Expected" in e and "got None" in e] + if data_consistency_errors: + pytest.fail( + f"Data consistency errors with shared cache (race conditions): {data_consistency_errors[:5]}" + ) + + if errors: + pytest.fail(f"Other shared cache errors: {errors[:10]}") # Show first 10 errors + + +@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) +def test_extreme_shared_cache_database_lock(tmpdir, dbm_backend): + """Extreme test with shared cache instance to try to trigger database lock errors.""" + cache_file = os.path.join(tmpdir, f"cache_extreme_shared_{dbm_backend}") + + # Create a single shared cache instance + shared_cache = PersistentCache(cache_file) + + # Use even more aggressive settings + num_threads = 50 + num_operations = 200 + errors = [] + + def extreme_shared_worker(thread_id): + """Worker that hammers the shared cache instance.""" + try: + for i in range(num_operations): + key = f"extreme_{thread_id}_{i}" + value = f"val_{thread_id}_{i}" + + # Rapid fire operations without any delays + shared_cache.set(key, value, expires_in=1) + shared_cache.get(key) + + # More frequent cleanup to increase contention + if i % 5 == 0: + shared_cache.clear_expired() + + # More cross-thread access + if i % 2 == 0 and thread_id > 0: + other_key = f"extreme_{thread_id - 1}_{i}" + shared_cache.get(other_key) + + # Additional operations to increase database pressure + if i % 3 == 0: + list(shared_cache.keys()) + + except Exception as e: + errors.append(f"Thread {thread_id} error: {str(e)}") + + # Run with maximum concurrency + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [ + executor.submit(extreme_shared_worker, i) for i in range(num_threads) + ] + concurrent.futures.wait(futures) + + # Check specifically for database lock errors + database_lock_errors = [e for e in errors if "database is locked" in str(e).lower()] + if database_lock_errors: + pytest.fail( + f"SUCCESS: Database lock errors reproduced with shared cache: {database_lock_errors[:5]}" + ) + + # Report other errors but don't fail the test for them + if errors: + print( + f"Other errors (not database locks): {len(errors)} total, first 5: {errors[:5]}" + ) + + @pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) def test_decode_invalid_data(tmpdir, dbm_backend): """Test _decode method with invalid data.""" From 5eddfc37dea4aa497d6cd4b4f7d960616fd3326f Mon Sep 17 00:00:00 2001 From: Tao Peng Date: Sun, 31 Aug 2025 18:48:33 -0700 Subject: [PATCH 03/10] add more tests --- tests/unit/test_persistent_cache.py | 590 ++++++++++++++++++++-------- 1 file changed, 417 insertions(+), 173 deletions(-) diff --git a/tests/unit/test_persistent_cache.py b/tests/unit/test_persistent_cache.py index 9e608899..22424fae 100644 --- a/tests/unit/test_persistent_cache.py +++ b/tests/unit/test_persistent_cache.py @@ -8,19 +8,13 @@ from mapillary_tools.history import PersistentCache -# DBM backends to test with -DBM_BACKENDS = ["dbm.sqlite3"] -# , "dbm.gnu", "dbm.ndbm", "dbm.dumb"] - - -@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) -def test_basic_operations_with_backend(tmpdir, dbm_backend): +def test_basic_operations_with_backend(tmpdir): """Test basic operations with different DBM backends. Note: This is a demonstration of pytest's parametrize feature. The actual PersistentCache class might not support specifying backends. """ - cache_file = os.path.join(tmpdir, dbm_backend) + cache_file = os.path.join(tmpdir, "cache") # Here you would use the backend if the cache implementation supported it cache = PersistentCache(cache_file) @@ -32,20 +26,18 @@ def test_basic_operations_with_backend(tmpdir, dbm_backend): # This is just a placeholder to demonstrate pytest's parametrization -@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) -def test_get_set(tmpdir, dbm_backend): +def test_get_set(tmpdir): """Test basic get and set operations.""" - cache_file = os.path.join(tmpdir, f"cache_get_set_{dbm_backend}") + cache_file = os.path.join(tmpdir, "cache") cache = PersistentCache(cache_file) cache.set("key1", "value1") assert cache.get("key1") == "value1" assert cache.get("nonexistent_key") is None -@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) -def test_expiration(tmpdir, dbm_backend): +def test_expiration(tmpdir): """Test that entries expire correctly.""" - cache_file = os.path.join(tmpdir, f"cache_expiration_{dbm_backend}") + cache_file = os.path.join(tmpdir, "cache") cache = PersistentCache(cache_file) # Set with short expiration @@ -65,7 +57,6 @@ def test_expiration(tmpdir, dbm_backend): assert cache.get("long_lived") == "value" -@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) @pytest.mark.parametrize( "expire_time,sleep_time,should_exist", [ @@ -74,13 +65,9 @@ def test_expiration(tmpdir, dbm_backend): (5, 2, True), # Should not expire yet ], ) -def test_parametrized_expiration( - tmpdir, dbm_backend, expire_time, sleep_time, should_exist -): +def test_parametrized_expiration(tmpdir, expire_time, sleep_time, should_exist): """Test expiration with different timing combinations.""" - cache_file = os.path.join( - tmpdir, f"cache_param_exp_{dbm_backend}_{expire_time}_{sleep_time}" - ) + cache_file = os.path.join(tmpdir, f"cache_param_exp_{expire_time}_{sleep_time}") cache = PersistentCache(cache_file) key = f"key_expires_in_{expire_time}_sleeps_{sleep_time}" @@ -94,10 +81,9 @@ def test_parametrized_expiration( assert cache.get(key) is None -@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) -def test_clear_expired(tmpdir, dbm_backend): +def test_clear_expired(tmpdir): """Test clearing expired entries.""" - cache_file = os.path.join(tmpdir, f"cache_clear_expired_{dbm_backend}") + cache_file = os.path.join(tmpdir, f"cache_clear_expired") cache = PersistentCache(cache_file) # Test 1: Single expired key @@ -117,10 +103,9 @@ def test_clear_expired(tmpdir, dbm_backend): assert cache.get("not_expired") == "value2" -@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) -def test_clear_expired_multiple(tmpdir, dbm_backend): +def test_clear_expired_multiple(tmpdir): """Test clearing multiple expired entries.""" - cache_file = os.path.join(tmpdir, f"cache_clear_multiple_{dbm_backend}") + cache_file = os.path.join(tmpdir, f"cache_clear_multiple") cache = PersistentCache(cache_file) # Test 2: Multiple expired keys @@ -143,10 +128,9 @@ def test_clear_expired_multiple(tmpdir, dbm_backend): assert cache.get("not_expired") == "value3" -@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) -def test_clear_expired_all(tmpdir, dbm_backend): +def test_clear_expired_all(tmpdir): """Test clearing all expired entries.""" - cache_file = os.path.join(tmpdir, f"cache_clear_all_{dbm_backend}") + cache_file = os.path.join(tmpdir, f"cache_clear_all") cache = PersistentCache(cache_file) # Test 3: All entries expired @@ -165,10 +149,9 @@ def test_clear_expired_all(tmpdir, dbm_backend): assert b"key2" in expired_keys -@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) -def test_clear_expired_none(tmpdir, dbm_backend): +def test_clear_expired_none(tmpdir): """Test clearing when no entries are expired.""" - cache_file = os.path.join(tmpdir, f"cache_clear_none_{dbm_backend}") + cache_file = os.path.join(tmpdir, f"cache_clear_none") cache = PersistentCache(cache_file) # Test 4: No entries expired @@ -184,10 +167,9 @@ def test_clear_expired_none(tmpdir, dbm_backend): assert cache.get("key2") == "value2" -@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) -def test_clear_expired_empty(tmpdir, dbm_backend): +def test_clear_expired_empty(tmpdir): """Test clearing expired entries on an empty cache.""" - cache_file = os.path.join(tmpdir, f"cache_clear_empty_{dbm_backend}") + cache_file = os.path.join(tmpdir, f"cache_clear_empty") cache = PersistentCache(cache_file) # Test 5: Empty cache @@ -197,10 +179,9 @@ def test_clear_expired_empty(tmpdir, dbm_backend): assert len(expired_keys) == 0 -@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) -def test_corrupted_data(tmpdir, dbm_backend): +def test_corrupted_data(tmpdir): """Test handling of corrupted data.""" - cache_file = os.path.join(tmpdir, f"cache_corrupted_{dbm_backend}") + cache_file = os.path.join(tmpdir, f"cache_corrupted") cache = PersistentCache(cache_file) # Set valid entry @@ -221,10 +202,10 @@ def test_corrupted_data(tmpdir, dbm_backend): cache.clear_expired() -@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) -def test_concurrency(tmpdir, dbm_backend): +def test_concurrency(tmpdir): """Test concurrent access to the cache - fixed version.""" - cache_file = os.path.join(tmpdir, f"cache_concurrency_{dbm_backend}") + cache_file = os.path.join(tmpdir, f"cache_concurrency") + cache = PersistentCache(cache_file) num_threads = 20 num_operations = 10 @@ -254,63 +235,6 @@ def worker(thread_id): assert not results, f"Thread assertions failed: {results}" -@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) -def test_aggressive_concurrency_database_lock(tmpdir, dbm_backend): - """Test aggressive concurrent access that might trigger database lock issues.""" - cache_file = os.path.join(tmpdir, f"cache_aggressive_{dbm_backend}") - - # Use a higher number of threads and operations to stress test - num_threads = 50 - num_operations = 50 - errors = [] - - def aggressive_worker(thread_id): - """Worker that performs rapid database operations.""" - try: - # Create a new cache instance per thread to simulate real-world usage - thread_cache = PersistentCache(cache_file) - - for i in range(num_operations): - key = f"thread_{thread_id}_op_{i}" - value = f"value_{thread_id}_{i}" - - # Rapid set/get operations - thread_cache.set(key, value, expires_in=1) - retrieved = thread_cache.get(key) - - if retrieved != value: - errors.append( - f"Thread {thread_id}: Expected {value}, got {retrieved}" - ) - - # Perform some operations that might cause contention - if i % 5 == 0: - thread_cache.clear_expired() - - # Try to access keys from other threads - if i % 3 == 0 and thread_id > 0: - other_key = f"thread_{thread_id - 1}_op_{i}" - thread_cache.get(other_key) - - except Exception as e: - errors.append(f"Thread {thread_id} error: {str(e)}") - - # Run with high concurrency - with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: - futures = [executor.submit(aggressive_worker, i) for i in range(num_threads)] - concurrent.futures.wait(futures) - - # Check for database lock errors or other issues - database_lock_errors = [e for e in errors if "database is locked" in str(e).lower()] - if database_lock_errors: - pytest.fail(f"Database lock errors detected: {database_lock_errors}") - - if errors: - pytest.fail( - f"Concurrency errors detected: {errors[:10]}" - ) # Show first 10 errors - - # Global function for multiprocessing (needed for pickling) def _multiprocess_worker(process_id, cache_file, num_ops): """Worker function for multiprocessing test.""" @@ -374,10 +298,9 @@ def test_multiprocess_database_lock(tmpdir): pytest.fail(f"Multiprocess errors: {errors}") -@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) -def test_rapid_file_creation_database_lock(tmpdir, dbm_backend): +def test_rapid_file_creation_database_lock(tmpdir): """Test rapid database file creation that might trigger lock issues.""" - base_path = os.path.join(tmpdir, f"rapid_creation_{dbm_backend}") + base_path = os.path.join(tmpdir, f"rapid_creation") def rapid_creator(thread_id): """Create and use cache files rapidly.""" @@ -425,10 +348,9 @@ def rapid_creator(thread_id): pytest.fail(f"Rapid creation errors: {all_errors[:5]}") # Show first 5 errors -@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) -def test_simultaneous_database_operations(tmpdir, dbm_backend): +def test_simultaneous_database_operations(tmpdir): """Test simultaneous database operations that might cause locking.""" - cache_file = os.path.join(tmpdir, f"cache_simultaneous_{dbm_backend}") + cache_file = os.path.join(tmpdir, f"cache_simultaneous") errors = [] @@ -473,10 +395,9 @@ def synchronized_worker(thread_id): pytest.fail(f"Simultaneous operation errors: {errors[:10]}") -@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) -def test_stress_database_with_exceptions(tmpdir, dbm_backend): +def test_stress_database_with_exceptions(tmpdir): """Stress test that might trigger database lock issues with exception handling.""" - cache_file = os.path.join(tmpdir, f"cache_stress_{dbm_backend}") + cache_file = os.path.join(tmpdir, f"cache_stress") cache = PersistentCache(cache_file) def stress_worker(thread_id): @@ -537,10 +458,9 @@ def stress_worker(thread_id): pytest.fail(f"Other stress test errors: {all_other_errors[:5]}") -@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) -def test_shared_cache_instance_database_lock(tmpdir, dbm_backend): +def test_shared_cache_instance_database_lock(tmpdir): """Test shared cache instance across threads - reproduces real uploader.py usage pattern.""" - cache_file = os.path.join(tmpdir, f"cache_shared_{dbm_backend}") + cache_file = os.path.join(tmpdir, f"cache_shared") # Create a single shared cache instance (like in uploader.py) shared_cache = PersistentCache(cache_file) @@ -586,6 +506,7 @@ def shared_cache_worker(thread_id): except Exception as e: errors.append(f"Thread {thread_id} error: {str(e)}") + print(e, traceback.format_exc()) # Run all threads using the same shared cache instance with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: @@ -610,90 +531,413 @@ def shared_cache_worker(thread_id): pytest.fail(f"Other shared cache errors: {errors[:10]}") # Show first 10 errors -@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) -def test_extreme_shared_cache_database_lock(tmpdir, dbm_backend): - """Extreme test with shared cache instance to try to trigger database lock errors.""" - cache_file = os.path.join(tmpdir, f"cache_extreme_shared_{dbm_backend}") +def test_decode_invalid_data(tmpdir): + """Test _decode method with invalid data.""" + cache_file = os.path.join(tmpdir, f"cache_decode_invalid") + cache = PersistentCache(cache_file) - # Create a single shared cache instance - shared_cache = PersistentCache(cache_file) + # Test with various invalid inputs + result = cache._decode(b"not valid json") + assert result == {} + + result = cache._decode(b'"string instead of dict"') + assert result == {} + + +def test_is_expired(tmpdir): + """Test _is_expired method.""" + cache_file = os.path.join(tmpdir, f"cache_is_expired") + cache = PersistentCache(cache_file) + + # Test with various payloads + assert cache._is_expired({"expires_at": time.time() - 10}) is True + assert cache._is_expired({"expires_at": time.time() + 10}) is False + assert cache._is_expired({}) is False + assert cache._is_expired({"expires_at": "not a number"}) is False + assert cache._is_expired({"expires_at": None}) is False - # Use even more aggressive settings - num_threads = 50 - num_operations = 200 + +# Shared worker functions for concurrency tests +def _shared_worker_get_set_pattern( + worker_id, cache, num_operations, key_prefix="worker" +): + """Shared worker implementation: get key -> if not exist then set key=value.""" errors = [] + try: + for i in range(num_operations): + key = f"{key_prefix}_{worker_id}_{i}" + value = f"value_{worker_id}_{i}" - def extreme_shared_worker(thread_id): - """Worker that hammers the shared cache instance.""" - try: - for i in range(num_operations): - key = f"extreme_{thread_id}_{i}" - value = f"val_{thread_id}_{i}" + # Pattern: get a key -> if not exist then set key=value + existing_value = cache.get(key) + if existing_value is None: + cache.set(key, value, expires_in=10) - # Rapid fire operations without any delays - shared_cache.set(key, value, expires_in=1) - shared_cache.get(key) + # Verify the value was set correctly + retrieved_value = cache.get(key) + if retrieved_value != value: + errors.append( + f"Worker {worker_id}: Expected {value}, got {retrieved_value}" + ) - # More frequent cleanup to increase contention - if i % 5 == 0: - shared_cache.clear_expired() + except Exception as e: + errors.append(f"Worker {worker_id} error: {str(e)}") - # More cross-thread access - if i % 2 == 0 and thread_id > 0: - other_key = f"extreme_{thread_id - 1}_{i}" - shared_cache.get(other_key) + return errors - # Additional operations to increase database pressure - if i % 3 == 0: - list(shared_cache.keys()) - except Exception as e: - errors.append(f"Thread {thread_id} error: {str(e)}") +def _multiprocess_worker_get_set_pattern(args): + """Multiprocess worker wrapper for the shared worker pattern.""" + worker_id, cache_file, num_operations, key_prefix = args + try: + # Each process creates its own cache instance pointing to the same file + cache = PersistentCache(cache_file) + return _shared_worker_get_set_pattern( + worker_id, cache, num_operations, key_prefix + ) + except Exception as e: + return [f"Process {worker_id} error: {str(e)} - {traceback.format_exc()}"] + + +def _mixed_operations_worker(args): + """Worker that performs mixed operations with get->set pattern.""" + worker_id, cache_file, num_operations, run_prefix = args + errors = [] + try: + cache = PersistentCache(cache_file) + + for i in range(num_operations): + key = f"{run_prefix}_mixed_{worker_id}_{i}" + value = f"mixed_value_{worker_id}_{i}" + + # Pattern: get a key -> if not exist then set key=value + existing_value = cache.get(key) + if existing_value is None: + cache.set(key, value, expires_in=15) + + # Mixed operations: also try to read from other workers + if i % 5 == 0 and worker_id > 0: + other_key = f"{run_prefix}_mixed_{worker_id - 1}_{i}" + cache.get(other_key) + + # Periodic cleanup + if i % 10 == 0: + cache.clear_expired() + + except Exception as e: + errors.append(f"Process {worker_id} error: {str(e)}") + + return errors + + +def test_multithread_shared_cache_get_set_pattern(tmpdir): + """Test multithreaded access with shared cache instance using get->set pattern.""" + cache_file = os.path.join(tmpdir, "cache_multithread_shared") + + # Initialize cache once and share across all workers + shared_cache = PersistentCache(cache_file) - # Run with maximum concurrency + # Clear expired before first concurrent run + shared_cache.clear_expired() + + num_threads = 10 + num_operations = 20 + all_errors = [] + + # First concurrent run with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: futures = [ - executor.submit(extreme_shared_worker, i) for i in range(num_threads) + executor.submit( + _shared_worker_get_set_pattern, i, shared_cache, num_operations, "run1" + ) + for i in range(num_threads) ] - concurrent.futures.wait(futures) + for future in concurrent.futures.as_completed(futures): + errors = future.result() + all_errors.extend(errors) - # Check specifically for database lock errors - database_lock_errors = [e for e in errors if "database is locked" in str(e).lower()] + # Clear expired between concurrent runs + expired_keys_1 = shared_cache.clear_expired() + assert isinstance(expired_keys_1, list), "clear_expired should return a list" + + # Second concurrent run + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [ + executor.submit( + _shared_worker_get_set_pattern, i, shared_cache, num_operations, "run2" + ) + for i in range(num_threads) + ] + for future in concurrent.futures.as_completed(futures): + errors = future.result() + all_errors.extend(errors) + + # Clear expired after second concurrent run + expired_keys_2 = shared_cache.clear_expired() + assert isinstance(expired_keys_2, list), "clear_expired should return a list" + + # Assertions using keys() and counts + all_keys = list(shared_cache.keys()) + expected_keys_count = num_threads * num_operations * 2 # Two runs + + assert ( + len(all_keys) == expected_keys_count + ), f"Expected {expected_keys_count} keys, got {len(all_keys)}" + + # Verify key patterns + run1_keys = [k for k in all_keys if b"run1" in k] + run2_keys = [k for k in all_keys if b"run2" in k] + + assert len(run1_keys) == num_threads * num_operations + assert len(run2_keys) == num_threads * num_operations + + # Check for any worker errors + assert not all_errors, f"Worker errors occurred: {all_errors}" + + +def test_multiprocess_shared_file_get_set_pattern(tmpdir): + """Test multiprocess access with shared cache file using get->set pattern.""" + cache_file = os.path.join(tmpdir, "cache_multiprocess_shared") + + # Initialize cache and clear expired before first concurrent run + init_cache = PersistentCache(cache_file) + init_cache.clear_expired() + + num_processes = 6 + num_operations = 15 + + # First concurrent run + with concurrent.futures.ProcessPoolExecutor(max_workers=num_processes) as executor: + args_list = [ + (i, cache_file, num_operations, "proc_run1") for i in range(num_processes) + ] + futures = [ + executor.submit(_multiprocess_worker_get_set_pattern, args) + for args in args_list + ] + results_1 = [ + future.result() for future in concurrent.futures.as_completed(futures) + ] + + # Clear expired between concurrent runs + expired_keys_1 = init_cache.clear_expired() + assert isinstance(expired_keys_1, list), "clear_expired should return a list" + + # Second concurrent run + with concurrent.futures.ProcessPoolExecutor(max_workers=num_processes) as executor: + args_list = [ + (i, cache_file, num_operations, "proc_run2") for i in range(num_processes) + ] + futures = [ + executor.submit(_multiprocess_worker_get_set_pattern, args) + for args in args_list + ] + results_2 = [ + future.result() for future in concurrent.futures.as_completed(futures) + ] + + # Clear expired after second concurrent run + expired_keys_2 = init_cache.clear_expired() + assert isinstance(expired_keys_2, list), "clear_expired should return a list" + + # Collect all errors from both runs + all_errors = [] + for result in results_1 + results_2: + all_errors.extend(result) + + # Assertions using keys() and counts + all_keys = list(init_cache.keys()) + expected_keys_count = num_processes * num_operations * 2 # Two runs + + assert ( + len(all_keys) == expected_keys_count + ), f"Expected {expected_keys_count} keys, got {len(all_keys)}" + + # Verify key patterns + run1_keys = [k for k in all_keys if b"proc_run1" in k] + run2_keys = [k for k in all_keys if b"proc_run2" in k] + + assert len(run1_keys) == num_processes * num_operations + assert len(run2_keys) == num_processes * num_operations + + # Check for database lock errors specifically + database_lock_errors = [ + e for e in all_errors if "database is locked" in str(e).lower() + ] if database_lock_errors: pytest.fail( - f"SUCCESS: Database lock errors reproduced with shared cache: {database_lock_errors[:5]}" + f"Database lock errors in multiprocess test: {database_lock_errors}" ) - # Report other errors but don't fail the test for them - if errors: - print( - f"Other errors (not database locks): {len(errors)} total, first 5: {errors[:5]}" - ) + # Check for any other errors + assert not all_errors, f"Process errors occurred: {all_errors}" -@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) -def test_decode_invalid_data(tmpdir, dbm_backend): - """Test _decode method with invalid data.""" - cache_file = os.path.join(tmpdir, f"cache_decode_invalid_{dbm_backend}") - cache = PersistentCache(cache_file) +def test_multithread_high_contention_get_set_pattern(tmpdir): + """Test high contention multithreaded access with shared cache using get->set pattern.""" + cache_file = os.path.join(tmpdir, "cache_multithread_contention") - # Test with various invalid inputs - result = cache._decode(b"not valid json") - assert result == {} + # Initialize cache once and share across all workers + shared_cache = PersistentCache(cache_file) - result = cache._decode(b'"string instead of dict"') - assert result == {} + # Clear expired before first concurrent run + shared_cache.clear_expired() + num_threads = 20 + num_operations = 50 + all_errors = [] -@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) -def test_is_expired(tmpdir, dbm_backend): - """Test _is_expired method.""" - cache_file = os.path.join(tmpdir, f"cache_is_expired_{dbm_backend}") - cache = PersistentCache(cache_file) + def high_contention_worker(worker_id, run_prefix): + """Worker with higher contention - accessing overlapping keys.""" + errors = [] + try: + for i in range(num_operations): + # Use overlapping keys to increase contention + key = f"{run_prefix}_shared_{i % 10}" # Only 10 unique keys per run + value = f"value_{worker_id}_{i}" - # Test with various payloads - assert cache._is_expired({"expires_at": time.time() - 10}) is True - assert cache._is_expired({"expires_at": time.time() + 10}) is False - assert cache._is_expired({}) is False - assert cache._is_expired({"expires_at": "not a number"}) is False - assert cache._is_expired({"expires_at": None}) is False + # Pattern: get a key -> if not exist then set key=value + existing_value = shared_cache.get(key) + if existing_value is None: + shared_cache.set(key, value, expires_in=10) + + # Occasional clear_expired to add more contention + if i % 25 == 0: + shared_cache.clear_expired() + + except Exception as e: + errors.append(f"Worker {worker_id} error: {str(e)}") + + return errors + + # First concurrent run with high contention + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [ + executor.submit(high_contention_worker, i, "contention_run1") + for i in range(num_threads) + ] + for future in concurrent.futures.as_completed(futures): + errors = future.result() + all_errors.extend(errors) + + # Clear expired between concurrent runs + expired_keys_1 = shared_cache.clear_expired() + assert isinstance(expired_keys_1, list), "clear_expired should return a list" + keys_after_run1 = list(shared_cache.keys()) + + # Second concurrent run with high contention + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [ + executor.submit(high_contention_worker, i, "contention_run2") + for i in range(num_threads) + ] + for future in concurrent.futures.as_completed(futures): + errors = future.result() + all_errors.extend(errors) + + # Clear expired after second concurrent run + expired_keys_2 = shared_cache.clear_expired() + assert isinstance(expired_keys_2, list), "clear_expired should return a list" + keys_after_run2 = list(shared_cache.keys()) + + # Assertions using keys() and counts + # With overlapping keys, we expect at most 10 keys per run (due to key overlap) + assert ( + len(keys_after_run1) <= 10 + ), f"Expected at most 10 keys after run1, got {len(keys_after_run1)}" + assert ( + len(keys_after_run2) <= 20 + ), f"Expected at most 20 keys after run2, got {len(keys_after_run2)}" + + # Verify key patterns exist + run1_keys = [k for k in keys_after_run2 if b"contention_run1" in k] + run2_keys = [k for k in keys_after_run2 if b"contention_run2" in k] + + assert len(run1_keys) > 0, "Should have some keys from run1" + assert len(run2_keys) > 0, "Should have some keys from run2" + + # Check for any worker errors + assert not all_errors, f"Worker errors occurred: {all_errors}" + + +def test_multiprocess_mixed_operations_get_set_pattern(tmpdir): + """Test multiprocess with mixed operations using get->set pattern.""" + cache_file = os.path.join(tmpdir, "cache_multiprocess_mixed") + + # Initialize cache and clear expired before first concurrent run + init_cache = PersistentCache(cache_file) + init_cache.clear_expired() + + num_processes = 8 + num_operations = 25 + + # First concurrent run + with concurrent.futures.ProcessPoolExecutor(max_workers=num_processes) as executor: + args_list = [ + (i, cache_file, num_operations, "mixed_run1") for i in range(num_processes) + ] + futures = [ + executor.submit(_mixed_operations_worker, args) for args in args_list + ] + results_1 = [ + future.result() for future in concurrent.futures.as_completed(futures) + ] + + # Clear expired between concurrent runs + expired_keys_1 = init_cache.clear_expired() + assert isinstance(expired_keys_1, list), "clear_expired should return a list" + keys_after_run1 = list(init_cache.keys()) + + # Verify first run results + expected_keys_run1 = num_processes * num_operations + assert ( + len(keys_after_run1) == expected_keys_run1 + ), f"Expected {expected_keys_run1} keys after run1, got {len(keys_after_run1)}" + + # Second concurrent run + with concurrent.futures.ProcessPoolExecutor(max_workers=num_processes) as executor: + args_list = [ + (i, cache_file, num_operations, "mixed_run2") for i in range(num_processes) + ] + futures = [ + executor.submit(_mixed_operations_worker, args) for args in args_list + ] + results_2 = [ + future.result() for future in concurrent.futures.as_completed(futures) + ] + + # Clear expired after second concurrent run + expired_keys_2 = init_cache.clear_expired() + assert isinstance(expired_keys_2, list), "clear_expired should return a list" + keys_after_run2 = list(init_cache.keys()) + + # Collect all errors from both runs + all_errors = [] + for result in results_1 + results_2: + all_errors.extend(result) + + # Assertions using keys() and counts + expected_keys_count = num_processes * num_operations * 2 # Two runs + + assert ( + len(keys_after_run2) == expected_keys_count + ), f"Expected {expected_keys_count} keys, got {len(keys_after_run2)}" + + # Verify key patterns + run1_keys = [k for k in keys_after_run2 if b"mixed_run1" in k] + run2_keys = [k for k in keys_after_run2 if b"mixed_run2" in k] + + assert len(run1_keys) == num_processes * num_operations + assert len(run2_keys) == num_processes * num_operations + + # Check for database lock errors specifically + database_lock_errors = [ + e for e in all_errors if "database is locked" in str(e).lower() + ] + if database_lock_errors: + pytest.fail( + f"Database lock errors in mixed operations test: {database_lock_errors}" + ) + + # Check for any other errors + assert not all_errors, f"Mixed operations errors occurred: {all_errors}" From 75356817dac160bd6e069f41ab109f6b4f84cd0a Mon Sep 17 00:00:00 2001 From: Tao Peng Date: Sun, 31 Aug 2025 19:19:13 -0700 Subject: [PATCH 04/10] tests --- tests/unit/test_persistent_cache.py | 644 +++++++++++++++------------- 1 file changed, 350 insertions(+), 294 deletions(-) diff --git a/tests/unit/test_persistent_cache.py b/tests/unit/test_persistent_cache.py index 22424fae..baa3bdc8 100644 --- a/tests/unit/test_persistent_cache.py +++ b/tests/unit/test_persistent_cache.py @@ -202,9 +202,9 @@ def test_corrupted_data(tmpdir): cache.clear_expired() -def test_concurrency(tmpdir): - """Test concurrent access to the cache - fixed version.""" - cache_file = os.path.join(tmpdir, f"cache_concurrency") +def test_multithread_basic_concurrency(tmpdir): + """Test basic concurrent access to the cache.""" + cache_file = os.path.join(tmpdir, f"cache_basic_concurrency") cache = PersistentCache(cache_file) num_threads = 20 @@ -268,96 +268,43 @@ def _multiprocess_worker(process_id, cache_file, num_ops): return None -def test_multiprocess_database_lock(tmpdir): - """Test multiprocess access that might trigger database lock issues.""" - cache_file = os.path.join(tmpdir, "cache_multiprocess") +def test_multithread_database_lock_detection(tmpdir): + """Comprehensive test for database lock detection across multiple scenarios.""" - # Use multiprocessing to create real process contention + # Phase 1: Multiprocess access that might trigger database lock issues + cache_file_mp = os.path.join(tmpdir, "cache_multiprocess") num_processes = 8 num_operations = 20 with concurrent.futures.ProcessPoolExecutor(max_workers=num_processes) as executor: futures = [ - executor.submit(_multiprocess_worker, i, cache_file, num_operations) + executor.submit(_multiprocess_worker, i, cache_file_mp, num_operations) for i in range(num_processes) ] results = [ future.result() for future in concurrent.futures.as_completed(futures) ] - # Check for errors + # Check for errors in multiprocess phase errors = [r for r in results if r is not None] database_lock_errors = [e for e in errors if "database is locked" in str(e).lower()] if database_lock_errors: pytest.fail( - f"Database lock errors in multiprocess test: {database_lock_errors}" + f"Database lock errors in multiprocess phase: {database_lock_errors}" ) if errors: - pytest.fail(f"Multiprocess errors: {errors}") + pytest.fail(f"Multiprocess phase errors: {errors}") - -def test_rapid_file_creation_database_lock(tmpdir): - """Test rapid database file creation that might trigger lock issues.""" - base_path = os.path.join(tmpdir, f"rapid_creation") - - def rapid_creator(thread_id): - """Create and use cache files rapidly.""" - errors = [] - try: - for i in range(10): - # Create a new cache file for each operation - cache_file = f"{base_path}_{thread_id}_{i}" - cache = PersistentCache(cache_file) - - # Perform operations immediately after creation - cache.set("test_key", f"test_value_{thread_id}_{i}") - result = cache.get("test_key") - - if result != f"test_value_{thread_id}_{i}": - errors.append( - f"Thread {thread_id}, iteration {i}: Expected test_value_{thread_id}_{i}, got {result}" - ) - - except Exception as e: - errors.append(f"Thread {thread_id} error: {str(e)}") - - return errors - - # Run multiple threads creating cache files rapidly - num_threads = 20 - all_errors = [] - - with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: - futures = [executor.submit(rapid_creator, i) for i in range(num_threads)] - for future in concurrent.futures.as_completed(futures): - errors = future.result() - all_errors.extend(errors) - - # Check for database lock errors - database_lock_errors = [ - e for e in all_errors if "database is locked" in str(e).lower() - ] - if database_lock_errors: - pytest.fail( - f"Database lock errors in rapid creation test: {database_lock_errors}" - ) - - if all_errors: - pytest.fail(f"Rapid creation errors: {all_errors[:5]}") # Show first 5 errors - - -def test_simultaneous_database_operations(tmpdir): - """Test simultaneous database operations that might cause locking.""" - cache_file = os.path.join(tmpdir, f"cache_simultaneous") - - errors = [] + # Phase 2: Simultaneous database operations that might cause locking + cache_file_sim = os.path.join(tmpdir, "cache_simultaneous") + sim_errors = [] def synchronized_worker(thread_id): """Worker that starts operations simultaneously.""" try: - cache = PersistentCache(cache_file) + cache = PersistentCache(cache_file_sim) # All threads perform operations at the same time for i in range(20): @@ -370,35 +317,37 @@ def synchronized_worker(thread_id): # Immediate read back result = cache.get(key) if result != value: - errors.append(f"Thread {thread_id}: Expected {value}, got {result}") + sim_errors.append( + f"Thread {thread_id}: Expected {value}, got {result}" + ) # Mixed operations if i % 2 == 0: cache.clear_expired() except Exception as e: - errors.append(f"Thread {thread_id} error: {str(e)}") + sim_errors.append(f"Thread {thread_id} error: {str(e)}") # Start all threads simultaneously using ThreadPoolExecutor with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: futures = [executor.submit(synchronized_worker, i) for i in range(10)] concurrent.futures.wait(futures) - # Check for database lock errors - database_lock_errors = [e for e in errors if "database is locked" in str(e).lower()] - if database_lock_errors: + # Check for database lock errors in simultaneous phase + sim_database_lock_errors = [ + e for e in sim_errors if "database is locked" in str(e).lower() + ] + if sim_database_lock_errors: pytest.fail( - f"Database lock errors in simultaneous operations test: {database_lock_errors}" + f"Database lock errors in simultaneous operations phase: {sim_database_lock_errors}" ) - if errors: - pytest.fail(f"Simultaneous operation errors: {errors[:10]}") + if sim_errors: + pytest.fail(f"Simultaneous operation errors: {sim_errors[:10]}") - -def test_stress_database_with_exceptions(tmpdir): - """Stress test that might trigger database lock issues with exception handling.""" - cache_file = os.path.join(tmpdir, f"cache_stress") - cache = PersistentCache(cache_file) + # Phase 3: Stress test with exception handling + cache_file_stress = os.path.join(tmpdir, "cache_stress") + cache = PersistentCache(cache_file_stress) def stress_worker(thread_id): """Worker that performs operations and handles exceptions.""" @@ -427,7 +376,6 @@ def stress_worker(thread_id): pass # Ignore access errors except Exception as e: - # raise e error_msg = str(e).lower() if "database is locked" in error_msg: database_lock_count += 1 @@ -448,33 +396,83 @@ def stress_worker(thread_id): total_lock_errors += lock_count all_other_errors.extend(other_errors) - # Report results + # Report results from stress phase if total_lock_errors > 0: pytest.fail( - f"Database lock errors detected: {total_lock_errors} total lock errors" + f"Database lock errors detected in stress phase: {total_lock_errors} total lock errors" ) if all_other_errors: pytest.fail(f"Other stress test errors: {all_other_errors[:5]}") -def test_shared_cache_instance_database_lock(tmpdir): - """Test shared cache instance across threads - reproduces real uploader.py usage pattern.""" - cache_file = os.path.join(tmpdir, f"cache_shared") +def test_multithread_rapid_file_creation(tmpdir): + """Test rapid database file creation that might trigger lock issues.""" + base_path = os.path.join(tmpdir, f"rapid_creation") + + def rapid_creator(thread_id): + """Create and use cache files rapidly.""" + errors = [] + try: + for i in range(10): + # Create a new cache file for each operation + cache_file = f"{base_path}_{thread_id}_{i}" + cache = PersistentCache(cache_file) - # Create a single shared cache instance (like in uploader.py) + # Perform operations immediately after creation + cache.set("test_key", f"test_value_{thread_id}_{i}") + result = cache.get("test_key") + + if result != f"test_value_{thread_id}_{i}": + errors.append( + f"Thread {thread_id}, iteration {i}: Expected test_value_{thread_id}_{i}, got {result}" + ) + + except Exception as e: + errors.append(f"Thread {thread_id} error: {str(e)}") + + return errors + + # Run multiple threads creating cache files rapidly + num_threads = 20 + all_errors = [] + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(rapid_creator, i) for i in range(num_threads)] + for future in concurrent.futures.as_completed(futures): + errors = future.result() + all_errors.extend(errors) + + # Check for database lock errors + database_lock_errors = [ + e for e in all_errors if "database is locked" in str(e).lower() + ] + if database_lock_errors: + pytest.fail( + f"Database lock errors in rapid creation test: {database_lock_errors}" + ) + + if all_errors: + pytest.fail(f"Rapid creation errors: {all_errors[:5]}") # Show first 5 errors + + +def test_multithread_shared_cache_comprehensive(tmpdir): + """Comprehensive test for shared cache instance across threads with get->set pattern validation.""" + cache_file = os.path.join(tmpdir, "cache_shared_comprehensive") + + # Initialize cache once and share across all workers shared_cache = PersistentCache(cache_file) shared_cache.clear_expired() - # Use higher numbers to increase chance of database lock - num_threads = 30 - num_operations = 100 - errors = [] + # Phase 1: Database lock detection with shared cache instance (reproduces real uploader.py usage) + num_threads_phase1 = 30 + num_operations_phase1 = 100 + phase1_errors = [] def shared_cache_worker(thread_id): """Worker that uses the shared cache instance (like CachedImageUploader.upload).""" try: - for i in range(num_operations): + for i in range(num_operations_phase1): key = f"shared_thread_{thread_id}_op_{i}" value = f"shared_value_{thread_id}_{i}" @@ -488,7 +486,7 @@ def shared_cache_worker(thread_id): retrieved = shared_cache.get(key) if retrieved != value: - errors.append( + phase1_errors.append( f"Thread {thread_id}: Expected {value}, got {retrieved}" ) else: @@ -505,30 +503,124 @@ def shared_cache_worker(thread_id): shared_cache.get(other_key) except Exception as e: - errors.append(f"Thread {thread_id} error: {str(e)}") - print(e, traceback.format_exc()) + phase1_errors.append(f"Thread {thread_id} error: {str(e)}") - # Run all threads using the same shared cache instance - with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: - futures = [executor.submit(shared_cache_worker, i) for i in range(num_threads)] + # Run phase 1: Database lock detection + with concurrent.futures.ThreadPoolExecutor( + max_workers=num_threads_phase1 + ) as executor: + futures = [ + executor.submit(shared_cache_worker, i) for i in range(num_threads_phase1) + ] concurrent.futures.wait(futures) - # Check for database lock errors - database_lock_errors = [e for e in errors if "database is locked" in str(e).lower()] + # Check for database lock errors in phase 1 + database_lock_errors = [ + e for e in phase1_errors if "database is locked" in str(e).lower() + ] if database_lock_errors: pytest.fail( f"Database lock errors with shared cache instance: {database_lock_errors}" ) # Check for data consistency errors (values not persisting correctly) - data_consistency_errors = [e for e in errors if "Expected" in e and "got None" in e] + data_consistency_errors = [ + e for e in phase1_errors if "Expected" in e and "got None" in e + ] if data_consistency_errors: pytest.fail( f"Data consistency errors with shared cache (race conditions): {data_consistency_errors[:5]}" ) - if errors: - pytest.fail(f"Other shared cache errors: {errors[:10]}") # Show first 10 errors + if phase1_errors: + pytest.fail(f"Phase 1 shared cache errors: {phase1_errors[:10]}") + + # Phase 2: Get->set pattern validation with cache correctness verification + shared_cache.clear_expired() + + num_threads_phase2 = 10 + num_operations_phase2 = 20 + all_set_values = [] + + # First concurrent run with get->set pattern + with concurrent.futures.ThreadPoolExecutor( + max_workers=num_threads_phase2 + ) as executor: + futures = [ + executor.submit( + _shared_worker_get_set_pattern, + i, + shared_cache, + num_operations_phase2, + "phase2_run1", + ) + for i in range(num_threads_phase2) + ] + for future in concurrent.futures.as_completed(futures): + set_values = future.result() + all_set_values.extend(set_values) + + # Verify that all set values are cached correctly + for key, expected_value in all_set_values: + cached_value = shared_cache.get(key) + assert ( + cached_value == expected_value + ), f"Phase2 Run1 - Key {key}: expected {expected_value}, got {cached_value}" + + # Clear expired between concurrent runs + expired_keys_1 = shared_cache.clear_expired() + assert isinstance(expired_keys_1, list), "clear_expired should return a list" + + # Second concurrent run with get->set pattern + run2_set_values = [] + with concurrent.futures.ThreadPoolExecutor( + max_workers=num_threads_phase2 + ) as executor: + futures = [ + executor.submit( + _shared_worker_get_set_pattern, + i, + shared_cache, + num_operations_phase2, + "phase2_run2", + ) + for i in range(num_threads_phase2) + ] + for future in concurrent.futures.as_completed(futures): + set_values = future.result() + run2_set_values.extend(set_values) + + # Verify that all run2 set values are cached correctly + for key, expected_value in run2_set_values: + cached_value = shared_cache.get(key) + assert ( + cached_value == expected_value + ), f"Phase2 Run2 - Key {key}: expected {expected_value}, got {cached_value}" + + # Final validation + expired_keys_2 = shared_cache.clear_expired() + assert isinstance(expired_keys_2, list), "clear_expired should return a list" + + # Assertions using keys() and counts + all_keys = list(shared_cache.keys()) + expected_keys_count = num_threads_phase2 * num_operations_phase2 * 2 # Two runs + + assert ( + len(all_keys) == expected_keys_count + ), f"Expected {expected_keys_count} keys, got {len(all_keys)}" + + # Verify key patterns + run1_keys = [k for k in all_keys if b"phase2_run1" in k] + run2_keys = [k for k in all_keys if b"phase2_run2" in k] + + assert len(run1_keys) == num_threads_phase2 * num_operations_phase2 + assert len(run2_keys) == num_threads_phase2 * num_operations_phase2 + + # Verify that the total number of set operations matches expectations + total_set_operations = len(all_set_values) + len(run2_set_values) + assert ( + total_set_operations == expected_keys_count + ), f"Expected {expected_keys_count} set operations, got {total_set_operations}" def test_decode_invalid_data(tmpdir): @@ -562,28 +654,24 @@ def _shared_worker_get_set_pattern( worker_id, cache, num_operations, key_prefix="worker" ): """Shared worker implementation: get key -> if not exist then set key=value.""" - errors = [] - try: - for i in range(num_operations): - key = f"{key_prefix}_{worker_id}_{i}" - value = f"value_{worker_id}_{i}" + set_values = [] + for i in range(num_operations): + key = f"{key_prefix}_{worker_id}_{i}" + value = f"value_{worker_id}_{i}" - # Pattern: get a key -> if not exist then set key=value - existing_value = cache.get(key) - if existing_value is None: - cache.set(key, value, expires_in=10) + # Pattern: get a key -> if not exist then set key=value + existing_value = cache.get(key) + if existing_value is None: + cache.set(key, value, expires_in=10) + set_values.append((key, value)) - # Verify the value was set correctly - retrieved_value = cache.get(key) - if retrieved_value != value: - errors.append( - f"Worker {worker_id}: Expected {value}, got {retrieved_value}" - ) + # Verify the value was set correctly + retrieved_value = cache.get(key) + assert ( + retrieved_value == value + ), f"Worker {worker_id}: Expected {value}, got {retrieved_value}" - except Exception as e: - errors.append(f"Worker {worker_id} error: {str(e)}") - - return errors + return set_values def _multiprocess_worker_get_set_pattern(args): @@ -630,147 +718,198 @@ def _mixed_operations_worker(args): return errors -def test_multithread_shared_cache_get_set_pattern(tmpdir): - """Test multithreaded access with shared cache instance using get->set pattern.""" - cache_file = os.path.join(tmpdir, "cache_multithread_shared") +def test_multiprocess_comprehensive_patterns(tmpdir): + """Comprehensive test for multiprocess access simulating multiple independent applications. - # Initialize cache once and share across all workers - shared_cache = PersistentCache(cache_file) + Each process creates its own PersistentCache instance to simulate running multiple apps + that access the same cache file independently. + """ - # Clear expired before first concurrent run - shared_cache.clear_expired() + # Phase 1: Basic multiprocess get->set pattern with cache correctness validation + # Each process will create its own PersistentCache instance (simulating multiple apps) + cache_file_basic = os.path.join(tmpdir, "cache_multiprocess_basic") - num_threads = 10 - num_operations = 20 - all_errors = [] + num_processes_basic = 6 + num_operations_basic = 15 - # First concurrent run - with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + # First concurrent run - basic pattern (each process creates its own cache instance) + with concurrent.futures.ProcessPoolExecutor( + max_workers=num_processes_basic + ) as executor: + args_list = [ + (i, cache_file_basic, num_operations_basic, "basic_run1") + for i in range(num_processes_basic) + ] futures = [ - executor.submit( - _shared_worker_get_set_pattern, i, shared_cache, num_operations, "run1" - ) - for i in range(num_threads) + executor.submit(_multiprocess_worker_get_set_pattern, args) + for args in args_list + ] + results_1 = [ + future.result() for future in concurrent.futures.as_completed(futures) ] - for future in concurrent.futures.as_completed(futures): - errors = future.result() - all_errors.extend(errors) - # Clear expired between concurrent runs - expired_keys_1 = shared_cache.clear_expired() + # Collect all set values from first run and verify they are cached correctly + # Create a separate validation cache instance (simulating another app checking the results) + validation_cache_basic = PersistentCache(cache_file_basic) + all_set_values_run1 = [] + for result in results_1: + all_set_values_run1.extend(result) + + # Verify that all set values from run1 are cached correctly + for key, expected_value in all_set_values_run1: + cached_value = validation_cache_basic.get(key) + assert ( + cached_value == expected_value + ), f"Basic Run1 - Key {key}: expected {expected_value}, got {cached_value}" + + # Clear expired between concurrent runs (using validation cache instance) + expired_keys_1 = validation_cache_basic.clear_expired() assert isinstance(expired_keys_1, list), "clear_expired should return a list" - # Second concurrent run - with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + # Second concurrent run - basic pattern (each process creates its own cache instance) + with concurrent.futures.ProcessPoolExecutor( + max_workers=num_processes_basic + ) as executor: + args_list = [ + (i, cache_file_basic, num_operations_basic, "basic_run2") + for i in range(num_processes_basic) + ] futures = [ - executor.submit( - _shared_worker_get_set_pattern, i, shared_cache, num_operations, "run2" - ) - for i in range(num_threads) + executor.submit(_multiprocess_worker_get_set_pattern, args) + for args in args_list + ] + results_2 = [ + future.result() for future in concurrent.futures.as_completed(futures) ] - for future in concurrent.futures.as_completed(futures): - errors = future.result() - all_errors.extend(errors) - # Clear expired after second concurrent run - expired_keys_2 = shared_cache.clear_expired() - assert isinstance(expired_keys_2, list), "clear_expired should return a list" + # Collect all set values from second run and verify they are cached correctly + all_set_values_run2 = [] + for result in results_2: + all_set_values_run2.extend(result) - # Assertions using keys() and counts - all_keys = list(shared_cache.keys()) - expected_keys_count = num_threads * num_operations * 2 # Two runs + # Verify that all set values from run2 are cached correctly + for key, expected_value in all_set_values_run2: + cached_value = validation_cache_basic.get(key) + assert ( + cached_value == expected_value + ), f"Basic Run2 - Key {key}: expected {expected_value}, got {cached_value}" - assert ( - len(all_keys) == expected_keys_count - ), f"Expected {expected_keys_count} keys, got {len(all_keys)}" + # Validate basic pattern results + expired_keys_2 = validation_cache_basic.clear_expired() + assert isinstance(expired_keys_2, list), "clear_expired should return a list" - # Verify key patterns - run1_keys = [k for k in all_keys if b"run1" in k] - run2_keys = [k for k in all_keys if b"run2" in k] + all_keys_basic = list(validation_cache_basic.keys()) + expected_keys_count_basic = ( + num_processes_basic * num_operations_basic * 2 + ) # Two runs - assert len(run1_keys) == num_threads * num_operations - assert len(run2_keys) == num_threads * num_operations + assert ( + len(all_keys_basic) == expected_keys_count_basic + ), f"Expected {expected_keys_count_basic} keys, got {len(all_keys_basic)}" - # Check for any worker errors - assert not all_errors, f"Worker errors occurred: {all_errors}" + # Verify key patterns for basic test + run1_keys = [k for k in all_keys_basic if b"basic_run1" in k] + run2_keys = [k for k in all_keys_basic if b"basic_run2" in k] + assert len(run1_keys) == num_processes_basic * num_operations_basic + assert len(run2_keys) == num_processes_basic * num_operations_basic -def test_multiprocess_shared_file_get_set_pattern(tmpdir): - """Test multiprocess access with shared cache file using get->set pattern.""" - cache_file = os.path.join(tmpdir, "cache_multiprocess_shared") + # Verify that the total number of set operations matches expectations + total_set_operations = len(all_set_values_run1) + len(all_set_values_run2) + assert ( + total_set_operations == expected_keys_count_basic + ), f"Expected {expected_keys_count_basic} set operations, got {total_set_operations}" - # Initialize cache and clear expired before first concurrent run - init_cache = PersistentCache(cache_file) - init_cache.clear_expired() + # Phase 2: Mixed operations pattern with cross-process reads and cleanup + # Each process will create its own PersistentCache instance (simulating multiple apps) + cache_file_mixed = os.path.join(tmpdir, "cache_multiprocess_mixed") - num_processes = 6 - num_operations = 15 + num_processes_mixed = 8 + num_operations_mixed = 25 - # First concurrent run - with concurrent.futures.ProcessPoolExecutor(max_workers=num_processes) as executor: + # First concurrent run - mixed operations (each process creates its own cache instance) + with concurrent.futures.ProcessPoolExecutor( + max_workers=num_processes_mixed + ) as executor: args_list = [ - (i, cache_file, num_operations, "proc_run1") for i in range(num_processes) + (i, cache_file_mixed, num_operations_mixed, "mixed_run1") + for i in range(num_processes_mixed) ] futures = [ - executor.submit(_multiprocess_worker_get_set_pattern, args) - for args in args_list + executor.submit(_mixed_operations_worker, args) for args in args_list ] - results_1 = [ + results_mixed_1 = [ future.result() for future in concurrent.futures.as_completed(futures) ] + # Create a separate validation cache instance for mixed operations + validation_cache_mixed = PersistentCache(cache_file_mixed) + # Clear expired between concurrent runs - expired_keys_1 = init_cache.clear_expired() - assert isinstance(expired_keys_1, list), "clear_expired should return a list" + expired_keys_mixed_1 = validation_cache_mixed.clear_expired() + assert isinstance(expired_keys_mixed_1, list), "clear_expired should return a list" + keys_after_mixed_run1 = list(validation_cache_mixed.keys()) - # Second concurrent run - with concurrent.futures.ProcessPoolExecutor(max_workers=num_processes) as executor: + # Verify first mixed run results + expected_keys_mixed_run1 = num_processes_mixed * num_operations_mixed + assert ( + len(keys_after_mixed_run1) == expected_keys_mixed_run1 + ), f"Expected {expected_keys_mixed_run1} keys after mixed run1, got {len(keys_after_mixed_run1)}" + + # Second concurrent run - mixed operations (each process creates its own cache instance) + with concurrent.futures.ProcessPoolExecutor( + max_workers=num_processes_mixed + ) as executor: args_list = [ - (i, cache_file, num_operations, "proc_run2") for i in range(num_processes) + (i, cache_file_mixed, num_operations_mixed, "mixed_run2") + for i in range(num_processes_mixed) ] futures = [ - executor.submit(_multiprocess_worker_get_set_pattern, args) - for args in args_list + executor.submit(_mixed_operations_worker, args) for args in args_list ] - results_2 = [ + results_mixed_2 = [ future.result() for future in concurrent.futures.as_completed(futures) ] - # Clear expired after second concurrent run - expired_keys_2 = init_cache.clear_expired() - assert isinstance(expired_keys_2, list), "clear_expired should return a list" + # Validate mixed operations results + expired_keys_mixed_2 = validation_cache_mixed.clear_expired() + assert isinstance(expired_keys_mixed_2, list), "clear_expired should return a list" + keys_after_mixed_run2 = list(validation_cache_mixed.keys()) - # Collect all errors from both runs - all_errors = [] - for result in results_1 + results_2: - all_errors.extend(result) + # Collect all errors from both mixed runs + all_mixed_errors = [] + for result in results_mixed_1 + results_mixed_2: + all_mixed_errors.extend(result) - # Assertions using keys() and counts - all_keys = list(init_cache.keys()) - expected_keys_count = num_processes * num_operations * 2 # Two runs + # Assertions for mixed operations + expected_keys_count_mixed = ( + num_processes_mixed * num_operations_mixed * 2 + ) # Two runs assert ( - len(all_keys) == expected_keys_count - ), f"Expected {expected_keys_count} keys, got {len(all_keys)}" + len(keys_after_mixed_run2) == expected_keys_count_mixed + ), f"Expected {expected_keys_count_mixed} keys, got {len(keys_after_mixed_run2)}" - # Verify key patterns - run1_keys = [k for k in all_keys if b"proc_run1" in k] - run2_keys = [k for k in all_keys if b"proc_run2" in k] + # Verify key patterns for mixed operations + mixed_run1_keys = [k for k in keys_after_mixed_run2 if b"mixed_run1" in k] + mixed_run2_keys = [k for k in keys_after_mixed_run2 if b"mixed_run2" in k] - assert len(run1_keys) == num_processes * num_operations - assert len(run2_keys) == num_processes * num_operations + assert len(mixed_run1_keys) == num_processes_mixed * num_operations_mixed + assert len(mixed_run2_keys) == num_processes_mixed * num_operations_mixed # Check for database lock errors specifically database_lock_errors = [ - e for e in all_errors if "database is locked" in str(e).lower() + e for e in all_mixed_errors if "database is locked" in str(e).lower() ] if database_lock_errors: pytest.fail( - f"Database lock errors in multiprocess test: {database_lock_errors}" + f"Database lock errors in multiprocess comprehensive test: {database_lock_errors}" ) # Check for any other errors - assert not all_errors, f"Process errors occurred: {all_errors}" + assert ( + not all_mixed_errors + ), f"Multiprocess comprehensive errors occurred: {all_mixed_errors}" def test_multithread_high_contention_get_set_pattern(tmpdir): @@ -858,86 +997,3 @@ def high_contention_worker(worker_id, run_prefix): # Check for any worker errors assert not all_errors, f"Worker errors occurred: {all_errors}" - - -def test_multiprocess_mixed_operations_get_set_pattern(tmpdir): - """Test multiprocess with mixed operations using get->set pattern.""" - cache_file = os.path.join(tmpdir, "cache_multiprocess_mixed") - - # Initialize cache and clear expired before first concurrent run - init_cache = PersistentCache(cache_file) - init_cache.clear_expired() - - num_processes = 8 - num_operations = 25 - - # First concurrent run - with concurrent.futures.ProcessPoolExecutor(max_workers=num_processes) as executor: - args_list = [ - (i, cache_file, num_operations, "mixed_run1") for i in range(num_processes) - ] - futures = [ - executor.submit(_mixed_operations_worker, args) for args in args_list - ] - results_1 = [ - future.result() for future in concurrent.futures.as_completed(futures) - ] - - # Clear expired between concurrent runs - expired_keys_1 = init_cache.clear_expired() - assert isinstance(expired_keys_1, list), "clear_expired should return a list" - keys_after_run1 = list(init_cache.keys()) - - # Verify first run results - expected_keys_run1 = num_processes * num_operations - assert ( - len(keys_after_run1) == expected_keys_run1 - ), f"Expected {expected_keys_run1} keys after run1, got {len(keys_after_run1)}" - - # Second concurrent run - with concurrent.futures.ProcessPoolExecutor(max_workers=num_processes) as executor: - args_list = [ - (i, cache_file, num_operations, "mixed_run2") for i in range(num_processes) - ] - futures = [ - executor.submit(_mixed_operations_worker, args) for args in args_list - ] - results_2 = [ - future.result() for future in concurrent.futures.as_completed(futures) - ] - - # Clear expired after second concurrent run - expired_keys_2 = init_cache.clear_expired() - assert isinstance(expired_keys_2, list), "clear_expired should return a list" - keys_after_run2 = list(init_cache.keys()) - - # Collect all errors from both runs - all_errors = [] - for result in results_1 + results_2: - all_errors.extend(result) - - # Assertions using keys() and counts - expected_keys_count = num_processes * num_operations * 2 # Two runs - - assert ( - len(keys_after_run2) == expected_keys_count - ), f"Expected {expected_keys_count} keys, got {len(keys_after_run2)}" - - # Verify key patterns - run1_keys = [k for k in keys_after_run2 if b"mixed_run1" in k] - run2_keys = [k for k in keys_after_run2 if b"mixed_run2" in k] - - assert len(run1_keys) == num_processes * num_operations - assert len(run2_keys) == num_processes * num_operations - - # Check for database lock errors specifically - database_lock_errors = [ - e for e in all_errors if "database is locked" in str(e).lower() - ] - if database_lock_errors: - pytest.fail( - f"Database lock errors in mixed operations test: {database_lock_errors}" - ) - - # Check for any other errors - assert not all_errors, f"Mixed operations errors occurred: {all_errors}" From 3eb126ea18d5b19f9cbdd9ec10f06701bcdf2beb Mon Sep 17 00:00:00 2001 From: Tao Peng Date: Mon, 1 Sep 2025 21:32:29 -0700 Subject: [PATCH 05/10] add store.py --- mapillary_tools/store.py | 112 +++++++++++++ tests/unit/test_store.py | 345 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 457 insertions(+) create mode 100644 mapillary_tools/store.py create mode 100644 tests/unit/test_store.py diff --git a/mapillary_tools/store.py b/mapillary_tools/store.py new file mode 100644 index 00000000..c75b560a --- /dev/null +++ b/mapillary_tools/store.py @@ -0,0 +1,112 @@ +import os +import sqlite3 +from collections.abc import MutableMapping +from contextlib import closing, suppress +from pathlib import Path + +BUILD_TABLE = """ + CREATE TABLE IF NOT EXISTS Dict ( + key BLOB UNIQUE NOT NULL, + value BLOB NOT NULL + ) +""" +GET_SIZE = "SELECT COUNT (key) FROM Dict" +LOOKUP_KEY = "SELECT value FROM Dict WHERE key = CAST(? AS BLOB)" +STORE_KV = "REPLACE INTO Dict (key, value) VALUES (CAST(? AS BLOB), CAST(? AS BLOB))" +DELETE_KEY = "DELETE FROM Dict WHERE key = CAST(? AS BLOB)" +ITER_KEYS = "SELECT key FROM Dict" + + +_ERR_CLOSED = "KeyValueStore object has already been closed" + + +def _normalize_uri(path): + path = Path(path) + uri = path.absolute().as_uri() + while "//" in uri: + uri = uri.replace("//", "/") + return uri + + +class KeyValueStore(MutableMapping): + def __init__(self, path, /, *, flag="r", mode=0o666): + """Open a key-value database and return the object. + + The 'path' parameter is the name of the database file. + + The optional 'flag' parameter can be one of ...: + 'r' (default): open an existing database for read only access + 'w': open an existing database for read/write access + 'c': create a database if it does not exist; open for read/write access + 'n': always create a new, empty database; open for read/write access + + The optional 'mode' parameter is the Unix file access mode of the database; + only used when creating a new database. Default: 0o666. + """ + path = os.fsdecode(path) + if flag == "r": + flag = "ro" + elif flag == "w": + flag = "rw" + elif flag == "c": + flag = "rwc" + Path(path).touch(mode=mode, exist_ok=True) + elif flag == "n": + flag = "rwc" + Path(path).unlink(missing_ok=True) + Path(path).touch(mode=mode) + else: + raise ValueError(f"Flag must be one of 'r', 'w', 'c', or 'n', not {flag!r}") + + # We use the URI format when opening the database. + uri = _normalize_uri(path) + uri = f"{uri}?mode={flag}" + + self._cx = sqlite3.connect(uri, autocommit=True, uri=True) + + # This is an optimization only; it's ok if it fails. + with suppress(sqlite3.OperationalError): + self._cx.execute("PRAGMA journal_mode = wal") + + if flag == "rwc": + self._execute(BUILD_TABLE) + + def _execute(self, *args, **kwargs): + return closing(self._cx.execute(*args, **kwargs)) + + def __len__(self): + with self._execute(GET_SIZE) as cu: + row = cu.fetchone() + return row[0] + + def __getitem__(self, key): + with self._execute(LOOKUP_KEY, (key,)) as cu: + row = cu.fetchone() + if not row: + raise KeyError(key) + return row[0] + + def __setitem__(self, key, value): + self._execute(STORE_KV, (key, value)) + + def __delitem__(self, key): + with self._execute(DELETE_KEY, (key,)) as cu: + if not cu.rowcount: + raise KeyError(key) + + def __iter__(self): + with self._execute(ITER_KEYS) as cu: + for row in cu: + yield row[0] + + def close(self): + self._cx.close() + + def keys(self): + return list(super().keys()) + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() diff --git a/tests/unit/test_store.py b/tests/unit/test_store.py new file mode 100644 index 00000000..a4a68643 --- /dev/null +++ b/tests/unit/test_store.py @@ -0,0 +1,345 @@ +import os +import sqlite3 +import tempfile + +import pytest + +from mapillary_tools.store import KeyValueStore + + +def test_basic_dict_operations(tmpdir): + """Test that KeyValueStore behaves like a dict for basic operations.""" + db_path = tmpdir.join("test.db") + + with KeyValueStore(str(db_path), flag="c", mode=0o666) as store: + # Test setting and getting - strings are stored as bytes + store["key1"] = "value1" + store["key2"] = "value2" + + assert store["key1"] == b"value1" + assert store["key2"] == b"value2" + + # Test length + assert len(store) == 2 + + # Test iteration - keys come back as bytes + keys = list(store) + assert set(keys) == {b"key1", b"key2"} + + # Test keys() method + assert set(store.keys()) == {b"key1", b"key2"} + + # Test deletion + del store["key1"] + assert len(store) == 1 + assert "key1" not in store + assert "key2" in store + + +def test_keyerror_on_missing_key(tmpdir): + """Test that KeyError is raised for missing keys.""" + db_path = tmpdir.join("test.db") + + with KeyValueStore(str(db_path), flag="c", mode=0o666) as store: + with pytest.raises(KeyError): + _ = store["nonexistent"] + + with pytest.raises(KeyError): + del store["nonexistent"] + + +def test_flag_modes(tmpdir): + """Test different flag modes.""" + db_path = tmpdir.join("test.db") + + # Test 'n' flag - creates new database + with KeyValueStore(str(db_path), flag="n", mode=0o666) as store: + store["key"] = "value" + assert store["key"] == b"value" + + # Test 'c' flag - opens existing or creates new + with KeyValueStore(str(db_path), flag="c", mode=0o666) as store: + assert store["key"] == b"value" # Should still exist + store["key2"] = "value2" + + # Test 'w' flag - opens existing for read/write + with KeyValueStore(str(db_path), flag="w", mode=0o666) as store: + assert store["key"] == b"value" + assert store["key2"] == b"value2" + store["key3"] = "value3" + + # Test 'r' flag - read-only mode + with KeyValueStore(str(db_path), flag="r", mode=0o666) as store: + # Getter should work in readonly mode + assert store["key"] == b"value" + assert store["key2"] == b"value2" + assert store["key3"] == b"value3" + assert len(store) == 3 + + # Iteration should work in readonly mode + keys = set(store.keys()) + assert keys == {b"key", b"key2", b"key3"} + + # But setter should fail + with pytest.raises(sqlite3.OperationalError): + store["new_key"] = "new_value" + + +def test_readonly_mode_comprehensive(tmpdir): + """Test comprehensive readonly mode functionality.""" + db_path = tmpdir.join("readonly_test.db") + + # First, create and populate the database + with KeyValueStore(str(db_path), flag="c", mode=0o666) as store: + store["test_key"] = "test_value" + store["another_key"] = "another_value" + store[b"binary_key"] = b"binary_value" + store[123] = "numeric_key_value" + + # Now open in readonly mode and test all read operations + with KeyValueStore(str(db_path), flag="r", mode=0o666) as readonly_store: + # Test basic getitem - strings come back as bytes + assert readonly_store["test_key"] == b"test_value" + assert readonly_store["another_key"] == b"another_value" + assert readonly_store[b"binary_key"] == b"binary_value" + assert readonly_store[123] == b"numeric_key_value" + + # Test len + assert len(readonly_store) == 4 + + # Test iteration - keys come back as bytes + all_keys = set(readonly_store) + assert all_keys == {b"test_key", b"another_key", b"binary_key", b"123"} + + # Test keys() method + assert set(readonly_store.keys()) == all_keys + + # Test containment + assert "test_key" in readonly_store + assert "nonexistent" not in readonly_store + + # Test that write operations fail + with pytest.raises( + sqlite3.OperationalError, match="attempt to write a readonly database" + ): + readonly_store["new_key"] = "should_fail" + + with pytest.raises( + sqlite3.OperationalError, match="attempt to write a readonly database" + ): + del readonly_store["test_key"] + + +def test_invalid_flag(): + """Test that invalid flags raise ValueError.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + + with pytest.raises(ValueError, match="Flag must be one of"): + KeyValueStore(db_path, flag="x", mode=0o666) + + +def test_context_manager(tmpdir): + """Test context manager functionality.""" + db_path = tmpdir.join("context_test.db") + + # Test normal context manager usage + with KeyValueStore(str(db_path), flag="c", mode=0o666) as store: + store["key"] = "value" + assert store["key"] == b"value" + + # After exiting context, store should be closed + with pytest.raises( + sqlite3.ProgrammingError, match="Cannot operate on a closed database" + ): + _ = store["key"] + + +def test_manual_close(tmpdir): + """Test manual close functionality.""" + db_path = tmpdir.join("close_test.db") + + store = KeyValueStore(str(db_path), flag="c", mode=0o666) + store["key"] = "value" + assert store["key"] == b"value" + + store.close() + + # After closing, operations should fail + with pytest.raises( + sqlite3.ProgrammingError, match="Cannot operate on a closed database" + ): + _ = store["key"] + + with pytest.raises( + sqlite3.ProgrammingError, match="Cannot operate on a closed database" + ): + store["new_key"] = "new_value" + + +def test_open_function(tmpdir): + """Test the open() function.""" + db_path = tmpdir.join("open_test.db") + + # Test default parameters + with KeyValueStore(str(db_path), flag="c") as store: + store["key"] = "value" + assert store["key"] == b"value" + + # Test with explicit parameters + with KeyValueStore(str(db_path), flag="w", mode=0o644) as store: + assert store["key"] == b"value" + store["key2"] = "value2" + + # Test readonly + with KeyValueStore(str(db_path), flag="r") as store: + assert store["key"] == b"value" + assert store["key2"] == b"value2" + + +def test_binary_and_various_key_types(tmpdir): + """Test that the store can handle various key and value types.""" + db_path = tmpdir.join("types_test.db") + + with KeyValueStore(str(db_path), flag="c", mode=0o666) as store: + # String keys and values + store["string_key"] = "string_value" + + # Binary keys and values + store[b"binary_key"] = b"binary_value" + + # Numeric keys + store[123] = "numeric_key" + store["numeric_value"] = 456 + + # Mixed types - use different key names to avoid collision + store[b"mixed_binary"] = "string_value_for_binary_key" + store["mixed_string"] = b"binary_value_for_string_key" + + # Verify all work - strings come back as bytes + assert store["string_key"] == b"string_value" + assert store[b"binary_key"] == b"binary_value" + assert store[123] == b"numeric_key" + assert store["numeric_value"] == b"456" + assert store[b"mixed_binary"] == b"string_value_for_binary_key" + assert store["mixed_string"] == b"binary_value_for_string_key" + + +def test_overwrite_behavior(tmpdir): + """Test that values can be overwritten.""" + db_path = tmpdir.join("overwrite_test.db") + + with KeyValueStore(str(db_path), flag="c", mode=0o666) as store: + store["key"] = "original_value" + assert store["key"] == b"original_value" + + store["key"] = "new_value" + assert store["key"] == b"new_value" + + # Length should still be 1 + assert len(store) == 1 + + +def test_empty_store(tmpdir): + """Test behavior with empty store.""" + db_path = tmpdir.join("empty_test.db") + + with KeyValueStore(str(db_path), flag="c", mode=0o666) as store: + assert len(store) == 0 + assert list(store) == [] + assert store.keys() == [] + + with pytest.raises(KeyError): + _ = store["any_key"] + + +def test_new_flag_overwrites_existing(tmpdir): + """Test that 'n' flag creates a new database, overwriting existing.""" + db_path = tmpdir.join("new_flag_test.db") + + # Create initial database + with KeyValueStore(str(db_path), flag="c", mode=0o666) as store: + store["key"] = "original" + + # Open with 'n' flag should create new database + with KeyValueStore(str(db_path), flag="n", mode=0o666) as store: + assert len(store) == 0 # Should be empty + store["key"] = "new" + + # Verify the old data is gone + with KeyValueStore(str(db_path), flag="r", mode=0o666) as store: + assert store["key"] == b"new" + assert len(store) == 1 + + +def test_persistence_across_sessions(tmpdir): + """Test that data persists across different sessions.""" + db_path = tmpdir.join("persistence_test.db") + + # First session + with KeyValueStore(str(db_path), flag="c", mode=0o666) as store: + store["persistent_key"] = "persistent_value" + store["another_key"] = "another_value" + + # Second session + with KeyValueStore(str(db_path), flag="r", mode=0o666) as store: + assert store["persistent_key"] == b"persistent_value" + assert store["another_key"] == b"another_value" + assert len(store) == 2 + + +def test_dict_like_membership(tmpdir): + """Test membership operations work like dict.""" + db_path = tmpdir.join("membership_test.db") + + with KeyValueStore(str(db_path), flag="c", mode=0o666) as store: + store["exists"] = "value" + + assert "exists" in store + assert "does_not_exist" not in store + + # Test with different key types + store[123] = "numeric" + store[b"binary"] = "binary_value" + + assert 123 in store + assert b"binary" in store + assert 456 not in store + assert b"not_there" not in store + + +def test_readonly_mode_nonexistent_file(tmpdir): + """Test readonly mode with a non-existent file.""" + db_path = tmpdir.join("nonexistent.db") + + # Try to open a non-existent file in readonly mode + # This should raise an error since the file doesn't exist + with pytest.raises(sqlite3.OperationalError, match="unable to open database file"): + KeyValueStore(str(db_path), flag="r", mode=0o666) + + +def test_readonly_mode_empty_database(tmpdir): + """Test readonly mode with an empty database.""" + db_path = tmpdir.join("empty_readonly.db") + + # First create an empty database + with KeyValueStore(str(db_path), flag="c", mode=0o666) as store: + pass # Create empty database + + # Now open it in readonly mode + with KeyValueStore(str(db_path), flag="r", mode=0o666) as readonly_store: + # Getting a non-existent value should raise KeyError + with pytest.raises(KeyError): + _ = readonly_store["nonexistent_key"] + + # Length should be 0 + assert len(readonly_store) == 0 + + # Keys should be empty + assert list(readonly_store.keys()) == [] + + # Iteration should be empty + assert list(readonly_store) == [] + + # Membership test should return False + assert "any_key" not in readonly_store From 3e0c4fec89c8af618da070e1403d42528a21ca59 Mon Sep 17 00:00:00 2001 From: Tao Peng Date: Mon, 1 Sep 2025 22:44:19 -0700 Subject: [PATCH 06/10] use key value store --- mapillary_tools/history.py | 140 +++++++++++++++++++++---------------- 1 file changed, 81 insertions(+), 59 deletions(-) diff --git a/mapillary_tools/history.py b/mapillary_tools/history.py index a0cf1311..4436bb57 100644 --- a/mapillary_tools/history.py +++ b/mapillary_tools/history.py @@ -1,24 +1,17 @@ from __future__ import annotations -import contextlib -import dbm import json import logging +import os + +import sqlite3 import string import threading import time import typing as T from pathlib import Path -# dbm modules are dynamically imported, so here we explicitly import dbm.sqlite3 to make sure pyinstaller include it -# Otherwise you will see: ImportError: no dbm clone found; tried ['dbm.sqlite3', 'dbm.gnu', 'dbm.ndbm', 'dbm.dumb'] -try: - import dbm.sqlite3 # type: ignore -except ImportError: - pass - - -from . import constants, types +from . import constants, store, types from .serializer.description import DescriptionJSONSerializer JSONDict = T.Dict[str, T.Union[str, int, float, None]] @@ -86,102 +79,131 @@ def write_history( class PersistentCache: - _lock: contextlib.nullcontext | threading.Lock + _lock: threading.Lock def __init__(self, file: str): - # SQLite3 backend supports concurrent access without a lock - if dbm.whichdb(file) == "dbm.sqlite3": - self._lock = contextlib.nullcontext() - else: - self._lock = threading.Lock() self._file = file + self._lock = threading.Lock() def get(self, key: str) -> str | None: + if not self._does_db_exist(): + return None + s = time.perf_counter() - with self._lock: - with dbm.open(self._file, flag="c") as db: - value: bytes | None = db.get(key) + with store.KeyValueStore(self._file, flag="r") as db: + try: + raw_payload: bytes | None = db.get(key) # data retrieved from db[key] + except Exception as ex: + if self._does_table_exist(ex): + return None + raise ex - if value is None: + if raw_payload is None: return None - payload = self._decode(value) + data: JSONDict = self._decode(raw_payload) # JSON dict decoded from db[key] - if self._is_expired(payload): + if self._is_expired(data): return None - file_handle = payload.get("file_handle") + cached_value = data.get("value") # value in the JSON dict decoded from db[key] LOG.debug( f"Found file handle for {key} in cache ({(time.perf_counter() - s) * 1000:.0f} ms)" ) - return T.cast(str, file_handle) + return T.cast(str, cached_value) - def set(self, key: str, file_handle: str, expires_in: int = 3600 * 24 * 2) -> None: + def set(self, key: str, value: str, expires_in: int = 3600 * 24 * 2) -> None: s = time.perf_counter() - payload = { + data = { "expires_at": time.time() + expires_in, - "file_handle": file_handle, + "value": value, } - value: bytes = json.dumps(payload).encode("utf-8") - - with self._lock: - with dbm.open(self._file, flag="c") as db: - db[key] = value + payload: bytes = json.dumps(data).encode("utf-8") + + while True: + try: + with self._lock: + with store.KeyValueStore(self._file, flag="c") as db: + # Assume db exists + db[key] = payload + except sqlite3.OperationalError as ex: + if "database is locked" in str(ex).lower(): + LOG.warning( + f"{str(ex)}: {self._file} (are you running multiple instances?)" + ) + LOG.info("Retrying in 1 second...") + time.sleep(1) + continue + else: + raise ex + else: + break LOG.debug( f"Cached file handle for {key} ({(time.perf_counter() - s) * 1000:.0f} ms)" ) def clear_expired(self) -> list[str]: - s = time.perf_counter() - expired_keys: list[str] = [] - with self._lock: - with dbm.open(self._file, flag="c") as db: - if hasattr(db, "items"): - items: T.Iterable[tuple[str | bytes, bytes]] = db.items() - else: - items = ((key, db[key]) for key in db.keys()) + s = time.perf_counter() - for key, value in items: - payload = self._decode(value) - if self._is_expired(payload): + with self._lock: + with store.KeyValueStore(self._file, flag="c") as db: + # Assume db and table exist here + for key, raw_payload in db.items(): + data = self._decode(raw_payload) + if self._is_expired(data): del db[key] expired_keys.append(T.cast(str, key)) - if expired_keys: - LOG.debug( - f"Cleared {len(expired_keys)} expired entries from the cache ({(time.perf_counter() - s) * 1000:.0f} ms)" - ) + LOG.debug( + f"Cleared {len(expired_keys)} expired entries from the cache ({(time.perf_counter() - s) * 1000:.0f} ms)" + ) return expired_keys - def keys(self): - with self._lock: - with dbm.open(self._file, flag="c") as db: - return db.keys() + def keys(self) -> list[str]: + if not self._does_db_exist(): + return [] - def _is_expired(self, payload: JSONDict) -> bool: - expires_at = payload.get("expires_at") + try: + with store.KeyValueStore(self._file, flag="r") as db: + return [key.decode("utf-8") for key in db.keys()] + except Exception as ex: + if self._does_table_exist(ex): + return [] + raise ex + + def _is_expired(self, data: JSONDict) -> bool: + expires_at = data.get("expires_at") if isinstance(expires_at, (int, float)): return expires_at is None or expires_at <= time.time() return False - def _decode(self, value: bytes) -> JSONDict: + def _decode(self, raw_payload: bytes) -> JSONDict: try: - payload = json.loads(value.decode("utf-8")) + data = json.loads(raw_payload.decode("utf-8")) except json.JSONDecodeError as ex: LOG.warning(f"Failed to decode cache value: {ex}") return {} - if not isinstance(payload, dict): - LOG.warning(f"Invalid cache value format: {payload}") + if not isinstance(data, dict): + LOG.warning(f"Invalid cache value format: {raw_payload}") return {} - return payload + return data + + def _does_db_exist(self) -> bool: + return os.path.exists(self._file) + + def _does_table_exist(self, ex: Exception) -> bool: + if isinstance(ex, sqlite3.OperationalError): + if "no such table" in str(ex): + return True + return False From a14ad66ef8d9e53a40c89aa1ec9bf304e03cd515 Mon Sep 17 00:00:00 2001 From: Tao Peng Date: Mon, 1 Sep 2025 22:46:42 -0700 Subject: [PATCH 07/10] tests --- tests/unit/test_persistent_cache.py | 861 ++++------------------------ 1 file changed, 99 insertions(+), 762 deletions(-) diff --git a/tests/unit/test_persistent_cache.py b/tests/unit/test_persistent_cache.py index baa3bdc8..0d4463ee 100644 --- a/tests/unit/test_persistent_cache.py +++ b/tests/unit/test_persistent_cache.py @@ -1,7 +1,6 @@ import concurrent.futures import os import time -import traceback import pytest @@ -180,21 +179,13 @@ def test_clear_expired_empty(tmpdir): def test_corrupted_data(tmpdir): - """Test handling of corrupted data.""" + """Test handling of corrupted data through public interface.""" cache_file = os.path.join(tmpdir, f"cache_corrupted") cache = PersistentCache(cache_file) # Set valid entry cache.set("key1", "value1") - # Test the _decode method directly with corrupted data to simulate corruption - # This tests the error handling without directly manipulating the database - corrupted_result = cache._decode(b"not valid json") - assert corrupted_result == {} - - corrupted_dict_result = cache._decode(b'"not a dict"') - assert corrupted_dict_result == {} - # Valid entries should still work assert cache.get("key1") == "value1" @@ -202,798 +193,144 @@ def test_corrupted_data(tmpdir): cache.clear_expired() -def test_multithread_basic_concurrency(tmpdir): - """Test basic concurrent access to the cache.""" - cache_file = os.path.join(tmpdir, f"cache_basic_concurrency") - - cache = PersistentCache(cache_file) - num_threads = 20 - num_operations = 10 - - results = [] # Store assertion failures for pytest to check after threads complete - - def worker(thread_id): - # Fixed: Don't overwrite thread_id parameter - for i in range(num_operations): - key = f"key_{thread_id}_{i}" - value = f"value_{thread_id}_{i}" - if cache.get(key) is None: - cache.set(key, value) - # Occasionally read a previously written value - if i > 0 and i % 2 == 0: - prev_key = f"key_{thread_id}_{i - 1}" - prev_value = cache.get(prev_key) - if prev_value != f"value_{thread_id}_{i - 1}": - results.append( - f"Expected {prev_key} to be value_{thread_id}_{i - 1}, got {prev_value}" - ) - - with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor: - list(executor.map(worker, range(num_threads))) - - # Check for any failures in threads - assert not results, f"Thread assertions failed: {results}" - - -# Global function for multiprocessing (needed for pickling) -def _multiprocess_worker(process_id, cache_file, num_ops): - """Worker function for multiprocessing test.""" - try: - # Each process creates its own cache instance - cache = PersistentCache(cache_file) - - for i in range(num_ops): - key = f"proc_{process_id}_op_{i}" - value = f"value_{process_id}_{i}" - - # Rapid operations that might cause database locking - cache.set(key, value, expires_in=2) - retrieved = cache.get(key) - - if retrieved != value: - return f"Process {process_id}: Expected {value}, got {retrieved}" - - # Operations that might cause contention - if i % 3 == 0: - cache.clear_expired() - - # Try to read from other processes - if i % 5 == 0 and process_id > 0: - other_key = f"proc_{process_id - 1}_op_{i}" - cache.get(other_key) - - except Exception as e: - return f"Process {process_id} error: {str(e)} - {traceback.format_exc()}" - - return None - - -def test_multithread_database_lock_detection(tmpdir): - """Comprehensive test for database lock detection across multiple scenarios.""" - - # Phase 1: Multiprocess access that might trigger database lock issues - cache_file_mp = os.path.join(tmpdir, "cache_multiprocess") - num_processes = 8 - num_operations = 20 - - with concurrent.futures.ProcessPoolExecutor(max_workers=num_processes) as executor: - futures = [ - executor.submit(_multiprocess_worker, i, cache_file_mp, num_operations) - for i in range(num_processes) - ] - results = [ - future.result() for future in concurrent.futures.as_completed(futures) - ] - - # Check for errors in multiprocess phase - errors = [r for r in results if r is not None] - database_lock_errors = [e for e in errors if "database is locked" in str(e).lower()] - - if database_lock_errors: - pytest.fail( - f"Database lock errors in multiprocess phase: {database_lock_errors}" - ) - - if errors: - pytest.fail(f"Multiprocess phase errors: {errors}") - - # Phase 2: Simultaneous database operations that might cause locking - cache_file_sim = os.path.join(tmpdir, "cache_simultaneous") - sim_errors = [] - - def synchronized_worker(thread_id): - """Worker that starts operations simultaneously.""" - try: - cache = PersistentCache(cache_file_sim) - - # All threads perform operations at the same time - for i in range(20): - key = f"sync_{thread_id}_{i}" - value = f"value_{thread_id}_{i}" - - # Simultaneous write operations - cache.set(key, value) - - # Immediate read back - result = cache.get(key) - if result != value: - sim_errors.append( - f"Thread {thread_id}: Expected {value}, got {result}" - ) - - # Mixed operations - if i % 2 == 0: - cache.clear_expired() - - except Exception as e: - sim_errors.append(f"Thread {thread_id} error: {str(e)}") - - # Start all threads simultaneously using ThreadPoolExecutor - with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: - futures = [executor.submit(synchronized_worker, i) for i in range(10)] - concurrent.futures.wait(futures) - - # Check for database lock errors in simultaneous phase - sim_database_lock_errors = [ - e for e in sim_errors if "database is locked" in str(e).lower() - ] - if sim_database_lock_errors: - pytest.fail( - f"Database lock errors in simultaneous operations phase: {sim_database_lock_errors}" - ) - - if sim_errors: - pytest.fail(f"Simultaneous operation errors: {sim_errors[:10]}") - - # Phase 3: Stress test with exception handling - cache_file_stress = os.path.join(tmpdir, "cache_stress") - cache = PersistentCache(cache_file_stress) - - def stress_worker(thread_id): - """Worker that performs operations and handles exceptions.""" - database_lock_count = 0 - other_errors = [] - - for i in range(100): # More operations to increase chance of lock - try: - key = f"stress_{thread_id}_{i}" - value = f"value_{thread_id}_{i}" - - # Rapid operations - cache.set(key, value, expires_in=1) - cache.get(key) - - # Operations that might cause contention - if i % 10 == 0: - cache.clear_expired() - - # Additional operations that might cause contention - if i % 15 == 0: - # Use cache.keys() instead of direct dbm access - try: - list(cache.keys()) - except Exception: - pass # Ignore access errors - - except Exception as e: - error_msg = str(e).lower() - if "database is locked" in error_msg: - database_lock_count += 1 - else: - other_errors.append(f"Thread {thread_id}, op {i}: {str(e)}") - - return database_lock_count, other_errors - - # Run stress test - num_threads = 15 - total_lock_errors = 0 - all_other_errors = [] - - with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: - futures = [executor.submit(stress_worker, i) for i in range(num_threads)] - for future in concurrent.futures.as_completed(futures): - lock_count, other_errors = future.result() - total_lock_errors += lock_count - all_other_errors.extend(other_errors) - - # Report results from stress phase - if total_lock_errors > 0: - pytest.fail( - f"Database lock errors detected in stress phase: {total_lock_errors} total lock errors" - ) - - if all_other_errors: - pytest.fail(f"Other stress test errors: {all_other_errors[:5]}") - - -def test_multithread_rapid_file_creation(tmpdir): - """Test rapid database file creation that might trigger lock issues.""" - base_path = os.path.join(tmpdir, f"rapid_creation") - - def rapid_creator(thread_id): - """Create and use cache files rapidly.""" - errors = [] - try: - for i in range(10): - # Create a new cache file for each operation - cache_file = f"{base_path}_{thread_id}_{i}" - cache = PersistentCache(cache_file) - - # Perform operations immediately after creation - cache.set("test_key", f"test_value_{thread_id}_{i}") - result = cache.get("test_key") - - if result != f"test_value_{thread_id}_{i}": - errors.append( - f"Thread {thread_id}, iteration {i}: Expected test_value_{thread_id}_{i}, got {result}" - ) - - except Exception as e: - errors.append(f"Thread {thread_id} error: {str(e)}") - - return errors - - # Run multiple threads creating cache files rapidly - num_threads = 20 - all_errors = [] - - with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: - futures = [executor.submit(rapid_creator, i) for i in range(num_threads)] - for future in concurrent.futures.as_completed(futures): - errors = future.result() - all_errors.extend(errors) - - # Check for database lock errors - database_lock_errors = [ - e for e in all_errors if "database is locked" in str(e).lower() - ] - if database_lock_errors: - pytest.fail( - f"Database lock errors in rapid creation test: {database_lock_errors}" - ) - - if all_errors: - pytest.fail(f"Rapid creation errors: {all_errors[:5]}") # Show first 5 errors - - def test_multithread_shared_cache_comprehensive(tmpdir): - """Comprehensive test for shared cache instance across threads with get->set pattern validation.""" + """Test shared cache instance across multiple threads using get->set pattern. + + Tests multithread scenarios using a single shared PersistentCache instance, + which simulates real-world usage patterns like CachedImageUploader.upload. + This test covers the case where values_a and values_b can intersect (overlapping keys). + """ cache_file = os.path.join(tmpdir, "cache_shared_comprehensive") # Initialize cache once and share across all workers shared_cache = PersistentCache(cache_file) shared_cache.clear_expired() - # Phase 1: Database lock detection with shared cache instance (reproduces real uploader.py usage) - num_threads_phase1 = 30 - num_operations_phase1 = 100 - phase1_errors = [] - - def shared_cache_worker(thread_id): - """Worker that uses the shared cache instance (like CachedImageUploader.upload).""" - try: - for i in range(num_operations_phase1): - key = f"shared_thread_{thread_id}_op_{i}" - value = f"shared_value_{thread_id}_{i}" - - # Simulate the pattern from CachedImageUploader: - # 1. Check cache first (_get_cached_file_handle) - cached_value = shared_cache.get(key) - - if cached_value is None: - # 2. Set new value (_set_file_handle_cache) - shared_cache.set(key, value, expires_in=2) - retrieved = shared_cache.get(key) - - if retrieved != value: - phase1_errors.append( - f"Thread {thread_id}: Expected {value}, got {retrieved}" - ) - else: - # 3. Update cache with existing value (_set_file_handle_cache) - shared_cache.set(key, cached_value, expires_in=2) - - # Occasional cleanup operations - if i % 20 == 0: - shared_cache.clear_expired() - - # Cross-thread access pattern - if i % 7 == 0 and thread_id > 0: - other_key = f"shared_thread_{thread_id - 1}_op_{i}" - shared_cache.get(other_key) - - except Exception as e: - phase1_errors.append(f"Thread {thread_id} error: {str(e)}") - - # Run phase 1: Database lock detection - with concurrent.futures.ThreadPoolExecutor( - max_workers=num_threads_phase1 - ) as executor: - futures = [ - executor.submit(shared_cache_worker, i) for i in range(num_threads_phase1) - ] - concurrent.futures.wait(futures) - - # Check for database lock errors in phase 1 - database_lock_errors = [ - e for e in phase1_errors if "database is locked" in str(e).lower() - ] - if database_lock_errors: - pytest.fail( - f"Database lock errors with shared cache instance: {database_lock_errors}" - ) + # Generate key-value pairs for first run (overlapping patterns to ensure intersections) + first_dict = {f"key_{i}": f"first_value_{i}" for i in range(5_000)} + assert len(first_dict) == 5_000 - # Check for data consistency errors (values not persisting correctly) - data_consistency_errors = [ - e for e in phase1_errors if "Expected" in e and "got None" in e - ] - if data_consistency_errors: - pytest.fail( - f"Data consistency errors with shared cache (race conditions): {data_consistency_errors[:5]}" + s = time.perf_counter() + # First concurrent run with get->set pattern + with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: + list( + executor.map( + lambda kv: _shared_worker_get_set_pattern(shared_cache, [kv]), + first_dict.items(), + ) ) + print(f"First run time: {(time.perf_counter() - s) * 1000:.0f} ms") - if phase1_errors: - pytest.fail(f"Phase 1 shared cache errors: {phase1_errors[:10]}") - - # Phase 2: Get->set pattern validation with cache correctness verification shared_cache.clear_expired() + assert len(shared_cache.keys()) == len(first_dict) + for key in shared_cache.keys(): + assert shared_cache.get(key) == first_dict[key] - num_threads_phase2 = 10 - num_operations_phase2 = 20 - all_set_values = [] + # Generate key-value pairs for first run (overlapping patterns to ensure intersections) + second_dict = {f"key_{i}": f"second_value_{i}" for i in range(2_500, 7_500)} + s = time.perf_counter() # First concurrent run with get->set pattern - with concurrent.futures.ThreadPoolExecutor( - max_workers=num_threads_phase2 - ) as executor: - futures = [ - executor.submit( - _shared_worker_get_set_pattern, - i, - shared_cache, - num_operations_phase2, - "phase2_run1", - ) - for i in range(num_threads_phase2) - ] - for future in concurrent.futures.as_completed(futures): - set_values = future.result() - all_set_values.extend(set_values) - - # Verify that all set values are cached correctly - for key, expected_value in all_set_values: - cached_value = shared_cache.get(key) - assert ( - cached_value == expected_value - ), f"Phase2 Run1 - Key {key}: expected {expected_value}, got {cached_value}" - - # Clear expired between concurrent runs - expired_keys_1 = shared_cache.clear_expired() - assert isinstance(expired_keys_1, list), "clear_expired should return a list" - - # Second concurrent run with get->set pattern - run2_set_values = [] - with concurrent.futures.ThreadPoolExecutor( - max_workers=num_threads_phase2 - ) as executor: - futures = [ - executor.submit( - _shared_worker_get_set_pattern, - i, - shared_cache, - num_operations_phase2, - "phase2_run2", + with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: + list( + executor.map( + lambda kv: _shared_worker_get_set_pattern(shared_cache, [kv]), + second_dict.items(), ) - for i in range(num_threads_phase2) - ] - for future in concurrent.futures.as_completed(futures): - set_values = future.result() - run2_set_values.extend(set_values) - - # Verify that all run2 set values are cached correctly - for key, expected_value in run2_set_values: - cached_value = shared_cache.get(key) - assert ( - cached_value == expected_value - ), f"Phase2 Run2 - Key {key}: expected {expected_value}, got {cached_value}" - - # Final validation - expired_keys_2 = shared_cache.clear_expired() - assert isinstance(expired_keys_2, list), "clear_expired should return a list" - - # Assertions using keys() and counts - all_keys = list(shared_cache.keys()) - expected_keys_count = num_threads_phase2 * num_operations_phase2 * 2 # Two runs - - assert ( - len(all_keys) == expected_keys_count - ), f"Expected {expected_keys_count} keys, got {len(all_keys)}" - - # Verify key patterns - run1_keys = [k for k in all_keys if b"phase2_run1" in k] - run2_keys = [k for k in all_keys if b"phase2_run2" in k] - - assert len(run1_keys) == num_threads_phase2 * num_operations_phase2 - assert len(run2_keys) == num_threads_phase2 * num_operations_phase2 - - # Verify that the total number of set operations matches expectations - total_set_operations = len(all_set_values) + len(run2_set_values) - assert ( - total_set_operations == expected_keys_count - ), f"Expected {expected_keys_count} set operations, got {total_set_operations}" - - -def test_decode_invalid_data(tmpdir): - """Test _decode method with invalid data.""" - cache_file = os.path.join(tmpdir, f"cache_decode_invalid") - cache = PersistentCache(cache_file) - - # Test with various invalid inputs - result = cache._decode(b"not valid json") - assert result == {} + ) + print(f"Second run time: {(time.perf_counter() - s) * 1000:.0f} ms") - result = cache._decode(b'"string instead of dict"') - assert result == {} + shared_cache.clear_expired() + merged_dict = {**second_dict, **first_dict} + assert len(merged_dict) < len(first_dict) + len(second_dict) -def test_is_expired(tmpdir): - """Test _is_expired method.""" - cache_file = os.path.join(tmpdir, f"cache_is_expired") - cache = PersistentCache(cache_file) - - # Test with various payloads - assert cache._is_expired({"expires_at": time.time() - 10}) is True - assert cache._is_expired({"expires_at": time.time() + 10}) is False - assert cache._is_expired({}) is False - assert cache._is_expired({"expires_at": "not a number"}) is False - assert cache._is_expired({"expires_at": None}) is False + assert len(shared_cache.keys()) == len(merged_dict) + for key in shared_cache.keys(): + assert shared_cache.get(key) == merged_dict[key] # Shared worker functions for concurrency tests -def _shared_worker_get_set_pattern( - worker_id, cache, num_operations, key_prefix="worker" -): +def _shared_worker_get_set_pattern(cache, key_value_pairs, expires_in=1000): """Shared worker implementation: get key -> if not exist then set key=value.""" - set_values = [] - for i in range(num_operations): - key = f"{key_prefix}_{worker_id}_{i}" - value = f"value_{worker_id}_{i}" - + for key, value in key_value_pairs: # Pattern: get a key -> if not exist then set key=value existing_value = cache.get(key) if existing_value is None: - cache.set(key, value, expires_in=10) - set_values.append((key, value)) + cache.set(key, value, expires_in=expires_in) + else: + value = existing_value # Verify the value was set correctly retrieved_value = cache.get(key) - assert ( - retrieved_value == value - ), f"Worker {worker_id}: Expected {value}, got {retrieved_value}" - - return set_values - - -def _multiprocess_worker_get_set_pattern(args): - """Multiprocess worker wrapper for the shared worker pattern.""" - worker_id, cache_file, num_operations, key_prefix = args - try: - # Each process creates its own cache instance pointing to the same file - cache = PersistentCache(cache_file) - return _shared_worker_get_set_pattern( - worker_id, cache, num_operations, key_prefix + assert retrieved_value == value, ( + f"Expected {value}, got {retrieved_value} for key {key}" ) - except Exception as e: - return [f"Process {worker_id} error: {str(e)} - {traceback.format_exc()}"] - -def _mixed_operations_worker(args): - """Worker that performs mixed operations with get->set pattern.""" - worker_id, cache_file, num_operations, run_prefix = args - errors = [] - try: - cache = PersistentCache(cache_file) - for i in range(num_operations): - key = f"{run_prefix}_mixed_{worker_id}_{i}" - value = f"mixed_value_{worker_id}_{i}" - - # Pattern: get a key -> if not exist then set key=value - existing_value = cache.get(key) - if existing_value is None: - cache.set(key, value, expires_in=15) - - # Mixed operations: also try to read from other workers - if i % 5 == 0 and worker_id > 0: - other_key = f"{run_prefix}_mixed_{worker_id - 1}_{i}" - cache.get(other_key) - - # Periodic cleanup - if i % 10 == 0: - cache.clear_expired() +def test_keys_basic(tmpdir): + """Test keys() method in read mode with empty cache.""" + cache_file = os.path.join(tmpdir, "cache_keys_empty") + cache = PersistentCache(cache_file) + cache.set("key1", "value1") - except Exception as e: - errors.append(f"Process {worker_id} error: {str(e)}") + # Test keys on non-existent cache file + keys = cache.keys() + assert keys == ["key1"] - return errors +def test_keys_read_mode_empty_cache(tmpdir): + """Test keys() method in read mode with empty cache.""" + cache_file = os.path.join(tmpdir, "cache_keys_empty") + cache = PersistentCache(cache_file) -def test_multiprocess_comprehensive_patterns(tmpdir): - """Comprehensive test for multiprocess access simulating multiple independent applications. + # Test keys on non-existent cache file + keys = cache.keys() + assert keys == [] - Each process creates its own PersistentCache instance to simulate running multiple apps - that access the same cache file independently. - """ - # Phase 1: Basic multiprocess get->set pattern with cache correctness validation - # Each process will create its own PersistentCache instance (simulating multiple apps) - cache_file_basic = os.path.join(tmpdir, "cache_multiprocess_basic") - - num_processes_basic = 6 - num_operations_basic = 15 - - # First concurrent run - basic pattern (each process creates its own cache instance) - with concurrent.futures.ProcessPoolExecutor( - max_workers=num_processes_basic - ) as executor: - args_list = [ - (i, cache_file_basic, num_operations_basic, "basic_run1") - for i in range(num_processes_basic) - ] - futures = [ - executor.submit(_multiprocess_worker_get_set_pattern, args) - for args in args_list - ] - results_1 = [ - future.result() for future in concurrent.futures.as_completed(futures) - ] - - # Collect all set values from first run and verify they are cached correctly - # Create a separate validation cache instance (simulating another app checking the results) - validation_cache_basic = PersistentCache(cache_file_basic) - all_set_values_run1 = [] - for result in results_1: - all_set_values_run1.extend(result) - - # Verify that all set values from run1 are cached correctly - for key, expected_value in all_set_values_run1: - cached_value = validation_cache_basic.get(key) - assert ( - cached_value == expected_value - ), f"Basic Run1 - Key {key}: expected {expected_value}, got {cached_value}" - - # Clear expired between concurrent runs (using validation cache instance) - expired_keys_1 = validation_cache_basic.clear_expired() - assert isinstance(expired_keys_1, list), "clear_expired should return a list" - - # Second concurrent run - basic pattern (each process creates its own cache instance) - with concurrent.futures.ProcessPoolExecutor( - max_workers=num_processes_basic - ) as executor: - args_list = [ - (i, cache_file_basic, num_operations_basic, "basic_run2") - for i in range(num_processes_basic) - ] - futures = [ - executor.submit(_multiprocess_worker_get_set_pattern, args) - for args in args_list - ] - results_2 = [ - future.result() for future in concurrent.futures.as_completed(futures) - ] - - # Collect all set values from second run and verify they are cached correctly - all_set_values_run2 = [] - for result in results_2: - all_set_values_run2.extend(result) - - # Verify that all set values from run2 are cached correctly - for key, expected_value in all_set_values_run2: - cached_value = validation_cache_basic.get(key) - assert ( - cached_value == expected_value - ), f"Basic Run2 - Key {key}: expected {expected_value}, got {cached_value}" - - # Validate basic pattern results - expired_keys_2 = validation_cache_basic.clear_expired() - assert isinstance(expired_keys_2, list), "clear_expired should return a list" - - all_keys_basic = list(validation_cache_basic.keys()) - expected_keys_count_basic = ( - num_processes_basic * num_operations_basic * 2 - ) # Two runs - - assert ( - len(all_keys_basic) == expected_keys_count_basic - ), f"Expected {expected_keys_count_basic} keys, got {len(all_keys_basic)}" - - # Verify key patterns for basic test - run1_keys = [k for k in all_keys_basic if b"basic_run1" in k] - run2_keys = [k for k in all_keys_basic if b"basic_run2" in k] - - assert len(run1_keys) == num_processes_basic * num_operations_basic - assert len(run2_keys) == num_processes_basic * num_operations_basic - - # Verify that the total number of set operations matches expectations - total_set_operations = len(all_set_values_run1) + len(all_set_values_run2) - assert ( - total_set_operations == expected_keys_count_basic - ), f"Expected {expected_keys_count_basic} set operations, got {total_set_operations}" - - # Phase 2: Mixed operations pattern with cross-process reads and cleanup - # Each process will create its own PersistentCache instance (simulating multiple apps) - cache_file_mixed = os.path.join(tmpdir, "cache_multiprocess_mixed") - - num_processes_mixed = 8 - num_operations_mixed = 25 - - # First concurrent run - mixed operations (each process creates its own cache instance) - with concurrent.futures.ProcessPoolExecutor( - max_workers=num_processes_mixed - ) as executor: - args_list = [ - (i, cache_file_mixed, num_operations_mixed, "mixed_run1") - for i in range(num_processes_mixed) - ] - futures = [ - executor.submit(_mixed_operations_worker, args) for args in args_list - ] - results_mixed_1 = [ - future.result() for future in concurrent.futures.as_completed(futures) - ] - - # Create a separate validation cache instance for mixed operations - validation_cache_mixed = PersistentCache(cache_file_mixed) - - # Clear expired between concurrent runs - expired_keys_mixed_1 = validation_cache_mixed.clear_expired() - assert isinstance(expired_keys_mixed_1, list), "clear_expired should return a list" - keys_after_mixed_run1 = list(validation_cache_mixed.keys()) - - # Verify first mixed run results - expected_keys_mixed_run1 = num_processes_mixed * num_operations_mixed - assert ( - len(keys_after_mixed_run1) == expected_keys_mixed_run1 - ), f"Expected {expected_keys_mixed_run1} keys after mixed run1, got {len(keys_after_mixed_run1)}" - - # Second concurrent run - mixed operations (each process creates its own cache instance) - with concurrent.futures.ProcessPoolExecutor( - max_workers=num_processes_mixed - ) as executor: - args_list = [ - (i, cache_file_mixed, num_operations_mixed, "mixed_run2") - for i in range(num_processes_mixed) - ] - futures = [ - executor.submit(_mixed_operations_worker, args) for args in args_list - ] - results_mixed_2 = [ - future.result() for future in concurrent.futures.as_completed(futures) - ] - - # Validate mixed operations results - expired_keys_mixed_2 = validation_cache_mixed.clear_expired() - assert isinstance(expired_keys_mixed_2, list), "clear_expired should return a list" - keys_after_mixed_run2 = list(validation_cache_mixed.keys()) - - # Collect all errors from both mixed runs - all_mixed_errors = [] - for result in results_mixed_1 + results_mixed_2: - all_mixed_errors.extend(result) - - # Assertions for mixed operations - expected_keys_count_mixed = ( - num_processes_mixed * num_operations_mixed * 2 - ) # Two runs - - assert ( - len(keys_after_mixed_run2) == expected_keys_count_mixed - ), f"Expected {expected_keys_count_mixed} keys, got {len(keys_after_mixed_run2)}" - - # Verify key patterns for mixed operations - mixed_run1_keys = [k for k in keys_after_mixed_run2 if b"mixed_run1" in k] - mixed_run2_keys = [k for k in keys_after_mixed_run2 if b"mixed_run2" in k] - - assert len(mixed_run1_keys) == num_processes_mixed * num_operations_mixed - assert len(mixed_run2_keys) == num_processes_mixed * num_operations_mixed - - # Check for database lock errors specifically - database_lock_errors = [ - e for e in all_mixed_errors if "database is locked" in str(e).lower() - ] - if database_lock_errors: - pytest.fail( - f"Database lock errors in multiprocess comprehensive test: {database_lock_errors}" - ) +def test_sqlite_database_locking(tmpdir): + """Test database locking with multiple threads accessing the same cache file.""" + import sqlite3 - # Check for any other errors - assert ( - not all_mixed_errors - ), f"Multiprocess comprehensive errors occurred: {all_mixed_errors}" + cache_file = os.path.join(tmpdir, "cache_sqlite_locking") + def create_table(value): + while True: + try: + with sqlite3.connect( + cache_file, autocommit=True + ) as conn: # Create the database file + conn.execute( + "CREATE TABLE IF NOT EXISTS cache (key TEXT, value TEXT)" + ) + conn.execute( + "INSERT INTO cache (key, value) VALUES (?, ?)", ("key1", value) + ) + except sqlite3.OperationalError as e: + if "database is locked" in str(e): + print(f"Thread {value} failed to acquire lock: {e}") + time.sleep(1) + continue + else: + raise + else: + break -def test_multithread_high_contention_get_set_pattern(tmpdir): - """Test high contention multithreaded access with shared cache using get->set pattern.""" - cache_file = os.path.join(tmpdir, "cache_multithread_contention") + with sqlite3.connect( + cache_file, autocommit=True + ) as conn: # Create the database file + rows = [row for row in conn.execute("select * from cache")] - # Initialize cache once and share across all workers - shared_cache = PersistentCache(cache_file) + r = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor: + futures = [executor.submit(create_table, str(_)) for _ in range(1000)] + for f in futures: + r.append(f.result()) - # Clear expired before first concurrent run - shared_cache.clear_expired() + with sqlite3.connect( + cache_file, autocommit=True + ) as conn: # Create the database file + row_count = len([row for row in conn.execute("select * from cache")]) - num_threads = 20 - num_operations = 50 - all_errors = [] - - def high_contention_worker(worker_id, run_prefix): - """Worker with higher contention - accessing overlapping keys.""" - errors = [] - try: - for i in range(num_operations): - # Use overlapping keys to increase contention - key = f"{run_prefix}_shared_{i % 10}" # Only 10 unique keys per run - value = f"value_{worker_id}_{i}" - - # Pattern: get a key -> if not exist then set key=value - existing_value = shared_cache.get(key) - if existing_value is None: - shared_cache.set(key, value, expires_in=10) - - # Occasional clear_expired to add more contention - if i % 25 == 0: - shared_cache.clear_expired() - - except Exception as e: - errors.append(f"Worker {worker_id} error: {str(e)}") - - return errors - - # First concurrent run with high contention - with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: - futures = [ - executor.submit(high_contention_worker, i, "contention_run1") - for i in range(num_threads) - ] - for future in concurrent.futures.as_completed(futures): - errors = future.result() - all_errors.extend(errors) - - # Clear expired between concurrent runs - expired_keys_1 = shared_cache.clear_expired() - assert isinstance(expired_keys_1, list), "clear_expired should return a list" - keys_after_run1 = list(shared_cache.keys()) - - # Second concurrent run with high contention - with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: - futures = [ - executor.submit(high_contention_worker, i, "contention_run2") - for i in range(num_threads) - ] - for future in concurrent.futures.as_completed(futures): - errors = future.result() - all_errors.extend(errors) - - # Clear expired after second concurrent run - expired_keys_2 = shared_cache.clear_expired() - assert isinstance(expired_keys_2, list), "clear_expired should return a list" - keys_after_run2 = list(shared_cache.keys()) - - # Assertions using keys() and counts - # With overlapping keys, we expect at most 10 keys per run (due to key overlap) - assert ( - len(keys_after_run1) <= 10 - ), f"Expected at most 10 keys after run1, got {len(keys_after_run1)}" - assert ( - len(keys_after_run2) <= 20 - ), f"Expected at most 20 keys after run2, got {len(keys_after_run2)}" - - # Verify key patterns exist - run1_keys = [k for k in keys_after_run2 if b"contention_run1" in k] - run2_keys = [k for k in keys_after_run2 if b"contention_run2" in k] - - assert len(run1_keys) > 0, "Should have some keys from run1" - assert len(run2_keys) > 0, "Should have some keys from run2" - - # Check for any worker errors - assert not all_errors, f"Worker errors occurred: {all_errors}" + assert row_count == 3000, f"Expected 3000 rows, got {row_count}" From d0148cb8f4e7b278623f5bc49a450b88ca98b108 Mon Sep 17 00:00:00 2001 From: Tao Peng Date: Wed, 3 Sep 2025 20:59:21 -0700 Subject: [PATCH 08/10] refactor --- mapillary_tools/history.py | 60 +++++++++++++++++++++---------------- mapillary_tools/store.py | 26 ++++++++++++---- mapillary_tools/uploader.py | 2 +- 3 files changed, 56 insertions(+), 32 deletions(-) diff --git a/mapillary_tools/history.py b/mapillary_tools/history.py index 4436bb57..eb07175c 100644 --- a/mapillary_tools/history.py +++ b/mapillary_tools/history.py @@ -3,12 +3,12 @@ import json import logging import os - import sqlite3 import string import threading import time import typing as T +from functools import wraps from pathlib import Path from . import constants, store, types @@ -78,6 +78,28 @@ def write_history( fp.write(json.dumps(history)) +def _retry_on_database_lock_error(fn): + """ + Decorator to retry a function if it raises a sqlite3.OperationalError with + "database is locked" in the message. + """ + + @wraps(fn) + def wrapper(*args, **kwargs): + while True: + try: + return fn(*args, **kwargs) + except sqlite3.OperationalError as ex: + if "database is locked" in str(ex).lower(): + LOG.warning(f"{str(ex)}") + LOG.info("Retrying in 1 second...") + time.sleep(1) + else: + raise ex + + return wrapper + + class PersistentCache: _lock: threading.Lock @@ -86,7 +108,7 @@ def __init__(self, file: str): self._lock = threading.Lock() def get(self, key: str) -> str | None: - if not self._does_db_exist(): + if not self._db_existed(): return None s = time.perf_counter() @@ -95,7 +117,7 @@ def get(self, key: str) -> str | None: try: raw_payload: bytes | None = db.get(key) # data retrieved from db[key] except Exception as ex: - if self._does_table_exist(ex): + if self._table_not_found(ex): return None raise ex @@ -115,6 +137,7 @@ def get(self, key: str) -> str | None: return T.cast(str, cached_value) + @_retry_on_database_lock_error def set(self, key: str, value: str, expires_in: int = 3600 * 24 * 2) -> None: s = time.perf_counter() @@ -125,29 +148,15 @@ def set(self, key: str, value: str, expires_in: int = 3600 * 24 * 2) -> None: payload: bytes = json.dumps(data).encode("utf-8") - while True: - try: - with self._lock: - with store.KeyValueStore(self._file, flag="c") as db: - # Assume db exists - db[key] = payload - except sqlite3.OperationalError as ex: - if "database is locked" in str(ex).lower(): - LOG.warning( - f"{str(ex)}: {self._file} (are you running multiple instances?)" - ) - LOG.info("Retrying in 1 second...") - time.sleep(1) - continue - else: - raise ex - else: - break + with self._lock: + with store.KeyValueStore(self._file, flag="c") as db: + db[key] = payload LOG.debug( f"Cached file handle for {key} ({(time.perf_counter() - s) * 1000:.0f} ms)" ) + @_retry_on_database_lock_error def clear_expired(self) -> list[str]: expired_keys: list[str] = [] @@ -155,7 +164,6 @@ def clear_expired(self) -> list[str]: with self._lock: with store.KeyValueStore(self._file, flag="c") as db: - # Assume db and table exist here for key, raw_payload in db.items(): data = self._decode(raw_payload) if self._is_expired(data): @@ -169,14 +177,14 @@ def clear_expired(self) -> list[str]: return expired_keys def keys(self) -> list[str]: - if not self._does_db_exist(): + if not self._db_existed(): return [] try: with store.KeyValueStore(self._file, flag="r") as db: return [key.decode("utf-8") for key in db.keys()] except Exception as ex: - if self._does_table_exist(ex): + if self._table_not_found(ex): return [] raise ex @@ -199,10 +207,10 @@ def _decode(self, raw_payload: bytes) -> JSONDict: return data - def _does_db_exist(self) -> bool: + def _db_existed(self) -> bool: return os.path.exists(self._file) - def _does_table_exist(self, ex: Exception) -> bool: + def _table_not_found(self, ex: Exception) -> bool: if isinstance(ex, sqlite3.OperationalError): if "no such table" in str(ex): return True diff --git a/mapillary_tools/store.py b/mapillary_tools/store.py index c75b560a..32253665 100644 --- a/mapillary_tools/store.py +++ b/mapillary_tools/store.py @@ -1,5 +1,15 @@ +""" +This module provides a persistent key-value store based on SQLite. + +This implementation is mostly copied from dbm.sqlite3 in the Python standard library, +but works for Python >= 3.9, whereas dbm.sqlite3 is only available for Python 3.13. + +Source: https://github.com/python/cpython/blob/3.13/Lib/dbm/sqlite3.py +""" + import os import sqlite3 +import sys from collections.abc import MutableMapping from contextlib import closing, suppress from pathlib import Path @@ -17,9 +27,6 @@ ITER_KEYS = "SELECT key FROM Dict" -_ERR_CLOSED = "KeyValueStore object has already been closed" - - def _normalize_uri(path): path = Path(path) uri = path.absolute().as_uri() @@ -62,7 +69,11 @@ def __init__(self, path, /, *, flag="r", mode=0o666): uri = _normalize_uri(path) uri = f"{uri}?mode={flag}" - self._cx = sqlite3.connect(uri, autocommit=True, uri=True) + if sys.version_info >= (3, 12): + # This is the preferred way, but only available in Python 3.10 and newer. + self._cx = sqlite3.connect(uri, autocommit=True, uri=True) + else: + self._cx = sqlite3.connect(uri, uri=True) # This is an optimization only; it's ok if it fails. with suppress(sqlite3.OperationalError): @@ -72,7 +83,12 @@ def __init__(self, path, /, *, flag="r", mode=0o666): self._execute(BUILD_TABLE) def _execute(self, *args, **kwargs): - return closing(self._cx.execute(*args, **kwargs)) + if sys.version_info >= (3, 12): + return closing(self._cx.execute(*args, **kwargs)) + else: + # Use a context manager to commit the changes + with self._cx: + return closing(self._cx.execute(*args, **kwargs)) def __len__(self): with self._execute(GET_SIZE) as cu: diff --git a/mapillary_tools/uploader.py b/mapillary_tools/uploader.py index 71a37b6a..2107d610 100644 --- a/mapillary_tools/uploader.py +++ b/mapillary_tools/uploader.py @@ -1311,7 +1311,7 @@ def _is_uuid(key: str) -> bool: def _build_upload_cache_path(upload_options: UploadOptions) -> Path: - # Different python/CLI versions use different cache (dbm) formats. + # Different python/CLI versions use different cache formats. # Separate them to avoid conflicts py_version_parts = [str(part) for part in sys.version_info[:3]] version = f"py_{'_'.join(py_version_parts)}_{VERSION}" From 10dd6749449ba09acb3e75361e7b9c30b36d2bd6 Mon Sep 17 00:00:00 2001 From: Tao Peng Date: Wed, 3 Sep 2025 21:41:45 -0700 Subject: [PATCH 09/10] tests --- tests/unit/test_persistent_cache.py | 474 +++++++++++++++++++++++++--- 1 file changed, 423 insertions(+), 51 deletions(-) diff --git a/tests/unit/test_persistent_cache.py b/tests/unit/test_persistent_cache.py index 0d4463ee..8faaf76d 100644 --- a/tests/unit/test_persistent_cache.py +++ b/tests/unit/test_persistent_cache.py @@ -1,6 +1,9 @@ import concurrent.futures +import multiprocessing import os +import sqlite3 import time +import traceback import pytest @@ -193,6 +196,179 @@ def test_corrupted_data(tmpdir): cache.clear_expired() +def test_keys_basic(tmpdir): + """Test keys() method in read mode with empty cache.""" + cache_file = os.path.join(tmpdir, "cache_keys_empty") + cache = PersistentCache(cache_file) + cache.set("key1", "value1") + + # Test keys on non-existent cache file + keys = cache.keys() + assert keys == ["key1"] + + +def test_keys_read_mode_empty_cache(tmpdir): + """Test keys() method in read mode with empty cache.""" + cache_file = os.path.join(tmpdir, "cache_keys_empty") + cache = PersistentCache(cache_file) + + # Test keys on non-existent cache file + keys = cache.keys() + assert keys == [] + + +def test_keys_read_mode_with_data(tmpdir): + """Test keys() method in read mode with existing data.""" + cache_file = os.path.join(tmpdir, "cache_keys_data") + cache = PersistentCache(cache_file) + + # Add some data first + cache.set("key1", "value1") + cache.set("key2", "value2") + cache.set("key3", "value3") + + # Test keys retrieval in read mode + keys = cache.keys() + assert len(keys) == 3 + assert "key1" in keys + assert "key2" in keys + assert "key3" in keys + assert set(keys) == {"key1", "key2", "key3"} + + +def test_keys_read_mode_with_expired_data(tmpdir): + """Test keys() method in read mode includes expired entries.""" + cache_file = os.path.join(tmpdir, "cache_keys_expired") + cache = PersistentCache(cache_file) + + # Add data with short expiration + cache.set("expired_key", "value1", expires_in=1) + cache.set("valid_key", "value2", expires_in=10) + + # Wait for expiration + time.sleep(1.1) + + # keys() should still return expired keys (it doesn't filter by expiration) + keys = cache.keys() + assert len(keys) == 2 + assert "expired_key" in keys + assert "valid_key" in keys + + # But get() should return None for expired key + assert cache.get("expired_key") is None + assert cache.get("valid_key") == "value2" + + +def test_keys_write_mode_concurrent_operations(tmpdir): + """Test keys() method during concurrent write operations.""" + cache_file = os.path.join(tmpdir, "cache_keys_concurrent") + cache = PersistentCache(cache_file) + + # Initial data + cache.set("initial_key", "initial_value") + + def write_worker(worker_id): + """Worker function that writes data and checks keys.""" + key = f"worker_{worker_id}_key" + value = f"worker_{worker_id}_value" + cache.set(key, value) + + # Get keys after writing + keys = cache.keys() + assert key in keys + return keys + + # Run concurrent write operations + with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(write_worker, i) for i in range(10)] + # Wait for all futures to complete + for f in futures: + f.result() + + # Final check - should have all keys + final_keys = cache.keys() + assert "initial_key" in final_keys + assert len(final_keys) >= 11 # initial + 10 worker keys + + # Verify all worker keys are present + for i in range(10): + assert f"worker_{i}_key" in final_keys + + +def test_keys_write_mode_after_clear_expired(tmpdir): + """Test keys() method after clear_expired() operations.""" + cache_file = os.path.join(tmpdir, "cache_keys_clear") + cache = PersistentCache(cache_file) + + # Add mixed data - some that will expire, some that won't + cache.set("short_lived_1", "value1", expires_in=1) + cache.set("short_lived_2", "value2", expires_in=1) + cache.set("long_lived_1", "value3", expires_in=10) + cache.set("long_lived_2", "value4", expires_in=10) + + # Verify all keys are present initially + initial_keys = cache.keys() + assert len(initial_keys) == 4 + + # Wait for expiration + time.sleep(1.1) + + # Keys should still show all entries (including expired) + keys_before_clear = cache.keys() + assert len(keys_before_clear) == 4 + + # Clear expired entries + expired_keys = cache.clear_expired() + assert len(expired_keys) == 2 + + # Now keys should only show non-expired entries + keys_after_clear = cache.keys() + assert len(keys_after_clear) == 2 + assert "long_lived_1" in keys_after_clear + assert "long_lived_2" in keys_after_clear + assert "short_lived_1" not in keys_after_clear + assert "short_lived_2" not in keys_after_clear + + +def test_keys_write_mode_large_dataset(tmpdir): + """Test keys() method with large dataset in write mode.""" + cache_file = os.path.join(tmpdir, "cache_keys_large") + cache = PersistentCache(cache_file) + + # Add a large number of entries + num_entries = 1000 + for i in range(num_entries): + cache.set(f"key_{i:04d}", f"value_{i}") + + # Test keys retrieval + keys = cache.keys() + assert len(keys) == num_entries + + # Verify all keys are present + expected_keys = {f"key_{i:04d}" for i in range(num_entries)} + actual_keys = set(keys) + assert actual_keys == expected_keys + + +def test_keys_read_mode_corrupted_database(tmpdir): + """Test keys() method handles corrupted database gracefully.""" + cache_file = os.path.join(tmpdir, "cache_keys_corrupted") + cache = PersistentCache(cache_file) + + # Add some valid data first + cache.set("valid_key", "valid_value") + + # Verify keys work with valid data + keys = cache.keys() + assert "valid_key" in keys + + # The keys() method should handle database issues gracefully + # (specific corruption testing would require manipulating the database file) + # For now, we test that it doesn't crash with normal operations + assert isinstance(keys, list) + assert all(isinstance(key, str) for key in keys) + + def test_multithread_shared_cache_comprehensive(tmpdir): """Test shared cache instance across multiple threads using get->set pattern. @@ -206,9 +382,11 @@ def test_multithread_shared_cache_comprehensive(tmpdir): shared_cache = PersistentCache(cache_file) shared_cache.clear_expired() + num_keys = 5_000 + # Generate key-value pairs for first run (overlapping patterns to ensure intersections) - first_dict = {f"key_{i}": f"first_value_{i}" for i in range(5_000)} - assert len(first_dict) == 5_000 + first_dict = {f"key_{i}": f"first_value_{i}" for i in range(num_keys)} + assert len(first_dict) == num_keys s = time.perf_counter() # First concurrent run with get->set pattern @@ -221,13 +399,21 @@ def test_multithread_shared_cache_comprehensive(tmpdir): ) print(f"First run time: {(time.perf_counter() - s) * 1000:.0f} ms") + assert len(shared_cache.keys()) == len(first_dict) + for key in shared_cache.keys(): + assert shared_cache.get(key) == first_dict[key] + shared_cache.clear_expired() + assert len(shared_cache.keys()) == len(first_dict) for key in shared_cache.keys(): assert shared_cache.get(key) == first_dict[key] # Generate key-value pairs for first run (overlapping patterns to ensure intersections) - second_dict = {f"key_{i}": f"second_value_{i}" for i in range(2_500, 7_500)} + second_dict = { + f"key_{i}": f"second_value_{i}" + for i in range(num_keys // 2, num_keys // 2 + num_keys) + } s = time.perf_counter() # First concurrent run with get->set pattern @@ -251,7 +437,7 @@ def test_multithread_shared_cache_comprehensive(tmpdir): # Shared worker functions for concurrency tests -def _shared_worker_get_set_pattern(cache, key_value_pairs, expires_in=1000): +def _shared_worker_get_set_pattern(cache, key_value_pairs, expires_in=36_000): """Shared worker implementation: get key -> if not exist then set key=value.""" for key, value in key_value_pairs: # Pattern: get a key -> if not exist then set key=value @@ -268,69 +454,255 @@ def _shared_worker_get_set_pattern(cache, key_value_pairs, expires_in=1000): ) -def test_keys_basic(tmpdir): - """Test keys() method in read mode with empty cache.""" - cache_file = os.path.join(tmpdir, "cache_keys_empty") +def _multiprocess_worker_comprehensive(args): + """Worker function for multiprocess comprehensive test. + + Each process creates its own PersistentCache instance but uses the same cache file. + Keys are prefixed with process_id to avoid conflicts. + """ + cache_file, process_id, num_keys = args + + # Create cache instance per process (but same file) cache = PersistentCache(cache_file) - cache.set("key1", "value1") - # Test keys on non-existent cache file - keys = cache.keys() - assert keys == ["key1"] + # Generate process-specific key-value pairs + process_dict = { + f"process_{process_id}_key_{i}": f"process_{process_id}_value_{i}" + for i in range(num_keys) + } + # Use the same get->set pattern as the original test + with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: + list( + executor.map( + lambda kv: _shared_worker_get_set_pattern(cache, [kv]), + process_dict.items(), + ) + ) + + # Verify all process-specific keys were set correctly + for key, expected_value in process_dict.items(): + actual_value = cache.get(key) + assert actual_value == expected_value, ( + f"Process {process_id}: Expected {expected_value}, got {actual_value} for key {key}" + ) + + # Return process results for verification + return process_id, process_dict + + +def test_multiprocess_shared_cache_comprehensive(tmpdir): + """Test shared cache file across multiple processes using get->set pattern. + + This test runs the comprehensive cache test across multiple processes where: + - All processes use the same cache file but create their own PersistentCache instance + - Each process has its own keys prefixed with process_id to avoid conflicts + - Reuses the logic from test_multithread_shared_cache_comprehensive + """ + cache_file = os.path.join(tmpdir, "cache_multiprocess_comprehensive") + + # Initialize cache and clear any existing data + init_cache = PersistentCache(cache_file) + init_cache.clear_expired() + + num_processes = 4 + keys_per_process = 1000 + + # Prepare arguments for each process + process_args = [ + (cache_file, process_id, keys_per_process) + for process_id in range(num_processes) + ] + + s = time.perf_counter() + + # Run multiple processes concurrently + with multiprocessing.Pool(processes=num_processes) as pool: + results = pool.map(_multiprocess_worker_comprehensive, process_args) + + print(f"Multiprocess run time: {(time.perf_counter() - s) * 1000:.0f} ms") + + # Verify results from all processes + final_cache = PersistentCache(cache_file) + final_cache.clear_expired() + + # Collect all expected keys and values from all processes + all_expected_keys = {} + for process_id, process_dict in results: + all_expected_keys.update(process_dict) + + # Verify total number of keys + final_keys = final_cache.keys() + assert len(final_keys) == len(all_expected_keys), ( + f"Expected {len(all_expected_keys)} keys, got {len(final_keys)}" + ) + + # Verify all keys from all processes are present and have correct values + for expected_key, expected_value in all_expected_keys.items(): + actual_value = final_cache.get(expected_key) + assert actual_value == expected_value, ( + f"Expected {expected_value}, got {actual_value} for key {expected_key}" + ) + + # Verify keys are properly distributed across processes + for process_id in range(num_processes): + process_keys = [ + key for key in final_keys if key.startswith(f"process_{process_id}_") + ] + assert len(process_keys) == keys_per_process, ( + f"Process {process_id} should have {keys_per_process} keys, got {len(process_keys)}" + ) + + print( + f"Successfully verified {len(all_expected_keys)} keys across {num_processes} processes" + ) + + +def test_multithread_write_without_database_lock_errors(tmpdir): + cache_file = os.path.join(tmpdir, "cache_locking") + assert not os.path.exists(cache_file) -def test_keys_read_mode_empty_cache(tmpdir): - """Test keys() method in read mode with empty cache.""" - cache_file = os.path.join(tmpdir, "cache_keys_empty") cache = PersistentCache(cache_file) - # Test keys on non-existent cache file - keys = cache.keys() - assert keys == [] + def _cache_set(cache, key, num_sets): + for _ in range(num_sets): + cache.set(key, "value1") + with concurrent.futures.ThreadPoolExecutor(max_workers=40) as executor: + futures = [ + executor.submit(_cache_set, cache, str(key), num_sets=1) + for key in range(1000) + ] + r = [f.result() for f in futures] -def test_sqlite_database_locking(tmpdir): - """Test database locking with multiple threads accessing the same cache file.""" - import sqlite3 + assert 1000 == len(cache.keys()) - cache_file = os.path.join(tmpdir, "cache_sqlite_locking") - def create_table(value): - while True: - try: - with sqlite3.connect( - cache_file, autocommit=True - ) as conn: # Create the database file - conn.execute( - "CREATE TABLE IF NOT EXISTS cache (key TEXT, value TEXT)" - ) +def _cache_set(cache_file, key, num_sets): + cache = PersistentCache(cache_file) + for i in range(num_sets): + cache.set(f"key_{key}_{i}", f"value_{key}_{i}") + + +def test_multiprocess_write_without_database_lock_errors(tmpdir): + """Test no database locking errors with multiple processes accessing the same cache file.""" + + cache_file = os.path.join(tmpdir, "cache_locking") + assert not os.path.exists(cache_file) + + with concurrent.futures.ProcessPoolExecutor(max_workers=40) as executor: + futures = [ + executor.submit(_cache_set, cache_file, str(key), num_sets=1) + for key in range(10000) + ] + r = [f.result() for f in futures] + + cache = PersistentCache(cache_file) + assert 10000 == len(cache.keys()) + + +def _sqlite_insert_rows(cache_file, value, num_inserts=1): + while True: + try: + with sqlite3.connect(cache_file) as conn: + conn.execute("PRAGMA journal_mode = wal") + conn.execute("CREATE TABLE IF NOT EXISTS cache (key TEXT, value TEXT)") + + with sqlite3.connect(cache_file) as conn: + for _ in range(num_inserts): conn.execute( "INSERT INTO cache (key, value) VALUES (?, ?)", ("key1", value) ) - except sqlite3.OperationalError as e: - if "database is locked" in str(e): - print(f"Thread {value} failed to acquire lock: {e}") - time.sleep(1) - continue - else: - raise + except sqlite3.OperationalError as e: + traceback.print_exc() + if "database is locked" in str(e): + time.sleep(1) + continue else: - break + raise + else: + break - with sqlite3.connect( - cache_file, autocommit=True - ) as conn: # Create the database file - rows = [row for row in conn.execute("select * from cache")] - r = [] - with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor: - futures = [executor.submit(create_table, str(_)) for _ in range(1000)] - for f in futures: - r.append(f.result()) +def test_multiprocess_sqlite_database_locking(tmpdir): + """Test database locking with multiple threads accessing the same cache file.""" + + cache_file = os.path.join(tmpdir, "cache_sqlite_locking") + assert not os.path.exists(cache_file) + + num_items = 4000 + num_inserts = 1 + + with concurrent.futures.ProcessPoolExecutor(max_workers=40) as executor: + futures = [ + executor.submit( + _sqlite_insert_rows, cache_file, str(val), num_inserts=num_inserts + ) + for val in range(num_items) + ] + r = [f.result() for f in futures] + + with sqlite3.connect(cache_file) as conn: + row_count = len([row for row in conn.execute("select * from cache")]) + + assert row_count == num_items * num_inserts, ( + f"Expected {num_items * num_inserts} rows, got {row_count}" + ) + + with concurrent.futures.ProcessPoolExecutor(max_workers=40) as executor: + futures = [ + executor.submit( + _sqlite_insert_rows, cache_file, str(val), num_inserts=num_inserts + ) + for val in range(num_items) + ] + r = [f.result() for f in futures] + + with sqlite3.connect(cache_file) as conn: + row_count = len([row for row in conn.execute("select * from cache")]) + + assert row_count == (num_items * num_inserts) * 2, ( + f"Expected {(num_items * num_inserts) * 2} rows, got {row_count}" + ) + + +def test_multithread_sqlite_database_locking(tmpdir): + """Test database locking with multiple threads accessing the same cache file.""" + + cache_file = os.path.join(tmpdir, "cache_sqlite_locking") + assert not os.path.exists(cache_file) + + num_items = 4000 + num_inserts = 1 + + with concurrent.futures.ThreadPoolExecutor(max_workers=40) as executor: + futures = [ + executor.submit( + _sqlite_insert_rows, cache_file, str(val), num_inserts=num_inserts + ) + for val in range(num_items) + ] + r = [f.result() for f in futures] + + with sqlite3.connect(cache_file) as conn: + row_count = len([row for row in conn.execute("select * from cache")]) + + assert row_count == num_items * num_inserts, ( + f"Expected {num_items * num_inserts} rows, got {row_count}" + ) + + with concurrent.futures.ThreadPoolExecutor(max_workers=40) as executor: + futures = [ + executor.submit( + _sqlite_insert_rows, cache_file, str(val), num_inserts=num_inserts + ) + for val in range(num_items) + ] + r = [f.result() for f in futures] - with sqlite3.connect( - cache_file, autocommit=True - ) as conn: # Create the database file + with sqlite3.connect(cache_file) as conn: row_count = len([row for row in conn.execute("select * from cache")]) - assert row_count == 3000, f"Expected 3000 rows, got {row_count}" + assert row_count == num_items * num_inserts * 2, ( + f"Expected {num_items * num_inserts * 2} rows, got {row_count}" + ) From 7aca1d82b6760d6d0848be7959c1d8717986a4e6 Mon Sep 17 00:00:00 2001 From: Tao Peng Date: Wed, 3 Sep 2025 21:46:33 -0700 Subject: [PATCH 10/10] types --- mapillary_tools/history.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mapillary_tools/history.py b/mapillary_tools/history.py index eb07175c..10515313 100644 --- a/mapillary_tools/history.py +++ b/mapillary_tools/history.py @@ -202,7 +202,7 @@ def _decode(self, raw_payload: bytes) -> JSONDict: return {} if not isinstance(data, dict): - LOG.warning(f"Invalid cache value format: {raw_payload}") + LOG.warning(f"Invalid cache value format: {raw_payload!r}") return {} return data