Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions invokeai/app/services/model_install/model_install_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ def __init__(
self._stop_event = threading.Event()
self._downloads_changed_event = threading.Event()
self._install_completed_event = threading.Event()
self._restore_completed_event = threading.Event()
self._restore_completed_event.set()
self._download_queue = download_queue
self._download_cache: Dict[int, ModelInstallJob] = {}
self._running = False
Expand Down Expand Up @@ -264,16 +266,23 @@ def _restore_incomplete_installs(self) -> None:
self._safe_rmtree(job._install_tmpdir, self._logger)

def _restore_incomplete_installs_async(self) -> None:
self._restore_completed_event.clear()

def _run() -> None:
try:
self._logger.info("Restoring incomplete installs")
self._restore_incomplete_installs()
self._logger.info("Finished restoring incomplete installs")
except Exception as e:
self._logger.error(f"Failed to restore incomplete installs: {e}")
finally:
self._restore_completed_event.set()

threading.Thread(target=_run, daemon=True).start()

def _wait_for_restore_complete(self) -> None:
self._restore_completed_event.wait()

def _resume_remote_download(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.WAITING
if job.download_parts:
Expand Down Expand Up @@ -459,6 +468,8 @@ def heuristic_import(
return self.import_model(source_obj, config)

def import_model(self, source: ModelSource, config: Optional[ModelRecordChanges] = None) -> ModelInstallJob: # noqa D102
self._wait_for_restore_complete()

similar_jobs = [x for x in self.list_jobs() if x.source == source and not x.in_terminal_state]
if similar_jobs:
self._logger.warning(f"There is already an active install job for {source}. Not enqueuing.")
Expand Down Expand Up @@ -506,6 +517,8 @@ def wait_for_job(self, job: ModelInstallJob, timeout: int = 0) -> ModelInstallJo

def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa D102
"""Block until all installation jobs are done."""
self._wait_for_restore_complete()

start = time.time()
while len(self._download_cache) > 0:
if self._downloads_changed_event.wait(timeout=0.25): # in case we miss an event
Expand Down
51 changes: 51 additions & 0 deletions tests/app/services/model_install/test_model_install.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import gc
import platform
import shutil
import threading
import time
import uuid
from pathlib import Path
from typing import Any, Dict
Expand Down Expand Up @@ -321,6 +323,55 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
assert isinstance(bus.events[4], ModelInstallCompleteEvent) # install completed


def test_import_waits_for_startup_restore(
mm2_app_config: InvokeAIAppConfig,
mm2_record_store,
mm2_download_queue,
mm2_session,
embedding_file: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
installer = ModelInstallService(
app_config=mm2_app_config,
record_store=mm2_record_store,
download_queue=mm2_download_queue,
event_bus=TestEventService(),
session=mm2_session,
)
restore_started = threading.Event()
release_restore = threading.Event()
imported = threading.Event()

def _blocked_restore() -> None:
restore_started.set()
assert release_restore.wait(timeout=5)

monkeypatch.setattr(installer, "_restore_incomplete_installs", _blocked_restore)

try:
installer.start()
assert restore_started.wait(timeout=5)

import_thread = threading.Thread(
target=lambda: (
installer.import_model(LocalModelSource(path=embedding_file)),
imported.set(),
)
)
import_thread.start()

time.sleep(0.1)
assert not imported.is_set()

release_restore.set()
import_thread.join(timeout=5)
assert imported.is_set()
installer.wait_for_installs(timeout=5)
finally:
release_restore.set()
installer.stop()


@pytest.mark.timeout(timeout=10, method="thread")
def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
Expand Down
Loading