diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 49d3cfdf7f9..776c46d004f 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -112,6 +112,8 @@ def __init__( self._stop_event = threading.Event() self._downloads_changed_event = threading.Event() self._install_completed_event = threading.Event() + self._restore_completed_event = threading.Event() + self._restore_completed_event.set() self._download_queue = download_queue self._download_cache: Dict[int, ModelInstallJob] = {} self._running = False @@ -264,6 +266,8 @@ def _restore_incomplete_installs(self) -> None: self._safe_rmtree(job._install_tmpdir, self._logger) def _restore_incomplete_installs_async(self) -> None: + self._restore_completed_event.clear() + def _run() -> None: try: self._logger.info("Restoring incomplete installs") @@ -271,9 +275,14 @@ def _run() -> None: self._logger.info("Finished restoring incomplete installs") except Exception as e: self._logger.error(f"Failed to restore incomplete installs: {e}") + finally: + self._restore_completed_event.set() threading.Thread(target=_run, daemon=True).start() + def _wait_for_restore_complete(self) -> None: + self._restore_completed_event.wait() + def _resume_remote_download(self, job: ModelInstallJob) -> None: job.status = InstallStatus.WAITING if job.download_parts: @@ -459,6 +468,8 @@ def heuristic_import( return self.import_model(source_obj, config) def import_model(self, source: ModelSource, config: Optional[ModelRecordChanges] = None) -> ModelInstallJob: # noqa D102 + self._wait_for_restore_complete() + similar_jobs = [x for x in self.list_jobs() if x.source == source and not x.in_terminal_state] if similar_jobs: self._logger.warning(f"There is already an active install job for {source}. Not enqueuing.") @@ -506,6 +517,8 @@ def wait_for_job(self, job: ModelInstallJob, timeout: int = 0) -> ModelInstallJo def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa D102 """Block until all installation jobs are done.""" + self._wait_for_restore_complete() + start = time.time() while len(self._download_cache) > 0: if self._downloads_changed_event.wait(timeout=0.25): # in case we miss an event diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index de001178adc..2742734d95d 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -5,6 +5,8 @@ import gc import platform import shutil +import threading +import time import uuid from pathlib import Path from typing import Any, Dict @@ -321,6 +323,55 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: assert isinstance(bus.events[4], ModelInstallCompleteEvent) # install completed +def test_import_waits_for_startup_restore( + mm2_app_config: InvokeAIAppConfig, + mm2_record_store, + mm2_download_queue, + mm2_session, + embedding_file: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + installer = ModelInstallService( + app_config=mm2_app_config, + record_store=mm2_record_store, + download_queue=mm2_download_queue, + event_bus=TestEventService(), + session=mm2_session, + ) + restore_started = threading.Event() + release_restore = threading.Event() + imported = threading.Event() + + def _blocked_restore() -> None: + restore_started.set() + assert release_restore.wait(timeout=5) + + monkeypatch.setattr(installer, "_restore_incomplete_installs", _blocked_restore) + + try: + installer.start() + assert restore_started.wait(timeout=5) + + import_thread = threading.Thread( + target=lambda: ( + installer.import_model(LocalModelSource(path=embedding_file)), + imported.set(), + ) + ) + import_thread.start() + + time.sleep(0.1) + assert not imported.is_set() + + release_restore.set() + import_thread.join(timeout=5) + assert imported.is_set() + installer.wait_for_installs(timeout=5) + finally: + release_restore.set() + installer.stop() + + @pytest.mark.timeout(timeout=10, method="thread") def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))