diff --git a/CHANGELOG.md b/CHANGELOG.md index 71d082e62..bf6f950a0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,13 +5,291 @@ All notable changes to Stability Matrix will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning 2.0](https://semver.org/spec/v2.0.0.html). -## v2.15.8 +## v2.16.0 +### Added +#### New Feature: πŸ§ͺ Image Lab - Conversational Image Generation for ComfyUI +- We've added a brand new conversational interface for image generation! Image Lab lets you iterate on images naturally through chat, rather than just one-off prompts. + - Local-First Power: Native support for Flux Kontext, Qwen Image Edit, and the Apache 2.0-licensed Flux.2 Klein running entirely locally via your ComfyUI backend. + - Smart Setup: Stability Matrix automatically detects and helps you download the specific models and LoRAs needed for these local workflows. + - Interactive Tools: Drag-and-drop image inputs, use the built-in annotation tool to draw on images, and keep persistent conversation history. + - Cloud Option: Includes optional support for Nano Banana (Gemini 3 Pro / 2.5) and Nano Banana 2 (Gemini 3.1 Flash) for users who want to leverage external reasoning models. +- Added Regional Prompting addon to Inference - paint detailed masks to apply different prompts, strengths, and settings to specific regions of your image + - Multi-layer mask editor with Photoshop-style interface for managing layers with independent masks, prompts, colors, and opacity + - Professional brush tools: freehand brush/eraser with pressure sensitivity, rectangle/ellipse shapes with fill/stroke modes, paint bucket flood fill + - Brush feathering/softness control for smooth, blended mask edges (0 = hard edge, 1 = soft/blurred) + - Per-layer prompt and strength controls, export/import masks as PNG, duplicate layers, image reference layers for tracing + - GPU-accelerated rendering with compact gzip-compressed metadata serialization +- Added official Inference support for the **Z-Image** (Base + Turbo), **Anima**, and **Flux.2** model architectures β€” workflow-appropriate text encoders, latent shapes, schedulers, and model sampling (AuraFlow for Z-Image, `Flux2Scheduler` for Flux.2) are wired up automatically across Text-to-Image and Image-to-Image +- Added an Inference **Workflow** selector to the Model card with profiles for Default/Checkpoint, Flux, Flux.2, Z-Image Base/Turbo, Anima, HiDream, and Custom + - **Auto** (default) detects the workflow from the model's CivitAI metadata, with filename fallbacks for models without metadata, and shows the resolved profile inline below the selector + - Sparkle button applies recommended sampler / scheduler / steps / CFG presets for the active workflow β€” e.g. `res_multistep` / `simple` / 8 steps / CFG 1 for Z-Image Turbo, `er_sde` / `simple` / 30 steps / CFG 4 for Anima, `euler` / 20 steps / CFG 5 for Flux.2 + - Choosing a non-Auto profile reveals a manual Encoder Type selector for advanced overrides (e.g. running Z-Image Turbo with the `sd3` encoder) + - Opening the model browser from the Model card pre-filters to the workflow's compatible base models, without overwriting your saved picker filters +- Added CivArchive model browser with details page, image viewer, version selector, trigger words, and in-app downloads with tracked progress +- Added a checkpoint organizer for previewing and reorganizing local models using connected metadata-driven folder and filename patterns (requested in [#280](https://github.com/LykosAI/StabilityMatrix/issues/280), [#424](https://github.com/LykosAI/StabilityMatrix/issues/424)) +- Added a new Model Picker dialog for Inference with grid/list views, search, filtering, and NSFW overlay +- Added browse buttons to all model dropdowns in Inference (Model, Refiner, VAE, Text Encoders, CLIP Vision) +- Added an inline search box to model combo box dropdowns with fuzzy matching +- Added a **Source** button in the Inference SamplerCard that one-click matches your generation Width/Height to the loaded source image β€” available in Image-to-Image whenever a source image is selected +- Added popularity counts to booru-style tag completions in the prompt editor; descriptions now show entries like `12.3K Β· artist` so the more common tags are easier to spot at a glance +- Added a settings gear button to the CivitAI browser's Base Models filter flyout that jumps straight to the base model filter configuration in Settings +- Added `er_sde` and `res_multistep` to the Inference sampler list +- Added `stable_diffusion`, `flux2`, and `lumina2` Encoder Type options for UNet workflows +- Added a **Bitsandbytes NF4** launch option to Stable Diffusion WebUI Forge - Neo for low-bit (`--bnb`) inference +- Added an **Activity center**: the sidebar download panel now has a **Notifications** tab alongside **In Progress**. Toasts are clickable β€” jumping to the downloaded folder, the originating page (e.g. Inference), or the activity panel β€” and persist into a session notification history (every notification is recorded, even ones suppressed by your settings) with read/unread indicators and a combined unread + active-download badge on the sidebar item +- Added an **"Always Show Scrollbars"** toggle under **Settings β†’ Appearance**. Defaults on β€” vertical scrollbars stay visible at their full thickness and reserve real layout space instead of fading to a thin overlay-style bar that only thickens on hover. Toggle off to restore Avalonia's classic auto-hide behavior. Single-line numeric inputs (e.g. SamplerCard Width/Height) keep their auto-hide regardless so spin-buttons aren't followed by a phantom bar +- Added new shared model folder categories β€” **Style Models**, **Audio Encoders**, **Model Patches**, and **Background Removal** β€” for ComfyUI's `style_models`, `audio_encoders`, `model_patches`, and `background_removal` directories. Models in these folders are now indexed and symlinked alongside everything else (e.g. Flux Redux / B-Lora style models, audio encoders for video/audio workflows, BiRefNet background-removal models) +- Added Intel GPU support for ComfyUI +- Added "Run Python Command" option to the package card's 3-dots menu for running arbitrary Python code in the package's virtual environment +- Added a recoverable error dialog for UI thread exceptions, with option to continue instead of exiting +- Added enable/disable toggle for environment variables in Settings, allowing variables to be temporarily disabled without deleting them +### Changed +- Promoted the Encoder Type selector in the Inference Model card out of Advanced Options up to the main card body, so it's visible whenever a non-Auto workflow profile is active (and always when **Custom** is selected) +- Tidied up the Inference SamplerCard dimensions section β€” Source/Presets actions are shown as labeled buttons below the dimension row +- The Inference checkpoint dropdown no longer **resets its scroll position** every time the model list refreshes. The refresh now applies a single combined (local + remote) diff to the underlying source cache, rather than first resetting to local-only and then re-adding remote entries β€” which previously caused the open dropdown to scroll back to the top +- Local model autocomplete in the prompt editor now uses substring matching instead of prefix-only β€” typing any part of a model's filename surfaces it, with names that start with your search still ranked first +- Single-encoder UNet workflows (Anima, Flux.2, Z-Image) now use the matching CLIPLoader instead of assuming Flux-style dual encoders +- The CivitAI model details page now collapses the preview-image area and shows a small "No preview images available" hint when a model has no images to display, letting the description card take the full vertical space instead of leaving a large empty region above it +- Improved the Gemini API error message in Image Lab when the API returns 401/403 to point users at Google's API key restriction policy (which starts blocking unrestricted keys on June 19 2026) +- Improved safetensor checkpoint classification to correctly detect UNet-only models for Wan Video, HiDream, Z-Image, Hunyuan3D, and diffusers-format Flux architectures, ensuring they are routed to the DiffusionModels folder +- GGUF checkpoint downloads now go directly to the DiffusionModels folder instead of StableDiffusion +- Updated AI-Toolkit to install torch 2.9.1 / torchvision 0.24.1 / torchaudio 2.9.1 from the cu128 index to match upstream (ostris/ai-toolkit), with a cu126 fallback for legacy NVIDIA GPUs; also pin numpy to 1.26.4 to avoid a numpy 2.x ABI break in scipy/diffusers that crashed training runs +- Pinned kohya_ss torch to 2.7.0 / torchvision 0.22.0 (cu128) to match upstream's requirements_pytorch_windows.txt instead of resolving an untested latest, keeping the cu126 legacy-GPU fallback +- Pinned reForge torch to 2.9.0 to match upstream (modules/launch_utils.py) +- Updated ComfyUI installs to cu130 (cu126 for legacy NVIDIA GPUs) / rocm7.2 torch indexes depending on GPU +- Upgraded the bundled Visual C++ redistributable from 2015–2019 (v16) to 2015–2022 (v17, build 14.40.33810+), required by modern native dependencies such as PyTorch and ONNX Runtime +- Video files can now be opened directly from the Output browser +- Videos will now appear with thumbnails in the Output browser +- Configured portable Git to suppress detached HEAD advice messages +### Fixed +- Fixed Inference text encoder selections being cleared when navigating away from and back to the Inference tab β€” encoder slots now ignore the transient null the model dropdown reports while its list refreshes +- Fixed UNet-only Inference model selection sometimes clearing during model-list refreshes β€” text encoder slots no longer disappear after generating, cancelling a generation, or reconnecting to ComfyUI +- Fixed [#1585](https://github.com/LykosAI/StabilityMatrix/issues/1585) - FluxGym installs/updates pulling an incompatible `transformers` version β€” installs now pin `transformers==4.54.1` and exclude it from the default requirements pass +- Fixed [#1641](https://github.com/LykosAI/StabilityMatrix/issues/1641) - Cogstudio failing to set up its `inference/gradio_composite_demo` directory when the parent path didn't already exist +- Fixed [#1650](https://github.com/LykosAI/StabilityMatrix/issues/1650) - ComfyUI-Manager extension installs failing on Linux with `File not found: venv/uv-build-constraints.txt` by no longer leaking the relative build-constraints path into the running package's environment +- Fixed [#1645](https://github.com/LykosAI/StabilityMatrix/issues/1645) - Strix Halo / Radeon 8060S and other `Display controller`-class integrated GPUs not appearing in the GPU list on Linux +- Fixed [#1643](https://github.com/LykosAI/StabilityMatrix/issues/1643) - package install and launch failures when `sitecustomize.py` or its compiled bytecode was corrupted by external software (e.g. some antivirus suites); the file now self-heals when out of date and its startup actions can no longer abort interpreter startup +- Fixed the CivitAI model browser requiring two clicks of Search to show results when all base-model filters were selected β€” a leftover post-response sanity check from the old single-select base-model UI was rejecting the response the first time, requiring a second search to surface the cached results +- Fixed CivitAI model cards showing "No versions available" when clicked for some models (typically recently uploaded or updated) even though the model has downloadable versions on the website β€” the app now retries with a different lookup path when the initial response comes back missing version data +- Fixed `$#1234` and `civitai.com/models/1234` URL searches returning zero results for some models that exist and are downloadable on the website β€” the app now retries via a per-model lookup when the batch search misses a requested ID +- Fixed `$#1234` searches with non-LORA / non-Checkpoint targets returning no results when the **Model Type** dropdown wasn't set to **All** β€” ID searches intentionally bypass the type and base-model filters in the request, but the post-response check was still rejecting the returned model when its type didn't match the dropdown +- Fixed clicking a CivitAI model card with an empty version list appearing to do nothing for ~1–2s while the recovery round-trip runs β€” the clicked card now shows a "Loading..." state during the recovery, and the recovered version data is cached on the card so subsequent clicks are instant +- Fixed "Invalid download link" error when using the browser extension +- Fixed downloaded checkpoint going to StableDiffusion folder when a saved download preference existed, even for GGUF files that should always go to DiffusionModels +- Fixed potential crash when adding metadata to malformed or non-PNG image data in Inference +- Fixed non-Latin-1 characters (e.g. Japanese, Chinese, Korean, emoji) in image generation parameters being stored in PNG tEXt chunks, violating the PNG specification and causing character corruption (mojibake) in standard-compliant parsers. Non-Latin-1 content now uses spec-compliant iTXt chunks with proper UTF-8 encoding ([#1535](https://github.com/LykosAI/StabilityMatrix/issues/1535)) +- Fixed batch notification firing when only one image is generated +### Security +- Updated the bundled 7-Zip binaries (Windows, Linux, macOS) to **26.01**, which includes the fix for the NTFS heap buffer overflow CVE-2026-48095 ([GitHub Security Lab GHSL-2026-140](https://securitylab.github.com/advisories/GHSL-2026-140_7-Zip/), CVSS 8.8) and brings years of accumulated upstream security fixes β€” the Windows binary in particular had been pinned at the 2018 18.01 release +- Package updates now re-run the prerequisite setup step (as installs already do), so the bundled 7-Zip binary is refreshed on update instead of only on a fresh package install +### Supporters +#### 🌟 Visionaries +An enormous thank you to our incredible Visionaries: **Waterclouds**, **bluepopsicle**, **Ibixat**, **Droolguy**, **snotty**, **LG**, **whudunit**, **MrMxyzptlk12836**, **Psilocyfer18731**, **KalAbaddon**, and **moon_milky2843**! This was a huge release, and every bit of it rests on your generosity. Whether you've been cheering us on for years or only just joined, having you in our corner is what makes all of this possible. We're so grateful for you. πŸ’› +#### πŸš€ Pioneers +And what a Pioneer crew this release! A heartfelt thank you to the regulars who keep showing up for us: **Szir777**, **[USA]TechDude**, **SinthCore**, **Jisuren**, **Tigon**, **jweg79**, **rwx14662**, **Hurbie53**, **ahnhj.al**, **drew.lukas**, **Tuskaruho**, **Cjloha**, **Alligator1907**, **Bitti**, **damianpointdexter**, **Ghislain G**, and **tmdcks**! Your steady support release after release is what keeps us at the keyboard. And the warmest of welcomes to our newest Pioneers: **CommissarGiygas16050**, **qob97515211**, **bastardofbethlehem**, and **Zombop** β€” we're thrilled you've joined us, and we can't wait to get to know you! (And to our anonymous Pioneer out there too β€” our thanks reaches you. πŸ’›) + +## v2.16.0-pre.2 +### Security +- Updated the bundled 7-Zip binaries (Windows, Linux, macOS) to **26.01**, which includes the fix for the NTFS heap buffer overflow CVE-2026-48095 ([GitHub Security Lab GHSL-2026-140](https://securitylab.github.com/advisories/GHSL-2026-140_7-Zip/), CVSS 8.8) and brings years of accumulated upstream security fixes β€” the Windows binary in particular had been pinned at the 2018 18.01 release +- Package updates now re-run the prerequisite setup step (as installs already do), so the bundled 7-Zip binary is refreshed on update instead of only on a fresh package install +### Added +- Added a **Source** button in the Inference SamplerCard that one-click matches your generation Width/Height to the loaded source image β€” available in Image-to-Image whenever a source image is selected +- Added popularity counts to booru-style tag completions in the prompt editor; descriptions now show entries like `12.3K Β· artist` so the more common tags are easier to spot at a glance +- Added a settings gear button to the CivitAI browser's Base Models filter flyout that jumps straight to the base model filter configuration in Settings +- Added a **Bitsandbytes NF4** launch option to Stable Diffusion WebUI Forge - Neo for low-bit (`--bnb`) inference +- Added an **Activity center**: the sidebar download panel now has a **Notifications** tab alongside **In Progress**. Toasts are clickable β€” jumping to the downloaded folder, the originating page (e.g. Inference), or the activity panel β€” and persist into a session notification history (every notification is recorded, even ones suppressed by your settings) with read/unread indicators and a combined unread + active-download badge on the sidebar item +- Added **Gemini 3.1 Flash (Nano Banana 2)** as a new cloud provider in Image Lab β€” Google's latest fast image model, sitting between Gemini 2.5 Flash and Gemini 3 Pro in the dropdown. Uses the newer `thinking_level` config when thinking is enabled, and falls back to the model's default behavior otherwise +- Added **Flux.2 Klein** as a new local provider in Image Lab. Klein 4B is **Apache 2.0 licensed** (commercial use free, unlike Flux.1 Kontext's non-commercial license), runs in just 4 sampling steps via the distilled variant, and supports up to 4 reference images per edit. Users with the Klein 9B UNET + matching Qwen3 8B text encoder dropped into their model folders will see those picked up automatically by the model dropdown +- Added **Steps and CFG sliders** to the Flux.2 Klein settings panel. The defaults snap automatically to the recommended values for the selected variant (4 / 1 for distilled, 20 / 5 for base β€” including community 9B fine-tunes and "base & distilled" merge listings), and you can override either at any time +- Added an **"Always Show Scrollbars"** toggle under **Settings β†’ Appearance**. Defaults on β€” vertical scrollbars stay visible at their full thickness and reserve real layout space instead of fading to a thin overlay-style bar that only thickens on hover. Toggle off to restore Avalonia's classic auto-hide behavior. Single-line numeric inputs (e.g. SamplerCard Width/Height) keep their auto-hide regardless so spin-buttons aren't followed by a phantom bar +- Added new shared model folder categories β€” **Style Models**, **Audio Encoders**, **Model Patches**, and **Background Removal** β€” for ComfyUI's `style_models`, `audio_encoders`, `model_patches`, and `background_removal` directories. Models in these folders are now indexed and symlinked alongside everything else (e.g. Flux Redux / B-Lora style models, audio encoders for video/audio workflows, BiRefNet background-removal models) +- Added a **Download progress indicator** to the Image Lab status banner. While a model-download batch is in flight, the banner shows "⬇️ Downloading models (N/Total)..." with the count bumping as each file completes, instead of continuing to display the "missing models" warning +- Greatly expanded native **Windows ROCm (AMD GPU)** support ([#1629](https://github.com/LykosAI/StabilityMatrix/pull/1629)) β€” the GPU detection matrix now spans **Vega / GCN5** (Vega 56/64, Radeon VII) through the entire **RDNA1/2/3/3.5** lineup and into **RDNA4** (RX 9070 / R9700), using TheRock ROCm Technical Preview PyTorch builds. ROCm install and launch now run through a shared helper, so the same Windows-native path is available to **ComfyUI, SwarmUI** (its Comfy backend), **reForge, InvokeAI, and Wan2GP** - thanks to @NeuralFault! +- Added optional **ROCm package commands** for ComfyUI on Windows β€” one-click install of **SageAttention**, **Flash Attention**, **bitsandbytes**, and the **ROCm SDK devel** module (for compiling extensions/modules against your installed ROCm) from the package's command menu - thanks to @NeuralFault! +### Changed +- Tidied up the Inference SamplerCard dimensions section β€” Source/Presets actions are shown as labeled buttons below the dimension row +- Promoted the Encoder Type selector in the Inference Model card out of Advanced Options up to the main card body, so it's visible whenever a non-Auto workflow profile is active (and always when **Custom** is selected) +- Local model autocomplete in the prompt editor now uses substring matching instead of prefix-only β€” typing any part of a model's filename surfaces it, with names that start with your search still ranked first +- Updated the bundled Qwen Image Edit model used by Image Lab to the **2511** build (Alibaba's November 2025 refresh) for better edit consistency, reduced drift on multi-turn edits, and stronger character/geometry preservation. Existing 2509 downloads continue to work β€” only newly installed setups pull the new file +- Improved the Gemini API error message in Image Lab when the API returns 401/403 to point users at Google's API key restriction policy (which starts blocking unrestricted keys on June 19 2026) +- Updated AI-Toolkit to install torch 2.9.1 / torchvision 0.24.1 / torchaudio 2.9.1 from the cu128 index to match upstream (ostris/ai-toolkit), with a cu126 fallback for legacy NVIDIA GPUs; also pin numpy to 1.26.4 to avoid a numpy 2.x ABI break in scipy/diffusers that crashed training runs +- Pinned kohya_ss torch to 2.7.0 / torchvision 0.22.0 (cu128) to match upstream's requirements_pytorch_windows.txt instead of resolving an untested latest, keeping the cu126 legacy-GPU fallback +- Pinned reForge torch to 2.9.0 to match upstream (modules/launch_utils.py) +- Upgraded the bundled Visual C++ redistributable from 2015–2019 (v16) to 2015–2022 (v17, build 14.40.33810+), required by modern native dependencies such as PyTorch and ONNX Runtime +- The CivitAI model details page now collapses the preview-image area and shows a small "No preview images available" hint when a model has no images to display, letting the description card take the full vertical space instead of leaving a large empty region above it +- Restructured the Image Lab provider settings panels (Flux Kontext, Qwen Image Edit, Flux.2 Klein) β€” every provider now shows the model dropdown on top and a bordered **LoRA card** below it spanning full width, with header + Add button + selected-LoRA list all inside one consistent card. Klein gets a second column with the Steps/CFG sliders to the right of its model dropdown. All three panels collapse to a single stacked column when the chat panel is narrow (< 720 px), driven by a pure-XAML `Classes.compact` width-binding via a new `WidthLessThanConverter` +- The Inference checkpoint dropdown no longer **resets its scroll position** every time the model list refreshes. The refresh now applies a single combined (local + remote) diff to the underlying source cache, rather than first resetting to local-only and then re-adding remote entries β€” which previously caused the open dropdown to scroll back to the top +- **PyTorch TunableOp** is now disabled by default on Windows ROCm. If you have an existing TunableOp cache and want to keep tuning GEMM kernels, add `PYTORCH_TUNABLEOP_ENABLED=1` under **Settings β†’ Environment Variables** - thanks to @NeuralFault! +### Fixed +- Fixed Inference text encoder selections being cleared when navigating away from and back to the Inference tab β€” encoder slots now ignore the transient null the model dropdown reports while its list refreshes +- Fixed [#1585](https://github.com/LykosAI/StabilityMatrix/issues/1585) - FluxGym installs/updates pulling an incompatible `transformers` version β€” installs now pin `transformers==4.54.1` and exclude it from the default requirements pass +- Fixed [#1641](https://github.com/LykosAI/StabilityMatrix/issues/1641) - Cogstudio failing to set up its `inference/gradio_composite_demo` directory when the parent path didn't already exist +- Fixed [#1650](https://github.com/LykosAI/StabilityMatrix/issues/1650) - ComfyUI-Manager extension installs failing on Linux with `File not found: venv/uv-build-constraints.txt` by no longer leaking the relative build-constraints path into the running package's environment +- Fixed [#1645](https://github.com/LykosAI/StabilityMatrix/issues/1645) - Strix Halo / Radeon 8060S and other `Display controller`-class integrated GPUs not appearing in the GPU list on Linux +- Fixed [#1643](https://github.com/LykosAI/StabilityMatrix/issues/1643) - package install and launch failures when `sitecustomize.py` or its compiled bytecode was corrupted by external software (e.g. some antivirus suites); the file now self-heals when out of date and its startup actions can no longer abort interpreter startup +- Fixed the CivitAI model browser requiring two clicks of Search to show results when all base-model filters were selected β€” a leftover post-response sanity check from the old single-select base-model UI was rejecting the response the first time, requiring a second search to surface the cached results +- Fixed CivitAI model cards showing "No versions available" when clicked for some models (typically recently uploaded or updated) even though the model has downloadable versions on the website β€” the app now retries with a different lookup path when the initial response comes back missing version data +- Fixed `$#1234` and `civitai.com/models/1234` URL searches returning zero results for some models that exist and are downloadable on the website β€” the app now retries via a per-model lookup when the batch search misses a requested ID +- Fixed `$#1234` searches with non-LORA / non-Checkpoint targets returning no results when the **Model Type** dropdown wasn't set to **All** β€” ID searches intentionally bypass the type and base-model filters in the request, but the post-response check was still rejecting the returned model when its type didn't match the dropdown +- Fixed clicking a CivitAI model card with an empty version list appearing to do nothing for ~1–2s while the recovery round-trip runs β€” the clicked card now shows a "Loading..." state during the recovery, and the recovered version data is cached on the card so subsequent clicks are instant +- Fixed "Invalid download link" error when using the browser extension +- Fixed the Image Lab status banner **flickering** between "Click Connect" and the button+text variants every second while ComfyUI was starting up β€” the per-attempt retry loop was toggling `IsConnecting` on/off, which propagated through `CanUserConnect` to the button's visibility. The retry loop now holds `IsWaitingForConnection = true` for its entire duration so the banner stays parked on "πŸ”„ Connecting to ComfyUI..." +- Fixed Image Lab thinking/reasoning content (Gemini 3 Pro / Gemini 3.1 Flash) where the **last few characters of long lines were clipped** under the vertical scrollbar. The HTML body now reserves a 12 px right gutter via CSS padding so text always wraps before reaching the scrollbar zone +- Fixed ComfyUI workflow rejections in Image Lab being surfaced as just "400 Bad Request" with no detail β€” the provider now logs the full JSON `node_errors` payload returned by ComfyUI so the failing node and validation error are visible in the log and the user-facing error message +- Fixed a phantom up/down arrow pair appearing next to the spin buttons on `NumberBox` / `NumericUpDown` controls (e.g. SamplerCard Width/Height) after the always-show scrollbar change went in β€” TextBox-derived controls have an internal vertical `ScrollBar` in their template that should stay hidden when content fits, and the global override was forcing it visible. A `TextBox ScrollBar:vertical` carve-out now keeps those on auto-hide regardless of the global setting +- Fixed user-set environment variables not overriding package-configured ROCm variables on Windows β€” launch env vars now layer as **helper defaults β†’ package config β†’ user-set**, so anything you set under **Settings β†’ Environment Variables** always wins - thanks to @NeuralFault! +- Fixed `pip show` "package not found" results being raised as exceptions instead of treated as a missing-package state, which could block installing optional modules (e.g. SageAttention, ROCm SDK) into a package's venv - thanks to @NeuralFault! +### Supporters +#### 🌟 Visionaries +An enormous thank you to our incredible Visionaries: **Waterclouds**, **bluepopsicle**, **Ibixat**, **Droolguy**, **snotty**, **LG**, **whudunit**, **MrMxyzptlk12836**, **Psilocyfer18731**, **KalAbaddon**, and **moon_milky2843**! Your steadfast generosity is the foundation every feature and fix in this release is built on. Whether you've been with us for ages or joined just last release, we're endlessly grateful you're in our corner. Thank you for believing in what we're building. πŸ’› + +## v2.16.0-pre.1 ### Added +- Added CivArchive model browser with details page, image viewer, version selector, trigger words, and in-app downloads with tracked progress - Added support for the civitai.red (mature-content) domain β€” NSFW CivitAI links now open and copy as civitai.red URLs, and pasting a civitai.red URL into the CivitAI model browser search works the same as a civitai.com URL +- Added official Inference support for the **Z-Image** (Base + Turbo), **Anima**, and **Flux.2** model architectures β€” workflow-appropriate text encoders, latent shapes, schedulers, and model sampling (AuraFlow for Z-Image, `Flux2Scheduler` for Flux.2) are wired up automatically across Text-to-Image and Image-to-Image +- Added an Inference **Workflow** selector to the Model card with profiles for Default/Checkpoint, Flux, Flux.2, Z-Image Base/Turbo, Anima, HiDream, and Custom + - **Auto** (default) detects the workflow from the model's CivitAI metadata, with filename fallbacks for models without metadata, and shows the resolved profile inline below the selector + - Sparkle button applies recommended sampler / scheduler / steps / CFG presets for the active workflow β€” e.g. `res_multistep` / `simple` / 8 steps / CFG 1 for Z-Image Turbo, `er_sde` / `simple` / 30 steps / CFG 4 for Anima, `euler` / 20 steps / CFG 5 for Flux.2 + - Choosing a non-Auto profile reveals a manual Encoder Type selector for advanced overrides (e.g. running Z-Image Turbo with the `sd3` encoder) + - Opening the model browser from the Model card pre-filters to the workflow's compatible base models, without overwriting your saved picker filters +- Added `er_sde` and `res_multistep` to the Inference sampler list +- Added `stable_diffusion`, `flux2`, and `lumina2` Encoder Type options for UNet workflows +- Added a checkpoint organizer for previewing and reorganizing local models using connected metadata-driven folder and filename patterns (requested in [#280](https://github.com/LykosAI/StabilityMatrix/issues/280), [#424](https://github.com/LykosAI/StabilityMatrix/issues/424)) ### Changed - The CivitAI base model type filter now uses CivitAI's official `/api/v1/enums` endpoint, with fallbacks to the previous technique and a built-in list, so the filter stays populated even if the CivitAI response format changes or the service is unreachable +- Single-encoder UNet workflows (Anima, Flux.2, Z-Image) now use the matching CLIPLoader instead of assuming Flux-style dual encoders ### Fixed +- Fixed CivitAI model browsing breaking during Discovery API outages β€” the browser now falls back to the direct CivitAI API when Discovery returns a server error, authentication failure, or times out +- Fixed UNet-only Inference model selection sometimes clearing during model-list refreshes β€” text encoder slots no longer disappear after generating, cancelling a generation, or reconnecting to ComfyUI +- Fixed SwarmUI user settings (theme, output format, server configuration, etc.) and any user-added backend entries being overwritten when the install flow ran over an existing install β€” `Settings.fds` and `Backends.fds` are now merged with their existing contents instead of being rewritten from a stale template +- Fixed pip requirements handling for environment-marker dependencies - thanks to @NeuralFault! - Fixed [#1608](https://github.com/LykosAI/StabilityMatrix/issues/1608) - Crash when cdn fetch fails due to error notification not being shown on UI Thread - thanks to @NeuralFault! +- Fixed ComfyUI-Zluda inheriting `--enable-manager` from the base ComfyUI launch options, which blocked the bundled custom-node manager from initializing - thanks to @NeuralFault! +### Supporters +#### 🌟 Visionaries +So much love to our Visionaries β€” **Waterclouds**, **bluepopsicle**, **Ibixat**, **Droolguy**, **snotty**, **LG**, and **whudunit** β€” thank you for your continued enthusiasm, kindness, and sheer staying-power. You've been with us through some big changes, and we're so lucky to have you in our corner. And the warmest welcome to our newest Visionaries **MrMxyzptlk12836**, **Psilocyfer18731**, **KalAbaddon**, **RustCupcake**, and **moon_milky2843** β€” we're so happy you're here, and we can't wait to get to know you. πŸ’› + +## v2.16.0-dev.3 +### Added +- Added enable/disable toggle for environment variables in Settings, allowing variables to be temporarily disabled without deleting them +- Added single-instance window activation signaling so reopening the app restores and focuses the existing desktop window instead of launching a duplicate instance +- Added notification system with localizable banner and markdown detail dialog UI +- Added warning in data directory selector when a OneDrive folder is selected +- Added support in the Checkpoints page to distinguish standard updates from Early Access-only updates - thanks to @x0x0b! +- Added torch index for Strix/Gorgon Point Ryzen AI APUs on Windows - thanks to @NeuralFault! +- Added retry button to failed downloads - thanks to @NeuralFault! +- Added new Membership support in Account Settings with Patreon migration prompt +### Changed +- Improved safetensor checkpoint classification to correctly detect UNet-only models for Wan Video, HiDream, Z-Image, Hunyuan3D, and diffusers-format Flux architectures, ensuring they are routed to the DiffusionModels folder +- GGUF checkpoint downloads now go directly to the DiffusionModels folder instead of StableDiffusion +- Configured portable Git to suppress detached HEAD advice messages +- Settings file saves are now atomic to prevent corruption from interrupted writes +- Updated torch indexes for A1111, ComfyUI, InvokeAI, and Forge-based UIs to rocm7.2 / cu128 depending on GPU - thanks to @NeuralFault! +- Replaced the "Become a Patron" footer button with "Support Us", linking to the new direct Lykos support page at lykos.ai/membership +- Updated the prompt dialog shown when enabling features like Accelerated Model Discovery to use Lykos accounts instead of Patreon linking +- Moved the Patreon connection in Account Settings to a new "Legacy Connections" section, only shown for users with an existing Patreon link +- Localized previously hardcoded strings on the Account Settings page (menu items, descriptions, section headers) and added Japanese, Korean, German, and French translations +### Fixed +- Fixed the Package Manager "Add Package" teaching tip opening inopportunely while packages were still loading or after opening the add-package dialog +- Fixed downloaded checkpoint going to StableDiffusion folder when a saved download preference existed, even for GGUF files that should always go to DiffusionModels +- Fixed potential crash when adding metadata to malformed or non-PNG image data in Inference +- Fixed non-Latin-1 characters (e.g. Japanese, Chinese, Korean, emoji) in image generation parameters being stored in PNG tEXt chunks, violating the PNG specification and causing character corruption (mojibake) in standard-compliant parsers. Non-Latin-1 content now uses spec-compliant iTXt chunks with proper UTF-8 encoding ([#1535](https://github.com/LykosAI/StabilityMatrix/issues/1535)) +- Fixed an issue where `Align Your Steps` scheduler and Unet Loader workflows ignored Regional Prompting (and other addon) conditioning modifiers. +- Fixed bold text not rendering in markdown dialogs on Windows 11 due to Avalonia 11.3.x variable font regression with Segoe UI Variable Text +- Fixed Japanese text appearing compressed/squished in markdown dialogs by ensuring the bundled NotoSansJP font is used for CTextBlock rendering +- Fixed ContentDialog title and buttons not using the correct font for Japanese locale (NotoSansJP) when shown as overlay +- Added missing `CBold` and `CItalic` inline styles to the markdown style sheet +- Fixed downloads failing with "The request message was already sent" when the server doesn't return Content-Length on the first attempt, caused by reusing a consumed HttpRequestMessage in the retry loop +- Fixed downloads from sources that redirect to CivitAI/HuggingFace (e.g. CivArchive) failing with Unauthorized by resolving the redirect target URL and applying auth headers for the correct domain +- Fixed dropdown menu overlayed in Inference UI Model Cards not being scrollable on Linux - thanks to @NeuralFault! +- Fixed model downloads failing on VPN connections - thanks to @NeuralFault! +- Fixed [#1598](https://github.com/LykosAI/StabilityMatrix/issues/1598) - download progress bar showing 100% immediately for fresh downloads due to missing Content-Length fallback when Content-Range header is absent +- Fixed [#1597](https://github.com/LykosAI/StabilityMatrix/issues/1597) - reForge launch failing due to setuptools version +- Fixed [#1596](https://github.com/LykosAI/StabilityMatrix/issues/1596) - package installs and managed embedded Python startup being poisoned by inherited shell Python activation variables such as `PYTHONHOME`, `PYTHONPATH`, `VIRTUAL_ENV`, and Conda environment variables +- Fixed [#1590](https://github.com/LykosAI/StabilityMatrix/issues/1590) - Startup crash when settings file is corrupted. Settings files are now self-healing with automatic recovery from null bytes, truncated JSON, and missing brackets +- Potentially fixed [#1578](https://github.com/LykosAI/StabilityMatrix/issues/1578) - `SocketException: Address already in use` on Linux startup by cleaning stale interprocess socket files and reactivating the existing window +- Fixed [#1397](https://github.com/LykosAI/StabilityMatrix/issues/1397), [#610](https://github.com/LykosAI/StabilityMatrix/issues/610) - duplicate pip package entries in results - thanks to @e-nord! +### Supporters +#### 🌟 Visionaries +A heartfelt thank you to our incredible Visionaries: **Waterclouds**, **JungleDragon**, **bluepopsicle**, **Bob S**, and **whudunit** - every feature, fix, and late-night breakthrough in this release carries your fingerprints. A huge welcome to our newest Visionaries **Droolguy** and **snotty** (leveling up from the Pioneer ranks!), a warm welcome back to longtime Visionary **Ibixat**, and an equally huge welcome to **LG**, making their Stability Matrix debut straight at the Visionary tier! You're the reason we can keep building bold things - and an extra-special thank you to everyone now supporting us directly through our new platform. Your trust in this next chapter means the world! + +## v2.16.0-dev.2 +### Added +- Added Regional Prompting addon to Inference - paint detailed masks to apply different prompts, strengths, and settings to specific regions of your image + - Multi-layer mask editor with Photoshop-style interface for managing layers with independent masks, prompts, colors, and opacity + - Professional brush tools: freehand brush/eraser with pressure sensitivity, rectangle/ellipse shapes with fill/stroke modes, paint bucket flood fill + - Brush feathering/softness control for smooth, blended mask edges (0 = hard edge, 1 = soft/blurred) + - Per-layer prompt and strength controls, export/import masks as PNG, duplicate layers, image reference layers for tracing + - GPU-accelerated rendering with compact gzip-compressed metadata serialization +- Added new Model Picker dialog for Inference with grid/list views, search, filtering, and NSFW overlay +- Added browse buttons to all model dropdowns in Inference (Model, Refiner, VAE, Text Encoders, CLIP Vision) +- Added inline search box to model combo box dropdowns with fuzzy matching +- Added NVIDIA driver version warning when launching ComfyUI with CUDA 13.0 (cu130) and driver versions below 580.x +- Added legacy Python warning when launching InvokeAI installations using Python 3.10.11 +- Added Tiled VAE Decode to the Inference video workflows - thanks to @NeuralFault! +- Added recoverable error dialog for UI thread exceptions, with option to continue instead of exiting +### Changed +- Disabled update checking for legacy InvokeAI installations using Python 3.10.11 +- Hide rating stars in the Civitai browser page if no rating is available +- Updated uv to v0.9.30 +- Updated PortableGit to v2.52.0.windows.1 +- Updated Sage/Triton/Nunchaku installers to use GitHub API to fetch latest releases +- Updated ComfyUI installations and updates to automatically install ComfyUI Manager +- Updated gfx110X Windows ROCm nightly index - thanks to @NeuralFault! +- Updated ComfyUI-Zluda install to more closely match the author's intended installation method - thanks to @NeuralFault! +- Updated Forge Classic installs/updates to use the upstream install script for better version compatibility with torch/sage/triton/nunchaku +### Fixed +- Fixed parsing of escape sequences in Inference such as `\\` +- Fixed batch notification firing when only one image is generated +- Fixed [#1546](https://github.com/LykosAI/StabilityMatrix/issues/1546), [#1541](https://github.com/LykosAI/StabilityMatrix/issues/1541) - "No module named 'pkg_resources'" error when installing Automatic1111/Forge/reForge packages +- Fixed [#1545](https://github.com/LykosAI/StabilityMatrix/issues/1545), [#1518](https://github.com/LykosAI/StabilityMatrix/issues/1518), [#1513](https://github.com/LykosAI/StabilityMatrix/issues/1513), [#1488](https://github.com/LykosAI/StabilityMatrix/issues/1488) - Forge Neo update breaking things +- Fixed [#1529](https://github.com/LykosAI/StabilityMatrix/issues/1529) - "Selected commit is null" error when installing packages and rate limited by GitHub +- Fixed [#1525](https://github.com/LykosAI/StabilityMatrix/issues/1525) - Crash after downloading a model +- Fixed [#1523](https://github.com/LykosAI/StabilityMatrix/issues/1523), [#1499](https://github.com/LykosAI/StabilityMatrix/issues/1499), [#1494](https://github.com/LykosAI/StabilityMatrix/issues/1494) - Automatic1111 using old stable diffusion repo +- Fixed [#1505](https://github.com/LykosAI/StabilityMatrix/issues/1505) - incorrect port argument for Wan2GP +- Possibly fix [#1502](https://github.com/LykosAI/StabilityMatrix/issues/1502) - English fonts not displaying correctly on Linux in Chinese environments +- Fixed [#1476](https://github.com/LykosAI/StabilityMatrix/issues/1476) - Incorrect shared output folder for Forge Classic/Neo +- Fixed [#1466](https://github.com/LykosAI/StabilityMatrix/issues/1466) - crash after moving portable install +- Fixed [#1445](https://github.com/LykosAI/StabilityMatrix/issues/1445) - Linux app updates not actually updating - thanks to @NeuralFault! +### Supporters +#### 🌟 Visionaries +Huge shoutout to our amazing Visionaries: **Waterclouds**, **JungleDragon**, **bluepopsicle**, **Bob S**, and **whudunit**! Your continued support fuels every new feature and improvement in Stability Matrix. We couldn't do it without you - thank you for believing in what we're building! + +## v2.16.0-dev.1 +### Added +#### New Feature: πŸ§ͺ Image Lab - Conversational Image Generation for ComfyUI +- We've added a brand new conversational interface for image generation! Image Lab lets you iterate on images naturally through chat, rather than just one-off prompts. + - Local-First Power: Native support for Flux Kontext and Qwen Image Edit running entirely locally via your ComfyUI backend. + - Smart Setup: Stability Matrix automatically detects and helps you download the specific models and LoRAs needed for these local workflows. + - Interactive Tools: Drag-and-drop image inputs, use the built-in annotation tool to draw on images, and keep persistent conversation history. + - Cloud Option: Includes optional support for Nano Banana (Gemini 3 Pro/2.5) for users who want to leverage external reasoning models. +- Added new package - [Wan2GP](https://github.com/deepbeepmeep/Wan2GP) +- Added [Stable Diffusion WebUI Forge - Neo](https://github.com/Haoming02/sd-webui-forge-classic/tree/neo) as a separate package for convenience +- Added Intel GPU support for ComfyUI +- Added "Run Python Command" option to the package card's 3-dots menu for running arbitrary Python code in the package's virtual environment +- Added togglable `--uv` argument to the SD.Next launch options +- Added Tiled VAE decoding as an Inference addon thanks to @NeuralFault! +### Changed +- Moved the original Stable Diffusion WebUI Forge to the "Legacy" packages tab due to inactivity +- Updated to cu130 torch index for ComfyUI installs with Nvidia GPUs +- Consolidated and fixed AMD GPU architecture detection +- Updated SageAttention installer to latest v2.2.0-windows.post4 version +- Video files can now be opened directly from the Output browser +- Videos will now appear with thumbnails in the Output browser +### Fixed +- Fixed [#1450](https://github.com/LykosAI/StabilityMatrix/issues/1450) - Older SD.Next not launching due to forced `--uv` argument +- Fixed duplicate custom node installations when installing workflows from the Workflow Browser - thanks again to @NeuralFault! +### Supporters +#### 🌟 Visionaries +A massive thank you to our esteemed Visionaries: **Waterclouds**, **JungleDragon**, **bluepopsicle**, **Bob S**, and **whudunit**! Your generosity is the powerhouse behind Stability Matrix, enabling us to keep building and refining with confidence. We are truly grateful for your partnership! + +## v2.15.8 +### Added +- Added support for the civitai.red (mature-content) domain β€” NSFW CivitAI links now open and copy as civitai.red URLs, and pasting a civitai.red URL into the CivitAI model browser search works the same as a civitai.com URL +### Changed +- The CivitAI base model type filter now uses CivitAI's official `/api/v1/enums` endpoint, with fallbacks to the previous technique and a built-in list, so the filter stays populated even if the CivitAI response format changes or the service is unreachable +### Fixed - Fixed CivitAI model browsing breaking during Discovery API outages β€” the browser now falls back to the direct CivitAI API when Discovery returns a server error, authentication failure, or times out - Fixed SwarmUI user settings (theme, output format, server configuration, etc.) and any user-added backend entries being overwritten when the install flow ran over an existing install β€” `Settings.fds` and `Backends.fds` are now merged with their existing contents instead of being rewritten from a stale template - Fixed pip requirements handling for environment-marker dependencies - thanks to @NeuralFault! diff --git a/StabilityMatrix.Avalonia/App.axaml b/StabilityMatrix.Avalonia/App.axaml index 3bbd19816..5c1cbbf4f 100644 --- a/StabilityMatrix.Avalonia/App.axaml +++ b/StabilityMatrix.Avalonia/App.axaml @@ -56,6 +56,7 @@ + @@ -89,6 +90,7 @@ + diff --git a/StabilityMatrix.Avalonia/App.axaml.cs b/StabilityMatrix.Avalonia/App.axaml.cs index 60ffe0216..0b3d14f8f 100644 --- a/StabilityMatrix.Avalonia/App.axaml.cs +++ b/StabilityMatrix.Avalonia/App.axaml.cs @@ -51,6 +51,7 @@ using StabilityMatrix.Avalonia.ViewModels.Progress; using StabilityMatrix.Avalonia.Views; using StabilityMatrix.Core.Api; +using StabilityMatrix.Core.Api.Handlers; using StabilityMatrix.Core.Api.LykosAuthApi; using StabilityMatrix.Core.Api.PromptGenApi; using StabilityMatrix.Core.Attributes; @@ -65,6 +66,7 @@ using StabilityMatrix.Core.Models.Settings; using StabilityMatrix.Core.Python; using StabilityMatrix.Core.Services; +using StabilityMatrix.Core.Services.ImageGeneration; using StabilityMatrix.Core.Updater; using ApiOptions = StabilityMatrix.Core.Models.Configs.ApiOptions; using Application = Avalonia.Application; @@ -149,6 +151,11 @@ public override void OnFrameworkInitializationCompleted() { base.OnFrameworkInitializationCompleted(); + if (!Debugger.IsAttached || Program.Args.DebugExceptionDialog) + { + Dispatcher.UIThread.UnhandledException += Dispatcher_UnhandledException; + } + if (Design.IsDesignMode) { DesignData.DesignData.Initialize(); @@ -389,6 +396,7 @@ internal static void ConfigurePageViewModels(IServiceCollection services) { provider.GetRequiredService(), provider.GetRequiredService(), + provider.GetRequiredService(), provider.GetRequiredService(), provider.GetRequiredService(), provider.GetRequiredService(), @@ -503,8 +511,21 @@ internal static IServiceCollection ConfigureServices(bool disableMessagePipeInte { services.AddSingleton(); services.AddSingleton(p => p.GetRequiredService()); + + // BananaVision has its own database to preserve conversations when main DB is cleared + services.AddSingleton(); + services.AddSingleton(p => p.GetRequiredService()); } + // Image generation services + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddTransient(_ => { var client = new GitHubClient(new ProductHeaderValue("StabilityMatrix")); @@ -728,7 +749,7 @@ internal static IServiceCollection ConfigureServices(bool disableMessagePipeInte } ) .ConfigurePrimaryHttpMessageHandler(() => new HttpClientHandler { AllowAutoRedirect = false }) - .AddPolicyHandler(retryPolicy) + .AddPolicyHandler(retryPolicyLonger) .AddHttpMessageHandler(serviceProvider => new TokenAuthHeaderHandler( serviceProvider.GetRequiredService() )); @@ -764,6 +785,19 @@ internal static IServiceCollection ConfigureServices(bool disableMessagePipeInte }) .AddPolicyHandler(retryPolicy); // Assuming retryPolicy is suitable + services + .AddRefitClient(defaultRefitSettings) + .ConfigureHttpClient(c => + { + c.BaseAddress = new Uri("https://generativelanguage.googleapis.com"); + c.Timeout = TimeSpan.FromMinutes(5); // Higher timeout for image generation + }) + .AddHttpMessageHandler() + .AddPolicyHandler(retryPolicyLonger); + + // Register GeminiApiKeyHandler + services.AddTransient(); + // Apizr clients services.AddApizrManagerFor(options => { @@ -1039,6 +1073,14 @@ private static void OnServiceProviderDisposing(ServiceProvider serviceProvider) Logger.Trace("Disposing {Count} Disposables", disposables.Count); } + private static void Dispatcher_UnhandledException(object? sender, DispatcherUnhandledExceptionEventArgs e) + { + if (Program.ShowExceptionDialog(e.Exception, true)) + { + e.Handled = true; + } + } + private static void TaskScheduler_UnobservedTaskException( object? sender, UnobservedTaskExceptionEventArgs e diff --git a/StabilityMatrix.Avalonia/Assets.cs b/StabilityMatrix.Avalonia/Assets.cs index c44005101..62047b196 100644 --- a/StabilityMatrix.Avalonia/Assets.cs +++ b/StabilityMatrix.Avalonia/Assets.cs @@ -105,7 +105,7 @@ internal static class Assets new RemoteResource { Url = new Uri("https://www.python.org/ftp/python/3.10.11/python-3.10.11-embed-amd64.zip"), - HashSha256 = "608619f8619075629c9c69f361352a0da6ed7e62f83a0e19c63e0ea32eb7629d" + HashSha256 = "608619f8619075629c9c69f361352a0da6ed7e62f83a0e19c63e0ea32eb7629d", } ), ( @@ -115,7 +115,7 @@ internal static class Assets Url = new Uri( "https://github.com/indygreg/python-build-standalone/releases/download/20230507/cpython-3.10.11+20230507-x86_64-unknown-linux-gnu-install_only.tar.gz" ), - HashSha256 = "c5bcaac91bc80bfc29cf510669ecad12d506035ecb3ad85ef213416d54aecd79" + HashSha256 = "c5bcaac91bc80bfc29cf510669ecad12d506035ecb3ad85ef213416d54aecd79", } ), ( @@ -124,7 +124,48 @@ internal static class Assets { // Requires our distribution with signed dylib for gatekeeper Url = new Uri("https://cdn.lykos.ai/cpython-3.10.11-macos-arm64.zip"), - HashSha256 = "83c00486e0af9c460604a425e519d58e4b9604fbe7a4448efda0f648f86fb6e3" + HashSha256 = "83c00486e0af9c460604a425e519d58e4b9604fbe7a4448efda0f648f86fb6e3", + } + ) + ); + + /// + /// FFmpeg LGPL builds for video thumbnail generation. + /// + [SupportedOSPlatform("windows")] + [SupportedOSPlatform("linux")] + [SupportedOSPlatform("macos")] + public static RemoteResource FfmpegDownloadUrl => + Compat.Switch( + ( + PlatformKind.Windows | PlatformKind.X64, + new RemoteResource + { + // BtbN LGPL build - ffmpeg-n7.1-latest-win64-lgpl-7.1 + Url = new Uri( + "https://github.com/BtbN/FFmpeg-Builds/releases/download/latest/ffmpeg-n7.1-latest-win64-lgpl-7.1.zip" + ), + HashSha256 = "a77ecdc794d67401f3e4976f8856065f7762d74afd16f9c7b777ff0291a7bcaa", + } + ), + ( + PlatformKind.Linux | PlatformKind.X64, + new RemoteResource + { + // BtbN LGPL build - linux + Url = new Uri( + "https://github.com/BtbN/FFmpeg-Builds/releases/download/latest/ffmpeg-n7.1-latest-linux64-lgpl-7.1.tar.xz" + ), + HashSha256 = "d7d691dfa3a6d0a75362c02274a80a1f9635bd67908561aae31ee538853ab8ce", + } + ), + ( + PlatformKind.MacOS | PlatformKind.Arm, + new RemoteResource + { + // evermeet.cx build for macOS arm64 + Url = new Uri("https://evermeet.cx/ffmpeg/ffmpeg-7.1.1.zip"), + HashSha256 = "8d7917c1cebd7a29e68c0a0a6cc4ecc3fe05c7fffed958636c7018b319afdda4", } ) ); @@ -135,18 +176,18 @@ internal static class Assets new RemoteResource { Url = new Uri("https://cdn.lykos.ai/tags/danbooru.csv"), - HashSha256 = "b84a879f1d9c47bf4758d66542598faa565b1571122ae12e7b145da8e7a4c1c6" + HashSha256 = "b84a879f1d9c47bf4758d66542598faa565b1571122ae12e7b145da8e7a4c1c6", }, new RemoteResource { Url = new Uri("https://cdn.lykos.ai/tags/e621.csv"), - HashSha256 = "ef7ea148ad865ad936d0c1ee57f0f83de723b43056c70b07fd67dbdbb89cae35" + HashSha256 = "ef7ea148ad865ad936d0c1ee57f0f83de723b43056c70b07fd67dbdbb89cae35", }, new RemoteResource { Url = new Uri("https://cdn.lykos.ai/tags/danbooru_e621_merged.csv"), - HashSha256 = "ac405ebce8b0caae363a7ef91f89beb4b8f60a7e218deb5078833686da6d497d" - } + HashSha256 = "ac405ebce8b0caae363a7ef91f89beb4b8f60a7e218deb5078833686da6d497d", + }, }; public static Uri DiscordServerUrl { get; } = new("https://discord.com/invite/TUrgfECxHz"); diff --git a/StabilityMatrix.Avalonia/Assets/ImagePrompt.tmLanguage.json b/StabilityMatrix.Avalonia/Assets/ImagePrompt.tmLanguage.json index 41d7c00c9..84a76b4c8 100644 --- a/StabilityMatrix.Avalonia/Assets/ImagePrompt.tmLanguage.json +++ b/StabilityMatrix.Avalonia/Assets/ImagePrompt.tmLanguage.json @@ -320,4 +320,4 @@ ] } } -} +} \ No newline at end of file diff --git a/StabilityMatrix.Avalonia/Assets/hf-packages.json b/StabilityMatrix.Avalonia/Assets/hf-packages.json index c3dd9c135..457952327 100644 --- a/StabilityMatrix.Avalonia/Assets/hf-packages.json +++ b/StabilityMatrix.Avalonia/Assets/hf-packages.json @@ -1223,12 +1223,11 @@ { "ModelCategory": "Vae", "ModelName": "Flux.1 VAE", - "RepositoryPath": "black-forest-labs/FLUX.1-schnell", + "RepositoryPath": "Comfy-Org/Lumina_Image_2.0_Repackaged", "Files": [ - "ae.safetensors" + "split_files/vae/ae.safetensors" ], - "LicenseType": "Apache 2.0", - "LoginRequired": true + "LicenseType": "Flux.1 Dev NonCommercial" }, { "ModelCategory": "Vae", diff --git a/StabilityMatrix.Avalonia/Assets/linux-x64/7zzs b/StabilityMatrix.Avalonia/Assets/linux-x64/7zzs index c27d649d6..5246b857a 100644 Binary files a/StabilityMatrix.Avalonia/Assets/linux-x64/7zzs and b/StabilityMatrix.Avalonia/Assets/linux-x64/7zzs differ diff --git a/StabilityMatrix.Avalonia/Assets/linux-x64/7zzs - LICENSE.txt b/StabilityMatrix.Avalonia/Assets/linux-x64/7zzs - LICENSE.txt index 8650d994b..d3cd9b0d6 100644 --- a/StabilityMatrix.Avalonia/Assets/linux-x64/7zzs - LICENSE.txt +++ b/StabilityMatrix.Avalonia/Assets/linux-x64/7zzs - LICENSE.txt @@ -1,15 +1,16 @@ - 7-Zip - ~~~~~ + 7-Zip for Linux and macOS + ~~~~~~~~~~~~~~~~~~~~~~~~~ License for use and distribution ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - 7-Zip Copyright (C) 1999-2023 Igor Pavlov. + 7-Zip Copyright (C) 1999-2026 Igor Pavlov. The licenses for 7zz and 7zzs files are: - The "GNU LGPL" as main license for most of the code - The "GNU LGPL" with "unRAR license restriction" for some code - The "BSD 3-clause License" for some code + - The "BSD 2-clause License" for some code Redistributions in binary form must reproduce related license information from this file. @@ -18,8 +19,8 @@ organization. You don't need to register or pay for 7-Zip. - GNU LGPL information - -------------------- +GNU LGPL information +-------------------- This library is free software; you can redistribute it and/or modify it under the terms of the GNU Lesser General Public @@ -37,52 +38,107 @@ - BSD 3-clause License - -------------------- +BSD 3-clause License in 7-Zip code +---------------------------------- - The "BSD 3-clause License" is used for the code in 7z.dll that implements LZFSE data decompression. - That code was derived from the code in the "LZFSE compression library" developed by Apple Inc, - that also uses the "BSD 3-clause License": + The "BSD 3-clause License" is used for the following code in 7z.dll + 1) LZFSE data decompression. + That code was derived from the code in the "LZFSE compression library" developed by Apple Inc, + that also uses the "BSD 3-clause License". + 2) ZSTD data decompression. + that code was developed using original zstd decoder code as reference code. + The original zstd decoder code was developed by Facebook Inc, + that also uses the "BSD 3-clause License". - ---- - Copyright (c) 2015-2016, Apple Inc. All rights reserved. + Copyright (c) 2015-2016, Apple Inc. All rights reserved. + Copyright (c) Facebook, Inc. All rights reserved. + Copyright (c) 2023-2026 Igor Pavlov. - Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: +Text of the "BSD 3-clause License" +---------------------------------- - 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: - 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer - in the documentation and/or other materials provided with the distribution. +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. - 3. Neither the name of the copyright holder(s) nor the names of any contributors may be used to endorse or promote products derived - from this software without specific prior written permission. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE - COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) - HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - ---- +3. Neither the name of the copyright holder nor the names of its contributors may + be used to endorse or promote products derived from this software without + specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +--- - unRAR license restriction - ------------------------- - The decompression engine for RAR archives was developed using source - code of unRAR program. - All copyrights to original unRAR code are owned by Alexander Roshal. - The license for original unRAR code has the following restriction: +BSD 2-clause License in 7-Zip code +---------------------------------- - The unRAR sources cannot be used to re-create the RAR compression algorithm, - which is proprietary. Distribution of modified unRAR sources in separate form - or as a part of other software is permitted, provided that it is clearly - stated in the documentation and source comments that the code may - not be used to develop a RAR (WinRAR) compatible archiver. + The "BSD 2-clause License" is used for the XXH64 code in 7-Zip. + XXH64 code in 7-Zip was derived from the original XXH64 code developed by Yann Collet. - -- - Igor Pavlov + Copyright (c) 2012-2021 Yann Collet. + Copyright (c) 2023-2026 Igor Pavlov. + +Text of the "BSD 2-clause License" +---------------------------------- + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +--- + + + + +unRAR license restriction +------------------------- + +The decompression engine for RAR archives was developed using source +code of unRAR program. +All copyrights to original unRAR code are owned by Alexander Roshal. + +The license for original unRAR code has the following restriction: + + The unRAR sources cannot be used to re-create the RAR compression algorithm, + which is proprietary. Distribution of modified unRAR sources in separate form + or as a part of other software is permitted, provided that it is clearly + stated in the documentation and source comments that the code may + not be used to develop a RAR (WinRAR) compatible archiver. + +-- diff --git a/StabilityMatrix.Avalonia/Assets/macos-arm64/7zz b/StabilityMatrix.Avalonia/Assets/macos-arm64/7zz index a7ea6fde5..4d038c946 100755 Binary files a/StabilityMatrix.Avalonia/Assets/macos-arm64/7zz and b/StabilityMatrix.Avalonia/Assets/macos-arm64/7zz differ diff --git a/StabilityMatrix.Avalonia/Assets/macos-arm64/7zz - LICENSE.txt b/StabilityMatrix.Avalonia/Assets/macos-arm64/7zz - LICENSE.txt index 8650d994b..d3cd9b0d6 100644 --- a/StabilityMatrix.Avalonia/Assets/macos-arm64/7zz - LICENSE.txt +++ b/StabilityMatrix.Avalonia/Assets/macos-arm64/7zz - LICENSE.txt @@ -1,15 +1,16 @@ - 7-Zip - ~~~~~ + 7-Zip for Linux and macOS + ~~~~~~~~~~~~~~~~~~~~~~~~~ License for use and distribution ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - 7-Zip Copyright (C) 1999-2023 Igor Pavlov. + 7-Zip Copyright (C) 1999-2026 Igor Pavlov. The licenses for 7zz and 7zzs files are: - The "GNU LGPL" as main license for most of the code - The "GNU LGPL" with "unRAR license restriction" for some code - The "BSD 3-clause License" for some code + - The "BSD 2-clause License" for some code Redistributions in binary form must reproduce related license information from this file. @@ -18,8 +19,8 @@ organization. You don't need to register or pay for 7-Zip. - GNU LGPL information - -------------------- +GNU LGPL information +-------------------- This library is free software; you can redistribute it and/or modify it under the terms of the GNU Lesser General Public @@ -37,52 +38,107 @@ - BSD 3-clause License - -------------------- +BSD 3-clause License in 7-Zip code +---------------------------------- - The "BSD 3-clause License" is used for the code in 7z.dll that implements LZFSE data decompression. - That code was derived from the code in the "LZFSE compression library" developed by Apple Inc, - that also uses the "BSD 3-clause License": + The "BSD 3-clause License" is used for the following code in 7z.dll + 1) LZFSE data decompression. + That code was derived from the code in the "LZFSE compression library" developed by Apple Inc, + that also uses the "BSD 3-clause License". + 2) ZSTD data decompression. + that code was developed using original zstd decoder code as reference code. + The original zstd decoder code was developed by Facebook Inc, + that also uses the "BSD 3-clause License". - ---- - Copyright (c) 2015-2016, Apple Inc. All rights reserved. + Copyright (c) 2015-2016, Apple Inc. All rights reserved. + Copyright (c) Facebook, Inc. All rights reserved. + Copyright (c) 2023-2026 Igor Pavlov. - Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: +Text of the "BSD 3-clause License" +---------------------------------- - 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: - 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer - in the documentation and/or other materials provided with the distribution. +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. - 3. Neither the name of the copyright holder(s) nor the names of any contributors may be used to endorse or promote products derived - from this software without specific prior written permission. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE - COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) - HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - ---- +3. Neither the name of the copyright holder nor the names of its contributors may + be used to endorse or promote products derived from this software without + specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +--- - unRAR license restriction - ------------------------- - The decompression engine for RAR archives was developed using source - code of unRAR program. - All copyrights to original unRAR code are owned by Alexander Roshal. - The license for original unRAR code has the following restriction: +BSD 2-clause License in 7-Zip code +---------------------------------- - The unRAR sources cannot be used to re-create the RAR compression algorithm, - which is proprietary. Distribution of modified unRAR sources in separate form - or as a part of other software is permitted, provided that it is clearly - stated in the documentation and source comments that the code may - not be used to develop a RAR (WinRAR) compatible archiver. + The "BSD 2-clause License" is used for the XXH64 code in 7-Zip. + XXH64 code in 7-Zip was derived from the original XXH64 code developed by Yann Collet. - -- - Igor Pavlov + Copyright (c) 2012-2021 Yann Collet. + Copyright (c) 2023-2026 Igor Pavlov. + +Text of the "BSD 2-clause License" +---------------------------------- + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +--- + + + + +unRAR license restriction +------------------------- + +The decompression engine for RAR archives was developed using source +code of unRAR program. +All copyrights to original unRAR code are owned by Alexander Roshal. + +The license for original unRAR code has the following restriction: + + The unRAR sources cannot be used to re-create the RAR compression algorithm, + which is proprietary. Distribution of modified unRAR sources in separate form + or as a part of other software is permitted, provided that it is clearly + stated in the documentation and source comments that the code may + not be used to develop a RAR (WinRAR) compatible archiver. + +-- diff --git a/StabilityMatrix.Avalonia/Assets/sitecustomize.py b/StabilityMatrix.Avalonia/Assets/sitecustomize.py index d154c13a0..7f9278f8d 100644 --- a/StabilityMatrix.Avalonia/Assets/sitecustomize.py +++ b/StabilityMatrix.Avalonia/Assets/sitecustomize.py @@ -46,12 +46,10 @@ def audit(event: str, *args): # Reconfigure stdout to UTF-8 # noinspection PyUnresolvedReferences -sys.stdin.reconfigure(encoding="utf-8") -sys.stdout.reconfigure(encoding="utf-8") -sys.stderr.reconfigure(encoding="utf-8") - -# Install the audit hook -sys.addaudithook(audit) +def _reconfigure_streams(): + sys.stdin.reconfigure(encoding="utf-8") + sys.stdout.reconfigure(encoding="utf-8") + sys.stderr.reconfigure(encoding="utf-8") # Patch Rich terminal detection def _patch_rich_console(): @@ -81,9 +79,7 @@ def is_terminal(self) -> bool: except ImportError: pass except Exception as e: - print("[sitecustomize error]:", e) - -_patch_rich_console() + print("[sitecustomize error]:", e) # Patch tqdm to use stdout instead of stderr def _patch_tqdm(): @@ -97,4 +93,19 @@ def _patch_tqdm(): except Exception as e: print("[sitecustomize error]:", e) -_patch_tqdm() +# Run startup customizations. Each is isolated so that a failure in one (or an +# unusual host environment, e.g. an interpreter probe with no real stdio) can +# never raise out of sitecustomize and abort interpreter startup. +def _run_safely(func): + try: + func() + except Exception as e: + try: + print("[sitecustomize error]:", e) + except Exception: + pass + +_run_safely(_reconfigure_streams) +_run_safely(lambda: sys.addaudithook(audit)) +_run_safely(_patch_rich_console) +_run_safely(_patch_tqdm) diff --git a/StabilityMatrix.Avalonia/Assets/win-x64/7za - LICENSE.txt b/StabilityMatrix.Avalonia/Assets/win-x64/7za - LICENSE.txt index 80473a66f..dae57cb4f 100644 --- a/StabilityMatrix.Avalonia/Assets/win-x64/7za - LICENSE.txt +++ b/StabilityMatrix.Avalonia/Assets/win-x64/7za - LICENSE.txt @@ -1,43 +1,123 @@ -7-Zip Extra 18.01 ------------------ + 7-Zip Extra + ~~~~~~~~~~~ + License for use and distribution + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -7-Zip Extra is package of extra modules of 7-Zip. + Copyright (C) 1999-2026 Igor Pavlov. -7-Zip Copyright (C) 1999-2018 Igor Pavlov. + The licenses for files are: -7-Zip is free software. Read License.txt for more information about license. + - 7za.exe: + - The "GNU LGPL" as main license for most of the code + - The "BSD 3-clause License" for some code + - The "BSD 2-clause License" for some code + - All other files: the "GNU LGPL". -Source code of binaries can be found at: - http://www.7-zip.org/ + Redistributions in binary form must reproduce related license information from this file. + Note: + You can use 7-Zip Extra on any computer, including a computer in a commercial + organization. You don't need to register or pay for 7-Zip. -7-Zip Extra -~~~~~~~~~~~ -License for use and distribution -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + It is allowed to digitally sign DLL and EXE files included into this package + with arbitrary signatures of third parties. -Copyright (C) 1999-2018 Igor Pavlov. -7-Zip Extra files are under the GNU LGPL license. +GNU LGPL information +-------------------- + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. -Notes: - You can use 7-Zip Extra on any computer, including a computer in a commercial - organization. You don't need to register or pay for 7-Zip. + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + You can receive a copy of the GNU Lesser General Public License from + http://www.gnu.org/ -GNU LGPL information --------------------- - This library is free software; you can redistribute it and/or - modify it under the terms of the GNU Lesser General Public - License as published by the Free Software Foundation; either - version 2.1 of the License, or (at your option) any later version. - This library is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU - Lesser General Public License for more details. +BSD 3-clause License in 7-Zip code +---------------------------------- + + The "BSD 3-clause License" is used for the following code in 7za.exe + - ZSTD data decompression. + that code was developed using original zstd decoder code as reference code. + The original zstd decoder code was developed by Facebook Inc, + that also uses the "BSD 3-clause License". + + Copyright (c) Facebook, Inc. All rights reserved. + Copyright (c) 2023-2025 Igor Pavlov. + +Text of the "BSD 3-clause License" +---------------------------------- + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may + be used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +--- + + + + +BSD 2-clause License in 7-Zip code +---------------------------------- + + The "BSD 2-clause License" is used for the XXH64 code in 7za.exe. + + XXH64 code in 7-Zip was derived from the original XXH64 code developed by Yann Collet. + + Copyright (c) 2012-2021 Yann Collet. + Copyright (c) 2023-2025 Igor Pavlov. + +Text of the "BSD 2-clause License" +---------------------------------- + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - You can receive a copy of the GNU Lesser General Public License from - http://www.gnu.org/ +--- diff --git a/StabilityMatrix.Avalonia/Assets/win-x64/7za.exe b/StabilityMatrix.Avalonia/Assets/win-x64/7za.exe index a67de9158..25773795e 100644 Binary files a/StabilityMatrix.Avalonia/Assets/win-x64/7za.exe and b/StabilityMatrix.Avalonia/Assets/win-x64/7za.exe differ diff --git a/StabilityMatrix.Avalonia/Controls/BetterComboBox.cs b/StabilityMatrix.Avalonia/Controls/BetterComboBox.cs index 08126a7ab..a114dc50f 100644 --- a/StabilityMatrix.Avalonia/Controls/BetterComboBox.cs +++ b/StabilityMatrix.Avalonia/Controls/BetterComboBox.cs @@ -1,76 +1,142 @@ -ο»Ώusing System.Reactive.Linq; +using System.Reactive.Linq; using System.Reactive.Subjects; using Avalonia; +using Avalonia.Automation; using Avalonia.Controls; using Avalonia.Controls.Presenters; using Avalonia.Controls.Primitives; using Avalonia.Controls.Primitives.PopupPositioning; using Avalonia.Input; -using Avalonia.Media; using Avalonia.Threading; +using FuzzySharp; +using Microsoft.Extensions.DependencyInjection; using StabilityMatrix.Core.Extensions; +using StabilityMatrix.Core.Helper.Cache; using StabilityMatrix.Core.Models; +using StabilityMatrix.Core.Models.Api.Comfy; +using StabilityMatrix.Core.Models.Settings; +using StabilityMatrix.Core.Services; namespace StabilityMatrix.Avalonia.Controls; public class BetterComboBox : ComboBox { + private static readonly TimeSpan LegacySearchIdleResetDelay = TimeSpan.FromMilliseconds(1200); + + public static readonly StyledProperty SearchWatermarkProperty = AvaloniaProperty.Register< + BetterComboBox, + string + >(nameof(SearchWatermark), defaultValue: "Search..."); + public static readonly StyledProperty UseLegacySearchProperty = AvaloniaProperty.Register< + BetterComboBox, + bool + >(nameof(UseLegacySearch)); + public static readonly DirectProperty SearchTextProperty = + AvaloniaProperty.RegisterDirect(nameof(SearchText), o => o.SearchText); + + public string SearchWatermark + { + get => GetValue(SearchWatermarkProperty); + set => SetValue(SearchWatermarkProperty, value); + } + + public bool UseLegacySearch + { + get => GetValue(UseLegacySearchProperty); + set => SetValue(UseLegacySearchProperty, value); + } + + public string SearchText + { + get => searchText; + private set => SetAndRaise(SearchTextProperty, ref searchText, value); + } + private readonly Subject inputSubject = new(); private readonly IDisposable subscription; - private readonly Popup inputPopup; - private readonly TextBlock inputTextBlock; - private string currentInput = string.Empty; + private readonly LRUCache searchCache = new(50); + private readonly ISettingsManager? settingsManager; + private readonly Popup legacyInputPopup; + private readonly TextBlock legacyInputTextBlock; + private readonly DispatcherTimer legacySearchResetTimer = new() { Interval = LegacySearchIdleResetDelay }; + private TextBox? searchTextBox; + private string keyboardSearchText = string.Empty; + private string searchText = string.Empty; + private string lastAppliedFilter = string.Empty; + private bool isUpdatingSearchText; public BetterComboBox() { - // Create an observable that buffers input over a short period + DropDownOpened += OnDropDownOpened; + DropDownClosed += OnDropDownClosed; + ContainerPrepared += OnContainerPrepared; + ContainerIndexChanged += OnContainerIndexChanged; + var inputObservable = inputSubject - .Do(text => currentInput += text) - .Throttle(TimeSpan.FromMilliseconds(500)) - .Where(_ => !string.IsNullOrEmpty(currentInput)) - .Select(_ => currentInput); + .Select(text => text.Trim()) + .Throttle(TimeSpan.FromMilliseconds(200)) + .DistinctUntilChanged(); - // Subscribe to the observable to filter the ComboBox items subscription = inputObservable .ObserveOn(SynchronizationContext.Current) - .Subscribe(OnInputReceived, _ => ResetPopupText()); + .Subscribe(OnInputReceived, _ => ResetSearchText()); + legacySearchResetTimer.Tick += OnLegacySearchResetTimerTick; - // Initialize the popup - inputPopup = new Popup + legacyInputTextBlock = new TextBlock { FontSize = 13 }; + legacyInputTextBlock.Bind( + TextBlock.ForegroundProperty, + this.GetResourceObservable("ComboBoxForeground") + ); + var popupBorder = new Border { Padding = new Thickness(8, 4), Child = legacyInputTextBlock }; + popupBorder.Bind(Border.BackgroundProperty, this.GetResourceObservable("ComboBoxDropDownBackground")); + legacyInputPopup = new Popup { IsLightDismissEnabled = true, Placement = PlacementMode.AnchorAndGravity, PlacementAnchor = PopupAnchor.Bottom, PlacementGravity = PopupGravity.Top, + VerticalOffset = -6, + Child = popupBorder, }; - // Initialize the TextBlock with custom styling - inputTextBlock = new TextBlock + if (!Design.IsDesignMode) { - Foreground = Brushes.White, // White text color - Background = Brush.Parse("#333333"), // Dark gray background - Padding = new Thickness(8), // Add padding - FontSize = 14 // Optional: adjust font size - }; - - inputPopup.Child = inputTextBlock; + settingsManager = App.Services.GetService(); + if (settingsManager is not null) + { + ApplyGlobalLegacySearchOverride(settingsManager.Settings.UseLegacySearch); + settingsManager.SettingsPropertyChanged += OnSettingsPropertyChanged; + } + } } /// protected override void OnApplyTemplate(TemplateAppliedEventArgs e) { base.OnApplyTemplate(e); + legacyInputPopup.PlacementTarget = this; - // Set the Popup's anchor to the ComboBox itself - inputPopup.PlacementTarget = this; - - if (e.NameScope.Find("ContentPresenter") is { } contentPresenter) + if (e.NameScope.Find("ContentPresenter") is { } contentPresenter) { if (SelectionBoxItemTemplate is { } template) { contentPresenter.ContentTemplate = template; } } + + if (searchTextBox is not null) + { + searchTextBox.TextChanged -= SearchTextBoxOnTextChanged; + searchTextBox.KeyDown -= SearchTextBoxOnKeyDown; + } + + searchTextBox = e.NameScope.Find("PART_SearchTextBox"); + if (searchTextBox is not null) + { + AutomationProperties.SetName(searchTextBox, "Search models"); + searchTextBox.TextChanged += SearchTextBoxOnTextChanged; + searchTextBox.KeyDown += SearchTextBoxOnKeyDown; + } } protected override void OnTextInput(TextInputEventArgs e) @@ -78,72 +144,438 @@ protected override void OnTextInput(TextInputEventArgs e) if (e.Handled) return; + if (searchTextBox?.IsFocused == true) + { + base.OnTextInput(e); + return; + } + if (!string.IsNullOrWhiteSpace(e.Text)) { - // Push the input text to the subject - inputSubject.OnNext(e.Text); - UpdatePopupText(e.Text); + keyboardSearchText += e.Text; + inputSubject.OnNext(keyboardSearchText); + RestartLegacySearchResetTimer(); + UpdateLegacySearchPopupText(keyboardSearchText); + + if (IsDropDownOpen) + { + UpdateSearchTextBoxText(keyboardSearchText); + if (!UseLegacySearch) + { + Dispatcher.UIThread.Post(() => searchTextBox?.Focus(), DispatcherPriority.Input); + } + } + e.Handled = true; } base.OnTextInput(e); } - private void OnInputReceived(string input) + private void SearchTextBoxOnTextChanged(object? sender, TextChangedEventArgs e) + { + if (isUpdatingSearchText || sender is not TextBox textBox) + return; + + keyboardSearchText = textBox.Text ?? string.Empty; + SearchText = keyboardSearchText; + inputSubject.OnNext(keyboardSearchText); + RestartLegacySearchResetTimer(); + UpdateLegacySearchPopupText(keyboardSearchText); + } + + private void SearchTextBoxOnKeyDown(object? sender, KeyEventArgs e) + { + if (e.Key != Key.Escape) + return; + + StopLegacySearchResetTimer(); + IsDropDownOpen = false; + e.Handled = true; + } + + private void OnDropDownOpened(object? sender, EventArgs e) { - if (Items.OfType().ToList() is { Count: > 0 } enumItems) + StopLegacySearchResetTimer(); + ResetSearchText(); + ApplyFilter(string.Empty); + if (!UseLegacySearch) { - var foundEnum = enumItems.FirstOrDefault( - x => x.GetStringValue().StartsWith(input, StringComparison.OrdinalIgnoreCase) - ); + Dispatcher.UIThread.Post(() => searchTextBox?.Focus(), DispatcherPriority.Input); + } + } + + private void OnDropDownClosed(object? sender, EventArgs e) + { + StopLegacySearchResetTimer(); + ResetSearchText(); + ApplyFilter(string.Empty); + } + + private void UpdateSearchTextBoxText(string text) + { + SearchText = text; + UpdateLegacySearchPopupText(text); + + if (searchTextBox is null) + return; - if (foundEnum is not null) + isUpdatingSearchText = true; + searchTextBox.Text = text; + searchTextBox.CaretIndex = searchTextBox.Text?.Length ?? 0; + isUpdatingSearchText = false; + } + + private void ResetSearchText() + { + StopLegacySearchResetTimer(); + keyboardSearchText = string.Empty; + UpdateSearchTextBoxText(string.Empty); + } + + private void RestartLegacySearchResetTimer() + { + if (!UseLegacySearch || string.IsNullOrEmpty(keyboardSearchText)) + return; + + legacySearchResetTimer.Stop(); + legacySearchResetTimer.Start(); + } + + private void StopLegacySearchResetTimer() + { + legacySearchResetTimer.Stop(); + } + + private void UpdateLegacySearchPopupText(string text) + { + if (!UseLegacySearch || string.IsNullOrWhiteSpace(text)) + { + HideLegacySearchPopup(); + return; + } + + legacyInputTextBlock.Text = text; + + if (legacyInputPopup.PlacementTarget is null) + { + legacyInputPopup.PlacementTarget = this; + } + + if (!legacyInputPopup.IsOpen) + { + legacyInputPopup.IsOpen = true; + } + } + + private void HideLegacySearchPopup() + { + legacyInputTextBlock.Text = string.Empty; + legacyInputPopup.IsOpen = false; + } + + private void OnLegacySearchResetTimerTick(object? sender, EventArgs e) + { + legacySearchResetTimer.Stop(); + + if (!UseLegacySearch || string.IsNullOrWhiteSpace(keyboardSearchText)) + return; + + ResetSearchText(); + } + + private void OnInputReceived(string input) + { + if (IsDropDownOpen) + { + if (UseLegacySearch) { - Dispatcher.UIThread.Post(() => + var query = input.Trim(); + if (string.IsNullOrWhiteSpace(query)) + return; + + var legacyMatch = FindLegacyMatch(query); + if (legacyMatch is not null) { - SelectedItem = foundEnum; - }); + Dispatcher.UIThread.Post(() => + { + SelectedItem = legacyMatch; + ScrollIntoView(legacyMatch); + }); + } + } + else + { + Dispatcher.UIThread.Post(() => ApplyFilter(input)); + } + return; + } + + if (string.IsNullOrWhiteSpace(input)) + return; + + if (searchCache.Get(input, out var cachedResult) && cachedResult is not null) + { + Dispatcher.UIThread.Post(() => SelectedItem = cachedResult); + return; + } + + if (UseLegacySearch) + { + var legacyMatch = FindLegacyMatch(input); + if (legacyMatch is null) + return; + + searchCache.Add(input, legacyMatch); + Dispatcher.UIThread.Post(() => SelectedItem = legacyMatch); + return; + } + + object? found = null; + + var enumBestMatch = FindBestMatch(input, Items.OfType(), e => e.GetStringValue()); + if (enumBestMatch.Score > 50) + { + found = enumBestMatch.Item; + } + else + { + var modelBestMatch = FindBestMatch(input, Items.OfType(), m => GetItemSearchText(m)); + if (modelBestMatch.Score > 50) + { + found = modelBestMatch.Item; } } - else if (Items.OfType().ToList() is { } modelFiles) + + if (found is not null) { - var found = modelFiles.FirstOrDefault( - x => x.SearchText.StartsWith(input, StringComparison.OrdinalIgnoreCase) - ); + searchCache.Add(input, found); + Dispatcher.UIThread.Post(() => SelectedItem = found); + } + } + + private void ApplyFilter(string input) + { + var query = input.Trim(); + var filterChanged = !string.Equals(lastAppliedFilter, query, StringComparison.Ordinal); + lastAppliedFilter = query; - if (found is not null) + var hasQuery = !string.IsNullOrWhiteSpace(query); + object? firstMatch = null; + + foreach (var item in Items.Cast()) + { + var isMatch = !hasQuery || IsItemMatch(item, query); + if (isMatch && firstMatch is null) { - Dispatcher.UIThread.Post(() => + firstMatch = item; + } + + if (ContainerFromItem(item) is not Control container) + continue; + + container.IsVisible = isMatch; + } + + if (!IsDropDownOpen || firstMatch is null) + { + return; + } + + if (!filterChanged) + { + return; + } + + // Keep the first matching result pinned near the top when virtualizing. + Dispatcher.UIThread.Post(() => ScrollIntoView(firstMatch), DispatcherPriority.Background); + } + + private bool IsItemMatch(object item, string query) + { + var itemText = GetItemSearchText(item, UseLegacySearch); + + if (UseLegacySearch) + { + return itemText.Contains(query, StringComparison.OrdinalIgnoreCase); + } + + if (itemText.Contains(query, StringComparison.OrdinalIgnoreCase)) + return true; + + // Allow approximate matching for typos while filtering. + return Fuzz.PartialRatio(query, itemText) >= 70; + } + + private object? FindLegacyMatch(string query) + { + var trimmedQuery = query.Trim(); + if (string.IsNullOrWhiteSpace(trimmedQuery)) + return null; + + object? firstSearchTextMatch = null; + + foreach (var item in Items) + { + if (item is Enum enumItem) + { + if (enumItem.GetStringValue().Contains(trimmedQuery, StringComparison.OrdinalIgnoreCase)) + { + return enumItem; + } + } + else if (firstSearchTextMatch is null && item is ISearchText or ComfySampler or ComfyScheduler) + { + if (GetItemSearchText(item, true).Contains(trimmedQuery, StringComparison.OrdinalIgnoreCase)) { - SelectedItem = found; - }); + firstSearchTextMatch = item; + } } } - Dispatcher.UIThread.Post(ResetPopupText); + return firstSearchTextMatch; } - private void UpdatePopupText(string text) + private static string GetItemSearchText(object item, bool useLegacySearch = false) { - inputTextBlock.Text += text; // Accumulate text in the popup + return item switch + { + HybridModelFile hybridModel => useLegacySearch + ? hybridModel.SearchText + : hybridModel.DetailedSearchText, + Enum enumItem => enumItem.GetStringValue(), + ComfySampler sampler => $"{sampler.DisplayName} {sampler.Name}", + ComfyScheduler scheduler => $"{scheduler.DisplayName} {scheduler.Name}", + ISearchText searchable => searchable.SearchText, + _ => item.ToString() ?? string.Empty, + }; + } - if (!inputPopup.IsOpen) + private static (TItem? Item, int Score) FindBestMatch( + string input, + IEnumerable items, + Func getSearchText + ) + { + TItem? bestItem = default; + var bestScore = 0; + + foreach (var item in items) { - inputPopup.IsOpen = true; + var score = Fuzz.WeightedRatio(input, getSearchText(item)); + if (score <= bestScore) + continue; + + bestScore = score; + bestItem = item; + } + + return (bestItem, bestScore); + } + + private void OnContainerPrepared(object? sender, ContainerPreparedEventArgs e) + { + if (!IsDropDownOpen || UseLegacySearch) + return; + + var query = keyboardSearchText.Trim(); + if (string.IsNullOrWhiteSpace(query)) + { + e.Container.IsVisible = true; + return; + } + + if (e.Index >= 0 && e.Index < ItemsView.Count && ItemsView[e.Index] is { } item) + { + e.Container.IsVisible = IsItemMatch(item, query); + } + } + + private void OnContainerIndexChanged(object? sender, ContainerIndexChangedEventArgs e) + { + if (!IsDropDownOpen || UseLegacySearch) + return; + + var query = keyboardSearchText.Trim(); + if (string.IsNullOrWhiteSpace(query)) + { + e.Container.IsVisible = true; + return; + } + + if (e.NewIndex >= 0 && e.NewIndex < ItemsView.Count && ItemsView[e.NewIndex] is { } item) + { + e.Container.IsVisible = IsItemMatch(item, query); + } + } + + protected override void OnPropertyChanged(AvaloniaPropertyChangedEventArgs change) + { + base.OnPropertyChanged(change); + + if (change.Property == ItemsSourceProperty) + { + searchCache.Clear(); + } + + if (change.Property == UseLegacySearchProperty && !UseLegacySearch) + { + StopLegacySearchResetTimer(); + HideLegacySearchPopup(); + } + } + + private void ApplyGlobalLegacySearchOverride(bool globalOverride) + { + if (globalOverride) + { + SetValue(UseLegacySearchProperty, true); + } + else + { + ClearValue(UseLegacySearchProperty); } } - private void ResetPopupText() + private void OnSettingsPropertyChanged(object? sender, RelayPropertyChangedEventArgs e) { - currentInput = string.Empty; - inputTextBlock.Text = string.Empty; - inputPopup.IsOpen = false; + if (e.PropertyName != nameof(Settings.UseLegacySearch) || settingsManager is null) + return; + + Dispatcher.UIThread.Post(() => + { + ApplyGlobalLegacySearchOverride(settingsManager.Settings.UseLegacySearch); + if (!UseLegacySearch) + { + StopLegacySearchResetTimer(); + HideLegacySearchPopup(); + } + }); } - // Ensure proper disposal of resources protected override void OnDetachedFromVisualTree(VisualTreeAttachmentEventArgs e) { base.OnDetachedFromVisualTree(e); + + DropDownOpened -= OnDropDownOpened; + DropDownClosed -= OnDropDownClosed; + ContainerPrepared -= OnContainerPrepared; + ContainerIndexChanged -= OnContainerIndexChanged; + + if (searchTextBox is not null) + { + searchTextBox.TextChanged -= SearchTextBoxOnTextChanged; + searchTextBox.KeyDown -= SearchTextBoxOnKeyDown; + } + + if (settingsManager is not null) + { + settingsManager.SettingsPropertyChanged -= OnSettingsPropertyChanged; + } + + legacySearchResetTimer.Tick -= OnLegacySearchResetTimerTick; + StopLegacySearchResetTimer(); + HideLegacySearchPopup(); subscription.Dispose(); } } diff --git a/StabilityMatrix.Avalonia/Controls/BetterContentDialog.cs b/StabilityMatrix.Avalonia/Controls/BetterContentDialog.cs index 15014b36c..03513fcf6 100644 --- a/StabilityMatrix.Avalonia/Controls/BetterContentDialog.cs +++ b/StabilityMatrix.Avalonia/Controls/BetterContentDialog.cs @@ -99,6 +99,8 @@ static BetterContentDialog() #endregion private Border? backgroundPart; + private ContentDialogViewModelBase? boundDialogViewModel; + private ContentDialogProgressViewModelBase? boundProgressViewModel; protected override Type StyleKeyOverride { get; } = typeof(ContentDialog); @@ -154,8 +156,8 @@ public double MaxDialogWidth public double MinDialogHeight { - get => GetValue(MaxDialogHeightProperty); - set => SetValue(MaxDialogHeightProperty, value); + get => GetValue(MinDialogHeightProperty); + set => SetValue(MinDialogHeightProperty, value); } public static readonly StyledProperty MaxDialogHeightProperty = AvaloniaProperty.Register< @@ -205,6 +207,7 @@ public BetterContentDialog() } AddHandler(LoadedEvent, OnLoaded); + AddHandler(UnloadedEvent, OnUnloaded); } /// @@ -283,23 +286,47 @@ private void TrySetButtonCommands() private void TryBindButtonEvents() { + UnbindButtonEvents(); + if ((Content as Control)?.DataContext is ContentDialogViewModelBase viewModel) { viewModel.PrimaryButtonClick += OnDialogButtonClick; viewModel.SecondaryButtonClick += OnDialogButtonClick; viewModel.CloseButtonClick += OnDialogButtonClick; + boundDialogViewModel = viewModel; } else if (Content is ContentDialogViewModelBase viewModelDirect) { viewModelDirect.PrimaryButtonClick += OnDialogButtonClick; viewModelDirect.SecondaryButtonClick += OnDialogButtonClick; viewModelDirect.CloseButtonClick += OnDialogButtonClick; + boundDialogViewModel = viewModelDirect; } else if ((Content as Control)?.DataContext is ContentDialogProgressViewModelBase progressViewModel) { progressViewModel.PrimaryButtonClick += OnDialogButtonClick; progressViewModel.SecondaryButtonClick += OnDialogButtonClick; progressViewModel.CloseButtonClick += OnDialogButtonClick; + boundProgressViewModel = progressViewModel; + } + } + + private void UnbindButtonEvents() + { + if (boundDialogViewModel is not null) + { + boundDialogViewModel.PrimaryButtonClick -= OnDialogButtonClick; + boundDialogViewModel.SecondaryButtonClick -= OnDialogButtonClick; + boundDialogViewModel.CloseButtonClick -= OnDialogButtonClick; + boundDialogViewModel = null; + } + + if (boundProgressViewModel is not null) + { + boundProgressViewModel.PrimaryButtonClick -= OnDialogButtonClick; + boundProgressViewModel.SecondaryButtonClick -= OnDialogButtonClick; + boundProgressViewModel.CloseButtonClick -= OnDialogButtonClick; + boundProgressViewModel = null; } } @@ -406,4 +433,9 @@ private void OnLoaded(object? sender, RoutedEventArgs? e) Dispatcher.UIThread.InvokeAsync(viewModel.OnLoadedAsync).SafeFireAndForget(); }*/ } + + private void OnUnloaded(object? sender, RoutedEventArgs? e) + { + UnbindButtonEvents(); + } } diff --git a/StabilityMatrix.Avalonia/Controls/Inference/ExtraNetworkCard.axaml b/StabilityMatrix.Avalonia/Controls/Inference/ExtraNetworkCard.axaml index 0c8bd3afe..620a86272 100644 --- a/StabilityMatrix.Avalonia/Controls/Inference/ExtraNetworkCard.axaml +++ b/StabilityMatrix.Avalonia/Controls/Inference/ExtraNetworkCard.axaml @@ -4,6 +4,7 @@ xmlns:avalonia="https://github.com/projektanker/icons.avalonia" xmlns:controls="clr-namespace:StabilityMatrix.Avalonia.Controls" xmlns:d="http://schemas.microsoft.com/expression/blend/2008" + xmlns:fluent="clr-namespace:FluentIcons.Avalonia.Fluent;assembly=FluentIcons.Avalonia.Fluent" xmlns:lang="clr-namespace:StabilityMatrix.Avalonia.Languages" xmlns:mocks="clr-namespace:StabilityMatrix.Avalonia.DesignData" xmlns:sg="clr-namespace:SpacedGridControl.Avalonia;assembly=SpacedGridControl.Avalonia" @@ -22,9 +23,9 @@ + RowDefinitions="Auto,Auto,Auto,Auto,Auto"> + + + + IsVisible="True"> @@ -44,359 +41,354 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + - + + - - + + + + + + + - + - - - - - - - - - - - - - - + ColumnDefinitions="70,*,Auto" + IsVisible="{Binding IsRefinerSelectionEnabled}"> + + + + - + + + + + + + + + + + + + + + + + + + + - + Margin="90,4,0,0" + FontSize="11" + Foreground="{DynamicResource TextFillColorSecondaryBrush}" + IsVisible="{Binding ShowWorkflowProfileStatus}" + Text="{Binding WorkflowProfileStatusText}" /> - - + + ColumnDefinitions="70,*" + IsVisible="{Binding ShowEncoderTypeSelection}"> + + + - + - - - + Header="{Binding AdvancedOptionsHeader}" + IsExpanded="{Binding IsAdvancedOptionsExpanded}" + IsVisible="{Binding HasActiveAdvancedOptions}"> + + + + + + - + + + + + + - - - - - - - - - + + + + + + + - + - - - - - - - + IsExpanded="{Binding IsTextEncodersExpanded}" + IsVisible="{Binding ShowEncoderSection}"> + + + + + + + + + + + + + + + + + + + + + + - - + + + + + + + - + - - - + ColumnDefinitions="90,*" + IsVisible="{Binding IsModelLoaderSelectionEnabled}"> + + + + + + + + + + + + + + + + + - - + + - + diff --git a/StabilityMatrix.Avalonia/Controls/Inference/RegionalPromptCard.axaml b/StabilityMatrix.Avalonia/Controls/Inference/RegionalPromptCard.axaml new file mode 100644 index 000000000..151d5cd2f --- /dev/null +++ b/StabilityMatrix.Avalonia/Controls/Inference/RegionalPromptCard.axaml @@ -0,0 +1,105 @@ +ο»Ώ + + + diff --git a/StabilityMatrix.Avalonia/Controls/Inference/RegionalPromptCard.axaml.cs b/StabilityMatrix.Avalonia/Controls/Inference/RegionalPromptCard.axaml.cs new file mode 100644 index 000000000..322f2aef7 --- /dev/null +++ b/StabilityMatrix.Avalonia/Controls/Inference/RegionalPromptCard.axaml.cs @@ -0,0 +1,6 @@ +using Injectio.Attributes; + +namespace StabilityMatrix.Avalonia.Controls; + +[RegisterTransient] +public class RegionalPromptCard : TemplatedControlBase; diff --git a/StabilityMatrix.Avalonia/Controls/Inference/SamplerCard.axaml b/StabilityMatrix.Avalonia/Controls/Inference/SamplerCard.axaml index 5cabb0b4d..5d9e7dac6 100644 --- a/StabilityMatrix.Avalonia/Controls/Inference/SamplerCard.axaml +++ b/StabilityMatrix.Avalonia/Controls/Inference/SamplerCard.axaml @@ -3,6 +3,7 @@ xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml" xmlns:avalonia="https://github.com/projektanker/icons.avalonia" xmlns:controls="using:StabilityMatrix.Avalonia.Controls" + xmlns:converters="clr-namespace:StabilityMatrix.Avalonia.Converters" xmlns:generic="clr-namespace:System.Collections.Generic;assembly=System.Runtime" xmlns:generic1="clr-namespace:System.Collections.Generic;assembly=System.Collections" xmlns:input="clr-namespace:FluentAvalonia.UI.Input;assembly=FluentAvalonia" @@ -46,7 +47,8 @@ DisplayMemberBinding="{Binding DisplayName}" IsVisible="{Binding IsSamplerSelectionEnabled}" ItemsSource="{Binding ClientManager.Samplers}" - SelectedItem="{Binding SelectedSampler}" /> + SelectedItem="{Binding SelectedSampler}" + UseLegacySearch="True" /> + SelectedItem="{Binding SelectedScheduler}" + UseLegacySearch="True" /> - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + diff --git a/StabilityMatrix.Avalonia/Controls/Inference/SelectImageCard.axaml.cs b/StabilityMatrix.Avalonia/Controls/Inference/SelectImageCard.axaml.cs index 71b8bd7bf..a0725f9cf 100644 --- a/StabilityMatrix.Avalonia/Controls/Inference/SelectImageCard.axaml.cs +++ b/StabilityMatrix.Avalonia/Controls/Inference/SelectImageCard.axaml.cs @@ -50,7 +50,10 @@ protected override void OnApplyTemplate(TemplateAppliedEventArgs e) } } - vm.CurrentBitmapSize = System.Drawing.Size.Empty; + if (vm.ImageSource is null) + { + vm.CurrentBitmapSize = System.Drawing.Size.Empty; + } }); } } diff --git a/StabilityMatrix.Avalonia/Controls/Inference/StackEditableCard.axaml.cs b/StabilityMatrix.Avalonia/Controls/Inference/StackEditableCard.axaml.cs index cc635d7a3..2153f2ae8 100644 --- a/StabilityMatrix.Avalonia/Controls/Inference/StackEditableCard.axaml.cs +++ b/StabilityMatrix.Avalonia/Controls/Inference/StackEditableCard.axaml.cs @@ -18,6 +18,7 @@ namespace StabilityMatrix.Avalonia.Controls; public class StackEditableCard : TemplatedControlBase { private ListBox? listBoxPart; + private Button? addButtonPart; // ReSharper disable once MemberCanBePrivate.Global public static readonly StyledProperty IsListBoxEditEnabledProperty = AvaloniaProperty.Register< @@ -51,10 +52,8 @@ protected override void OnApplyTemplate(TemplateAppliedEventArgs e) }; } - if (e.NameScope.Find + + + + + Grid.ColumnSpan="3"> diff --git a/StabilityMatrix.Avalonia/Controls/MarkdownViewer.axaml.cs b/StabilityMatrix.Avalonia/Controls/MarkdownViewer.axaml.cs index 773cb3eef..9f8797816 100644 --- a/StabilityMatrix.Avalonia/Controls/MarkdownViewer.axaml.cs +++ b/StabilityMatrix.Avalonia/Controls/MarkdownViewer.axaml.cs @@ -1,4 +1,4 @@ -ο»Ώusing System.IO; +using System.IO; using Avalonia; using Avalonia.Controls.Primitives; using Markdig; @@ -50,7 +50,7 @@ private void ParseText(string value) if (string.IsNullOrWhiteSpace(value)) return; - var pipeline = new MarkdownPipelineBuilder().UseAdvancedExtensions().Build(); + var pipeline = new MarkdownPipelineBuilder().UseAdvancedExtensions().UseEmojiAndSmiley().Build(); var html = $"""{Markdig.Markdown.ToHtml(value, pipeline)}"""; Html = html; diff --git a/StabilityMatrix.Avalonia/Controls/Models/PenPath.cs b/StabilityMatrix.Avalonia/Controls/Models/PenPath.cs index ba830c7f9..877ded502 100644 --- a/StabilityMatrix.Avalonia/Controls/Models/PenPath.cs +++ b/StabilityMatrix.Avalonia/Controls/Models/PenPath.cs @@ -1,19 +1,262 @@ -ο»Ώusing System.Collections.Generic; +ο»Ώusing System; +using System.Buffers; +using System.Collections.Generic; +using System.IO; +using System.IO.Compression; +using System.Linq; +using System.Text.Json; using System.Text.Json.Serialization; using SkiaSharp; using StabilityMatrix.Core.Converters.Json; namespace StabilityMatrix.Avalonia.Controls.Models; +/// +/// Type of path - determines how the path is rendered. +/// +public enum PenPathType +{ + /// + /// Freehand brush strokes (default). + /// + Freehand, + + /// + /// Filled rectangle shape. + /// + Rectangle, + + /// + /// Filled ellipse/oval shape. + /// + Ellipse, + + /// + /// Bitmap image (used for flood fill results). + /// + Bitmap, +} + +/// +/// Custom JSON converter for PenPath that handles both legacy (JSON array) +/// and new (compressed base64 string) formats for backwards compatibility. +/// +public class PenPathJsonConverter : JsonConverter +{ + public override PenPath Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + if (reader.TokenType != JsonTokenType.StartObject) + return default; + + var penPath = new PenPath(); + + while (reader.Read()) + { + if (reader.TokenType == JsonTokenType.EndObject) + return penPath; + + if (reader.TokenType != JsonTokenType.PropertyName) + continue; + + var propertyName = reader.GetString()?.ToLowerInvariant(); + reader.Read(); + + switch (propertyName) + { + case "points": + // Handle both legacy (array) and new (compressed string) formats + if (reader.TokenType == JsonTokenType.String) + { + // New compressed format + var compressed = reader.GetString(); + var decompressedPoints = PenPath.DecompressPointsPublic(compressed); + penPath = penPath with { Points = decompressedPoints ?? [] }; + } + else if (reader.TokenType == JsonTokenType.StartArray) + { + // Legacy format - manually deserialize array of PenPoint objects + // (Can't use JsonSerializer.Deserialize due to source-gen context limitations) + var points = new List(); + var penPointConverter = new PenPointJsonConverter(); + + while (reader.Read()) + { + if (reader.TokenType == JsonTokenType.EndArray) + break; + + if (reader.TokenType == JsonTokenType.StartObject) + { + var point = penPointConverter.Read(ref reader, typeof(PenPoint), options); + points.Add(point); + } + } + + penPath = penPath with { Points = points }; + } + break; + + case "fillcolor": + var colorConverter = new SKColorJsonConverter(); + var color = colorConverter.Read(ref reader, typeof(SKColor), options); + penPath = penPath with { FillColor = color }; + break; + + case "iserase": + penPath = penPath with { IsErase = reader.GetBoolean() }; + break; + + case "pathtype": + // Handle both string and number formats for backward compatibility + if (reader.TokenType == JsonTokenType.String) + { + if (Enum.TryParse(reader.GetString(), out var pathType)) + penPath = penPath with { PathType = pathType }; + } + else if (reader.TokenType == JsonTokenType.Number) + { + var pathTypeInt = reader.GetInt32(); + if (Enum.IsDefined(typeof(PenPathType), pathTypeInt)) + penPath = penPath with { PathType = (PenPathType)pathTypeInt }; + } + break; + + case "bounds": + var rectConverter = new SKRectJsonConverter(); + var bounds = rectConverter.Read(ref reader, typeof(SKRect), options); + penPath = penPath with { Bounds = bounds }; + break; + + case "isstrokeonly": + penPath = penPath with { IsStrokeOnly = reader.GetBoolean() }; + break; + + case "strokewidth": + penPath = penPath with { StrokeWidth = (float)reader.GetDouble() }; + break; + + case "radius": + penPath = penPath with { Radius = (float)reader.GetDouble() }; + break; + + case "feathering": + penPath = penPath with { Feathering = (float)reader.GetDouble() }; + break; + + case "bitmapdata": + var base64 = reader.GetString(); + if (!string.IsNullOrEmpty(base64)) + { + var bytes = Convert.FromBase64String(base64); + var bmp = SKBitmap.Decode(bytes); + if (bmp is not null) + { + penPath = penPath with { BitmapData = bmp }; + } + } + break; + + default: + reader.Skip(); + break; + } + } + + return penPath; + } + + public override void Write(Utf8JsonWriter writer, PenPath value, JsonSerializerOptions options) + { + writer.WriteStartObject(); + + // Write FillColor + var colorConverter = new SKColorJsonConverter(); + writer.WritePropertyName("fillColor"); + colorConverter.Write(writer, value.FillColor, options); + + writer.WriteBoolean("isErase", value.IsErase); + writer.WriteString("pathType", value.PathType.ToString()); + + // Write Bounds + var rectConverter = new SKRectJsonConverter(); + writer.WritePropertyName("bounds"); + rectConverter.Write(writer, value.Bounds, options); + + writer.WriteBoolean("isStrokeOnly", value.IsStrokeOnly); + writer.WriteNumber("strokeWidth", value.StrokeWidth); + writer.WriteNumber("radius", value.Radius); + writer.WriteNumber("feathering", value.Feathering); + + // Write points in compressed format + var compressedPoints = PenPath.CompressPointsPublic(value.Points); + if (compressedPoints != null) + { + writer.WriteString("points", compressedPoints); + } + + // Write bitmap data (for flood fill paths) as PNG base64 + if (value.BitmapData is { } bitmap) + { + using var image = SKImage.FromBitmap(bitmap); + using var data = image.Encode(SKEncodedImageFormat.Png, 100); + writer.WriteString("bitmapData", Convert.ToBase64String(data.AsSpan())); + } + + writer.WriteEndObject(); + } +} + +[JsonConverter(typeof(PenPathJsonConverter))] public readonly record struct PenPath() { - [JsonConverter(typeof(SKColorJsonConverter))] public SKColor FillColor { get; init; } public bool IsErase { get; init; } + /// + /// Type of path (Freehand, Rectangle, or Ellipse). + /// + public PenPathType PathType { get; init; } = PenPathType.Freehand; + + /// + /// Bounding rectangle for shape paths (Rectangle, Ellipse). + /// For Freehand paths, this is ignored. + /// + public SKRect Bounds { get; init; } + + /// + /// If true, draws shape outline only (stroke). If false, fills the shape. + /// Only applies to Rectangle and Ellipse path types. + /// + public bool IsStrokeOnly { get; init; } + + /// + /// Stroke width for stroke-only shapes. Only used when IsStrokeOnly is true. + /// + public float StrokeWidth { get; init; } = 5f; + + /// + /// Brush radius for this stroke. All points in the stroke share this radius. + /// + public float Radius { get; init; } + + /// + /// Feathering amount for soft brush edges. 0 = hard edge, 1 = fully soft/blurred. + /// The blur radius is calculated as: effectiveRadius * feathering. + /// + public float Feathering { get; init; } + + /// + /// Points for rendering. Serialization is handled by the custom JsonConverter. + /// + [JsonIgnore] public List Points { get; init; } = []; + /// + /// Bitmap data for flood fill paths. + /// + [JsonIgnore] + public SKBitmap? BitmapData { get; init; } + public SKPath ToSKPath() { var skPath = new SKPath(); @@ -34,4 +277,126 @@ public SKPath ToSKPath() return skPath; } + + /// + /// Gets the effective radius for rendering. Returns Radius if set, otherwise falls back to first point's radius for backward compatibility. + /// + public float GetEffectiveRadius() + { + if (Radius > 0) + return Radius; + + // Backward compatibility: check first point + if (Points.Count > 0 && Points[0].Radius > 0) + return (float)Points[0].Radius; + + return 1f; // Default fallback + } + + /// + /// Compresses points to a base64-encoded gzip string. Public for use by JsonConverter. + /// + public static string? CompressPointsPublic(List points) + { + if (points.Count == 0) + return null; + + // Calculate buffer size: 4 bytes count + 12 bytes per point (3 floats: x, y, pressure) + var bufferSize = 4 + (points.Count * 12); + var buffer = ArrayPool.Shared.Rent(bufferSize); + + try + { + var offset = 0; + + // Write point count + BitConverter.TryWriteBytes(buffer.AsSpan(offset), points.Count); + offset += 4; + + // Write each point as 3 floats + foreach (var point in points) + { + BitConverter.TryWriteBytes(buffer.AsSpan(offset), (float)point.X); + offset += 4; + BitConverter.TryWriteBytes(buffer.AsSpan(offset), (float)point.Y); + offset += 4; + BitConverter.TryWriteBytes(buffer.AsSpan(offset), (float)(point.Pressure ?? 1.0)); + offset += 4; + } + + // Compress with gzip + using var outputStream = new MemoryStream(); + using (var gzipStream = new GZipStream(outputStream, CompressionLevel.Optimal, leaveOpen: true)) + { + gzipStream.Write(buffer, 0, offset); + } + + return Convert.ToBase64String(outputStream.ToArray()); + } + finally + { + ArrayPool.Shared.Return(buffer); + } + } + + /// + /// Decompresses points from a base64-encoded gzip string. Public for use by JsonConverter. + /// + public static List? DecompressPointsPublic(string? compressed) + { + if (string.IsNullOrEmpty(compressed)) + return null; + + try + { + var compressedBytes = Convert.FromBase64String(compressed); + + using var inputStream = new MemoryStream(compressedBytes); + using var gzipStream = new GZipStream(inputStream, CompressionMode.Decompress); + using var outputStream = new MemoryStream(); + + gzipStream.CopyTo(outputStream); + var buffer = outputStream.ToArray(); + + if (buffer.Length < 4) + return null; + + var offset = 0; + + // Read point count + var count = BitConverter.ToInt32(buffer, offset); + offset += 4; + + // Validate we have enough data + if (buffer.Length < 4 + (count * 12)) + return null; + + var points = new List(count); + + for (var i = 0; i < count; i++) + { + var x = BitConverter.ToSingle(buffer, offset); + offset += 4; + var y = BitConverter.ToSingle(buffer, offset); + offset += 4; + var pressure = BitConverter.ToSingle(buffer, offset); + offset += 4; + + points.Add( + new PenPoint(x, y) + { + Pressure = pressure >= 0 && pressure <= 1 ? pressure : null, + IsPen = true, // Mark as pen point so it renders correctly + } + ); + } + + return points; + } + catch + { + // If decompression fails, return null (caller will handle as legacy format) + return null; + } + } } diff --git a/StabilityMatrix.Avalonia/Controls/Models/PenPoint.cs b/StabilityMatrix.Avalonia/Controls/Models/PenPoint.cs index b3f004492..de6ad27bc 100644 --- a/StabilityMatrix.Avalonia/Controls/Models/PenPoint.cs +++ b/StabilityMatrix.Avalonia/Controls/Models/PenPoint.cs @@ -1,8 +1,118 @@ ο»Ώusing System; +using System.Text.Json; +using System.Text.Json.Serialization; using SkiaSharp; namespace StabilityMatrix.Avalonia.Controls.Models; +/// +/// Custom JSON converter for PenPoint to handle serialization of ulong coordinates +/// and legacy double-based formats. +/// +public class PenPointJsonConverter : JsonConverter +{ + public override PenPoint Read( + ref Utf8JsonReader reader, + Type typeToConvert, + JsonSerializerOptions options + ) + { + if (reader.TokenType != JsonTokenType.StartObject) + return default; + + ulong x = 0; + ulong y = 0; + double? pressure = null; + double radius = 1; // Default radius, legacy format stored per-point + bool isPen = true; // Default to true for rendering + + while (reader.Read()) + { + if (reader.TokenType == JsonTokenType.EndObject) + break; + + if (reader.TokenType == JsonTokenType.PropertyName) + { + var propertyName = reader.GetString(); + reader.Read(); + + switch (propertyName?.ToLowerInvariant()) + { + case "x": + // Handle both double and ulong formats + if (reader.TokenType == JsonTokenType.Number) + { + if (reader.TryGetUInt64(out var ulongX)) + x = ulongX; + else if (reader.TryGetDouble(out var doubleX)) + x = Convert.ToUInt64(doubleX); + } + break; + + case "y": + // Handle both double and ulong formats + if (reader.TokenType == JsonTokenType.Number) + { + if (reader.TryGetUInt64(out var ulongY)) + y = ulongY; + else if (reader.TryGetDouble(out var doubleY)) + y = Convert.ToUInt64(doubleY); + } + break; + + case "pressure": + if (reader.TokenType == JsonTokenType.Number) + { + pressure = reader.GetDouble(); + } + break; + + case "ispen": + // Legacy format had IsPen serialized - read it but we'll set true anyway + if (reader.TokenType == JsonTokenType.True || reader.TokenType == JsonTokenType.False) + { + isPen = reader.GetBoolean(); + } + break; + + case "radius": + // Legacy format had Radius on each point - read it for backward compatibility + // GetEffectiveRadius() on PenPath will check Points[0].Radius as fallback + if (reader.TokenType == JsonTokenType.Number) + { + radius = reader.GetDouble(); + } + break; + + default: + reader.Skip(); + break; + } + } + } + + return new PenPoint(x, y) + { + Pressure = pressure, + IsPen = isPen, + Radius = radius, + }; + } + + public override void Write(Utf8JsonWriter writer, PenPoint value, JsonSerializerOptions options) + { + writer.WriteStartObject(); + writer.WriteNumber("x", value.X); + writer.WriteNumber("y", value.Y); + if (value.Pressure.HasValue) + { + writer.WriteNumber("pressure", value.Pressure.Value); + } + writer.WriteEndObject(); + } +} + +[JsonConverter(typeof(PenPointJsonConverter))] public readonly record struct PenPoint(ulong X, ulong Y) { public PenPoint(double x, double y) @@ -14,6 +124,10 @@ public PenPoint(SKPoint skPoint) /// /// Radius of the point. /// + /// + /// Legacy property for backward compatibility. New paths store Radius at the PenPath level. + /// + [JsonIgnore] public double Radius { get; init; } = 1; /// @@ -24,6 +138,10 @@ public PenPoint(SKPoint skPoint) /// /// True if the point was created by a pen, false if it was created by a mouse. /// + /// + /// Runtime-only property for pressure-sensitive rendering. Not persisted. + /// + [JsonIgnore] public bool IsPen { get; init; } public SKPoint ToSKPoint() => new(X, Y); diff --git a/StabilityMatrix.Avalonia/Controls/Painting/PaintCanvas.axaml b/StabilityMatrix.Avalonia/Controls/Painting/PaintCanvas.axaml index cfd62d17a..59b32a321 100644 --- a/StabilityMatrix.Avalonia/Controls/Painting/PaintCanvas.axaml +++ b/StabilityMatrix.Avalonia/Controls/Painting/PaintCanvas.axaml @@ -2,19 +2,18 @@ xmlns="https://github.com/avaloniaui" xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml" xmlns:controls="using:StabilityMatrix.Avalonia.Controls" + xmlns:converters="clr-namespace:StabilityMatrix.Avalonia.Converters" xmlns:fluentIcons="clr-namespace:FluentIcons.Avalonia.Fluent;assembly=FluentIcons.Avalonia.Fluent" xmlns:input="clr-namespace:FluentAvalonia.UI.Input;assembly=FluentAvalonia" xmlns:mocks="clr-namespace:StabilityMatrix.Avalonia.DesignData" + xmlns:models="clr-namespace:StabilityMatrix.Avalonia.Models" xmlns:ui="clr-namespace:FluentAvalonia.UI.Controls;assembly=FluentAvalonia" xmlns:vmControls="clr-namespace:StabilityMatrix.Avalonia.ViewModels.Controls" - xmlns:faIcons="https://github.com/projektanker/icons.avalonia" - xmlns:converters="clr-namespace:StabilityMatrix.Avalonia.Converters" - xmlns:models="clr-namespace:StabilityMatrix.Avalonia.Models" x:DataType="vmControls:PaintCanvasViewModel"> - + @@ -23,11 +22,11 @@ - + - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - diff --git a/StabilityMatrix.Avalonia/Controls/Painting/PaintCanvas.axaml.cs b/StabilityMatrix.Avalonia/Controls/Painting/PaintCanvas.axaml.cs index 773213500..2a960ed0f 100644 --- a/StabilityMatrix.Avalonia/Controls/Painting/PaintCanvas.axaml.cs +++ b/StabilityMatrix.Avalonia/Controls/Painting/PaintCanvas.axaml.cs @@ -35,8 +35,6 @@ private ImmutableList Paths private IDisposable? viewModelSubscription; - private bool isPenDown; - private PaintCanvasViewModel? ViewModel { get; set; } private SkiaCustomCanvas? MainCanvas { get; set; } @@ -134,9 +132,9 @@ protected override void OnPropertyChanged(AvaloniaPropertyChangedEventArgs chang { var newIsEnabled = change.GetNewValue(); - if (!newIsEnabled) + if (!newIsEnabled && ViewModel is { } vm) { - isPenDown = false; + vm.IsPenDown = false; } // On any enabled change, flush temporary paths @@ -164,9 +162,27 @@ private void HandlePointerEvent(PointerEventArgs e) return; } + if (ViewModel is not { } vm) + { + return; + } + + // Handle Move tool separately - it doesn't require drawing to be enabled + if (vm.IsMoveTool) + { + HandleMoveToolEvent(e, vm); + return; + } + + if (!vm.IsDrawingEnabled) + { + return; + } + if (e.RoutedEvent == PointerReleasedEvent && e.Pointer.Type == PointerType.Touch) { TemporaryPaths.TryRemove(e.Pointer.Id, out _); + vm.CancelShapeDrawing(); return; } @@ -176,11 +192,6 @@ private void HandlePointerEvent(PointerEventArgs e) // https://github.com/AvaloniaUI/Avalonia/issues/12289#issuecomment-1695620412 e.PreventGestureRecognition(); - if (DataContext is not PaintCanvasViewModel viewModel) - { - return; - } - var currentPoint = e.GetCurrentPoint(this); if (e.RoutedEvent == PointerPressedEvent) @@ -191,35 +202,115 @@ private void HandlePointerEvent(PointerEventArgs e) return; } - isPenDown = true; + vm.IsPenDown = true; - HandlePointerMoved(e); + if (vm.SelectedTool == PaintCanvasTool.PaintBucket) + { + // Paint bucket: perform flood fill on click + var position = e.GetPosition(MainCanvas); + var fillColor = vm.PaintBrushSKColor.WithAlpha((byte)(vm.PaintBrushAlpha * 255)); + vm.FloodFillAt(new SKPoint((float)position.X, (float)position.Y), fillColor); + vm.IsPenDown = false; + } + else if (vm.IsShapeTool) + { + var position = e.GetPosition(MainCanvas); + vm.StartShapeDrawing(new SKPoint((float)position.X, (float)position.Y), e.Pointer.Id); + } + else + { + HandlePointerMoved(e); + } } else if (e.RoutedEvent == PointerReleasedEvent) { - if (isPenDown) + if (vm.IsPenDown) { - HandlePointerMoved(e); + if (vm.IsShapeTool && vm.ShapeStartPoint.HasValue) + { + var endPoint = e.GetPosition(MainCanvas); + vm.FinalizeShape(new SKPoint((float)endPoint.X, (float)endPoint.Y)); + } + else + { + HandlePointerMoved(e); + } - isPenDown = false; + vm.IsPenDown = false; } - if (TemporaryPaths.TryGetValue(e.Pointer.Id, out var path)) + if (!vm.IsShapeTool && TemporaryPaths.TryGetValue(e.Pointer.Id, out var path)) { Paths = Paths.Add(path); + vm.ClearRedoStack(); // New path added, clear redo history } - TemporaryPaths.TryRemove(e.Pointer.Id, out _); + if (!vm.IsShapeTool) + { + TemporaryPaths.TryRemove(e.Pointer.Id, out _); + } } else { // Moved event - if (!isPenDown || currentPoint.Properties.Pressure == 0) + if (!vm.IsPenDown) { return; } - HandlePointerMoved(e); + if (vm.IsShapeTool && vm.ShapeStartPoint.HasValue) + { + var endPoint = e.GetPosition(MainCanvas); + vm.UpdateShapePreview(new SKPoint((float)endPoint.X, (float)endPoint.Y)); + } + else if (currentPoint.Properties.Pressure != 0) + { + HandlePointerMoved(e); + } + } + + Dispatcher.UIThread.Post(() => MainCanvas?.InvalidateVisual(), DispatcherPriority.Render); + } + + private void HandleMoveToolEvent(PointerEventArgs e, PaintCanvasViewModel vm) + { + e.Handled = true; + e.PreventGestureRecognition(); + + var currentPoint = e.GetCurrentPoint(this); + + if (e.RoutedEvent == PointerPressedEvent) + { + // Ignore if mouse and not left button + if (e.Pointer.Type == PointerType.Mouse && !currentPoint.Properties.IsLeftButtonPressed) + { + return; + } + + vm.IsPenDown = true; + var position = e.GetPosition(MainCanvas); + // Get current offsets from the callback, or use (0, 0) if not set + var currentOffset = vm.GetCurrentMoveOffset?.Invoke() ?? (0, 0); + vm.StartMove(new SKPoint((float)position.X, (float)position.Y), currentOffset.X, currentOffset.Y); + } + else if (e.RoutedEvent == PointerReleasedEvent) + { + if (vm.IsPenDown) + { + vm.EndMove(); + vm.IsPenDown = false; + } + } + else + { + // Moved event + if (!vm.IsPenDown || !vm.MoveStartPoint.HasValue) + { + return; + } + + var position = e.GetPosition(MainCanvas); + vm.UpdateMove(new SKPoint((float)position.X, (float)position.Y)); } Dispatcher.UIThread.Post(() => MainCanvas?.InvalidateVisual(), DispatcherPriority.Render); @@ -235,8 +326,6 @@ private void HandlePointerMoved(PointerEventArgs e) // Use intermediate points to include past events we missed var points = e.GetIntermediatePoints(MainCanvas); - Debug.WriteLine($"Points: {string.Join(",", points.Select(p => p.Position.ToString()))}"); - if (points.Count == 0) { return; @@ -250,7 +339,9 @@ private void HandlePointerMoved(PointerEventArgs e) penPath = new PenPath { FillColor = viewModel.PaintBrushSKColor.WithAlpha((byte)(viewModel.PaintBrushAlpha * 255)), - IsErase = viewModel.SelectedTool == PaintCanvasTool.Eraser + IsErase = viewModel.SelectedTool == PaintCanvasTool.Eraser, + Radius = (float)viewModel.PaintBrushSize, + Feathering = (float)viewModel.PaintBrushFeathering, }; TemporaryPaths[e.Pointer.Id] = penPath; } @@ -275,7 +366,7 @@ private void HandlePointerMoved(PointerEventArgs e) { Pressure = point.Pointer.Type == PointerType.Mouse ? null : point.Properties.Pressure, Radius = viewModel.PaintBrushSize, - IsPen = point.Pointer.Type == PointerType.Pen + IsPen = point.Pointer.Type == PointerType.Pen, }; penPath.Points.Add(penPoint); @@ -311,7 +402,113 @@ protected override void OnKeyDown(KeyEventArgs e) if (e.Key == Key.Escape) { e.Handled = true; + return; + } + + // Keyboard shortcuts for paint canvas + if (ViewModel is not { } vm) + return; + + // Check for modifier keys + var isCtrl = e.KeyModifiers.HasFlag(KeyModifiers.Control); + + // Ctrl+Z: Undo + if (isCtrl && e.Key == Key.Z && !e.KeyModifiers.HasFlag(KeyModifiers.Shift)) + { + if (vm.UndoCommand.CanExecute(null)) + { + vm.UndoCommand.Execute(null); + RefreshCanvas(); + } + e.Handled = true; + return; + } + + // Ctrl+Y or Ctrl+Shift+Z: Redo + if ( + (isCtrl && e.Key == Key.Y) + || (isCtrl && e.KeyModifiers.HasFlag(KeyModifiers.Shift) && e.Key == Key.Z) + ) + { + if (vm.RedoCommand.CanExecute(null)) + { + vm.RedoCommand.Execute(null); + RefreshCanvas(); + } + e.Handled = true; + return; } + + // Arrow key nudging when Move tool is selected + if (vm.IsMoveTool) + { + var nudgeAmount = e.KeyModifiers.HasFlag(KeyModifiers.Shift) ? 10.0 : 1.0; + double deltaX = 0, + deltaY = 0; + + switch (e.Key) + { + case Key.Left: + deltaX = -nudgeAmount; + break; + case Key.Right: + deltaX = nudgeAmount; + break; + case Key.Up: + deltaY = -nudgeAmount; + break; + case Key.Down: + deltaY = nudgeAmount; + break; + } + + if (deltaX != 0 || deltaY != 0) + { + // Get current offset, apply delta, and invoke callback + var currentOffset = vm.GetCurrentMoveOffset?.Invoke() ?? (0, 0); + vm.OnMoveToolDrag?.Invoke(currentOffset.X + deltaX, currentOffset.Y + deltaY); + RefreshCanvas(); + e.Handled = true; + return; + } + } + + // Skip tool shortcuts if modifiers are held (to not interfere with other shortcuts) + // But allow Shift for arrow key nudging (handled above) + if (e.KeyModifiers != KeyModifiers.None && e.KeyModifiers != KeyModifiers.Shift) + return; + + switch (e.Key) + { + case Key.B: + vm.SelectBrushToolCommand.Execute(null); + break; + case Key.E: + vm.SelectEraserToolCommand.Execute(null); + break; + case Key.R: + vm.SelectRectangleToolCommand.Execute(null); + break; + case Key.O: + vm.SelectEllipseToolCommand.Execute(null); + break; + case Key.OemOpenBrackets: + vm.DecreaseBrushSizeCommand.Execute(null); + break; + case Key.OemCloseBrackets: + vm.IncreaseBrushSizeCommand.Execute(null); + break; + case Key.G: + vm.SelectPaintBucketToolCommand.Execute(null); + break; + case Key.V: + vm.SelectMoveToolCommand.Execute(null); + break; + default: + return; + } + UpdateCanvasCursor(); + e.Handled = true; } /// @@ -340,6 +537,7 @@ private void UpdateMainCanvasBounds() private int lastCanvasCursorRadius; private Cursor? lastCanvasCursor; + private PaintCanvasTool? lastCanvasCursorTool; private void UpdateCanvasCursor() { @@ -348,6 +546,39 @@ private void UpdateCanvasCursor() return; } + var selectedTool = ViewModel?.SelectedTool ?? PaintCanvasTool.PaintBrush; + + // Use crosshair for shape tools and paint bucket + if ( + selectedTool + is PaintCanvasTool.Rectangle + or PaintCanvasTool.Ellipse + or PaintCanvasTool.PaintBucket + ) + { + if (lastCanvasCursorTool != selectedTool) + { + lastCanvasCursor?.Dispose(); + lastCanvasCursor = new Cursor(StandardCursorType.Cross); + lastCanvasCursorTool = selectedTool; + } + canvas.Cursor = lastCanvasCursor; + return; + } + + // Use SizeAll (move) cursor for Move tool + if (selectedTool == PaintCanvasTool.Move) + { + if (lastCanvasCursorTool != selectedTool) + { + lastCanvasCursor?.Dispose(); + lastCanvasCursor = new Cursor(StandardCursorType.SizeAll); + lastCanvasCursorTool = selectedTool; + } + canvas.Cursor = lastCanvasCursor; + return; + } + var currentZoom = ViewModel?.CurrentZoom ?? 1; // Get brush size @@ -355,13 +586,14 @@ private void UpdateCanvasCursor() var brushRadius = (int)Math.Ceiling(currentBrushSize * 2 * currentZoom); // Only update cursor if brush size has changed - if (brushRadius == lastCanvasCursorRadius) + if (brushRadius == lastCanvasCursorRadius && lastCanvasCursorTool == selectedTool) { canvas.Cursor = lastCanvasCursor; return; } lastCanvasCursorRadius = brushRadius; + lastCanvasCursorTool = selectedTool; var brushDiameter = brushRadius * 2; @@ -386,7 +618,7 @@ private void UpdateCanvasCursor() StrokeCap = SKStrokeCap.Round, StrokeJoin = SKStrokeJoin.Round, IsDither = true, - IsAntialias = true + IsAntialias = true, } ); cursorCanvas.Flush(); @@ -415,26 +647,6 @@ private void MainCanvas_OnPointerExited(object? sender, PointerEventArgs e) } } - private Point GetRelativePosition(Point pt, Visual? relativeTo) - { - if (VisualRoot is not Visual visualRoot) - return default; - if (relativeTo == null) - return pt; - - return pt * visualRoot.TransformToVisual(relativeTo) ?? default; - } - - public AsyncRelayCommand ClearCanvasCommand => new(ClearCanvasAsync); - - public async Task ClearCanvasAsync() - { - Paths = ImmutableList.Empty; - TemporaryPaths.Clear(); - - await Dispatcher.UIThread.InvokeAsync(() => MainCanvas?.InvalidateVisual()); - } - private void OnRenderSkia(SKSurface surface) { ViewModel?.RenderToSurface(surface, renderBackgroundFill: true, renderBackgroundImage: true); diff --git a/StabilityMatrix.Avalonia/Controls/SelectableImageCard/SelectableImageButton.axaml b/StabilityMatrix.Avalonia/Controls/SelectableImageCard/SelectableImageButton.axaml index bfbc310e2..ea08be458 100644 --- a/StabilityMatrix.Avalonia/Controls/SelectableImageCard/SelectableImageButton.axaml +++ b/StabilityMatrix.Avalonia/Controls/SelectableImageCard/SelectableImageButton.axaml @@ -1,30 +1,29 @@ -ο»Ώ +ο»Ώ - - + + - + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/StabilityMatrix.Avalonia/Styles/ScrollBarStyles.axaml b/StabilityMatrix.Avalonia/Styles/ScrollBarStyles.axaml new file mode 100644 index 000000000..b1c4629db --- /dev/null +++ b/StabilityMatrix.Avalonia/Styles/ScrollBarStyles.axaml @@ -0,0 +1,65 @@ +ο»Ώ + + + + False + 14 + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/StabilityMatrix.Avalonia/ViewModels/BananaVisionPageViewModel.Downloads.cs b/StabilityMatrix.Avalonia/ViewModels/BananaVisionPageViewModel.Downloads.cs new file mode 100644 index 000000000..39f6c96ce --- /dev/null +++ b/StabilityMatrix.Avalonia/ViewModels/BananaVisionPageViewModel.Downloads.cs @@ -0,0 +1,283 @@ +using System.Threading; +using AsyncAwaitBestPractices; +using Avalonia.Controls.Notifications; +using Avalonia.Threading; +using CommunityToolkit.Mvvm.ComponentModel; +using CommunityToolkit.Mvvm.Input; +using FluentAvalonia.UI.Controls; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using StabilityMatrix.Avalonia.Services; +using StabilityMatrix.Avalonia.ViewModels.Dialogs; +using StabilityMatrix.Core.Models; +using StabilityMatrix.Core.Models.Progress; +using StabilityMatrix.Core.Services.ImageGeneration; + +namespace StabilityMatrix.Avalonia.ViewModels; + +public partial class BananaVisionPageViewModel +{ + /// + /// Whether there are missing models that can be downloaded + /// + [ObservableProperty] + public partial bool HasMissingModels { get; set; } + + /// + /// Whether a model-download batch is currently in progress. + /// While true, the status banner shows download progress instead of the missing-models warning. + /// + [ObservableProperty] + public partial bool IsDownloadingModels { get; set; } + + /// + /// Human-readable progress text for the in-flight download batch (e.g. "Downloading models (2/4)..."). + /// + [ObservableProperty] + public partial string? DownloadProgressText { get; set; } + + partial void OnIsDownloadingModelsChanged(bool value) + { + UpdateProviderStatus(); + } + + partial void OnDownloadProgressTextChanged(string? value) + { + if (IsDownloadingModels) + { + UpdateProviderStatus(); + } + } + + /// + /// Check for missing models and auto-show the download dialog if needed + /// + private async Task CheckAndShowMissingModelsDialogAsync() + { + // Don't show if we've already shown it this session + if (hasShownMissingModelsDialog) + return; + + // Wait a moment for connection status to settle + await Task.Delay(500); + + // Only show if connected and models are missing + if (!ClientManager.IsConnected || !HasMissingModels) + return; + + hasShownMissingModelsDialog = true; + await ShowMissingModelsDialogAsync(); + } + + /// + /// Show the missing models download dialog + /// + [RelayCommand] + private async Task ShowMissingModelsDialogAsync() + { + if (!ClientManager.IsConnected) + { + notificationService.Show( + "Not Connected", + "Please connect to ComfyUI first to check for missing models.", + NotificationType.Warning + ); + return; + } + + // Get the model manager for the current provider + var modelManager = LocalProviderModelManagerRegistry.GetManager(SelectedProviderId); + if (modelManager == null) + { + logger.LogWarning("No model manager found for provider {ProviderId}", SelectedProviderId); + return; + } + + var missingModels = modelManager.GetMissingModels(ClientManager).ToList(); + + if (missingModels.Count == 0) + { + notificationService.Show( + "All Models Present", + "All required models are already installed!", + NotificationType.Success + ); + return; + } + + logger.LogInformation( + "Showing missing models dialog for {Provider} with {Count} models", + modelManager.ProviderDisplayName, + missingModels.Count + ); + + // Create and configure the dialog using manager's properties + var dialogVm = vmFactory.Get(); + dialogVm.DialogTitle = $"{modelManager.ProviderDisplayName} Setup"; + dialogVm.Description = modelManager.DownloadDialogDescription; + dialogVm.SetModels(missingModels); + + var dialog = dialogVm.GetDialog(); + var result = await dialog.ShowAsync(); + + // If user clicked Download, start the downloads + if (result == ContentDialogResult.Primary && dialogVm.SelectedCount > 0) + { + // Start downloads (runs in background via TrackedDownloadService) + var downloads = await dialogVm.StartDownloadsAsync(); + + if (downloads.Count > 0) + { + // Switch the status banner over to a download-progress view so it doesn't + // keep showing "⚠️ Missing: X, Y, Z" with a Download button while the + // download is already running. + DownloadProgressText = $"⬇️ Downloading models (0/{downloads.Count})..."; + IsDownloadingModels = true; + + notificationService.Show( + "Downloads Started", + $"Downloading {downloads.Count} model(s). Check the progress panel for status.", + NotificationType.Information + ); + + // Track completion of all downloads + TrackDownloadCompletionAsync(downloads, modelManager.ProviderDisplayName) + .SafeFireAndForget(ex => + { + logger.LogError(ex, "Failed to track download completion"); + }); + } + } + } + + /// + /// Track when all downloads complete and show notification + /// + private async Task TrackDownloadCompletionAsync( + List downloads, + string providerDisplayName + ) + { + var totalCount = downloads.Count; + var completedCount = 0; + + void BumpProgress(ProgressState state) + { + // Each terminal-state event bumps the completed count; UI update is marshaled + // because ProgressStateChanged may fire from a background thread. + var newCompleted = Interlocked.Increment(ref completedCount); + Dispatcher.UIThread.Post(() => + { + if (IsDownloadingModels) + { + DownloadProgressText = $"⬇️ Downloading models ({newCompleted}/{totalCount})..."; + } + }); + } + + var completionTasks = downloads + .Select(d => + { + var tcs = new TaskCompletionSource(); + var counted = 0; // Guard against double-counting if both handler + already-completed fire + + void OnTerminal(ProgressState state) + { + if (Interlocked.Exchange(ref counted, 1) == 0) + { + BumpProgress(state); + } + tcs.TrySetResult(state == ProgressState.Success); + } + + d.ProgressStateChanged += (s, state) => + { + if (state is ProgressState.Success or ProgressState.Failed or ProgressState.Cancelled) + { + OnTerminal(state); + } + }; + + // Check if already completed + if ( + d.ProgressState + is ProgressState.Success + or ProgressState.Failed + or ProgressState.Cancelled + ) + { + OnTerminal(d.ProgressState); + } + + return tcs.Task; + }) + .ToList(); + + // Wait for all downloads to complete + var results = await Task.WhenAll(completionTasks); + var successCount = results.Count(r => r); + var failCount = results.Count(r => !r); + + logger.LogInformation( + "Model downloads completed: {Success} succeeded, {Failed} failed", + successCount, + failCount + ); + + // Refresh model index + await modelIndexService.RefreshIndex(); + + // Reconnect to ComfyUI to refresh model lists + if (ClientManager.IsConnected) + { + try + { + await ClientManager.ConnectAsync(); + } + catch (Exception ex) + { + logger.LogWarning(ex, "Failed to reconnect after model download"); + } + } + + // Update status on UI thread + await Dispatcher.UIThread.InvokeAsync(() => + { + // Clear the download-in-progress flag before recomputing status so the banner + // returns to its normal state ("βœ… ready" or "⚠️ Missing: ...") immediately. + IsDownloadingModels = false; + DownloadProgressText = null; + UpdateProviderStatus(); + LoadAvailableFluxModels(); + LoadAvailableQwenModels(); + LoadAvailableKleinModels(); + }); + + // Show completion notification + if (failCount == 0 && successCount > 0) + { + notificationService.Show( + "Models Ready! πŸŽ‰", + $"All required models have been downloaded. {providerDisplayName} is ready to use!", + NotificationType.Success, + TimeSpan.FromSeconds(8) + ); + } + else if (successCount > 0) + { + notificationService.Show( + "Downloads Partially Complete", + $"{successCount} model(s) downloaded, {failCount} failed. Check the progress panel for details.", + NotificationType.Warning + ); + } + else + { + notificationService.Show( + "Downloads Failed", + "All model downloads failed. Please check your connection and try again.", + NotificationType.Error + ); + } + } +} diff --git a/StabilityMatrix.Avalonia/ViewModels/BananaVisionPageViewModel.Models.cs b/StabilityMatrix.Avalonia/ViewModels/BananaVisionPageViewModel.Models.cs new file mode 100644 index 000000000..67bf7aafe --- /dev/null +++ b/StabilityMatrix.Avalonia/ViewModels/BananaVisionPageViewModel.Models.cs @@ -0,0 +1,417 @@ +using System.Collections.ObjectModel; +using Avalonia; +using Avalonia.Controls; +using Avalonia.Controls.Notifications; +using Avalonia.Layout; +using Avalonia.Styling; +using CommunityToolkit.Mvvm.Input; +using FluentAvalonia.UI.Controls; +using Microsoft.Extensions.Logging; +using StabilityMatrix.Avalonia.Controls; +using StabilityMatrix.Avalonia.Models.BananaVision; +using StabilityMatrix.Core.Models; +using StabilityMatrix.Core.Services.ImageGeneration; + +namespace StabilityMatrix.Avalonia.ViewModels; + +public partial class BananaVisionPageViewModel +{ + /// + /// Sorts models by connected status first, then alphabetically by display name + /// + private static IOrderedEnumerable SortModelsByConnectedThenName( + IEnumerable models + ) + { + return models + .OrderByDescending(m => m.Local?.ConnectedModelInfo != null) + .ThenBy(m => m.Local?.DisplayModelName ?? m.ShortDisplayName); + } + + /// + /// Populates a collection with sorted models from multiple priority groups + /// + private static void PopulateModelCollection( + ObservableCollection collection, + params IEnumerable[] modelGroups + ) + { + collection.Clear(); + foreach (var group in modelGroups) + { + foreach (var model in SortModelsByConnectedThenName(group)) + { + collection.Add(model); + } + } + } + + /// + /// Categorizes models from a folder type based on search terms + /// + /// The folder type to search + /// Primary search terms (highest priority) + /// Secondary search terms (medium priority, optional) + /// Tuple of (matched models, secondary matched models, untagged models) + private ( + List Primary, + List Secondary, + List Untagged + ) CategorizeModelsByTerms( + SharedFolderType folderType, + string[] primaryTerms, + string[]? secondaryTerms = null + ) + { + var primaryModels = new List(); + var secondaryModels = new List(); + var untaggedModels = new List(); + + foreach (var model in modelIndexService.FindByModelType(folderType).Select(HybridModelFile.FromLocal)) + { + var baseModel = model.Local?.ConnectedModelInfo?.BaseModel; + + // Check primary terms first + if ( + primaryTerms.Any(term => + baseModel?.Contains(term, StringComparison.OrdinalIgnoreCase) == true + ) + ) + { + primaryModels.Add(model); + } + // Check secondary terms + else if ( + secondaryTerms?.Any(term => + baseModel?.Contains(term, StringComparison.OrdinalIgnoreCase) == true + ) == true + ) + { + secondaryModels.Add(model); + } + // Check filename fallback for untagged models + else if (string.IsNullOrEmpty(baseModel)) + { + if ( + primaryTerms.Any(term => + model.FileName.Contains(term, StringComparison.OrdinalIgnoreCase) + ) + ) + { + primaryModels.Add(model); + } + else + { + untaggedModels.Add(model); + } + } + } + + return (primaryModels, secondaryModels, untaggedModels); + } + + /// + /// Loads available Flux Kontext models from the DiffusionModels folder using local model index + /// + private void LoadAvailableFluxModels() + { + // Load UNet models - prioritize Kontext + var (kontextModels, _, untaggedModels) = CategorizeModelsByTerms( + SharedFolderType.DiffusionModels, + ["Kontext"] + ); + + PopulateModelCollection(AvailableFluxModels, kontextModels, untaggedModels); + + // Auto-select first Kontext model if available + if (SelectedFluxModel == null && AvailableFluxModels.Count > 0) + { + SelectedFluxModel = + AvailableFluxModels.FirstOrDefault(m => + m.FileName.Contains("kontext", StringComparison.OrdinalIgnoreCase) + ) ?? AvailableFluxModels.First(); + } + + // Load LoRA models - prioritize Kontext, then Flux, then untagged + var (kontextLoras, fluxLoras, untaggedLoras) = CategorizeModelsByTerms( + SharedFolderType.Lora | SharedFolderType.LyCORIS, + ["Kontext"], + ["Flux"] + ); + + PopulateModelCollection(AvailableFluxLoras, kontextLoras, fluxLoras, untaggedLoras); + + logger.LogInformation( + "Loaded {ModelCount} Flux models and {LoraCount} LoRAs from local index", + AvailableFluxModels.Count, + AvailableFluxLoras.Count + ); + } + + /// + /// Loads available Qwen Image Edit models from the DiffusionModels folder using local model index + /// + private void LoadAvailableQwenModels() + { + // Load UNet models - prioritize Qwen + var (qwenModels, _, untaggedModels) = CategorizeModelsByTerms( + SharedFolderType.DiffusionModels, + ["Qwen"] + ); + + PopulateModelCollection(AvailableQwenModels, qwenModels, untaggedModels); + + // Auto-select first Qwen model if available + if (SelectedQwenModel == null && AvailableQwenModels.Count > 0) + { + SelectedQwenModel = + AvailableQwenModels.FirstOrDefault(m => + m.FileName.Contains("qwen", StringComparison.OrdinalIgnoreCase) + ) ?? AvailableQwenModels.First(); + } + + // Load LoRA models - prioritize Qwen, then untagged + var (qwenLoras, _, untaggedLoras) = CategorizeModelsByTerms( + SharedFolderType.Lora | SharedFolderType.LyCORIS, + ["Qwen"] + ); + + PopulateModelCollection(AvailableQwenLoras, qwenLoras, untaggedLoras); + + logger.LogInformation( + "Loaded {ModelCount} Qwen models and {LoraCount} LoRAs from local index", + AvailableQwenModels.Count, + AvailableQwenLoras.Count + ); + } + + /// + /// Loads available Flux.2 Klein models from the DiffusionModels folder using local model index. + /// Picks up both Klein 4B and Klein 9B variants for the dropdown selector. + /// + private void LoadAvailableKleinModels() + { + // Load UNet models - prioritize Klein, then any Flux.2 (catches future variants), then untagged + var (kleinModels, flux2Models, untaggedModels) = CategorizeModelsByTerms( + SharedFolderType.DiffusionModels, + ["Klein", "flux-2-klein", "flux2-klein"], + ["Flux.2", "flux2"] + ); + + PopulateModelCollection(AvailableKleinModels, kleinModels, flux2Models, untaggedModels); + + // Auto-select first Klein model if available β€” prefer 4B since it's the auto-downloaded + // Apache 2.0 default, then any other Klein variant the user has dropped in + if (SelectedKleinModel == null && AvailableKleinModels.Count > 0) + { + SelectedKleinModel = + AvailableKleinModels.FirstOrDefault(m => + m.FileName.Contains("klein-4b", StringComparison.OrdinalIgnoreCase) + || m.FileName.Contains("klein_4b", StringComparison.OrdinalIgnoreCase) + ) + ?? AvailableKleinModels.FirstOrDefault(m => + m.FileName.Contains("klein", StringComparison.OrdinalIgnoreCase) + ) + ?? AvailableKleinModels.First(); + } + + // Load LoRA models - prioritize Klein, then any Flux LoRA, then untagged + var (kleinLoras, fluxLoras, untaggedLoras) = CategorizeModelsByTerms( + SharedFolderType.Lora | SharedFolderType.LyCORIS, + ["Klein", "Flux.2"], + ["Flux"] + ); + + PopulateModelCollection(AvailableKleinLoras, kleinLoras, fluxLoras, untaggedLoras); + + logger.LogInformation( + "Loaded {ModelCount} Klein models and {LoraCount} LoRAs from local index", + AvailableKleinModels.Count, + AvailableKleinLoras.Count + ); + } + + [RelayCommand] + private async Task AddLoraAsync() + { + // Get available LoRAs based on current provider + var availableLoras = SelectedProviderId switch + { + BananaVisionProviderIds.QwenImageEdit => AvailableQwenLoras, + BananaVisionProviderIds.Flux2Klein => AvailableKleinLoras, + _ => AvailableFluxLoras, + }; + + if (availableLoras.Count == 0) + { + notificationService.Show( + "No LoRAs Available", + "No compatible LoRA models found.", + NotificationType.Warning + ); + return; + } + + // Create a styled selection dialog using BetterComboBox with HybridModel theme + var comboBox = new BetterComboBox + { + ItemsSource = availableLoras, + SelectedIndex = 0, + MinWidth = 350, + Padding = new Thickness(8, 6, 4, 6), + HorizontalAlignment = HorizontalAlignment.Stretch, + }; + + // Apply the HybridModel theme + if ( + App.Current?.Resources.TryGetResource( + "BetterComboBoxHybridModelTheme", + App.Current.ActualThemeVariant, + out var theme + ) == true + && theme is ControlTheme controlTheme + ) + { + comboBox.Theme = controlTheme; + } + + var dialog = new ContentDialog + { + Title = "Add LoRA", + Content = comboBox, + PrimaryButtonText = "Add", + CloseButtonText = "Cancel", + DefaultButton = ContentDialogButton.Primary, + }; + + var result = await dialog.ShowAsync(); + + if (result == ContentDialogResult.Primary && comboBox.SelectedItem is HybridModelFile selectedLora) + { + // Check if already added + if (SelectedLoras.Any(l => l.Model.RelativePath == selectedLora.RelativePath)) + { + notificationService.Show( + "Already Added", + "This LoRA is already in the list.", + NotificationType.Warning + ); + return; + } + + SelectedLoras.Add(new SelectedLora { Model = selectedLora }); + } + } + + [RelayCommand] + private void RemoveLora(SelectedLora lora) + { + SelectedLoras.Remove(lora); + } + + [RelayCommand] + private void ToggleFluxSettings() + { + IsFluxSettingsExpanded = !IsFluxSettingsExpanded; + } + + [RelayCommand] + private void ToggleQwenSettings() + { + IsQwenSettingsExpanded = !IsQwenSettingsExpanded; + } + + [RelayCommand] + private void ToggleKleinSettings() + { + IsKleinSettingsExpanded = !IsKleinSettingsExpanded; + } + + /// + /// When the user picks a different Klein model, snap Steps/CFG to the recommended + /// defaults for that variant. Distilled = 4 steps / CFG 1, Base = 20 steps / CFG 5. + /// The user can still override afterwards; this just sets sane starting values. + /// + partial void OnSelectedKleinModelChanged(HybridModelFile? value) + { + if (value == null) + return; + + var (recommendedSteps, recommendedCfg) = DetectKleinDefaults(value); + KleinSteps = recommendedSteps; + KleinCfg = recommendedCfg; + } + + /// + /// Returns the recommended Steps and CFG for a Klein UNET, based on filename and + /// CivitAI metadata. Base variants need 20 steps / CFG 5; distilled needs 4 / 1. + /// 9B models without an explicit "distilled" tag are assumed to be base, since + /// Klein 9B distilled isn't publicly shipped β€” almost all 9B installs are base + /// (or fine-tunes of base). 4B without signals defaults to distilled, matching + /// our auto-downloaded Apache 2.0 default. + /// + private static (int Steps, double Cfg) DetectKleinDefaults(HybridModelFile model) + { + var info = model.Local?.ConnectedModelInfo; + + var haystacks = new List { model.FileName }; + if (info != null) + { + if (!string.IsNullOrEmpty(info.BaseModel)) + haystacks.Add(info.BaseModel); + if (!string.IsNullOrEmpty(info.ModelName)) + haystacks.Add(info.ModelName); + if (!string.IsNullOrEmpty(info.VersionName)) + haystacks.Add(info.VersionName); + if (!string.IsNullOrEmpty(info.VersionDescription)) + haystacks.Add(info.VersionDescription); + if (info.TrainedWords != null) + haystacks.AddRange(info.TrainedWords); + } + + bool LooksLikeBase(string s) => + s.Contains("base", StringComparison.OrdinalIgnoreCase) + || s.Contains("non-distilled", StringComparison.OrdinalIgnoreCase) + || s.Contains("non_distilled", StringComparison.OrdinalIgnoreCase) + || s.Contains("nondistilled", StringComparison.OrdinalIgnoreCase) + || s.Contains("foundation", StringComparison.OrdinalIgnoreCase); + + bool LooksLikeDistilled(string s) => + s.Contains("distilled", StringComparison.OrdinalIgnoreCase) + || s.Contains("turbo", StringComparison.OrdinalIgnoreCase); + + bool LooksLikeNineB(string s) => + s.Contains("9b", StringComparison.OrdinalIgnoreCase) + || s.Contains("9 b", StringComparison.OrdinalIgnoreCase) + || s.Contains("9-b", StringComparison.OrdinalIgnoreCase) + || s.Contains("klein 9", StringComparison.OrdinalIgnoreCase) + || s.Contains("klein-9", StringComparison.OrdinalIgnoreCase) + || s.Contains("klein_9", StringComparison.OrdinalIgnoreCase); + + var hasBaseSignal = haystacks.Any(LooksLikeBase); + var hasDistilledSignal = haystacks.Any(LooksLikeDistilled); + var hasNineBSignal = haystacks.Any(LooksLikeNineB); + + // Ambiguous case: BOTH "base" and "distilled" appear (common for community uploads + // labeled e.g. "Klein 9B Base & Distilled" that cover both variants). Prefer base + // for 9B (distilled 9B isn't publicly shipped) and distilled for 4B (matches our + // auto-download default). + if (hasBaseSignal && hasDistilledSignal) + return hasNineBSignal ? (20, 5.0) : (4, 1.0); + + // Unambiguous explicit tags. + if (hasDistilledSignal) + return (4, 1.0); + if (hasBaseSignal) + return (20, 5.0); + + // No explicit base/distilled signal, but it's a 9B variant β€” default to base. + // Klein 9B distilled isn't publicly shipped, so 9B installs (including merges and + // fine-tunes) are almost always base-derived and need 20 steps / CFG 5. + if (hasNineBSignal) + return (20, 5.0); + + // Default: distilled (matches the auto-downloaded Apache 2.0 Klein 4B). + return (4, 1.0); + } +} diff --git a/StabilityMatrix.Avalonia/ViewModels/BananaVisionPageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/BananaVisionPageViewModel.cs new file mode 100644 index 000000000..b78c4d6c1 --- /dev/null +++ b/StabilityMatrix.Avalonia/ViewModels/BananaVisionPageViewModel.cs @@ -0,0 +1,2595 @@ +ο»Ώusing System.Collections.ObjectModel; +using System.Collections.Specialized; +using System.Net.Http; +using System.Net.Sockets; +using System.Reactive.Linq; +using AsyncAwaitBestPractices; +using Avalonia.Controls; +using Avalonia.Controls.Notifications; +using Avalonia.Input; +using Avalonia.Media.Imaging; +using Avalonia.Platform.Storage; +using Avalonia.Threading; +using CommunityToolkit.Mvvm.ComponentModel; +using CommunityToolkit.Mvvm.Input; +using FluentAvalonia.UI.Controls; +using FluentAvalonia.UI.Media.Animation; +using Injectio.Attributes; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using StabilityMatrix.Avalonia.Controls; +using StabilityMatrix.Avalonia.Helpers; +using StabilityMatrix.Avalonia.Models; +using StabilityMatrix.Avalonia.Models.BananaVision; +using StabilityMatrix.Avalonia.Services; +using StabilityMatrix.Avalonia.ViewModels.Base; +using StabilityMatrix.Avalonia.ViewModels.Dialogs; +using StabilityMatrix.Avalonia.ViewModels.Settings; +using StabilityMatrix.Avalonia.Views; +using StabilityMatrix.Core.Attributes; +using StabilityMatrix.Core.Helper; +using StabilityMatrix.Core.Models; +using StabilityMatrix.Core.Models.Database; +using StabilityMatrix.Core.Models.Packages; +using StabilityMatrix.Core.Services; +using StabilityMatrix.Core.Services.ImageGeneration; + +namespace StabilityMatrix.Avalonia.ViewModels; + +[View(typeof(BananaVisionPage))] +[RegisterSingleton] +public partial class BananaVisionPageViewModel : PageViewModelBase +{ + private readonly ILogger logger; + private readonly IImageGenerationChatService chatService; + private readonly ISecretsManager secretsManager; + private readonly INotificationService notificationService; + private readonly IServiceManager vmFactory; + private readonly IModelIndexService modelIndexService; + private readonly INavigationService navigationService; + private readonly INavigationService settingsNavigationService; + + public override string Title => "Image Lab"; + public override IconSource IconSource => new FASymbolIconSource { Symbol = "fa-solid fa-flask" }; + + public IInferenceClientManager ClientManager { get; } + + [ObservableProperty] + public partial string? NewMessageText { get; set; } + + [ObservableProperty] + [NotifyCanExecuteChangedFor(nameof(SendMessageCommand))] + [NotifyPropertyChangedFor(nameof(IsCurrentConversationGenerating))] + public partial bool IsGenerating { get; set; } + + [ObservableProperty] + [NotifyPropertyChangedFor(nameof(IsCurrentConversationGenerating))] + public partial Guid? GeneratingConversationId { get; set; } + + /// + /// True if the currently selected conversation is the one that's generating. + /// Used to scope the progress indicator to the active conversation. + /// + public bool IsCurrentConversationGenerating => + IsGenerating && CurrentConversation != null && GeneratingConversationId == CurrentConversation.Id; + + [ObservableProperty] + [NotifyPropertyChangedFor(nameof(HasGenerationProgress))] + [NotifyPropertyChangedFor(nameof(GenerationProgressText))] + public partial int? GenerationProgressPercent { get; set; } + + [ObservableProperty] + [NotifyPropertyChangedFor(nameof(HasGenerationProgress))] + [NotifyPropertyChangedFor(nameof(GenerationProgressText))] + public partial string? GenerationProgressStage { get; set; } + + [ObservableProperty] + [NotifyPropertyChangedFor(nameof(HasGenerationProgress))] + [NotifyPropertyChangedFor(nameof(GenerationProgressText))] + public partial string? GenerationProgressRunningNode { get; set; } + + public bool HasGenerationProgress => + RequiresLocalBackend + && ( + GenerationProgressPercent != null + || !string.IsNullOrEmpty(GenerationProgressStage) + || !string.IsNullOrEmpty(GenerationProgressRunningNode) + ); + + public string GenerationProgressText + { + get + { + if (!IsGenerating) + return "Ready"; + + if (!RequiresLocalBackend) + return "Creating your image..."; + + var stage = string.IsNullOrWhiteSpace(GenerationProgressStage) + ? "Creating your image..." + : GenerationProgressStage; + var node = string.IsNullOrWhiteSpace(GenerationProgressRunningNode) + ? null + : GenerationProgressRunningNode.Replace('_', ' '); + var percent = GenerationProgressPercent; + + if (percent is >= 0 and <= 100 && !string.IsNullOrWhiteSpace(node)) + return $"{stage} ({percent}%) β€’ {node}"; + if (percent is >= 0 and <= 100) + return $"{stage} ({percent}%)"; + if (!string.IsNullOrWhiteSpace(node)) + return $"{stage} β€’ {node}"; + + return stage; + } + } + + [ObservableProperty] + public partial string? ErrorMessage { get; set; } + + [ObservableProperty] + [NotifyPropertyChangedFor(nameof(IsCurrentConversationGenerating))] + public partial ImageGenerationConversation? CurrentConversation { get; set; } + + partial void OnCurrentConversationChanged( + ImageGenerationConversation? oldValue, + ImageGenerationConversation? newValue + ) + { + // Cancel any pending message load from the previous conversation + loadMessagesCts?.Cancel(); + loadMessagesCts?.Dispose(); + loadMessagesCts = null; + + if (newValue != null) + { + logger.LogInformation( + "Current conversation changed to: {ConversationId} - {Title} (provider: {ProviderId})", + newValue.Id, + newValue.Title, + newValue.ProviderId + ); + + // Auto-switch to the conversation's last-used provider for convenience. + // Users can still freely change it afterwards, and that change will be + // remembered when they send the next message. + if (newValue.ProviderId != SelectedProviderId) + { + SelectedProviderId = newValue.ProviderId; + } + + // Create new cancellation token for this load operation + loadMessagesCts = new(); + var token = loadMessagesCts.Token; + + // Load messages for the new conversation (fire and forget with error handling) + LoadMessagesForConversationAsync(newValue, token) + .SafeFireAndForget(ex => + { + logger.LogError( + ex, + "Unhandled error loading messages for conversation {Id}", + newValue.Id + ); + }); + } + else + { + logger.LogWarning("Current conversation set to null"); + ClearMessages(); + } + } + + /// + /// Loads messages for a conversation without changing CurrentConversation + /// + private async Task LoadMessagesForConversationAsync( + ImageGenerationConversation conversation, + CancellationToken cancellationToken = default + ) + { + // Clear on UI thread + await Dispatcher.UIThread.InvokeAsync(ClearMessages); + + try + { + var messages = await chatService.GetMessagesAsync(conversation.Id); + + // Check if cancelled before updating UI (user may have switched conversations) + cancellationToken.ThrowIfCancellationRequested(); + + logger.LogInformation( + "Loaded {Count} messages for conversation {Id}", + messages.Count, + conversation.Id + ); + + // Update UI on the UI thread + await Dispatcher.UIThread.InvokeAsync(() => + { + foreach (var message in messages) + { + AddMessageToUI(message); + } + + // Notify gallery that images may have changed + OnPropertyChanged(nameof(ConversationImages)); + OnPropertyChanged(nameof(HasConversationImages)); + + // If this conversation is currently generating, re-add the loading placeholder + if (GeneratingConversationId == conversation.Id && IsGenerating) + { + currentLoadingMessage = new LoadingImageMessage + { + TargetWidth = (SelectedAspectRatio?.Width ?? 300) / 3, + TargetHeight = (SelectedAspectRatio?.Height ?? 300) / 3, + }; + Messages.Add(currentLoadingMessage); + } + + // Start the view at the bottom when switching to a (potentially long) conversation. + // Guard against late completion after the user already switched away. + if (CurrentConversation?.Id == conversation.Id) + { + Dispatcher.UIThread.Post( + () => ScrollToEndForcedRequested?.Invoke(this, EventArgs.Empty), + DispatcherPriority.Background + ); + } + }); + } + catch (OperationCanceledException) + { + // Conversation switch cancelled this load - this is expected, don't log as error + logger.LogDebug("Message loading cancelled for conversation {ConversationId}", conversation.Id); + } + catch (Exception ex) + { + logger.LogError(ex, "Failed to load messages for conversation {ConversationId}", conversation.Id); + await Dispatcher.UIThread.InvokeAsync(() => + { + ErrorMessage = $"Failed to load messages: {ex.Message}"; + }); + } + } + + [ObservableProperty] + public partial string? SelectedProviderId { get; set; } + + [ObservableProperty] + public partial string? ProviderStatusMessage { get; set; } + + [ObservableProperty] + public partial bool IsFluxKontextAvailable { get; set; } + + [ObservableProperty] + public partial bool CanRetryLastMessage { get; set; } + + /// + /// Whether we can regenerate the last assistant response (true when there's at least one assistant message) + /// + public bool CanRegenerateLastResponse => + Messages.OfType().Any(m => !m.IsMyMessage) + || Messages.OfType().Any(m => !m.IsMyMessage); + + /// + /// Whether to show thinking/reasoning output from Gemini 3 Pro + /// + [ObservableProperty] + public partial bool ShowThinkingOutput { get; set; } = true; + + /// + /// Whether the selected provider supports thinking output + /// + public bool SupportsThinking => BananaVisionProviderIds.SupportsThinking(SelectedProviderId); + + /// + /// Whether the selected provider requires a local backend (ComfyUI) + /// + public bool RequiresLocalBackend => BananaVisionProviderIds.IsLocalProvider(SelectedProviderId); + + /// + /// Whether the selected provider is a cloud/API provider (Gemini) + /// + public bool IsCloudProvider => BananaVisionProviderIds.IsCloudProvider(SelectedProviderId); + + /// + /// Whether to show the Flux Kontext settings panel + /// + public bool ShowFluxSettings => SelectedProviderId == BananaVisionProviderIds.FluxKontext; + + /// + /// Whether to show the Qwen Image Edit settings panel + /// + public bool ShowQwenSettings => SelectedProviderId == BananaVisionProviderIds.QwenImageEdit; + + /// + /// Whether to show the Flux.2 Klein settings panel + /// + public bool ShowKleinSettings => SelectedProviderId == BananaVisionProviderIds.Flux2Klein; + + /// + /// Whether the Flux settings panel is expanded + /// + [ObservableProperty] + public partial bool IsFluxSettingsExpanded { get; set; } = true; + + /// + /// Whether the Qwen settings panel is expanded + /// + [ObservableProperty] + public partial bool IsQwenSettingsExpanded { get; set; } = true; + + /// + /// Whether the Klein settings panel is expanded + /// + [ObservableProperty] + public partial bool IsKleinSettingsExpanded { get; set; } = true; + + /// + /// Selected Flux Kontext model + /// + [ObservableProperty] + public partial HybridModelFile? SelectedFluxModel { get; set; } + + /// + /// Selected Qwen Image Edit model + /// + [ObservableProperty] + public partial HybridModelFile? SelectedQwenModel { get; set; } + + /// + /// Selected Flux.2 Klein model + /// + [ObservableProperty] + public partial HybridModelFile? SelectedKleinModel { get; set; } + + /// + /// Sampling steps for Klein. Auto-set when the model changes (4 for distilled, + /// 20 for base); user can override via the Klein settings panel. + /// + [ObservableProperty] + public partial int KleinSteps { get; set; } = 4; + + /// + /// CFG scale for Klein. Auto-set when the model changes (1 for distilled, + /// 5 for base); user can override via the Klein settings panel. + /// + [ObservableProperty] + public partial double KleinCfg { get; set; } = 1.0; + + /// + /// Available Flux Kontext models (filtered by BaseModel metadata or untagged) + /// + public ObservableCollection AvailableFluxModels { get; } = []; + + /// + /// Available Qwen Image Edit models (filtered by BaseModel metadata or filename) + /// + public ObservableCollection AvailableQwenModels { get; } = []; + + /// + /// Available Flux.2 Klein models (filtered by BaseModel metadata or filename) + /// + public ObservableCollection AvailableKleinModels { get; } = []; + + /// + /// Available LoRA models for Flux Kontext + /// + public ObservableCollection AvailableFluxLoras { get; } = []; + + /// + /// Available LoRA models for Qwen Image Edit + /// + public ObservableCollection AvailableQwenLoras { get; } = []; + + /// + /// Available LoRA models for Flux.2 Klein + /// + public ObservableCollection AvailableKleinLoras { get; } = []; + + /// + /// Selected LoRAs with weights + /// + public ObservableCollection SelectedLoras { get; } = []; + + /// + /// Available aspect ratio presets + /// + public ObservableCollection AvailableAspectRatios { get; } = + [ + new("1:1", "Square", 1024, 1024), + new("16:9", "Landscape Wide", 1344, 768), + new("9:16", "Portrait Tall", 768, 1344), + new("4:3", "Landscape", 1152, 896), + new("3:4", "Portrait", 896, 1152), + new("3:2", "Photo Landscape", 1216, 832), + new("2:3", "Photo Portrait", 832, 1216), + new("21:9", "Ultrawide", 1536, 640), + new("9:21", "Ultra Tall", 640, 1536), + ]; + + /// + /// Selected aspect ratio + /// + [ObservableProperty] + public partial AspectRatioOption? SelectedAspectRatio { get; set; } + + /// + /// Whether to use custom resolution instead of aspect ratio presets + /// + [ObservableProperty] + public partial bool UseCustomResolution { get; set; } + + /// + /// Custom width when UseCustomResolution is true + /// + [ObservableProperty] + public partial int CustomWidth { get; set; } = 1024; + + /// + /// Custom height when UseCustomResolution is true + /// + [ObservableProperty] + public partial int CustomHeight { get; set; } = 1024; + + [ObservableProperty] + [NotifyPropertyChangedFor(nameof(IsComfyRunning))] + public partial PackagePair? RunningPackage { get; set; } + + [ObservableProperty] + public partial bool IsWaitingForConnection { get; set; } + + /// + /// Indicates whether the user is dragging an image over the page + /// + [ObservableProperty] + public partial bool IsDragOverImage { get; set; } + + /// + /// Whether the image gallery sidebar is visible + /// + [ObservableProperty] + public partial bool IsGalleryVisible { get; set; } + + /// + /// Gets all images in the current conversation for the gallery view + /// + public IEnumerable ConversationImages => + Messages.OfType().Where(m => !m.IsMyMessage); + + /// + /// Whether there are any images in the conversation + /// + public bool HasConversationImages => ConversationImages.Any(); + + partial void OnIsWaitingForConnectionChanged(bool value) + { + UpdateProviderStatus(); + } + + public bool IsComfyRunning => RunningPackage?.BasePackage is ComfyUI; + + private string? lastMessageText; + private List? lastMessageImagePaths; + private IDisposable? startupCompleteSubscription; + private bool hasShownMissingModelsDialog; + private CancellationTokenSource? loadMessagesCts; + + /// + /// Tracks the current loading message so it can be reliably removed on cancellation. + /// + private LoadingImageMessage? currentLoadingMessage; + + /// + /// Messages in the current conversation. Can contain MessageBase or ThinkingMessage. + /// + public ObservableCollection Messages { get; } + + /// + /// Event raised when the message list should scroll to the end + /// + public event EventHandler? ScrollToEndRequested; + + /// + /// Event raised when the message list should force-scroll to the end. + /// Used after switching conversations so users start at the bottom immediately. + /// + public event EventHandler? ScrollToEndForcedRequested; + + public ObservableCollection Conversations { get; set; } = []; + public ObservableCollection AvailableProviders { get; set; } = []; + + /// + /// Pending images to be sent with the next message + /// + public ObservableCollection PendingImages { get; set; } = []; + + // Will be set by the view + public IStorageProvider? StorageProvider { get; set; } + + public BananaVisionPageViewModel( + ILogger logger, + IImageGenerationChatService chatService, + ISecretsManager secretsManager, + INotificationService notificationService, + IInferenceClientManager inferenceClientManager, + RunningPackageService runningPackageService, + IServiceManager vmFactory, + IModelIndexService modelIndexService, + INavigationService navigationService, + INavigationService settingsNavigationService + ) + { + this.logger = logger; + this.chatService = chatService; + this.secretsManager = secretsManager; + this.notificationService = notificationService; + this.vmFactory = vmFactory; + this.navigationService = navigationService; + this.settingsNavigationService = settingsNavigationService; + this.modelIndexService = modelIndexService; + + ClientManager = inferenceClientManager; + + // Initialize Messages collection and subscribe to changes for auto-scroll + Messages = []; + Messages.CollectionChanged += OnMessagesCollectionChanged; + + // Load available providers + var providers = chatService.GetAvailableProviders(); + foreach (var provider in providers) + { + AvailableProviders.Add(new(provider.ProviderId, provider.ProviderName)); + } + + // Set default provider (use the first provider's ID) + SelectedProviderId = AvailableProviders.FirstOrDefault()?.Id; + + // Set default aspect ratio (1:1 Square) + SelectedAspectRatio = AvailableAspectRatios.FirstOrDefault(); + + // Subscribe to connection status changes + ClientManager.PropertyChanged += (s, e) => + { + if (e.PropertyName != nameof(IInferenceClientManager.IsConnected)) + return; + + UpdateProviderStatus(); + + // When connected and using a local provider, check for missing models + if (ClientManager.IsConnected && RequiresLocalBackend) + { + CheckAndShowMissingModelsDialogAsync() + .SafeFireAndForget(ex => + { + logger.LogError(ex, "Failed to check for missing models"); + }); + } + + // When disconnected during generation, cancel the pending operation + if (!ClientManager.IsConnected && IsGenerating && RequiresLocalBackend) + { + logger.LogWarning("ComfyUI disconnected during generation, cancelling..."); + CancelGeneration(); + } + }; + + // Subscribe to running package changes + runningPackageService.RunningPackages.CollectionChanged += (s, e) => + { + // ComfyZluda inherits from ComfyUI, so this check covers both + var comfyPackage = runningPackageService + .RunningPackages.FirstOrDefault(p => p.Value.RunningPackage.BasePackage is ComfyUI) + .Value?.RunningPackage; + + // Handle package startup - auto-connect when ComfyUI starts + if (comfyPackage != null && RunningPackage == null) + { + RunningPackage = comfyPackage; + + // Dispose previous subscription if any + startupCompleteSubscription?.Dispose(); + + // Subscribe to StartupComplete event for auto-connect + IsWaitingForConnection = true; + startupCompleteSubscription = Observable + .FromEventPattern( + comfyPackage.BasePackage, + nameof(comfyPackage.BasePackage.StartupComplete) + ) + .Take(1) + .Subscribe(_ => + { + Dispatcher.UIThread.InvokeAsync(async () => + { + // Only auto-connect for local providers (Flux Kontext, Qwen Image Edit, etc.) + if (RequiresLocalBackend && ClientManager.CanUserConnect) + { + logger.LogInformation( + "ComfyUI startup complete, auto-connecting for local provider..." + ); + await ConnectAsync(); + } + + IsWaitingForConnection = false; + }); + }); + } + else if (comfyPackage == null && RunningPackage != null) + { + // Package stopped + startupCompleteSubscription?.Dispose(); + startupCompleteSubscription = null; + IsWaitingForConnection = false; + } + + RunningPackage = comfyPackage; + UpdateProviderStatus(); + }; + + // Initial status update + var initialComfyPackage = runningPackageService + .RunningPackages.FirstOrDefault(p => p.Value.RunningPackage.BasePackage is ComfyUI) + .Value?.RunningPackage; + + RunningPackage = initialComfyPackage; + + // If ComfyUI is already running and we're using a local provider, try to connect + if (initialComfyPackage != null && RequiresLocalBackend && !ClientManager.IsConnected) + { + Dispatcher.UIThread.InvokeAsync(async () => + { + await Task.Delay(500); // Small delay to ensure ComfyUI is ready + if (ClientManager.CanUserConnect) + { + logger.LogInformation("ComfyUI already running on load, attempting connection..."); + await ConnectAsync(); + } + }); + } + + UpdateProviderStatus(); + } + + private void ResetGenerationProgress() + { + GenerationProgressPercent = null; + GenerationProgressStage = null; + GenerationProgressRunningNode = null; + OnPropertyChanged(nameof(GenerationProgressText)); + } + + private IProgress CreateProgressReporter(string providerId) + { + return new Progress(progress => + { + // Only show progress for the active local generation session/provider. + if (!RequiresLocalBackend || SelectedProviderId != providerId) + return; + + Dispatcher.UIThread.Post(() => + { + GenerationProgressPercent = progress.Percent; + GenerationProgressStage = progress.Stage; + GenerationProgressRunningNode = progress.RunningNode; + OnPropertyChanged(nameof(GenerationProgressText)); + }); + }); + } + + public override async Task OnLoadedAsync() + { + await base.OnLoadedAsync(); + + logger.LogInformation("BananaVisionPage loaded, initializing..."); + + // Load conversations + logger.LogInformation("Loading conversations from database..."); + await LoadConversationsAsync(); + logger.LogInformation("Loaded {Count} conversations", Conversations.Count); + + // Create or load a conversation + if (Conversations.Count == 0 && SelectedProviderId != null) + { + logger.LogInformation("No conversations found, creating new conversation"); + await NewConversationAsync(); + } + else if (Conversations.Count > 0) + { + logger.LogInformation("Loading most recent conversation: {ConversationId}", Conversations[0].Id); + await LoadConversationAsync(Conversations[0]); + } + } + + private async Task LoadConversationsAsync() + { + try + { + var conversations = await chatService.GetConversationsAsync(); + + // Update UI on the UI thread + await Dispatcher.UIThread.InvokeAsync(() => + { + Conversations.Clear(); + foreach (var conversation in conversations) + { + Conversations.Add(conversation); + } + }); + } + catch (Exception ex) + { + logger.LogError(ex, "Failed to load conversations"); + await Dispatcher.UIThread.InvokeAsync(() => + { + ErrorMessage = $"Failed to load conversations: {ex.Message}"; + }); + } + } + + [RelayCommand] + private async Task NewConversationAsync() + { + if (string.IsNullOrEmpty(SelectedProviderId)) + { + notificationService.Show("Error", "Please select a provider", NotificationType.Error); + return; + } + + try + { + var conversation = await chatService.CreateConversationAsync(SelectedProviderId); + Conversations.Insert(0, conversation); + await LoadConversationAsync(conversation); + } + catch (Exception ex) + { + logger.LogError(ex, "Failed to create conversation"); + notificationService.Show( + "Error", + $"Failed to create conversation: {ex.Message}", + NotificationType.Error + ); + } + } + + [RelayCommand] + private Task LoadConversationAsync(ImageGenerationConversation conversation) + { + // Setting CurrentConversation triggers OnCurrentConversationChanged which loads the messages + CurrentConversation = conversation; + return Task.CompletedTask; + } + + [RelayCommand] + private async Task DeleteConversationAsync(ImageGenerationConversation conversation) + { + // Show confirmation dialog + var dialog = new ContentDialog + { + Title = "Delete Conversation", + Content = + $"Are you sure you want to delete \"{conversation.Title}\"?\n\nThis will also delete all messages and generated images in this conversation.", + PrimaryButtonText = "Delete", + CloseButtonText = "Cancel", + DefaultButton = ContentDialogButton.Close, + }; + + var result = await dialog.ShowAsync(); + if (result != ContentDialogResult.Primary) + { + return; + } + + try + { + await chatService.DeleteConversationAsync(conversation.Id); + Conversations.Remove(conversation); + + if (CurrentConversation?.Id == conversation.Id) + { + ClearMessages(); + CurrentConversation = null; + + // Load first conversation if available + if (Conversations.Count > 0) + { + await LoadConversationAsync(Conversations[0]); + } + } + + notificationService.Show("Success", "Conversation deleted", NotificationType.Success); + } + catch (Exception ex) + { + logger.LogError(ex, "Failed to delete conversation {ConversationId}", conversation.Id); + notificationService.Show( + "Error", + $"Failed to delete conversation: {ex.Message}", + NotificationType.Error + ); + } + } + + [RelayCommand] + private async Task RenameConversationAsync(ImageGenerationConversation conversation) + { + try + { + var textBox = new TextBox + { + Text = conversation.Title, + Watermark = "Enter conversation name...", + MinWidth = 300, + }; + + var dialog = new ContentDialog + { + Title = "Rename Conversation", + Content = textBox, + PrimaryButtonText = "Save", + CloseButtonText = "Cancel", + DefaultButton = ContentDialogButton.Primary, + }; + + var result = await dialog.ShowAsync(); + + if (result == ContentDialogResult.Primary && !string.IsNullOrWhiteSpace(textBox.Text)) + { + conversation.Title = textBox.Text.Trim(); + await chatService.UpdateConversationAsync(conversation); + + // Refresh the list to update UI + var index = Conversations.IndexOf(conversation); + if (index >= 0) + { + Conversations[index] = conversation; + } + } + } + catch (Exception ex) + { + logger.LogError(ex, "Failed to rename conversation {ConversationId}", conversation.Id); + notificationService.Show( + "Error", + $"Failed to rename conversation: {ex.Message}", + NotificationType.Error + ); + } + } + + /// + /// Gets the display name for a provider ID + /// + public string GetProviderDisplayName(string? providerId) + { + if (string.IsNullOrEmpty(providerId)) + return "Unknown"; + return AvailableProviders.FirstOrDefault(p => p.Id == providerId)?.DisplayName ?? providerId; + } + + [RelayCommand] + private async Task ConnectAsync() + { + const int maxRetries = 5; + const int retryDelayMs = 1000; + + // Hold IsWaitingForConnection for the whole retry loop so the status banner and + // Connect button don't flicker as ClientManager.IsConnecting toggles per attempt + // (each failed attempt + retry delay would otherwise expose CanUserConnect briefly). + IsWaitingForConnection = true; + try + { + for (var attempt = 1; attempt <= maxRetries; attempt++) + { + try + { + logger.LogInformation( + "Attempting to connect to ComfyUI (attempt {Attempt}/{MaxRetries})...", + attempt, + maxRetries + ); + await ClientManager.ConnectAsync(); + notificationService.Show( + "Connected", + "Successfully connected to ComfyUI", + NotificationType.Success + ); + return; // Success - exit the method + } + catch (HttpRequestException ex) + when (ex.InnerException + is SocketException { SocketErrorCode: SocketError.ConnectionRefused } + ) + { + // Connection refused - ComfyUI might still be starting up + if (attempt < maxRetries) + { + logger.LogDebug( + "Connection refused (attempt {Attempt}/{MaxRetries}), retrying in {Delay}ms...", + attempt, + maxRetries, + retryDelayMs + ); + await Task.Delay(retryDelayMs); + } + else + { + logger.LogWarning( + ex, + "Failed to connect to ComfyUI after {MaxRetries} attempts", + maxRetries + ); + notificationService.Show( + "Connection Failed", + "Could not connect to ComfyUI. Make sure it's running and try again.", + NotificationType.Warning + ); + } + } + catch (Exception ex) + { + // Other errors - don't retry + logger.LogError(ex, "Failed to connect to ComfyUI"); + notificationService.Show("Connection Failed", ex.Message, NotificationType.Error); + return; + } + } + } + finally + { + IsWaitingForConnection = false; + } + } + + [RelayCommand] + private async Task ShowConnectionHelpAsync() + { + var viewModel = App.Services.GetRequiredService(); + var dialog = viewModel.CreateDialog(); + + await dialog.ShowAsync(); + + // After dialog closes, check if we should connect + if (IsComfyRunning && ClientManager.CanUserConnect) + { + await ConnectAsync(); + } + } + + [RelayCommand(IncludeCancelCommand = true)] + private async Task SendMessageAsync(CancellationToken cancellationToken) + { + if (string.IsNullOrWhiteSpace(NewMessageText) && PendingImages.Count == 0) + return; + + if (CurrentConversation == null) + { + notificationService.Show("Error", "No conversation selected", NotificationType.Error); + return; + } + + if (string.IsNullOrEmpty(SelectedProviderId)) + { + notificationService.Show("Error", "Please select a provider", NotificationType.Error); + return; + } + + var messageText = NewMessageText; + var imagePaths = PendingImages.Select(p => p.FilePath).ToList(); + + // Store for retry + lastMessageText = messageText; + lastMessageImagePaths = imagePaths.Count > 0 ? imagePaths : null; + + NewMessageText = string.Empty; + ErrorMessage = null; + CanRetryLastMessage = false; + + // Add user message to UI immediately + var provisionalUiItems = new List(); + if (!string.IsNullOrWhiteSpace(messageText)) + { + var uiText = new TextMessage(messageText, true); + provisionalUiItems.Add(uiText); + Messages.Add(uiText); + } + + // Show pending images in chat (provisional; will be replaced by persisted copies after DB save) + foreach (var pendingImage in PendingImages) + { + var uiImage = new ImageMessage(pendingImage.Bitmap, true); + provisionalUiItems.Add(uiImage); + Messages.Add(uiImage); + } + + // Clear pending images + PendingImages.Clear(); + + IsGenerating = true; + ResetGenerationProgress(); + if (RequiresLocalBackend && !string.IsNullOrEmpty(SelectedProviderId)) + { + GenerationProgressStage = "Starting..."; + } + + // Track which conversation is generating (for restoring placeholder on switch back) + GeneratingConversationId = CurrentConversation.Id; + + // Add loading placeholder (scaled to 1/3 of target size for compact display) + currentLoadingMessage = new LoadingImageMessage + { + TargetWidth = (SelectedAspectRatio?.Width ?? 300) / 3, + TargetHeight = (SelectedAspectRatio?.Height ?? 300) / 3, + }; + Messages.Add(currentLoadingMessage); + + try + { + // Build provider options + var providerOptions = BuildProviderOptions(); + var progress = + RequiresLocalBackend && SelectedProviderId != null + ? CreateProgressReporter(SelectedProviderId) + : null; + + var (userMessage, assistantMessage) = await chatService.SendMessageAsync( + CurrentConversation.Id, + SelectedProviderId!, + messageText, + imagePaths.Count > 0 ? imagePaths : null, + providerOptions, + progress, + cancellationToken + ); + + // Remove loading placeholder + if (currentLoadingMessage != null) + { + Messages.Remove(currentLoadingMessage); + currentLoadingMessage = null; + } + + // Replace provisional user UI items with canonical DB-backed messages (with IDs and persisted image paths). + foreach (var item in provisionalUiItems) + { + Messages.Remove(item); + if (item is ImageMessage imageMessage) + { + imageMessage.Image?.Dispose(); + } + } + + AddUserMessageToUI(userMessage); + + // Add assistant response to UI + if (assistantMessage != null) + { + AddAssistantMessageToUI(assistantMessage); + } + + // Reload conversations to update timestamps and titles + await LoadConversationsAsync(); + + // Update current conversation reference to reflect title changes + if (CurrentConversation != null) + { + var updatedConversation = Conversations.FirstOrDefault(c => c.Id == CurrentConversation.Id); + if (updatedConversation != null) + { + CurrentConversation = updatedConversation; + } + } + } + catch (OperationCanceledException) + { + // Check if cancellation was due to connection loss + if (RequiresLocalBackend && !ClientManager.IsConnected) + { + logger.LogWarning("Message generation cancelled due to connection loss"); + ErrorMessage = "Connection to ComfyUI was lost during generation."; + notificationService.Show( + "Connection Lost", + "ComfyUI disconnected during generation", + NotificationType.Warning + ); + } + else + { + logger.LogInformation("Message generation cancelled"); + ErrorMessage = "Cancelled"; + } + CanRetryLastMessage = true; + } + catch (ImageGenerationException ex) + { + // Expected error from generation (provider error, API error, etc.) + logger.LogWarning("Image generation failed: {Message}", ex.Message); + + // Check if this is an API key error + if (ex.Message.Contains("API key", StringComparison.OrdinalIgnoreCase)) + { + await ShowApiKeyRequiredDialogAsync(); + CanRetryLastMessage = true; + } + else + { + ErrorMessage = ex.Message; + notificationService.Show("Generation Failed", ex.Message, NotificationType.Warning); + CanRetryLastMessage = true; + } + } + catch (Exception ex) + { + // Unexpected error + logger.LogError(ex, "Unexpected error sending message"); + ErrorMessage = $"Unexpected error: {ex.Message}"; + notificationService.Show("Error", ex.Message, NotificationType.Error); + CanRetryLastMessage = true; + } + finally + { + IsGenerating = false; + GeneratingConversationId = null; + ResetGenerationProgress(); + // Ensure loading placeholder is removed on cancel/error + if (currentLoadingMessage != null) + { + Messages.Remove(currentLoadingMessage); + currentLoadingMessage = null; + } + } + } + + /// + /// Shows a dialog prompting the user to add their Gemini API key in settings + /// + private async Task ShowApiKeyRequiredDialogAsync() + { + var dialog = new ContentDialog + { + Title = "API Key Required", + Content = + "Gemini API key not configured. Please add your Gemini API key in Account Settings to use cloud providers.", + PrimaryButtonText = "Open Settings", + CloseButtonText = "Cancel", + DefaultButton = ContentDialogButton.Primary, + }; + + var result = await dialog.ShowAsync(); + + if (result == ContentDialogResult.Primary) + { + // Navigate to Settings -> Account Settings + navigationService.NavigateTo(new SuppressNavigationTransitionInfo()); + await Task.Delay(100); + settingsNavigationService.NavigateTo( + new SuppressNavigationTransitionInfo() + ); + } + } + + [RelayCommand(IncludeCancelCommand = true)] + private async Task RetryLastMessageAsync(CancellationToken cancellationToken) + { + if (CurrentConversation == null) + return; + + if (string.IsNullOrEmpty(SelectedProviderId)) + { + notificationService.Show("Error", "Please select a provider", NotificationType.Error); + return; + } + + // Clear error state + ErrorMessage = null; + CanRetryLastMessage = false; + IsGenerating = true; + ResetGenerationProgress(); + if (RequiresLocalBackend && !string.IsNullOrEmpty(SelectedProviderId)) + { + GenerationProgressStage = "Starting..."; + } + + // Track which conversation is generating (for restoring placeholder on switch back) + GeneratingConversationId = CurrentConversation.Id; + + try + { + // Build provider options + var providerOptions = BuildProviderOptions(); + var progress = + RequiresLocalBackend && SelectedProviderId != null + ? CreateProgressReporter(SelectedProviderId) + : null; + + // Add loading placeholder (scaled to 1/3 of target size for compact display) + currentLoadingMessage = new LoadingImageMessage + { + TargetWidth = (SelectedAspectRatio?.Width ?? 300) / 3, + TargetHeight = (SelectedAspectRatio?.Height ?? 300) / 3, + }; + Messages.Add(currentLoadingMessage); + + // Retry generation - this doesn't create a new user message + var assistantMessage = await chatService.RetryGenerationAsync( + CurrentConversation.Id, + SelectedProviderId, + providerOptions, + progress, + cancellationToken + ); + + // Remove loading placeholder + if (currentLoadingMessage != null) + { + Messages.Remove(currentLoadingMessage); + currentLoadingMessage = null; + } + + // Add only the assistant response to UI + AddAssistantMessageToUI(assistantMessage, includeDbId: false); + + // Reload conversations to update timestamps + await LoadConversationsAsync(); + } + catch (OperationCanceledException) + { + // Check if cancellation was due to connection loss + if (RequiresLocalBackend && !ClientManager.IsConnected) + { + logger.LogWarning("Retry generation cancelled due to connection loss"); + ErrorMessage = "Connection to ComfyUI was lost during generation."; + notificationService.Show( + "Connection Lost", + "ComfyUI disconnected during generation", + NotificationType.Warning + ); + } + else + { + logger.LogInformation("Retry generation cancelled"); + ErrorMessage = "Cancelled"; + } + CanRetryLastMessage = true; + } + catch (ImageGenerationException ex) + { + logger.LogWarning("Retry generation failed: {Message}", ex.Message); + + // Check if this is an API key error + if (ex.Message.Contains("API key", StringComparison.OrdinalIgnoreCase)) + { + await ShowApiKeyRequiredDialogAsync(); + CanRetryLastMessage = true; + } + else + { + ErrorMessage = ex.Message; + notificationService.Show("Generation Failed", ex.Message, NotificationType.Warning); + CanRetryLastMessage = true; + } + } + catch (Exception ex) + { + logger.LogError(ex, "Unexpected error during retry"); + ErrorMessage = $"Unexpected error: {ex.Message}"; + notificationService.Show("Error", ex.Message, NotificationType.Error); + CanRetryLastMessage = true; + } + finally + { + IsGenerating = false; + GeneratingConversationId = null; + ResetGenerationProgress(); + // Ensure loading placeholder is removed on cancel/error + if (currentLoadingMessage != null) + { + Messages.Remove(currentLoadingMessage); + currentLoadingMessage = null; + } + } + } + + [RelayCommand] + private void DismissError() + { + ErrorMessage = null; + CanRetryLastMessage = false; + } + + /// + /// Regenerates the last assistant response (without an error context) + /// + [RelayCommand(IncludeCancelCommand = true)] + private async Task RegenerateLastResponseAsync(CancellationToken cancellationToken) + { + if (CurrentConversation == null) + return; + + if (string.IsNullOrEmpty(SelectedProviderId)) + { + notificationService.Show("Error", "Please select a provider", NotificationType.Error); + return; + } + + // Clear error state + ErrorMessage = null; + CanRetryLastMessage = false; + IsGenerating = true; + ResetGenerationProgress(); + if (RequiresLocalBackend && !string.IsNullOrEmpty(SelectedProviderId)) + { + GenerationProgressStage = "Starting..."; + } + + // Track which conversation is generating (for restoring placeholder on switch back) + GeneratingConversationId = CurrentConversation.Id; + + try + { + // Remove the last assistant message(s) from UI before regenerating + var messagesToRemove = new List(); + for (var i = Messages.Count - 1; i >= 0; i--) + { + var message = Messages[i]; + // Stop when we hit a user message + if (message is TextMessage tm && tm.IsMyMessage) + break; + if (message is ImageMessage im && im.IsMyMessage) + break; + + messagesToRemove.Add(message); + } + + // Remove in reverse order to avoid index issues + foreach (var message in messagesToRemove) + { + Messages.Remove(message); + // Dispose image if needed + if (message is ImageMessage imageMessage) + { + imageMessage.Image?.Dispose(); + } + } + + // Delete old assistant messages from database (but preserve their image files) + var dbMessages = await chatService.GetMessagesAsync(CurrentConversation.Id); + var lastUserMessage = dbMessages.LastOrDefault(m => m.Role == MessageRole.User); + if (lastUserMessage != null) + { + // Find all assistant messages after the last user message + var oldAssistantMessages = dbMessages + .Where(m => m.Role == MessageRole.Assistant && m.Timestamp > lastUserMessage.Timestamp) + .ToList(); + + // Delete them from database but preserve image files for the output browser + foreach (var oldMsg in oldAssistantMessages) + { + await chatService.DeleteMessageAsync(oldMsg.Id, preserveImageFile: true); + } + + if (oldAssistantMessages.Count > 0) + { + logger.LogInformation( + "Removed {Count} old assistant message(s) from database, preserved image files", + oldAssistantMessages.Count + ); + } + } + + // Build provider options + var providerOptions = BuildProviderOptions(); + var progress = + RequiresLocalBackend && SelectedProviderId != null + ? CreateProgressReporter(SelectedProviderId) + : null; + + // Add loading placeholder (scaled to 1/3 of target size for compact display) + currentLoadingMessage = new LoadingImageMessage + { + TargetWidth = (SelectedAspectRatio?.Width ?? 300) / 3, + TargetHeight = (SelectedAspectRatio?.Height ?? 300) / 3, + }; + Messages.Add(currentLoadingMessage); + + // Retry generation - this doesn't create a new user message + var assistantMessage = await chatService.RetryGenerationAsync( + CurrentConversation.Id, + SelectedProviderId, + providerOptions, + progress, + cancellationToken + ); + + // Remove loading placeholder + if (currentLoadingMessage != null) + { + Messages.Remove(currentLoadingMessage); + currentLoadingMessage = null; + } + + // Add the new assistant response to UI + var addedAnyImages = AddAssistantMessageToUI(assistantMessage, includeDbId: false); + + if (addedAnyImages) + { + // Notify gallery + OnPropertyChanged(nameof(ConversationImages)); + OnPropertyChanged(nameof(HasConversationImages)); + } + + // Reload conversations to update timestamps + await LoadConversationsAsync(); + + // Notify property change + OnPropertyChanged(nameof(CanRegenerateLastResponse)); + } + catch (OperationCanceledException) + { + // Check if cancellation was due to connection loss + if (RequiresLocalBackend && !ClientManager.IsConnected) + { + logger.LogWarning("Regenerate cancelled due to connection loss"); + ErrorMessage = "Connection to ComfyUI was lost during generation."; + notificationService.Show( + "Connection Lost", + "ComfyUI disconnected during generation", + NotificationType.Warning + ); + } + else + { + logger.LogInformation("Regenerate cancelled"); + ErrorMessage = "Cancelled"; + } + CanRetryLastMessage = true; + } + catch (ImageGenerationException ex) + { + logger.LogWarning("Regenerate failed: {Message}", ex.Message); + + // Check if this is an API key error + if (ex.Message.Contains("API key", StringComparison.OrdinalIgnoreCase)) + { + await ShowApiKeyRequiredDialogAsync(); + CanRetryLastMessage = true; + } + else + { + ErrorMessage = ex.Message; + notificationService.Show("Generation Failed", ex.Message, NotificationType.Warning); + CanRetryLastMessage = true; + } + } + catch (Exception ex) + { + logger.LogError(ex, "Unexpected error during regenerate"); + ErrorMessage = $"Unexpected error: {ex.Message}"; + notificationService.Show("Error", ex.Message, NotificationType.Error); + CanRetryLastMessage = true; + } + finally + { + IsGenerating = false; + GeneratingConversationId = null; + ResetGenerationProgress(); + // Ensure loading placeholder is removed on cancel/error + if (currentLoadingMessage != null) + { + Messages.Remove(currentLoadingMessage); + currentLoadingMessage = null; + } + } + } + + /// + /// Edits a user message with option to save only or save and regenerate + /// + [RelayCommand] + private async Task EditUserMessageAsync(TextMessage? message) + { + if (message == null || !message.IsMyMessage || CurrentConversation == null) + return; + + try + { + var existingMessageId = message.DatabaseMessageId; + + // Show edit dialog with two action options + var textBox = new TextBox + { + Text = message.Text, + Watermark = "Edit your message...", + MinWidth = 400, + MinHeight = 100, + AcceptsReturn = true, + TextWrapping = global::Avalonia.Media.TextWrapping.Wrap, + }; + + var dialog = new ContentDialog + { + Title = "Edit Message", + Content = textBox, + PrimaryButtonText = "Save & Regenerate", + SecondaryButtonText = "Save Only", + CloseButtonText = "Cancel", + DefaultButton = ContentDialogButton.Primary, + }; + + var result = await dialog.ShowAsync(); + + if (result == ContentDialogResult.None || string.IsNullOrWhiteSpace(textBox.Text)) + return; + + var editedText = textBox.Text.Trim(); + var shouldRegenerate = result == ContentDialogResult.Primary; + + // Get all messages from database + var dbMessages = await chatService.GetMessagesAsync(CurrentConversation.Id); + var dbMessage = + existingMessageId != null + ? dbMessages.FirstOrDefault(m => m.Id == existingMessageId.Value) + : null; + + if (dbMessage == null) + { + // Message doesn't have a DatabaseMessageId - this is legacy data from before we tracked IDs. + // We cannot safely edit these messages because mapping UI messages to database entries + // is unreliable (a single database message can contain both text and images, but they + // appear as separate UI elements). Refuse to edit to prevent data corruption. + logger.LogWarning( + "Cannot edit message without DatabaseMessageId - legacy message from before ID tracking" + ); + notificationService.Show( + "Cannot Edit", + "This message cannot be edited because it was created before message tracking was added. " + + "You can still send new messages normally.", + NotificationType.Warning + ); + return; + } + + if (shouldRegenerate) + { + // Original behavior: delete from this point and regenerate + await EditAndRegenerateAsync(message, dbMessage, dbMessages, editedText); + } + else + { + // New behavior: just update the text without regenerating + await EditMessageOnlyAsync(message, dbMessage, editedText); + } + } + catch (Exception ex) + { + logger.LogError(ex, "Failed to edit message"); + notificationService.Show( + "Error", + $"Failed to edit message: {ex.Message}", + NotificationType.Error + ); + } + } + + /// + /// Updates a message's text without regenerating subsequent messages + /// + private async Task EditMessageOnlyAsync( + TextMessage uiMessage, + ImageGenerationMessage dbMessage, + string newText + ) + { + // Update in database + var updatedMessage = await chatService.UpdateMessageTextAsync(dbMessage.Id, newText); + if (updatedMessage == null) + { + notificationService.Show("Error", "Failed to update message", NotificationType.Error); + return; + } + + // Replace the UI message (TextMessage.Text is read-only) + var index = Messages.IndexOf(uiMessage); + if (index >= 0) + { + Messages[index] = new TextMessage(newText, uiMessage.IsMyMessage) + { + DatabaseMessageId = dbMessage.Id, + }; + } + + logger.LogInformation("Updated message {MessageId} text without regeneration", dbMessage.Id); + notificationService.Show("Message Updated", "Your message has been saved.", NotificationType.Success); + } + + /// + /// Edits a message and regenerates the conversation from that point + /// + private async Task EditAndRegenerateAsync( + TextMessage uiMessage, + ImageGenerationMessage dbMessage, + List allDbMessages, + string editedText + ) + { + // Delete all UI messages from this point onward + var firstUiIndexToDelete = -1; + for (var i = 0; i < Messages.Count; i++) + { + if (GetDatabaseMessageId(Messages[i]) == dbMessage.Id) + { + firstUiIndexToDelete = i; + break; + } + } + + if (firstUiIndexToDelete < 0) + { + firstUiIndexToDelete = Messages.IndexOf(uiMessage); + } + + var messagesToRemove = Messages.Skip(firstUiIndexToDelete).ToList(); + foreach (var msg in messagesToRemove) + { + Messages.Remove(msg); + if (msg is ImageMessage im) + { + im.Image?.Dispose(); + } + } + + // Delete all database messages from this point onward + var messagesToDelete = allDbMessages + .Where(m => m.Timestamp >= dbMessage.Timestamp) + .OrderBy(m => m.Timestamp) + .ToList(); + + foreach (var msg in messagesToDelete) + { + await chatService.DeleteMessageAsync(msg.Id); + } + + // Now send the edited message + IsGenerating = true; + ErrorMessage = null; + + try + { + // Add edited user message to UI + Messages.Add(new TextMessage(editedText, true)); + + // Build provider options + var providerOptions = BuildProviderOptions(); + + // Send the edited message + var (userMessage, assistantMessage) = await chatService.SendMessageAsync( + CurrentConversation!.Id, + SelectedProviderId!, + editedText, + null, + providerOptions, + progress: null, + CancellationToken.None + ); + + // Add assistant response to UI + if (assistantMessage != null) + { + AddAssistantMessageToUI(assistantMessage); + } + + // Reload conversations to update timestamps + await LoadConversationsAsync(); + + notificationService.Show( + "Message Edited", + "Your message has been edited and the conversation regenerated.", + NotificationType.Success + ); + } + catch (ImageGenerationException ex) + { + logger.LogWarning("Failed to regenerate after edit: {Message}", ex.Message); + + if (ex.Message.Contains("API key", StringComparison.OrdinalIgnoreCase)) + { + await ShowApiKeyRequiredDialogAsync(); + CanRetryLastMessage = true; + } + else + { + ErrorMessage = ex.Message; + notificationService.Show("Generation Failed", ex.Message, NotificationType.Warning); + CanRetryLastMessage = true; + } + } + catch (Exception ex) + { + logger.LogError(ex, "Unexpected error regenerating after edit"); + ErrorMessage = $"Unexpected error: {ex.Message}"; + notificationService.Show("Error", ex.Message, NotificationType.Error); + CanRetryLastMessage = true; + } + finally + { + IsGenerating = false; + } + } + + /// + /// Builds the provider options dictionary based on current settings + /// + private Dictionary BuildProviderOptions() + { + Dictionary? providerOptions = null; + + if (SupportsThinking && ShowThinkingOutput) + { + providerOptions = new() { ["enableThinking"] = true, ["thinkingBudget"] = 2048 }; + } + + if (SelectedProviderId == BananaVisionProviderIds.FluxKontext) + { + providerOptions ??= new(); + if (SelectedFluxModel != null) + providerOptions["CustomUnetModel"] = SelectedFluxModel; + if (SelectedLoras.Count > 0) + providerOptions["SelectedLoras"] = SelectedLoras.ToList(); + } + + if (SelectedProviderId == BananaVisionProviderIds.QwenImageEdit) + { + providerOptions ??= new(); + if (SelectedQwenModel != null) + providerOptions["CustomUnetModel"] = SelectedQwenModel; + if (SelectedLoras.Count > 0) + providerOptions["SelectedLoras"] = SelectedLoras.ToList(); + } + + if (SelectedProviderId == BananaVisionProviderIds.Flux2Klein) + { + providerOptions ??= new(); + if (SelectedKleinModel != null) + providerOptions["CustomUnetModel"] = SelectedKleinModel; + if (SelectedLoras.Count > 0) + providerOptions["SelectedLoras"] = SelectedLoras.ToList(); + providerOptions["Steps"] = KleinSteps; + providerOptions["CfgScale"] = KleinCfg; + } + + providerOptions ??= new(); + + if (UseCustomResolution) + { + providerOptions["Width"] = CustomWidth; + providerOptions["Height"] = CustomHeight; + // Marker that the user explicitly opted into a specific resolution. Providers + // doing img2img edits (e.g. Klein) use this to decide whether to override the + // reference-image-derived dimensions. + providerOptions["ExplicitDimensions"] = true; + } + else if (SelectedAspectRatio != null) + { + providerOptions["aspectRatio"] = SelectedAspectRatio.Ratio; + providerOptions["Width"] = SelectedAspectRatio.Width; + providerOptions["Height"] = SelectedAspectRatio.Height; + } + + return providerOptions; + } + + /// + /// Handles key down events from the message input TextBox. + /// Enter sends the message, Shift+Enter adds a new line. + /// + [RelayCommand] + private void TextBoxKeyDown(KeyEventArgs? e) + { + if (e?.Key != Key.Enter) + return; + + // Shift+Enter = let TextBox handle it naturally (insert newline at cursor position) + if (e.KeyModifiers.HasFlag(KeyModifiers.Shift)) + { + // Don't handle it - let the TextBox process the newline naturally + return; + } + + // Plain Enter = send message (but only if not already generating) + if (!IsGenerating && SendMessageCommand.CanExecute(null)) + { + e.Handled = true; + SendMessageCommand.Execute(null); + } + else + { + // Prevent the Enter from doing anything if we're generating + e.Handled = true; + } + } + + [RelayCommand] + private async Task AddImageAsync() + { + if (StorageProvider == null) + { + notificationService.Show("Error", "Storage provider not available", NotificationType.Error); + return; + } + + try + { + var files = await StorageProvider.OpenFilePickerAsync( + new() + { + Title = "Select Images", + AllowMultiple = true, + FileTypeFilter = + [ + new("Images") { Patterns = ["*.png", "*.jpg", "*.jpeg", "*.webp", "*.gif"] }, + ], + } + ); + + if (files.Count == 0) + return; + + foreach (var file in files) + { + var imagePath = file.Path.LocalPath; + var bitmap = new Bitmap(imagePath); + + PendingImages.Add(new() { FilePath = imagePath, Bitmap = bitmap }); + } + + notificationService.Show( + "Images Added", + $"Added {files.Count} image(s). They will be sent with your next message.", + NotificationType.Success + ); + } + catch (Exception ex) + { + logger.LogError(ex, "Failed to add images"); + notificationService.Show("Error", $"Failed to add images: {ex.Message}", NotificationType.Error); + } + } + + [RelayCommand] + private void RemovePendingImage(PendingImage image) + { + PendingImages.Remove(image); + image.Dispose(); + } + + /// + /// Adds images from file paths (used by drag and drop) + /// + public void AddImagesFromPaths(IEnumerable imagePaths) + { + try + { + var pathsList = imagePaths.ToList(); + var addedCount = 0; + + foreach (var imagePath in pathsList) + { + if (!File.Exists(imagePath)) + continue; + + var bitmap = new Bitmap(imagePath); + PendingImages.Add(new() { FilePath = imagePath, Bitmap = bitmap }); + addedCount++; + } + + if (addedCount > 0) + { + notificationService.Show( + "Images Added", + $"Added {addedCount} image(s). They will be sent with your next message.", + NotificationType.Success + ); + } + } + catch (Exception ex) + { + logger.LogError(ex, "Failed to add images from drag and drop"); + notificationService.Show("Error", $"Failed to add images: {ex.Message}", NotificationType.Error); + } + } + + /// + /// Supported image extensions for clipboard paste + /// + private static readonly HashSet SupportedImageExtensions = new(StringComparer.OrdinalIgnoreCase) + { + ".png", + ".jpg", + ".jpeg", + ".webp", + ".gif", + ".bmp", + }; + + /// + /// Tries to paste images from the clipboard. Returns true if images were pasted. + /// + public async Task TryPasteImagesFromClipboardAsync() + { + try + { + var clipboard = App.Clipboard; + if (clipboard == null) + return false; + + // First, check for files in clipboard (e.g., copied from file explorer) + var formats = await clipboard.GetFormatsAsync(); + + if (formats.Contains(DataFormats.Files)) + { + var data = await clipboard.GetDataAsync(DataFormats.Files); + if (data is IEnumerable files) + { + var imagePaths = files + .Select(f => f.Path.LocalPath) + .Where(p => SupportedImageExtensions.Contains(Path.GetExtension(p))) + .ToList(); + + if (imagePaths.Count > 0) + { + AddImagesFromPaths(imagePaths); + return true; + } + } + } + + // Check for bitmap/image data in clipboard (e.g., screenshots, copied images) + // Try common image formats + foreach ( + var format in new[] { "PNG", "image/png", "Bitmap", "DeviceIndependentBitmap", "image/bmp" } + ) + { + if (!formats.Contains(format)) + continue; + + var data = await clipboard.GetDataAsync(format); + if (data is byte[] { Length: > 0 } imageBytes) + { + var tempPath = await SaveClipboardImageToTempFileAsync(imageBytes, format); + if (tempPath != null) + { + AddImagesFromPaths([tempPath]); + return true; + } + } + else if (data is Stream stream) + { + using var ms = new MemoryStream(); + await stream.CopyToAsync(ms); + var bytes = ms.ToArray(); + + if (bytes.Length > 0) + { + var tempPath = await SaveClipboardImageToTempFileAsync(bytes, format); + if (tempPath != null) + { + AddImagesFromPaths([tempPath]); + return true; + } + } + } + } + + return false; + } + catch (Exception ex) + { + logger.LogError(ex, "Failed to paste images from clipboard"); + return false; + } + } + + /// + /// Saves clipboard image bytes to a temporary file + /// + private async Task SaveClipboardImageToTempFileAsync(byte[] imageBytes, string format) + { + try + { + var extension = format.ToLowerInvariant() switch + { + "png" or "image/png" => ".png", + "image/jpeg" or "jpeg" or "jpg" => ".jpg", + "image/bmp" or "bitmap" or "deviceindependentbitmap" => ".bmp", + _ => ".png", + }; + + var tempDir = Path.Combine(Path.GetTempPath(), "StabilityMatrix", "ClipboardImages"); + Directory.CreateDirectory(tempDir); + + var shortGuid = Guid.NewGuid().ToString("N")[..8]; + var fileName = $"clipboard_{DateTime.Now:yyyyMMdd_HHmmss}_{shortGuid}{extension}"; + var tempPath = Path.Combine(tempDir, fileName); + + await File.WriteAllBytesAsync(tempPath, imageBytes); + + logger.LogInformation("Saved clipboard image to temp file: {Path}", tempPath); + return tempPath; + } + catch (Exception ex) + { + logger.LogError(ex, "Failed to save clipboard image to temp file"); + return null; + } + } + + [RelayCommand] + private void ClearPendingImages() + { + foreach (var image in PendingImages) + { + image.Dispose(); + } + PendingImages.Clear(); + } + + [RelayCommand] + private async Task CopyMessageAsync(TextMessage message) + { + try + { + if (string.IsNullOrEmpty(message.Text)) + return; + + var clipboard = App.Clipboard; + if (clipboard != null) + { + await clipboard.SetTextAsync(message.Text); + notificationService.Show("Copied", "Message copied to clipboard", NotificationType.Success); + } + } + catch (Exception ex) + { + logger.LogError(ex, "Failed to copy message to clipboard"); + notificationService.Show("Error", "Failed to copy message", NotificationType.Error); + } + } + + [RelayCommand] + private async Task CopyImageToClipboardAsync(Bitmap? image) + { + if (image == null) + return; + + try + { + if (Compat.IsWindows) + { + await WindowsClipboard.SetBitmapAsync(image); + notificationService.Show("Copied", "Image copied to clipboard", NotificationType.Success); + } + else + { + notificationService.Show( + "Not Supported", + "Image clipboard is only supported on Windows", + NotificationType.Warning + ); + } + } + catch (Exception ex) + { + logger.LogError(ex, "Failed to copy image to clipboard"); + notificationService.Show("Error", "Failed to copy image", NotificationType.Error); + } + } + + /// + /// Unified cancel command that stops any ongoing generation (Send, Retry, or Regenerate) + /// + [RelayCommand] + private void CancelGeneration() + { + // Immediately remove loading placeholder for instant UI feedback + if (currentLoadingMessage != null) + { + Messages.Remove(currentLoadingMessage); + currentLoadingMessage = null; + } + + // Cancel whichever operation is in progress + if (SendMessageCancelCommand.CanExecute(null)) + { + SendMessageCancelCommand.Execute(null); + } + if (RetryLastMessageCancelCommand.CanExecute(null)) + { + RetryLastMessageCancelCommand.Execute(null); + } + if (RegenerateLastResponseCancelCommand.CanExecute(null)) + { + RegenerateLastResponseCancelCommand.Execute(null); + } + } + + [RelayCommand] + private void ToggleGallery() + { + IsGalleryVisible = !IsGalleryVisible; + if (IsGalleryVisible) + { + OnPropertyChanged(nameof(ConversationImages)); + } + } + + [RelayCommand] + private async Task EditPendingImageAsync(PendingImage image) + { + try + { + var editorVm = vmFactory.Get(); + editorVm.LoadImage(image.Bitmap, image.FilePath); + + var dialog = editorVm.GetDialog(); + var result = await dialog.ShowAsync(); + + if (result == FluentAvalonia.UI.Controls.ContentDialogResult.Primary && editorVm.HasAnnotations) + { + // Save the annotated image to a temp file + var annotatedPath = await editorVm.SaveAnnotatedImageAsync(); + + if (annotatedPath != null) + { + // Replace the pending image with the annotated version + var index = PendingImages.IndexOf(image); + if (index >= 0) + { + var annotatedBitmap = new Bitmap(annotatedPath); + var oldImage = PendingImages[index]; + PendingImages[index] = new() { FilePath = annotatedPath, Bitmap = annotatedBitmap }; + oldImage.Dispose(); // Dispose the old bitmap + + notificationService.Show( + "Image Updated", + "Your annotations have been applied to the image.", + NotificationType.Success + ); + } + } + } + } + catch (Exception ex) + { + logger.LogError(ex, "Failed to edit image"); + notificationService.Show("Error", $"Failed to edit image: {ex.Message}", NotificationType.Error); + } + } + + /// + /// Preview an image in a full-size dialog + /// + [RelayCommand] + private async Task PreviewImageAsync(Bitmap? bitmap) + { + if (bitmap == null) + return; + + try + { + var viewerVm = vmFactory.Get(); + viewerVm.ImageSource = new(bitmap); + + var dialog = viewerVm.GetDialog(); + await dialog.ShowAsync(); + } + catch (Exception ex) + { + logger.LogError(ex, "Failed to preview image"); + } + } + + partial void OnSelectedProviderIdChanged(string? value) + { + // Log provider change - the actual conversation update happens when sending a message + if (CurrentConversation != null && value != null && value != CurrentConversation.ProviderId) + { + logger.LogInformation( + "Provider selection changed from {OldProvider} to {NewProvider} for conversation {ConversationId}", + CurrentConversation.ProviderId, + value, + CurrentConversation.Id + ); + } + + // If switching away from local providers, clean up any pending connection + if (!BananaVisionProviderIds.IsLocalProvider(value)) + { + startupCompleteSubscription?.Dispose(); + startupCompleteSubscription = null; + IsWaitingForConnection = false; + hasShownMissingModelsDialog = false; // Reset for next time + } + + // Update provider status for the new provider + UpdateProviderStatus(); + + // Notify that provider-related properties may have changed + OnPropertyChanged(nameof(SupportsThinking)); + OnPropertyChanged(nameof(RequiresLocalBackend)); + OnPropertyChanged(nameof(IsCloudProvider)); + OnPropertyChanged(nameof(ShowFluxSettings)); + OnPropertyChanged(nameof(ShowQwenSettings)); + OnPropertyChanged(nameof(ShowKleinSettings)); + + // Load available Flux models when switching to Flux Kontext + if (value == BananaVisionProviderIds.FluxKontext) + { + LoadAvailableFluxModels(); + + // Auto-show missing models dialog if connected and models are missing + CheckAndShowMissingModelsDialogAsync() + .SafeFireAndForget(ex => + { + logger.LogError(ex, "Failed to check for missing Flux models"); + }); + } + + // Load available Qwen models when switching to Qwen Image Edit + if (value == BananaVisionProviderIds.QwenImageEdit) + { + LoadAvailableQwenModels(); + + // Auto-show missing models dialog if connected and models are missing + CheckAndShowMissingModelsDialogAsync() + .SafeFireAndForget(ex => + { + logger.LogError(ex, "Failed to check for missing Qwen models"); + }); + } + + // Load available Klein models when switching to Flux.2 Klein + if (value == BananaVisionProviderIds.Flux2Klein) + { + LoadAvailableKleinModels(); + + // Auto-show missing models dialog if connected and models are missing + CheckAndShowMissingModelsDialogAsync() + .SafeFireAndForget(ex => + { + logger.LogError(ex, "Failed to check for missing Klein models"); + }); + } + } + + private void UpdateProviderStatus() + { + // Check if this is a local provider with model requirements + var modelManager = LocalProviderModelManagerRegistry.GetManager(SelectedProviderId); + + if (modelManager != null) + { + // This is a local provider - check ComfyUI and model status + + // Check if ComfyUI is running + if (!IsComfyRunning) + { + ProviderStatusMessage = "⚠️ ComfyUI is not running. Click Launch to start."; + IsFluxKontextAvailable = false; + HasMissingModels = false; + return; + } + + // Check if we're waiting for connection + if (IsWaitingForConnection) + { + ProviderStatusMessage = "πŸ”„ Connecting to ComfyUI..."; + IsFluxKontextAvailable = false; + HasMissingModels = false; + return; + } + + // Check ComfyUI connection status + if (!ClientManager.IsConnected) + { + ProviderStatusMessage = "⚠️ Not connected to ComfyUI. Click Connect."; + IsFluxKontextAvailable = false; + HasMissingModels = false; + return; + } + + // While a model-download batch is running, show progress instead of the + // missing-models warning. Models are still technically missing on disk until + // the download finishes, but the user has already acted on that β€” surfacing + // the same warning + Download button would be misleading. + if (IsDownloadingModels) + { + ProviderStatusMessage = DownloadProgressText ?? "⬇️ Downloading models..."; + IsFluxKontextAvailable = false; + HasMissingModels = false; + return; + } + + // Check if required models are available + if (!modelManager.AreModelsAvailable(ClientManager)) + { + var missingModelNames = modelManager.GetMissingModelNames(ClientManager).ToList(); + var modelsList = string.Join(", ", missingModelNames); + ProviderStatusMessage = $"⚠️ Missing: {modelsList}"; + IsFluxKontextAvailable = false; + HasMissingModels = true; + return; + } + + // All good + ProviderStatusMessage = $"βœ… {modelManager.ProviderDisplayName} is ready"; + IsFluxKontextAvailable = true; + HasMissingModels = false; + } + else + { + // Cloud providers or providers without model requirements + ProviderStatusMessage = null; + IsFluxKontextAvailable = false; + HasMissingModels = false; + } + } + + /// + /// Gets all valid image paths from a message, handling both ImagePath and ImagePaths properties + /// + private static List GetMessageImagePaths(ImageGenerationMessage message) + { + var paths = + message.ImagePaths?.Where(p => !string.IsNullOrWhiteSpace(p)).ToList() + ?? (!string.IsNullOrEmpty(message.ImagePath) ? [message.ImagePath] : []); + + return paths.Distinct(StringComparer.OrdinalIgnoreCase).ToList(); + } + + /// + /// Adds a message (user or assistant) to the Messages collection + /// + /// The message to add + /// Whether to include the database message ID for tracking + /// True if any images were added + private bool AddMessageToUI(ImageGenerationMessage message, bool includeDbId = true) + { + var isUserMessage = message.Role == MessageRole.User; + var dbId = includeDbId ? message.Id : (Guid?)null; + + // Show thinking content first (only for assistant messages) + if (!isUserMessage && ShowThinkingOutput && !string.IsNullOrEmpty(message.ThinkingContent)) + { + Messages.Add(new ThinkingMessage(message.ThinkingContent) { DatabaseMessageId = dbId }); + } + + if (!string.IsNullOrEmpty(message.TextContent)) + { + Messages.Add(new TextMessage(message.TextContent, isUserMessage) { DatabaseMessageId = dbId }); + } + + var addedAnyImages = false; + foreach (var imagePath in GetMessageImagePaths(message).Where(File.Exists)) + { + var bitmap = new Bitmap(imagePath); + Messages.Add( + new ImageMessage(bitmap, isUserMessage) { DatabaseMessageId = dbId, FilePath = imagePath } + ); + addedAnyImages = true; + } + + return addedAnyImages; + } + + /// + /// Adds an assistant message (thinking, text, and images) to the Messages collection + /// + private bool AddAssistantMessageToUI(ImageGenerationMessage message, bool includeDbId = true) + { + return AddMessageToUI(message, includeDbId); + } + + /// + /// Adds a user message (text and images) to the Messages collection + /// + private void AddUserMessageToUI(ImageGenerationMessage message) + { + AddMessageToUI(message, includeDbId: true); + } + + /// + /// Clears all messages and disposes any image bitmaps to prevent memory leaks + /// + private void ClearMessages() + { + foreach (var message in Messages) + { + if (message is ImageMessage imageMessage) + { + imageMessage.Image?.Dispose(); + } + } + Messages.Clear(); + } + + private static Guid? GetDatabaseMessageId(object? message) + { + return message switch + { + MessageBase m => m.DatabaseMessageId, + ThinkingMessage tm => tm.DatabaseMessageId, + _ => null, + }; + } + + private void RemoveUiMessagesForDatabaseMessageId(Guid messageId) + { + var toRemove = Messages.Where(m => GetDatabaseMessageId(m) == messageId).ToList(); + + foreach (var item in toRemove) + { + Messages.Remove(item); + if (item is ImageMessage imageMessage) + { + imageMessage.Image?.Dispose(); + } + } + + // Notify gallery that images may have changed + OnPropertyChanged(nameof(ConversationImages)); + OnPropertyChanged(nameof(HasConversationImages)); + OnPropertyChanged(nameof(CanRegenerateLastResponse)); + } + + [RelayCommand] + private async Task DeleteMessageAsync(object? messageItem) + { + if (CurrentConversation == null) + return; + + var messageId = GetDatabaseMessageId(messageItem); + if (messageId == null) + return; + + try + { + // Check if this is an image from a multi-image message + var isImageMessage = + messageItem is ImageMessage imageMsg && !string.IsNullOrEmpty(imageMsg.FilePath); + var dbMessage = isImageMessage ? await chatService.GetMessageAsync(messageId.Value) : null; + var imageCount = + dbMessage != null + ? (dbMessage.ImagePaths?.Count ?? (string.IsNullOrEmpty(dbMessage.ImagePath) ? 0 : 1)) + : 0; + var isMultiImageMessage = imageCount > 1; + + string dialogContent; + if (isMultiImageMessage) + { + dialogContent = + "Delete this image from the message?\n\n" + + $"The message has {imageCount} images. Only this image will be removed."; + } + else + { + dialogContent = + "This will permanently delete the selected message from this conversation.\n\n" + + "Note: deleting a message in the middle of a conversation may change context for future generations."; + } + + var dialog = new ContentDialog + { + Title = isMultiImageMessage ? "Delete image?" : "Delete message?", + Content = new TextBlock + { + Text = dialogContent, + TextWrapping = global::Avalonia.Media.TextWrapping.Wrap, + MaxWidth = 420, + }, + PrimaryButtonText = "Delete", + CloseButtonText = "Cancel", + DefaultButton = ContentDialogButton.Close, + }; + + if (await dialog.ShowAsync() != ContentDialogResult.Primary) + return; + + // Handle multi-image message: only remove the specific image + if ( + isMultiImageMessage + && messageItem is ImageMessage imgToDelete + && !string.IsNullOrEmpty(imgToDelete.FilePath) + ) + { + var wasFullyDeleted = await chatService.RemoveImageFromMessageAsync( + messageId.Value, + imgToDelete.FilePath + ); + + if (wasFullyDeleted) + { + // Whole message was deleted (was the last image) + RemoveUiMessagesForDatabaseMessageId(messageId.Value); + } + else + { + // Only remove this specific UI element + Messages.Remove(messageItem); + } + } + else + { + // Regular deletion - remove entire message + await chatService.DeleteMessageAsync(messageId.Value); + RemoveUiMessagesForDatabaseMessageId(messageId.Value); + } + + // Reload conversations to update timestamps + await LoadConversationsAsync(); + } + catch (Exception ex) + { + logger.LogError(ex, "Failed to delete message {MessageId}", messageId); + notificationService.Show( + "Error", + $"Failed to delete message: {ex.Message}", + NotificationType.Error + ); + } + } + + /// + /// Handles collection changes to trigger auto-scroll + /// + private void OnMessagesCollectionChanged(object? sender, NotifyCollectionChangedEventArgs e) + { + // Request scroll to end when new messages are added + if (e.Action == NotifyCollectionChangedAction.Add) + { + Dispatcher.UIThread.Post( + () => + { + ScrollToEndRequested?.Invoke(this, EventArgs.Empty); + }, + DispatcherPriority.Background + ); + } + + // Notify that regenerate availability may have changed + OnPropertyChanged(nameof(CanRegenerateLastResponse)); + } + + /// + protected override void Dispose(bool disposing) + { + if (disposing) + { + // Cancel and dispose any pending message load + loadMessagesCts?.Cancel(); + loadMessagesCts?.Dispose(); + loadMessagesCts = null; + + // Dispose startup subscription + startupCompleteSubscription?.Dispose(); + startupCompleteSubscription = null; + + // Dispose pending images + foreach (var image in PendingImages) + { + image.Dispose(); + } + PendingImages.Clear(); + + // Dispose message bitmaps and clear + ClearMessages(); + + // Unsubscribe from collection changed + Messages.CollectionChanged -= OnMessagesCollectionChanged; + } + + base.Dispose(disposing); + } +} diff --git a/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs b/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs index 41537a3e4..0868417f9 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Base/InferenceGenerationViewModelBase.cs @@ -39,6 +39,7 @@ using StabilityMatrix.Core.Models.Api.Comfy.WebSocketData; using StabilityMatrix.Core.Models.FileInterfaces; using StabilityMatrix.Core.Models.Inference; +using StabilityMatrix.Core.Models.Notifications; using StabilityMatrix.Core.Models.PackageModification; using StabilityMatrix.Core.Models.Packages.Extensions; using StabilityMatrix.Core.Models.Settings; @@ -60,8 +61,8 @@ public abstract partial class InferenceGenerationViewModelBase private readonly ISettingsManager settingsManager; private readonly RunningPackageService runningPackageService; - private readonly INotificationService notificationService; private readonly IServiceManager vmFactory; + private readonly INotificationService notificationService; [JsonPropertyName("ImageGallery")] public ImageGalleryCardViewModel ImageGalleryCardViewModel { get; } @@ -411,7 +412,8 @@ await notificationService.ShowAsync( Title = "Prompt Completed", Body = $"Prompt [{promptTask.Id[..7].ToLower()}] completed successfully", BodyImagePath = notificationImage?.FullPath, - } + }, + action: new NavigateToPageAction(typeof(InferenceViewModel).AssemblyQualifiedName!) ); } finally diff --git a/StabilityMatrix.Avalonia/ViewModels/Base/LoadableViewModelBase.cs b/StabilityMatrix.Avalonia/ViewModels/Base/LoadableViewModelBase.cs index 26a875029..ec75b4b32 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Base/LoadableViewModelBase.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Base/LoadableViewModelBase.cs @@ -45,6 +45,8 @@ namespace StabilityMatrix.Avalonia.ViewModels.Base; [JsonDerivedType(typeof(NRSModule))] [JsonDerivedType(typeof(CfzCudnnToggleModule))] [JsonDerivedType(typeof(TiledVAEModule))] +[JsonDerivedType(typeof(RegionalPromptModule))] +[JsonDerivedType(typeof(RegionalPromptCardViewModel), RegionalPromptCardViewModel.ModuleKey)] public abstract class LoadableViewModelBase : ViewModelBase, IJsonLoadableState { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); diff --git a/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowser/CivArchiveBrowserViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowser/CivArchiveBrowserViewModel.cs new file mode 100644 index 000000000..f39219a87 --- /dev/null +++ b/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowser/CivArchiveBrowserViewModel.cs @@ -0,0 +1,890 @@ +using System.Collections.ObjectModel; +using System.ComponentModel; +using System.Reactive.Disposables; +using Avalonia.Threading; +using CommunityToolkit.Mvvm.ComponentModel; +using CommunityToolkit.Mvvm.Input; +using FluentAvalonia.UI.Controls; +using Injectio.Attributes; +using StabilityMatrix.Avalonia.Animations; +using StabilityMatrix.Avalonia.Models; +using StabilityMatrix.Avalonia.Services; +using StabilityMatrix.Avalonia.ViewModels.Base; +using StabilityMatrix.Avalonia.ViewModels.CheckpointManager; +using StabilityMatrix.Avalonia.Views; +using StabilityMatrix.Core.Api; +using StabilityMatrix.Core.Attributes; +using StabilityMatrix.Core.Helper; +using StabilityMatrix.Core.Models.Api.CivArchive; +using StabilityMatrix.Core.Models.Settings; +using StabilityMatrix.Core.Processes; +using StabilityMatrix.Core.Services; + +namespace StabilityMatrix.Avalonia.ViewModels.CheckpointBrowser; + +[View(typeof(CivArchiveBrowserPage))] +[RegisterSingleton] +public sealed partial class CivArchiveBrowserViewModel( + ICivArchiveApiClient civArchiveApiClient, + ISettingsManager settingsManager, + IServiceManager viewModelFactory, + INavigationService navigationService, + IModelIndexService modelIndexService +) : TabViewModelBase, IInfinitelyScroll +{ + private bool suppressSearch; + private bool filterOptionsLoaded; + private bool searchQueued; + private int currentPage = 1; + private CancellationTokenSource? searchDebounceCts; + + /// + /// How long to wait after the most recent filter change before firing the search. + /// Kept settable (not a const) so unit tests can collapse it to + /// instead of waiting hundreds of ms per assertion. + /// + public TimeSpan SearchDebounceInterval { get; set; } = TimeSpan.FromMilliseconds(300); + + /// + /// All search results we've fetched so far across pages, regardless of client-side filters. + /// Used as the source for rebuilds and dedupe checks. + /// + private readonly List rawResults = []; + + [ObservableProperty] + private ObservableCollection results = []; + + [ObservableProperty] + [NotifyPropertyChangedFor(nameof(IsEndOfResults), nameof(HasResultCount))] + private int totalHits; + + [ObservableProperty] + private bool hideInstalledModels; + + [ObservableProperty] + private double resizeFactor = 1.0; + + /// + /// True (default) renders card images in Fit mode (whole image visible with a blurred + /// edge-fill behind it). False switches to Fill mode (UniformToFill, may crop edges). + /// + [ObservableProperty] + private bool fitCardImages = true; + + public bool IsEndOfResults => HasSearched && TotalHits > 0 && rawResults.Count >= TotalHits; + + public bool HasResultCount => HasSearched && TotalHits > 0; + + [ObservableProperty] + private ObservableCollection allModelTypes = []; + + [ObservableProperty] + private ObservableCollection allBaseModels = []; + + [ObservableProperty] + private ObservableCollection filteredModelTypes = []; + + [ObservableProperty] + private ObservableCollection filteredBaseModels = []; + + [ObservableProperty] + private string modelTypeFilter = string.Empty; + + [ObservableProperty] + private string baseModelFilter = string.Empty; + + public string ModelTypeSelectionSummary => + AllModelTypes.Count == 0 + ? string.Empty + : $"{AllModelTypes.Count(x => x.IsSelected)} of {AllModelTypes.Count} selected"; + + public string BaseModelSelectionSummary => + AllBaseModels.Count == 0 + ? string.Empty + : $"{AllBaseModels.Count(x => x.IsSelected)} of {AllBaseModels.Count} selected"; + + /// + /// True when a partial selection is active β€” at least one selected and at least one + /// deselected. The "all selected" stock state and the "none selected" degenerate state + /// both render plain (no badge), since neither is a meaningful filter to surface. + /// + public bool HasModelTypeFilter => + AllModelTypes.Count > 0 + && AllModelTypes.Any(x => x.IsSelected) + && AllModelTypes.Any(x => !x.IsSelected); + + public bool HasBaseModelFilter => + AllBaseModels.Count > 0 + && AllBaseModels.Any(x => x.IsSelected) + && AllBaseModels.Any(x => !x.IsSelected); + + public int SelectedModelTypeCount => AllModelTypes.Count(x => x.IsSelected); + + public int SelectedBaseModelCount => AllBaseModels.Count(x => x.IsSelected); + + [ObservableProperty] + private string searchQuery = string.Empty; + + [ObservableProperty] + [NotifyPropertyChangedFor(nameof(HasAdvancedFilters))] + private string tagQuery = string.Empty; + + [ObservableProperty] + [NotifyPropertyChangedFor(nameof(HasAdvancedFilters))] + private string usernameQuery = string.Empty; + + /// + /// True when the user has set explicit Tags or Username via the "More filters" flyout + /// (not the inline @/# tokens β€” those are visible in the search box itself). + /// Used to show a small dot indicator on the More Filters button. + /// + public bool HasAdvancedFilters => + !string.IsNullOrWhiteSpace(TagQuery) || !string.IsNullOrWhiteSpace(UsernameQuery); + + [ObservableProperty] + private bool isLoading; + + [ObservableProperty] + private bool noResultsFound; + + [ObservableProperty] + [NotifyPropertyChangedFor(nameof(IsEndOfResults), nameof(HasResultCount))] + private bool hasSearched; + + [ObservableProperty] + private string noResultsText = "No results found"; + + [ObservableProperty] + private NamedOption? selectedPlatform; + + [ObservableProperty] + private NamedOption? selectedSort; + + [ObservableProperty] + private NamedOption? selectedPeriod; + + [ObservableProperty] + private NamedOption? selectedRating; + + [ObservableProperty] + private NamedOption? selectedPlatformStatus; + + [ObservableProperty] + private NamedOption? selectedKind; + + public IReadOnlyList> AllPlatforms { get; } = + [ + new("All Platforms", CivArchivePlatformOption.All), + new("CivitAI", CivArchivePlatformOption.Civitai), + new("TensorArt", CivArchivePlatformOption.Tensorart), + new("TensorHub", CivArchivePlatformOption.Tensorhub), + new("SeaArt", CivArchivePlatformOption.Seaart), + new("Civision", CivArchivePlatformOption.Civision), + new("PixAI", CivArchivePlatformOption.Pixai), + new("Tungsten", CivArchivePlatformOption.Tungsten), + new("Yodayo", CivArchivePlatformOption.Yodayo), + new("Moescape", CivArchivePlatformOption.Moescape), + new("Shakker", CivArchivePlatformOption.Shakker), + new("HuggingFace", CivArchivePlatformOption.Huggingface), + new("ModelScope", CivArchivePlatformOption.Modelscope), + new("ModelScope CN", CivArchivePlatformOption.ModelscopeCn), + ]; + + public IReadOnlyList> AllSorts { get; } = + [ + new("Top", CivArchiveSortOption.Top), + new("Newest", CivArchiveSortOption.Newest), + new("Oldest", CivArchiveSortOption.Oldest), + new("Relevance", CivArchiveSortOption.Relevance), + new("Deleted Newest", CivArchiveSortOption.DeletedNewest), + new("Deleted Oldest", CivArchiveSortOption.DeletedOldest), + ]; + + public IReadOnlyList> AllPeriods { get; } = + [ + new("All", CivArchivePeriodOption.All), + new("Week", CivArchivePeriodOption.Week), + new("Month", CivArchivePeriodOption.Month), + new("Quarter", CivArchivePeriodOption.Quarter), + new("Half", CivArchivePeriodOption.Half), + new("Year", CivArchivePeriodOption.Year), + ]; + + public IReadOnlyList> AllRatings { get; } = + [ + new("Safe", CivArchiveRatingOption.Safe), + new("All", CivArchiveRatingOption.All), + new("Explicit", CivArchiveRatingOption.Explicit), + ]; + + public IReadOnlyList> AllPlatformStatuses { get; } = + [ + new("All", CivArchivePlatformStatusOption.All), + new("Available", CivArchivePlatformStatusOption.Available), + new("Deleted", CivArchivePlatformStatusOption.Deleted), + ]; + + public IReadOnlyList> AllKinds { get; } = + [ + new("All", CivArchiveKindOption.All), + new("Version", CivArchiveKindOption.Version), + new("User", CivArchiveKindOption.User), + new("File", CivArchiveKindOption.File), + ]; + + public override string Header => "CivArchive"; + + public override void OnLoaded() + { + if (!ViewModelState.HasFlag(ViewModelState.InitialLoaded)) + { + RestoreSettings(); + } + + base.OnLoaded(); + } + + protected override async Task OnInitialLoadedAsync() + { + await base.OnInitialLoadedAsync(); + + AddDisposable( + settingsManager.RelayPropertyFor( + this, + vm => vm.ResizeFactor, + s => s.CivArchiveBrowserResizeFactor, + true + ) + ); + + AddDisposable( + settingsManager.RelayPropertyFor( + this, + vm => vm.HideInstalledModels, + s => s.HideInstalledModelsInModelBrowser, + true + ) + ); + + AddDisposable( + settingsManager.RelayPropertyFor( + this, + vm => vm.FitCardImages, + s => s.CivArchiveBrowserFitCardImages, + true + ) + ); + + EventHandler indexHandler = (_, _) => Dispatcher.UIThread.Post(OnLocalModelIndexChanged); + EventManager.Instance.ModelIndexChanged += indexHandler; + AddDisposable(Disposable.Create(() => EventManager.Instance.ModelIndexChanged -= indexHandler)); + + // Cancel any pending debounced search when the VM is disposed. + AddDisposable( + Disposable.Create(() => + { + searchDebounceCts?.Cancel(); + searchDebounceCts?.Dispose(); + searchDebounceCts = null; + }) + ); + + // Filter options (Model Type / Base Model dropdown contents) only come back + // populated when the URL has no query string, so we have to fetch them with a + // dedicated parameterless call before the first filtered search runs. + await LoadFilterOptionsAsync(); + + await SearchModels(); + } + + private async Task LoadFilterOptionsAsync() + { + if (filterOptionsLoaded) + { + return; + } + + try + { + var options = await civArchiveApiClient.GetFilterOptionsAsync(); + ApplyFilterOptions(options); + filterOptionsLoaded = true; + } + catch (Exception ex) + { + // Don't block the search itself β€” failure here just means the multi-select + // dropdowns stay empty until the user reloads the page. + NoResultsText = ex.Message; + } + finally + { + // ApplyFilterOptions sets suppressSearch=true while it populates the option + // collections; release it before the real search runs so user input flows. + suppressSearch = false; + } + } + + partial void OnHideInstalledModelsChanged(bool value) => RebuildVisibleResults(); + + private void OnLocalModelIndexChanged() + { + // The local index changed (download finished or model deleted) β€” re-evaluate every + // cached result's IsInstalled flag, then rebuild Results so the badge / hide-installed + // filter reflect reality without needing a re-search. + var hashes = modelIndexService.ModelIndexSha256Hashes; + var urls = modelIndexService.ModelIndexCivArchiveUrls; + foreach (var item in rawResults) + { + item.IsInstalled = + (!string.IsNullOrEmpty(item.Url) && urls.Contains(item.Url)) + || (!string.IsNullOrEmpty(item.Sha256FromUrl) && hashes.Contains(item.Sha256FromUrl)); + } + RebuildVisibleResults(); + } + + private void RebuildVisibleResults() + { + // Replace the collection wholesale instead of Clear + NΓ—Add β€” one + // PropertyChanged on Results vs N CollectionChanged notifications, + // which keeps ItemsRepeater from churning containers per item. + Results = new ObservableCollection( + HideInstalledModels ? rawResults.Where(item => !item.IsInstalled) : rawResults + ); + NoResultsFound = HasSearched && Results.Count == 0; + OnPropertyChanged(nameof(IsEndOfResults)); + } + + /// + /// Cancel any pending debounced search and start a new wait window. The actual + /// search runs once the user pauses for β€” + /// multiple rapid filter changes within that window collapse into a single fetch. + /// + private void RequestDebouncedSearch() + { + if (suppressSearch) + { + return; + } + + SaveSettings(); + + if (!HasSearched) + { + return; + } + + searchDebounceCts?.Cancel(); + searchDebounceCts?.Dispose(); + + var cts = new CancellationTokenSource(); + searchDebounceCts = cts; + + _ = RunDebouncedSearchAsync(cts.Token); + } + + private async Task RunDebouncedSearchAsync(CancellationToken token) + { + try + { + await Task.Delay(SearchDebounceInterval, token); + } + catch (TaskCanceledException) + { + return; + } + + if (token.IsCancellationRequested) + { + return; + } + + await SearchModels(); + } + + [RelayCommand] + private async Task SearchModels(bool isInfiniteScroll = false) + { + // Cancel any pending debounced search so an explicit invocation (search button, + // ResetFilters, etc.) doesn't get shadowed by a redundant fire moments later. + searchDebounceCts?.Cancel(); + + if (IsLoading) + { + if (!isInfiniteScroll) + { + searchQueued = true; + } + + return; + } + + if (!isInfiniteScroll) + { + searchQueued = false; + } + + if (!isInfiniteScroll) + { + currentPage = 1; + TotalHits = 0; + rawResults.Clear(); + Results.Clear(); + } + + var filters = BuildFilters(isInfiniteScroll ? currentPage + 1 : currentPage); + + IsLoading = true; + NoResultsFound = false; + NoResultsText = "No results found"; + + try + { + var response = await civArchiveApiClient.SearchAsync(filters); + + TotalHits = response.TotalHits; + currentPage = response.EffectiveFilters.Page; + + // O(1) dedupe lookup against everything we've already fetched, instead of a + // linear scan per incoming item (which becomes O(NΒ²) as paged results grow). + var existingIds = isInfiniteScroll ? rawResults.Select(x => x.Id).ToHashSet() : []; + var installedHashes = modelIndexService.ModelIndexSha256Hashes; + var installedUrls = modelIndexService.ModelIndexCivArchiveUrls; + foreach (var item in response.Results) + { + if (isInfiniteScroll && !existingIds.Add(item.Id)) + { + continue; + } + + // URL match works for any CivArchive download with a stored SourceUrl; + // SHA256 fallback covers the rare File-kind result with hash in URL. + if (!string.IsNullOrEmpty(item.Url) && installedUrls.Contains(item.Url)) + { + item.IsInstalled = true; + } + else if ( + !string.IsNullOrEmpty(item.Sha256FromUrl) && installedHashes.Contains(item.Sha256FromUrl) + ) + { + item.IsInstalled = true; + } + + rawResults.Add(item); + if (!HideInstalledModels || !item.IsInstalled) + { + Results.Add(item); + } + } + + HasSearched = true; + NoResultsFound = Results.Count == 0; + OnPropertyChanged(nameof(IsEndOfResults)); + } + catch (Exception ex) + { + NoResultsFound = Results.Count == 0; + NoResultsText = ex.Message; + } + finally + { + IsLoading = false; + SaveSettings(); + } + + if (searchQueued) + { + searchQueued = false; + await SearchModels(); + } + } + + [RelayCommand] + private void ClearSearchQuery() + { + SearchQuery = string.Empty; + } + + [RelayCommand] + private async Task OpenResult(CivArchiveSearchResult? result) + { + if (result is null) + { + return; + } + + switch (result.Kind) + { + case CivArchiveKindOption.User: + await PivotToUser(result); + break; + case CivArchiveKindOption.File: + await OpenFileResult(result); + break; + default: + NavigateToDetails(result.Url); + break; + } + } + + /// + /// File-kind results have a /sha256/{hash} URL whose endpoint returns a different + /// shape than the model details page. Resolve the SHA256 to its linked model + version + /// and navigate there in-app; if no model is linked (orphaned hash), fall back to opening + /// the URL externally. + /// + private async Task OpenFileResult(CivArchiveSearchResult result) + { + var resolvedUrl = await civArchiveApiClient.ResolveFileUrlAsync(result.Url); + if (!string.IsNullOrWhiteSpace(resolvedUrl)) + { + NavigateToDetails(resolvedUrl); + } + else + { + ProcessRunner.OpenUrl(civArchiveApiClient.GetAbsoluteUri(result.Url).ToString()); + } + } + + private void NavigateToDetails(string relativeUrl) + { + var detailsVm = viewModelFactory.Get(vm => + { + vm.RelativeUrl = relativeUrl; + return vm; + }); + navigationService.NavigateTo(detailsVm, BetterSlideNavigationTransition.PageSlideFromRight); + } + + [RelayCommand] + private void OpenOnCivArchive(CivArchiveSearchResult? result) + { + if (result is not null) + { + ProcessRunner.OpenUrl(civArchiveApiClient.GetAbsoluteUri(result.Url).ToString()); + } + } + + [RelayCommand] + private async Task SearchByCreator(CivArchiveSearchResult? result) + { + if (string.IsNullOrWhiteSpace(result?.Username)) + { + return; + } + + UsernameQuery = result.Username; + await SearchModels(); + } + + [RelayCommand] + private async Task CopySha256(CivArchiveSearchResult? result) + { + if (!string.IsNullOrWhiteSpace(result?.Sha256FromUrl) && App.Clipboard is not null) + { + await App.Clipboard.SetTextAsync(result.Sha256FromUrl); + } + } + + [RelayCommand] + private void ToggleAllModelTypes() + { + var shouldSelectAll = AllModelTypes.Any(x => !x.IsSelected); + suppressSearch = true; + try + { + foreach (var option in AllModelTypes) + { + option.IsSelected = shouldSelectAll; + } + } + finally + { + suppressSearch = false; + } + TriggerFilterSearch(); + } + + [RelayCommand] + private void ToggleAllBaseModels() + { + var shouldSelectAll = AllBaseModels.Any(x => !x.IsSelected); + suppressSearch = true; + try + { + foreach (var option in AllBaseModels) + { + option.IsSelected = shouldSelectAll; + } + } + finally + { + suppressSearch = false; + } + TriggerFilterSearch(); + } + + /// + /// Reset every filter back to its default. Single property setter at the end re-triggers + /// the search instead of one fetch per change. + /// + [RelayCommand] + private async Task ResetFilters() + { + suppressSearch = true; + try + { + SearchQuery = string.Empty; + TagQuery = string.Empty; + UsernameQuery = string.Empty; + SelectedPlatform = AllPlatforms.First(x => x.Value == CivArchivePlatformOption.All); + SelectedSort = AllSorts.First(x => x.Value == CivArchiveSortOption.Top); + SelectedPeriod = AllPeriods.First(x => x.Value == CivArchivePeriodOption.All); + SelectedRating = AllRatings.First(x => x.Value == CivArchiveRatingOption.Safe); + SelectedPlatformStatus = AllPlatformStatuses.First(x => + x.Value == CivArchivePlatformStatusOption.All + ); + SelectedKind = AllKinds.First(x => x.Value == CivArchiveKindOption.All); + foreach (var option in AllModelTypes) + option.IsSelected = true; + foreach (var option in AllBaseModels) + option.IsSelected = true; + } + finally + { + suppressSearch = false; + } + + await SearchModels(); + } + + public async Task LoadNextPageAsync() + { + // Compare against rawResults so infinite-scroll keeps fetching even when + // HideInstalledModels filters items out of Results. + if (!IsLoading && rawResults.Count < TotalHits) + { + await SearchModels(true); + } + } + + private async Task PivotToUser(CivArchiveSearchResult result) + { + UsernameQuery = !string.IsNullOrWhiteSpace(result.Username) ? result.Username : result.Name; + await SearchModels(); + } + + private void ApplyFilterOptions(CivArchiveFilterOptions options) + { + var savedOptions = settingsManager.Settings.CivArchiveBrowserOptions; + + suppressSearch = true; + SetSelectableOptions(AllModelTypes, options.ModelTypes, savedOptions.SelectedModelTypes); + SetSelectableOptions(AllBaseModels, options.BaseModels, savedOptions.SelectedBaseModels); + } + + private void SetSelectableOptions( + ObservableCollection target, + IEnumerable values, + IReadOnlyCollection selectedValues + ) + { + target.Clear(); + + var sortedValues = values.OrderBy(x => x, StringComparer.OrdinalIgnoreCase).ToList(); + var selectAll = selectedValues.Count == 0; + + foreach (var value in sortedValues) + { + var option = new BaseModelOptionViewModel + { + ModelType = value, + IsSelected = selectAll || selectedValues.Contains(value), + }; + option.PropertyChanged += OnSelectableOptionChanged; + target.Add(option); + } + + if (ReferenceEquals(target, AllModelTypes)) + { + ApplyModelTypeFilter(); + OnPropertyChanged(nameof(ModelTypeSelectionSummary)); + OnPropertyChanged(nameof(HasModelTypeFilter)); + OnPropertyChanged(nameof(SelectedModelTypeCount)); + } + else if (ReferenceEquals(target, AllBaseModels)) + { + ApplyBaseModelFilter(); + OnPropertyChanged(nameof(BaseModelSelectionSummary)); + OnPropertyChanged(nameof(HasBaseModelFilter)); + OnPropertyChanged(nameof(SelectedBaseModelCount)); + } + } + + partial void OnModelTypeFilterChanged(string value) => ApplyModelTypeFilter(); + + partial void OnBaseModelFilterChanged(string value) => ApplyBaseModelFilter(); + + private void ApplyModelTypeFilter() => + RefreshFilteredOptions(AllModelTypes, FilteredModelTypes, ModelTypeFilter); + + private void ApplyBaseModelFilter() => + RefreshFilteredOptions(AllBaseModels, FilteredBaseModels, BaseModelFilter); + + private static void RefreshFilteredOptions( + ObservableCollection source, + ObservableCollection target, + string filter + ) + { + var query = filter?.Trim() ?? string.Empty; + var matches = string.IsNullOrEmpty(query) + ? source + : source.Where(x => x.ModelType.Contains(query, StringComparison.OrdinalIgnoreCase)); + + target.Clear(); + foreach (var item in matches) + { + target.Add(item); + } + } + + private void RestoreSettings() + { + var options = settingsManager.Settings.CivArchiveBrowserOptions; + + suppressSearch = true; + SearchQuery = options.Query; + TagQuery = options.Tags; + UsernameQuery = options.Username; + SelectedPlatform = AllPlatforms.First(x => x.Value == options.Platform); + SelectedSort = AllSorts.First(x => x.Value == options.Sort); + SelectedPeriod = AllPeriods.First(x => x.Value == options.Period); + SelectedRating = AllRatings.First(x => x.Value == options.Rating); + SelectedPlatformStatus = AllPlatformStatuses.First(x => x.Value == options.PlatformStatus); + SelectedKind = AllKinds.First(x => x.Value == options.Kind); + suppressSearch = false; + } + + private CivArchiveSearchFilters BuildFilters(int page) + { + var selectedTypes = GetSelectedValues(AllModelTypes); + var selectedBaseModels = GetSelectedValues(AllBaseModels); + + // Parse @user / #tag tokens inline from the search box and merge with + // explicit values from the More Filters flyout. Inline tokens win for username + // (only one allowed); tags are merged additively. + var (cleanedQuery, parsedTags, parsedUsername) = ParseSearchQuery(SearchQuery); + var combinedTags = string.Join( + ",", + new[] { TagQuery, parsedTags }.Where(s => !string.IsNullOrWhiteSpace(s)) + ); + var combinedUsername = !string.IsNullOrWhiteSpace(parsedUsername) ? parsedUsername : UsernameQuery; + + return new CivArchiveSearchFilters + { + Query = cleanedQuery, + Tags = combinedTags, + Username = combinedUsername, + Platform = SelectedPlatform?.Value ?? CivArchivePlatformOption.All, + Sort = SelectedSort?.Value ?? CivArchiveSortOption.Top, + Period = SelectedPeriod?.Value ?? CivArchivePeriodOption.All, + Rating = SelectedRating?.Value ?? CivArchiveRatingOption.Safe, + PlatformStatus = SelectedPlatformStatus?.Value ?? CivArchivePlatformStatusOption.All, + Kind = SelectedKind?.Value ?? CivArchiveKindOption.All, + Types = selectedTypes.Count == AllModelTypes.Count ? [] : selectedTypes, + BaseModels = selectedBaseModels.Count == AllBaseModels.Count ? [] : selectedBaseModels, + Page = page, + }; + } + + /// + /// Pull @user and #tag tokens out of a free-form search string. + /// Returns the leftover query (model name search), comma-joined tag list, and the + /// parsed username (last-wins if multiple @ tokens are present). + /// + internal static (string query, string tags, string username) ParseSearchQuery(string raw) + { + if (string.IsNullOrWhiteSpace(raw)) + return (string.Empty, string.Empty, string.Empty); + + var tokens = raw.Split(' ', StringSplitOptions.RemoveEmptyEntries); + var queryParts = new List(); + var tags = new List(); + string username = string.Empty; + + foreach (var token in tokens) + { + if (token.Length > 1 && token[0] == '@') + username = token[1..]; // last @user wins + else if (token.Length > 1 && token[0] == '#') + tags.Add(token[1..]); + else + queryParts.Add(token); + } + + return (string.Join(' ', queryParts), string.Join(',', tags), username); + } + + private static List GetSelectedValues(IEnumerable options) + { + return options.Where(x => x.IsSelected).Select(x => x.ModelType).ToList(); + } + + private void SaveSettings() + { + if (!settingsManager.IsLibraryDirSet) + { + return; + } + + settingsManager.Transaction(s => + s.CivArchiveBrowserOptions = new CivArchiveBrowserOptions + { + Query = SearchQuery, + Tags = TagQuery, + Username = UsernameQuery, + Platform = SelectedPlatform?.Value ?? CivArchivePlatformOption.All, + Sort = SelectedSort?.Value ?? CivArchiveSortOption.Top, + Period = SelectedPeriod?.Value ?? CivArchivePeriodOption.All, + Rating = SelectedRating?.Value ?? CivArchiveRatingOption.Safe, + PlatformStatus = SelectedPlatformStatus?.Value ?? CivArchivePlatformStatusOption.All, + Kind = SelectedKind?.Value ?? CivArchiveKindOption.All, + SelectedModelTypes = GetSelectedValues(AllModelTypes), + SelectedBaseModels = GetSelectedValues(AllBaseModels), + } + ); + } + + private void OnSelectableOptionChanged(object? sender, PropertyChangedEventArgs e) + { + if (e.PropertyName != nameof(BaseModelOptionViewModel.IsSelected)) + { + return; + } + + OnPropertyChanged(nameof(ModelTypeSelectionSummary)); + OnPropertyChanged(nameof(BaseModelSelectionSummary)); + OnPropertyChanged(nameof(HasModelTypeFilter)); + OnPropertyChanged(nameof(HasBaseModelFilter)); + OnPropertyChanged(nameof(SelectedModelTypeCount)); + OnPropertyChanged(nameof(SelectedBaseModelCount)); + + RequestDebouncedSearch(); + } + + partial void OnSelectedPlatformChanged(NamedOption? value) => + TriggerFilterSearch(); + + partial void OnSelectedSortChanged(NamedOption? value) => TriggerFilterSearch(); + + partial void OnSelectedPeriodChanged(NamedOption? value) => TriggerFilterSearch(); + + partial void OnSelectedRatingChanged(NamedOption? value) => TriggerFilterSearch(); + + partial void OnSelectedPlatformStatusChanged(NamedOption? value) => + TriggerFilterSearch(); + + partial void OnSelectedKindChanged(NamedOption? value) => TriggerFilterSearch(); + + private void TriggerFilterSearch() => RequestDebouncedSearch(); +} diff --git a/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowser/CivArchiveDetailsPageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowser/CivArchiveDetailsPageViewModel.cs new file mode 100644 index 000000000..d84c50186 --- /dev/null +++ b/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowser/CivArchiveDetailsPageViewModel.cs @@ -0,0 +1,922 @@ +using System.Collections.ObjectModel; +using System.ComponentModel.DataAnnotations; +using System.IO; +using System.Reactive.Disposables; +using System.Reactive.Linq; +using System.Threading; +using AsyncAwaitBestPractices; +using Avalonia.Threading; +using CommunityToolkit.Mvvm.ComponentModel; +using CommunityToolkit.Mvvm.Input; +using DynamicData.Binding; +using FluentAvalonia.Core; +using FluentAvalonia.UI.Controls; +using Injectio.Attributes; +using StabilityMatrix.Avalonia.Extensions; +using StabilityMatrix.Avalonia.Models; +using StabilityMatrix.Avalonia.Models.Inference; +using StabilityMatrix.Avalonia.Services; +using StabilityMatrix.Avalonia.ViewModels.Base; +using StabilityMatrix.Avalonia.ViewModels.Dialogs; +using StabilityMatrix.Avalonia.Views; +using StabilityMatrix.Core.Api; +using StabilityMatrix.Core.Attributes; +using StabilityMatrix.Core.Extensions; +using StabilityMatrix.Core.Helper; +using StabilityMatrix.Core.Models; +using StabilityMatrix.Core.Models.Api; +using StabilityMatrix.Core.Models.Api.CivArchive; +using StabilityMatrix.Core.Models.Database; +using StabilityMatrix.Core.Models.FileInterfaces; +using StabilityMatrix.Core.Models.Progress; +using StabilityMatrix.Core.Processes; +using StabilityMatrix.Core.Services; + +namespace StabilityMatrix.Avalonia.ViewModels.CheckpointBrowser; + +[View(typeof(CivArchiveDetailsPage))] +[ManagedService] +[RegisterTransient] +public partial class CivArchiveDetailsPageViewModel( + ICivArchiveApiClient civArchiveApiClient, + INavigationService navigationService, + IServiceManager vmFactory, + IModelImportService modelImportService, + ISettingsManager settingsManager, + INotificationService notificationService, + IModelIndexService modelIndexService +) : DisposableViewModelBase +{ + private static readonly string[] IgnoredFileNameFormatVars = + [ + "seed", + "prompt", + "negative_prompt", + "model_hash", + "sampler", + "cfgscale", + "steps", + "width", + "height", + "project_type", + "project_name", + ]; + + public IEnumerable ModelFileNameFormatVars => + FileNameFormatProvider + .GetSampleForModelBrowser() + .Substitutions.Where(kv => !IgnoredFileNameFormatVars.Contains(kv.Key)) + .Select(kv => new FileNameFormatVar { Variable = $"{{{kv.Key}}}", Example = kv.Value.Invoke() }); + + [ObservableProperty] + public partial string RelativeUrl { get; set; } = string.Empty; + + [ObservableProperty] + public partial CivArchiveModelDetails? Model { get; set; } + + [ObservableProperty] + public partial bool IsLoading { get; set; } + + [ObservableProperty] + public partial string ErrorText { get; set; } = string.Empty; + + [ObservableProperty] + public partial CivArchiveVersionReference? SelectedVersion { get; set; } + + [ObservableProperty] + public partial string ModelDescriptionHtml { get; set; } = string.Empty; + + [ObservableProperty] + public partial string VersionDescriptionHtml { get; set; } = string.Empty; + + [ObservableProperty] + public partial bool HasDownloadUrl { get; set; } + + [ObservableProperty] + public partial bool IsInstalled { get; set; } + + [ObservableProperty] + public partial string? InstalledLocation { get; set; } + + [ObservableProperty] + [CustomValidation(typeof(CivArchiveDetailsPageViewModel), nameof(ValidateModelFileNameFormat))] + public partial string? ModelFileNameFormat { get; set; } + + [ObservableProperty] + public partial string? ModelNameFormatSample { get; set; } + + public ObservableCollection Images { get; } = []; + public ObservableCollection Files { get; } = []; + public ObservableCollection Mirrors { get; } = []; + + private static readonly Dictionary< + string, + (SharedFolderType Folder, CivitModelType ModelType) + > ModelTypeMap = new(StringComparer.OrdinalIgnoreCase) + { + ["Checkpoint"] = (SharedFolderType.StableDiffusion, CivitModelType.Checkpoint), + ["LORA"] = (SharedFolderType.Lora, CivitModelType.LORA), + ["DoRA"] = (SharedFolderType.Lora, CivitModelType.DoRA), + ["LoCon"] = (SharedFolderType.LyCORIS, CivitModelType.LoCon), + ["TextualInversion"] = (SharedFolderType.Embeddings, CivitModelType.TextualInversion), + ["Hypernetwork"] = (SharedFolderType.Hypernetwork, CivitModelType.Hypernetwork), + ["Controlnet"] = (SharedFolderType.ControlNet, CivitModelType.Controlnet), + ["VAE"] = (SharedFolderType.VAE, CivitModelType.VAE), + ["Upscaler"] = (SharedFolderType.ESRGAN, CivitModelType.Upscaler), + }; + + protected override async Task OnInitialLoadedAsync() + { + await base.OnInitialLoadedAsync(); + + AddDisposable( + settingsManager.RelayPropertyFor( + this, + vm => vm.ModelFileNameFormat, + settings => settings.CivitModelBrowserFileNamePattern, + true + ) + ); + + AddDisposable( + this.WhenPropertyChanged(vm => vm.ModelFileNameFormat) + .Throttle(TimeSpan.FromMilliseconds(50)) + .ObserveOn(SynchronizationContext.Current!) + .Subscribe(_ => UpdateNameFormatSample()) + ); + + AddDisposable( + this.WhenPropertyChanged(vm => vm.Model) + .ObserveOn(SynchronizationContext.Current!) + .Subscribe(_ => UpdateNameFormatSample()) + ); + + // Refresh the IsInstalled badge / Download button label when the user downloads + // (or deletes) this model β€” without forcing them to navigate away and back. + EventHandler indexChangedHandler = (_, _) => + Dispatcher.UIThread.Post(() => UpdateInstalledStatus(Model?.Version)); + EventManager.Instance.ModelIndexChanged += indexChangedHandler; + AddDisposable( + Disposable.Create(() => EventManager.Instance.ModelIndexChanged -= indexChangedHandler) + ); + } + + public override async Task OnLoadedAsync() + { + await base.OnLoadedAsync(); + + if (IsLoading || string.IsNullOrWhiteSpace(RelativeUrl)) + { + return; + } + + await LoadModelAsync(); + } + + public static ValidationResult ValidateModelFileNameFormat(string? format, ValidationContext context) + { + return FileNameFormatProvider.GetSampleForModelBrowser().Validate(format ?? string.Empty); + } + + private void UpdateNameFormatSample() + { + var provider = BuildFormatProvider(Model?.Version, GetPrimaryFile(Model?.Version)); + var format = ParseFormatOrDefault(ModelFileNameFormat, provider); + + var sample = NormalizePathSegments(format.GetFileName()); + ModelNameFormatSample = string.IsNullOrWhiteSpace(sample) + ? null + : "Example: " + sample + ".safetensors"; + } + + /// + /// Strip empty path segments left behind by null/empty substitutions, so a pattern + /// like {base_model}/{file_name} with an empty base_model collapses to + /// file_name instead of /file_name. + /// Also drops .. / . traversal segments so a pattern variable that + /// resolves to .. can't escape the destination folder. + /// + private static string NormalizePathSegments(string raw) + { + if (string.IsNullOrEmpty(raw) || (!raw.Contains('/') && !raw.Contains('\\'))) + return raw; + + var parts = raw.Split( + ['/', '\\'], + StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries + ) + .Where(p => p != ".." && p != "."); + return string.Join('/', parts); + } + + /// + /// Parse a template against the provider, falling back to the default template if the input is empty, + /// references unknown variables, or fails to parse for any reason. Validate() must be called first + /// because Parse() throws KeyNotFoundException on unknown variables (e.g. mid-typing "{base}" before + /// the user finishes "{base_model}"). + /// + private static FileNameFormat ParseFormatOrDefault(string? template, FileNameFormatProvider provider) + { + if ( + !string.IsNullOrEmpty(template) + && provider.Validate(template) == ValidationResult.Success + && FileNameFormat.TryParse(template, provider, out var format) + ) + { + return format; + } + + return FileNameFormat.Parse(FileNameFormat.DefaultModelBrowserTemplate, provider); + } + + private FileNameFormatProvider BuildFormatProvider( + CivArchiveModelVersion? version, + CivArchiveModelFile? primaryFile + ) + { + if (Model is null) + { + return new FileNameFormatProvider(); + } + + // Build CivitModel-shaped stubs so FileNameFormatProvider can resolve {model_name}, {file_name}, etc. + var modelType = CivitModelType.Unknown; + if (Model.Type is not null && ModelTypeMap.TryGetValue(Model.Type, out var mapping)) + { + modelType = mapping.ModelType; + } + + var synthesizedFileName = string.IsNullOrWhiteSpace(version?.Name) + ? $"{Model.Name}.safetensors" + : $"{Model.Name}_{version.Name}.safetensors"; + + var civitModel = new CivitModel + { + Id = int.TryParse(Model.Id, out var modelId) ? modelId : 0, + Name = Model.Name, + Type = modelType, + Creator = new CivitCreator { Username = Model.CreatorUsername ?? Model.Username }, + }; + + var civitVersion = version is null + ? null + : new CivitModelVersion + { + Id = int.TryParse(version.Id, out var versionId) ? versionId : 0, + Name = version.Name, + BaseModel = version.BaseModel, + }; + + var civitFile = new CivitFile + { + Id = int.TryParse(primaryFile?.Id, out var fileId) ? fileId : 0, + Name = !string.IsNullOrWhiteSpace(primaryFile?.Name) ? primaryFile.Name : synthesizedFileName, + }; + + return new FileNameFormatProvider + { + CivitModel = civitModel, + CivitModelVersion = civitVersion, + CivitFile = civitFile, + }; + } + + private async Task LoadModelAsync() + { + IsLoading = true; + ErrorText = string.Empty; + + try + { + var response = await civArchiveApiClient.GetModelDetailsAsync(RelativeUrl); + Model = response.Model; + + ModelDescriptionHtml = WrapHtml(response.Model.Description); + PopulateVersionData(response.Model.Version); + } + catch (Exception ex) + { + ErrorText = ex.Message; + } + finally + { + IsLoading = false; + } + } + + private void PopulateVersionData(CivArchiveModelVersion? version) + { + VersionDescriptionHtml = WrapHtml(version?.Description); + HasDownloadUrl = GetDownloadUris(version).Count > 0; + + Images.Clear(); + foreach (var image in version?.Images.Where(IsUsableImage) ?? []) + { + Images.Add(image); + } + + Files.Clear(); + foreach (var file in version?.Files ?? []) + { + Files.Add(file); + } + + Mirrors.Clear(); + foreach (var mirror in version?.Mirrors ?? []) + { + Mirrors.Add(mirror); + } + + UpdateInstalledStatus(version); + } + + private void UpdateInstalledStatus(CivArchiveModelVersion? version) + { + // First try URL match β€” works for every platform, including ones where file hashes are missing. + var installedUrls = modelIndexService.ModelIndexCivArchiveUrls; + if (!string.IsNullOrWhiteSpace(RelativeUrl) && installedUrls.Contains(RelativeUrl)) + { + IsInstalled = true; + InstalledLocation = LookupInstalledLocationByUrl(RelativeUrl); + return; + } + + // Fallback to SHA256 match β€” catches CivitAI mirrors with full file hashes, + // including models downloaded via SM before SourceUrl tracking existed. + var hashes = version + ?.Files.Select(f => f.Sha256) + .Where(s => !string.IsNullOrWhiteSpace(s)) + .Cast() + .ToList(); + + if (hashes is null || hashes.Count == 0) + { + IsInstalled = false; + InstalledLocation = null; + return; + } + + var installedHashes = modelIndexService.ModelIndexSha256Hashes; + var matchedHash = hashes.FirstOrDefault(h => installedHashes.Contains(h)); + + if (matchedHash is null) + { + IsInstalled = false; + InstalledLocation = null; + return; + } + + IsInstalled = true; + _ = LoadInstalledLocationAsync(matchedHash); + } + + private string? LookupInstalledLocationByUrl(string sourceUrl) + { + return modelIndexService + .ModelIndex.Values.SelectMany(x => x) + .FirstOrDefault(m => + m.HasCivArchiveMetadata + && string.Equals( + m.ConnectedModelInfo.SourceUrl, + sourceUrl, + StringComparison.OrdinalIgnoreCase + ) + ) + ?.RelativePath; + } + + private async Task LoadInstalledLocationAsync(string sha256) + { + try + { + var matches = await modelIndexService.FindBySha256Async(sha256); + var first = matches?.FirstOrDefault(); + InstalledLocation = first?.RelativePath; + } + catch + { + InstalledLocation = null; + } + } + + private static string WrapHtml(string? html) + { + if (string.IsNullOrWhiteSpace(html)) + { + return string.Empty; + } + + return $"""{html}"""; + } + + [RelayCommand] + private async Task ShowImageDialog(CivArchiveModelImage? image) + { + if (image?.Url is null) + { + return; + } + + var currentIndex = Images.IndexOf(image); + var imageSource = await PrepareImageSourceAsync(image.Url); + if (imageSource is null) + return; + + var vm = vmFactory.Get(); + vm.ImageSource = imageSource; + + using var onNav = Observable + .FromEventPattern( + vm, + nameof(ImageViewerViewModel.NavigationRequested) + ) + .ObserveOn(SynchronizationContext.Current!) + .Subscribe(ctx => + { + Dispatcher + .UIThread.InvokeAsync(async () => + { + var sender = (ImageViewerViewModel)ctx.Sender!; + var newIndex = currentIndex + (ctx.EventArgs.IsNext ? 1 : -1); + + if (newIndex >= 0 && newIndex < Images.Count) + { + var newImage = Images[newIndex]; + if (newImage.Url is null) + { + return; + } + + var newSource = await PrepareImageSourceAsync(newImage.Url); + if (newSource is null) + return; + + sender.ImageSource = newSource; + currentIndex = newIndex; + } + }) + .SafeFireAndForget(); + }); + + await vm.GetDialog().ShowAsync(); + } + + /// + /// Build an ready for the image viewer to render. + /// The viewer's template selector keys off ImageSource.TemplateKey; if that's + /// Default, the selector renders the literal "Unsupported Format" text. The + /// URL-construction path leaves TemplateKey as Default until a Task-based binding + /// resolves it, which races with the viewer's first paint on extensionless CivArchive + /// CDN URLs (e.g. img.genur.art/sig/.../base64). Use the bitmap-only constructor + /// instead β€” it sets TemplateKey to Image synchronously, which the AdvancedImageBox + /// can render whether the bytes were JPEG, PNG, or WebP. + /// + private static async Task PrepareImageSourceAsync(string url) + { + try + { + var loader = new ImageSource(new Uri(url)); + var bitmap = await loader.GetBitmapAsync(); + return bitmap is not null ? new ImageSource(bitmap) { RemoteUrl = new Uri(url) } : null; + } + catch + { + return null; + } + } + + [RelayCommand] + private async Task SelectVersion(CivArchiveVersionReference? versionRef) + { + if (versionRef is null || string.IsNullOrWhiteSpace(versionRef.Href) || IsLoading) + { + return; + } + + SelectedVersion = versionRef; + RelativeUrl = versionRef.Href; + await LoadModelAsync(); + } + + [RelayCommand] + private void GoBack() + { + if (!navigationService.GoBack()) + { + navigationService.NavigateTo(); + } + } + + [RelayCommand] + private void OpenOnCivArchive() + { + ProcessRunner.OpenUrl(civArchiveApiClient.GetAbsoluteUri(RelativeUrl).ToString()); + } + + [RelayCommand] + private async Task DownloadModel() + { + var version = Model?.Version; + if (version is null) + return; + + var primaryFile = GetPrimaryFile(version); + await ExecuteDownloadAsync(version, primaryFile, GetDownloadUris(version), sourceLabel: null); + } + + [RelayCommand] + private async Task DeleteModel() + { + var localFiles = FindLocallyInstalledFiles(); + if (localFiles.Count == 0) + return; + + var pathsToDelete = new List(); + foreach (var localFile in localFiles) + { + var checkpointPath = new FilePath(localFile.GetFullPath(settingsManager.ModelsDirectory)); + if (File.Exists(checkpointPath)) + pathsToDelete.Add(checkpointPath); + + var previewPath = localFile.GetPreviewImageFullPath(settingsManager.ModelsDirectory); + if (!string.IsNullOrEmpty(previewPath) && File.Exists(previewPath)) + pathsToDelete.Add(previewPath); + + var cmInfoPath = checkpointPath + .ToString() + .Replace(checkpointPath.Extension, ConnectedModelInfo.FileExtension); + if (File.Exists(cmInfoPath)) + pathsToDelete.Add(cmInfoPath); + } + + if (pathsToDelete.Count == 0) + return; + + var confirmDeleteVm = vmFactory.Get(); + confirmDeleteVm.PathsToDelete = pathsToDelete; + + var dialog = confirmDeleteVm.GetDialog(); + var result = await dialog.ShowAsync(); + + if (result != ContentDialogResult.Primary) + return; + + try + { + await confirmDeleteVm.ExecuteCurrentDeleteOperationAsync(failFast: true); + } + catch (Exception ex) + { + notificationService.Show("Delete failed", ex.Message); + } + finally + { + await modelIndexService.RefreshIndex(); + // RefreshIndex fires ModelIndexChanged β†’ our handler updates IsInstalled + // and the UI flips back to "Download" automatically. + } + } + + /// + /// Find every locally installed file matching this model β€” try the SourceUrl first, + /// fall back to SHA256 hash matches for legacy downloads without SourceUrl. + /// + private List FindLocallyInstalledFiles() + { + var matches = new List(); + var allLocal = modelIndexService.ModelIndex.Values.SelectMany(x => x).ToList(); + + if (!string.IsNullOrWhiteSpace(RelativeUrl)) + { + matches.AddRange( + allLocal.Where(m => + m.HasCivArchiveMetadata + && string.Equals( + m.ConnectedModelInfo.SourceUrl, + RelativeUrl, + StringComparison.OrdinalIgnoreCase + ) + ) + ); + } + + if (matches.Count == 0 && Model?.Version is { } version) + { + var hashes = version + .Files.Select(f => f.Sha256) + .Where(s => !string.IsNullOrWhiteSpace(s)) + .Cast() + .ToHashSet(StringComparer.OrdinalIgnoreCase); + + if (hashes.Count > 0) + { + matches.AddRange( + allLocal.Where(m => + !string.IsNullOrWhiteSpace(m.HashSha256) && hashes.Contains(m.HashSha256) + ) + ); + } + } + + return matches; + } + + [RelayCommand] + private async Task DownloadFile(CivArchiveModelFile? file) + { + var version = Model?.Version; + if (version is null || file is null) + return; + + await ExecuteDownloadAsync(version, file, GetDownloadUrisForFile(file), sourceLabel: null); + } + + [RelayCommand] + private async Task DownloadFromMirror(CivArchiveFileMirror? mirror) + { + if (mirror is null || string.IsNullOrWhiteSpace(mirror.Url)) + return; + + // Gated/paid mirrors require auth or payment we can't handle in-app β€” open externally. + if (mirror.IsGated || mirror.IsPaid) + { + ProcessRunner.OpenUrl(mirror.Url); + return; + } + + var version = Model?.Version; + if (version is null) + return; + + // Find the parent file so we can attach hash + cm-info correctly. + var parentFile = version.Files.FirstOrDefault(f => f.Mirrors.Contains(mirror)); + + await ExecuteDownloadAsync( + version, + parentFile, + [civArchiveApiClient.GetAbsoluteUri(mirror.Url)], + sourceLabel: mirror.Source + ); + } + + private async Task ExecuteDownloadAsync( + CivArchiveModelVersion version, + CivArchiveModelFile? file, + IReadOnlyList downloadUris, + string? sourceLabel + ) + { + if (Model is null) + return; + + if (downloadUris.Count == 0) + { + notificationService.Show( + "No download available", + "This file has no usable download URL β€” every mirror was either missing or gated/paid." + ); + return; + } + + if (!settingsManager.IsLibraryDirSet) + { + notificationService.Show("Download Failed", "Please set a library directory in settings first."); + return; + } + + var destinationDir = GetDefaultDownloadFolder(); + var fileName = BuildDownloadFileName(version, file); + + Uri? previewImageUri = null; + string? previewImageExtension = null; + var firstImage = version.Images.FirstOrDefault(IsUsableImage); + if (firstImage?.Url is not null) + { + previewImageUri = new Uri(firstImage.Url); + previewImageExtension = ResolvePreviewImageExtension(previewImageUri); + } + + var connectedModelInfo = BuildConnectedModelInfo(Model, version, RelativeUrl); + // Override hash so the cm-info matches the specific file being downloaded + // (BuildConnectedModelInfo defaults to the primary file's hash). + if (!string.IsNullOrWhiteSpace(file?.Sha256)) + { + connectedModelInfo.Hashes = new CivitFileHashes { SHA256 = file.Sha256 }; + } + + await modelImportService.DoCustomImport( + downloadUris, + fileName, + destinationDir, + previewImageUri, + previewImageFileExtension: previewImageExtension, + connectedModelInfo: connectedModelInfo, + configureDownload: download => + { + if (!string.IsNullOrWhiteSpace(file?.Sha256)) + { + download.ExpectedHashSha256 = file.Sha256; + } + + // The CivitAI flow uses CivitPostDownloadContextAction to refresh the + // model index post-download; we don't have an analogous context action + // (we rely on cm-info instead of Blake3 hash), so subscribe directly to + // ProgressStateChanged. Refreshing the index fires ModelIndexChanged, + // which our OnInitialLoadedAsync subscription uses to flip the Installed + // badge / "Download again" label live. + download.ProgressStateChanged += (_, state) => + { + if (state == ProgressState.Success) + { + modelIndexService.BackgroundRefreshIndex(); + } + }; + } + ); + + var finalPath = destinationDir.JoinFile(fileName); + var sourceText = string.IsNullOrEmpty(sourceLabel) ? string.Empty : $" from {sourceLabel}"; + notificationService.Show( + "Download Started", + $"{finalPath.Name}{sourceText} will be saved to {finalPath.Directory}" + ); + } + + /// + /// CivArchive aggregates images from many platforms; some URLs don't end in a recognizable + /// extension (e.g. CivitAI's "/width=512/img" style paths or extension-less CDN URLs). Try + /// Path.GetExtension first, then scan the raw URL for a known image extension, then fall back + /// to ".jpeg" so the import never fails at the preview-image step. + /// + private static string ResolvePreviewImageExtension(Uri previewImageUri) + { + var fromPath = Path.GetExtension(previewImageUri.LocalPath); + if (!string.IsNullOrWhiteSpace(fromPath)) + return fromPath; + + ReadOnlySpan known = [".jpeg", ".jpg", ".png", ".webp", ".gif", ".avif"]; + var raw = previewImageUri.ToString(); + foreach (var ext in known) + { + if (raw.Contains(ext, StringComparison.OrdinalIgnoreCase)) + return ext; + } + + return ".jpeg"; + } + + private IReadOnlyList GetDownloadUrisForFile(CivArchiveModelFile file) + { + var urlCandidates = new List { file.DownloadUrl }; + if (file.Mirrors is not null) + { + urlCandidates.AddRange(file.Mirrors.Where(m => !m.IsGated && !m.IsPaid).Select(m => m.Url)); + } + + return urlCandidates + .Where(url => !string.IsNullOrWhiteSpace(url)) + .Select(url => civArchiveApiClient.GetAbsoluteUri(url!)) + .Distinct() + .ToList(); + } + + private IReadOnlyList GetDownloadUris(CivArchiveModelVersion? version) + { + if (version is null) + { + return []; + } + + var primaryFile = GetPrimaryFile(version); + var urlCandidates = new List { version.DownloadUrl, primaryFile?.DownloadUrl }; + + if (primaryFile?.Mirrors is not null) + { + urlCandidates.AddRange(primaryFile.Mirrors.Select(mirror => mirror.Url)); + } + + return urlCandidates + .Where(url => !string.IsNullOrWhiteSpace(url)) + .Select(url => civArchiveApiClient.GetAbsoluteUri(url!)) + .Distinct() + .ToList(); + } + + private static CivArchiveModelFile? GetPrimaryFile(CivArchiveModelVersion? version) + { + if (version is null) + { + return null; + } + + return version.Files.FirstOrDefault(f => f.IsPrimary) ?? version.Files.FirstOrDefault(); + } + + private string BuildDownloadFileName(CivArchiveModelVersion version, CivArchiveModelFile? primaryFile) + { + var extension = !string.IsNullOrWhiteSpace(primaryFile?.Name) + ? Path.GetExtension(primaryFile.Name) + : ".safetensors"; + if (string.IsNullOrEmpty(extension)) + { + extension = ".safetensors"; + } + + var provider = BuildFormatProvider(version, primaryFile); + var format = ParseFormatOrDefault(ModelFileNameFormat, provider); + + // Normalize so a leading "/" from an empty {base_model} doesn't make Path.Combine + // treat the name as rooted and drop the destination folder. + var stem = NormalizePathSegments(format.GetFileName()); + + if (string.IsNullOrWhiteSpace(stem)) + { + // Pattern resolved to empty (e.g. only {file_name} on a non-CivitAI mirror with no primary file). + // Fall back to a sensible synthesized name. + stem = string.IsNullOrWhiteSpace(version.Name) + ? Model?.Name ?? "model" + : $"{Model?.Name ?? "model"}_{version.Name}"; + } + + return stem + extension; + } + + private DirectoryPath GetDefaultDownloadFolder() + { + var modelType = Model?.Type; + if (modelType is not null && ModelTypeMap.TryGetValue(modelType, out var mapping)) + { + return new DirectoryPath(settingsManager.ModelsDirectory, mapping.Folder.GetStringValue()); + } + + return new DirectoryPath(settingsManager.ModelsDirectory); + } + + private static ConnectedModelInfo BuildConnectedModelInfo( + CivArchiveModelDetails model, + CivArchiveModelVersion version, + string sourceUrl + ) + { + var civitModelType = CivitModelType.Unknown; + if (model.Type is not null && ModelTypeMap.TryGetValue(model.Type, out var mapping)) + { + civitModelType = mapping.ModelType; + } + + var primaryFile = version.Files.FirstOrDefault(f => f.IsPrimary) ?? version.Files.FirstOrDefault(); + + return new ConnectedModelInfo + { + ModelName = model.Name, + ModelDescription = model.Description ?? string.Empty, + Nsfw = model.IsNsfw, + Tags = model.Tags.ToArray(), + ModelType = civitModelType, + VersionName = version.Name, + VersionDescription = version.Description, + BaseModel = version.BaseModel, + ImportedAt = DateTimeOffset.UtcNow, + Hashes = new CivitFileHashes { SHA256 = primaryFile?.Sha256 }, + TrainedWords = version.Trigger.ToArray(), + ThumbnailImageUrl = version.Images.FirstOrDefault(IsUsableImage)?.Url, + Source = ConnectedModelSource.CivArchive, + SourceUrl = sourceUrl, + Stats = new CivitModelStats + { + DownloadCount = (int)model.DownloadCount, + FavoriteCount = (int)model.FavoriteCount, + CommentCount = (int)model.CommentCount, + RatingCount = (int)model.RatingCount, + Rating = model.Rating, + }, + }; + } + + private static bool IsUsableImage(CivArchiveModelImage image) + { + return !string.IsNullOrWhiteSpace(image.Url) + && ( + string.IsNullOrWhiteSpace(image.Type) + || string.Equals(image.Type, "image", StringComparison.OrdinalIgnoreCase) + ); + } + + [RelayCommand] + private void OpenVersionMirror(CivArchiveVersionMirror? mirror) + { + if (!string.IsNullOrWhiteSpace(mirror?.PlatformUrl)) + { + ProcessRunner.OpenUrl(mirror.PlatformUrl); + } + } + + [RelayCommand] + private async Task CopySha256(CivArchiveModelFile? file) + { + if (!string.IsNullOrWhiteSpace(file?.Sha256) && App.Clipboard is not null) + { + await App.Clipboard.SetTextAsync(file.Sha256); + } + } +} diff --git a/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowser/CivitAiBrowserViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowser/CivitAiBrowserViewModel.cs index bdcd4bcf8..bf5460efa 100644 --- a/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowser/CivitAiBrowserViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowser/CivitAiBrowserViewModel.cs @@ -11,6 +11,7 @@ using CommunityToolkit.Mvvm.Input; using DynamicData; using DynamicData.Binding; +using FluentAvalonia.UI.Media.Animation; using Injectio.Attributes; using NLog; using Refit; @@ -20,6 +21,7 @@ using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Avalonia.ViewModels.CheckpointManager; +using StabilityMatrix.Avalonia.ViewModels.Settings; using StabilityMatrix.Avalonia.Views; using StabilityMatrix.Core.Api; using StabilityMatrix.Core.Attributes; @@ -48,6 +50,7 @@ public sealed partial class CivitAiBrowserViewModel : TabViewModelBase, IInfinit private readonly INotificationService notificationService; private readonly ICivitBaseModelTypeService baseModelTypeService; private readonly INavigationService navigationService; + private readonly INavigationService settingsNavigationService; private bool dontSearch = false; private readonly SourceCache, int> modelCache = new(static ov => ov.Value.Id); @@ -147,7 +150,8 @@ public CivitAiBrowserViewModel( IConnectedServiceManager connectedServiceManager, INotificationService notificationService, ICivitBaseModelTypeService baseModelTypeService, - INavigationService navigationService + INavigationService navigationService, + INavigationService settingsNavigationService ) { this.civitApi = civitApi; @@ -158,6 +162,7 @@ INavigationService navigationService this.notificationService = notificationService; this.baseModelTypeService = baseModelTypeService; this.navigationService = navigationService; + this.settingsNavigationService = settingsNavigationService; EventManager.Instance.NavigateAndFindCivitModelRequested += OnNavigateAndFindCivitModelRequested; @@ -453,7 +458,41 @@ private async Task CivitModelQuery(CivitModelsRequest request, bool isInfiniteSc else { modelsResponse = await civitApi.GetModels(request); - models = modelsResponse.Items; + if (modelsResponse.Items != null) + { + models.AddRange(modelsResponse.Items); + } + } + + // CivitAI's list endpoint (/api/v1/models?ids=...) sometimes returns zero items + // for models that the single-model endpoint (/api/v1/models/{id}) can find just + // fine β€” another server-side cache desync. For each requested ID that didn't + // come back, fall back to a per-ID lookup (which itself goes through the tRPC + // fallback in CivitCompatApiManager when needed). + var returnedIds = models.Select(m => m.Id).ToHashSet(); + foreach (var idStr in ids) + { + if (!int.TryParse(idStr.Trim(), out var idValue) || returnedIds.Contains(idValue)) + continue; + + try + { + var single = await civitApi.GetModelById(idValue); + // GetModelById can return a default-id object on 404 with some implementations; + // only accept it if it actually looks valid. + if (single.Id == idValue) + { + models.Add(single); + Logger.Info( + "Recovered model {Id} via per-ID fallback after list endpoint missed it", + idValue + ); + } + } + catch (Exception ex) + { + Logger.Warn(ex, "Per-ID fallback failed for model {Id}; skipping", idValue); + } } } else @@ -553,14 +592,32 @@ private async Task CivitModelQuery(CivitModelsRequest request, bool isInfiniteSc if (cacheNew) { + // ID-targeted searches ($#1234 / civitai.com URLs / Installed/Favorites sorts) + // explicitly bypass the type+base-model filters when building the request, so the + // post-response sanity check would otherwise reject perfectly valid results + // (e.g. searching $#someLoraId while SelectedModelType=Checkpoint). + var isIdSearch = !string.IsNullOrEmpty(request.CommaSeparatedModelIds); + + // "No filter" is sent when either zero or all base models are selected β€” mirror that + // when checking the response matches the current filter state. + var isNoBaseModelFilter = + SelectedBaseModels.Count == 0 || SelectedBaseModels.Count == AllBaseModels.Count; var doesBaseModelTypeMatch = - SelectedBaseModels.Count == 0 - ? request.BaseModels == null || request.BaseModels.Length == 0 - : SelectedBaseModels.SequenceEqual(request.BaseModels ?? []); + isIdSearch + || ( + isNoBaseModelFilter + ? request.BaseModels == null || request.BaseModels.Length == 0 + : SelectedBaseModels + .OrderBy(x => x) + .SequenceEqual((request.BaseModels ?? []).OrderBy(x => x)) + ); var doesModelTypeMatch = - SelectedModelType == CivitModelType.All - ? request.Types == null || request.Types.Length == 0 - : SelectedModelType == request.Types?.FirstOrDefault(); + isIdSearch + || ( + SelectedModelType == CivitModelType.All + ? request.Types == null || request.Types.Length == 0 + : SelectedModelType == request.Types?.FirstOrDefault() + ); if (doesBaseModelTypeMatch && doesModelTypeMatch) { @@ -839,9 +896,60 @@ private void ClearOrSelectAllBaseModels() } [RelayCommand] - private void ShowVersionDialog(CivitModel model) + private async Task NavigateToBaseModelSettings() + { + navigationService.NavigateTo(new SuppressNavigationTransitionInfo()); + await Task.Delay(100); + settingsNavigationService.NavigateTo(new SuppressNavigationTransitionInfo()); + } + + [RelayCommand] + private async Task ShowVersionDialog(CivitModel model) { var versions = model.ModelVersions; + + // The CivitAI public REST API sometimes returns models with an empty modelVersions list + // even when versions exist on the website (server-side cache desync). Re-fetch via + // GetModelById β€” CivitCompatApiManager will transparently fall back to the tRPC endpoint + // to recover the missing versions when the REST response is empty. + if (versions is null || versions.Count == 0) + { + // Surface a loading state on the card the user clicked so they get feedback instead + // of staring at a frozen-looking UI for the ~1-2s round-trip. + var card = ModelCards.FirstOrDefault(c => c.CivitModel.Id == model.Id); + var previousIsLoading = card?.IsLoading ?? false; + var previousText = card?.Text; + if (card is not null) + { + card.IsLoading = true; + card.Text = "Loading..."; + } + + try + { + var refreshed = await civitApi.GetModelById(model.Id); + if (refreshed.ModelVersions is { Count: > 0 }) + { + // Mutate in place so subsequent clicks on the same card don't re-fetch β€” + // model is the live CivitModel that the card holds via CommandParameter. + model.ModelVersions = refreshed.ModelVersions; + versions = refreshed.ModelVersions; + } + } + catch (Exception e) + { + Logger.Warn(e, "Failed to refresh CivitModel {Id} when versions list was empty", model.Id); + } + finally + { + if (card is not null) + { + card.IsLoading = previousIsLoading; + card.Text = previousText; + } + } + } + if (versions is null || versions.Count == 0) { notificationService.Show( diff --git a/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowser/CivitDetailsPageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowser/CivitDetailsPageViewModel.cs index e8143d9cc..1f9152ac4 100644 --- a/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowser/CivitDetailsPageViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowser/CivitDetailsPageViewModel.cs @@ -65,6 +65,14 @@ IModelImportService modelImportService [NotifyPropertyChangedFor(nameof(CanGoNext), nameof(CanGoPrevious))] public required partial int CurrentIndex { get; set; } + /// + /// True when there's at least one preview image to render. Drives layout β€” when false, + /// the page collapses the carousel row so the description fills the space (e.g. when a + /// model was recovered via the tRPC fallback, which doesn't return per-version images). + /// + [ObservableProperty] + public partial bool HasImages { get; set; } + private List ignoredFileNameFormatVars = [ "seed", @@ -327,6 +335,18 @@ protected override async Task OnInitialLoadedAsync() .Subscribe() ); + // Mirror the same filter chain to drive HasImages β€” used by the view to collapse + // the carousel row when nothing's there to show. + AddDisposable( + imageCache + .Connect() + .Filter(showNsfwPredicate) + .Filter(img => img.Type == "image") + .ToCollection() + .ObserveOn(SynchronizationContext.Current!) + .Subscribe(c => HasImages = c.Count > 0) + ); + var includeTrainingDataPredicate = Observable .FromEventPattern(this, nameof(PropertyChanged)) .Where(x => x.EventArgs.PropertyName is nameof(ShowTrainingData)) @@ -415,7 +435,8 @@ out var preference settingsManager.ModelsDirectory, viewModel.CivitFile.Type, CivitModel.Type, - CivitModel.BaseModelType + CivitModel.BaseModelType, + viewModel.CivitFile.Name ); effectiveLocationKeyForPreference = viewModel.InstallLocations.FirstOrDefault(loc => @@ -492,7 +513,12 @@ await modelImportService.DoImport( SelectedVersion?.ModelVersion, viewModel.CivitFile, fileNameOverride, - inferenceDefaults: IsInferenceDefaultsEnabled ? SamplerCardViewModel : null + inferenceDefaults: IsInferenceDefaultsEnabled ? SamplerCardViewModel : null, + onImportComplete: () => + TryMoveDownloadedCheckpointToDiffusionModelsIfNeededAsync( + viewModel.CivitFile, + finalDestinationDir + ) ); notificationService.Show( @@ -538,7 +564,8 @@ private async Task ShowBulkDownloadDialogAsync() new DirectoryPath(settingsManager.ModelsDirectory), file.FileViewModel.CivitFile.Type, CivitModel.Type, - CivitModel.BaseModelType + CivitModel.BaseModelType, + file.FileViewModel.CivitFile.Name ); var folderName = Path.GetInvalidFileNameChars() @@ -558,7 +585,12 @@ await modelImportService.DoImport( destinationDir, file.ModelVersion, file.FileViewModel.CivitFile, - fileNameOverride + fileNameOverride, + onImportComplete: () => + TryMoveDownloadedCheckpointToDiffusionModelsIfNeededAsync( + file.FileViewModel.CivitFile, + destinationDir + ) ); } @@ -865,7 +897,8 @@ private ObservableCollection LoadInstallLocations(CivitFile selectedFile rootModelsDirectory, selectedFile.Type, CivitModel.Type, - CivitModel.BaseModelType + CivitModel.BaseModelType, + selectedFile.Name ); if (!downloadDirectory.ToString().EndsWith("Unknown")) @@ -886,9 +919,15 @@ var directory in downloadDirectory.EnumerateDirectories( } } - if (downloadDirectory.ToString().EndsWith(SharedFolderType.DiffusionModels.GetStringValue())) + var isGguf = Path.GetExtension(selectedFile.Name).Equals(".gguf", StringComparison.OrdinalIgnoreCase); + + if ( + downloadDirectory.ToString().EndsWith(SharedFolderType.DiffusionModels.GetStringValue()) + && !isGguf + ) { // also add StableDiffusion in case we have an AIO version + // (not for GGUFs, which are always UNet-only) var stableDiffusionDirectory = rootModelsDirectory.JoinDir( SharedFolderType.StableDiffusion.GetStringValue() ); @@ -907,7 +946,8 @@ private static DirectoryPath GetSharedFolderPath( DirectoryPath rootModelsDirectory, CivitFileType? fileType, CivitModelType modelType, - string? baseModelType + string? baseModelType, + string? fileName = null ) { if (fileType is CivitFileType.VAE) @@ -930,9 +970,217 @@ modelType is CivitModelType.Checkpoint return rootModelsDirectory.JoinDir(SharedFolderType.DiffusionModels.GetStringValue()); } + // GGUF checkpoints are always UNet-only, route directly to DiffusionModels + if ( + modelType is CivitModelType.Checkpoint + && fileName is not null + && Path.GetExtension(fileName).Equals(".gguf", StringComparison.OrdinalIgnoreCase) + ) + { + return rootModelsDirectory.JoinDir(SharedFolderType.DiffusionModels.GetStringValue()); + } + return rootModelsDirectory.JoinDir(modelType.ConvertTo().GetStringValue()); } + private async Task TryMoveDownloadedCheckpointToDiffusionModelsIfNeededAsync( + CivitFile civitFile, + DirectoryPath requestedDestinationDir + ) + { + if ( + civitFile.Type is not (CivitFileType.Model or CivitFileType.PrunedModel) + || CivitModel.Type is not CivitModelType.Checkpoint + ) + { + return; + } + + if (!settingsManager.IsLibraryDirSet) + { + return; + } + + if (!Path.GetExtension(civitFile.Name).Equals(".safetensors", StringComparison.OrdinalIgnoreCase)) + { + return; + } + + if (string.IsNullOrWhiteSpace(civitFile.Hashes.BLAKE3)) + { + return; + } + + try + { + await modelIndexService.RefreshIndex(); + + var matchingModels = (await modelIndexService.FindByHashAsync(civitFile.Hashes.BLAKE3)).ToList(); + if (matchingModels.Count == 0) + { + return; + } + + var modelsRoot = new DirectoryPath(settingsManager.ModelsDirectory); + if (!IsPathWithinDirectory(requestedDestinationDir, modelsRoot)) + { + return; + } + + var sourceModel = PickMostLikelyDownloadedModel( + matchingModels, + modelsRoot, + requestedDestinationDir + ); + if (sourceModel is null) + { + return; + } + + var sourceModelPath = new FilePath(sourceModel.GetFullPath(modelsRoot)); + if (!sourceModelPath.Exists) + { + return; + } + + var checkpointKind = await SafetensorClassifier.ClassifyAsync(sourceModelPath); + if (checkpointKind is not SafetensorCheckpointKind.UnetOnly) + { + return; + } + + var stableDiffusionRoot = modelsRoot.JoinDir(SharedFolderType.StableDiffusion.GetStringValue()); + var diffusionModelsRoot = modelsRoot.JoinDir(SharedFolderType.DiffusionModels.GetStringValue()); + + if (!IsPathWithinDirectory(sourceModelPath, stableDiffusionRoot)) + { + return; + } + + var sourceDirectory = sourceModelPath.Directory; + if (sourceDirectory is null) + { + return; + } + + var relativeSubDir = Path.GetRelativePath(stableDiffusionRoot, sourceDirectory); + var destinationDirectory = + relativeSubDir == "." + ? diffusionModelsRoot + : new DirectoryPath(Path.Combine(diffusionModelsRoot, relativeSubDir)); + + destinationDirectory.Create(); + + var originalModelName = sourceModelPath.Name; + var destinationModelPath = destinationDirectory.JoinFile(sourceModelPath.Name); + var movedModelPath = await sourceModelPath.MoveToWithIncrementAsync(destinationModelPath); + var wasRenamedForCollision = !movedModelPath.Name.Equals( + originalModelName, + StringComparison.OrdinalIgnoreCase + ); + + var cmInfoPath = sourceDirectory.JoinFile( + $"{Path.GetFileNameWithoutExtension(originalModelName)}{ConnectedModelInfo.FileExtension}" + ); + if (cmInfoPath.Exists) + { + await FileTransfers.MoveFileAsync( + cmInfoPath, + destinationDirectory.JoinFile( + $"{movedModelPath.NameWithoutExtension}{ConnectedModelInfo.FileExtension}" + ), + overwrite: true + ); + } + + foreach ( + var previewFile in sourceDirectory.EnumerateFiles( + $"{Path.GetFileNameWithoutExtension(originalModelName)}.preview.*", + SearchOption.TopDirectoryOnly + ) + ) + { + await FileTransfers.MoveFileAsync( + previewFile, + destinationDirectory.JoinFile( + $"{movedModelPath.NameWithoutExtension}.preview{previewFile.Extension}" + ), + overwrite: true + ); + } + + await modelIndexService.RefreshIndex(); + + Dispatcher.UIThread.Post(() => + { + var movedRelativePath = Path.GetRelativePath(modelsRoot, movedModelPath); + notificationService.Show( + "Model moved", + wasRenamedForCollision + ? $"Detected UNet-only checkpoint and moved it to \"Models/{movedRelativePath}\" (renamed from \"{originalModelName}\" to \"{movedModelPath.Name}\" because that filename already existed)." + : $"Detected UNet-only checkpoint and moved it to \"Models/{movedRelativePath}\"." + ); + }); + } + catch (Exception ex) + { + logger.LogWarning( + ex, + "Failed to evaluate or move downloaded checkpoint {FileName}", + civitFile.Name + ); + } + } + + private static LocalModelFile? PickMostLikelyDownloadedModel( + IEnumerable candidates, + DirectoryPath modelsRoot, + DirectoryPath requestedDestinationDir + ) + { + var existingCandidates = candidates + .Select(model => new { Model = model, FullPath = model.GetFullPath(modelsRoot) }) + .Where(x => File.Exists(x.FullPath)) + .ToList(); + + if (existingCandidates.Count == 0) + { + return null; + } + + var preferredCandidates = existingCandidates + .Where(x => IsPathWithinDirectory(x.FullPath, requestedDestinationDir)) + .ToList(); + + if (preferredCandidates.Count == 0) + { + return null; + } + + return preferredCandidates + .OrderByDescending(x => File.GetLastWriteTimeUtc(x.FullPath)) + .Select(x => x.Model) + .FirstOrDefault(); + } + + private static bool IsPathWithinDirectory(string candidatePath, string directoryPath) + { + var normalizedCandidate = Path.GetFullPath(candidatePath) + .TrimEnd(Path.DirectorySeparatorChar, Path.AltDirectorySeparatorChar); + var normalizedDirectory = Path.GetFullPath(directoryPath) + .TrimEnd(Path.DirectorySeparatorChar, Path.AltDirectorySeparatorChar); + + if (string.Equals(normalizedCandidate, normalizedDirectory, StringComparison.OrdinalIgnoreCase)) + { + return true; + } + + return normalizedCandidate.StartsWith( + normalizedDirectory + Path.DirectorySeparatorChar, + StringComparison.OrdinalIgnoreCase + ); + } + private IReadOnlyDictionary GetOtherMetadata(CivitImageGenerationDataResponse value) { var metaDict = new Dictionary(); diff --git a/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowserViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowserViewModel.cs index 589f4d255..5e27d0c21 100644 --- a/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowserViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/CheckpointBrowserViewModel.cs @@ -32,13 +32,19 @@ public partial class CheckpointBrowserViewModel : PageViewModelBase /// public CheckpointBrowserViewModel( CivitAiBrowserViewModel civitAiBrowserViewModel, + CivArchiveBrowserViewModel civArchiveBrowserViewModel, HuggingFacePageViewModel huggingFaceViewModel, OpenModelDbBrowserViewModel openModelDbBrowserViewModel ) { Pages = new List( new List( - [civitAiBrowserViewModel, huggingFaceViewModel, openModelDbBrowserViewModel] + [ + civitAiBrowserViewModel, + civArchiveBrowserViewModel, + huggingFaceViewModel, + openModelDbBrowserViewModel, + ] ).Select(vm => new TabItem { Header = vm.Header, Content = vm }) ); SelectedPage = Pages.FirstOrDefault(); diff --git a/StabilityMatrix.Avalonia/ViewModels/CheckpointsPageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/CheckpointsPageViewModel.cs index 62bea1c0f..16f5764d3 100644 --- a/StabilityMatrix.Avalonia/ViewModels/CheckpointsPageViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/CheckpointsPageViewModel.cs @@ -20,6 +20,7 @@ using StabilityMatrix.Avalonia.Helpers; using StabilityMatrix.Avalonia.Languages; using StabilityMatrix.Avalonia.Models; +using StabilityMatrix.Avalonia.Models.CheckpointOrganizer; using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Avalonia.ViewModels.CheckpointBrowser; @@ -56,6 +57,7 @@ public partial class CheckpointsPageViewModel( INotificationService notificationService, IMetadataImportService metadataImportService, IModelImportService modelImportService, + ModelOrganizationService modelOrganizationService, OpenModelDbManager openModelDbManager, IServiceManager dialogFactory, ICivitBaseModelTypeService baseModelTypeService, @@ -590,6 +592,93 @@ private async Task ScanMetadata(bool updateExistingMetadata) notificationService.Show("Scan Complete", message, NotificationType.Success); } + [RelayCommand] + private async Task OrganizeModelsAsync() + { + if (SelectedCategory == null) + { + notificationService.Show( + "No Category Selected", + "Please select a category to organize.", + NotificationType.Error + ); + return; + } + + var organizeDialogVm = dialogFactory.Get(); + organizeDialogVm.Initialize( + modelIndexService.ModelIndex.Values.SelectMany(x => x), + settingsManager.ModelsDirectory, + SelectedCategory.Path, + ShowModelsInSubfolders, + settingsManager.Settings.ModelOrganizationFileNamePattern + ); + + if (organizeDialogVm.Plan?.Items.Count == 0) + { + notificationService.Show( + "Nothing To Organize", + "No indexed models matched the selected category.", + NotificationType.Information + ); + return; + } + + var dialogResult = await organizeDialogVm.GetDialog().ShowAsync(); + + if (dialogResult == ContentDialogResult.Secondary) + { + switch (organizeDialogVm.RequestedMetadataAction) + { + case ModelOrganizationMetadataAction.ScanMissing: + await ScanMetadata(false); + break; + case ModelOrganizationMetadataAction.UpdateExisting: + await ScanMetadata(true); + break; + } + + return; + } + + if (dialogResult != ContentDialogResult.Primary) + return; + + var plan = organizeDialogVm.Plan!; + + IsLoading = true; + Progress.Text = "Organizing models..."; + Progress.IsIndeterminate = true; + + try + { + var result = await modelOrganizationService.ApplyPlan(plan); + await modelIndexService.RefreshIndex(); + + var summary = + $"{result.MovedCount} moved, {result.ConflictCount} conflicts, {result.SkippedCount} skipped."; + notificationService.Show( + "Organization Complete", + summary, + result.Errors.Count == 0 ? NotificationType.Success : NotificationType.Warning + ); + + if (result.Errors.Count > 0) + { + notificationService.ShowPersistent( + "Organization encountered errors", + string.Join(Environment.NewLine, result.Errors.Take(5)), + NotificationType.Warning + ); + } + } + finally + { + IsLoading = false; + Progress.ClearProgress(); + } + } + [RelayCommand] private Task OnItemClick(CheckpointFileViewModel item) { @@ -607,7 +696,10 @@ private Task OnItemClick(CheckpointFileViewModel item) [RelayCommand] private async Task ShowVersionDialog(CheckpointFileViewModel item) { - if (item.CheckpointFile is { HasCivitMetadata: false, HasOpenModelDbMetadata: false }) + if ( + item.CheckpointFile is + { HasCivitMetadata: false, HasOpenModelDbMetadata: false, HasCivArchiveMetadata: false } + ) { notificationService.Show( "Cannot show version dialog", @@ -625,6 +717,32 @@ private async Task ShowVersionDialog(CheckpointFileViewModel item) { await ShowOpenModelDbDialog(item); } + else if (item.CheckpointFile.HasCivArchiveMetadata) + { + ShowCivArchiveDialog(item); + } + } + + private void ShowCivArchiveDialog(CheckpointFileViewModel item) + { + var sourceUrl = item.CheckpointFile.ConnectedModelInfo?.SourceUrl; + if (string.IsNullOrWhiteSpace(sourceUrl)) + { + notificationService.Show( + "CivArchive link unavailable", + "This model was downloaded before navigation back to CivArchive was supported. Re-download from CivArchive to enable this.", + NotificationType.Warning + ); + return; + } + + var newVm = dialogFactory.Get(vm => + { + vm.RelativeUrl = sourceUrl; + return vm; + }); + + navigationService.NavigateTo(newVm, BetterSlideNavigationTransition.PageSlideFromRight); } private void ShowCivitVersionDialog(CheckpointFileViewModel item) diff --git a/StabilityMatrix.Avalonia/ViewModels/Controls/PaintCanvasViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Controls/PaintCanvasViewModel.cs index bc655863e..99d1444e8 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Controls/PaintCanvasViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Controls/PaintCanvasViewModel.cs @@ -24,14 +24,23 @@ namespace StabilityMatrix.Avalonia.ViewModels.Controls; [RegisterTransient] [ManagedService] -public partial class PaintCanvasViewModel(ILogger logger) : LoadableViewModelBase +public partial class PaintCanvasViewModel(ILogger logger) + : LoadableViewModelBase, + IDisposable { + private bool _disposed; public ConcurrentDictionary TemporaryPaths { get; set; } = new(); [ObservableProperty] [NotifyCanExecuteChangedFor(nameof(UndoCommand))] private ImmutableList paths = []; + /// + /// Stack of undone paths for redo functionality. + /// + [JsonIgnore] + private readonly Stack redoStack = new(); + [ObservableProperty] private Color? paintBrushColor = Colors.White; @@ -43,6 +52,13 @@ public partial class PaintCanvasViewModel(ILogger logger) [ObservableProperty] private double paintBrushAlpha = 1; + /// + /// Feathering amount for soft brush edges. 0 = hard edge, 1 = fully soft/blurred. + /// UI typically shows this inverted as "Hardness" (100% = no feathering). + /// + [ObservableProperty] + private double paintBrushFeathering = 0; + [ObservableProperty] private double currentPenPressure; @@ -58,6 +74,18 @@ public partial class PaintCanvasViewModel(ILogger logger) [ObservableProperty] private Size canvasSize = Size.Empty; + /// + /// Whether drawing is enabled. Set to false to disable brush strokes (e.g., for image reference layers). + /// + [ObservableProperty] + private bool isDrawingEnabled = true; + + /// + /// Whether to draw shapes (Rectangle/Ellipse) as strokes only instead of filled. + /// + [ObservableProperty] + private bool isShapeStrokeOnly; + [JsonIgnore] private SKCanvas? SourceCanvas { set; get; } @@ -67,8 +95,9 @@ public partial class PaintCanvasViewModel(ILogger logger) new() { ["Background"] = new SKLayer(), - ["Images"] = new SKLayer(), - ["Brush"] = new SKLayer(), + ["Images"] = new SKLayer(), // Layers BELOW the selected layer + ["Brush"] = new SKLayer(), // The currently selected/active layer + ["Overlay"] = new SKLayer(), // Layers ABOVE the selected layer }; [JsonIgnore] @@ -77,9 +106,109 @@ public partial class PaintCanvasViewModel(ILogger logger) [JsonIgnore] private SKLayer ImagesLayer => Layers["Images"]; + [JsonIgnore] + private SKLayer OverlayLayer => Layers["Overlay"]; + [JsonIgnore] private SKLayer BackgroundLayer => Layers["Background"]; + /// + /// Cached bitmap of all finalized paths. Cleared when paths change. + /// + [JsonIgnore] + private SKImage? cachedPathsImage; + + /// + /// Number of paths that were rendered into the cached image. + /// Used to determine if cache needs to be updated. + /// + [JsonIgnore] + private int cachedPathsCount; + + /// + /// Cached surface for temporary paths during active drawing. + /// Allows incremental rendering of long strokes. + /// + [JsonIgnore] + private SKSurface? tempPathSurface; + + /// + /// Tracks how many points have been rendered to the temp path surface per pointer ID. + /// + [JsonIgnore] + private readonly ConcurrentDictionary tempPathRenderedPoints = new(); + + /// + /// Whether to use GPU-accelerated surfaces when available. + /// + [JsonIgnore] + public bool UseGpuAcceleration { get; set; } = true; + + /// + /// Indicates whether GPU acceleration is currently active. + /// + [JsonIgnore] + public bool IsUsingGpu { get; private set; } + + /// + /// Debug flag: Set to true to log GPU/CPU surface creation. + /// + [JsonIgnore] + public static bool LogRenderingMode { get; set; } +#if DEBUG + = true; +#endif + + /// + /// Whether to show a checkerboard pattern for transparent areas. + /// + [JsonIgnore] + public bool ShowCheckerboardBackground { get; set; } = true; + + /// + /// Size of each checkerboard square in pixels. + /// + private const int CheckerboardSquareSize = 16; + + /// + /// Light color for the checkerboard pattern. + /// + private static readonly SKColor CheckerboardLight = new(220, 220, 220); + + /// + /// Dark color for the checkerboard pattern. + /// + private static readonly SKColor CheckerboardDark = new(180, 180, 180); + + /// + /// Cached checkerboard shader for efficient rendering. + /// + [JsonIgnore] + private SKShader? cachedCheckerboardShader; + + /// + /// The canvas size that the cached checkerboard shader was created for. + /// + [JsonIgnore] + private Size cachedCheckerboardSize; + + /// + /// Whether to show a grid overlay for alignment assistance. + /// + [ObservableProperty] + private bool showGridOverlay; + + /// + /// Number of grid divisions (e.g., 3 for rule of thirds). + /// + [ObservableProperty] + private int gridDivisions = 3; + + /// + /// Color for the grid overlay lines. + /// + private static readonly SKColor GridLineColor = new(128, 128, 128, 180); + [JsonIgnore] public SKBitmap? BackgroundImage { @@ -106,16 +235,64 @@ public SKBitmap? BackgroundImage [JsonIgnore] public Action? RefreshCanvas { get; set; } + /// + /// Sets or clears a bitmap for a compositing layer. + /// Used for displaying other layers when compositing in a layered editor. + /// + /// + /// Layer name: "Images" for layers below the selected layer, + /// "Overlay" for layers above the selected layer, + /// or legacy "OtherLayers" which maps to "Images" for backwards compatibility. + /// + /// The bitmap to set, or null to clear + public void SetLayerBitmap(string name, SKBitmap? bitmap) + { + // Map legacy name to new name for backwards compatibility + var layerName = name switch + { + "OtherLayers" => "Images", // Legacy: all other layers went to Images + "LayersBelow" => "Images", + "LayersAbove" => "Overlay", + "CurrentImage" => "Brush", // Selected image layer bitmap goes to Brush layer + _ => name, + }; + + if (!Layers.TryGetValue(layerName, out var layer)) + { + return; + } + + // Dispose old bitmaps before replacing to prevent memory leaks + lock (layer) + { + foreach (var oldBitmap in layer.Bitmaps) + { + oldBitmap.Dispose(); + } + + layer.Bitmaps = bitmap is not null ? [bitmap] : []; + } + } + public void SetSourceCanvas(SKCanvas canvas) { - ArgumentNullException.ThrowIfNull(canvas, nameof(canvas)); + ArgumentNullException.ThrowIfNull(canvas); SourceCanvas = canvas; } public void LoadCanvasFromBitmap(SKBitmap bitmap) { - ImagesLayer.Bitmaps = [bitmap]; + // Dispose old bitmaps and invalidate cache + lock (ImagesLayer) + { + foreach (var oldBitmap in ImagesLayer.Bitmaps) + { + oldBitmap.Dispose(); + } + ImagesLayer.Bitmaps = [bitmap]; + } + InvalidatePathCache(); RefreshCanvas?.Invoke(); } @@ -130,175 +307,1243 @@ public void Undo() return; } + // Push the removed path to redo stack + var removedPath = currentPaths[^1]; + redoStack.Push(removedPath); + RedoCommand.NotifyCanExecuteChanged(); + Paths = currentPaths.RemoveAt(currentPaths.Count - 1); - RefreshCanvas?.Invoke(); - } + // Invalidate cache since paths changed + InvalidatePathCache(); - private bool CanExecuteUndo() - { - return Paths.Count > 0; + RefreshCanvas?.Invoke(); } - public SKImage? RenderToWhiteChannelImage() + [RelayCommand(CanExecute = nameof(CanExecuteRedo))] + public void Redo() { - using var _ = CodeTimer.StartDebug(); - - if (CanvasSize == Size.Empty) + if (redoStack.Count == 0) { - logger.LogWarning($"RenderToImage: {nameof(CanvasSize)} is not set, returning null."); - return null; + return; } - using var surface = SKSurface.Create(new SKImageInfo(CanvasSize.Width, CanvasSize.Height)); + var pathToRestore = redoStack.Pop(); + Paths = Paths.Add(pathToRestore); + RedoCommand.NotifyCanExecuteChanged(); - RenderToSurface(surface); + // Invalidate cache since paths changed + InvalidatePathCache(); - using var originalImage = surface.Snapshot(); - // Replace all colors to white (255, 255, 255), keep original alpha - // csharpier-ignore - using var colorFilter = SKColorFilter.CreateColorMatrix( - [ - // R, G, B, A, Bias - -1, 0, 0, 0, 255, - 0, -1, 0, 0, 255, - 0, 0, -1, 0, 255, - 0, 0, 0, 1, 0 - ] - ); + RefreshCanvas?.Invoke(); + } - using var paint = new SKPaint(); - paint.ColorFilter = colorFilter; + /// + /// Invalidates the cached paths image. Call when paths are modified externally. + /// + public void InvalidatePathCache() + { + cachedPathsImage?.Dispose(); + cachedPathsImage = null; + cachedPathsCount = 0; + } - surface.Canvas.Clear(SKColors.Transparent); - surface.Canvas.DrawImage(originalImage, originalImage.Info.Rect, paint); + /// + /// Called when the Paths property changes. + /// Invalidates the cache since we have a completely new set of paths. + /// + partial void OnPathsChanged(ImmutableList value) + { + // When paths change (e.g., layer switch), invalidate the cache + // since the cached image is from the old paths + InvalidatePathCache(); + } - return surface.Snapshot(); + private bool CanExecuteUndo() + { + return Paths.Count > 0; } - public SKImage? RenderToImage() + private bool CanExecuteRedo() { - using var _ = CodeTimer.StartDebug(); + return redoStack.Count > 0; + } - if (CanvasSize == Size.Empty) + /// + /// Clears the redo stack. Call when new paths are added (not via redo). + /// + public void ClearRedoStack() + { + if (redoStack.Count > 0) { - logger.LogWarning($"RenderToImage: {nameof(CanvasSize)} is not set, returning null."); - return null; + redoStack.Clear(); + RedoCommand.NotifyCanExecuteChanged(); } + } - using var surface = SKSurface.Create(new SKImageInfo(CanvasSize.Width, CanvasSize.Height)); + #region Shape Tool State - RenderToSurface(surface); + /// + /// Starting point for shape drawing (Rectangle/Ellipse tools). + /// + [ObservableProperty] + [property: JsonIgnore] + private SKPoint? shapeStartPoint; - return surface.Snapshot(); - } + /// + /// Pointer ID for the current shape drawing operation. + /// + [ObservableProperty] + [property: JsonIgnore] + private long shapePointerId; - public void RenderToSurface( - SKSurface surface, - bool renderBackgroundFill = false, - bool renderBackgroundImage = false - ) - { - // Initialize canvas layers - foreach (var layer in Layers.Values) - { - lock (layer) - { - if (layer.Surface is null) - { - layer.Surface = SKSurface.Create(new SKImageInfo(CanvasSize.Width, CanvasSize.Height)); - /*layer.Surface = SKSurface.Create( - surface.Context, - true, - new SKImageInfo(CanvasSize.Width, CanvasSize.Height) - );*/ - } - else - { - // If we need to resize: - var currentInfo = layer.Surface.Canvas.DeviceClipBounds; - if (currentInfo.Width != CanvasSize.Width || currentInfo.Height != CanvasSize.Height) - { - // Dispose the old surface - layer.Surface.Dispose(); + /// + /// Returns true if the currently selected tool is a shape tool. + /// + [JsonIgnore] + public bool IsShapeTool => SelectedTool is PaintCanvasTool.Rectangle or PaintCanvasTool.Ellipse; - // Create a brand-new SKSurface with the new size - layer.Surface = SKSurface.Create( - new SKImageInfo(CanvasSize.Width, CanvasSize.Height) - ); - } - else - { - // No resize needed, just clear - layer.Surface.Canvas.Clear(SKColors.Transparent); - } - } - } - } + #endregion - // Render all layer images in order - foreach (var (layerName, layer) in Layers) - { - // Skip background image if not requested - if (!renderBackgroundImage && layerName == "Background") - { - continue; - } + #region Move Tool State - lock (layer) - { - var layerCanvas = layer.Surface!.Canvas; - foreach (var bitmap in layer.Bitmaps) - { - layerCanvas.DrawBitmap(bitmap, new SKPoint(0, 0)); - } - } - } + /// + /// Starting point for move operations. + /// + [ObservableProperty] + [property: JsonIgnore] + private SKPoint? moveStartPoint; - // Render paint layer - var paintLayerCanvas = BrushLayer.Surface!.Canvas; + /// + /// Layer offset at the start of a move operation. + /// + [ObservableProperty] + [property: JsonIgnore] + private SKPoint moveStartOffset; - using var paint = new SKPaint(); + /// + /// Returns true if the currently selected tool is the move tool. + /// + [JsonIgnore] + public bool IsMoveTool => SelectedTool == PaintCanvasTool.Move; - // Draw the paths - foreach (var penPath in Paths) - { - RenderPenPath(paintLayerCanvas, penPath, paint); - } + /// + /// Callback invoked when the move tool adjusts the image layer position. + /// Parameters: (newOffsetX, newOffsetY) - the new absolute offset position. + /// + [JsonIgnore] + public Action? OnMoveToolDrag { get; set; } - foreach (var penPath in TemporaryPaths.Values) - { - RenderPenPath(paintLayerCanvas, penPath, paint); - } + /// + /// Callback to get the current image layer offset when starting a move. + /// Returns (currentOffsetX, currentOffsetY). + /// + [JsonIgnore] + public Func<(double X, double Y)>? GetCurrentMoveOffset { get; set; } - // Draw background color - surface.Canvas.Clear(SKColors.Transparent); + /// + /// Starts a move operation at the given position. + /// + public void StartMove(SKPoint position, double currentOffsetX, double currentOffsetY) + { + MoveStartPoint = position; + MoveStartOffset = new SKPoint((float)currentOffsetX, (float)currentOffsetY); + } - // Draw the layers to the main surface - foreach (var layer in Layers.Values) - { - lock (layer) - { - layer.Surface!.Canvas.Flush(); + /// + /// Updates the move during drag, calculating delta from start position. + /// + public void UpdateMove(SKPoint currentPoint) + { + if (!MoveStartPoint.HasValue) + return; - surface.Canvas.DrawSurface(layer.Surface!, new SKPoint(0, 0)); - } - } + var deltaX = currentPoint.X - MoveStartPoint.Value.X; + var deltaY = currentPoint.Y - MoveStartPoint.Value.Y; - surface.Canvas!.Flush(); + // Invoke callback with new absolute offset + OnMoveToolDrag?.Invoke(MoveStartOffset.X + deltaX, MoveStartOffset.Y + deltaY); } - private static void RenderPenPath(SKCanvas canvas, PenPath penPath, SKPaint paint) + /// + /// Ends the current move operation. + /// + public void EndMove() { - if (penPath.Points.Count == 0) - { - return; - } + MoveStartPoint = null; + } + + #endregion + + #region Canvas Commands + + /// + /// Clears all paths from the canvas. + /// + [RelayCommand] + public void ClearCanvas() + { + Paths = ImmutableList.Empty; + TemporaryPaths.Clear(); + redoStack.Clear(); + RedoCommand.NotifyCanExecuteChanged(); + InvalidatePathCache(); + RefreshCanvas?.Invoke(); + } + + #endregion + + #region Tool Selection Commands + + [RelayCommand] + public void SelectBrushTool() => SelectedTool = PaintCanvasTool.PaintBrush; + + [RelayCommand] + public void SelectEraserTool() => SelectedTool = PaintCanvasTool.Eraser; + + [RelayCommand] + public void SelectRectangleTool() => SelectedTool = PaintCanvasTool.Rectangle; + + [RelayCommand] + public void SelectEllipseTool() => SelectedTool = PaintCanvasTool.Ellipse; + + [RelayCommand] + public void SelectMoveTool() => SelectedTool = PaintCanvasTool.Move; + + #endregion + + #region Brush Size Commands + + [RelayCommand] + public void IncreaseBrushSize() + { + PaintBrushSize = Math.Min(100, PaintBrushSize + 5); + } + + [RelayCommand] + public void DecreaseBrushSize() + { + PaintBrushSize = Math.Max(1, PaintBrushSize - 5); + } + + #endregion + + #region Shape Drawing Helpers + + /// + /// Starts shape drawing at the given position. + /// + public void StartShapeDrawing(SKPoint position, long pointerId) + { + ShapeStartPoint = position; + ShapePointerId = pointerId; + } + + /// + /// Updates the shape preview during drag. + /// + public void UpdateShapePreview(SKPoint currentPoint) + { + if (!ShapeStartPoint.HasValue) + return; + + var bounds = CreateBoundsFromPoints(ShapeStartPoint.Value, currentPoint); + var previewPath = new PenPath + { + FillColor = PaintBrushSKColor.WithAlpha((byte)(PaintBrushAlpha * 255)), + PathType = + SelectedTool == PaintCanvasTool.Rectangle ? PenPathType.Rectangle : PenPathType.Ellipse, + Bounds = bounds, + IsStrokeOnly = IsShapeStrokeOnly, + StrokeWidth = (float)PaintBrushSize, + }; + TemporaryPaths[ShapePointerId] = previewPath; + } + + /// + /// Finalizes the shape drawing and adds it to paths. + /// + /// The created shape path, or null if shape was too small. + public PenPath? FinalizeShape(SKPoint endPoint) + { + if (!ShapeStartPoint.HasValue) + return null; + + var bounds = CreateBoundsFromPoints(ShapeStartPoint.Value, endPoint); + + // Only create shape if it has meaningful size + if (bounds.Width <= 2 || bounds.Height <= 2) + { + ShapeStartPoint = null; + TemporaryPaths.TryRemove(ShapePointerId, out _); + return null; + } + + var shapePath = new PenPath + { + FillColor = PaintBrushSKColor.WithAlpha((byte)(PaintBrushAlpha * 255)), + IsErase = SelectedTool == PaintCanvasTool.Eraser, + PathType = + SelectedTool == PaintCanvasTool.Rectangle ? PenPathType.Rectangle : PenPathType.Ellipse, + Bounds = bounds, + IsStrokeOnly = IsShapeStrokeOnly, + StrokeWidth = (float)PaintBrushSize, + }; + + Paths = Paths.Add(shapePath); + ClearRedoStack(); // New path added, clear redo history + ShapeStartPoint = null; + TemporaryPaths.TryRemove(ShapePointerId, out _); + + return shapePath; + } + + /// + /// Cancels the current shape drawing operation. + /// + public void CancelShapeDrawing() + { + ShapeStartPoint = null; + TemporaryPaths.TryRemove(ShapePointerId, out _); + } + + private static SKRect CreateBoundsFromPoints(SKPoint start, SKPoint end) + { + return new SKRect( + Math.Min(start.X, end.X), + Math.Min(start.Y, end.Y), + Math.Max(start.X, end.X), + Math.Max(start.Y, end.Y) + ); + } + + #endregion + + #region Paint Bucket / Flood Fill + + [RelayCommand] + public void SelectPaintBucketTool() => SelectedTool = PaintCanvasTool.PaintBucket; + + /// + /// Performs a flood fill at the specified point. + /// Returns the created path, or null if fill wasn't possible. + /// + public PenPath? FloodFillAt(SKPoint clickPoint, SKColor fillColor) + { + if (CanvasSize == Size.Empty) + return null; + + var x = (int)clickPoint.X; + var y = (int)clickPoint.Y; + + // Bounds check + if (x < 0 || x >= CanvasSize.Width || y < 0 || y >= CanvasSize.Height) + return null; + + // Get the current state of the canvas on CPU to avoid GPU context threading issues ("Could not allocate vertices") + // and to ensure we don't accidentally fill the checkerboard pattern. + using var sourceBitmap = GetFlattenedContentBitmap(); + var targetColor = sourceBitmap.GetPixel(x, y); + + // Don't fill if clicking on the same color (with some tolerance for anti-aliasing) + if (ColorsAreSimilar(targetColor, fillColor, tolerance: 30)) + return null; + + // Create a surface for drawing the fill result + using var surface = SKSurface.Create( + new SKImageInfo(CanvasSize.Width, CanvasSize.Height, SKColorType.Rgba8888, SKAlphaType.Premul) + ); + var canvas = surface.Canvas; + canvas.Clear(SKColors.Transparent); + + // Perform flood fill and draw horizontal spans + var hasContent = ScanlineFillWithCanvas(sourceBitmap, canvas, x, y, targetColor, fillColor); + + if (!hasContent) + { + return null; + } + + // Copy the surface to the bitmap + canvas.Flush(); + using var filledImage = surface.Snapshot(); + + // Create a new bitmap with the filled content + var resultBitmap = new SKBitmap( + CanvasSize.Width, + CanvasSize.Height, + SKColorType.Rgba8888, + SKAlphaType.Premul + ); + using var resultCanvas = new SKCanvas(resultBitmap); + resultCanvas.DrawImage(filledImage, 0, 0); + resultCanvas.Flush(); + + // Create a bitmap path with the fill result + var fillPath = new PenPath + { + PathType = PenPathType.Bitmap, + FillColor = fillColor, + BitmapData = resultBitmap, + Bounds = new SKRect(0, 0, CanvasSize.Width, CanvasSize.Height), + }; + + Paths = Paths.Add(fillPath); + ClearRedoStack(); // New path added, clear redo history + InvalidatePathCache(); + RefreshCanvas?.Invoke(); + + return fillPath; + } + + /// + /// Generates a flattened bitmap of the current canvas content (Layers + Paths). + /// Runs strictly on CPU to avoid GPU threading/context issues during Flood Fill. + /// Ignores checkerboard background to ensure correct filling of transparent areas. + /// + private SKBitmap GetFlattenedContentBitmap() + { + var width = CanvasSize.Width; + var height = CanvasSize.Height; + var bitmap = new SKBitmap(width, height, SKColorType.Rgba8888, SKAlphaType.Premul); + + using var canvas = new SKCanvas(bitmap); + canvas.Clear(SKColors.Transparent); + + // Draw all layers in order + foreach (var (name, layer) in Layers) + { + lock (layer) + { + foreach (var layerBitmap in layer.Bitmaps) + { + canvas.DrawBitmap(layerBitmap, 0, 0); + } + + // If this is the active brush layer, also render the active vector paths + // We render them freshly here on CPU to avoid using the GPU-backed cache from a different thread + if (name == "Brush") + { + using var paint = new SKPaint(); + foreach (var path in Paths) + { + RenderPenPath(canvas, path, paint); + } + } + } + } + + canvas.Flush(); + return bitmap; + } + + /// + /// Scanline flood fill that draws horizontal spans to an SKCanvas. + /// Returns true if any pixels were filled. + /// + private static bool ScanlineFillWithCanvas( + SKBitmap source, + SKCanvas canvas, + int startX, + int startY, + SKColor targetColor, + SKColor fillColor + ) + { + var width = source.Width; + var height = source.Height; + + // Use SKBitmap.Pixels which is platform-agnostic (returns SKColor[]) + var sourcePixels = source.Pixels; + + var visited = new bool[width * height]; + var queue = new Queue<(int x, int y)>(); + queue.Enqueue((startX, startY)); + + // Collect horizontal spans to draw + var spans = new List<(int y, int left, int right)>(); + + // Increased tolerance to better catch anti-aliased edges + const int Tolerance = 50; + // Increased expansion to ensuring we fully cover the semi-transparent border pixels + const float Expand = 1.5f; + + using var paint = new SKPaint + { + Color = fillColor, + Style = SKPaintStyle.Fill, + IsAntialias = true, // Smooth edges for the dilated rects + BlendMode = SKBlendMode.Src, // Replace mode prevents alpha buildup on overlapping dilated scanlines + }; + + while (queue.Count > 0) + { + var (x, y) = queue.Dequeue(); + + // Bounds check + if (x < 0 || x >= width || y < 0 || y >= height) + continue; + + var index = y * width + x; + if (visited[index]) + continue; + + var pixel = sourcePixels[index]; + if (!ColorsAreSimilar(pixel, targetColor, tolerance: Tolerance)) + continue; + + // Mark as visited + visited[index] = true; + + // Scanline approach: find the entire horizontal span + var left = x; + var right = x; + + // Extend left + while (left > 0) + { + var leftIndex = y * width + (left - 1); + if (visited[leftIndex]) + break; + var leftPixel = sourcePixels[leftIndex]; + if (!ColorsAreSimilar(leftPixel, targetColor, tolerance: Tolerance)) + break; + left--; + visited[leftIndex] = true; + } + + // Extend right + while (right < width - 1) + { + var rightIndex = y * width + (right + 1); + if (visited[rightIndex]) + break; + var rightPixel = sourcePixels[rightIndex]; + if (!ColorsAreSimilar(rightPixel, targetColor, tolerance: Tolerance)) + break; + right++; + visited[rightIndex] = true; + } + + // Draw this span as a filled rectangle with slight expansion + // Using DrawRect with float coordinates allows sub-pixel expansion + canvas.DrawRect( + left - Expand, + y - Expand, + (right - left + 1) + (Expand * 2), + 1 + (Expand * 2), + paint + ); + + // Queue pixels above and below the span + for (var i = left; i <= right; i++) + { + if (y > 0) + { + var aboveIndex = (y - 1) * width + i; + if (!visited[aboveIndex]) + { + var abovePixel = sourcePixels[aboveIndex]; + if (ColorsAreSimilar(abovePixel, targetColor, tolerance: Tolerance)) + queue.Enqueue((i, y - 1)); + } + } + + if (y < height - 1) + { + var belowIndex = (y + 1) * width + i; + if (!visited[belowIndex]) + { + var belowPixel = sourcePixels[belowIndex]; + if (ColorsAreSimilar(belowPixel, targetColor, tolerance: Tolerance)) + queue.Enqueue((i, y + 1)); + } + } + } + } + + // Check if anything was filled (at least one visited pixel) + foreach (var v in visited) + { + if (v) + return true; + } + + return false; + } + + private static bool ColorsAreSimilar(SKColor a, SKColor b, int tolerance) + { + // Handle transparent pixels specially + if (a.Alpha < 10 && b.Alpha < 10) + return true; + if (a.Alpha < 10 || b.Alpha < 10) + return Math.Abs(a.Alpha - b.Alpha) <= tolerance; + + return Math.Abs(a.Red - b.Red) <= tolerance + && Math.Abs(a.Green - b.Green) <= tolerance + && Math.Abs(a.Blue - b.Blue) <= tolerance + && Math.Abs(a.Alpha - b.Alpha) <= tolerance; + } + + #endregion + + public SKImage? RenderToWhiteChannelImage() + { + using var _ = CodeTimer.StartDebug(); + + if (CanvasSize == Size.Empty) + { + logger.LogWarning($"RenderToImage: {nameof(CanvasSize)} is not set, returning null."); + return null; + } + + using var surface = SKSurface.Create(new SKImageInfo(CanvasSize.Width, CanvasSize.Height)); + + RenderToSurface(surface); + + using var originalImage = surface.Snapshot(); + // Replace all colors to white (255, 255, 255), keep original alpha + // csharpier-ignore + using var colorFilter = SKColorFilter.CreateColorMatrix( + [ + // R, G, B, A, Bias + -1, 0, 0, 0, 255, + 0, -1, 0, 0, 255, + 0, 0, -1, 0, 255, + 0, 0, 0, 1, 0 + ] + ); + + using var paint = new SKPaint(); + paint.ColorFilter = colorFilter; + + surface.Canvas.Clear(SKColors.Transparent); + surface.Canvas.DrawImage(originalImage, originalImage.Info.Rect, paint); + + return surface.Snapshot(); + } + + public SKImage? RenderToImage() + { + using var _ = CodeTimer.StartDebug(); + + if (CanvasSize == Size.Empty) + { + logger.LogWarning($"RenderToImage: {nameof(CanvasSize)} is not set, returning null."); + return null; + } + + using var surface = SKSurface.Create(new SKImageInfo(CanvasSize.Width, CanvasSize.Height)); + + RenderToSurface(surface); + + return surface.Snapshot(); + } + + /// + /// Extracts masks for multiple colors in a single render pass. + /// More efficient than calling ExtractMaskByColor multiple times. + /// + /// The colors to extract masks for. + /// RGB tolerance for color matching (0-255). Default 10. + /// A dictionary mapping each color to its mask image. + public Dictionary ExtractMasksByColors( + IReadOnlyList targetColors, + int tolerance = 10 + ) + { + using var _ = CodeTimer.StartDebug(); + + var results = new Dictionary(); + + if (CanvasSize == Size.Empty || targetColors.Count == 0) + return results; + + // Render canvas once + using var renderedImage = RenderToImage(); + if (renderedImage is null) + return results; + + using var sourceBitmap = SKBitmap.FromImage(renderedImage); + var srcPixels = sourceBitmap.Pixels; // SKColor[] array - fast direct access + var pixelCount = srcPixels.Length; + + // Create result bitmaps for each color + var resultBitmaps = new Dictionary(); + var resultPixels = new Dictionary(); + foreach (var color in targetColors) + { + var bitmap = new SKBitmap( + sourceBitmap.Width, + sourceBitmap.Height, + SKColorType.Rgba8888, + SKAlphaType.Premul + ); + resultBitmaps[color] = bitmap; + resultPixels[color] = new SKColor[pixelCount]; + } + + // Single pass through pixels, check all colors + for (var i = 0; i < pixelCount; i++) + { + var pixel = srcPixels[i]; + + foreach (var targetColor in targetColors) + { + var matches = + Math.Abs(pixel.Red - targetColor.Red) <= tolerance + && Math.Abs(pixel.Green - targetColor.Green) <= tolerance + && Math.Abs(pixel.Blue - targetColor.Blue) <= tolerance + && pixel.Alpha > 0; + + resultPixels[targetColor][i] = matches ? SKColors.White : SKColors.Transparent; + } + } + + // Set pixels and convert bitmaps to images + foreach (var (color, bitmap) in resultBitmaps) + { + bitmap.Pixels = resultPixels[color]; + results[color] = SKImage.FromBitmap(bitmap); + bitmap.Dispose(); + } + + return results; + } + + /// + /// Extracts a mask from the canvas where pixels match the target color. + /// Returns a grayscale mask where white = match, transparent = no match. + /// Used for regional prompting to separate painted regions by color. + /// + /// The color to extract. + /// RGB tolerance for color matching (0-255). Default 10. + /// A mask image, or null if canvas is empty. + public SKImage? ExtractMaskByColor(SKColor targetColor, int tolerance = 10) + { + using var _ = CodeTimer.StartDebug(); + + if (CanvasSize == Size.Empty) + { + logger.LogWarning($"ExtractMaskByColor: {nameof(CanvasSize)} is not set, returning null."); + return null; + } + + // First render the canvas to get the painted image + using var renderedImage = RenderToImage(); + if (renderedImage is null) + return null; + + using var bitmap = SKBitmap.FromImage(renderedImage); + var resultBitmap = new SKBitmap( + bitmap.Width, + bitmap.Height, + SKColorType.Rgba8888, + SKAlphaType.Premul + ); + + // Use Pixels array for fast direct access + var srcPixels = bitmap.Pixels; + var dstPixels = new SKColor[srcPixels.Length]; + + for (var i = 0; i < srcPixels.Length; i++) + { + var pixel = srcPixels[i]; + + // Check if pixel matches target color within tolerance + var matches = + Math.Abs(pixel.Red - targetColor.Red) <= tolerance + && Math.Abs(pixel.Green - targetColor.Green) <= tolerance + && Math.Abs(pixel.Blue - targetColor.Blue) <= tolerance + && pixel.Alpha > 0; + + dstPixels[i] = matches ? SKColors.White : SKColors.Transparent; + } + + resultBitmap.Pixels = dstPixels; + return SKImage.FromBitmap(resultBitmap); + } + + /// + /// Gets all unique colors present in the painted canvas (excluding transparent). + /// Used for regional prompting to detect which colors the user has painted. + /// + /// A list of unique colors found in the canvas. + public IReadOnlyList GetPaintedColors() + { + // Default palette colors to match against + return GetPaintedColors( + [ + new SKColor(255, 0, 0), // Red + new SKColor(255, 128, 0), // Orange + new SKColor(255, 255, 0), // Yellow + new SKColor(0, 255, 0), // Green + new SKColor(0, 128, 255), // Blue + new SKColor(128, 0, 255), // Purple + ] + ); + } + + /// + /// Gets a list of palette colors that have been painted on the canvas. + /// Uses tolerance matching to handle anti-aliased edges. + /// + /// The palette colors to match against. + /// RGB tolerance for color matching (default 40 to handle anti-aliasing). + /// A list of palette colors that were found in the canvas. + public IReadOnlyList GetPaintedColors(IReadOnlyList paletteColors, int tolerance = 40) + { + if (CanvasSize == Size.Empty) + return []; + + using var renderedImage = RenderToImage(); + if (renderedImage is null) + return []; + + using var bitmap = SKBitmap.FromImage(renderedImage); + var foundPaletteColors = new HashSet(); + + // Use Pixels array for fast direct access + var pixels = bitmap.Pixels; + var paletteCount = paletteColors.Count; + + foreach (var pixel in pixels) + { + if (pixel.Alpha < 128) // Skip mostly transparent pixels + continue; + + // Find the closest palette color + for (var p = 0; p < paletteCount; p++) + { + var paletteColor = paletteColors[p]; + if (!ColorMatchesWithTolerance(pixel, paletteColor, tolerance)) + continue; + + foundPaletteColors.Add(paletteColor); + + // Early exit if we've found all palette colors + if (foundPaletteColors.Count == paletteCount) + return foundPaletteColors.ToList(); + + break; + } + } + + return foundPaletteColors.ToList(); + } + + /// + /// Checks if two colors match within the specified RGB tolerance. + /// + private static bool ColorMatchesWithTolerance(SKColor a, SKColor b, int tolerance) + { + return Math.Abs(a.Red - b.Red) <= tolerance + && Math.Abs(a.Green - b.Green) <= tolerance + && Math.Abs(a.Blue - b.Blue) <= tolerance; + } + + public void RenderToSurface( + SKSurface surface, + bool renderBackgroundFill = false, + bool renderBackgroundImage = false + ) + { + var grContext = surface.Context; + var useGpu = UseGpuAcceleration && grContext != null; + IsUsingGpu = useGpu; + + // Initialize canvas layers + foreach (var layer in Layers.Values) + { + lock (layer) + { + var needsNewSurface = layer.Surface is null; + if (!needsNewSurface) + { + // Check if we need to resize + var currentInfo = layer.Surface!.Canvas.DeviceClipBounds; + needsNewSurface = + currentInfo.Width != CanvasSize.Width || currentInfo.Height != CanvasSize.Height; + } + + if (needsNewSurface) + { + // Dispose old surface if exists + layer.Surface?.Dispose(); + + var imageInfo = new SKImageInfo(CanvasSize.Width, CanvasSize.Height); + + // Try GPU surface first if available + if (useGpu) + { + layer.Surface = SKSurface.Create(grContext!, budgeted: true, imageInfo); + + // Fallback to CPU if GPU surface creation failed + if (layer.Surface is null) + { + if (LogRenderingMode) + { + logger.LogWarning( + "GPU surface creation failed, falling back to CPU for layer" + ); + } + layer.Surface = SKSurface.Create(imageInfo); + } + else if (LogRenderingMode) + { + logger.LogDebug("Created GPU-accelerated surface for layer"); + } + } + else + { + layer.Surface = SKSurface.Create(imageInfo); + if (LogRenderingMode) + { + logger.LogDebug("Created CPU surface for layer (GPU not available or disabled)"); + } + } + } + else + { + // No resize needed, just clear + layer.Surface!.Canvas.Clear(SKColors.Transparent); + } + } + } + + // Render all layer images in order + foreach (var (layerName, layer) in Layers) + { + // Skip background image if not requested + if (!renderBackgroundImage && layerName == "Background") + { + continue; + } + + lock (layer) + { + var layerCanvas = layer.Surface!.Canvas; + foreach (var bitmap in layer.Bitmaps) + { + layerCanvas.DrawBitmap(bitmap, new SKPoint(0, 0)); + } + } + } + + // Render paint layer with caching optimization + RenderPathsWithCaching(BrushLayer.Surface!.Canvas); + + // Draw background - either checkerboard for transparency or clear + // Draw background - either checkerboard for transparency or clear + // Include check for renderBackgroundFill so snapshots (like FloodFill analysis) can skip the checkerboard pattern + if (ShowCheckerboardBackground && renderBackgroundFill) + { + RenderCheckerboardBackground(surface.Canvas); + } + else + { + surface.Canvas.Clear(SKColors.Transparent); + } + + // Draw the layers to the main surface + foreach (var layer in Layers.Values) + { + lock (layer) + { + layer.Surface!.Canvas.Flush(); + surface.Canvas.DrawSurface(layer.Surface!, new SKPoint(0, 0)); + } + } + + // Draw grid overlay if enabled + if (ShowGridOverlay) + { + RenderGridOverlay(surface.Canvas); + } + + surface.Canvas.Flush(); + } + + /// + /// Renders a checkerboard pattern to indicate transparent areas. + /// Uses a cached shader for efficient repeated rendering. + /// + private void RenderCheckerboardBackground(SKCanvas canvas) + { + // Check if we need to create or recreate the shader + if (cachedCheckerboardShader is null || cachedCheckerboardSize != CanvasSize) + { + cachedCheckerboardShader?.Dispose(); + cachedCheckerboardShader = CreateCheckerboardShader(); + cachedCheckerboardSize = CanvasSize; + } + + using var paint = new SKPaint(); + paint.Shader = cachedCheckerboardShader; + paint.IsAntialias = false; + + canvas.DrawRect(0, 0, CanvasSize.Width, CanvasSize.Height, paint); + } + + /// + /// Creates a checkerboard pattern shader using a small tiled bitmap. + /// + private static SKShader CreateCheckerboardShader() + { + // Create a small 2x2 checker bitmap (in units of square size) + var tileSize = CheckerboardSquareSize * 2; + using var tileBitmap = new SKBitmap(tileSize, tileSize); + using var tileCanvas = new SKCanvas(tileBitmap); + + // Draw the four squares + using var lightPaint = new SKPaint { Color = CheckerboardLight }; + using var darkPaint = new SKPaint { Color = CheckerboardDark }; + + // Top-left and bottom-right are light + tileCanvas.DrawRect(0, 0, CheckerboardSquareSize, CheckerboardSquareSize, lightPaint); + tileCanvas.DrawRect( + CheckerboardSquareSize, + CheckerboardSquareSize, + CheckerboardSquareSize, + CheckerboardSquareSize, + lightPaint + ); + + // Top-right and bottom-left are dark + tileCanvas.DrawRect( + CheckerboardSquareSize, + 0, + CheckerboardSquareSize, + CheckerboardSquareSize, + darkPaint + ); + tileCanvas.DrawRect( + 0, + CheckerboardSquareSize, + CheckerboardSquareSize, + CheckerboardSquareSize, + darkPaint + ); + + tileCanvas.Flush(); + + // Create a shader that tiles this bitmap + return SKShader.CreateBitmap(tileBitmap, SKShaderTileMode.Repeat, SKShaderTileMode.Repeat); + } + + /// + /// Renders a grid overlay for alignment assistance (e.g., rule of thirds). + /// + private void RenderGridOverlay(SKCanvas canvas) + { + if (GridDivisions <= 1 || CanvasSize == Size.Empty) + return; + + using var paint = new SKPaint + { + Color = GridLineColor, + IsAntialias = true, + Style = SKPaintStyle.Stroke, + StrokeWidth = 1f, + }; + + var width = CanvasSize.Width; + var height = CanvasSize.Height; + + // Draw vertical lines + for (var i = 1; i < GridDivisions; i++) + { + var x = (float)(width * i) / GridDivisions; + canvas.DrawLine(x, 0, x, height, paint); + } + + // Draw horizontal lines + for (var i = 1; i < GridDivisions; i++) + { + var y = (float)(height * i) / GridDivisions; + canvas.DrawLine(0, y, width, y, paint); + } + } + + /// + /// Renders paths with caching optimization. Completed paths are cached + /// to avoid re-rendering them every frame. + /// + private void RenderPathsWithCaching(SKCanvas paintLayerCanvas) + { + var currentPathCount = Paths.Count; + var hasTemporaryPaths = !TemporaryPaths.IsEmpty; + + // Check if we can use the cached image + if (cachedPathsImage != null && cachedPathsCount == currentPathCount && !hasTemporaryPaths) + { + // All paths are cached and no temporary paths - just draw the cached image + paintLayerCanvas.DrawImage(cachedPathsImage, new SKPoint(0, 0)); + return; + } + + // Check if we need to update the cache (new completed paths) + if (cachedPathsCount < currentPathCount && !hasTemporaryPaths) + { + // Render all completed paths to a new cached image + UpdatePathCache(); + + if (cachedPathsImage != null) + { + paintLayerCanvas.DrawImage(cachedPathsImage, new SKPoint(0, 0)); + return; + } + } + + // Fallback: render with partial caching + using var paint = new SKPaint(); + + // If we have a cache, draw it first + if (cachedPathsImage != null && cachedPathsCount > 0) + { + paintLayerCanvas.DrawImage(cachedPathsImage, new SKPoint(0, 0)); + + // Only render paths that aren't in the cache + for (var i = cachedPathsCount; i < currentPathCount; i++) + { + RenderPenPath(paintLayerCanvas, Paths[i], paint); + } + } + else + { + // No cache, render all paths + foreach (var penPath in Paths) + { + RenderPenPath(paintLayerCanvas, penPath, paint); + } + } + + // Render temporary paths directly (the batched RenderPenPath is already optimized) + foreach (var penPath in TemporaryPaths.Values) + { + RenderPenPath(paintLayerCanvas, penPath, paint); + } + } + + /// + /// Renders temporary paths with incremental caching for long strokes. + /// Only new points since last render are drawn, dramatically improving + /// performance for continuous drawing. + /// + private void RenderTemporaryPathsIncremental(SKCanvas targetCanvas, SKPaint paint) + { + if (TemporaryPaths.IsEmpty) + { + // No temporary paths - dispose surface if exists + if (tempPathSurface != null) + { + tempPathSurface.Dispose(); + tempPathSurface = null; + tempPathRenderedPoints.Clear(); + } + return; + } + + // For simplicity and reliability, use a hybrid approach: + // - Keep a cached surface for the "already rendered" portions + // - Render new points directly to target canvas (which gets composited) + + // Ensure we have a temp surface + var needNewSurface = tempPathSurface == null; + if (!needNewSurface) + { + var bounds = tempPathSurface!.Canvas.DeviceClipBounds; + needNewSurface = bounds.Width != CanvasSize.Width || bounds.Height != CanvasSize.Height; + } + + if (needNewSurface) + { + tempPathSurface?.Dispose(); + var imageInfo = new SKImageInfo(CanvasSize.Width, CanvasSize.Height); + + // Use CPU surface for temp paths to avoid GPU context threading issues + tempPathSurface = SKSurface.Create(imageInfo); + tempPathSurface?.Canvas.Clear(SKColors.Transparent); + tempPathRenderedPoints.Clear(); + } + + if (tempPathSurface == null) + { + // Fallback: render all temp paths directly + foreach (var penPath in TemporaryPaths.Values) + { + RenderPenPath(targetCanvas, penPath, paint); + } + return; + } + + var tempCanvas = tempPathSurface.Canvas; + + // Check if any paths were removed (stroke finalized) - need to clear and rebuild + var pathsRemoved = false; + foreach (var pointerId in tempPathRenderedPoints.Keys.ToArray()) + { + if (!TemporaryPaths.ContainsKey(pointerId)) + { + pathsRemoved = true; + tempPathRenderedPoints.TryRemove(pointerId, out _); + } + } + + if (pathsRemoved) + { + // A stroke was finalized - clear the temp surface + tempCanvas.Clear(SKColors.Transparent); + tempPathRenderedPoints.Clear(); + } + + // Render each temporary path + foreach (var (pointerId, penPath) in TemporaryPaths) + { + var renderedCount = tempPathRenderedPoints.GetValueOrDefault(pointerId, 0); + var totalPoints = penPath.Points.Count; + + if (totalPoints > renderedCount) + { + if (renderedCount == 0) + { + // New path - render everything to the temp surface + RenderPenPath(tempCanvas, penPath, paint); + } + else + { + // Continuing path - render new segment to temp surface + RenderPenPathSegment(tempCanvas, penPath, renderedCount, totalPoints, paint); + } + tempPathRenderedPoints[pointerId] = totalPoints; + } + } + + // Draw the temp surface to target + tempCanvas.Flush(); + using var tempImage = tempPathSurface.Snapshot(); + targetCanvas.DrawImage(tempImage, new SKPoint(0, 0)); + } + + /// + /// Renders a segment of a pen path (from startIndex to endIndex). + /// Used for incremental rendering of temporary paths. + /// + private static void RenderPenPathSegment( + SKCanvas canvas, + PenPath penPath, + int startIndex, + int endIndex, + SKPaint paint + ) + { + if (startIndex >= endIndex || penPath.Points.Count == 0) + return; // Apply Color if (penPath.IsErase) { - // paint.BlendMode = SKBlendMode.SrcIn; paint.BlendMode = SKBlendMode.Clear; paint.Color = SKColors.Transparent; } @@ -308,61 +1553,467 @@ private static void RenderPenPath(SKCanvas canvas, PenPath penPath, SKPaint pain paint.Color = penPath.FillColor; } - // Defaults paint.IsDither = true; paint.IsAntialias = true; + paint.Style = SKPaintStyle.Stroke; + paint.StrokeCap = SKStrokeCap.Round; + paint.StrokeJoin = SKStrokeJoin.Round; + + // Apply feathering (soft brush edge) using blur mask filter + if (penPath.Feathering > 0) + { + var effectiveRadiusForBlur = penPath.GetEffectiveRadius(); + var blurSigma = effectiveRadiusForBlur * penPath.Feathering * 0.5f; + if (blurSigma > 0.1f) + { + paint.MaskFilter = SKMaskFilter.CreateBlur(SKBlurStyle.Normal, blurSigma); + } + } + else + { + paint.MaskFilter = null; + } - // Track if we have any pen points - var hasPenPoints = false; + using var path = new SKPath(); + var started = false; + var currentThickness = 0f; - // Can't use foreach since this list may be modified during iteration - // ReSharper disable once ForCanBeConvertedToForeach - for (var i = 0; i < penPath.Points.Count; i++) + // Start from one point before to ensure continuity + var actualStart = Math.Max(0, startIndex - 1); + + var effectiveRadius = penPath.GetEffectiveRadius(); + + for (var i = actualStart; i < endIndex && i < penPath.Points.Count; i++) { - var penPoint = penPath.Points[i]; + var point = penPath.Points[i]; + if (!point.IsPen) + continue; - // Skip non-pen points - if (!penPoint.IsPen) + var thickness = (float)((point.Pressure ?? 1) * effectiveRadius * 2.5); + + if (!started) { - continue; + path.MoveTo(point.X, point.Y); + currentThickness = thickness; + started = true; + } + else + { + path.LineTo(point.X, point.Y); + currentThickness = (currentThickness + thickness) / 2; } + } + + if (started) + { + paint.StrokeWidth = currentThickness; + canvas.DrawPath(path, paint); + } + } + + /// + /// Clears the temporary path cache. Call when a stroke is finalized. + /// + public void ClearTempPathCache() + { + tempPathSurface?.Dispose(); + tempPathSurface = null; + tempPathRenderedPoints.Clear(); + } + + /// + /// Updates the path cache with all current completed paths. + /// Uses CPU-only surfaces to avoid GPU context threading issues. + /// + private void UpdatePathCache() + { + if (CanvasSize == Size.Empty || Paths.Count == 0) + { + cachedPathsImage?.Dispose(); + cachedPathsImage = null; + cachedPathsCount = 0; + return; + } + + var imageInfo = new SKImageInfo(CanvasSize.Width, CanvasSize.Height); + + // Always use CPU surface for cache to avoid GPU context threading issues + // The cache is created once per set of completed paths, so CPU performance is acceptable + var cacheSurface = SKSurface.Create(imageInfo); - hasPenPoints = true; + if (cacheSurface == null) + { + logger.LogWarning("Failed to create cache surface"); + return; + } + + using (cacheSurface) + { + var cacheCanvas = cacheSurface.Canvas; + cacheCanvas.Clear(SKColors.Transparent); - var radius = penPoint.Radius; - var pressure = penPoint.Pressure ?? 1; - var thickness = pressure * radius * 2.5; + using var paint = new SKPaint(); - // Draw path - if (i < penPath.Points.Count - 1) + // Render all completed paths + foreach (var penPath in Paths) { - paint.Style = SKPaintStyle.Stroke; - paint.StrokeWidth = (float)thickness; - paint.StrokeCap = SKStrokeCap.Round; - paint.StrokeJoin = SKStrokeJoin.Round; + RenderPenPath(cacheCanvas, penPath, paint); + } - var nextPoint = penPath.Points[i + 1]; - canvas.DrawLine(penPoint.X, penPoint.Y, nextPoint.X, nextPoint.Y, paint); + // Save the cached image + cachedPathsImage?.Dispose(); + cachedPathsImage = cacheSurface.Snapshot(); + cachedPathsCount = Paths.Count; + + if (LogRenderingMode) + { + logger.LogDebug("Updated path cache with {Count} paths (CPU surface)", cachedPathsCount); } + } + } - // Draw circles for pens - paint.Style = SKPaintStyle.Fill; - canvas.DrawCircle(penPoint.X, penPoint.Y, (float)thickness / 2, paint); + /// + /// Renders a pen path to a canvas. This method is public so it can be shared + /// with other ViewModels like LayeredMaskEditorViewModel. + /// Optimized to batch draw calls into a single SKPath for performance. + /// + /// If provided, uses this color instead of the path's FillColor. Useful for mask export. + public static void RenderPenPath( + SKCanvas canvas, + PenPath penPath, + SKPaint paint, + SKColor? overrideColor = null + ) + { + // Handle shape path types (Rectangle, Ellipse, Bitmap) + switch (penPath.PathType) + { + case PenPathType.Rectangle: + case PenPathType.Ellipse: + RenderShapePath(canvas, penPath, paint, overrideColor); + return; + + case PenPathType.Bitmap: + RenderBitmapPath(canvas, penPath, paint, overrideColor); + return; + + case PenPathType.Freehand: + default: + // Continue with freehand rendering below + RenderFreehandPath(canvas, penPath, paint, overrideColor); + return; } + } - // Draw paths directly if we didn't have any pen points - if (!hasPenPoints) + /// + /// Renders shape paths (Rectangle and Ellipse) to the canvas. + /// + private static void RenderShapePath( + SKCanvas canvas, + PenPath penPath, + SKPaint paint, + SKColor? overrideColor + ) + { + // Apply color and blend mode + if (penPath.IsErase) + { + paint.BlendMode = SKBlendMode.Clear; + paint.Color = SKColors.Transparent; + } + else { - var point = penPath.Points[0]; - var thickness = point.Radius * 2; + paint.BlendMode = SKBlendMode.SrcOver; + paint.Color = overrideColor ?? penPath.FillColor; + } + + paint.IsDither = true; + paint.IsAntialias = true; + if (penPath.IsStrokeOnly) + { paint.Style = SKPaintStyle.Stroke; - paint.StrokeWidth = (float)thickness; - paint.StrokeCap = SKStrokeCap.Round; - paint.StrokeJoin = SKStrokeJoin.Round; + paint.StrokeWidth = penPath.StrokeWidth; + } + else + { + paint.Style = SKPaintStyle.Fill; + } + + if (penPath.PathType == PenPathType.Rectangle) + { + canvas.DrawRect(penPath.Bounds, paint); + } + else // Ellipse + { + canvas.DrawOval(penPath.Bounds, paint); + } + } + + /// + /// Renders bitmap paths to the canvas with optional color override. + /// + private static void RenderBitmapPath( + SKCanvas canvas, + PenPath penPath, + SKPaint paint, + SKColor? overrideColor + ) + { + if (penPath.BitmapData == null) + return; + + if (overrideColor.HasValue) + { + // Apply color filter to replace colors with override while keeping alpha + var color = overrideColor.Value; + using var colorPaint = new SKPaint(); + // Color matrix that replaces RGB with override color, preserves alpha + // csharpier-ignore + colorPaint.ColorFilter = SKColorFilter.CreateColorMatrix( + [ + 0, 0, 0, 0, color.Red / 255f, + 0, 0, 0, 0, color.Green / 255f, + 0, 0, 0, 0, color.Blue / 255f, + 0, 0, 0, 1, 0 + ]); + canvas.DrawBitmap(penPath.BitmapData, penPath.Bounds.Left, penPath.Bounds.Top, colorPaint); + } + else + { + canvas.DrawBitmap(penPath.BitmapData, penPath.Bounds.Left, penPath.Bounds.Top); + } + } + + /// + /// Renders freehand paths with pressure-sensitive strokes to the canvas. + /// + private static void RenderFreehandPath( + SKCanvas canvas, + PenPath penPath, + SKPaint paint, + SKColor? overrideColor = null + ) + { + // Freehand path rendering + if (penPath.Points.Count == 0) + { + return; + } + + // Apply Color + if (penPath.IsErase) + { + paint.BlendMode = SKBlendMode.Clear; + paint.Color = SKColors.Transparent; + } + else + { + paint.BlendMode = SKBlendMode.SrcOver; + paint.Color = overrideColor ?? penPath.FillColor; + } + + // Setup paint for strokes + paint.IsDither = true; + paint.IsAntialias = true; + paint.Style = SKPaintStyle.Stroke; + paint.StrokeCap = SKStrokeCap.Round; // Round caps handle endpoints + paint.StrokeJoin = SKStrokeJoin.Round; + + // Apply feathering (soft brush edge) using blur mask filter + if (penPath.Feathering > 0) + { + // Calculate blur sigma based on the effective radius and feathering amount + var effectiveRadiusForBlur = penPath.GetEffectiveRadius(); + var blurSigma = effectiveRadiusForBlur * penPath.Feathering * 0.5f; + if (blurSigma > 0.1f) + { + paint.MaskFilter = SKMaskFilter.CreateBlur(SKBlurStyle.Normal, blurSigma); + } + } + else + { + paint.MaskFilter = null; + } + + // Count pen points and check pressure uniformity in a single pass (avoids LINQ allocations) + var penPointCount = 0; + var uniformPressure = true; + var firstPressure = 0.0; + var totalThickness = 0.0; + var firstPenPointIndex = -1; + + // Get effective radius (path-level or backward-compat from first point) + var effectiveRadius = penPath.GetEffectiveRadius(); + + for (var i = 0; i < penPath.Points.Count; i++) + { + var p = penPath.Points[i]; + if (!p.IsPen) + continue; + + var pressure = p.Pressure ?? 1; + var thickness = pressure * effectiveRadius * 2.5; + + if (penPointCount == 0) + { + firstPressure = pressure; + firstPenPointIndex = i; + } + else if (uniformPressure && Math.Abs(pressure - firstPressure) >= 0.1) + { + uniformPressure = false; + } + + totalThickness += thickness; + penPointCount++; + } + if (penPointCount == 0) + { + // No pen points - use the ToSKPath method for mouse-based paths + paint.StrokeWidth = effectiveRadius * 2; var skPath = penPath.ToSKPath(); canvas.DrawPath(skPath, paint); + return; + } + + // For pressure-sensitive drawing, we need to handle variable thickness + if (penPointCount == 1) + { + // Single point - draw a circle + var point = penPath.Points[firstPenPointIndex]; + var thickness = (point.Pressure ?? 1) * effectiveRadius * 2.5; + paint.Style = SKPaintStyle.Fill; + canvas.DrawCircle(point.X, point.Y, (float)(thickness / 2), paint); + return; + } + + if (uniformPressure) + { + // All points have similar pressure - batch into single path + var avgThickness = totalThickness / penPointCount; + paint.StrokeWidth = (float)avgThickness; + + using var path = new SKPath(); + var started = false; + + // Use plain loop instead of LINQ to avoid iterator allocation in hot path + foreach (var p in penPath.Points) + { + if (!p.IsPen) + continue; + + if (!started) + { + path.MoveTo(p.X, p.Y); + started = true; + } + else + { + path.LineTo(p.X, p.Y); + } + } + + canvas.DrawPath(path, paint); + } + else + { + // Variable pressure - draw segments with varying thickness + // Batch into groups of similar thickness for fewer draw calls + using var path = new SKPath(); + var currentThickness = 0f; + var pathStarted = false; + var lastPenX = 0f; + var lastPenY = 0f; + + foreach (var point in penPath.Points) + { + if (!point.IsPen) + continue; + + var thickness = (float)((point.Pressure ?? 1) * effectiveRadius * 2.5); + + // If thickness changed significantly, draw current path and start new one + if (pathStarted && Math.Abs(thickness - currentThickness) > currentThickness * 0.2f) + { + paint.StrokeWidth = currentThickness; + canvas.DrawPath(path, paint); + path.Reset(); + + // Start new path from previous point for continuity + path.MoveTo(lastPenX, lastPenY); + pathStarted = false; + } + + if (!pathStarted) + { + path.MoveTo(point.X, point.Y); + currentThickness = thickness; + pathStarted = true; + } + else + { + path.LineTo(point.X, point.Y); + // Smoothly blend thickness + currentThickness = (currentThickness + thickness) / 2; + } + + lastPenX = point.X; + lastPenY = point.Y; + } + + // Draw remaining path + if (pathStarted) + { + paint.StrokeWidth = currentThickness; + canvas.DrawPath(path, paint); + } + } + } + + /// + /// Disposes all cached resources to free memory. + /// + public void Dispose() + { + if (_disposed) + return; + + _disposed = true; + + // Dispose cached path image + cachedPathsImage?.Dispose(); + cachedPathsImage = null; + + // Dispose temporary path surface + tempPathSurface?.Dispose(); + tempPathSurface = null; + tempPathRenderedPoints.Clear(); + + // Dispose checkerboard shader + cachedCheckerboardShader?.Dispose(); + cachedCheckerboardShader = null; + + // Dispose layer surfaces and bitmaps + foreach (var layer in Layers.Values) + { + lock (layer) + { + layer.Surface?.Dispose(); + layer.Surface = null; + + foreach (var bitmap in layer.Bitmaps) + { + bitmap.Dispose(); + } + layer.Bitmaps = []; + } } + + // Clear paths + TemporaryPaths.Clear(); + + GC.SuppressFinalize(this); } } diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/DownloadMissingModelsViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/DownloadMissingModelsViewModel.cs new file mode 100644 index 000000000..8e1c9fafd --- /dev/null +++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/DownloadMissingModelsViewModel.cs @@ -0,0 +1,271 @@ +using System.Collections.ObjectModel; +using Avalonia.Controls; +using Avalonia.Threading; +using CommunityToolkit.Mvvm.ComponentModel; +using CommunityToolkit.Mvvm.Input; +using FluentAvalonia.UI.Controls; +using Injectio.Attributes; +using Microsoft.Extensions.Logging; +using StabilityMatrix.Avalonia.Controls; +using StabilityMatrix.Avalonia.Languages; +using StabilityMatrix.Avalonia.ViewModels.Base; +using StabilityMatrix.Avalonia.Views.Dialogs; +using StabilityMatrix.Core.Attributes; +using StabilityMatrix.Core.Extensions; +using StabilityMatrix.Core.Helper; +using StabilityMatrix.Core.Models; +using StabilityMatrix.Core.Models.FileInterfaces; +using StabilityMatrix.Core.Models.Progress; +using StabilityMatrix.Core.Services; + +namespace StabilityMatrix.Avalonia.ViewModels.Dialogs; + +/// +/// Reusable dialog view model for downloading missing models. +/// Can be configured for any provider that needs model downloads. +/// +[View(typeof(DownloadMissingModelsDialog))] +[ManagedService] +[RegisterTransient] +public partial class DownloadMissingModelsViewModel( + ILogger logger, + ISettingsManager settingsManager, + ITrackedDownloadService trackedDownloadService, + IDownloadService downloadService +) : ContentDialogViewModelBase +{ + /// + /// Dialog title (e.g., "Flux Kontext Setup") + /// + [ObservableProperty] + public partial string DialogTitle { get; set; } = "Download Required Models"; + + /// + /// Friendly description message + /// + [ObservableProperty] + public partial string Description { get; set; } = + "The following models are required. Select the ones you'd like to download."; + + /// + /// Collection of downloadable model items + /// + public ObservableCollection Models { get; } = []; + + /// + /// Whether file sizes are being loaded + /// + [ObservableProperty] + public partial bool IsLoadingSizes { get; set; } + + /// + /// Number of selected items + /// + public int SelectedCount => Models.Count(m => m.IsSelected); + + /// + /// Total size of selected items + /// + public string TotalSelectedSizeText + { + get + { + var totalBytes = Models.Where(m => m.IsSelected).Sum(m => m.FileSize); + return totalBytes > 0 ? Size.FormatBase10Bytes(totalBytes) : "Calculating..."; + } + } + + /// + /// Whether download can be started + /// + public bool CanStartDownload => SelectedCount > 0; + + /// + /// The downloads that were started (populated after StartDownloadsAsync is called) + /// + public List StartedDownloads { get; } = []; + + /// + /// Set the models to display in the dialog + /// + public void SetModels(IEnumerable resources) + { + Models.Clear(); + + foreach (var resource in resources) + { + var item = new DownloadableModelItemViewModel(resource); + item.PropertyChanged += (s, e) => + { + if (e.PropertyName == nameof(DownloadableModelItemViewModel.IsSelected)) + { + OnPropertyChanged(nameof(SelectedCount)); + OnPropertyChanged(nameof(TotalSelectedSizeText)); + OnPropertyChanged(nameof(CanStartDownload)); + } + }; + Models.Add(item); + } + + // Load file sizes asynchronously + _ = LoadFileSizesAsync(); + } + + private async Task LoadFileSizesAsync() + { + if (Design.IsDesignMode) + return; + + IsLoadingSizes = true; + + try + { + var tasks = Models.Select(async model => + { + try + { + if (model.Resource.Url is { } url) + { + var size = await downloadService.GetFileSizeAsync(url.ToString()); + await Dispatcher.UIThread.InvokeAsync(() => + { + model.FileSize = size; + }); + } + } + catch (Exception ex) + { + logger.LogWarning(ex, "Failed to get file size for {FileName}", model.FileName); + } + }); + + await Task.WhenAll(tasks); + } + finally + { + IsLoadingSizes = false; + OnPropertyChanged(nameof(TotalSelectedSizeText)); + } + } + + [RelayCommand] + private void SelectAll() + { + foreach (var model in Models) + { + model.IsSelected = true; + } + } + + [RelayCommand] + private void DeselectAll() + { + foreach (var model in Models) + { + model.IsSelected = false; + } + } + + /// + /// Queue downloads for all selected models. Returns the list of started downloads. + /// Call this after dialog closes with Primary result. + /// + public async Task> StartDownloadsAsync() + { + var selectedModels = Models.Where(m => m.IsSelected).ToList(); + StartedDownloads.Clear(); + + if (selectedModels.Count == 0) + { + return StartedDownloads; + } + + logger.LogInformation("Queueing download of {Count} models", selectedModels.Count); + + foreach (var model in selectedModels) + { + try + { + var download = await QueueDownloadAsync(model); + if (download != null) + { + StartedDownloads.Add(download); + } + } + catch (Exception ex) + { + logger.LogError(ex, "Failed to queue download for {FileName}", model.FileName); + } + } + + // Show progress flyout + if (StartedDownloads.Count > 0) + { + EventManager.Instance.OnToggleProgressFlyout(); + } + + return StartedDownloads; + } + + private async Task QueueDownloadAsync(DownloadableModelItemViewModel model) + { + var resource = model.Resource; + + var sharedFolderType = + resource.ContextType as SharedFolderType? + ?? throw new InvalidOperationException( + $"ContextType is not SharedFolderType for {resource.FileName}" + ); + + var modelsDir = new DirectoryPath(settingsManager.ModelsDirectory).JoinDir( + sharedFolderType.GetStringValue() + ); + + if (resource.RelativeDirectory is not null) + { + modelsDir = modelsDir.JoinDir(resource.RelativeDirectory); + } + + // Ensure directory exists + modelsDir.Create(); + + var downloadPath = modelsDir.JoinFile(resource.FileName); + + logger.LogInformation("Queueing download: {FileName} to {Path}", resource.FileName, downloadPath); + + var download = trackedDownloadService.NewDownload(resource.Url, downloadPath); + + // Set hash for verification if available + if (resource.HashSha256 is not null) + { + download.ExpectedHashSha256 = resource.HashSha256; + } + + // Set extraction properties + download.AutoExtractArchive = resource.AutoExtractArchive; + download.ExtractRelativePath = resource.ExtractRelativePath; + + // Set context action for post-download processing + download.ContextAction = new ModelPostDownloadContextAction(); + + // Start the download + await trackedDownloadService.TryStartDownload(download); + + return download; + } + + public override BetterContentDialog GetDialog() + { + var dialog = base.GetDialog(); + + dialog.Title = DialogTitle; + dialog.Content = new DownloadMissingModelsDialog { DataContext = this }; + dialog.PrimaryButtonText = Resources.Action_Download; + dialog.CloseButtonText = "Skip for Now"; + dialog.DefaultButton = ContentDialogButton.Primary; + dialog.IsPrimaryButtonEnabled = CanStartDownload; + dialog.MinDialogWidth = 550; + + return dialog; + } +} diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/DownloadableModelItemViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/DownloadableModelItemViewModel.cs new file mode 100644 index 000000000..265806722 --- /dev/null +++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/DownloadableModelItemViewModel.cs @@ -0,0 +1,140 @@ +using CommunityToolkit.Mvvm.ComponentModel; +using CommunityToolkit.Mvvm.Input; +using StabilityMatrix.Avalonia.ViewModels.Base; +using StabilityMatrix.Core.Helper; +using StabilityMatrix.Core.Models; +using StabilityMatrix.Core.Models.Progress; + +namespace StabilityMatrix.Avalonia.ViewModels.Dialogs; + +/// +/// ViewModel for a single downloadable model item in the missing models dialog. +/// Wraps a RemoteResource with selection and progress state. +/// +public partial class DownloadableModelItemViewModel(RemoteResource resource, string? displayName = null) + : ViewModelBase +{ + /// + /// The underlying remote resource + /// + public RemoteResource Resource { get; } = resource; + + /// + /// Whether this item is selected for download + /// + [ObservableProperty] + public partial bool IsSelected { get; set; } = true; + + /// + /// Whether this item is currently downloading + /// + [ObservableProperty] + public partial bool IsDownloading { get; set; } + + /// + /// Whether this item has completed downloading + /// + [ObservableProperty] + public partial bool IsCompleted { get; set; } + + /// + /// Whether this item failed to download + /// + [ObservableProperty] + public partial bool IsFailed { get; set; } + + /// + /// Current download progress (0-100) + /// + [ObservableProperty] + [NotifyPropertyChangedFor(nameof(ProgressText))] + public partial double Progress { get; set; } + + /// + /// File size in bytes (fetched asynchronously) + /// + [ObservableProperty] + [NotifyPropertyChangedFor(nameof(FileSizeText))] + public partial long FileSize { get; set; } + + /// + /// Status message for the download + /// + [ObservableProperty] + public partial string? StatusMessage { get; set; } + + /// + /// Display name for the model + /// + public string DisplayName { get; } = displayName ?? GetDefaultDisplayName(resource); + + /// + /// Type badge text (e.g., "UNET", "VAE", "CLIP") + /// + public string TypeBadge { get; } = GetTypeBadge(resource); + + /// + /// File name + /// + public string FileName => Resource.FileName; + + /// + /// Formatted file size text + /// + public string? FileSizeText => FileSize > 0 ? Size.FormatBase10Bytes(FileSize) : null; + + /// + /// Progress text for display + /// + public string ProgressText => IsDownloading ? $"{Progress:F0}%" : string.Empty; + + /// + /// Author of the model + /// + public string? Author => Resource.Author; + + /// + /// License type + /// + public string? LicenseType => Resource.LicenseType; + + // Determine display name based on context type or filename + + private static string GetDefaultDisplayName(RemoteResource resource) + { + // Try to get a friendly name based on the file and context + var fileName = resource.FileName; + + return resource.ContextType switch + { + SharedFolderType.DiffusionModels + when fileName.Contains("kontext", StringComparison.OrdinalIgnoreCase) => "Flux Kontext UNET", + SharedFolderType.VAE when fileName.Equals("ae.safetensors", StringComparison.OrdinalIgnoreCase) => + "Flux VAE", + SharedFolderType.TextEncoders + when fileName.Contains("clip_l", StringComparison.OrdinalIgnoreCase) => "CLIP-L Text Encoder", + SharedFolderType.TextEncoders + when fileName.Contains("t5xxl", StringComparison.OrdinalIgnoreCase) => "T5-XXL Text Encoder", + _ => Path.GetFileNameWithoutExtension(fileName), + }; + } + + private static string GetTypeBadge(RemoteResource resource) + { + return resource.ContextType switch + { + SharedFolderType.DiffusionModels => "UNET", + SharedFolderType.VAE => "VAE", + SharedFolderType.TextEncoders => "CLIP", + SharedFolderType.ControlNet => "ControlNet", + SharedFolderType.Lora or SharedFolderType.LyCORIS => "LoRA", + _ => resource.ContextType?.ToString() ?? "Model", + }; + } + + [RelayCommand] + private void ToggleSelection() + { + IsSelected = !IsSelected; + } +} diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/ImageAnnotationEditorViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/ImageAnnotationEditorViewModel.cs new file mode 100644 index 000000000..e4b99435d --- /dev/null +++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/ImageAnnotationEditorViewModel.cs @@ -0,0 +1,222 @@ +using System.Runtime.CompilerServices; +using System.Text.Json.Serialization; +using Avalonia; +using Avalonia.Controls.Primitives; +using Avalonia.Media.Imaging; +using Avalonia.Threading; +using CommunityToolkit.Mvvm.ComponentModel; +using CommunityToolkit.Mvvm.Input; +using Injectio.Attributes; +using SkiaSharp; +using StabilityMatrix.Avalonia.Controls; +using StabilityMatrix.Avalonia.Extensions; +using StabilityMatrix.Avalonia.Languages; +using StabilityMatrix.Avalonia.Models; +using StabilityMatrix.Avalonia.Services; +using StabilityMatrix.Avalonia.ViewModels.Base; +using StabilityMatrix.Avalonia.ViewModels.Controls; +using StabilityMatrix.Avalonia.Views.Dialogs; +using StabilityMatrix.Core.Attributes; +using ContentDialogButton = FluentAvalonia.UI.Controls.ContentDialogButton; + +namespace StabilityMatrix.Avalonia.ViewModels.Dialogs; + +/// +/// ViewModel for the image annotation editor dialog. +/// Allows users to draw/annotate on images before sending to AI providers. +/// +[RegisterTransient] +[ManagedService] +[View(typeof(ImageAnnotationEditorDialog))] +public partial class ImageAnnotationEditorViewModel(IServiceManager vmFactory) + : LoadableViewModelBase, + IDisposable +{ + [JsonIgnore] + private SKBitmap? originalBitmap; + + [JsonIgnore] + private ImageSource? cachedAnnotatedImage; + + /// + /// The source image file path being edited + /// + [ObservableProperty] + private string? sourceFilePath; + + /// + /// The paint canvas view model for drawing annotations + /// + [JsonInclude] + public PaintCanvasViewModel PaintCanvasViewModel { get; } = vmFactory.Get(); + + /// + /// Whether there are any annotations on the canvas + /// + public bool HasAnnotations => PaintCanvasViewModel.Paths.Count > 0; + + /// + /// Load an image from file path for editing + /// + public void LoadImage(string filePath) + { + SourceFilePath = filePath; + originalBitmap?.Dispose(); + originalBitmap = SKBitmap.Decode(filePath); + + if (originalBitmap != null) + { + PaintCanvasViewModel.BackgroundImage = originalBitmap; + PaintCanvasViewModel.RefreshCanvas?.Invoke(); + } + } + + /// + /// Load an image from bitmap for editing + /// + public void LoadImage(Bitmap bitmap, string? sourcePath = null) + { + SourceFilePath = sourcePath; + originalBitmap?.Dispose(); + + // Convert Avalonia Bitmap to SKBitmap + using var stream = new MemoryStream(); + bitmap.Save(stream); + stream.Position = 0; + originalBitmap = SKBitmap.Decode(stream); + + if (originalBitmap != null) + { + PaintCanvasViewModel.BackgroundImage = originalBitmap; + PaintCanvasViewModel.RefreshCanvas?.Invoke(); + } + } + + /// + /// Get the annotated image with drawings overlaid on the original + /// + [MethodImpl(MethodImplOptions.Synchronized)] + public ImageSource? GetAnnotatedImage() + { + if (cachedAnnotatedImage != null) + { + return cachedAnnotatedImage; + } + + using var skImage = RenderAnnotatedImage(); + if (skImage == null) + { + return null; + } + + cachedAnnotatedImage = new ImageSource(skImage.ToAvaloniaBitmap()); + return cachedAnnotatedImage; + } + + /// + /// Render the annotated image to an SKImage + /// + public SKImage? RenderAnnotatedImage() + { + var canvasSize = PaintCanvasViewModel.CanvasSize; + if (canvasSize.IsEmpty) + { + return null; + } + + using var surface = SKSurface.Create(new SKImageInfo(canvasSize.Width, canvasSize.Height)); + PaintCanvasViewModel.RenderToSurface( + surface, + renderBackgroundFill: false, + renderBackgroundImage: true + ); + + return surface.Snapshot(); + } + + /// + /// Save the annotated image to a file + /// + public async Task SaveAnnotatedImageAsync(string? targetPath = null) + { + using var image = RenderAnnotatedImage(); + if (image == null) + { + return null; + } + + // Generate target path if not provided + targetPath ??= Path.Combine(Path.GetTempPath(), $"annotated_{Guid.NewGuid():N}.png"); + + using var data = image.Encode(SKEncodedImageFormat.Png, 100); + await using var fileStream = File.OpenWrite(targetPath); + data.SaveTo(fileStream); + + return targetPath; + } + + /// + /// Get the annotated image as a byte array (PNG format) + /// + public byte[]? GetAnnotatedImageBytes() + { + using var image = RenderAnnotatedImage(); + if (image == null) + { + return null; + } + + using var data = image.Encode(SKEncodedImageFormat.Png, 100); + return data.ToArray(); + } + + /// + /// Invalidate the cached annotated image + /// + public void InvalidateCache() + { + cachedAnnotatedImage?.Dispose(); + cachedAnnotatedImage = null; + } + + /// + /// Clear all annotations from the canvas + /// + [RelayCommand] + public void ClearAnnotations() + { + PaintCanvasViewModel.Paths = []; + PaintCanvasViewModel.RefreshCanvas?.Invoke(); + InvalidateCache(); + } + + /// + /// Create and show the editor dialog + /// + public BetterContentDialog GetDialog() + { + Dispatcher.UIThread.VerifyAccess(); + + var dialog = new BetterContentDialog + { + Content = this, + ContentVerticalScrollBarVisibility = ScrollBarVisibility.Disabled, + MaxDialogHeight = 900, + MaxDialogWidth = 1200, + ContentMargin = new Thickness(16), + FullSizeDesired = true, + PrimaryButtonText = Resources.Action_Save, + CloseButtonText = Resources.Action_Cancel, + DefaultButton = ContentDialogButton.Primary, + }; + + return dialog; + } + + public void Dispose() + { + originalBitmap?.Dispose(); + cachedAnnotatedImage?.Dispose(); + GC.SuppressFinalize(this); + } +} diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/LayeredMaskEditorViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/LayeredMaskEditorViewModel.cs new file mode 100644 index 000000000..b62a1f104 --- /dev/null +++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/LayeredMaskEditorViewModel.cs @@ -0,0 +1,1907 @@ +using System.Collections.Immutable; +using System.Collections.ObjectModel; +using System.ComponentModel; +using System.Text.Json.Nodes; +using Avalonia; +using Avalonia.Controls.Primitives; +using Avalonia.Platform.Storage; +using Avalonia.Threading; +using CommunityToolkit.Mvvm.ComponentModel; +using CommunityToolkit.Mvvm.Input; +using DynamicData; +using DynamicData.Binding; +using Injectio.Attributes; +using Microsoft.Extensions.Logging; +using SkiaSharp; +using StabilityMatrix.Avalonia.Controls; +using StabilityMatrix.Avalonia.Controls.Models; +using StabilityMatrix.Avalonia.Languages; +using StabilityMatrix.Avalonia.Models.Inference; +using StabilityMatrix.Avalonia.Services; +using StabilityMatrix.Avalonia.ViewModels.Base; +using StabilityMatrix.Avalonia.ViewModels.Controls; +using StabilityMatrix.Avalonia.Views.Dialogs; +using StabilityMatrix.Core.Attributes; +using StabilityMatrix.Core.Models.Database; +using StabilityMatrix.Core.Services; +using ContentDialogButton = FluentAvalonia.UI.Controls.ContentDialogButton; +using Size = System.Drawing.Size; + +namespace StabilityMatrix.Avalonia.ViewModels.Dialogs; + +/// +/// ViewModel for the layered mask editor dialog. +/// Manages multiple layers with independent masks, prompts, and opacity settings. +/// +[RegisterTransient] +[ManagedService] +[View(typeof(LayeredMaskEditorDialog))] +public partial class LayeredMaskEditorViewModel : LoadableViewModelBase, IDisposable +{ + private readonly IImageIndexService imageIndexService; + private readonly ILogger logger; + private readonly IServiceManager vmFactory; + + /// + /// Canvas size for all layers. + /// + [ObservableProperty] + private Size canvasSize = new(1024, 1024); + + /// + /// Previous canvas size, used for rescaling layers when dimensions change. + /// + private Size _previousCanvasSize = new(1024, 1024); + + private int imageLayerCounter; + + /// + /// Stack of layer snapshots for undo support. + /// Each entry captures the full layer state before a destructive operation. + /// + private readonly Stack layerUndoStack = new(); + + /// + /// Maximum number of undo snapshots to keep. + /// + private const int MaxUndoSnapshots = 20; + + /// + /// Whether the recent images panel is expanded. + /// + [ObservableProperty] + private bool isRecentImagesPanelExpanded; + + /// + /// Counter to suppress layer index change callbacks from programmatic list updates. + /// This keeps drag-drop callbacks from fighting keyboard and button reorders. + /// + private int layerIndexChangeSuppressionCount; + + private int layerCounter; + + /// + /// Cached bitmap for the currently selected image layer. + /// Invalidated when source image, scale, opacity, offset, flip, or canvas size changes. + /// + private SKBitmap? _cachedImageLayerBitmap; + private MaskLayer? _cachedImageLayerSource; + private SKBitmap? _cachedImageLayerSourceImage; + private double _cachedImageLayerScale; + private double _cachedImageLayerOpacity; + private double _cachedImageLayerOffsetX; + private double _cachedImageLayerOffsetY; + private bool _cachedImageLayerFlipH; + private bool _cachedImageLayerFlipV; + private Size _cachedImageLayerCanvasSize; + + /// + /// The currently selected layer for editing. + /// + [ObservableProperty] + [NotifyCanExecuteChangedFor(nameof(DeleteLayerCommand))] + private MaskLayer? selectedLayer; + + /// + /// When true, shows all layers composited on the canvas. + /// When false, shows only the selected layer. + /// + [ObservableProperty] + private bool showAllLayers = true; + + public LayeredMaskEditorViewModel( + IServiceManager vmFactory, + IImageIndexService imageIndexService, + ILogger logger + ) + { + this.vmFactory = vmFactory; + this.imageIndexService = imageIndexService; + this.logger = logger; + PaintCanvasViewModel = vmFactory.Get(); + + // Set up Move tool callback to update image layer offsets + PaintCanvasViewModel.OnMoveToolDrag = (newOffsetX, newOffsetY) => + { + if (SelectedLayer is { LayerType: MaskLayerType.Image }) + { + SelectedLayer.ImageOffsetX = newOffsetX; + SelectedLayer.ImageOffsetY = newOffsetY; + SyncSelectedLayerToCanvas(); + } + }; + + // Provide current offset when starting a move + PaintCanvasViewModel.GetCurrentMoveOffset = () => + { + if (SelectedLayer is { LayerType: MaskLayerType.Image }) + { + return (SelectedLayer.ImageOffsetX, SelectedLayer.ImageOffsetY); + } + return (0, 0); + }; + + // Subscribe to recent images + imageIndexService + .InferenceImages.ItemsSource.Connect() + .DeferUntilLoaded() + .SortBy(file => file.LastModifiedAt, SortDirection.Descending) + .Top(50) // Limit to 50 most recent + .Bind(LocalImages) + .Subscribe(); + + // Initialize with one layer + AddLayer(); + } + + /// + /// The collection of layers in the editor (ordered from bottom to top). + /// + public ObservableCollection Layers { get; } = []; + + /// + /// The paint canvas view model for the currently selected layer. + /// + public PaintCanvasViewModel PaintCanvasViewModel { get; } + + /// + /// Collection of recent inference images for quick selection. + /// + public IObservableCollection LocalImages { get; } = + new ObservableCollectionExtended(); + + /// + public void Dispose() + { + // Clean up all layers + foreach (var layer in Layers) + CleanupLayer(layer); + Layers.Clear(); + + // Dispose cached image layer bitmap + _cachedImageLayerBitmap?.Dispose(); + _cachedImageLayerBitmap = null; + _cachedImageLayerSource = null; + _cachedImageLayerSourceImage = null; + + // Dispose the paint canvas view model + PaintCanvasViewModel.Dispose(); + + GC.SuppressFinalize(this); + } + + /// + /// Captures a snapshot of the current layer state for undo. + /// Call this before any destructive layer operation. + /// + private void PushLayerUndoSnapshot() + { + // Save current canvas paths to the layer first + SaveCurrentLayerPaths(); + + // Serialize each layer's state + var layerStates = Layers.Select(l => l.SaveStateToJsonObject()).ToList(); + var selectedIndex = SelectedLayer is not null ? Layers.IndexOf(SelectedLayer) : -1; + + layerUndoStack.Push(new LayerSnapshot(layerStates, selectedIndex, layerCounter, imageLayerCounter)); + + // Trim to max size + if (layerUndoStack.Count > MaxUndoSnapshots) + { + // Rebuild stack without the oldest entry + var items = layerUndoStack.ToArray(); + layerUndoStack.Clear(); + for (var i = items.Length - 2; i >= 0; i--) + layerUndoStack.Push(items[i]); + } + + UndoLayerOperationCommand.NotifyCanExecuteChanged(); + } + + /// + /// Undoes the last destructive layer operation by restoring a snapshot. + /// + [RelayCommand(CanExecute = nameof(CanUndoLayerOperation))] + private void UndoLayerOperation() + { + if (layerUndoStack.Count == 0) + return; + + var snapshot = layerUndoStack.Pop(); + + RunWithLayerIndexChangeSuppressed(() => + { + // Clear current layers + SelectedLayer = null; + foreach (var layer in Layers) + CleanupLayer(layer); + Layers.Clear(); + + // Restore counters + layerCounter = snapshot.LayerCounter; + imageLayerCounter = snapshot.ImageLayerCounter; + + // Restore layers from snapshot + foreach (var layerState in snapshot.LayerStates) + { + var layer = new MaskLayer(); + layer.LoadStateFromJsonObject(layerState); + layer.PropertyChanged += Layer_PropertyChanged; + Layers.Add(layer); + } + + // Restore selection + if (snapshot.SelectedLayerIndex >= 0 && snapshot.SelectedLayerIndex < Layers.Count) + SelectedLayer = Layers[snapshot.SelectedLayerIndex]; + else if (Layers.Count > 0) + SelectedLayer = Layers[0]; + + // Reload image layer bitmaps from paths + ReloadImageLayersFromPaths(); + + SyncSelectedLayerToCanvas(); + }); + + UndoLayerOperationCommand.NotifyCanExecuteChanged(); + DeleteLayerCommand.NotifyCanExecuteChanged(); + } + + private bool CanUndoLayerOperation() => layerUndoStack.Count > 0; + + /// + /// Snapshot of the full layer state for undo support. + /// + private sealed record LayerSnapshot( + List LayerStates, + int SelectedLayerIndex, + int LayerCounter, + int ImageLayerCounter + ); + + /// + public override async Task OnLoadedAsync() + { + await base.OnLoadedAsync(); + + // Refresh the image index to populate recent images + await imageIndexService.RefreshIndexForAllCollections(); + } + + /// + /// Adds a new layer on top of the stack. + /// + [RelayCommand] + private void AddLayer() + { + layerCounter++; + var layer = new MaskLayer + { + Name = $"Layer {layerCounter}", + DisplayColor = MaskLayerColors.GetByIndex(layerCounter - 1), + }; + + // Subscribe to layer property changes to refresh canvas + layer.PropertyChanged += Layer_PropertyChanged; + + Layers.Add(layer); + SelectedLayer = layer; + SyncSelectedLayerToCanvas(); + } + + /// + /// Adds a new image layer on top of the stack. + /// + [RelayCommand] + private void AddImageLayer() + { + imageLayerCounter++; + var layer = new MaskLayer + { + Name = $"Image {imageLayerCounter}", + LayerType = MaskLayerType.Image, + DisplayColor = new SKColor(128, 128, 128), // Gray for image layers + }; + + // Subscribe to layer property changes to refresh canvas + layer.PropertyChanged += Layer_PropertyChanged; + + Layers.Add(layer); + SelectedLayer = layer; + + // Expand the recent images panel when adding an image layer + IsRecentImagesPanelExpanded = true; + + SyncSelectedLayerToCanvas(); + } + + /// + /// Centers the selected image layer by resetting its offset to (0, 0). + /// + [RelayCommand] + private void CenterImageLayer(MaskLayer? target = null) + { + var layer = target ?? SelectedLayer; + if (layer is not { LayerType: MaskLayerType.Image }) + return; + + layer.ImageOffsetX = 0; + layer.ImageOffsetY = 0; + SyncSelectedLayerToCanvas(); + } + + /// + /// Toggles horizontal flip for the selected image layer. + /// + [RelayCommand] + private void FlipImageHorizontally(MaskLayer? target = null) + { + var layer = target ?? SelectedLayer; + if (layer is not { LayerType: MaskLayerType.Image }) + return; + + layer.IsFlippedHorizontally = !layer.IsFlippedHorizontally; + SyncSelectedLayerToCanvas(); + } + + /// + /// Toggles vertical flip for the selected image layer. + /// + [RelayCommand] + private void FlipImageVertically(MaskLayer? target = null) + { + var layer = target ?? SelectedLayer; + if (layer is not { LayerType: MaskLayerType.Image }) + return; + + layer.IsFlippedVertically = !layer.IsFlippedVertically; + SyncSelectedLayerToCanvas(); + } + + /// + /// Scales the selected image layer to fit within the canvas bounds while maintaining aspect ratio. + /// + [RelayCommand] + private void FitImageToCanvas(MaskLayer? target = null) + { + var layer = target ?? SelectedLayer; + if (layer is not { LayerType: MaskLayerType.Image, SourceImage: { } sourceImage }) + return; + + if (CanvasSize == Size.Empty) + return; + + // Calculate scale to fit image within canvas (maintaining aspect ratio) + var scaleX = (double)CanvasSize.Width / sourceImage.Width; + var scaleY = (double)CanvasSize.Height / sourceImage.Height; + var fitScale = Math.Min(scaleX, scaleY); + + // Clamp to valid range (0.1 to 3.0) + fitScale = Math.Clamp(fitScale, 0.1, 3.0); + + layer.ImageScale = fitScale; + layer.ImageOffsetX = 0; + layer.ImageOffsetY = 0; + SyncSelectedLayerToCanvas(); + } + + /// + /// Selects an image from the recent images panel for the current image layer. + /// + [RelayCommand] + private async Task SelectImageFromRecent(LocalImageFile? imageFile) + { + if (imageFile is null || SelectedLayer is null) + return; + + // If selected layer is not an image layer, create a new one + if (SelectedLayer.LayerType != MaskLayerType.Image) + AddImageLayer(); + + await LoadImageIntoLayerAsync(SelectedLayer!, imageFile.AbsolutePath); + } + + /// + /// Opens a file picker to select an image for the current image layer. + /// + [RelayCommand] + private async Task BrowseImageForLayer() + { + var files = await App.StorageProvider.OpenFilePickerAsync( + new FilePickerOpenOptions + { + Title = "Select Reference Image", + AllowMultiple = false, + FileTypeFilter = + [ + new FilePickerFileType("Images") + { + Patterns = ["*.png", "*.jpg", "*.jpeg", "*.webp", "*.bmp"], + }, + ], + } + ); + + if (files.Count == 0 || files[0].TryGetLocalPath() is not { } path) + return; + + // If no layer selected or current layer is paint, create a new image layer + if (SelectedLayer is null || SelectedLayer.LayerType != MaskLayerType.Image) + AddImageLayer(); + + await LoadImageIntoLayerAsync(SelectedLayer!, path); + } + + /// + /// Loads an image from the given path into the specified layer. + /// + private async Task LoadImageIntoLayerAsync(MaskLayer layer, string imagePath) + { + if (layer.LayerType != MaskLayerType.Image) + return; + + try + { + // Load bitmap on background thread + var bitmap = await Task.Run(() => + { + using var stream = File.OpenRead(imagePath); + return SKBitmap.Decode(stream); + }); + + if (bitmap is null) + return; + + // Dispose old bitmap + layer.SourceImage?.Dispose(); + + // Store the path and bitmap + layer.SourceImagePath = imagePath; + layer.SourceImage = bitmap; + + // Refresh canvas (SourceImage property change also triggers Layer_PropertyChanged) + SyncSelectedLayerToCanvas(); + } + catch (Exception ex) + { + logger.LogWarning(ex, "Failed to load image into layer from path {ImagePath}", imagePath); + } + } + + private void Layer_PropertyChanged(object? sender, PropertyChangedEventArgs e) + { + var changedLayer = sender as MaskLayer; + + // Handle color change: save canvas paths first, recolor them, update brush + if (e.PropertyName == nameof(MaskLayer.DisplayColor) && changedLayer == SelectedLayer) + { + // Save current canvas paths to layer first (so we don't lose any strokes) + SaveCurrentLayerPaths(); + + // Recolor the saved paths with the new color + if (changedLayer?.Paths.Count > 0) + { + var newColor = changedLayer.DisplayColor; + var recoloredPaths = changedLayer + .Paths.Select(p => + p.IsErase ? p : p with { FillColor = newColor.WithAlpha(p.FillColor.Alpha) } + ) + .ToImmutableList(); + changedLayer.Paths = recoloredPaths; + } + + // Update brush color for new strokes + PaintCanvasViewModel.PaintBrushColor = changedLayer?.AvaloniaDisplayColor; + + // Sync the recolored paths back to canvas + SyncSelectedLayerToCanvas(); + return; + } + + // Refresh canvas when visibility, opacity, paths, lock, or image scale changes + if ( + e.PropertyName + is nameof(MaskLayer.IsVisible) + or nameof(MaskLayer.Opacity) + or nameof(MaskLayer.ImageScale) + or nameof(MaskLayer.SourceImage) + or nameof(MaskLayer.Paths) + or nameof(MaskLayer.IsLocked) + ) + { + // Save paths before sync, but handle visibility toggle specially: + // - When toggling OFF (IsVisible is now false): canvas has paths, SAVE them + // - When toggling ON (IsVisible is now true): canvas was empty, DON'T save + // - For color changes: skip save since MaskLayer itself updates Paths + if ( + changedLayer == SelectedLayer + && e.PropertyName == nameof(MaskLayer.IsVisible) + && changedLayer!.IsVisible + ) + { + // Toggling ON - skip save (canvas was empty while hidden) + } + else if (e.PropertyName == nameof(MaskLayer.Paths)) + { + // Paths change from layer update (e.g., other layer's color change), just refresh + } + else + { + // All other cases: save paths + // Force save if we are toggling off the selected layer (it's hidden now, but canvas has valid paths) + var force = changedLayer == SelectedLayer && e.PropertyName == nameof(MaskLayer.IsVisible); + SaveCurrentLayerPaths(force); + } + + SyncSelectedLayerToCanvas(); + } + } + + /// + /// Refreshes the canvas composite. Call after drawing to update layer order. + /// + public void RefreshComposite() + { + SyncSelectedLayerToCanvas(); + } + + /// + /// Clears the content (paths) of the specified layer without deleting it. + /// + [RelayCommand] + private void ClearLayerContent(MaskLayer? layer) + { + layer ??= SelectedLayer; + if (layer is null || layer.LayerType == MaskLayerType.Image) + return; + + PushLayerUndoSnapshot(); + + // If this is the selected layer, clear the canvas paths too + if (layer == SelectedLayer) + PaintCanvasViewModel.Paths = []; + + layer.Paths = []; + SyncSelectedLayerToCanvas(); + } + + /// + /// Deletes the specified layer, or the selected layer if none specified. + /// + [RelayCommand(CanExecute = nameof(CanDeleteLayer))] + private void DeleteLayer(MaskLayer? target = null) + { + var layerToRemove = target ?? SelectedLayer; + if (layerToRemove is null || Layers.Count <= 1) + return; + + PushLayerUndoSnapshot(); + + RunWithLayerIndexChangeSuppressed(() => + { + var index = Layers.IndexOf(layerToRemove); + + // Unsubscribe and dispose before removing + CleanupLayer(layerToRemove); + Layers.Remove(layerToRemove); + + // Select adjacent layer if we removed the selected one + if (SelectedLayer is null || !Layers.Contains(SelectedLayer)) + { + if (Layers.Count > 0) + { + SelectedLayer = Layers[Math.Min(index, Layers.Count - 1)]; + } + else + { + SelectedLayer = null; + PaintCanvasViewModel.Paths = []; + } + } + + SyncSelectedLayerToCanvas(); + PaintCanvasViewModel.RefreshCanvas?.Invoke(); + }); + } + + /// + /// Unsubscribes event handlers and disposes resources for a layer. + /// + private void CleanupLayer(MaskLayer layer) + { + layer.PropertyChanged -= Layer_PropertyChanged; + layer.SourceImage?.Dispose(); + } + + private bool CanDeleteLayer() + { + return SelectedLayer is not null && Layers.Count > 1; + } + + /// + /// Moves the specified layer (or selected layer if null) up in the list (toward top of list = drawn ON TOP of + /// others). + /// + [RelayCommand] + private void MoveLayerUp(MaskLayer? layer) + { + layer ??= SelectedLayer; + if (layer is null || Layers.Count <= 1) + return; + + var index = Layers.IndexOf(layer); + if (index > 0) + { + SaveCurrentLayerPaths(); + RunWithLayerIndexChangeSuppressed(() => + { + Layers.Move(index, index - 1); + SelectedLayer = layer; + SyncSelectedLayerToCanvas(); + }); + } + } + + /// + /// Moves the specified layer (or selected layer if null) down in the list (toward bottom of list = drawn UNDER + /// others). + /// + [RelayCommand] + private void MoveLayerDown(MaskLayer? layer) + { + layer ??= SelectedLayer; + if (layer is null || Layers.Count <= 1) + return; + + var index = Layers.IndexOf(layer); + if (index >= 0 && index < Layers.Count - 1) + { + SaveCurrentLayerPaths(); + RunWithLayerIndexChangeSuppressed(() => + { + Layers.Move(index, index + 1); + SelectedLayer = layer; + SyncSelectedLayerToCanvas(); + }); + } + } + + private void RunWithLayerIndexChangeSuppressed(Action action) + { + layerIndexChangeSuppressionCount++; + try + { + action(); + } + finally + { + layerIndexChangeSuppressionCount = Math.Max(0, layerIndexChangeSuppressionCount - 1); + } + } + + /// + /// Handles layer index changes from drag-drop reordering in the UI. + /// Called by the View when a layer is dropped at a new position. + /// + /// The layer that was moved. + /// The new index where the layer was dropped. + public void OnLayerIndexChanged(MaskLayer layer, int newIndex) + { + if (layerIndexChangeSuppressionCount > 0) + return; + + var currentIndex = Layers.IndexOf(layer); + if (currentIndex < 0 || currentIndex == newIndex) + return; + + if (newIndex < 0 || newIndex >= Layers.Count) + return; + + // Save current layer paths before moving + SaveCurrentLayerPaths(); + + RunWithLayerIndexChangeSuppressed(() => + { + // Move the layer to the new position + Layers.Move(currentIndex, newIndex); + + // Refresh the canvas to reflect the new layer order + SyncSelectedLayerToCanvas(); + }); + } + + /// + /// Fills the selected layer with a rectangle covering the entire canvas. + /// + [RelayCommand] + private void FillLayer(MaskLayer? target = null) + { + var layer = target ?? SelectedLayer; + if (layer is null || layer.LayerType != MaskLayerType.Paint) + return; + + // Create a rectangle path covering the entire canvas + var fillPath = new PenPath + { + PathType = PenPathType.Rectangle, + FillColor = layer.DisplayColor, + Bounds = new SKRect(0, 0, CanvasSize.Width, CanvasSize.Height), + }; + + // Add to current paths + layer.Paths = layer.Paths.Add(fillPath); + SyncSelectedLayerToCanvas(); + } + + /// + /// Clears all layers and creates a fresh empty layer. + /// + [RelayCommand] + private void ClearAllLayers() + { + PushLayerUndoSnapshot(); + + RunWithLayerIndexChangeSuppressed(() => + { + // Clear all layers + SelectedLayer = null; + foreach (var layer in Layers) + CleanupLayer(layer); + Layers.Clear(); + layerCounter = 0; + + // Add a fresh layer + AddLayer(); + }); + } + + /// + /// Inverts the selected layer's mask by creating a full-canvas fill + /// and setting the existing paths to erase mode. + /// + [RelayCommand] + private void InvertLayer(MaskLayer? target = null) + { + var layer = target ?? SelectedLayer; + if (layer is null || layer.LayerType != MaskLayerType.Paint) + return; + + PushLayerUndoSnapshot(); + + var currentPaths = layer.Paths; + + // If no paths, fill the entire canvas + if (currentPaths.Count == 0) + { + FillLayer(layer); + return; + } + + // Create a full canvas fill as the base + var fullFill = new PenPath + { + PathType = PenPathType.Rectangle, + FillColor = layer.DisplayColor, + Bounds = new SKRect(0, 0, CanvasSize.Width, CanvasSize.Height), + }; + + // Convert existing paths to erase mode + var erasePaths = currentPaths.Select(p => p with { IsErase = !p.IsErase }).ToList(); + + // Rebuild paths: full fill first, then inverted paths + var newPaths = ImmutableList.Create(fullFill).AddRange(erasePaths); + layer.Paths = newPaths; + + SyncSelectedLayerToCanvas(); + } + + /// + /// Expands all layer details. + /// + [RelayCommand] + private void ExpandAllLayers() + { + foreach (var layer in Layers) + layer.IsExpanded = true; + } + + /// + /// Collapses all layer details. + /// + [RelayCommand] + private void CollapseAllLayers() + { + foreach (var layer in Layers) + layer.IsExpanded = false; + } + + #region Quick Division Presets + + /// + /// Creates layers for a quick division preset. + /// Preserves prompts and settings from existing layers where possible. + /// + /// Array of (left, top, right, bottom) fractions (0.0-1.0) for each region. + /// Optional names for each region layer. + private void CreateQuickDivisionLayers(SKRect[] divisions, string[]? names = null) + { + if (CanvasSize == Size.Empty) + return; + + PushLayerUndoSnapshot(); + + // Save current layer paths before modifying + SaveCurrentLayerPaths(); + + // Always capture existing layer settings so we can preserve them + // This preserves prompts when going to more layers, equal layers, or fewer layers + var existingSettings = Layers + .Select(l => (l.Prompt, l.NegativePrompt, l.Strength, l.ConditioningArea, l.Opacity, l.IsEnabled)) + .ToList(); + + RunWithLayerIndexChangeSuppressed(() => + { + // Clear existing layers + SelectedLayer = null; + foreach (var layer in Layers) + CleanupLayer(layer); + Layers.Clear(); + layerCounter = 0; + + // Create new layers for each division + for (var i = 0; i < divisions.Length; i++) + { + layerCounter++; + var layer = new MaskLayer + { + Name = names != null && i < names.Length ? names[i] : $"Region {layerCounter}", + DisplayColor = MaskLayerColors.GetByIndex(layerCounter - 1), + }; + + // Restore settings from existing layers if available + // This preserves prompts from the first N existing layers + if (i < existingSettings.Count) + { + var settings = existingSettings[i]; + layer.Prompt = settings.Prompt; + layer.NegativePrompt = settings.NegativePrompt; + layer.Strength = settings.Strength; + layer.ConditioningArea = settings.ConditioningArea; + layer.Opacity = settings.Opacity; + layer.IsEnabled = settings.IsEnabled; + } + + // Calculate the actual pixel bounds from fractions + var fractionalRect = divisions[i]; + var pixelRect = new SKRect( + fractionalRect.Left * CanvasSize.Width, + fractionalRect.Top * CanvasSize.Height, + fractionalRect.Right * CanvasSize.Width, + fractionalRect.Bottom * CanvasSize.Height + ); + + // Create a filled rectangle for this region + var fillPath = new PenPath + { + PathType = PenPathType.Rectangle, + FillColor = layer.DisplayColor, + Bounds = pixelRect, + }; + layer.Paths = layer.Paths.Add(fillPath); + + // Subscribe to layer property changes + layer.PropertyChanged += Layer_PropertyChanged; + + Layers.Add(layer); + } + + // Select the first layer + if (Layers.Count > 0) + SelectedLayer = Layers[0]; + + SyncSelectedLayerToCanvas(); + }); + } + + /// + /// Rescales all layer paths from oldSize to newSize coordinates. + /// + private void RescaleAllLayersInternal(Size oldSize, Size newSize) + { + if (oldSize == newSize || oldSize.Width <= 0 || oldSize.Height <= 0) + return; + + // Save current layer before rescaling + SaveCurrentLayerPaths(); + + var scaleX = (float)newSize.Width / oldSize.Width; + var scaleY = (float)newSize.Height / oldSize.Height; + + foreach (var layer in Layers) + { + if (layer.LayerType != MaskLayerType.Paint || layer.Paths.Count == 0) + continue; + + var scaledPaths = layer + .Paths.Select(path => ScalePenPath(path, scaleX, scaleY)) + .ToImmutableList(); + layer.Paths = scaledPaths; + } + + // Refresh the canvas to show rescaled paths + SyncSelectedLayerToCanvas(); + } + + /// + /// Scales a PenPath by the given factors. + /// + private static PenPath ScalePenPath(PenPath path, float scaleX, float scaleY) + { + // Scale the bounds + var scaledBounds = new SKRect( + path.Bounds.Left * scaleX, + path.Bounds.Top * scaleY, + path.Bounds.Right * scaleX, + path.Bounds.Bottom * scaleY + ); + + // Scale points if present (for freehand paths) + var scaledPoints = + path.Points.Count > 0 + ? path + .Points.Select(p => new PenPoint(p.X * scaleX, p.Y * scaleY) + { + Pressure = p.Pressure, + IsPen = p.IsPen, + Radius = p.Radius * Math.Max(scaleX, scaleY), // Scale radius too + }) + .ToList() + : path.Points; + + // Scale the stroke radius proportionally + var scaledRadius = path.Radius * Math.Max(scaleX, scaleY); + var scaledStrokeWidth = path.StrokeWidth * Math.Max(scaleX, scaleY); + + return path with + { + Bounds = scaledBounds, + Points = scaledPoints, + Radius = (float)scaledRadius, + StrokeWidth = (float)scaledStrokeWidth, + }; + } + + /// + /// Manually rescales all layers to fit the current canvas size. + /// Useful when layers were created at a different resolution. + /// + [RelayCommand] + private void RescaleAllLayers() + { + if ( + _previousCanvasSize != CanvasSize + && _previousCanvasSize.Width > 0 + && _previousCanvasSize.Height > 0 + ) + { + RescaleAllLayersInternal(_previousCanvasSize, CanvasSize); + _previousCanvasSize = CanvasSize; + } + } + + /// + /// Quick preset: 50/50 horizontal split (left and right halves). + /// + [RelayCommand] + private void QuickDivisionHorizontal5050() + { + CreateQuickDivisionLayers( + [ + new SKRect(0f, 0f, 0.5f, 1f), // Left half + new SKRect(0.5f, 0f, 1f, 1f), // Right half + ], + ["Left", "Right"] + ); + } + + /// + /// Quick preset: 33/33/33 horizontal split (thirds). + /// + [RelayCommand] + private void QuickDivisionHorizontal333333() + { + CreateQuickDivisionLayers( + [ + new SKRect(0f, 0f, 0.333f, 1f), // Left third + new SKRect(0.333f, 0f, 0.666f, 1f), // Middle third + new SKRect(0.666f, 0f, 1f, 1f), // Right third + ], + ["Left", "Center", "Right"] + ); + } + + /// + /// Quick preset: 50/50 vertical split (top and bottom halves). + /// + [RelayCommand] + private void QuickDivisionVertical5050() + { + CreateQuickDivisionLayers( + [ + new SKRect(0f, 0f, 1f, 0.5f), // Top half + new SKRect(0f, 0.5f, 1f, 1f), // Bottom half + ], + ["Top", "Bottom"] + ); + } + + /// + /// Quick preset: 33/33/33 vertical split (thirds). + /// + [RelayCommand] + private void QuickDivisionVertical333333() + { + CreateQuickDivisionLayers( + [ + new SKRect(0f, 0f, 1f, 0.333f), // Top third + new SKRect(0f, 0.333f, 1f, 0.666f), // Middle third + new SKRect(0f, 0.666f, 1f, 1f), // Bottom third + ], + ["Top", "Middle", "Bottom"] + ); + } + + /// + /// Quick preset: 2x2 quadrants. + /// + [RelayCommand] + private void QuickDivisionQuadrants() + { + CreateQuickDivisionLayers( + [ + new SKRect(0f, 0f, 0.5f, 0.5f), // Top-left + new SKRect(0.5f, 0f, 1f, 0.5f), // Top-right + new SKRect(0f, 0.5f, 0.5f, 1f), // Bottom-left + new SKRect(0.5f, 0.5f, 1f, 1f), // Bottom-right + ], + ["Top-Left", "Top-Right", "Bottom-Left", "Bottom-Right"] + ); + } + + /// + /// Quick preset: 3x3 grid (9 regions). + /// + [RelayCommand] + private void QuickDivision3x3Grid() + { + CreateQuickDivisionLayers( + [ + new SKRect(0f, 0f, 0.333f, 0.333f), // Top-left + new SKRect(0.333f, 0f, 0.666f, 0.333f), // Top-center + new SKRect(0.666f, 0f, 1f, 0.333f), // Top-right + new SKRect(0f, 0.333f, 0.333f, 0.666f), // Middle-left + new SKRect(0.333f, 0.333f, 0.666f, 0.666f), // Center + new SKRect(0.666f, 0.333f, 1f, 0.666f), // Middle-right + new SKRect(0f, 0.666f, 0.333f, 1f), // Bottom-left + new SKRect(0.333f, 0.666f, 0.666f, 1f), // Bottom-center + new SKRect(0.666f, 0.666f, 1f, 1f), // Bottom-right + ], + [ + "Top-Left", + "Top-Center", + "Top-Right", + "Middle-Left", + "Center", + "Middle-Right", + "Bottom-Left", + "Bottom-Center", + "Bottom-Right", + ] + ); + } + + /// + /// Quick preset: Center focus (center region with surrounding frame). + /// + [RelayCommand] + private void QuickDivisionCenterFocus() + { + CreateQuickDivisionLayers( + [ + new SKRect(0.25f, 0.25f, 0.75f, 0.75f), // Center (50% of canvas) + new SKRect(0f, 0f, 1f, 0.25f), // Top strip + new SKRect(0f, 0.75f, 1f, 1f), // Bottom strip + new SKRect(0f, 0.25f, 0.25f, 0.75f), // Left strip + new SKRect(0.75f, 0.25f, 1f, 0.75f), // Right strip + ], + ["Center", "Top", "Bottom", "Left", "Right"] + ); + } + + /// + /// Quick preset: Portrait mode (foreground subject with background). + /// Creates a large center oval-ish region and a background region. + /// + [RelayCommand] + private void QuickDivisionPortrait() + { + // For portrait, we create a center region (roughly where a person would be) + // and a background region + CreateQuickDivisionLayers( + [ + new SKRect(0.15f, 0.05f, 0.85f, 0.95f), // Foreground (subject area) + new SKRect(0f, 0f, 1f, 1f), // Background (full canvas, will be behind) + ], + ["Subject", "Background"] + ); + } + + /// + /// Quick preset: Landscape scene (sky, horizon, ground). + /// + [RelayCommand] + private void QuickDivisionLandscape() + { + CreateQuickDivisionLayers( + [ + new SKRect(0f, 0f, 1f, 0.35f), // Sky + new SKRect(0f, 0.35f, 1f, 0.65f), // Horizon/middle ground + new SKRect(0f, 0.65f, 1f, 1f), // Foreground + ], + ["Sky", "Horizon", "Foreground"] + ); + } + + #endregion + + /// + /// Duplicates the selected layer with all its content and settings. + /// + [RelayCommand] + private void DuplicateLayer(MaskLayer? target = null) + { + var source = target ?? SelectedLayer; + if (source is null) + return; + + // Save current layer paths first + SaveCurrentLayerPaths(); + + if (source.LayerType == MaskLayerType.Image) + imageLayerCounter++; + else + layerCounter++; + + var clone = new MaskLayer + { + Name = $"{source.Name} Copy", + LayerType = source.LayerType, + DisplayColor = source.DisplayColor, + Prompt = source.Prompt, + Strength = source.Strength, + Opacity = source.Opacity, + IsVisible = source.IsVisible, + IsEnabled = source.IsEnabled, + Paths = source.Paths, // ImmutableList, safe to share + SourceImagePath = source.SourceImagePath, + ImageScale = source.ImageScale, + ImageOffsetX = source.ImageOffsetX, + ImageOffsetY = source.ImageOffsetY, + IsFlippedHorizontally = source.IsFlippedHorizontally, + IsFlippedVertically = source.IsFlippedVertically, + }; + + // Subscribe to layer property changes + clone.PropertyChanged += Layer_PropertyChanged; + + // Insert after source layer + var index = Layers.IndexOf(source); + RunWithLayerIndexChangeSuppressed(() => + { + Layers.Insert(index + 1, clone); + SelectedLayer = clone; + SyncSelectedLayerToCanvas(); + }); + } + + /// + /// Exports the selected layer as a white-on-black mask PNG. + /// + [RelayCommand] + private async Task ExportLayerAsMaskAsync(MaskLayer? layer) + { + layer ??= SelectedLayer; + if (layer is null || layer.LayerType != MaskLayerType.Paint || CanvasSize == Size.Empty) + return; + + // Save current layer paths before rendering + if (layer == SelectedLayer) + SaveCurrentLayerPaths(); + + var storageProvider = App.StorageProvider; + + var file = await storageProvider.SaveFilePickerAsync( + new FilePickerSaveOptions + { + Title = "Export Mask as PNG", + SuggestedFileName = $"{layer.Name}_mask.png", + FileTypeChoices = [new FilePickerFileType("PNG Image") { Patterns = ["*.png"] }], + } + ); + + if (file is null) + return; + + // Render layer to white-on-black mask + using var bitmap = new SKBitmap( + CanvasSize.Width, + CanvasSize.Height, + SKColorType.Rgba8888, + SKAlphaType.Premul + ); + using var canvas = new SKCanvas(bitmap); + canvas.Clear(SKColors.Black); + + // Render paths as white + using var paint = new SKPaint + { + Color = SKColors.White, + IsAntialias = true, + Style = SKPaintStyle.Fill, + }; + + foreach (var penPath in layer.Paths) + RenderPenPathToCanvas(canvas, penPath, paint, SKColors.White); + + // Save to file + await using var stream = await file.OpenWriteAsync(); + using var image = SKImage.FromBitmap(bitmap); + using var data = image.Encode(SKEncodedImageFormat.Png, 100); + data.SaveTo(stream); + } + + /// + /// Imports a mask image as a new layer, converting white areas to the new layer's color. + /// + [RelayCommand] + private async Task ImportMaskAsLayerAsync() + { + if (CanvasSize == Size.Empty) + return; + + var storageProvider = App.StorageProvider; + + var files = await storageProvider.OpenFilePickerAsync( + new FilePickerOpenOptions + { + Title = "Import Mask Image", + AllowMultiple = false, + FileTypeFilter = + [ + new FilePickerFileType("Image Files") + { + Patterns = ["*.png", "*.jpg", "*.jpeg", "*.bmp"], + }, + ], + } + ); + + if (files.Count == 0) + return; + + var file = files[0]; + await using var stream = await file.OpenReadAsync(); + using var bitmap = SKBitmap.Decode(stream); + if (bitmap is null) + return; + + // Create new paint layer + var newLayer = new MaskLayer + { + Name = $"Imported Mask {Layers.Count + 1}", + LayerType = MaskLayerType.Paint, + DisplayColor = MaskLayerColors.GetByIndex(Layers.Count), + }; + newLayer.PropertyChanged += Layer_PropertyChanged; + + // Scale bitmap to canvas size and create a fill path + // For mask import, we create a bitmap path that covers the canvas + var scaledBitmap = new SKBitmap( + CanvasSize.Width, + CanvasSize.Height, + SKColorType.Rgba8888, + SKAlphaType.Premul + ); + using var scaleCanvas = new SKCanvas(scaledBitmap); + scaleCanvas.Clear(SKColors.Transparent); + + var srcRect = new SKRect(0, 0, bitmap.Width, bitmap.Height); + var destRect = new SKRect(0, 0, CanvasSize.Width, CanvasSize.Height); + scaleCanvas.DrawBitmap(bitmap, srcRect, destRect); + + // Convert white areas to layer color using the mask as a bitmap path + var maskPath = new PenPath + { + PathType = PenPathType.Bitmap, + FillColor = newLayer.DisplayColor, + Bounds = destRect, + BitmapData = scaledBitmap, + }; + + newLayer.Paths = [maskPath]; + RunWithLayerIndexChangeSuppressed(() => + { + Layers.Insert(0, newLayer); + SelectedLayer = newLayer; + SyncSelectedLayerToCanvas(); + }); + } + + /// + /// Called when the selected layer changes. + /// Saves the current layer's paths and loads the new layer's paths. + /// + partial void OnSelectedLayerChanging(MaskLayer? oldValue, MaskLayer? newValue) + { + // Save paths from old layer before switching + SaveCurrentLayerPaths(); + } + + partial void OnSelectedLayerChanged(MaskLayer? value) + { + SyncSelectedLayerToCanvas(); + } + + partial void OnCanvasSizeChanged(Size oldValue, Size newValue) + { + PaintCanvasViewModel.CanvasSize = newValue; + + // Invalidate cached image layer bitmap since canvas size changed + _cachedImageLayerBitmap?.Dispose(); + _cachedImageLayerBitmap = null; + + // Rescale all layers if we have a valid previous size and new size + if (oldValue.Width > 0 && oldValue.Height > 0 && newValue.Width > 0 && newValue.Height > 0) + { + RescaleAllLayersInternal(oldValue, newValue); + } + + _previousCanvasSize = newValue; + } + + partial void OnShowAllLayersChanged(bool value) + { + SyncSelectedLayerToCanvas(); + } + + /// + /// Syncs the selected layer's paths and brush color to the paint canvas. + /// Other visible layers are rendered to their correct z-order positions. + /// + private void SyncSelectedLayerToCanvas() + { + if (SelectedLayer is null) + { + PaintCanvasViewModel.Paths = []; + PaintCanvasViewModel.SetLayerBitmap("LayersBelow", null); + PaintCanvasViewModel.SetLayerBitmap("LayersAbove", null); + PaintCanvasViewModel.RefreshCanvas?.Invoke(); + return; + } + + // Set brush color to layer's display color for visual feedback + PaintCanvasViewModel.PaintBrushColor = SelectedLayer.AvaloniaDisplayColor; + PaintCanvasViewModel.CanvasSize = CanvasSize; + + // Handle different layer types + if (SelectedLayer.LayerType == MaskLayerType.Image) + { + // Image layers are reference-only, disable drawing + PaintCanvasViewModel.IsDrawingEnabled = false; + PaintCanvasViewModel.Paths = []; + + // Render the selected image layer's bitmap directly if visible and has content + if (SelectedLayer.IsVisible && SelectedLayer.SourceImage != null && CanvasSize != Size.Empty) + { + var selectedImageBitmap = RenderSingleImageLayer(SelectedLayer); + // Clone the cached bitmap since SetLayerBitmap will dispose it later + var bitmapToSet = selectedImageBitmap?.Copy(); + PaintCanvasViewModel.SetLayerBitmap("CurrentImage", bitmapToSet); + } + else + { + PaintCanvasViewModel.SetLayerBitmap("CurrentImage", null); + } + } + else if (SelectedLayer.IsVisible) + { + // Paint layer - enable drawing if not locked, show paths if visible + PaintCanvasViewModel.IsDrawingEnabled = !SelectedLayer.IsLocked; + PaintCanvasViewModel.Paths = SelectedLayer.Paths; + PaintCanvasViewModel.SetLayerBitmap("CurrentImage", null); + } + else + { + // Layer is hidden - still allow drawing if not locked but don't render its paths until shown + PaintCanvasViewModel.IsDrawingEnabled = !SelectedLayer.IsLocked; + PaintCanvasViewModel.Paths = []; + PaintCanvasViewModel.SetLayerBitmap("CurrentImage", null); + } + + if (ShowAllLayers && CanvasSize != Size.Empty) + { + // Render layers to their correct z-order positions + var (belowBitmap, aboveBitmap) = RenderLayersByPosition(); + PaintCanvasViewModel.SetLayerBitmap("LayersBelow", belowBitmap); + PaintCanvasViewModel.SetLayerBitmap("LayersAbove", aboveBitmap); + } + else + { + // Clear other layer bitmaps + PaintCanvasViewModel.SetLayerBitmap("LayersBelow", null); + PaintCanvasViewModel.SetLayerBitmap("LayersAbove", null); + } + + PaintCanvasViewModel.RefreshCanvas?.Invoke(); + } + + /// + /// Renders a single image layer's bitmap at the canvas size with scaling. + /// Uses caching to avoid re-rendering on every sync when the image hasn't changed. + /// + /// + /// The cached or newly rendered bitmap. Note: The caller should NOT dispose this bitmap + /// as it is managed by the cache. Returns null if no image is available. + /// + private SKBitmap? RenderSingleImageLayer(MaskLayer layer) + { + if (layer.SourceImage is null || CanvasSize == Size.Empty) + return null; + + // Check if we can use the cached bitmap + if ( + _cachedImageLayerBitmap is not null + && _cachedImageLayerSource == layer + && _cachedImageLayerSourceImage == layer.SourceImage + && Math.Abs(_cachedImageLayerScale - layer.ImageScale) < 0.001 + && Math.Abs(_cachedImageLayerOpacity - layer.Opacity) < 0.001 + && Math.Abs(_cachedImageLayerOffsetX - layer.ImageOffsetX) < 0.001 + && Math.Abs(_cachedImageLayerOffsetY - layer.ImageOffsetY) < 0.001 + && _cachedImageLayerFlipH == layer.IsFlippedHorizontally + && _cachedImageLayerFlipV == layer.IsFlippedVertically + && _cachedImageLayerCanvasSize == CanvasSize + ) + { + return _cachedImageLayerBitmap; + } + + // Dispose old cached bitmap + _cachedImageLayerBitmap?.Dispose(); + + // Create new bitmap + var bitmap = new SKBitmap( + CanvasSize.Width, + CanvasSize.Height, + SKColorType.Rgba8888, + SKAlphaType.Premul + ); + using var canvas = new SKCanvas(bitmap); + canvas.Clear(SKColors.Transparent); + + var alpha = (byte)(layer.Opacity * 255); + RenderImageLayer(canvas, layer, alpha); + + // Update cache + _cachedImageLayerBitmap = bitmap; + _cachedImageLayerSource = layer; + _cachedImageLayerSourceImage = layer.SourceImage; + _cachedImageLayerScale = layer.ImageScale; + _cachedImageLayerOpacity = layer.Opacity; + _cachedImageLayerOffsetX = layer.ImageOffsetX; + _cachedImageLayerOffsetY = layer.ImageOffsetY; + _cachedImageLayerFlipH = layer.IsFlippedHorizontally; + _cachedImageLayerFlipV = layer.IsFlippedVertically; + _cachedImageLayerCanvasSize = CanvasSize; + + return bitmap; + } + + /// + /// Renders visible layers split into two bitmaps: layers below and above the selected layer. + /// This enables proper z-ordering where the selected layer maintains its correct position. + /// Note: In this layer system, LOWER index = drawn on TOP (like Photoshop's layer panel). + /// + /// A tuple of (layersBelow, layersAbove) bitmaps. Either may be null if empty. + private (SKBitmap? LayersBelow, SKBitmap? LayersAbove) RenderLayersByPosition() + { + if (CanvasSize == Size.Empty || SelectedLayer is null) + return (null, null); + + var selectedIndex = Layers.IndexOf(SelectedLayer); + if (selectedIndex < 0) + return (null, null); + + SKBitmap? belowBitmap = null; + SKBitmap? aboveBitmap = null; + + // Layers with LOWER index than selected = drawn on TOP (rendered to Overlay layer) + // These are visually "above" the selected layer + var hasLayersAbove = false; + for (var i = 0; i < selectedIndex; i++) + { + var layer = Layers[i]; + if (layer.IsVisible && LayerHasContent(layer)) + { + hasLayersAbove = true; + break; + } + } + + if (hasLayersAbove) + { + aboveBitmap = new SKBitmap( + CanvasSize.Width, + CanvasSize.Height, + SKColorType.Rgba8888, + SKAlphaType.Premul + ); + using var aboveCanvas = new SKCanvas(aboveBitmap); + aboveCanvas.Clear(SKColors.Transparent); + + // Render in reverse order so that lower index (top layer) is drawn last (on top) + for (var i = selectedIndex - 1; i >= 0; i--) + { + var layer = Layers[i]; + if (!layer.IsVisible || !LayerHasContent(layer)) + continue; + + RenderLayerToCanvas(aboveCanvas, layer); + } + } + + // Layers with HIGHER index than selected = drawn BELOW (rendered to Images layer) + // These are visually "behind" the selected layer + var hasLayersBelow = false; + for (var i = selectedIndex + 1; i < Layers.Count; i++) + { + var layer = Layers[i]; + if (layer.IsVisible && LayerHasContent(layer)) + { + hasLayersBelow = true; + break; + } + } + + if (hasLayersBelow) + { + belowBitmap = new SKBitmap( + CanvasSize.Width, + CanvasSize.Height, + SKColorType.Rgba8888, + SKAlphaType.Premul + ); + using var belowCanvas = new SKCanvas(belowBitmap); + belowCanvas.Clear(SKColors.Transparent); + + // Render from bottom to top (highest index first, as it's the bottom-most) + for (var i = Layers.Count - 1; i > selectedIndex; i--) + { + var layer = Layers[i]; + if (!layer.IsVisible || !LayerHasContent(layer)) + continue; + + RenderLayerToCanvas(belowCanvas, layer); + } + } + + return (belowBitmap, aboveBitmap); + } + + /// + /// Checks if a layer has any renderable content. + /// + private static bool LayerHasContent(MaskLayer layer) + { + return layer.LayerType == MaskLayerType.Image ? layer.SourceImage != null : layer.Paths.Count > 0; + } + + /// + /// Renders a single layer to a canvas with the layer's settings and opacity. + /// Handles both paint layers (paths) and image layers (bitmaps). + /// + private void RenderLayerToCanvas(SKCanvas canvas, MaskLayer layer) + { + var alpha = (byte)(layer.Opacity * 255); + + if (layer.LayerType == MaskLayerType.Image && layer.SourceImage != null) + { + // Render image layer with scaling + RenderImageLayer(canvas, layer, alpha); + } + else + { + // Render paint layer paths + using var paint = new SKPaint + { + Color = new SKColor( + layer.DisplayColor.Red, + layer.DisplayColor.Green, + layer.DisplayColor.Blue, + alpha + ), + IsAntialias = true, + Style = SKPaintStyle.Fill, + }; + + foreach (var penPath in layer.Paths) + RenderPenPathToCanvas(canvas, penPath, paint); + } + } + + /// + /// Renders an image layer with scaling, positioning, and optional flipping. + /// + private void RenderImageLayer(SKCanvas canvas, MaskLayer layer, byte alpha) + { + if (layer.SourceImage is null) + return; + + var bitmap = layer.SourceImage; + var scale = (float)layer.ImageScale; + + // Calculate scaled dimensions + var scaledWidth = bitmap.Width * scale; + var scaledHeight = bitmap.Height * scale; + + // Center the image on the canvas, then apply user offset + var centerOffsetX = (CanvasSize.Width - scaledWidth) / 2f; + var centerOffsetY = (CanvasSize.Height - scaledHeight) / 2f; + var offsetX = centerOffsetX + (float)layer.ImageOffsetX; + var offsetY = centerOffsetY + (float)layer.ImageOffsetY; + + var destRect = new SKRect(offsetX, offsetY, offsetX + scaledWidth, offsetY + scaledHeight); + + using var paint = new SKPaint(); + paint.Color = new SKColor(255, 255, 255, alpha); + paint.IsAntialias = true; + paint.FilterQuality = SKFilterQuality.High; + + // Apply flip transforms if needed + if (layer.IsFlippedHorizontally || layer.IsFlippedVertically) + { + canvas.Save(); + + // Calculate center of the image for flip transformation + var centerX = offsetX + scaledWidth / 2f; + var centerY = offsetY + scaledHeight / 2f; + + // Translate to center, scale (flip), translate back + canvas.Translate(centerX, centerY); + canvas.Scale(layer.IsFlippedHorizontally ? -1 : 1, layer.IsFlippedVertically ? -1 : 1); + canvas.Translate(-centerX, -centerY); + + canvas.DrawBitmap(bitmap, destRect, paint); + canvas.Restore(); + } + else + { + canvas.DrawBitmap(bitmap, destRect, paint); + } + } + + /// + /// Saves the current canvas paths back to the selected layer. + /// Only saves for paint layers that could have been edited. + /// + public void SaveCurrentLayerPaths(bool force = false) + { + // Only save for paint layers (image layers don't have editable paths) + if (SelectedLayer is null || SelectedLayer.LayerType != MaskLayerType.Paint) + return; + + // If the layer is hidden, PaintCanvasViewModel.Paths is cleared (visually hidden) + // by SyncSelectedLayerToCanvas. We should not overwrite the layer's actual paths + // with this empty list. This prevents data loss when moving/updating hidden layers. + if (!force && !SelectedLayer.IsVisible) + return; + + SelectedLayer.Paths = PaintCanvasViewModel.Paths; + } + + /// + /// Gets enabled layers with content (for generation). + /// + public IReadOnlyList GetEnabledLayersWithContent() + { + // Save current layer first + SaveCurrentLayerPaths(); + + return Layers + .Where(l => l.IsEnabled && l.HasContent && !string.IsNullOrWhiteSpace(l.Prompt)) + .ToList(); + } + + /// + /// Renders a specific layer's paths to a white mask image. + /// + public SKImage? RenderLayerToMask(MaskLayer layer) + { + if (layer.Paths.Count == 0 || CanvasSize == Size.Empty) + return null; + + // Create a temporary surface + using var surface = SKSurface.Create(new SKImageInfo(CanvasSize.Width, CanvasSize.Height)); + var canvas = surface.Canvas; + canvas.Clear(SKColors.Transparent); + + // Draw paths in white directly + // We pass White as overrideColor, which RenderPenPath uses for non-erase paths + using var paint = new SKPaint(); + paint.IsAntialias = true; + paint.Style = SKPaintStyle.Fill; + + foreach (var penPath in layer.Paths) + RenderPenPathToCanvas(canvas, penPath, paint, SKColors.White); + + return surface.Snapshot(); + } + + /// + /// Renders a pen path to a canvas. Delegates to PaintCanvasViewModel's shared implementation. + /// + /// If provided, uses this color instead of the path's color. + private static void RenderPenPathToCanvas( + SKCanvas canvas, + PenPath penPath, + SKPaint paint, + SKColor? overrideColor = null + ) + { + PaintCanvasViewModel.RenderPenPath(canvas, penPath, paint, overrideColor); + } + + /// + /// Gets the dialog for this view model. + /// + public BetterContentDialog GetDialog() + { + Dispatcher.UIThread.VerifyAccess(); + + var dialog = new BetterContentDialog + { + Content = this, + ContentVerticalScrollBarVisibility = ScrollBarVisibility.Disabled, + MaxDialogHeight = 2000, + MaxDialogWidth = 2500, + ContentMargin = new Thickness(16), + FullSizeDesired = true, + CloseButtonText = Resources.Action_Close, + DefaultButton = ContentDialogButton.Close, + }; + + return dialog; + } + + /// + public override void LoadStateFromJsonObject(JsonObject state) + { + base.LoadStateFromJsonObject(state); + + // Load canvas size + if ( + state.TryGetPropertyValue("canvasWidth", out var widthNode) + && state.TryGetPropertyValue("canvasHeight", out var heightNode) + ) + { + var width = widthNode?.GetValue() ?? 1024; + var height = heightNode?.GetValue() ?? 1024; + CanvasSize = new Size(width, height); + } + + // Load layers + if (state.TryGetPropertyValue("layers", out var layersNode) && layersNode is JsonArray layersArray) + { + // Clear existing layers and selection + SelectedLayer = null; + foreach (var layer in Layers) + CleanupLayer(layer); + Layers.Clear(); + layerCounter = 0; + imageLayerCounter = 0; + + foreach (var layerNode in layersArray) + if (layerNode is JsonObject layerObj) + { + var layer = new MaskLayer(); + layer.LoadStateFromJsonObject(layerObj); + + // Subscribe to layer property changes (same as AddLayer) + layer.PropertyChanged += Layer_PropertyChanged; + + Layers.Add(layer); + + // Update counters based on layer type + if (layer.LayerType == MaskLayerType.Image) + imageLayerCounter++; + else + layerCounter++; + } + + // Select first layer + if (Layers.Count > 0) + SelectedLayer = Layers[0]; + } + + // Ensure at least one layer exists + if (Layers.Count == 0) + AddLayer(); + + // Reload image layer bitmaps from their saved paths + // This must happen after layers are loaded since SourceImage is not serialized + ReloadImageLayersFromPaths(); + + // Always sync to canvas after loading to ensure paths are displayed + SyncSelectedLayerToCanvas(); + } + + /// + /// Reloads image layer bitmaps from their saved SourceImagePath. + /// Called after loading state since the actual SKBitmap is not serialized. + /// + private void ReloadImageLayersFromPaths() + { + foreach (var layer in Layers) + { + if ( + layer.LayerType == MaskLayerType.Image + && !string.IsNullOrEmpty(layer.SourceImagePath) + && layer.SourceImage == null + && File.Exists(layer.SourceImagePath) + ) + { + // Fire and forget - LoadImageIntoLayerAsync will update UI when done + _ = LoadImageIntoLayerAsync(layer, layer.SourceImagePath); + } + } + } + + /// + public override JsonObject SaveStateToJsonObject() + { + // Save current layer paths first + SaveCurrentLayerPaths(); + + var state = base.SaveStateToJsonObject(); + + // Save canvas size + state["canvasWidth"] = CanvasSize.Width; + state["canvasHeight"] = CanvasSize.Height; + + // Save layers + var layersArray = new JsonArray(); + foreach (var layer in Layers) + layersArray.Add(layer.SaveStateToJsonObject()); + state["layers"] = layersArray; + + return state; + } +} diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/MaskEditorViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/MaskEditorViewModel.cs index c409c78f7..e9fa86cb5 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Dialogs/MaskEditorViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/MaskEditorViewModel.cs @@ -38,7 +38,7 @@ public partial class MaskEditorViewModel(IServiceManager vmFactor { Patterns = new[] { "*.png", "*.jpg", "*.jpeg", "*.webp", "*.json" }, AppleUniformTypeIdentifiers = new[] { "public.image", "public.json" }, - MimeTypes = new[] { "image/*", "application/json" } + MimeTypes = new[] { "image/*", "application/json" }, }; [JsonIgnore] @@ -121,11 +121,12 @@ public BetterContentDialog GetDialog() ContentVerticalScrollBarVisibility = ScrollBarVisibility.Disabled, MaxDialogHeight = 2000, MaxDialogWidth = 2000, + MinDialogWidth = 900, ContentMargin = new Thickness(16), FullSizeDesired = true, PrimaryButtonText = Resources.Action_Save, CloseButtonText = Resources.Action_Cancel, - DefaultButton = ContentDialogButton.Primary + DefaultButton = ContentDialogButton.Primary, }; return dialog; @@ -193,7 +194,7 @@ await image! ".png" => SKEncodedImageFormat.Png, ".jpg" or ".jpeg" => SKEncodedImageFormat.Jpeg, ".webp" => SKEncodedImageFormat.Webp, - _ => throw new NotSupportedException("Unsupported image format") + _ => throw new NotSupportedException("Unsupported image format"), }, 100 ) diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/ModelPickerDialogViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/ModelPickerDialogViewModel.cs new file mode 100644 index 000000000..e11bb6882 --- /dev/null +++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/ModelPickerDialogViewModel.cs @@ -0,0 +1,633 @@ +using System.Collections.Immutable; +using System.Collections.ObjectModel; +using System.ComponentModel; +using System.Reactive.Disposables; +using System.Reactive.Linq; +using AsyncAwaitBestPractices; +using Avalonia.Threading; +using CommunityToolkit.Mvvm.ComponentModel; +using CommunityToolkit.Mvvm.Input; +using DynamicData; +using DynamicData.Binding; +using FuzzySharp; +using Injectio.Attributes; +using StabilityMatrix.Avalonia.Controls; +using StabilityMatrix.Avalonia.Models.Inference; +using StabilityMatrix.Avalonia.Services; +using StabilityMatrix.Avalonia.ViewModels.Base; +using StabilityMatrix.Avalonia.ViewModels.CheckpointManager; +using StabilityMatrix.Avalonia.Views.Dialogs; +using StabilityMatrix.Core.Attributes; +using StabilityMatrix.Core.Extensions; +using StabilityMatrix.Core.Helper.Cache; +using StabilityMatrix.Core.Models; +using StabilityMatrix.Core.Models.Settings; +using StabilityMatrix.Core.Services; + +namespace StabilityMatrix.Avalonia.ViewModels.Dialogs; + +/// +/// Specifies which model source collection to use in the picker. +/// +public enum ModelPickerSource +{ + /// Checkpoints + UNet models (default for model selection) + CheckpointAndUnet, + + /// LoRA models only + Lora, + + /// VAE models only + Vae, + + /// CLIP/Text encoder models only + Clip, + + /// CLIP Vision models only + ClipVision, +} + +[View(typeof(ModelPickerDialog))] +[RegisterTransient] +[ManagedService] +public partial class ModelPickerDialogViewModel : ContentDialogViewModelBase +{ + private readonly IInferenceClientManager clientManager; + private readonly ISettingsManager settingsManager; + private readonly CompositeDisposable propertySubscriptions = new(); + private LRUCache> searchCache = new(100); + private IDisposable? modelSubscription; + private ImmutableList allModels = []; + private int refreshRequestId; + private HashSet pendingSelectedBaseModels = []; + private bool isApplyingSavedFilterState; + private bool isDialogActive; + + /// + /// Gets or sets the model source to use. Set before showing the dialog. + /// + public ModelPickerSource Source { get; set; } = ModelPickerSource.CheckpointAndUnet; + + /// + /// Optional workflow hint from Inference model cards. Used to open near compatible models. + /// + public InferenceWorkflowProfile PreferredWorkflowProfile { get; set; } = InferenceWorkflowProfile.Auto; + + [ObservableProperty] + private string title = "Select Model"; + + [ObservableProperty] + private string searchText = string.Empty; + + [ObservableProperty] + private HybridModelFile? selectedModel; + + [ObservableProperty] + private IReadOnlyList filteredModels = []; + + [ObservableProperty] + private bool showCheckpointsOnly; + + [ObservableProperty] + private bool showUnetsOnly; + + [ObservableProperty] + private bool isGridView; + + [ObservableProperty] + private bool showNsfwContent; + + public ObservableCollection BaseModelOptions { get; } = []; + + public IEnumerable SelectedBaseModelOptions => + BaseModelOptions.Where(x => x.IsSelected); + + public int ActiveFilterCount => + SelectedBaseModelOptions.Count() + (ShowCheckpointsOnly ? 1 : 0) + (ShowUnetsOnly ? 1 : 0); + + public bool HasActiveFilters => ActiveFilterCount > 0; + public bool HasFilteredModels => FilteredModels.Count > 0; + + public string FilterButtonText => HasActiveFilters ? $"Filter ({ActiveFilterCount})" : "Filter"; + + /// + /// Whether to show the folder type filter buttons (Checkpoints/Diffusion Models). + /// Only relevant for CheckpointAndUnet source. + /// + public bool ShowFolderTypeFilters => Source == ModelPickerSource.CheckpointAndUnet; + + public ModelPickerDialogViewModel(IInferenceClientManager clientManager, ISettingsManager settingsManager) + { + this.clientManager = clientManager; + this.settingsManager = settingsManager; + isGridView = settingsManager.Settings.ModelPickerIsGridView; + showNsfwContent = settingsManager.Settings.ModelBrowserNsfwEnabled; + + // Subscribe to search text and filter changes + propertySubscriptions.Add( + Observable + .FromEventPattern(this, nameof(PropertyChanged)) + .Where(x => + x.EventArgs.PropertyName + is nameof(SearchText) + or nameof(ShowCheckpointsOnly) + or nameof(ShowUnetsOnly) + ) + .Throttle(TimeSpan.FromMilliseconds(100)) + .ObserveOn(SynchronizationContext.Current!) + .Subscribe(_ => + { + if (!isDialogActive) + return; + + UpdateFilteredModels(); + OnPropertyChanged(nameof(ActiveFilterCount)); + OnPropertyChanged(nameof(HasActiveFilters)); + OnPropertyChanged(nameof(FilterButtonText)); + SaveFilterStateForCurrentSource(); + }) + ); + + // Subscribe to base model option changes + propertySubscriptions.Add( + BaseModelOptions + .ToObservableChangeSet() + .AutoRefresh(x => x.IsSelected) + .Throttle(TimeSpan.FromMilliseconds(100)) + .ObserveOn(SynchronizationContext.Current!) + .Subscribe(_ => + { + if (!isDialogActive) + return; + + UpdateFilteredModels(); + OnPropertyChanged(nameof(SelectedBaseModelOptions)); + OnPropertyChanged(nameof(ActiveFilterCount)); + OnPropertyChanged(nameof(HasActiveFilters)); + OnPropertyChanged(nameof(FilterButtonText)); + SaveFilterStateForCurrentSource(); + }) + ); + + AddDisposable(propertySubscriptions); + } + + partial void OnIsGridViewChanged(bool value) + { + settingsManager.Transaction(s => s.ModelPickerIsGridView = value); + } + + partial void OnShowNsfwContentChanged(bool value) + { + settingsManager.Transaction(s => s.ModelBrowserNsfwEnabled = value); + } + + partial void OnFilteredModelsChanged(IReadOnlyList value) + { + OnPropertyChanged(nameof(HasFilteredModels)); + } + + partial void OnShowCheckpointsOnlyChanged(bool value) + { + if (value && ShowUnetsOnly) + { + ShowUnetsOnly = false; + } + } + + partial void OnShowUnetsOnlyChanged(bool value) + { + if (value && ShowCheckpointsOnly) + { + ShowCheckpointsOnly = false; + } + } + + public override void OnLoaded() + { + base.OnLoaded(); + isDialogActive = true; + + // Save caller-specified type filters before loading persisted state + // (e.g., WanModelCardViewModel sets ShowUnetsOnly = true before opening) + var preShowUnets = ShowUnetsOnly; + var preShowCheckpoints = ShowCheckpointsOnly; + + LoadFilterStateForCurrentSource(); + + // Re-apply caller-specified type filters (they take priority over saved state) + if (preShowUnets) + ShowUnetsOnly = true; + if (preShowCheckpoints) + ShowCheckpointsOnly = true; + + ApplyPreferredWorkflowProfileFilters(); + + // Populate models in background after dialog appears to reduce opening hitch. + Dispatcher.UIThread.Post(RefreshAllModels, DispatcherPriority.Background); + + // Subscribe to changes in the relevant collections + var subscriptions = new List(); + + switch (Source) + { + case ModelPickerSource.CheckpointAndUnet: + subscriptions.Add( + clientManager + .Models.ToObservableChangeSet< + IObservableCollection, + HybridModelFile + >() + .Throttle(TimeSpan.FromMilliseconds(100)) + .ObserveOn(SynchronizationContext.Current!) + .Subscribe(_ => RefreshAllModels()) + ); + subscriptions.Add( + clientManager + .UnetModels.ToObservableChangeSet< + IObservableCollection, + HybridModelFile + >() + .Throttle(TimeSpan.FromMilliseconds(100)) + .ObserveOn(SynchronizationContext.Current!) + .Subscribe(_ => RefreshAllModels()) + ); + break; + + case ModelPickerSource.Lora: + subscriptions.Add( + clientManager + .LoraModels.ToObservableChangeSet< + IObservableCollection, + HybridModelFile + >() + .Throttle(TimeSpan.FromMilliseconds(100)) + .ObserveOn(SynchronizationContext.Current!) + .Subscribe(_ => RefreshAllModels()) + ); + break; + + case ModelPickerSource.Vae: + subscriptions.Add( + clientManager + .VaeModels.ToObservableChangeSet< + IObservableCollection, + HybridModelFile + >() + .Throttle(TimeSpan.FromMilliseconds(100)) + .ObserveOn(SynchronizationContext.Current!) + .Subscribe(_ => RefreshAllModels()) + ); + break; + + case ModelPickerSource.Clip: + subscriptions.Add( + clientManager + .ClipModels.ToObservableChangeSet< + IObservableCollection, + HybridModelFile + >() + .Throttle(TimeSpan.FromMilliseconds(100)) + .ObserveOn(SynchronizationContext.Current!) + .Subscribe(_ => RefreshAllModels()) + ); + break; + + case ModelPickerSource.ClipVision: + subscriptions.Add( + clientManager + .ClipVisionModels.ToObservableChangeSet< + IObservableCollection, + HybridModelFile + >() + .Throttle(TimeSpan.FromMilliseconds(100)) + .ObserveOn(SynchronizationContext.Current!) + .Subscribe(_ => RefreshAllModels()) + ); + break; + } + + modelSubscription?.Dispose(); + modelSubscription = new CompositeDisposable(subscriptions); + } + + private void RefreshAllModels() + { + if (!isDialogActive) + return; + + RefreshAllModelsAsync().SafeFireAndForget(); + } + + private async Task RefreshAllModelsAsync() + { + var requestId = Interlocked.Increment(ref refreshRequestId); + + var sortedModels = await Task.Run(() => + { + IEnumerable models = Source switch + { + ModelPickerSource.CheckpointAndUnet => clientManager.Models.Concat(clientManager.UnetModels), + ModelPickerSource.Lora => clientManager.LoraModels, + ModelPickerSource.Vae => clientManager.VaeModels, + ModelPickerSource.Clip => clientManager.ClipModels, + ModelPickerSource.ClipVision => clientManager.ClipVisionModels, + _ => [], + }; + + return models + .OrderBy(m => m.ShortDisplayName, StringComparer.OrdinalIgnoreCase) + .ToImmutableList(); + }); + + // Ignore stale refreshes if a newer one was queued while this was running. + if (requestId != refreshRequestId) + { + return; + } + + if (!isDialogActive) + { + return; + } + + allModels = sortedModels; + UpdateAvailableBaseModels(); + UpdateFilteredModels(); + } + + public override void OnUnloaded() + { + base.OnUnloaded(); + isDialogActive = false; + Interlocked.Increment(ref refreshRequestId); + SaveFilterStateForCurrentSource(); + modelSubscription?.Dispose(); + modelSubscription = null; + + // Release only internal heavy references. Avoid touching bound collections here, + // because clearing ItemsSource during close can null out SelectedModel before caller reads it. + allModels = []; + searchCache = new LRUCache>(100); + } + + private void UpdateAvailableBaseModels() + { + var baseModels = allModels + .Where(m => m.Local?.ConnectedModelInfo?.BaseModel != null) + .Select(m => m.Local!.ConnectedModelInfo!.BaseModel!) + .Distinct() + .OrderBy(b => b) + .ToList(); + + // Add "Unknown" for models without metadata + if (allModels.Any(m => m.Local?.ConnectedModelInfo?.BaseModel == null)) + { + baseModels.Add("Unknown"); + } + + // Update BaseModelOptions collection, preserving selection state + var existingSelections = BaseModelOptions + .Where(x => x.IsSelected) + .Select(x => x.ModelType) + .ToHashSet(StringComparer.OrdinalIgnoreCase); + + if (existingSelections.Count == 0 && pendingSelectedBaseModels.Count > 0) + { + existingSelections = new HashSet( + pendingSelectedBaseModels, + StringComparer.OrdinalIgnoreCase + ); + } + + BaseModelOptions.Clear(); + foreach (var baseModel in baseModels) + { + BaseModelOptions.Add( + new BaseModelOptionViewModel + { + ModelType = baseModel, + IsSelected = existingSelections.Contains(baseModel), + } + ); + } + } + + private string GetSourceKey() => Source.ToString(); + + private void LoadFilterStateForCurrentSource() + { + var states = settingsManager.Settings.ModelPickerFilterStates; + if (states is null || !states.TryGetValue(GetSourceKey(), out var state) || state is null) + { + pendingSelectedBaseModels = []; + return; + } + + isApplyingSavedFilterState = true; + try + { + SearchText = state.SearchText ?? string.Empty; + ShowCheckpointsOnly = state.ShowCheckpointsOnly; + ShowUnetsOnly = state.ShowUnetsOnly; + pendingSelectedBaseModels = (state.SelectedBaseModels ?? []).ToHashSet( + StringComparer.OrdinalIgnoreCase + ); + } + finally + { + isApplyingSavedFilterState = false; + } + } + + private void SaveFilterStateForCurrentSource() + { + if (isApplyingSavedFilterState || PreferredWorkflowProfile is not InferenceWorkflowProfile.Auto) + return; + + var selectedBaseModels = + BaseModelOptions.Count > 0 + ? SelectedBaseModelOptions.Select(x => x.ModelType).ToList() + : pendingSelectedBaseModels.ToList(); + + var state = new ModelPickerFilterState + { + SearchText = SearchText.Trim(), + ShowCheckpointsOnly = ShowCheckpointsOnly, + ShowUnetsOnly = ShowUnetsOnly, + SelectedBaseModels = selectedBaseModels, + }; + + settingsManager.Transaction(s => + { + s.ModelPickerFilterStates ??= []; + s.ModelPickerFilterStates[GetSourceKey()] = state; + }); + } + + private void ApplyPreferredWorkflowProfileFilters() + { + if ( + Source is not ModelPickerSource.CheckpointAndUnet + || PreferredWorkflowProfile is InferenceWorkflowProfile.Auto or InferenceWorkflowProfile.Custom + ) + { + return; + } + + if (PreferredWorkflowProfile is InferenceWorkflowProfile.DefaultCheckpoint) + { + ShowCheckpointsOnly = true; + return; + } + + ShowUnetsOnly = true; + pendingSelectedBaseModels = GetPreferredBaseModels(PreferredWorkflowProfile) + .ToHashSet(StringComparer.OrdinalIgnoreCase); + } + + private static IEnumerable GetPreferredBaseModels(InferenceWorkflowProfile profile) + { + return profile switch + { + InferenceWorkflowProfile.Flux => ["Flux.1"], + InferenceWorkflowProfile.Flux2 => ["Flux.2"], + InferenceWorkflowProfile.ZImageBase => ["ZImageBase"], + InferenceWorkflowProfile.ZImageTurbo => ["ZImageTurbo"], + InferenceWorkflowProfile.Anima => ["Anima"], + InferenceWorkflowProfile.HiDream => ["HiDream"], + _ => [], + }; + } + + private void UpdateFilteredModels() + { + var models = allModels.AsEnumerable(); + + // Apply base model filter + var selectedBaseModels = SelectedBaseModelOptions.Select(x => x.ModelType).ToList(); + if (selectedBaseModels.Count > 0) + { + models = models.Where(m => + { + var baseModel = m.Local?.ConnectedModelInfo?.BaseModel; + if (baseModel == null) + { + return selectedBaseModels.Contains("Unknown"); + } + return selectedBaseModels.Contains(baseModel); + }); + } + + // Apply folder type filter + if (ShowCheckpointsOnly) + { + models = models.Where(m => m.Local?.SharedFolderType == SharedFolderType.StableDiffusion); + } + else if (ShowUnetsOnly) + { + models = models.Where(m => m.Local?.SharedFolderType == SharedFolderType.DiffusionModels); + } + + var modelList = models.ToList(); + + // Apply search filter + var query = SearchText.Trim(); + if (!string.IsNullOrWhiteSpace(query)) + { + // Check cache + var selectedBaseModelsKey = string.Join( + ",", + selectedBaseModels.OrderBy(x => x, StringComparer.Ordinal) + ); + var cacheKey = + $"{refreshRequestId}|{query}|{selectedBaseModelsKey}|{ShowCheckpointsOnly}|{ShowUnetsOnly}"; + if (searchCache.Get(cacheKey, out var cachedResults)) + { + FilteredModels = cachedResults!; + return; + } + + var results = modelList + .Select(m => + { + var modelSearchText = m.DetailedSearchText; + var weightedScore = Fuzz.WeightedRatio(query, modelSearchText); + var partialScore = Fuzz.PartialRatio(query, modelSearchText); + var score = Math.Max(weightedScore, partialScore); + var contains = modelSearchText.Contains(query, StringComparison.OrdinalIgnoreCase); + return (Model: m, Score: score, Contains: contains); + }) + .Where(x => x.Contains || x.Score >= 70) + .OrderByDescending(x => x.Contains) + .ThenByDescending(x => x.Score) + .Select(x => x.Model) + .ToImmutableList(); + + searchCache.Add(cacheKey, results); + FilteredModels = results; + } + else + { + FilteredModels = modelList.ToImmutableList(); + } + } + + [RelayCommand] + private void ClearOrSelectAllBaseModels() + { + var anySelected = BaseModelOptions.Any(x => x.IsSelected); + foreach (var option in BaseModelOptions) + { + option.IsSelected = !anySelected; + } + } + + [RelayCommand] + private void ClearFilters() + { + foreach (var option in BaseModelOptions) + { + option.IsSelected = false; + } + ShowCheckpointsOnly = false; + ShowUnetsOnly = false; + SearchText = string.Empty; + } + + [RelayCommand] + private void SelectModel(HybridModelFile? model) + { + if (model != null) + { + SelectedModel = model; + OnPrimaryButtonClick(); + } + } + + [RelayCommand] + private void SetSelectedModel(HybridModelFile? model) + { + SelectedModel = model; + } + + public override BetterContentDialog GetDialog() + { + var dialog = base.GetDialog(); + + dialog.MinDialogWidth = 700; + dialog.MaxDialogWidth = 900; + dialog.MinDialogHeight = 500; + dialog.MaxDialogHeight = 700; + dialog.IsFooterVisible = false; + dialog.CloseOnClickOutside = true; + // Disable dialog's internal scrolling - let the ListBox handle it + dialog.ContentVerticalScrollBarVisibility = global::Avalonia + .Controls + .Primitives + .ScrollBarVisibility + .Disabled; + + return dialog; + } +} diff --git a/StabilityMatrix.Avalonia/ViewModels/Dialogs/OrganizeModelsDialogViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Dialogs/OrganizeModelsDialogViewModel.cs new file mode 100644 index 000000000..47b49ae11 --- /dev/null +++ b/StabilityMatrix.Avalonia/ViewModels/Dialogs/OrganizeModelsDialogViewModel.cs @@ -0,0 +1,261 @@ +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.ComponentModel; +using System.ComponentModel.DataAnnotations; +using System.Linq; +using System.Reactive.Linq; +using Avalonia.Controls.Primitives; +using CommunityToolkit.Mvvm.ComponentModel; +using CommunityToolkit.Mvvm.Input; +using DynamicData.Binding; +using Injectio.Attributes; +using StabilityMatrix.Avalonia.Controls; +using StabilityMatrix.Avalonia.Languages; +using StabilityMatrix.Avalonia.Models.CheckpointOrganizer; +using StabilityMatrix.Avalonia.Models.Inference; +using StabilityMatrix.Avalonia.Services; +using StabilityMatrix.Avalonia.ViewModels.Base; +using StabilityMatrix.Avalonia.Views.Dialogs; +using StabilityMatrix.Core.Attributes; +using StabilityMatrix.Core.Models.Database; +using StabilityMatrix.Core.Services; + +namespace StabilityMatrix.Avalonia.ViewModels.Dialogs; + +[View(typeof(OrganizeModelsDialog))] +[ManagedService] +[RegisterTransient] +public partial class OrganizeModelsDialogViewModel( + ISettingsManager settingsManager, + ModelOrganizationService modelOrganizationService +) : ContentDialogViewModelBase +{ + private IReadOnlyList models = []; + private string modelsRoot = string.Empty; + private string scopePath = string.Empty; + private bool includeNested; + private IReadOnlyList allSortedItems = []; + + public ModelOrganizationMetadataAction RequestedMetadataAction { get; private set; } + + [ObservableProperty] + [NotifyPropertyChangedFor(nameof(CanOrganize))] + [NotifyPropertyChangedFor(nameof(ReadySummary))] + public partial ModelOrganizationPlan? Plan { get; set; } + + [ObservableProperty] + public partial string OrganizePattern { get; set; } = FileNameFormat.DefaultOrganizationTemplate; + + [ObservableProperty] + public partial string PatternPreviewSample { get; set; } = string.Empty; + + [ObservableProperty] + public partial string? PatternValidationError { get; set; } + + [ObservableProperty] + [NotifyPropertyChangedFor(nameof(ShowMissingMetadataWarning))] + [NotifyPropertyChangedFor(nameof(ShowMetadataActions))] + [NotifyPropertyChangedFor(nameof(MissingMetadataText))] + public partial int MissingMetadataCount { get; set; } + + [ObservableProperty] + [NotifyPropertyChangedFor(nameof(ShowIncompleteMetadataWarning))] + [NotifyPropertyChangedFor(nameof(ShowMetadataActions))] + [NotifyPropertyChangedFor(nameof(IncompleteMetadataText))] + public partial int IncompleteMetadataCount { get; set; } + + [ObservableProperty] + public partial bool IsVariablesTipOpen { get; set; } + + [ObservableProperty] + public partial bool ShowReadyItems { get; set; } = true; + + [ObservableProperty] + public partial bool ShowConflictItems { get; set; } = true; + + [ObservableProperty] + public partial bool ShowSkippedItems { get; set; } = true; + + [ObservableProperty] + public partial bool ShowUnchangedItems { get; set; } + + [ObservableProperty] + [NotifyPropertyChangedFor(nameof(ShowUnchangedLabel))] + public partial int UnchangedCount { get; set; } + + public ObservableCollection Items { get; } = []; + + public bool CanOrganize => Plan?.ReadyCount > 0; + + public bool ShowMissingMetadataWarning => MissingMetadataCount > 0; + + public bool ShowIncompleteMetadataWarning => IncompleteMetadataCount > 0; + + public bool ShowMetadataActions => ShowMissingMetadataWarning || ShowIncompleteMetadataWarning; + + public string ReadySummary => + Plan == null + ? string.Empty + : $"{Plan.ReadyCount} ready, {Plan.ConflictCount} conflicts, {Plan.SkippedCount} skipped"; + + public string MissingMetadataText => + string.Format(Resources.TextTemplate_FilesNeedMetadata, MissingMetadataCount); + + public string IncompleteMetadataText => + string.Format(Resources.TextTemplate_FilesIncompleteMetadata, IncompleteMetadataCount); + + public string ShowUnchangedLabel => + string.Format(Resources.TextTemplate_ShowUnchangedCount, UnchangedCount); + + public IEnumerable OrganizationFormatVars => + FileNameFormatProvider + .GetSampleForOrganization() + .Substitutions.Where(kv => FileNameFormatProvider.LocalOrganizationVariables.Contains(kv.Key)) + .Select(kv => new FileNameFormatVar { Variable = $"{{{kv.Key}}}", Example = kv.Value.Invoke() }); + + public void Initialize( + IEnumerable allModels, + string rootPath, + string scope, + bool nested, + string? initialPattern + ) + { + models = allModels.ToList(); + modelsRoot = rootPath; + scopePath = scope; + includeNested = nested; + RequestedMetadataAction = ModelOrganizationMetadataAction.None; + + OrganizePattern = string.IsNullOrWhiteSpace(initialPattern) + ? FileNameFormat.DefaultOrganizationTemplate + : initialPattern; + + RebuildPlan(); + + AddDisposable( + this.WhenPropertyChanged(vm => vm.OrganizePattern) + .Throttle(TimeSpan.FromMilliseconds(300)) + .ObserveOn(SynchronizationContext.Current!) + .Subscribe(_ => RebuildPlan()) + ); + } + + private void RebuildPlan() + { + var plan = modelOrganizationService.BuildPlan( + models, + modelsRoot, + scopePath, + includeNested, + OrganizePattern + ); + + Plan = plan; + PatternValidationError = plan.ValidationError; + + // Update preview sample + UpdatePreviewSample(); + + // Store all sorted items and update visible items + allSortedItems = plan.Items.OrderBy(i => i.SortOrder).ToList(); + UnchangedCount = plan.Items.Count(i => i.IsUnchanged); + RefreshVisibleItems(); + + // Count models with no connected metadata at all + MissingMetadataCount = plan.Items.Count(i => + i.Status == ModelOrganizationPreviewStatus.Skipped && !i.Model.HasConnectedModel + ); + + // Count models with connected metadata but missing fields needed by the template + IncompleteMetadataCount = plan.Items.Count(i => + i.Status == ModelOrganizationPreviewStatus.Skipped + && i.Model.HasConnectedModel + && i.Reason?.Contains("not available", StringComparison.OrdinalIgnoreCase) == true + ); + } + + partial void OnShowReadyItemsChanged(bool value) => RefreshVisibleItems(); + + partial void OnShowConflictItemsChanged(bool value) => RefreshVisibleItems(); + + partial void OnShowSkippedItemsChanged(bool value) => RefreshVisibleItems(); + + partial void OnShowUnchangedItemsChanged(bool value) => RefreshVisibleItems(); + + private void RefreshVisibleItems() + { + Items.Clear(); + foreach (var item in allSortedItems) + { + var visible = item.Status switch + { + ModelOrganizationPreviewStatus.Ready => ShowReadyItems, + ModelOrganizationPreviewStatus.Conflict => ShowConflictItems, + ModelOrganizationPreviewStatus.Skipped => ShowSkippedItems, + ModelOrganizationPreviewStatus.Unchanged => ShowUnchangedItems, + _ => true, + }; + + if (visible) + { + Items.Add(item); + } + } + } + + private void UpdatePreviewSample() + { + var provider = FileNameFormatProvider.GetSampleForOrganization(); + var template = OrganizePattern; + + var format = + !string.IsNullOrEmpty(template) && provider.Validate(template) == ValidationResult.Success + ? FileNameFormat.Parse(template, provider) + : FileNameFormat.Parse(FileNameFormat.DefaultOrganizationTemplate, provider); + + PatternPreviewSample = string.Format( + Resources.TextTemplate_PatternPreviewExample, + format.GetFileName() + ".safetensors" + ); + } + + public override BetterContentDialog GetDialog() + { + var dialog = base.GetDialog(); + dialog.MinDialogWidth = 1120; + dialog.MaxDialogHeight = 900; + dialog.IsFooterVisible = false; + dialog.CloseOnClickOutside = true; + dialog.ContentVerticalScrollBarVisibility = ScrollBarVisibility.Disabled; + return dialog; + } + + [RelayCommand] + private void OpenVariablesTip() => IsVariablesTipOpen = true; + + [RelayCommand] + private void ToggleVariablesTip() => IsVariablesTipOpen = !IsVariablesTipOpen; + + [RelayCommand(CanExecute = nameof(CanOrganize))] + private void ConfirmOrganize() + { + settingsManager.Transaction(s => s.ModelOrganizationFileNamePattern = OrganizePattern); + OnPrimaryButtonClick(); + } + + [RelayCommand] + private void ScanForMetadata() + { + RequestedMetadataAction = ModelOrganizationMetadataAction.ScanMissing; + OnSecondaryButtonClick(); + } + + [RelayCommand] + private void UpdateMetadata() + { + RequestedMetadataAction = ModelOrganizationMetadataAction.UpdateExisting; + OnSecondaryButtonClick(); + } +} diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/ExtraNetworkCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/ExtraNetworkCardViewModel.cs index 47532dfe2..864c49eb6 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/ExtraNetworkCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/ExtraNetworkCardViewModel.cs @@ -7,10 +7,12 @@ using CommunityToolkit.Mvvm.Input; using DynamicData; using DynamicData.Binding; +using FluentAvalonia.UI.Controls; using Injectio.Attributes; using StabilityMatrix.Avalonia.Controls; using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Base; +using StabilityMatrix.Avalonia.ViewModels.Dialogs; using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Models; @@ -26,6 +28,7 @@ namespace StabilityMatrix.Avalonia.ViewModels.Inference; public partial class ExtraNetworkCardViewModel : DisposableLoadableViewModelBase { private readonly ISettingsManager settingsManager; + private readonly IServiceManager vmFactory; private readonly ModelCompatChecker modelCompatChecker = new(); public const string ModuleKey = "ExtraNetwork"; @@ -64,9 +67,14 @@ public partial class ExtraNetworkCardViewModel : DisposableLoadableViewModelBase private HybridModelFile? selectedBaseModel; /// - public ExtraNetworkCardViewModel(IInferenceClientManager clientManager, ISettingsManager settingsManager) + public ExtraNetworkCardViewModel( + IInferenceClientManager clientManager, + ISettingsManager settingsManager, + IServiceManager vmFactory + ) { this.settingsManager = settingsManager; + this.vmFactory = vmFactory; ClientManager = clientManager; // Observable signal when SelectedBaseModel changes @@ -158,6 +166,23 @@ private void CopyTriggerWords() App.Clipboard.SetTextAsync(TriggerWords); } + [RelayCommand] + private async Task OpenLoraPickerAsync() + { + using var pickerScope = vmFactory.CreateScope(); + var pickerVm = pickerScope.ServiceManager.Get(); + pickerVm.Title = "Select LoRA"; + pickerVm.Source = ModelPickerSource.Lora; + + if (await pickerVm.GetDialog().ShowAsync() == ContentDialogResult.Primary) + { + if (pickerVm.SelectedModel is { } selected) + { + SelectedModel = selected; + } + } + } + private bool FilterCompatibleLoras(HybridModelFile? lora) { if (!settingsManager.Settings.FilterExtraNetworksByBaseModel) diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/ImageGalleryCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/ImageGalleryCardViewModel.cs index 5de2d6d06..399b1e52c 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/ImageGalleryCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/ImageGalleryCardViewModel.cs @@ -50,6 +50,36 @@ public partial class ImageGalleryCardViewModel : ViewModelBase [ObservableProperty] private bool isPixelGridEnabled; + /// + /// Synchronize SelectedImage when SelectedImageIndex changes. + /// + partial void OnSelectedImageIndexChanged(int value) + { + if (value >= 0 && value < ImageSources.Count) + { + SelectedImage = ImageSources[value]; + } + else + { + SelectedImage = null; + } + } + + /// + /// Synchronize SelectedImageIndex when SelectedImage changes (e.g., from thumbnail click). + /// + partial void OnSelectedImageChanged(ImageSource? value) + { + if (value is not null) + { + var index = ImageSources.IndexOf(value); + if (index >= 0 && index != SelectedImageIndex) + { + SelectedImageIndex = index; + } + } + } + public bool HasMultipleImages => ImageSources.Count > 1; public bool CanNavigateBack => SelectedImageIndex > 0; diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageToImageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageToImageViewModel.cs index 3450faa37..1a18da444 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageToImageViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceImageToImageViewModel.cs @@ -46,6 +46,7 @@ TabContext tabContext SelectImageCardViewModel = vmFactory.Get(vm => { vm.IsMaskEditorEnabled = true; + vm.SyncBitmapSizeToTabContext = true; }); SamplerCardViewModel.IsDenoiseStrengthEnabled = true; @@ -60,7 +61,7 @@ protected override void BuildPrompt(BuildPromptEventArgs args) builder.Connections.Seed = args.SeedOverride switch { { } seed => Convert.ToUInt64(seed), - _ => Convert.ToUInt64(SeedCardViewModel.Seed) + _ => Convert.ToUInt64(SeedCardViewModel.Seed), }; var applyArgs = args.ToModuleApplyStepEventArgs(); @@ -76,21 +77,10 @@ protected override void BuildPrompt(BuildPromptEventArgs args) // Prompts and loras PromptCardViewModel.ApplyStep(applyArgs); + ApplyModelSamplingForCurrentWorkflow(applyArgs); + // Setup Sampler and Refiner if enabled - var isUnetLoader = - ModelCardViewModel.SelectedModelLoader is ModelLoader.Unet || ModelCardViewModel.IsGguf; - if (isUnetLoader) - { - SamplerCardViewModel.ApplyStepsInitialCustomSampler(applyArgs, true); - } - else if (SamplerCardViewModel.SelectedScheduler?.Name is "align_your_steps") - { - SamplerCardViewModel.ApplyStepsInitialCustomSampler(applyArgs, false); - } - else - { - SamplerCardViewModel.ApplyStep(applyArgs); - } + ApplySamplerForCurrentWorkflow(applyArgs, includeGgufAsFluxGuidance: true); // Apply module steps foreach (var module in ModulesCardViewModel.Cards.OfType()) diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs index 39ff2eb33..d081f3717 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceTextToImageViewModel.cs @@ -1,11 +1,7 @@ -ο»Ώusing System; -using System.Collections.Generic; -using System.Linq; -using System.Reactive.Linq; +ο»Ώusing System.Reactive.Linq; using System.Text.Json.Nodes; using System.Text.Json.Serialization; -using System.Threading; -using System.Threading.Tasks; +using DesktopNotifications; using DynamicData.Binding; using Injectio.Attributes; using NLog; @@ -19,7 +15,9 @@ using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models.Api.Comfy; +using StabilityMatrix.Core.Models.Api.Comfy.Nodes; using StabilityMatrix.Core.Models.Inference; +using StabilityMatrix.Core.Models.Settings; using StabilityMatrix.Core.Services; using InferenceTextToImageView = StabilityMatrix.Avalonia.Views.Inference.InferenceTextToImageView; @@ -80,6 +78,7 @@ TabContext tabContext SeedCardViewModel.GenerateNewSeed(); ModelCardViewModel = vmFactory.Get(); + ModelCardViewModel.RecommendedDefaultsRequested += ApplyRecommendedDefaults; // When the model changes in the ModelCardViewModel, we'll have access to it in the TabContext @@ -105,6 +104,7 @@ TabContext tabContext typeof(HiresFixModule), typeof(SaveImageModule), typeof(UpscalerModule), + typeof(FaceDetailerModule), }; modulesCard.DefaultModules = new[] { typeof(HiresFixModule), typeof(UpscalerModule) }; modulesCard.InitializeDefaults(); @@ -169,12 +169,22 @@ protected override void BuildPrompt(BuildPromptEventArgs args) // Load models ModelCardViewModel.ApplyStep(applyArgs); - var isUnetLoader = ModelCardViewModel.SelectedModelLoader is ModelLoader.Unet; var useSd3Latent = - SamplerCardViewModel.ModulesCardViewModel.IsModuleEnabled() || isUnetLoader; + SamplerCardViewModel.ModulesCardViewModel.IsModuleEnabled() + || (IsUnetLoader && UsesSd3Latent); var usePlasmaNoise = SamplerCardViewModel.ModulesCardViewModel.IsModuleEnabled(); - if (useSd3Latent) + if (IsFlux2Unet) + { + builder.SetupEmptyLatentSource( + SamplerCardViewModel.Width, + SamplerCardViewModel.Height, + BatchSizeCardViewModel.BatchSize, + BatchSizeCardViewModel.IsBatchIndexEnabled ? BatchSizeCardViewModel.BatchIndex : null, + latentType: LatentType.Flux2 + ); + } + else if (useSd3Latent) { builder.SetupEmptyLatentSource( SamplerCardViewModel.Width, @@ -219,8 +229,108 @@ protected override void BuildPrompt(BuildPromptEventArgs args) // Prompts and loras PromptCardViewModel.ApplyStep(applyArgs); + ApplyModelSamplingForCurrentWorkflow(applyArgs); + // Setup Sampler and Refiner if enabled - if (isUnetLoader) + ApplySamplerForCurrentWorkflow(applyArgs); + + // Hires fix if enabled + foreach (var module in ModulesCardViewModel.Cards.OfType()) + { + module.ApplyStep(applyArgs); + } + + applyArgs.InvokeAllPreOutputActions(); + + builder.SetupOutputImage(); + } + + protected bool IsUnetLoader => ModelCardViewModel.SelectedModelLoader is ModelLoader.Unet; + + protected InferenceWorkflowProfile ResolvedWorkflowProfile => ModelCardViewModel.ResolvedWorkflowProfile; + + protected bool IsAnimaUnet => + IsUnetLoader + && ( + ResolvedWorkflowProfile is InferenceWorkflowProfile.Anima + || ( + ResolvedWorkflowProfile is InferenceWorkflowProfile.Custom + && ModelCardViewModel.SelectedClipType is "stable_diffusion" + ) + ); + + protected bool IsFlux2Unet => + IsUnetLoader + && ( + ResolvedWorkflowProfile is InferenceWorkflowProfile.Flux2 + || ( + ResolvedWorkflowProfile is InferenceWorkflowProfile.Custom + && ModelCardViewModel.SelectedClipType is "flux2" + ) + ); + + protected bool IsZImageUnet => + IsUnetLoader + && ( + ResolvedWorkflowProfile + is InferenceWorkflowProfile.ZImageBase + or InferenceWorkflowProfile.ZImageTurbo + || ( + ResolvedWorkflowProfile is InferenceWorkflowProfile.Custom + && ModelCardViewModel.SelectedClipType is "lumina2" + ) + ); + + protected bool UsesSd3Latent => + ResolvedWorkflowProfile + is InferenceWorkflowProfile.Flux + or InferenceWorkflowProfile.ZImageBase + or InferenceWorkflowProfile.ZImageTurbo + or InferenceWorkflowProfile.HiDream + || ( + ResolvedWorkflowProfile is InferenceWorkflowProfile.Custom + && ModelCardViewModel.SelectedClipType is not "stable_diffusion" and not "flux2" + ); + + protected bool UsesFluxGuidanceSampler => + ResolvedWorkflowProfile is InferenceWorkflowProfile.Flux or InferenceWorkflowProfile.HiDream + || ( + ResolvedWorkflowProfile is InferenceWorkflowProfile.Custom + && IsUnetLoader + && ModelCardViewModel.SelectedClipType is not "stable_diffusion" and not "lumina2" and not "flux2" + ); + + protected void ApplyModelSamplingForCurrentWorkflow(ModuleApplyStepEventArgs applyArgs) + { + if (!IsZImageUnet) + return; + + var builder = applyArgs.Builder; + var modelSampling = builder.Nodes.AddTypedNode( + new ComfyNodeBuilder.ModelSamplingAuraFlow + { + Name = builder.Nodes.GetUniqueName(nameof(ComfyNodeBuilder.ModelSamplingAuraFlow)), + Model = builder.Connections.Base.Model.Unwrap(), + Shift = ModelCardViewModel.Shift, + } + ); + + builder.Connections.Base.Model = modelSampling.Output; + } + + protected void ApplySamplerForCurrentWorkflow( + ModuleApplyStepEventArgs applyArgs, + bool includeGgufAsFluxGuidance = false + ) + { + if (IsFlux2Unet) + { + SamplerCardViewModel.ApplyStepsInitialCustomSampler(applyArgs, false, useFlux2Scheduler: true); + } + else if ( + (includeGgufAsFluxGuidance && ModelCardViewModel.IsGguf) + || (IsUnetLoader && UsesFluxGuidanceSampler) + ) { SamplerCardViewModel.ApplyStepsInitialCustomSampler(applyArgs, true); } @@ -232,16 +342,49 @@ protected override void BuildPrompt(BuildPromptEventArgs args) { SamplerCardViewModel.ApplyStep(applyArgs); } + } - // Hires fix if enabled - foreach (var module in ModulesCardViewModel.Cards.OfType()) + private void ApplyRecommendedDefaults(InferenceWorkflowProfile profile) + { + switch (profile) { - module.ApplyStep(applyArgs); + case InferenceWorkflowProfile.DefaultCheckpoint: + SamplerCardViewModel.SelectedSampler = ComfySampler.EulerAncestral; + SamplerCardViewModel.SelectedScheduler = ComfyScheduler.Normal; + SamplerCardViewModel.Steps = 30; + SamplerCardViewModel.CfgScale = 5.0d; + break; + case InferenceWorkflowProfile.Flux: + SamplerCardViewModel.SelectedSampler = ComfySampler.Euler; + SamplerCardViewModel.SelectedScheduler = ComfyScheduler.Simple; + SamplerCardViewModel.Steps = 20; + SamplerCardViewModel.CfgScale = 3.5d; + break; + case InferenceWorkflowProfile.Flux2: + SamplerCardViewModel.SelectedSampler = ComfySampler.Euler; + SamplerCardViewModel.SelectedScheduler = ComfyScheduler.Normal; + SamplerCardViewModel.Steps = 20; + SamplerCardViewModel.CfgScale = 5.0d; + break; + case InferenceWorkflowProfile.ZImageTurbo: + SamplerCardViewModel.SelectedSampler = ComfySampler.ResMultistep; + SamplerCardViewModel.SelectedScheduler = ComfyScheduler.Simple; + SamplerCardViewModel.Steps = 8; + SamplerCardViewModel.CfgScale = 1.0d; + break; + case InferenceWorkflowProfile.ZImageBase: + SamplerCardViewModel.SelectedSampler = ComfySampler.ResMultistep; + SamplerCardViewModel.SelectedScheduler = ComfyScheduler.Simple; + SamplerCardViewModel.Steps = 30; + SamplerCardViewModel.CfgScale = 4.0d; + break; + case InferenceWorkflowProfile.Anima: + SamplerCardViewModel.SelectedSampler = ComfySampler.ErSde; + SamplerCardViewModel.SelectedScheduler = ComfyScheduler.Simple; + SamplerCardViewModel.Steps = 30; + SamplerCardViewModel.CfgScale = 4.0d; + break; } - - applyArgs.InvokeAllPreOutputActions(); - - builder.SetupOutputImage(); } /// @@ -337,6 +480,24 @@ CancellationToken cancellationToken { await RunGeneration(args, cancellationToken); } + + // Only show batch notification when there's more than one item + // (single items already get a "Prompt Completed" notification) + if (batches > 1) + { + await notificationService.ShowAsync( + NotificationKey.Inference_BatchCompleted, + new Notification + { + Title = "Batch Completed", + Body = + $"Batch of {batches} items [{Guid.NewGuid().ToString()[..7].ToLower()}] completed successfully", + BodyImagePath = ImageGalleryCardViewModel + .ImageSources.LastOrDefault() + ?.LocalFile?.FullPath, + } + ); + } } /// diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceWanTextToVideoViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceWanTextToVideoViewModel.cs index 75f88189c..d6ba8743d 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceWanTextToVideoViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/InferenceWanTextToVideoViewModel.cs @@ -1,4 +1,5 @@ ο»Ώusing System.Text.Json.Serialization; +using DesktopNotifications; using Injectio.Attributes; using StabilityMatrix.Avalonia.Extensions; using StabilityMatrix.Avalonia.Models; @@ -9,6 +10,7 @@ using StabilityMatrix.Avalonia.Views.Inference; using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Models; +using StabilityMatrix.Core.Models.Settings; using StabilityMatrix.Core.Services; namespace StabilityMatrix.Avalonia.ViewModels.Inference; @@ -17,6 +19,8 @@ namespace StabilityMatrix.Avalonia.ViewModels.Inference; [RegisterScoped, ManagedService] public class InferenceWanTextToVideoViewModel : InferenceGenerationViewModelBase, IParametersLoadableState { + private readonly INotificationService notificationService; + [JsonIgnore] public StackCardViewModel StackCardViewModel { get; } @@ -47,6 +51,7 @@ RunningPackageService runningPackageService ) : base(vmFactory, inferenceClientManager, notificationService, settingsManager, runningPackageService) { + this.notificationService = notificationService; SeedCardViewModel = vmFactory.Get(); SeedCardViewModel.GenerateNewSeed(); @@ -185,6 +190,24 @@ CancellationToken cancellationToken { await RunGeneration(args, cancellationToken); } + + // Only show batch notification when there's more than one item + // (single items already get a "Prompt Completed" notification) + if (batches > 1) + { + await notificationService.ShowAsync( + NotificationKey.Inference_BatchCompleted, + new Notification + { + Title = "Batch Completed", + Body = + $"Batch of {batches} items [{Guid.NewGuid().ToString()[..7].ToLower()}] completed successfully", + BodyImagePath = ImageGalleryCardViewModel + .ImageSources.LastOrDefault() + ?.LocalFile?.FullPath, + } + ); + } } /// diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs index 3daa6618f..845ad33b6 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/ModelCardViewModel.cs @@ -1,7 +1,9 @@ -ο»Ώusing System.ComponentModel.DataAnnotations; +ο»Ώusing System.Collections.ObjectModel; +using System.ComponentModel.DataAnnotations; using System.Text.Json.Nodes; using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.Input; +using FluentAvalonia.UI.Controls; using Injectio.Attributes; using StabilityMatrix.Avalonia.Controls; using StabilityMatrix.Avalonia.Languages; @@ -9,8 +11,10 @@ using StabilityMatrix.Avalonia.Models.Inference; using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Base; +using StabilityMatrix.Avalonia.ViewModels.Dialogs; using StabilityMatrix.Avalonia.ViewModels.Inference.Modules; using StabilityMatrix.Core.Attributes; +using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Models.Api.Comfy.Nodes; using StabilityMatrix.Core.Models.Api.Comfy.NodeTypes; @@ -28,12 +32,58 @@ TabContext tabContext ) : LoadableViewModelBase, IParametersLoadableState, IComfyStep { [ObservableProperty] + [NotifyPropertyChangedFor(nameof(SelectedUnifiedModel))] + [NotifyPropertyChangedFor(nameof(WorkflowProfileStatusText))] + [NotifyPropertyChangedFor(nameof(ShowWorkflowProfileStatus))] + [NotifyPropertyChangedFor(nameof(RecommendedDefaultsToolTip))] private HybridModelFile? selectedModel; [ObservableProperty] - [NotifyPropertyChangedFor(nameof(IsGguf), nameof(ShowPrecisionSelection))] + [NotifyPropertyChangedFor( + nameof(IsGguf), + nameof(ShowPrecisionSelection), + nameof(SelectedUnifiedModel), + nameof(HasActiveAdvancedOptions), + nameof(AdvancedOptionsHeader), + nameof(WorkflowProfileStatusText), + nameof(ShowWorkflowProfileStatus), + nameof(RecommendedDefaultsToolTip) + )] private HybridModelFile? selectedUnetModel; + /// + /// Unified model property that auto-detects the loader type based on the model's SharedFolderType. + /// Getter returns the currently active model based on IsStandaloneModelLoader. + /// Setter auto-detects whether it's a checkpoint or UNet model and sets the appropriate properties. + /// + public HybridModelFile? SelectedUnifiedModel + { + get => IsStandaloneModelLoader ? SelectedUnetModel : SelectedModel; + set + { + if (value is null) + { + // ComboBox selection can briefly report null while the model list refreshes. + // Keep the active model so UNet-only encoder slots do not disappear transiently. + return; + } + + // Auto-detect model type based on folder + if (value.Local?.SharedFolderType == SharedFolderType.DiffusionModels) + { + // It's a UNet model from diffusion_models folder + SelectedModelLoader = ModelLoader.Unet; + SelectedUnetModel = value; + } + else + { + // It's a checkpoint model + SelectedModelLoader = ModelLoader.Default; + SelectedModel = value; + } + } + } + [ObservableProperty] private bool isRefinerSelectionEnabled; @@ -47,12 +97,14 @@ TabContext tabContext private HybridModelFile? selectedVae = HybridModelFile.Default; [ObservableProperty] + [NotifyPropertyChangedFor(nameof(HasActiveAdvancedOptions), nameof(AdvancedOptionsHeader))] private bool isVaeSelectionEnabled; [ObservableProperty] private bool disableSettings; [ObservableProperty] + [NotifyPropertyChangedFor(nameof(HasActiveAdvancedOptions), nameof(AdvancedOptionsHeader))] private bool isClipSkipEnabled; [NotifyDataErrorInfo] @@ -67,25 +119,79 @@ TabContext tabContext private bool isModelLoaderSelectionEnabled; [ObservableProperty] - [NotifyPropertyChangedFor(nameof(IsStandaloneModelLoader))] + [NotifyPropertyChangedFor( + nameof(IsStandaloneModelLoader), + nameof(SelectedUnifiedModel), + nameof(ShowPrecisionSelection), + nameof(ShowEncoderSection), + nameof(HasActiveAdvancedOptions), + nameof(AdvancedOptionsHeader), + nameof(WorkflowProfileStatusText), + nameof(ShowWorkflowProfileStatus), + nameof(RecommendedDefaultsToolTip) + )] private ModelLoader selectedModelLoader; - [ObservableProperty] - private HybridModelFile? selectedClip1; + /// + /// Dynamic collection of text encoder slots. + /// + public ObservableCollection TextEncoders { get; } = []; - [ObservableProperty] - private HybridModelFile? selectedClip2; + /// + /// Whether the remove encoder button should be enabled. + /// + public bool CanRemoveEncoder => TextEncoders.Count > 1; - [ObservableProperty] - private HybridModelFile? selectedClip3; + /// + /// Gets the selected model for encoder slot 1 (for backward compatibility with SetupClipLoaders). + /// + public HybridModelFile? SelectedClip1 => TextEncoders.Count > 0 ? TextEncoders[0].SelectedModel : null; - [ObservableProperty] - private HybridModelFile? selectedClip4; + /// + /// Gets the selected model for encoder slot 2 (for backward compatibility with SetupClipLoaders). + /// + public HybridModelFile? SelectedClip2 => TextEncoders.Count > 1 ? TextEncoders[1].SelectedModel : null; + + /// + /// Gets the selected model for encoder slot 3 (for backward compatibility with SetupClipLoaders). + /// + public HybridModelFile? SelectedClip3 => TextEncoders.Count > 2 ? TextEncoders[2].SelectedModel : null; + + /// + /// Gets the selected model for encoder slot 4 (for backward compatibility with SetupClipLoaders). + /// + public HybridModelFile? SelectedClip4 => TextEncoders.Count > 3 ? TextEncoders[3].SelectedModel : null; [ObservableProperty] - [NotifyPropertyChangedFor(nameof(IsSd3Clip), nameof(IsHiDreamClip))] + [NotifyPropertyChangedFor( + nameof(IsSd3Clip), + nameof(IsHiDreamClip), + nameof(ResolvedWorkflowProfile), + nameof(IsHiDreamWorkflow), + nameof(IsZImageWorkflow), + nameof(ShowShift), + nameof(ShowEncoderTypeSelection), + nameof(HasRecommendedDefaults), + nameof(WorkflowProfileStatusText), + nameof(ShowWorkflowProfileStatus), + nameof(RecommendedDefaultsToolTip) + )] private string? selectedClipType; + [ObservableProperty] + [NotifyPropertyChangedFor( + nameof(ResolvedWorkflowProfile), + nameof(IsHiDreamWorkflow), + nameof(IsZImageWorkflow), + nameof(ShowShift), + nameof(ShowEncoderTypeSelection), + nameof(HasRecommendedDefaults), + nameof(WorkflowProfileStatusText), + nameof(ShowWorkflowProfileStatus), + nameof(RecommendedDefaultsToolTip) + )] + private InferenceWorkflowProfile selectedWorkflowProfile = InferenceWorkflowProfile.Auto; + [ObservableProperty] private string? selectedDType; @@ -93,13 +199,29 @@ TabContext tabContext private bool enableModelLoaderSelection = true; [ObservableProperty] + [NotifyPropertyChangedFor(nameof(ShowEncoderSection))] private bool isClipModelSelectionEnabled; [ObservableProperty] private double shift = 3.0d; + /// + /// Whether the Advanced Options expander is expanded. + /// + [ObservableProperty] + private bool isAdvancedOptionsExpanded; + + /// + /// Whether the Text Encoders expander is expanded. + /// + [ObservableProperty] + private bool isTextEncodersExpanded = true; + public List WeightDTypes { get; set; } = ["default", "fp8_e4m3fn", "fp8_e5m2"]; - public List ClipTypes { get; set; } = ["flux", "sd3", "HiDream"]; + public List ClipTypes { get; set; } = + ["flux", "flux2", "lumina2", "stable_diffusion", "sd3", "HiDream"]; + public List WorkflowProfiles { get; set; } = + Enum.GetValues().ToList(); public StackEditableCardViewModel ExtraNetworksStackCardViewModel { get; } = new(vmFactory) { Title = Resources.Label_ExtraNetworks, AvailableModules = [typeof(LoraModule)] }; @@ -111,14 +233,103 @@ TabContext tabContext public bool IsStandaloneModelLoader => SelectedModelLoader is ModelLoader.Unet; public bool ShowPrecisionSelection => SelectedModelLoader is ModelLoader.Unet && !IsGguf; + + /// + /// Whether to show the encoder section (only for UNet models when encoder selection is enabled). + /// + public bool ShowEncoderSection => IsClipModelSelectionEnabled && IsStandaloneModelLoader; + public bool IsSd3Clip => SelectedClipType == "sd3"; public bool IsHiDreamClip => SelectedClipType == "HiDream"; + public bool IsHiDreamWorkflow => + ResolvedWorkflowProfile is InferenceWorkflowProfile.HiDream + || (ResolvedWorkflowProfile is InferenceWorkflowProfile.Custom && SelectedClipType == "HiDream"); + public bool IsZImageWorkflow => + ResolvedWorkflowProfile is InferenceWorkflowProfile.ZImageBase or InferenceWorkflowProfile.ZImageTurbo + || (ResolvedWorkflowProfile is InferenceWorkflowProfile.Custom && SelectedClipType == "lumina2"); public bool IsGguf => SelectedUnetModel?.RelativePath.EndsWith("gguf") ?? false; + /// + /// Whether any advanced options are currently visible (for expander header indication). + /// Includes: Precision (UNet only), VAE, CLIP Skip. + /// + public bool HasActiveAdvancedOptions => + ShowPrecisionSelection || IsVaeSelectionEnabled || IsClipSkipEnabled; + + /// + /// Header text for the Advanced Options expander, showing count of active options. + /// + public string AdvancedOptionsHeader + { + get + { + var count = + (ShowPrecisionSelection ? 1 : 0) + + (IsVaeSelectionEnabled ? 1 : 0) + + (IsClipSkipEnabled ? 1 : 0); + return count > 0 ? $"Advanced Options ({count})" : "Advanced Options"; + } + } + + /// + /// Header text for the Text Encoders expander, showing count of encoders. + /// + public string TextEncodersHeader => $"Text Encoders ({TextEncoders.Count})"; + + /// + /// Whether to show the Shift control (for HiDream and Z-Image workflows, only when in UNet mode). + /// + public bool ShowShift => ShowEncoderSection && (IsHiDreamWorkflow || IsZImageWorkflow); + public bool ShowEncoderTypeSelection => + SelectedWorkflowProfile is InferenceWorkflowProfile.Custom + || (ShowEncoderSection && SelectedWorkflowProfile is not InferenceWorkflowProfile.Auto); + public InferenceWorkflowProfile ResolvedWorkflowProfile => + SelectedWorkflowProfile is InferenceWorkflowProfile.Auto + ? InferWorkflowProfile() + : SelectedWorkflowProfile; + public bool HasRecommendedDefaults => + ResolvedWorkflowProfile + is InferenceWorkflowProfile.DefaultCheckpoint + or InferenceWorkflowProfile.Flux + or InferenceWorkflowProfile.Flux2 + or InferenceWorkflowProfile.ZImageBase + or InferenceWorkflowProfile.ZImageTurbo + or InferenceWorkflowProfile.Anima; + public bool ShowWorkflowProfileStatus => + SelectedWorkflowProfile is InferenceWorkflowProfile.Auto + && SelectedUnifiedModel is not null + && ResolvedWorkflowProfile is not InferenceWorkflowProfile.Custom; + public string WorkflowProfileStatusText => $"Detected: {ResolvedWorkflowProfile.GetStringValue()}"; + public string RecommendedDefaultsToolTip => + ResolvedWorkflowProfile switch + { + InferenceWorkflowProfile.DefaultCheckpoint => + "Apply recommended sampler defaults: Euler Ancestral / Normal / 30 steps / CFG 5", + InferenceWorkflowProfile.Flux => + "Apply recommended sampler defaults: Euler / Simple / 20 steps / CFG 3.5", + InferenceWorkflowProfile.Flux2 => + "Apply recommended sampler defaults: Euler / Flux2Scheduler / 20 steps / CFG 5", + InferenceWorkflowProfile.ZImageBase => + "Apply recommended sampler defaults: Res Multistep / Simple / 30 steps / CFG 4", + InferenceWorkflowProfile.ZImageTurbo => + "Apply recommended sampler defaults: Res Multistep / Simple / 8 steps / CFG 1", + InferenceWorkflowProfile.Anima => + "Apply recommended sampler defaults: ER SDE / Simple / 30 steps / CFG 4", + _ => "No recommended sampler defaults for this workflow", + }; + + public event Action? RecommendedDefaultsRequested; + protected override void OnInitialLoaded() { base.OnInitialLoaded(); ExtraNetworksStackCardViewModel.CardAdded += ExtraNetworksStackCardViewModelOnCardAdded; + + // Initialize default encoders if empty + if (TextEncoders.Count == 0) + { + SetDefaultEncoderCount(); + } } public override void OnUnloaded() @@ -154,6 +365,87 @@ You can use a config (.yaml) file to load a model with specific settings. .ShowAsync(); } + [RelayCommand] + private async Task OpenModelPickerAsync() + { + using var pickerScope = vmFactory.CreateScope(); + var pickerVm = pickerScope.ServiceManager.Get(); + pickerVm.Title = "Select Model"; + pickerVm.PreferredWorkflowProfile = SelectedWorkflowProfile; + + if (await pickerVm.GetDialog().ShowAsync() == ContentDialogResult.Primary) + { + if (pickerVm.SelectedModel is { } selected) + { + // Auto-detect model type based on folder + if (selected.Local?.SharedFolderType == SharedFolderType.DiffusionModels) + { + // It's a UNet model from diffusion_models folder + SelectedModelLoader = ModelLoader.Unet; + SelectedUnetModel = selected; + } + else + { + // It's a checkpoint model + SelectedModelLoader = ModelLoader.Default; + SelectedModel = selected; + } + } + } + } + + [RelayCommand] + private async Task OpenRefinerPickerAsync() + { + using var pickerScope = vmFactory.CreateScope(); + var pickerVm = pickerScope.ServiceManager.Get(); + pickerVm.Title = "Select Refiner"; + pickerVm.Source = ModelPickerSource.CheckpointAndUnet; + pickerVm.ShowCheckpointsOnly = true; + + if (await pickerVm.GetDialog().ShowAsync() == ContentDialogResult.Primary) + { + if (pickerVm.SelectedModel is { } selected) + { + SelectedRefiner = selected; + } + } + } + + [RelayCommand] + private async Task OpenVaePickerAsync() + { + using var pickerScope = vmFactory.CreateScope(); + var pickerVm = pickerScope.ServiceManager.Get(); + pickerVm.Title = "Select VAE"; + pickerVm.Source = ModelPickerSource.Vae; + + if (await pickerVm.GetDialog().ShowAsync() == ContentDialogResult.Primary) + { + if (pickerVm.SelectedModel is { } selected) + { + SelectedVae = selected; + } + } + } + + [RelayCommand] + private async Task OpenClipPickerAsync(TextEncoderSlotViewModel encoderSlot) + { + using var pickerScope = vmFactory.CreateScope(); + var pickerVm = pickerScope.ServiceManager.Get(); + pickerVm.Title = $"Select Text Encoder ({encoderSlot.Label})"; + pickerVm.Source = ModelPickerSource.Clip; + + if (await pickerVm.GetDialog().ShowAsync() == ContentDialogResult.Primary) + { + if (pickerVm.SelectedModel is { } selected) + { + encoderSlot.SelectedModel = selected; + } + } + } + public async Task ValidateModel() { if (IsStandaloneModelLoader && SelectedUnetModel != null) @@ -239,6 +531,9 @@ public virtual void ApplyStep(ModuleApplyStepEventArgs e) /// public override JsonObject SaveStateToJsonObject() { + // Build encoder names list from dynamic collection + var encoderNames = TextEncoders.Select(e => e.SelectedModel?.RelativePath).ToList(); + return SerializeModel( new ModelCardModel { @@ -253,14 +548,22 @@ public override JsonObject SaveStateToJsonObject() IsClipSkipEnabled = IsClipSkipEnabled, IsExtraNetworksEnabled = IsExtraNetworksEnabled, IsModelLoaderSelectionEnabled = IsModelLoaderSelectionEnabled, - SelectedClip1Name = SelectedClip1?.RelativePath, - SelectedClip2Name = SelectedClip2?.RelativePath, - SelectedClip3Name = SelectedClip3?.RelativePath, - SelectedClip4Name = SelectedClip4?.RelativePath, + // For backward compatibility, also save to legacy fields + SelectedClip1Name = encoderNames.ElementAtOrDefault(0), + SelectedClip2Name = encoderNames.ElementAtOrDefault(1), + SelectedClip3Name = encoderNames.ElementAtOrDefault(2), + SelectedClip4Name = encoderNames.ElementAtOrDefault(3), + // New field for dynamic encoders (for future proofing if > 4 encoders) + TextEncoderNames = encoderNames, SelectedClipType = SelectedClipType, + SelectedWorkflowProfile = SelectedWorkflowProfile, + SelectedDType = SelectedDType, + Shift = Shift, IsClipModelSelectionEnabled = IsClipModelSelectionEnabled, ModelLoader = SelectedModelLoader, ShowRefinerOption = ShowRefinerOption, + IsAdvancedOptionsExpanded = IsAdvancedOptionsExpanded, + IsTextEncodersExpanded = IsTextEncodersExpanded, ExtraNetworks = ExtraNetworksStackCardViewModel.SaveStateToJsonObject(), } ); @@ -271,64 +574,204 @@ public override void LoadStateFromJsonObject(JsonObject state) { var model = DeserializeModel(state); - // uwu 123 - // :thinknom: - // :thinkcode: - SelectedModelLoader = model.ModelLoader is ModelLoader.Gguf ? ModelLoader.Unet : model.ModelLoader; + // Set loading flag to prevent auto-adjustment of encoder count + isLoadingState = true; - if (SelectedModelLoader is ModelLoader.Unet) + try { - SelectedUnetModel = model.SelectedModelName is null - ? null - : ClientManager.UnetModels.FirstOrDefault(x => x.RelativePath == model.SelectedModelName); + // uwu 123 + // :thinknom: + // :thinkcode: + SelectedModelLoader = + model.ModelLoader is ModelLoader.Gguf ? ModelLoader.Unet : model.ModelLoader; + + if (SelectedModelLoader is ModelLoader.Unet) + { + SelectedUnetModel = model.SelectedModelName is null + ? null + : ClientManager.UnetModels.FirstOrDefault(x => x.RelativePath == model.SelectedModelName); + } + else + { + SelectedModel = model.SelectedModelName is null + ? null + : ClientManager.Models.FirstOrDefault(x => x.RelativePath == model.SelectedModelName); + } + + SelectedVae = model.SelectedVaeName is null + ? HybridModelFile.Default + : ClientManager.VaeModels.FirstOrDefault(x => x.RelativePath == model.SelectedVaeName); + + SelectedRefiner = model.SelectedRefinerName is null + ? HybridModelFile.None + : ClientManager.Models.FirstOrDefault(x => x.RelativePath == model.SelectedRefinerName); + + // Load encoder type first (needed for default encoder count) + SelectedClipType = model.SelectedClipType; + SelectedWorkflowProfile = model.SelectedWorkflowProfile; + + // Load text encoders from saved state + LoadTextEncodersFromModel(model); + + SelectedDType = model.SelectedDType; + Shift = model.Shift; + ClipSkip = model.ClipSkip; + + IsVaeSelectionEnabled = model.IsVaeSelectionEnabled; + IsRefinerSelectionEnabled = model.IsRefinerSelectionEnabled; + ShowRefinerOption = model.ShowRefinerOption; + IsClipSkipEnabled = model.IsClipSkipEnabled; + IsExtraNetworksEnabled = model.IsExtraNetworksEnabled; + IsModelLoaderSelectionEnabled = model.IsModelLoaderSelectionEnabled; + IsClipModelSelectionEnabled = model.IsClipModelSelectionEnabled; + IsAdvancedOptionsExpanded = model.IsAdvancedOptionsExpanded; + IsTextEncodersExpanded = model.IsTextEncodersExpanded; + + if (model.ExtraNetworks is not null) + { + ExtraNetworksStackCardViewModel.LoadStateFromJsonObject(model.ExtraNetworks); + } } - else + finally { - SelectedModel = model.SelectedModelName is null - ? null - : ClientManager.Models.FirstOrDefault(x => x.RelativePath == model.SelectedModelName); + isLoadingState = false; + NotifyWorkflowProfileStateChanged(); } + } - SelectedVae = model.SelectedVaeName is null - ? HybridModelFile.Default - : ClientManager.VaeModels.FirstOrDefault(x => x.RelativePath == model.SelectedVaeName); + private InferenceWorkflowProfile InferWorkflowProfile() + { + return InferWorkflowProfile( + SelectedUnifiedModel, + SelectedModelLoader is ModelLoader.Unet + || SelectedUnifiedModel?.Local?.SharedFolderType is SharedFolderType.DiffusionModels + ); + } - SelectedRefiner = model.SelectedRefinerName is null - ? HybridModelFile.None - : ClientManager.Models.FirstOrDefault(x => x.RelativePath == model.SelectedRefinerName); + private static InferenceWorkflowProfile InferWorkflowProfile(HybridModelFile? model, bool isUnetModel) + { + if (!isUnetModel) + return InferenceWorkflowProfile.DefaultCheckpoint; - SelectedClip1 = model.SelectedClip1Name is null - ? HybridModelFile.None - : ClientManager.ClipModels.FirstOrDefault(x => x.RelativePath == model.SelectedClip1Name); + var baseModel = model?.Local?.ConnectedModelInfo?.BaseModel; - SelectedClip2 = model.SelectedClip2Name is null - ? HybridModelFile.None - : ClientManager.ClipModels.FirstOrDefault(x => x.RelativePath == model.SelectedClip2Name); + if (!string.IsNullOrWhiteSpace(baseModel)) + { + if (baseModel.Equals("ZImageTurbo", StringComparison.OrdinalIgnoreCase)) + return InferenceWorkflowProfile.ZImageTurbo; - SelectedClip3 = model.SelectedClip3Name is null - ? HybridModelFile.None - : ClientManager.ClipModels.FirstOrDefault(x => x.RelativePath == model.SelectedClip3Name); + if (baseModel.Equals("ZImageBase", StringComparison.OrdinalIgnoreCase)) + return InferenceWorkflowProfile.ZImageBase; - SelectedClip4 = model.SelectedClip4Name is null - ? HybridModelFile.None - : ClientManager.ClipModels.FirstOrDefault(x => x.RelativePath == model.SelectedClip4Name); + if (baseModel.Equals("Anima", StringComparison.OrdinalIgnoreCase)) + return InferenceWorkflowProfile.Anima; - SelectedClipType = model.SelectedClipType; + if (baseModel.StartsWith("Flux.2", StringComparison.OrdinalIgnoreCase)) + return InferenceWorkflowProfile.Flux2; - ClipSkip = model.ClipSkip; + if (baseModel.StartsWith("Flux.1", StringComparison.OrdinalIgnoreCase)) + return InferenceWorkflowProfile.Flux; - IsVaeSelectionEnabled = model.IsVaeSelectionEnabled; - IsRefinerSelectionEnabled = model.IsRefinerSelectionEnabled; - ShowRefinerOption = model.ShowRefinerOption; - IsClipSkipEnabled = model.IsClipSkipEnabled; - IsExtraNetworksEnabled = model.IsExtraNetworksEnabled; - IsModelLoaderSelectionEnabled = model.IsModelLoaderSelectionEnabled; - IsClipModelSelectionEnabled = model.IsClipModelSelectionEnabled; + if (baseModel.Equals("HiDream", StringComparison.OrdinalIgnoreCase)) + return InferenceWorkflowProfile.HiDream; + } - if (model.ExtraNetworks is not null) + var name = model?.RelativePath ?? string.Empty; + + if ( + name.Contains("z_image", StringComparison.OrdinalIgnoreCase) + || name.Contains("z-image", StringComparison.OrdinalIgnoreCase) + || name.Contains("zimage", StringComparison.OrdinalIgnoreCase) + ) { - ExtraNetworksStackCardViewModel.LoadStateFromJsonObject(model.ExtraNetworks); + return name.Contains("turbo", StringComparison.OrdinalIgnoreCase) + ? InferenceWorkflowProfile.ZImageTurbo + : InferenceWorkflowProfile.ZImageBase; } + + if ( + name.Contains("flux2", StringComparison.OrdinalIgnoreCase) + || name.Contains("flux-2", StringComparison.OrdinalIgnoreCase) + || name.Contains("flux_2", StringComparison.OrdinalIgnoreCase) + ) + return InferenceWorkflowProfile.Flux2; + + if (name.Contains("flux", StringComparison.OrdinalIgnoreCase)) + return InferenceWorkflowProfile.Flux; + + if (name.Contains("anima", StringComparison.OrdinalIgnoreCase)) + return InferenceWorkflowProfile.Anima; + + if (name.Contains("hidream", StringComparison.OrdinalIgnoreCase)) + return InferenceWorkflowProfile.HiDream; + + return InferenceWorkflowProfile.DefaultCheckpoint; + } + + /// + /// Loads text encoders from the saved model state, supporting both new and legacy formats. + /// + private void LoadTextEncodersFromModel(ModelCardModel model) + { + TextEncoders.Clear(); + + // Try new format first (TextEncoderNames list) + if (model.TextEncoderNames is { Count: > 0 }) + { + for (var i = 0; i < model.TextEncoderNames.Count; i++) + { + var slot = new TextEncoderSlotViewModel(i + 1); + var encoderName = model.TextEncoderNames[i]; + if (encoderName is not null) + { + slot.SelectedModel = ClientManager.ClipModels.FirstOrDefault(x => + x.RelativePath == encoderName + ); + } + TextEncoders.Add(slot); + } + } + else + { + // Fall back to legacy format (SelectedClip1-4) + var legacyNames = new[] + { + model.SelectedClip1Name, + model.SelectedClip2Name, + model.SelectedClip3Name, + model.SelectedClip4Name, + }; + + // Count how many legacy encoders were set (non-null) + var encoderCount = legacyNames.TakeWhile(n => n is not null).Count(); + + // Use at least the default count for the encoder type + var defaultCount = SelectedClipType switch + { + "flux" => 2, + "flux2" or "lumina2" or "stable_diffusion" => 1, + "sd3" => 3, + "HiDream" => 4, + _ => 2, + }; + encoderCount = Math.Max(encoderCount, defaultCount); + + for (var i = 0; i < encoderCount; i++) + { + var slot = new TextEncoderSlotViewModel(i + 1); + var encoderName = legacyNames.ElementAtOrDefault(i); + if (encoderName is not null) + { + slot.SelectedModel = ClientManager.ClipModels.FirstOrDefault(x => + x.RelativePath == encoderName + ); + } + TextEncoders.Add(slot); + } + } + + OnPropertyChanged(nameof(CanRemoveEncoder)); + OnPropertyChanged(nameof(TextEncodersHeader)); } /// @@ -385,10 +828,12 @@ public void LoadStateFromParameters(GenerationParameters parameters) if (model.Local?.SharedFolderType is SharedFolderType.DiffusionModels) { + SelectedModelLoader = ModelLoader.Unet; SelectedUnetModel = model; } else { + SelectedModelLoader = ModelLoader.Default; SelectedModel = model; } } @@ -421,7 +866,12 @@ partial void OnSelectedModelLoaderChanged(ModelLoader value) if (!IsClipModelSelectionEnabled) IsClipModelSelectionEnabled = true; + + if (TextEncoders.Count == 0) + SetDefaultEncoderCount(); } + + RefreshWorkflowProfileState(); } partial void OnSelectedModelChanged(HybridModelFile? value) @@ -444,7 +894,151 @@ partial void OnSelectedModelChanged(HybridModelFile? value) } } - partial void OnSelectedUnetModelChanged(HybridModelFile? value) => OnSelectedModelChanged(value); + partial void OnSelectedUnetModelChanged(HybridModelFile? value) + { + OnSelectedModelChanged(value); + RefreshWorkflowProfileState(); + } + + partial void OnSelectedClipTypeChanged(string? value) + { + // When encoder type changes, set the default encoder count for that type + // But only if we're not loading state (to preserve user's custom encoder count) + if (!isLoadingState) + { + SetDefaultEncoderCount(preserveUserSelections: true); + } + } + + partial void OnSelectedWorkflowProfileChanged(InferenceWorkflowProfile value) + { + if (!isLoadingState) + { + ApplyDefaultClipTypeForResolvedProfile(preserveUserSelections: true); + } + + RefreshWorkflowProfileState(); + } + + private void NotifyWorkflowProfileStateChanged() + { + OnPropertyChanged(nameof(ResolvedWorkflowProfile)); + OnPropertyChanged(nameof(IsHiDreamWorkflow)); + OnPropertyChanged(nameof(IsZImageWorkflow)); + OnPropertyChanged(nameof(ShowShift)); + OnPropertyChanged(nameof(ShowEncoderTypeSelection)); + OnPropertyChanged(nameof(HasRecommendedDefaults)); + OnPropertyChanged(nameof(WorkflowProfileStatusText)); + OnPropertyChanged(nameof(ShowWorkflowProfileStatus)); + OnPropertyChanged(nameof(RecommendedDefaultsToolTip)); + } + + private void RefreshWorkflowProfileState() + { + NotifyWorkflowProfileStateChanged(); + + if (!isLoadingState) + { + ApplyDefaultClipTypeForResolvedProfile(preserveUserSelections: true); + } + } + + private void ApplyDefaultClipTypeForResolvedProfile(bool preserveUserSelections) + { + if (SelectedWorkflowProfile is InferenceWorkflowProfile.Custom) + return; + + var clipType = ResolvedWorkflowProfile switch + { + InferenceWorkflowProfile.Flux => "flux", + InferenceWorkflowProfile.Flux2 => "flux2", + InferenceWorkflowProfile.ZImageBase or InferenceWorkflowProfile.ZImageTurbo => "lumina2", + InferenceWorkflowProfile.Anima => "stable_diffusion", + InferenceWorkflowProfile.HiDream => "HiDream", + _ => SelectedClipType, + }; + + if (string.IsNullOrWhiteSpace(clipType) || SelectedClipType == clipType) + return; + + SelectedClipType = clipType; + SetDefaultEncoderCount(preserveUserSelections); + } + + /// + /// Flag to prevent auto-adjustment during state loading. + /// + private bool isLoadingState; + + /// + /// Sets the default number of encoder slots based on the selected clip type. + /// + /// If true, only adjust if no encoders have been configured yet. + private void SetDefaultEncoderCount(bool preserveUserSelections = false) + { + // If preserving user selections and any encoder has a model selected, skip adjustment + if (preserveUserSelections && TextEncoders.Any(e => e.SelectedModel is { IsNone: false })) + { + return; + } + + var targetCount = SelectedClipType switch + { + "flux" => 2, + "flux2" or "lumina2" or "stable_diffusion" => 1, + "sd3" => 3, + "HiDream" => 4, + _ => 2, // Default to 2 for unknown types + }; + + // Add or remove encoders to match target count + while (TextEncoders.Count < targetCount) + { + TextEncoders.Add(new TextEncoderSlotViewModel(TextEncoders.Count + 1)); + } + + while (TextEncoders.Count > targetCount) + { + TextEncoders.RemoveAt(TextEncoders.Count - 1); + } + + OnPropertyChanged(nameof(CanRemoveEncoder)); + OnPropertyChanged(nameof(TextEncodersHeader)); + } + + /// + /// Adds a new text encoder slot. + /// + [RelayCommand] + private void AddEncoder() + { + TextEncoders.Add(new TextEncoderSlotViewModel(TextEncoders.Count + 1)); + OnPropertyChanged(nameof(CanRemoveEncoder)); + OnPropertyChanged(nameof(TextEncodersHeader)); + } + + /// + /// Removes the last text encoder slot. + /// + [RelayCommand] + private void RemoveEncoder() + { + if (TextEncoders.Count > 1) + { + TextEncoders.RemoveAt(TextEncoders.Count - 1); + OnPropertyChanged(nameof(CanRemoveEncoder)); + OnPropertyChanged(nameof(TextEncodersHeader)); + } + } + + [RelayCommand] + private void ApplyRecommendedDefaults() + { + if (!HasRecommendedDefaults) + return; + + RecommendedDefaultsRequested?.Invoke(ResolvedWorkflowProfile); + } private void SetupStandaloneModelLoader(ModuleApplyStepEventArgs e) { @@ -476,7 +1070,7 @@ private void SetupStandaloneModelLoader(ModuleApplyStepEventArgs e) e.Builder.Connections.Base.Model = checkpointLoader.Output; } - if (SelectedModelLoader is ModelLoader.Unet && IsHiDreamClip) + if (SelectedModelLoader is ModelLoader.Unet && IsHiDreamWorkflow) { var modelSamplingSd3 = e.Nodes.AddTypedNode( new ComfyNodeBuilder.ModelSamplingSD3 @@ -658,6 +1252,13 @@ private void SetupClipLoaders(ModuleApplyStepEventArgs e) ); e.Builder.Connections.Base.Clip = clipLoader.Output; } + else + { + // No valid encoders configured + throw new ValidationException( + "No text encoders configured. Please select at least one encoder model." + ); + } } internal class ModelCardModel @@ -665,11 +1266,21 @@ internal class ModelCardModel public string? SelectedModelName { get; init; } public string? SelectedRefinerName { get; init; } public string? SelectedVaeName { get; init; } + + // Legacy encoder fields (for backward compatibility) public string? SelectedClip1Name { get; init; } public string? SelectedClip2Name { get; init; } public string? SelectedClip3Name { get; init; } public string? SelectedClip4Name { get; init; } + + // New dynamic encoder list (supports any number of encoders) + public List? TextEncoderNames { get; init; } + public string? SelectedClipType { get; init; } + public InferenceWorkflowProfile SelectedWorkflowProfile { get; init; } = + InferenceWorkflowProfile.Auto; + public string? SelectedDType { get; init; } + public double Shift { get; init; } = 3.0; public ModelLoader ModelLoader { get; init; } public int ClipSkip { get; init; } = 1; @@ -681,6 +1292,9 @@ internal class ModelCardModel public bool IsClipModelSelectionEnabled { get; init; } public bool ShowRefinerOption { get; init; } + public bool IsAdvancedOptionsExpanded { get; init; } + public bool IsTextEncodersExpanded { get; init; } = true; + public JsonObject? ExtraNetworks { get; init; } } } diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/Modules/RegionalPromptModule.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/Modules/RegionalPromptModule.cs new file mode 100644 index 000000000..123c2dd44 --- /dev/null +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/Modules/RegionalPromptModule.cs @@ -0,0 +1,294 @@ +using System.Collections.Generic; +using System.IO; +using Injectio.Attributes; +using SkiaSharp; +using StabilityMatrix.Avalonia.Models; +using StabilityMatrix.Avalonia.Models.Inference; +using StabilityMatrix.Avalonia.Services; +using StabilityMatrix.Avalonia.ViewModels.Base; +using StabilityMatrix.Core.Attributes; +using StabilityMatrix.Core.Extensions; +using StabilityMatrix.Core.Models.Api.Comfy.Nodes; + +namespace StabilityMatrix.Avalonia.ViewModels.Inference.Modules; + +/// +/// Module for regional prompting - apply different prompts to different regions of the image. +/// Uses layers with painted masks to define regions. +/// +[ManagedService] +[RegisterTransient] +public class RegionalPromptModule : ModuleBase +{ + /// + public RegionalPromptModule(IServiceManager vmFactory) + : base(vmFactory) + { + Title = "Regional Prompting"; + AddCards(vmFactory.Get()); + } + + /// + protected override IEnumerable GetInputImages() + { + // Regional prompting masks are transferred via FilesToTransfer + yield break; + } + + /// + protected override void OnApplyStep(ModuleApplyStepEventArgs e) + { + var card = GetCard(); + var maskCounter = 0; + var maskFileNames = new List(); + + // Clean up old mask files from previous generations + CleanupOldMaskFiles(); + + // Sync canvas size from the generation resolution + // This ensures masks are rendered at the correct size even if user changed dimensions + var primarySize = e.Builder.Connections.PrimarySize; + if (primarySize is { Width: > 0, Height: > 0 }) + { + card.SetCanvasSize(primarySize.Width, primarySize.Height); + } + + // Get enabled layers with content + var enabledLayers = card.GetEnabledLayersWithContent(); + + if (enabledLayers.Count == 0) + { + // No layers defined, nothing to do + return; + } + + // Start with the base positive and negative conditioning + var currentPositive = e.Temp.Base.Conditioning!.Unwrap().Positive; + var currentNegative = e.Temp.Base.Conditioning.Negative; + + // Process each layer + foreach (var layer in enabledLayers) + { + // Render layer to mask + using var maskImage = card.RenderLayerToMask(layer); + if (maskImage is null) + continue; + + // Save mask to temp file and add to file transfers + var maskFileName = GetMaskFileName(layer, maskCounter); + maskFileNames.Add(maskFileName); + var tempPath = SaveMaskToTempFile(maskImage, maskFileName); + + // Add to file transfers so it gets uploaded to ComfyUI's input/Inference folder + e.AddFileTransfer(tempPath, $"input/Inference/{maskFileName}"); + + // Load the mask in the workflow + var loadedMask = e.Nodes.AddTypedNode( + new ComfyNodeBuilder.LoadImageMask + { + Name = e.Nodes.GetUniqueName($"RegionalPrompt_LoadMask_{maskCounter}"), + Image = $"Inference/{maskFileName}", + Channel = "red", + } + ); + + // Encode the layer's prompt + var layerClip = e.Nodes.AddTypedNode( + new ComfyNodeBuilder.CLIPTextEncode + { + Name = e.Nodes.GetUniqueName($"RegionalPrompt_CLIP_{maskCounter}"), + Clip = e.Builder.Connections.Base.Clip!, + Text = layer.Prompt, + } + ); + + // Apply the mask to the conditioning + var maskedConditioning = e.Nodes.AddTypedNode( + new ComfyNodeBuilder.ConditioningSetMask + { + Name = e.Nodes.GetUniqueName($"RegionalPrompt_SetMask_{maskCounter}"), + Conditioning = layerClip.Output, + Mask = loadedMask.Output, + Strength = layer.Strength, + SetCondArea = layer.ConditioningAreaValue, + } + ); + + // Combine with the current positive conditioning + var combined = e.Nodes.AddTypedNode( + new ComfyNodeBuilder.ConditioningCombine + { + Name = e.Nodes.GetUniqueName($"RegionalPrompt_Combine_{maskCounter}"), + Conditioning1 = currentPositive, + Conditioning2 = maskedConditioning.Output, + } + ); + + currentPositive = combined.Output; + + // Handle per-layer negative prompt if specified + if (!string.IsNullOrWhiteSpace(layer.NegativePrompt)) + { + var layerNegClip = e.Nodes.AddTypedNode( + new ComfyNodeBuilder.CLIPTextEncode + { + Name = e.Nodes.GetUniqueName($"RegionalPrompt_NegCLIP_{maskCounter}"), + Clip = e.Builder.Connections.Base.Clip!, + Text = layer.NegativePrompt, + } + ); + + var maskedNegConditioning = e.Nodes.AddTypedNode( + new ComfyNodeBuilder.ConditioningSetMask + { + Name = e.Nodes.GetUniqueName($"RegionalPrompt_NegSetMask_{maskCounter}"), + Conditioning = layerNegClip.Output, + Mask = loadedMask.Output, + Strength = layer.Strength, + SetCondArea = layer.ConditioningAreaValue, + } + ); + + var combinedNeg = e.Nodes.AddTypedNode( + new ComfyNodeBuilder.ConditioningCombine + { + Name = e.Nodes.GetUniqueName($"RegionalPrompt_NegCombine_{maskCounter}"), + Conditioning1 = currentNegative, + Conditioning2 = maskedNegConditioning.Output, + } + ); + + currentNegative = combinedNeg.Output; + } + + maskCounter++; + } + + // Update the base conditioning with our combined regional conditioning + e.Temp.Base.Conditioning = (currentPositive, currentNegative); + + // Apply to refiner if available + if (e.Temp.Refiner.Conditioning is not null) + { + ApplyToRefiner(e, enabledLayers, maskFileNames, card); + } + } + + private void ApplyToRefiner( + ModuleApplyStepEventArgs e, + IReadOnlyList enabledLayers, + IReadOnlyList maskFileNames, + RegionalPromptCardViewModel card + ) + { + var refinerPositive = e.Temp.Refiner.Conditioning!.Positive; + var refinerNegative = e.Temp.Refiner.Conditioning.Negative; + var refinerMaskCounter = 0; + + foreach (var layer in enabledLayers) + { + // Reuse the same mask filename from base pass (already uploaded) + var maskFileName = maskFileNames[refinerMaskCounter]; + + var loadedMask = e.Nodes.AddTypedNode( + new ComfyNodeBuilder.LoadImageMask + { + Name = e.Nodes.GetUniqueName($"Refiner_RegionalPrompt_LoadMask_{refinerMaskCounter}"), + Image = $"Inference/{maskFileName}", + Channel = "red", + } + ); + + var layerClip = e.Nodes.AddTypedNode( + new ComfyNodeBuilder.CLIPTextEncode + { + Name = e.Nodes.GetUniqueName($"Refiner_RegionalPrompt_CLIP_{refinerMaskCounter}"), + Clip = e.Builder.Connections.Refiner.Clip ?? e.Builder.Connections.Base.Clip!, + Text = layer.Prompt, + } + ); + + var maskedConditioning = e.Nodes.AddTypedNode( + new ComfyNodeBuilder.ConditioningSetMask + { + Name = e.Nodes.GetUniqueName($"Refiner_RegionalPrompt_SetMask_{refinerMaskCounter}"), + Conditioning = layerClip.Output, + Mask = loadedMask.Output, + Strength = layer.Strength, + SetCondArea = layer.ConditioningAreaValue, + } + ); + + var combined = e.Nodes.AddTypedNode( + new ComfyNodeBuilder.ConditioningCombine + { + Name = e.Nodes.GetUniqueName($"Refiner_RegionalPrompt_Combine_{refinerMaskCounter}"), + Conditioning1 = refinerPositive, + Conditioning2 = maskedConditioning.Output, + } + ); + + refinerPositive = combined.Output; + refinerMaskCounter++; + } + + e.Temp.Refiner.Conditioning = (refinerPositive, refinerNegative); + } + + /// + /// Cleans up old mask files from previous generations to prevent temp directory accumulation. + /// + private static void CleanupOldMaskFiles() + { + var tempPath = Path.Combine(Path.GetTempPath(), "StabilityMatrix", "RegionalPrompts"); + if (!Directory.Exists(tempPath)) + return; + + try + { + // Delete all regional mask files from previous sessions + foreach (var file in Directory.GetFiles(tempPath, "regional_mask_*.png")) + { + try + { + File.Delete(file); + } + catch + { + // Ignore individual file deletion errors - file may be in use + } + } + } + catch + { + // Ignore cleanup errors - not critical to generation + } + } + + /// + /// Generates a unique filename for a layer mask. + /// + private static string GetMaskFileName(MaskLayer layer, int index) + { + // Use layer name sanitized + index for uniqueness + var safeName = layer.Name.Replace(" ", "_").Replace("/", "_").Replace("\\", "_"); + return $"regional_mask_{safeName}_{index}.png"; + } + + /// + /// Saves a mask image to a temporary file. + /// + private static string SaveMaskToTempFile(SKImage maskImage, string fileName) + { + var tempPath = Path.Combine(Path.GetTempPath(), "StabilityMatrix", "RegionalPrompts"); + Directory.CreateDirectory(tempPath); + + var filePath = Path.Combine(tempPath, fileName); + + using var data = maskImage.Encode(SKEncodedImageFormat.Png, 100); + using var fileStream = File.Create(filePath); + data.SaveTo(fileStream); + + return filePath; + } +} diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs index 0bfd10c94..12c106926 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/PromptCardViewModel.cs @@ -280,7 +280,7 @@ private void SetTokenThreshold() if (accountsService.LykosStatus is not { User: not null } status) return; - if (status.User.Roles.Count is 1 && status.User.Roles.Contains(LykosRole.Basic.ToString())) + if (status.User.Roles.Count is 1 && status.User.Roles.Contains(nameof(LykosRole.Basic))) { LowTokenThreshold = 25; } diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/RegionalPromptCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/RegionalPromptCardViewModel.cs new file mode 100644 index 000000000..bca8c2f0c --- /dev/null +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/RegionalPromptCardViewModel.cs @@ -0,0 +1,121 @@ +using System.Collections.ObjectModel; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization; +using CommunityToolkit.Mvvm.ComponentModel; +using CommunityToolkit.Mvvm.Input; +using Injectio.Attributes; +using SkiaSharp; +using StabilityMatrix.Avalonia.Controls; +using StabilityMatrix.Avalonia.Models.Inference; +using StabilityMatrix.Avalonia.Services; +using StabilityMatrix.Avalonia.ViewModels.Base; +using StabilityMatrix.Avalonia.ViewModels.Dialogs; +using StabilityMatrix.Core.Attributes; + +namespace StabilityMatrix.Avalonia.ViewModels.Inference; + +[View(typeof(RegionalPromptCard))] +[ManagedService] +[RegisterTransient] +public partial class RegionalPromptCardViewModel( + IServiceManager vmFactory, + TabContext tabContext +) : LoadableViewModelBase +{ + public const string ModuleKey = "RegionalPrompt"; + + private readonly IServiceManager vmFactory = vmFactory; + + /// + /// The layered mask editor for painting regions. + /// Each layer = one prompt with its own mask. + /// + [JsonIgnore] + public LayeredMaskEditorViewModel LayeredMaskEditor { get; } = + vmFactory.Get(); + + /// + /// Convenience accessor for layers (for UI binding). + /// + [JsonIgnore] + public ObservableCollection Layers => LayeredMaskEditor.Layers; + + /// + /// Sets the canvas size for the mask editor. + /// Should be called when the sampler dimensions change. + /// + public void SetCanvasSize(int width, int height) + { + LayeredMaskEditor.CanvasSize = new System.Drawing.Size(width, height); + } + + /// + /// Opens the layered mask editor dialog. + /// + [RelayCommand] + private async Task OpenMaskEditorAsync() + { + // Get canvas size from TabContext (synced from SamplerCardViewModel) + var width = tabContext.SamplerWidth; + var height = tabContext.SamplerHeight; + + // Use the sampler dimensions, or fallback to 1024x1024 + if (width > 0 && height > 0) + { + LayeredMaskEditor.CanvasSize = new System.Drawing.Size(width, height); + } + else if (LayeredMaskEditor.CanvasSize == System.Drawing.Size.Empty) + { + LayeredMaskEditor.CanvasSize = new System.Drawing.Size(1024, 1024); + } + + var dialog = LayeredMaskEditor.GetDialog(); + await dialog.ShowAsync(); + + // Save current layer paths after dialog closes + LayeredMaskEditor.SaveCurrentLayerPaths(); + } + + /// + /// Gets enabled layers with content for generation. + /// + public IReadOnlyList GetEnabledLayersWithContent() + { + return LayeredMaskEditor.GetEnabledLayersWithContent(); + } + + /// + /// Renders a layer to a mask image for ComfyUI. + /// + public SKImage? RenderLayerToMask(MaskLayer layer) + { + return LayeredMaskEditor.RenderLayerToMask(layer); + } + + /// + public override void LoadStateFromJsonObject(JsonObject state) + { + base.LoadStateFromJsonObject(state); + + // Load layered mask editor state + if ( + state.TryGetPropertyValue("layeredMaskEditor", out var editorNode) + && editorNode is JsonObject editorObj + ) + { + LayeredMaskEditor.LoadStateFromJsonObject(editorObj); + } + } + + /// + public override JsonObject SaveStateToJsonObject() + { + var state = base.SaveStateToJsonObject(); + + // Save layered mask editor state + state["layeredMaskEditor"] = LayeredMaskEditor.SaveStateToJsonObject(); + + return state; + } +} diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs index 8cdb313a1..e7ad9797a 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs @@ -130,6 +130,11 @@ public partial class SamplerCardViewModel : LoadableViewModelBase, IParametersLo [JsonIgnore] public IInferenceClientManager ClientManager { get; } + [JsonIgnore] + public TabContext TabContext => tabContext; + + public bool HasSourceImageDimensions => tabContext.HasSourceImageDimensions; + private int TotalSteps => Steps + RefinerSteps; public SamplerCardViewModel( @@ -149,6 +154,7 @@ TabContext tabContext [ typeof(FreeUModule), typeof(ControlNetModule), + typeof(RegionalPromptModule), typeof(LayerDiffuseModule), typeof(FluxGuidanceModule), typeof(DiscreteModelSamplingModule), @@ -178,6 +184,13 @@ public override void OnUnloaded() private void TabContextOnStateChanged(object? sender, TabContext.TabStateChangedEventArgs e) { + if (e.PropertyName is nameof(tabContext.SourceImageWidth) or nameof(tabContext.SourceImageHeight)) + { + OnPropertyChanged(nameof(HasSourceImageDimensions)); + ApplySourceImageDimensionsCommand.NotifyCanExecuteChanged(); + return; + } + if (e.PropertyName != nameof(tabContext.SelectedModel)) return; @@ -192,12 +205,36 @@ private void TabContextOnStateChanged(object? sender, TabContext.TabStateChanged SelectedScheduler = defaults.Scheduler; } + partial void OnWidthChanged(int value) + { + // Sync width to TabContext so other components can access it + tabContext.SamplerWidth = value; + } + + partial void OnHeightChanged(int value) + { + // Sync height to TabContext so other components can access it + tabContext.SamplerHeight = value; + } + [RelayCommand] private void SwapDimensions() { (Width, Height) = (Height, Width); } + [RelayCommand(CanExecute = nameof(CanApplySourceImageDimensions))] + private void ApplySourceImageDimensions() + { + Width = tabContext.SourceImageWidth; + Height = tabContext.SourceImageHeight; + } + + private bool CanApplySourceImageDimensions() + { + return tabContext.HasSourceImageDimensions; + } + [RelayCommand] private void SetResolution(string resolution) { @@ -291,7 +328,11 @@ public virtual void ApplyStep(ModuleApplyStepEventArgs e) } } - public void ApplyStepsInitialCustomSampler(ModuleApplyStepEventArgs e, bool useFluxGuidance) + public void ApplyStepsInitialCustomSampler( + ModuleApplyStepEventArgs e, + bool useFluxGuidance, + bool useFlux2Scheduler = false + ) { // Provide temp values e.Temp = e.CreateTempFromBuilder(); @@ -329,7 +370,21 @@ public void ApplyStepsInitialCustomSampler(ModuleApplyStepEventArgs e, bool useF e.Builder.Connections.PrimarySamplerNode = kSamplerSelect.Output; // Scheduler/Sigmas - if (e.Builder.Connections.PrimaryScheduler?.Name is "align_your_steps") + if (useFlux2Scheduler) + { + var flux2Scheduler = e.Nodes.AddTypedNode( + new ComfyNodeBuilder.Flux2Scheduler + { + Name = e.Nodes.GetUniqueName(nameof(ComfyNodeBuilder.Flux2Scheduler)), + Steps = Steps, + Width = Width, + Height = Height, + } + ); + + e.Builder.Connections.PrimarySigmas = flux2Scheduler.Output; + } + else if (e.Builder.Connections.PrimaryScheduler?.Name is "align_your_steps") { var alignYourSteps = e.Nodes.AddTypedNode( new ComfyNodeBuilder.AlignYourStepsScheduler @@ -377,14 +432,14 @@ public void ApplyStepsInitialCustomSampler(ModuleApplyStepEventArgs e, bool useF new ComfyNodeBuilder.FluxGuidance { Name = e.Nodes.GetUniqueName(nameof(ComfyNodeBuilder.FluxGuidance)), - Conditioning = e.Builder.Connections.GetRefinerOrBaseConditioning().Positive, + Conditioning = e.Temp.GetRefinerOrBaseConditioning().Positive, Guidance = CfgScale, } ); e.Builder.Connections.Base.Conditioning = new ConditioningConnections( fluxGuidance.Output, - e.Builder.Connections.GetRefinerOrBaseConditioning().Negative + e.Temp.GetRefinerOrBaseConditioning().Negative ); // Guider @@ -401,7 +456,7 @@ public void ApplyStepsInitialCustomSampler(ModuleApplyStepEventArgs e, bool useF } else { - e.Builder.Connections.Base.Conditioning = e.Builder.Connections.GetRefinerOrBaseConditioning(); + e.Builder.Connections.Base.Conditioning = e.Temp.GetRefinerOrBaseConditioning(); var cfgGuider = e.Nodes.AddTypedNode( new ComfyNodeBuilder.CFGGuider diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/SelectImageCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/SelectImageCardViewModel.cs index 2bc1ce201..c2682bc77 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/SelectImageCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/SelectImageCardViewModel.cs @@ -32,7 +32,8 @@ namespace StabilityMatrix.Avalonia.ViewModels.Inference; [RegisterTransient] public partial class SelectImageCardViewModel( INotificationService notificationService, - IServiceManager vmFactory + IServiceManager vmFactory, + TabContext tabContext ) : LoadableViewModelBase, IDropTarget, IComfyStep, IInputImageProvider { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); @@ -42,11 +43,12 @@ IServiceManager vmFactory { Patterns = new[] { "*.png", "*.jpg", "*.jpeg" }, AppleUniformTypeIdentifiers = new[] { "public.jpeg", "public.png" }, - MimeTypes = new[] { "image/jpeg", "image/png" } + MimeTypes = new[] { "image/jpeg", "image/png" }, }; - private readonly Lazy _lazyMaskEditorViewModel = - new(vmFactory.Get); + private readonly Lazy _lazyMaskEditorViewModel = new( + vmFactory.Get + ); /// /// When true, enables a button to open a mask editor for the image. @@ -79,6 +81,10 @@ IServiceManager vmFactory [ObservableProperty] private Size currentBitmapSize = Size.Empty; + [ObservableProperty] + [property: JsonIgnore] + private bool syncBitmapSizeToTabContext; + /// /// True if the image file is set but the local file does not exist. /// @@ -99,6 +105,27 @@ IServiceManager vmFactory [JsonIgnore] public ImageSource? LastMaskImage { get; private set; } + partial void OnCurrentBitmapSizeChanged(Size value) + { + PublishCurrentBitmapSizeToTabContext(); + } + + partial void OnSyncBitmapSizeToTabContextChanged(bool value) + { + PublishCurrentBitmapSizeToTabContext(); + } + + private void PublishCurrentBitmapSizeToTabContext() + { + if (!SyncBitmapSizeToTabContext) + { + return; + } + + tabContext.SourceImageWidth = CurrentBitmapSize.Width > 0 ? CurrentBitmapSize.Width : 0; + tabContext.SourceImageHeight = CurrentBitmapSize.Height > 0 ? CurrentBitmapSize.Height : 0; + } + /// public void ApplyStep(ModuleApplyStepEventArgs e) { @@ -170,6 +197,11 @@ partial void OnImageSourceChanged(ImageSource? value) ); }); } + + if (value is null) + { + CurrentBitmapSize = Size.Empty; + } } [RelayCommand] @@ -178,7 +210,12 @@ private async Task SelectImageFromFilePickerAsync() var files = await App.StorageProvider.OpenFilePickerAsync( new FilePickerOpenOptions { - FileTypeFilter = [FilePickerFileTypes.ImagePng, FilePickerFileTypes.ImageJpg, SupportedImages] + FileTypeFilter = + [ + FilePickerFileTypes.ImagePng, + FilePickerFileTypes.ImageJpg, + SupportedImages, + ], } ); @@ -289,6 +326,7 @@ private void LoadUserImage(ImageSource image) { var current = ImageSource; + CurrentBitmapSize = Size.Empty; ImageSource = image; // current?.Dispose(); diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/TextEncoderSlotViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/TextEncoderSlotViewModel.cs new file mode 100644 index 000000000..b9c878083 --- /dev/null +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/TextEncoderSlotViewModel.cs @@ -0,0 +1,56 @@ +using CommunityToolkit.Mvvm.ComponentModel; +using StabilityMatrix.Avalonia.ViewModels.Base; +using StabilityMatrix.Core.Models; + +namespace StabilityMatrix.Avalonia.ViewModels.Inference; + +/// +/// Represents a single text encoder slot in the dynamic encoder list. +/// +public partial class TextEncoderSlotViewModel : ViewModelBase +{ + /// + /// 1-based index for display (Encoder 1, Encoder 2, etc.) + /// + [ObservableProperty] + private int index; + + /// + /// Display label for this encoder slot. + /// + public string Label => $"Encoder {Index}"; + + private HybridModelFile? selectedModel; + + /// + /// The selected CLIP/text encoder model. + /// + public HybridModelFile? SelectedModel + { + get => selectedModel; + set + { + // The bound ComboBox can briefly report null while the model list refreshes + // (e.g. when navigating away and back to the Inference tab). Ignore the + // transient null so the encoder selection isn't cleared out from under the user. + if (value is null && selectedModel is not null) + { + return; + } + + SetProperty(ref selectedModel, value); + } + } + + public TextEncoderSlotViewModel() { } + + public TextEncoderSlotViewModel(int index) + { + Index = index; + } + + partial void OnIndexChanged(int value) + { + OnPropertyChanged(nameof(Label)); + } +} diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/WanModelCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/WanModelCardViewModel.cs index dfdf885f3..f7e33939d 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/WanModelCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/WanModelCardViewModel.cs @@ -1,5 +1,7 @@ ο»Ώusing System.ComponentModel.DataAnnotations; using CommunityToolkit.Mvvm.ComponentModel; +using CommunityToolkit.Mvvm.Input; +using FluentAvalonia.UI.Controls; using Injectio.Attributes; using StabilityMatrix.Avalonia.Controls; using StabilityMatrix.Avalonia.Languages; @@ -7,6 +9,7 @@ using StabilityMatrix.Avalonia.Models.Inference; using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Base; +using StabilityMatrix.Avalonia.ViewModels.Dialogs; using StabilityMatrix.Avalonia.ViewModels.Inference.Modules; using StabilityMatrix.Core.Attributes; using StabilityMatrix.Core.Models; @@ -51,6 +54,76 @@ IServiceManager vmFactory public List WeightDTypes { get; set; } = ["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"]; + [RelayCommand] + private async Task OpenModelPickerAsync() + { + using var pickerScope = vmFactory.CreateScope(); + var pickerVm = pickerScope.ServiceManager.Get(); + pickerVm.Title = "Select Model"; + // WanModelCard only uses UNet models + pickerVm.Source = ModelPickerSource.CheckpointAndUnet; + pickerVm.ShowUnetsOnly = true; + + if (await pickerVm.GetDialog().ShowAsync() == ContentDialogResult.Primary) + { + if (pickerVm.SelectedModel is { } selected) + { + SelectedModel = selected; + } + } + } + + [RelayCommand] + private async Task OpenVaePickerAsync() + { + using var pickerScope = vmFactory.CreateScope(); + var pickerVm = pickerScope.ServiceManager.Get(); + pickerVm.Title = "Select VAE"; + pickerVm.Source = ModelPickerSource.Vae; + + if (await pickerVm.GetDialog().ShowAsync() == ContentDialogResult.Primary) + { + if (pickerVm.SelectedModel is { } selected) + { + SelectedVae = selected; + } + } + } + + [RelayCommand] + private async Task OpenClipPickerAsync() + { + using var pickerScope = vmFactory.CreateScope(); + var pickerVm = pickerScope.ServiceManager.Get(); + pickerVm.Title = "Select Text Encoder"; + pickerVm.Source = ModelPickerSource.Clip; + + if (await pickerVm.GetDialog().ShowAsync() == ContentDialogResult.Primary) + { + if (pickerVm.SelectedModel is { } selected) + { + SelectedClipModel = selected; + } + } + } + + [RelayCommand] + private async Task OpenClipVisionPickerAsync() + { + using var pickerScope = vmFactory.CreateScope(); + var pickerVm = pickerScope.ServiceManager.Get(); + pickerVm.Title = "Select CLIP Vision"; + pickerVm.Source = ModelPickerSource.ClipVision; + + if (await pickerVm.GetDialog().ShowAsync() == ContentDialogResult.Primary) + { + if (pickerVm.SelectedModel is { } selected) + { + SelectedClipVisionModel = selected; + } + } + } + public async Task ValidateModel() { if (SelectedModel == null) @@ -106,7 +179,7 @@ public void ApplyStep(ModuleApplyStepEventArgs e) { Name = e.Nodes.GetUniqueName(nameof(ComfyNodeBuilder.UnetLoaderGGUF)), UnetName = - SelectedModel?.RelativePath ?? throw new ValidationException("Model not selected") + SelectedModel?.RelativePath ?? throw new ValidationException("Model not selected"), } ); } @@ -118,7 +191,7 @@ public void ApplyStep(ModuleApplyStepEventArgs e) Name = e.Nodes.GetUniqueName(nameof(ComfyNodeBuilder.UNETLoader)), UnetName = SelectedModel?.RelativePath ?? throw new ValidationException("Model not selected"), - WeightDtype = SelectedDType ?? "fp8_e4m3fn_fast" + WeightDtype = SelectedDType ?? "fp8_e4m3fn_fast", } ); } @@ -128,7 +201,7 @@ public void ApplyStep(ModuleApplyStepEventArgs e) { Name = e.Nodes.GetUniqueName(nameof(ComfyNodeBuilder.ModelSamplingSD3)), Model = modelLoader.Output, - Shift = Shift + Shift = Shift, } ); @@ -141,7 +214,7 @@ public void ApplyStep(ModuleApplyStepEventArgs e) ClipName = SelectedClipModel?.RelativePath ?? throw new ValidationException("No Clip Model Selected"), - Type = "wan" + Type = "wan", } ); @@ -151,7 +224,7 @@ public void ApplyStep(ModuleApplyStepEventArgs e) new ComfyNodeBuilder.VAELoader { Name = e.Nodes.GetUniqueName(nameof(ComfyNodeBuilder.VAELoader)), - VaeName = SelectedVae?.RelativePath ?? throw new ValidationException("No VAE Selected") + VaeName = SelectedVae?.RelativePath ?? throw new ValidationException("No VAE Selected"), } ); e.Builder.Connections.Base.VAE = vaeLoader.Output; @@ -171,7 +244,7 @@ public void ApplyStep(ModuleApplyStepEventArgs e) Name = e.Nodes.GetUniqueName(nameof(ComfyNodeBuilder.CLIPVisionLoader)), ClipName = SelectedClipVisionModel?.RelativePath - ?? throw new ValidationException("No Clip Vision Model Selected") + ?? throw new ValidationException("No Clip Vision Model Selected"), } ); @@ -191,10 +264,9 @@ public void LoadStateFromParameters(GenerationParameters parameters) // First try hash match if (parameters.ModelHash is not null) { - model = currentModels.FirstOrDefault( - m => - m.Local?.ConnectedModelInfo?.Hashes.SHA256 is { } sha256 - && sha256.StartsWith(parameters.ModelHash, StringComparison.InvariantCultureIgnoreCase) + model = currentModels.FirstOrDefault(m => + m.Local?.ConnectedModelInfo?.Hashes.SHA256 is { } sha256 + && sha256.StartsWith(parameters.ModelHash, StringComparison.InvariantCultureIgnoreCase) ); } else @@ -215,7 +287,7 @@ public GenerationParameters SaveStateToParameters(GenerationParameters parameter return parameters with { ModelName = SelectedModel?.FileName, - ModelHash = SelectedModel?.Local?.ConnectedModelInfo?.Hashes.SHA256 + ModelHash = SelectedModel?.Local?.ConnectedModelInfo?.Hashes.SHA256, }; } } diff --git a/StabilityMatrix.Avalonia/ViewModels/LaunchPageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/LaunchPageViewModel.cs index 9751c923e..cd337ef0b 100644 --- a/StabilityMatrix.Avalonia/ViewModels/LaunchPageViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/LaunchPageViewModel.cs @@ -20,6 +20,7 @@ using Microsoft.Extensions.Logging; using StabilityMatrix.Avalonia.Controls; using StabilityMatrix.Avalonia.Languages; +using StabilityMatrix.Avalonia.Models.PackageSteps; using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Avalonia.ViewModels.Dialogs; @@ -340,13 +341,8 @@ await basePackage.UpdateModelFolders( } // Unpacks sitecustomize.py to the target venv - private static async Task UnpackSiteCustomize(DirectoryPath venvPath) - { - var sitePackages = venvPath.JoinDir(PyVenvRunner.RelativeSitePackagesPath); - var file = sitePackages.JoinFile("sitecustomize.py"); - file.Directory?.Create(); - await Assets.PyScriptSiteCustomize.ExtractTo(file, true); - } + private static Task UnpackSiteCustomize(DirectoryPath venvPath) => + new UnpackSiteCustomizeStep(venvPath).ExecuteAsync(); [RelayCommand] private async Task Config() diff --git a/StabilityMatrix.Avalonia/ViewModels/MainWindowViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/MainWindowViewModel.cs index d57b8d420..e317d78fb 100644 --- a/StabilityMatrix.Avalonia/ViewModels/MainWindowViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/MainWindowViewModel.cs @@ -266,7 +266,6 @@ settingsManager.Settings.Analytics.LastSeenConsentVersion is null await dialog.ShowAsync(App.TopLevel); EventManager.Instance.OnRecommendedModelsDialogClosed(); - EventManager.Instance.OnDownloadsTeachingTipRequested(); var installedPackageNameMaybe = settingsManager.PackageInstallsInProgress.FirstOrDefault() diff --git a/StabilityMatrix.Avalonia/ViewModels/OutputsPage/OutputImageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/OutputsPage/OutputImageViewModel.cs index f4b98a0c5..ef53dada3 100644 --- a/StabilityMatrix.Avalonia/ViewModels/OutputsPage/OutputImageViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/OutputsPage/OutputImageViewModel.cs @@ -1,14 +1,23 @@ -ο»Ώusing StabilityMatrix.Avalonia.ViewModels.Base; +ο»Ώusing CommunityToolkit.Mvvm.ComponentModel; +using StabilityMatrix.Avalonia.ViewModels.Base; using StabilityMatrix.Core.Models.Database; namespace StabilityMatrix.Avalonia.ViewModels.OutputsPage; -public class OutputImageViewModel : SelectableViewModelBase +public partial class OutputImageViewModel(LocalImageFile imageFile) : SelectableViewModelBase { - public OutputImageViewModel(LocalImageFile imageFile) - { - ImageFile = imageFile; - } + public LocalImageFile ImageFile { get; } = imageFile; - public LocalImageFile ImageFile { get; } + /// + /// Thumbnail path for video files. Set asynchronously after thumbnail generation. + /// + [ObservableProperty] + [NotifyPropertyChangedFor(nameof(DisplayPath))] + public partial string? ThumbnailPath { get; set; } + + /// + /// Path to display - uses thumbnail for videos, original path for images. + /// + public string DisplayPath => + ImageFile.IsVideo && ThumbnailPath != null ? ThumbnailPath : ImageFile.AbsolutePath; } diff --git a/StabilityMatrix.Avalonia/ViewModels/OutputsPageViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/OutputsPageViewModel.cs index 57a0c7f75..905ae2783 100644 --- a/StabilityMatrix.Avalonia/ViewModels/OutputsPageViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/OutputsPageViewModel.cs @@ -1,4 +1,4 @@ -ο»Ώusing System; +using System; using System.Collections.Generic; using System.Collections.ObjectModel; using System.IO; @@ -57,6 +57,9 @@ public partial class OutputsPageViewModel : PageViewModelBase private readonly ILogger logger; private readonly List cancellationTokenSources = []; private readonly IServiceManager vmFactory; + private readonly IVideoThumbnailService videoThumbnailService; + + private int pendingThumbnails = 0; public override string Title => Resources.Label_OutputsPageTitle; @@ -109,7 +112,19 @@ public partial class OutputsPageViewModel : PageViewModelBase ? Resources.Label_OneImageSelected : string.Format(Resources.Label_NumImagesSelected, NumItemsSelected); - private string[] allowedExtensions = [".png", ".webp", ".jpg", ".jpeg", ".gif"]; + private string[] allowedExtensions = + [ + ".png", + ".webp", + ".jpg", + ".jpeg", + ".gif", + ".mp4", + ".webm", + ".mov", + ".avi", + ".mkv", + ]; private TreeViewDirectory? lastOutputCategory; @@ -119,7 +134,8 @@ public OutputsPageViewModel( INotificationService notificationService, INavigationService navigationService, ILogger logger, - IServiceManager vmFactory + IServiceManager vmFactory, + IVideoThumbnailService videoThumbnailService ) { this.settingsManager = settingsManager; @@ -128,6 +144,7 @@ IServiceManager vmFactory this.navigationService = navigationService; this.logger = logger; this.vmFactory = vmFactory; + this.videoThumbnailService = videoThumbnailService; var searcher = new ImageSearcher(); @@ -142,7 +159,28 @@ IServiceManager vmFactory .Connect() .DeferUntilLoaded() .Filter(searchPredicate) - .Transform(file => new OutputImageViewModel(file)) + .Transform(file => + { + var vm = new OutputImageViewModel(file); + // For video files, check for existing thumbnail immediately (sync) + // Then kick off async generation if needed + if (file.IsVideo) + { + // Check if thumbnail already exists (sync - for immediate display) + var existingThumb = videoThumbnailService.GetExistingThumbnailPath(file.AbsolutePath); + if (existingThumb != null) + { + vm.ThumbnailPath = existingThumb; + } + else + { + // Kick off async thumbnail generation + GenerateVideoThumbnailAsync(vm) + .SafeFireAndForget(ex => logger.LogError(ex, "Error generating video thumbnail")); + } + } + return vm; + }) .Sort( SortExpressionComparer .Descending(vm => vm.ImageFile.CreatedAt) @@ -226,7 +264,7 @@ partial void OnSelectedCategoryChanged(TreeViewDirectory? oldValue, TreeViewDire var path = CanShowOutputTypes && SelectedOutputType != SharedOutputType.All ? Path.Combine(newValue.Path, SelectedOutputType.ToString()) - : SelectedCategory.Path; + : newValue.Path; GetOutputs(path); lastOutputCategory = newValue; } @@ -260,6 +298,13 @@ public Task OnImageClick(OutputImageViewModel item) public async Task ShowImageDialog(OutputImageViewModel item) { + // If it's a video file, open with system player + if (item.ImageFile.IsVideo) + { + ProcessRunner.OpenUrl(item.ImageFile.AbsolutePath); + return; + } + var currentIndex = Outputs.IndexOf(item); var image = new ImageSource(new FilePath(item.ImageFile.AbsolutePath)); @@ -568,6 +613,12 @@ private void GetOutputs(string directory) if (!settingsManager.IsLibraryDirSet) return; + var imageLabInputsRoot = Path.Combine(settingsManager.ImagesDirectory, "ImageLab", "Inputs"); + var imageLabInputsRootFull = + Path.GetFullPath(imageLabInputsRoot) + .TrimEnd(Path.DirectorySeparatorChar, Path.AltDirectorySeparatorChar) + + Path.DirectorySeparatorChar; + if ( !Directory.Exists(directory) && ( @@ -605,11 +656,18 @@ private void GetOutputs(string directory) .EnumerateFiles(directory, "*", EnumerationOptionConstants.AllDirectories) .Where(file => allowedExtensions.Contains(new FilePath(file).Extension) + && !Path.GetFullPath(file) + .StartsWith(imageLabInputsRootFull, StringComparison.OrdinalIgnoreCase) && new FilePath(file).Info.DirectoryName?.EndsWith( "thumbnails", StringComparison.OrdinalIgnoreCase ) is false + && new FilePath(file).Info.DirectoryName?.EndsWith( + ".sm-thumbs", + StringComparison.OrdinalIgnoreCase + ) + is false ) .Select(file => LocalImageFile.FromPath(file)) .ToList(); @@ -706,7 +764,14 @@ private ObservableCollection GetSubfolders(string strPath) foreach (var dir in directories) { - var category = new TreeViewDirectory { Name = Path.GetFileName(dir), Path = dir }; + // Skip thumbnail directories + var dirName = Path.GetFileName(dir); + if (dirName.Equals(".sm-thumbs", StringComparison.OrdinalIgnoreCase)) + { + continue; + } + + var category = new TreeViewDirectory { Name = dirName, Path = dir }; if (Directory.GetDirectories(dir, "*", EnumerationOptionConstants.TopLevelOnly).Length > 0) { @@ -791,4 +856,60 @@ private bool FindAndExpandPathToCategory(IEnumerable nodes, s return false; // Target was not found in this collection of nodes. } + + /// + /// Generates a thumbnail for a video file and updates the ViewModel. + /// + private async Task GenerateVideoThumbnailAsync(OutputImageViewModel vm) + { + logger.LogInformation("Starting thumbnail generation for video: {Path}", vm.ImageFile.FileName); + + // Track pending thumbnails and show notification + var count = Interlocked.Increment(ref pendingThumbnails); + if (count == 1) + { + notificationService.Show( + "Generating Video Thumbnails", + "Creating thumbnails for video files...", + NotificationType.Information + ); + } + + try + { + var thumbnailPath = await videoThumbnailService + .GetOrCreateThumbnailAsync(vm.ImageFile.AbsolutePath) + .ConfigureAwait(false); + + logger.LogInformation( + "Thumbnail result for {Video}: {ThumbnailPath}", + vm.ImageFile.FileName, + thumbnailPath ?? "(null)" + ); + + if (thumbnailPath != null) + { + // Update on UI thread + await Dispatcher.UIThread.InvokeAsync(() => + { + logger.LogInformation( + "Setting ThumbnailPath for {Video} to {Path}", + vm.ImageFile.FileName, + thumbnailPath + ); + vm.ThumbnailPath = thumbnailPath; + logger.LogInformation("DisplayPath is now: {DisplayPath}", vm.DisplayPath); + }); + } + } + catch (Exception ex) + { + logger.LogWarning(ex, "Failed to generate thumbnail for {Path}", vm.ImageFile.AbsolutePath); + } + finally + { + // Decrement counter + Interlocked.Decrement(ref pendingThumbnails); + } + } } diff --git a/StabilityMatrix.Avalonia/ViewModels/PackageManager/PackageCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/PackageManager/PackageCardViewModel.cs index 2fbc50c6d..55e092056 100644 --- a/StabilityMatrix.Avalonia/ViewModels/PackageManager/PackageCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/PackageManager/PackageCardViewModel.cs @@ -43,7 +43,9 @@ public partial class PackageCardViewModel( ISettingsManager settingsManager, INavigationService navigationService, IServiceManager vmFactory, - RunningPackageService runningPackageService + IPyInstallationManager pyInstallationManager, + RunningPackageService runningPackageService, + IPrerequisiteHelper prerequisiteHelper ) : ProgressViewModel { private string webUiUrl = string.Empty; @@ -481,6 +483,15 @@ public async Task Update() versionOptions.CommitHash = latest.Sha; } + PyVersion? desiredPythonVersion = PyVersion.TryParse(Package.PythonVersion, out var pv) + ? pv + : null; + var ensurePrereqStep = new SetupPrerequisitesStep( + prerequisiteHelper, + basePackage, + desiredPythonVersion + ); + var updatePackageStep = new UpdatePackageStep( settingsManager, basePackage, @@ -492,11 +503,11 @@ public async Task Update() PythonOptions = { TorchIndex = Package.PreferredTorchIndex, - PythonVersion = PyVersion.TryParse(Package.PythonVersion, out var pv) ? pv : null, + PythonVersion = desiredPythonVersion, }, } ); - var steps = new List { updatePackageStep }; + var steps = new List { ensurePrereqStep, updatePackageStep }; EventManager.Instance.OnPackageInstallProgressAdded(runner); await runner.ExecuteSteps(steps); @@ -657,6 +668,15 @@ private async Task ChangeVersion() versionOptions.CommitHash = viewModel.SelectedCommit?.Sha; } + PyVersion? desiredPythonVersion = PyVersion.TryParse(Package.PythonVersion, out var pyVer) + ? pyVer + : null; + var ensurePrereqStep = new SetupPrerequisitesStep( + prerequisiteHelper, + basePackage, + desiredPythonVersion + ); + var updatePackageStep = new UpdatePackageStep( settingsManager, basePackage, @@ -668,13 +688,11 @@ private async Task ChangeVersion() PythonOptions = { TorchIndex = Package.PreferredTorchIndex, - PythonVersion = PyVersion.TryParse(Package.PythonVersion, out var pyVer) - ? pyVer - : null, + PythonVersion = desiredPythonVersion, }, } ); - var steps = new List { updatePackageStep }; + var steps = new List { ensurePrereqStep, updatePackageStep }; EventManager.Instance.OnPackageInstallProgressAdded(runner); await runner.ExecuteSteps(steps); @@ -961,6 +979,41 @@ private async Task ExecuteExtraCommand(string commandName) } } + [RelayCommand] + private async Task RunPythonCommand() + { + if (Package is null || IsUnknownPackage) + return; + + var field = new TextBoxField + { + Label = "Arguments", + InnerLeftText = "python.exe", + Watermark = "-c \"print('Hello World')\"", + }; + + var result = await DialogHelper.GetTextEntryDialogResultAsync(field, "Run Python Command"); + + if (result.Result == ContentDialogResult.Primary) + { + var runCommandStep = new RunPythonCommandStep(pyInstallationManager, settingsManager) + { + Arguments = field.Text, + InstalledPackage = Package, + WorkingDirectory = Package.FullPath, + }; + + var runner = new PackageModificationRunner + { + ShowDialogOnStart = true, + CloseWhenFinished = false, + ModificationCompleteMessage = "Python command executed successfully", + }; + EventManager.Instance.OnPackageInstallProgressAdded(runner); + await runner.ExecuteSteps([runCommandStep]).ConfigureAwait(false); + } + } + private async Task HasUpdate() { if (Package == null || IsUnknownPackage || Design.IsDesignMode || Package.DontCheckForUpdates) diff --git a/StabilityMatrix.Avalonia/ViewModels/Progress/NotificationItemViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Progress/NotificationItemViewModel.cs new file mode 100644 index 000000000..40803ef43 --- /dev/null +++ b/StabilityMatrix.Avalonia/ViewModels/Progress/NotificationItemViewModel.cs @@ -0,0 +1,131 @@ +using System; +using System.Threading.Tasks; +using Avalonia.Controls.Notifications; +using Avalonia.Media; +using CommunityToolkit.Mvvm.ComponentModel; +using CommunityToolkit.Mvvm.Input; +using StabilityMatrix.Avalonia.Extensions; +using StabilityMatrix.Avalonia.Languages; +using StabilityMatrix.Avalonia.Services; +using StabilityMatrix.Avalonia.ViewModels.Base; +using StabilityMatrix.Core.Models.Notifications; +using StabilityMatrix.Core.Services; + +namespace StabilityMatrix.Avalonia.ViewModels.Progress; + +public partial class NotificationItemViewModel : ViewModelBase +{ + private readonly INotificationHistoryService historyService; + private readonly INotificationActionDispatcher dispatcher; + + public NotificationHistoryEntry Entry { get; } + + public Guid Id => Entry.Id; + public string Title => Entry.Title; + public string? Body => Entry.Body; + public string? BodyImagePath => Entry.BodyImagePath; + public bool HasBodyImage => !string.IsNullOrEmpty(Entry.BodyImagePath); + public bool HasBody => !string.IsNullOrEmpty(Entry.Body); + public DateTimeOffset Timestamp => Entry.Timestamp; + public NotificationType SeverityType => Entry.Level.ToNotificationType(); + + public IBrush SeverityBrush => + Entry.Level switch + { + Core.Models.Settings.NotificationLevel.Success => Brushes.MediumSeaGreen, + Core.Models.Settings.NotificationLevel.Warning => Brushes.Orange, + Core.Models.Settings.NotificationLevel.Error => Brushes.IndianRed, + _ => Brushes.SteelBlue, + }; + + public bool HasAction => Entry.Action is not null; + + public string ActionLabel => + Entry.Action switch + { + OpenFolderAction => Resources.Action_OpenFolder, + NavigateToPageAction => Resources.Action_Open, + ToggleProgressFlyoutAction => Resources.Action_ShowActivity, + _ => Resources.Action_Open, + }; + + public string FormattedTimestamp => FormatRelative(Entry.Timestamp); + + /// True only when there is a body to show AND the row is currently collapsed β€” + /// keeps the preview from duplicating the body shown in the expanded section. + public bool IsPreviewBodyVisible => HasBody && !IsExpanded; + + /// Drives the leading unread-state dot. Disappears once the entry is read. + public bool IsUnreadIndicatorVisible => !IsRead; + + /// Read entries fade slightly so the active ones pop visually. + public double ReadOpacity => IsRead ? 0.55 : 1.0; + + [ObservableProperty] + private bool isExpanded; + + partial void OnIsExpandedChanged(bool value) => OnPropertyChanged(nameof(IsPreviewBodyVisible)); + + partial void OnIsReadChanged(bool value) + { + OnPropertyChanged(nameof(IsUnreadIndicatorVisible)); + OnPropertyChanged(nameof(ReadOpacity)); + } + + [ObservableProperty] + private bool isRead; + + public NotificationItemViewModel( + NotificationHistoryEntry entry, + INotificationHistoryService historyService, + INotificationActionDispatcher dispatcher + ) + { + Entry = entry; + this.historyService = historyService; + this.dispatcher = dispatcher; + isRead = entry.IsRead; + } + + [RelayCommand] + private async Task InvokeActionAsync() + { + MarkRead(); + if (Entry.Action is { } action) + { + await dispatcher.DispatchAsync(action); + } + } + + [RelayCommand] + private void Dismiss() => historyService.Remove(Entry.Id); + + [RelayCommand] + private void ToggleDetails() + { + IsExpanded = !IsExpanded; + MarkRead(); + } + + public void MarkRead() + { + if (IsRead) + return; + historyService.MarkRead(Entry.Id); + IsRead = true; + } + + public void RefreshReadState() => IsRead = Entry.IsRead; + + private static string FormatRelative(DateTimeOffset ts) + { + var delta = DateTimeOffset.Now - ts; + if (delta < TimeSpan.FromSeconds(45)) + return Resources.Label_RelativeTime_JustNow; + if (delta < TimeSpan.FromMinutes(60)) + return string.Format(Resources.Label_RelativeTime_MinutesAgo, (int)delta.TotalMinutes); + if (delta < TimeSpan.FromHours(24) && ts.Date == DateTimeOffset.Now.Date) + return ts.ToLocalTime().ToString("t"); + return ts.ToLocalTime().ToString("g"); + } +} diff --git a/StabilityMatrix.Avalonia/ViewModels/Progress/ProgressManagerViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Progress/ProgressManagerViewModel.cs index bd932933f..883d8f366 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Progress/ProgressManagerViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Progress/ProgressManagerViewModel.cs @@ -9,6 +9,7 @@ using Avalonia.Controls.Notifications; using Avalonia.Threading; using CommunityToolkit.Mvvm.ComponentModel; +using CommunityToolkit.Mvvm.Input; using FluentAvalonia.UI.Controls; using FluentAvalonia.UI.Media.Animation; using FluentIcons.Common; @@ -23,6 +24,7 @@ using StabilityMatrix.Core.Exceptions; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Models; +using StabilityMatrix.Core.Models.Notifications; using StabilityMatrix.Core.Models.PackageModification; using StabilityMatrix.Core.Models.Progress; using StabilityMatrix.Core.Models.Settings; @@ -42,36 +44,151 @@ public partial class ProgressManagerViewModel : PageViewModelBase private readonly INotificationService notificationService; private readonly INavigationService navigationService; private readonly INavigationService settingsNavService; + private readonly INotificationHistoryService notificationHistory; + private readonly INotificationActionDispatcher actionDispatcher; - public override string Title => "Download Manager"; + public override string Title => Resources.Label_Activity; public override IconSource IconSource => - new SymbolIconSource { Symbol = Symbol.ArrowCircleDown, IconVariant = IconVariant.Filled }; + new SymbolIconSource { Symbol = Symbol.History, IconVariant = IconVariant.Filled }; public AvaloniaList ProgressItems { get; } = new(); + public AvaloniaList NotificationItems { get; } = new(); + [ObservableProperty] private bool isOpen; + [ObservableProperty] + private int unreadNotificationCount; + + [ObservableProperty] + private int selectedTabIndex; + + /// + /// Pick the most-useful tab to show on flyout open: stay on In Progress if there's any active + /// download/install, otherwise jump to Notifications when only history is present. + /// + public void RecomputePreferredTab() + { + if (ProgressItems.Count == 0 && NotificationItems.Count > 0) + { + SelectedTabIndex = 1; + } + else + { + SelectedTabIndex = 0; + } + } + + /// True when either tab has any content β€” used to decide whether the footer flyout is reachable. + public bool HasAnyContent => ProgressItems.Count > 0 || NotificationItems.Count > 0; + + /// Combined counter rendered in the footer InfoBadge. + public int TotalBadgeCount => ProgressItems.Count + UnreadNotificationCount; + + /// Explicit bool so the InfoBadge can hide cleanly when nothing is pending. + public bool IsBadgeVisible => TotalBadgeCount > 0; + public ProgressManagerViewModel( ITrackedDownloadService trackedDownloadService, INotificationService notificationService, INavigationService navigationService, - INavigationService settingsNavService + INavigationService settingsNavService, + INotificationHistoryService notificationHistory, + INotificationActionDispatcher actionDispatcher ) { this.trackedDownloadService = trackedDownloadService; this.notificationService = notificationService; this.navigationService = navigationService; this.settingsNavService = settingsNavService; + this.notificationHistory = notificationHistory; + this.actionDispatcher = actionDispatcher; // Attach to the event trackedDownloadService.DownloadAdded += TrackedDownloadService_OnDownloadAdded; EventManager.Instance.ToggleProgressFlyout += (_, _) => IsOpen = !IsOpen; EventManager.Instance.PackageInstallProgressAdded += InstanceOnPackageInstallProgressAdded; EventManager.Instance.RecommendedModelsDialogClosed += InstanceOnRecommendedModelsDialogClosed; + + // Hydrate notifications (entries are stored newest-first) + foreach (var entry in notificationHistory.Entries) + { + NotificationItems.Add( + new NotificationItemViewModel(entry, notificationHistory, actionDispatcher) + ); + } + notificationHistory.EntryAdded += OnHistoryEntryAdded; + notificationHistory.EntriesChanged += OnHistoryChanged; + ProgressItems.CollectionChanged += (_, _) => + { + OnPropertyChanged(nameof(HasAnyContent)); + OnPropertyChanged(nameof(TotalBadgeCount)); + OnPropertyChanged(nameof(IsBadgeVisible)); + }; + NotificationItems.CollectionChanged += (_, _) => + { + OnPropertyChanged(nameof(HasAnyContent)); + }; + UnreadNotificationCount = notificationHistory.UnreadCount; + } + + private void OnHistoryEntryAdded(object? sender, NotificationHistoryEntry entry) + { + Dispatcher.UIThread.Post(() => + { + NotificationItems.Insert( + 0, + new NotificationItemViewModel(entry, notificationHistory, actionDispatcher) + ); + + // The service evicts the oldest entry (from the tail) once it hits its cap; mirror + // that here so the UI list stays in sync instead of growing unbounded. + while (NotificationItems.Count > notificationHistory.Count) + { + NotificationItems.RemoveAt(NotificationItems.Count - 1); + } + + UnreadNotificationCount = notificationHistory.UnreadCount; + }); + } + + private void OnHistoryChanged(object? sender, EventArgs e) + { + Dispatcher.UIThread.Post(() => + { + // Drop any entries that were evicted from the underlying service + var liveIds = notificationHistory.Entries.Select(x => x.Id).ToHashSet(); + for (var i = NotificationItems.Count - 1; i >= 0; i--) + { + if (!liveIds.Contains(NotificationItems[i].Id)) + { + NotificationItems.RemoveAt(i); + } + } + + foreach (var item in NotificationItems) + { + item.RefreshReadState(); + } + + UnreadNotificationCount = notificationHistory.UnreadCount; + }); + } + + partial void OnUnreadNotificationCountChanged(int value) + { + OnPropertyChanged(nameof(TotalBadgeCount)); + OnPropertyChanged(nameof(IsBadgeVisible)); } + [RelayCommand] + private void ClearNotifications() => notificationHistory.Clear(); + + [RelayCommand] + private void MarkAllNotificationsRead() => notificationHistory.MarkAllRead(); + private void InstanceOnRecommendedModelsDialogClosed(object? sender, EventArgs e) { var vm = ProgressItems.OfType().FirstOrDefault(); @@ -109,7 +226,8 @@ private void TrackedDownloadService_OnDownloadAdded(object? sender, TrackedDownl Title = "Download Completed", Body = $"Download of {e.FileName} completed successfully.", BodyImagePath = imageFile?.FullPath, - } + }, + action: new OpenFolderAction(e.DownloadDirectory.FullPath) ) .SafeFireAndForget(); @@ -158,7 +276,8 @@ await notificationService.ShowPersistentAsync( Title = "Download Disabled", Body = $"The creator of {e.FileName} has disabled downloads on this file", - } + }, + action: new ToggleProgressFlyoutAction() ) ); return; @@ -172,7 +291,8 @@ await notificationService.ShowPersistentAsync( { Title = "Download Failed", Body = $"Download of {e.FileName} failed: {msg}", - } + }, + action: new ToggleProgressFlyoutAction() ) ); diff --git a/StabilityMatrix.Avalonia/ViewModels/Settings/AccountSettingsViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Settings/AccountSettingsViewModel.cs index 0b0c0b6b5..c4dc14cd7 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Settings/AccountSettingsViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Settings/AccountSettingsViewModel.cs @@ -137,10 +137,6 @@ public record RoleBadge(string DisplayName, string ColorClass); [ObservableProperty] private CivitAccountStatusUpdateEventArgs civitStatus = CivitAccountStatusUpdateEventArgs.Disconnected; - // Assume HuggingFaceAccountStatusUpdateEventArgs will be created with at least these properties - // For now, using a placeholder or assuming a structure like: - // public record HuggingFaceAccountStatusUpdateEventArgs(bool IsConnected, string? Username); - // Initialize with a disconnected state. [ObservableProperty] private HuggingFaceAccountStatusUpdateEventArgs huggingFaceStatus = new(false, null); @@ -150,6 +146,12 @@ public record RoleBadge(string DisplayName, string ColorClass); [ObservableProperty] private string huggingFaceUsernameWithParentheses = string.Empty; + [ObservableProperty] + private GeminiAccountStatusUpdateEventArgs geminiStatus = GeminiAccountStatusUpdateEventArgs.Disconnected; + + [ObservableProperty] + private bool hasGeminiApiKey; + public string LykosAccountManageUrl => apiOptions.Value.LykosAccountApiBaseUrl.Append("/manage").ToString(); @@ -197,6 +199,16 @@ IOptions apiOptions // IsHuggingFaceConnected and HuggingFaceUsernameWithParentheses will be updated by OnHuggingFaceStatusChanged }); }; + + accountsService.GeminiAccountStatusUpdate += (_, args) => + { + Dispatcher.UIThread.Post(() => + { + IsInitialUpdateFinished = true; + GeminiStatus = args; + // HasGeminiApiKey will be updated by OnGeminiStatusChanged + }); + }; } /// @@ -366,7 +378,7 @@ private async Task ConnectHuggingFace() if (!await BeforeConnectCheck()) return; - var field = new TextBoxField + var field = new TextBoxField { Label = "Hugging Face Token", // Assuming Label is for the prompt IsPassword = true, // Assuming TextBoxField has an IsPassword property @@ -382,14 +394,14 @@ private async Task ConnectHuggingFace() var dialog = DialogHelper.CreateTextEntryDialog( "Connect Hugging Face Account", "Go to [Hugging Face settings](https://huggingface.co/settings/tokens) to create a new Access Token. Ensure it has read permissions. Paste the token below.", - [field] + [field] ); var result = await dialog.ShowAsync(); if (result == ContentDialogResult.Primary && !string.IsNullOrWhiteSpace(field.Text)) { - await accountsService.HuggingFaceLoginAsync(field.Text); + await accountsService.HuggingFaceLoginAsync(field.Text); await accountsService.RefreshAsync(); } } @@ -617,4 +629,70 @@ partial void OnHuggingFaceStatusChanged(HuggingFaceAccountStatusUpdateEventArgs } } } + + partial void OnGeminiStatusChanged(GeminiAccountStatusUpdateEventArgs value) + { + HasGeminiApiKey = value.IsConnected; + } + + [RelayCommand] + private async Task SetGeminiApiKey() + { + if (!await BeforeConnectCheck()) + return; + + var field = new TextBoxField + { + Label = "Gemini API Key", + IsPassword = true, + Validator = s => + { + if (string.IsNullOrWhiteSpace(s)) + { + throw new ValidationException("API key is required"); + } + }, + }; + + var dialog = DialogHelper.CreateTextEntryDialog( + "Set Gemini API Key", + """ + Get your Gemini API key from [Google AI Studio](https://ai.google.dev/) + + This key will be used for Image Lab image generation. + """, + null, + [field] + ); + dialog.PrimaryButtonText = "Save"; + + if (await dialog.ShowAsync() != ContentDialogResult.Primary || field.Text is not { } apiKey) + { + return; + } + + await accountsService.GeminiLoginAsync(apiKey); + notificationService.Show("Success", "Gemini API key saved", NotificationType.Success); + } + + [RelayCommand] + private async Task RemoveGeminiApiKey() + { + var dialog = new BetterContentDialog + { + Title = "Remove Gemini API Key", + Content = "Are you sure you want to remove your Gemini API key?", + PrimaryButtonText = "Remove", + CloseButtonText = "Cancel", + DefaultButton = ContentDialogButton.Close, + }; + + if (await dialog.ShowAsync() != ContentDialogResult.Primary) + { + return; + } + + await accountsService.GeminiLogoutAsync(); + notificationService.Show("Success", "Gemini API key removed", NotificationType.Success); + } } diff --git a/StabilityMatrix.Avalonia/ViewModels/Settings/InferenceSettingsViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Settings/InferenceSettingsViewModel.cs index 8cc7e19a4..2c3cc91b3 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Settings/InferenceSettingsViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Settings/InferenceSettingsViewModel.cs @@ -69,6 +69,9 @@ public partial class InferenceSettingsViewModel : PageViewModelBase [ObservableProperty] private bool filterExtraNetworksByBaseModel; + [ObservableProperty] + private bool useLegacySearch; + private List ignoredFileNameFormatVars = [ "author", @@ -147,6 +150,13 @@ ISettingsManager settingsManager true ); + settingsManager.RelayPropertyFor( + this, + vm => vm.UseLegacySearch, + settings => settings.UseLegacySearch, + true + ); + this.WhenPropertyChanged(vm => vm.OutputImageFileNameFormat) .Throttle(TimeSpan.FromMilliseconds(50)) .ObserveOn(SynchronizationContext.Current) diff --git a/StabilityMatrix.Avalonia/ViewModels/Settings/MainSettingsViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Settings/MainSettingsViewModel.cs index 83b33dbd7..ffe5beaa3 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Settings/MainSettingsViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Settings/MainSettingsViewModel.cs @@ -44,6 +44,7 @@ using StabilityMatrix.Avalonia.Models.TagCompletion; using StabilityMatrix.Avalonia.Services; using StabilityMatrix.Avalonia.ViewModels.Base; +using StabilityMatrix.Avalonia.ViewModels.CheckpointBrowser; using StabilityMatrix.Avalonia.ViewModels.CheckpointManager; using StabilityMatrix.Avalonia.ViewModels.Controls; using StabilityMatrix.Avalonia.ViewModels.Dialogs; @@ -134,6 +135,10 @@ public partial class MainSettingsViewModel : PageViewModelBase [ObservableProperty] private bool isDiscordRichPresenceEnabled; + // Appearance section + [ObservableProperty] + private bool scrollBarsAlwaysVisible = true; + // Console section [ObservableProperty] private int consoleLogHistorySize; @@ -256,6 +261,7 @@ ICivitBaseModelTypeService baseModelTypeService RemoveSymlinksOnShutdown = settingsManager.Settings.RemoveFolderLinksOnShutdown; SelectedAnimationScale = settingsManager.Settings.AnimationScale; HolidayModeSetting = settingsManager.Settings.HolidayModeSetting; + ScrollBarsAlwaysVisible = settingsManager.Settings.ScrollBarsAlwaysVisible; settingsManager.RelayPropertyFor(this, vm => vm.SelectedTheme, settings => settings.Theme); @@ -266,6 +272,17 @@ ICivitBaseModelTypeService baseModelTypeService true ); + settingsManager.RelayPropertyFor( + this, + vm => vm.ScrollBarsAlwaysVisible, + settings => settings.ScrollBarsAlwaysVisible, + true + ); + + // Push the initial value into the global DynamicResources that the + // ScrollBarStyles read from, so the loaded setting actually takes effect. + ApplyScrollBarSetting(ScrollBarsAlwaysVisible); + settingsManager.RelayPropertyFor( this, vm => vm.SelectedAnimationScale, @@ -523,6 +540,28 @@ partial void OnMaxConcurrentDownloadsChanged(int value) trackedDownloadService.UpdateMaxConcurrentDownloads(value); } + partial void OnScrollBarsAlwaysVisibleChanged(bool value) + { + ApplyScrollBarSetting(value); + } + + /// + /// Pushes the scrollbar-visibility setting into the global DynamicResources that + /// the ScrollBarStyles consume. Called both on initial settings load and whenever + /// the toggle changes, so changes take effect immediately without restart. + /// + private static void ApplyScrollBarSetting(bool alwaysVisible) + { + if (Application.Current is not { } app) + return; + + // AllowAutoHide is the inverse of "always visible". MinWidth=14 gives the + // thumb a comfortable hit-target when permanent; 0 lets Avalonia's default + // thin/expanded transition work normally in legacy mode. + app.Resources["ScrollBarAllowAutoHide"] = !alwaysVisible; + app.Resources["ScrollBarMinWidthValue"] = alwaysVisible ? 14d : 0d; + } + public async Task ResetCheckpointCache() { await notificationService.TryAsync(modelIndexService.RefreshIndex()); @@ -554,22 +593,32 @@ private async Task OpenEnvVarsDialog() { var viewModel = dialogFactory.Get(); - // Load current settings - var current = settingsManager.Settings.UserEnvironmentVariables ?? new Dictionary(); - viewModel.EnvVars = new ObservableCollection( - current.Select(kvp => new EnvVarKeyPair(kvp.Key, kvp.Value)) - ); + // Load current settings β€” prefer new list format, fall back to legacy dict + var currentList = settingsManager.Settings.UserEnvironmentVariablesList; + if (currentList is { Count: > 0 }) + { + viewModel.EnvVars = new ObservableCollection(currentList); + } + else + { + var current = + settingsManager.Settings.UserEnvironmentVariables ?? new Dictionary(); + viewModel.EnvVars = new ObservableCollection( + current.Select(kvp => new EnvVarKeyPair(kvp.Key, kvp.Value)) + ); + } var dialog = viewModel.GetDialog(); if (await dialog.ShowAsync() == ContentDialogResult.Primary) { - // Save settings - var newEnvVars = viewModel - .EnvVars.Where(kvp => !string.IsNullOrWhiteSpace(kvp.Key)) - .GroupBy(kvp => kvp.Key, StringComparer.Ordinal) - .ToDictionary(g => g.Key, g => g.First().Value, StringComparer.Ordinal); - settingsManager.Transaction(s => s.UserEnvironmentVariables = newEnvVars); + // Save in new list format, clear legacy dict + var newEnvVars = viewModel.EnvVars.Where(kvp => !string.IsNullOrWhiteSpace(kvp.Key)).ToList(); + settingsManager.Transaction(s => + { + s.UserEnvironmentVariablesList = newEnvVars; + s.UserEnvironmentVariables = null; + }); } } @@ -1270,6 +1319,7 @@ private async Task DebugRunUv() new CommandItem(DebugRobocopyCommand), new CommandItem(DebugInstallUvCommand), new CommandItem(DebugRunUvCommand), + new CommandItem(DebugClassifySafetensorCommand), ]; [RelayCommand] @@ -1533,6 +1583,42 @@ private void DebugNvidiaSmi() HardwareHelper.IterGpuInfoNvidiaSmi(); } + [RelayCommand] + private async Task DebugClassifySafetensor() + { + var files = await App.StorageProvider.OpenFilePickerAsync( + new FilePickerOpenOptions + { + Title = "Select a .safetensors file", + FileTypeFilter = [new FilePickerFileType("Safetensors") { Patterns = ["*.safetensors"] }], + } + ); + + if (files.Count == 0) + return; + + var filePath = files[0].TryGetLocalPath(); + if (filePath is null) + return; + + try + { + var kind = await SafetensorClassifier.ClassifyAsync(new FilePath(filePath)); + + var fileName = Path.GetFileName(filePath); + await DialogHelper + .CreateMarkdownDialog( + $"**File:** `{fileName}`\n\n**Classification:** `{kind}`", + "Safetensor Classification" + ) + .ShowAsync(); + } + catch (Exception e) + { + notificationService.Show("Classification failed", e.Message, NotificationType.Error); + } + } + #endregion #region Systems Setting Section diff --git a/StabilityMatrix.Avalonia/Views/BananaVisionPage.axaml b/StabilityMatrix.Avalonia/Views/BananaVisionPage.axaml new file mode 100644 index 000000000..7651684a0 --- /dev/null +++ b/StabilityMatrix.Avalonia/Views/BananaVisionPage.axaml @@ -0,0 +1,2291 @@ +ο»Ώ + + + + + + + + + M2.01 21L23 12 2.01 3 2 10l15 2-15 2z + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/StabilityMatrix.Avalonia/Views/BananaVisionPage.axaml.cs b/StabilityMatrix.Avalonia/Views/BananaVisionPage.axaml.cs new file mode 100644 index 000000000..1f8f0cac4 --- /dev/null +++ b/StabilityMatrix.Avalonia/Views/BananaVisionPage.axaml.cs @@ -0,0 +1,303 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using Avalonia; +using Avalonia.Controls; +using Avalonia.Input; +using Avalonia.Markup.Xaml; +using Avalonia.Platform.Storage; +using Injectio.Attributes; +using StabilityMatrix.Avalonia.Controls; +using StabilityMatrix.Avalonia.ViewModels; + +namespace StabilityMatrix.Avalonia.Views; + +[RegisterSingleton] +public partial class BananaVisionPage : UserControlBase +{ + /// + /// Threshold in pixels from the bottom to consider the scroll position "near bottom" + /// + private const double ScrollBottomThreshold = 120; + + private ScrollViewer? messageScrollViewer; + private Button? scrollToBottomButton; + private bool isNearBottom = true; + + private static bool IsScrollNearBottom( + ScrollViewer scrollViewer, + double thresholdPixels = ScrollBottomThreshold + ) + { + var extentHeight = scrollViewer.Extent.Height; + var viewportHeight = scrollViewer.Viewport.Height; + var offsetY = scrollViewer.Offset.Y; + + if (extentHeight <= 0 || viewportHeight <= 0) + return true; + + return offsetY + viewportHeight >= extentHeight - thresholdPixels; + } + + /// + /// Supported image extensions for drag and drop + /// + private static readonly HashSet SupportedImageExtensions = new(StringComparer.OrdinalIgnoreCase) + { + ".png", + ".jpg", + ".jpeg", + ".webp", + ".gif", + }; + + public BananaVisionPage() + { + InitializeComponent(); + + // Enable drag and drop + DragDrop.SetAllowDrop(this, true); + AddHandler(DragDrop.DragEnterEvent, OnDragEnter); + AddHandler(DragDrop.DragLeaveEvent, OnDragLeave); + AddHandler(DragDrop.DragOverEvent, OnDragOver); + AddHandler(DragDrop.DropEvent, OnDrop); + + // Handle keyboard events for paste + AddHandler(KeyDownEvent, OnKeyDown, handledEventsToo: true); + } + + private void InitializeComponent() + { + AvaloniaXamlLoader.Load(this); + } + + protected override void OnAttachedToVisualTree(VisualTreeAttachmentEventArgs e) + { + base.OnAttachedToVisualTree(e); + + // Set the StorageProvider on the ViewModel + if (DataContext is BananaVisionPageViewModel viewModel) + { + var topLevel = TopLevel.GetTopLevel(this); + if (topLevel != null) + { + viewModel.StorageProvider = topLevel.StorageProvider; + } + + // Subscribe to scroll request events from ViewModel + viewModel.ScrollToEndRequested += OnScrollToEndRequested; + viewModel.ScrollToEndForcedRequested += OnScrollToEndForcedRequested; + } + + // Find the message scroll viewer + messageScrollViewer = this.FindControl("MessageScrollViewer"); + scrollToBottomButton = this.FindControl + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/StabilityMatrix.Avalonia/Views/CivArchiveBrowserPage.axaml.cs b/StabilityMatrix.Avalonia/Views/CivArchiveBrowserPage.axaml.cs new file mode 100644 index 000000000..0e39bc39f --- /dev/null +++ b/StabilityMatrix.Avalonia/Views/CivArchiveBrowserPage.axaml.cs @@ -0,0 +1,43 @@ +using System; +using AsyncAwaitBestPractices; +using Avalonia.Controls; +using Avalonia.Input; +using Injectio.Attributes; +using StabilityMatrix.Avalonia.Controls; +using StabilityMatrix.Avalonia.Models; +using StabilityMatrix.Avalonia.ViewModels.CheckpointBrowser; + +namespace StabilityMatrix.Avalonia.Views; + +[RegisterSingleton] +public partial class CivArchiveBrowserPage : UserControlBase +{ + public CivArchiveBrowserPage() + { + InitializeComponent(); + } + + private void ScrollViewer_OnScrollChanged(object? sender, ScrollChangedEventArgs e) + { + if (sender is not ScrollViewer scrollViewer) + return; + + if (scrollViewer.Offset.Y == 0) + return; + + var isAtEnd = Math.Abs(scrollViewer.Offset.Y - scrollViewer.ScrollBarMaximum.Y) < 1f; + + if (isAtEnd && DataContext is IInfinitelyScroll scroll) + { + scroll.LoadNextPageAsync().SafeFireAndForget(); + } + } + + private void InputElement_OnKeyDown(object? sender, KeyEventArgs e) + { + if (e.Key == Key.Escape && DataContext is CivArchiveBrowserViewModel viewModel) + { + viewModel.ClearSearchQueryCommand.Execute(null); + } + } +} diff --git a/StabilityMatrix.Avalonia/Views/CivArchiveDetailsPage.axaml b/StabilityMatrix.Avalonia/Views/CivArchiveDetailsPage.axaml new file mode 100644 index 000000000..b19f62bf5 --- /dev/null +++ b/StabilityMatrix.Avalonia/Views/CivArchiveDetailsPage.axaml @@ -0,0 +1,963 @@ +ο»Ώ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/StabilityMatrix.Avalonia/Views/CivArchiveDetailsPage.axaml.cs b/StabilityMatrix.Avalonia/Views/CivArchiveDetailsPage.axaml.cs new file mode 100644 index 000000000..8505af7b4 --- /dev/null +++ b/StabilityMatrix.Avalonia/Views/CivArchiveDetailsPage.axaml.cs @@ -0,0 +1,26 @@ +using Avalonia; +using Avalonia.Controls; +using Avalonia.Input; +using Injectio.Attributes; +using StabilityMatrix.Avalonia.Controls; + +namespace StabilityMatrix.Avalonia.Views; + +[RegisterTransient] +public partial class CivArchiveDetailsPage : UserControlBase +{ + public CivArchiveDetailsPage() + { + InitializeComponent(); + } + + private void ImageScroller_OnPointerWheelChanged(object? sender, PointerWheelEventArgs e) + { + if (sender is not ScrollViewer sv) + return; + + var scrollAmount = e.Delta.Y * 75; + sv.Offset = new Vector(sv.Offset.X - scrollAmount, sv.Offset.Y); + e.Handled = true; + } +} diff --git a/StabilityMatrix.Avalonia/Views/CivitAiBrowserPage.axaml b/StabilityMatrix.Avalonia/Views/CivitAiBrowserPage.axaml index ff34486d8..fa75c157a 100644 --- a/StabilityMatrix.Avalonia/Views/CivitAiBrowserPage.axaml +++ b/StabilityMatrix.Avalonia/Views/CivitAiBrowserPage.axaml @@ -597,12 +597,24 @@ - + + ColumnDefinitions="2*, *"> + + + + + + + + + @@ -338,8 +356,22 @@ Background="White" BorderThickness="1" CornerRadius="4" + IsVisible="{Binding HasImages}" Opacity="0.5" /> + + + + + + + + + + + + + + + + + + + + + + diff --git a/StabilityMatrix.Avalonia/Views/Dialogs/ImageAnnotationEditorDialog.axaml.cs b/StabilityMatrix.Avalonia/Views/Dialogs/ImageAnnotationEditorDialog.axaml.cs new file mode 100644 index 000000000..973ed5cbf --- /dev/null +++ b/StabilityMatrix.Avalonia/Views/Dialogs/ImageAnnotationEditorDialog.axaml.cs @@ -0,0 +1,13 @@ +using Injectio.Attributes; +using StabilityMatrix.Avalonia.Controls; + +namespace StabilityMatrix.Avalonia.Views.Dialogs; + +[RegisterTransient] +public partial class ImageAnnotationEditorDialog : UserControlBase +{ + public ImageAnnotationEditorDialog() + { + InitializeComponent(); + } +} diff --git a/StabilityMatrix.Avalonia/Views/Dialogs/LayeredMaskEditorDialog.axaml b/StabilityMatrix.Avalonia/Views/Dialogs/LayeredMaskEditorDialog.axaml new file mode 100644 index 000000000..9b4a81e4d --- /dev/null +++ b/StabilityMatrix.Avalonia/Views/Dialogs/LayeredMaskEditorDialog.axaml @@ -0,0 +1,880 @@ +ο»Ώ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/StabilityMatrix.Avalonia/Views/Dialogs/LayeredMaskEditorDialog.axaml.cs b/StabilityMatrix.Avalonia/Views/Dialogs/LayeredMaskEditorDialog.axaml.cs new file mode 100644 index 000000000..6055456c4 --- /dev/null +++ b/StabilityMatrix.Avalonia/Views/Dialogs/LayeredMaskEditorDialog.axaml.cs @@ -0,0 +1,314 @@ +using Avalonia; +using Avalonia.Controls; +using Avalonia.Input; +using Avalonia.Interactivity; +using Avalonia.LogicalTree; +using Avalonia.Threading; +using Avalonia.VisualTree; +using Injectio.Attributes; +using StabilityMatrix.Avalonia.Controls; +using StabilityMatrix.Avalonia.Models.Inference; +using StabilityMatrix.Avalonia.ViewModels.Dialogs; + +namespace StabilityMatrix.Avalonia.Views.Dialogs; + +[RegisterTransient] +public partial class LayeredMaskEditorDialog : UserControlBase +{ + private ListBox? layerListBox; + private ScrollViewer? layerScrollViewer; + private DispatcherTimer? autoScrollTimer; + private double autoScrollSpeed; + private const double AutoScrollEdgeThreshold = 50; // Pixels from edge to trigger auto-scroll + private const double AutoScrollBaseSpeed = 5; // Base scroll speed in pixels per tick + private bool isDragging; + + public LayeredMaskEditorDialog() + { + InitializeComponent(); + } + + /// + protected override void OnLoaded(RoutedEventArgs e) + { + base.OnLoaded(e); + + // Find the ListBox and subscribe to child index changes for drag reordering + layerListBox = this.FindControl("LayerItemsControl"); + if (layerListBox != null) + { + ((IChildIndexProvider)layerListBox).ChildIndexChanged += OnChildIndexChanged; + + // Find the parent ScrollViewer + layerScrollViewer = layerListBox.FindAncestorOfType(); + } + + // Set up auto-scroll timer + autoScrollTimer = new DispatcherTimer { Interval = TimeSpan.FromMilliseconds(16) }; // ~60fps + autoScrollTimer.Tick += AutoScrollTimer_Tick; + + // Subscribe to pointer events for drag detection + AddHandler(PointerMovedEvent, OnPointerMoved, RoutingStrategies.Tunnel); + AddHandler(PointerReleasedEvent, OnPointerReleased, RoutingStrategies.Tunnel); + AddHandler(PointerCaptureLostEvent, OnPointerCaptureLost, RoutingStrategies.Tunnel); + + // Subscribe to keyboard events for shortcuts + // Use Tunnel strategy to intercept before ListBox handles navigation keys + // Also use handledEventsToo to receive events even if child controls handle them + AddHandler(KeyDownEvent, OnKeyDown, RoutingStrategies.Tunnel, handledEventsToo: true); + } + + /// + protected override void OnUnloaded(RoutedEventArgs e) + { + base.OnUnloaded(e); + + // Unsubscribe from events + if (layerListBox != null) + { + ((IChildIndexProvider)layerListBox).ChildIndexChanged -= OnChildIndexChanged; + } + + // Clean up timer + autoScrollTimer?.Stop(); + autoScrollTimer = null; + + RemoveHandler(PointerMovedEvent, OnPointerMoved); + RemoveHandler(PointerReleasedEvent, OnPointerReleased); + RemoveHandler(PointerCaptureLostEvent, OnPointerCaptureLost); + RemoveHandler(KeyDownEvent, OnKeyDown); + } + + /// + /// Handles the child index changed event from the ListBox. + /// This is fired when a drag reorder operation completes. + /// + private void OnChildIndexChanged(object? sender, ChildIndexChangedEventArgs e) + { + if ( + e.Child is Control { DataContext: MaskLayer layer } + && DataContext is LayeredMaskEditorViewModel vm + ) + { + vm.OnLayerIndexChanged(layer, e.Index); + } + } + + private DateTime _lastShortcutTime = DateTime.MinValue; + private static readonly TimeSpan ShortcutThrottle = TimeSpan.FromMilliseconds(100); + + /// + /// Handles keyboard shortcuts for layer operations. + /// + private void OnKeyDown(object? sender, KeyEventArgs e) + { + if (DataContext is not LayeredMaskEditorViewModel vm) + return; + + // Don't handle shortcuts when typing in a TextBox + if (e.Source is TextBox) + return; + + var ctrl = e.KeyModifiers.HasFlag(KeyModifiers.Control); + var shift = e.KeyModifiers.HasFlag(KeyModifiers.Shift); + + // Throttle rapid-fire events for navigation keys + if ((e.Key == Key.Up || e.Key == Key.Down) && DateTime.UtcNow - _lastShortcutTime < ShortcutThrottle) + { + e.Handled = true; + return; + } + + switch (e.Key) + { + // Delete - Delete selected layer + case Key.Delete when !ctrl && !shift: + if (vm.DeleteLayerCommand.CanExecute(null)) + { + vm.DeleteLayerCommand.Execute(null); + e.Handled = true; + } + break; + + // Ctrl+Shift+Delete - Clear all layers + case Key.Delete when ctrl && shift: + vm.ClearAllLayersCommand.Execute(null); + e.Handled = true; + break; + + // Ctrl+D - Duplicate layer + case Key.D when ctrl: + vm.DuplicateLayerCommand.Execute(null); + e.Handled = true; + break; + + // Ctrl+N - New layer + case Key.N when ctrl: + vm.AddLayerCommand.Execute(null); + e.Handled = true; + break; + + // Ctrl+F - Fill layer + case Key.F when ctrl: + vm.FillLayerCommand.Execute(null); + e.Handled = true; + break; + + // Ctrl+I - Invert layer + case Key.I when ctrl: + vm.InvertLayerCommand.Execute(null); + e.Handled = true; + break; + + // Ctrl+Shift+Z - Undo layer operation + case Key.Z when ctrl && shift: + if (vm.UndoLayerOperationCommand.CanExecute(null)) + { + vm.UndoLayerOperationCommand.Execute(null); + e.Handled = true; + } + break; + + // Ctrl+Up - Move layer up + case Key.Up when ctrl: + if (vm.MoveLayerUpCommand.CanExecute(null)) + { + vm.MoveLayerUpCommand.Execute(null); + _lastShortcutTime = DateTime.UtcNow; + e.Handled = true; + FocusSelectedLayer(); + } + break; + + // Ctrl+Down - Move layer down + case Key.Down when ctrl: + if (vm.MoveLayerDownCommand.CanExecute(null)) + { + vm.MoveLayerDownCommand.Execute(null); + _lastShortcutTime = DateTime.UtcNow; + e.Handled = true; + FocusSelectedLayer(); + } + break; + } + } + + /// + /// Handles pointer move to detect dragging and trigger auto-scroll near edges. + /// + private void OnPointerMoved(object? sender, PointerEventArgs e) + { + if (layerScrollViewer == null || layerListBox == null) + return; + + // Check if we're likely dragging (pointer is captured) + var pointer = e.Pointer; + if (pointer.Captured == null) + { + StopAutoScroll(); + return; + } + + // Check if pointer is over/near the layer list area + var scrollViewerBounds = layerScrollViewer.Bounds; + var pointerPos = e.GetPosition(layerScrollViewer); + + // Only process if pointer is within the horizontal bounds of the ScrollViewer + if (pointerPos.X < 0 || pointerPos.X > scrollViewerBounds.Width) + { + StopAutoScroll(); + return; + } + + isDragging = true; + + // Check if near top edge + if (pointerPos.Y < AutoScrollEdgeThreshold && pointerPos.Y >= -AutoScrollEdgeThreshold) + { + // Calculate speed based on proximity to edge (closer = faster) + var proximity = 1 - (pointerPos.Y / AutoScrollEdgeThreshold); + autoScrollSpeed = -AutoScrollBaseSpeed * Math.Max(1, proximity * 3); + StartAutoScroll(); + } + // Check if near bottom edge + else if ( + pointerPos.Y > scrollViewerBounds.Height - AutoScrollEdgeThreshold + && pointerPos.Y <= scrollViewerBounds.Height + AutoScrollEdgeThreshold + ) + { + var distanceFromBottom = scrollViewerBounds.Height - pointerPos.Y; + var proximity = 1 - (distanceFromBottom / AutoScrollEdgeThreshold); + autoScrollSpeed = AutoScrollBaseSpeed * Math.Max(1, proximity * 3); + StartAutoScroll(); + } + else + { + StopAutoScroll(); + } + } + + private void OnPointerReleased(object? sender, PointerReleasedEventArgs e) + { + isDragging = false; + StopAutoScroll(); + } + + private void OnPointerCaptureLost(object? sender, PointerCaptureLostEventArgs e) + { + isDragging = false; + StopAutoScroll(); + } + + private void StartAutoScroll() + { + if (autoScrollTimer != null && !autoScrollTimer.IsEnabled) + { + autoScrollTimer.Start(); + } + } + + private void StopAutoScroll() + { + autoScrollTimer?.Stop(); + autoScrollSpeed = 0; + } + + private void AutoScrollTimer_Tick(object? sender, EventArgs e) + { + if (layerScrollViewer == null || !isDragging || autoScrollSpeed == 0) + { + StopAutoScroll(); + return; + } + + var newOffset = layerScrollViewer.Offset.Y + autoScrollSpeed; + newOffset = Math.Max(0, Math.Min(newOffset, layerScrollViewer.ScrollBarMaximum.Y)); + layerScrollViewer.Offset = new Vector(layerScrollViewer.Offset.X, newOffset); + } + + private void FocusSelectedLayer() + { + if ( + DataContext is not LayeredMaskEditorViewModel vm + || vm.SelectedLayer == null + || layerListBox == null + ) + return; + + var layer = vm.SelectedLayer; + layerListBox.ScrollIntoView(layer); + + // Post to UI thread to allow layout updates to happen first + Dispatcher.UIThread.Post( + () => + { + var container = layerListBox.ContainerFromItem(layer); + if (container is Control control) + { + control.Focus(); + } + }, + DispatcherPriority.Input + ); + } +} diff --git a/StabilityMatrix.Avalonia/Views/Dialogs/MaskEditorDialog.axaml b/StabilityMatrix.Avalonia/Views/Dialogs/MaskEditorDialog.axaml index 1229a6de7..c2369953e 100644 --- a/StabilityMatrix.Avalonia/Views/Dialogs/MaskEditorDialog.axaml +++ b/StabilityMatrix.Avalonia/Views/Dialogs/MaskEditorDialog.axaml @@ -1,35 +1,50 @@ -ο»Ώ - +ο»Ώ + - + - + - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Stretch + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - + + - - - + + + - - - - - - - - - - - - - - + + + + + + + + + + + + + + - - - - - + + + + + - - - - - + + + + + - - - + + + - - - - - - - - - - - - - - + + + + + + + + + + + + + + - - - + + + - - + + - - - - - + + + + + - - - + + + - - + + - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + public string FileNameWithoutExtension => Path.GetFileNameWithoutExtension(AbsolutePath); + /// + /// Whether this file is a video file based on extension. + /// + public bool IsVideo => SupportedVideoExtensions.Contains(Path.GetExtension(AbsolutePath)); + + /// + /// Optional path to a thumbnail image for video files. + /// + public string? ThumbnailPath { get; init; } + public ( string? Parameters, string? ParametersJson, @@ -164,6 +174,20 @@ public static LocalImageFile FromPath(FilePath filePath) }; } + // Handle video files - no image metadata, just file info + if (SupportedVideoExtensions.Contains(filePath.Extension)) + { + filePath.Info.Refresh(); + + return new LocalImageFile + { + AbsolutePath = filePath, + ImageType = imageType, + CreatedAt = filePath.Info.CreationTimeUtc, + LastModifiedAt = filePath.Info.LastWriteTimeUtc, + }; + } + using var fs = new FileStream(filePath, FileMode.Open, FileAccess.Read); using var ms = new SKManagedStream(fs); var codec = SKCodec.Create(ms); @@ -186,4 +210,13 @@ public static LocalImageFile FromPath(FilePath filePath) ".gif", ".webp", ]; + + public static readonly HashSet SupportedVideoExtensions = new(StringComparer.OrdinalIgnoreCase) + { + ".mp4", + ".webm", + ".mov", + ".avi", + ".mkv", + }; } diff --git a/StabilityMatrix.Core/Models/Database/LocalModelFile.cs b/StabilityMatrix.Core/Models/Database/LocalModelFile.cs index 23ee347a4..b6d4ed58f 100644 --- a/StabilityMatrix.Core/Models/Database/LocalModelFile.cs +++ b/StabilityMatrix.Core/Models/Database/LocalModelFile.cs @@ -176,6 +176,11 @@ public override int GetHashCode() public bool HasOpenModelDbMetadata => HasConnectedModel && ConnectedModelInfo.Source == ConnectedModelSource.OpenModelDb; + [BsonIgnore] + [MemberNotNullWhen(true, nameof(ConnectedModelInfo))] + public bool HasCivArchiveMetadata => + HasConnectedModel && ConnectedModelInfo.Source == ConnectedModelSource.CivArchive; + public string GetFullPath(string rootModelDirectory) { return Path.Combine(rootModelDirectory, RelativePath); diff --git a/StabilityMatrix.Core/Models/EnvVarKeyPair.cs b/StabilityMatrix.Core/Models/EnvVarKeyPair.cs index 3a093837a..287ed2014 100644 --- a/StabilityMatrix.Core/Models/EnvVarKeyPair.cs +++ b/StabilityMatrix.Core/Models/EnvVarKeyPair.cs @@ -4,10 +4,12 @@ public class EnvVarKeyPair { public string Key { get; set; } public string Value { get; set; } + public bool IsEnabled { get; set; } - public EnvVarKeyPair(string key = "", string value = "") + public EnvVarKeyPair(string key = "", string value = "", bool isEnabled = true) { Key = key; Value = value; + IsEnabled = isEnabled; } } diff --git a/StabilityMatrix.Core/Models/HybridModelFile.cs b/StabilityMatrix.Core/Models/HybridModelFile.cs index 9348a6d15..03464db8d 100644 --- a/StabilityMatrix.Core/Models/HybridModelFile.cs +++ b/StabilityMatrix.Core/Models/HybridModelFile.cs @@ -230,4 +230,47 @@ public int GetHashCode(HybridModelFile obj) [JsonIgnore] public string SearchText => SortKey; + + /// + /// Gets a detailed search text string that includes model name, version, filename, tags, + /// trained words, and other metadata for comprehensive search/filtering. + /// + [JsonIgnore] + public string DetailedSearchText + { + get + { + var terms = new List + { + SearchText, + ShortDisplayName, + FileName, + RemoteName ?? string.Empty, + DownloadableResource?.FileName ?? string.Empty, + DownloadableResource?.RelativeDirectory ?? string.Empty, + }; + + if (Local is { } localModel) + { + terms.Add(localModel.DisplayModelName); + terms.Add(localModel.DisplayModelVersion); + terms.Add(localModel.DisplayModelFileName); + terms.Add(localModel.DisplayConfigFileName); + + if (localModel.ConnectedModelInfo is { } connectedModel) + { + terms.Add(connectedModel.BaseModel ?? string.Empty); + terms.Add(connectedModel.ModelType.GetStringValue()); + + if (connectedModel.Tags.Length > 0) + terms.Add(string.Join(' ', connectedModel.Tags)); + + if (connectedModel.TrainedWords is { Length: > 0 }) + terms.Add(string.Join(' ', connectedModel.TrainedWords)); + } + } + + return string.Join(' ', terms.Where(term => !string.IsNullOrWhiteSpace(term))); + } + } } diff --git a/StabilityMatrix.Core/Models/Notifications/NotificationAction.cs b/StabilityMatrix.Core/Models/Notifications/NotificationAction.cs new file mode 100644 index 000000000..0aa701c83 --- /dev/null +++ b/StabilityMatrix.Core/Models/Notifications/NotificationAction.cs @@ -0,0 +1,25 @@ +using System.Text.Json.Serialization; + +namespace StabilityMatrix.Core.Models.Notifications; + +/// +/// Discriminated record describing what should happen when a notification (toast or history entry) is invoked. +/// New variants must be registered with here so the dispatcher can switch on them. +/// +[JsonPolymorphic(TypeDiscriminatorPropertyName = "$kind")] +[JsonDerivedType(typeof(OpenFolderAction), "OpenFolder")] +[JsonDerivedType(typeof(NavigateToPageAction), "NavigateToPage")] +[JsonDerivedType(typeof(ToggleProgressFlyoutAction), "ToggleProgressFlyout")] +public abstract record NotificationAction; + +/// Open the given filesystem path in the OS file manager. +public sealed record OpenFolderAction(string Path) : NotificationAction; + +/// +/// Navigate the main shell to the page identified by the given ViewModel type. +/// Stored as the assembly-qualified type name so the action remains serializable. +/// +public sealed record NavigateToPageAction(string PageTypeName) : NotificationAction; + +/// Open the activity flyout in the sidebar footer. +public sealed record ToggleProgressFlyoutAction : NotificationAction; diff --git a/StabilityMatrix.Core/Models/Notifications/NotificationHistoryEntry.cs b/StabilityMatrix.Core/Models/Notifications/NotificationHistoryEntry.cs new file mode 100644 index 000000000..808f746d7 --- /dev/null +++ b/StabilityMatrix.Core/Models/Notifications/NotificationHistoryEntry.cs @@ -0,0 +1,29 @@ +using StabilityMatrix.Core.Models.Settings; + +namespace StabilityMatrix.Core.Models.Notifications; + +/// +/// A persisted record of a single notification shown (or suppressed) during this session. +/// Exposed through . +/// +public sealed record NotificationHistoryEntry +{ + public Guid Id { get; init; } = Guid.NewGuid(); + + public DateTimeOffset Timestamp { get; init; } = DateTimeOffset.Now; + + public NotificationKey? Key { get; init; } + + public string Title { get; init; } = string.Empty; + + public string? Body { get; init; } + + public string? BodyImagePath { get; init; } + + public NotificationLevel Level { get; init; } = NotificationLevel.Information; + + public NotificationAction? Action { get; init; } + + /// Mutable so callers can flip read-state without having to replace the entry. + public bool IsRead { get; set; } +} diff --git a/StabilityMatrix.Core/Models/PackageModification/InstallSageAttentionStep.cs b/StabilityMatrix.Core/Models/PackageModification/InstallSageAttentionStep.cs index dbaf7cb0a..bf0137de1 100644 --- a/StabilityMatrix.Core/Models/PackageModification/InstallSageAttentionStep.cs +++ b/StabilityMatrix.Core/Models/PackageModification/InstallSageAttentionStep.cs @@ -1,4 +1,5 @@ -ο»Ώusing StabilityMatrix.Core.Exceptions; +ο»Ώusing System; +using StabilityMatrix.Core.Exceptions; using StabilityMatrix.Core.Extensions; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Models.FileInterfaces; diff --git a/StabilityMatrix.Core/Models/PackageModification/RunPythonCommandStep.cs b/StabilityMatrix.Core/Models/PackageModification/RunPythonCommandStep.cs new file mode 100644 index 000000000..b8e009a72 --- /dev/null +++ b/StabilityMatrix.Core/Models/PackageModification/RunPythonCommandStep.cs @@ -0,0 +1,52 @@ +ο»Ώusing StabilityMatrix.Core.Extensions; +using StabilityMatrix.Core.Models.FileInterfaces; +using StabilityMatrix.Core.Models.Progress; +using StabilityMatrix.Core.Processes; +using StabilityMatrix.Core.Python; +using StabilityMatrix.Core.Services; + +namespace StabilityMatrix.Core.Models.PackageModification; + +public class RunPythonCommandStep( + IPyInstallationManager pyInstallationManager, + ISettingsManager settingsManager +) : IPackageStep +{ + public required InstalledPackage InstalledPackage { get; init; } + public required DirectoryPath WorkingDirectory { get; init; } + public required ProcessArgs Arguments { get; init; } + + public async Task ExecuteAsync(IProgress? progress = null) + { + progress?.Report( + new ProgressReport(-1f, message: "Setting up virtual environment...", isIndeterminate: true) + ); + + var venvDir = WorkingDirectory.JoinDir("venv"); + var pyVersion = PyVersion.Parse(InstalledPackage.PythonVersion); + if (pyVersion.StringValue == "0.0.0") + { + pyVersion = PyInstallationManager.Python_3_10_11; + } + + var baseInstall = !string.IsNullOrWhiteSpace(InstalledPackage.PythonVersion) + ? new PyBaseInstall( + await pyInstallationManager.GetInstallationAsync(pyVersion).ConfigureAwait(false) + ) + : PyBaseInstall.Default; + + await using var venvRunner = baseInstall.CreateVenvRunner( + venvDir, + workingDirectory: WorkingDirectory, + environmentVariables: settingsManager.Settings.EnvironmentVariables + ); + + venvRunner.RunDetached(Arguments, progress.AsProcessOutputHandler()); + if (venvRunner.Process != null) + { + await venvRunner.Process.WaitForExitAsync().ConfigureAwait(false); + } + } + + public string ProgressTitle => "Running Python Command"; +} diff --git a/StabilityMatrix.Core/Models/Packages/AiToolkit.cs b/StabilityMatrix.Core/Models/Packages/AiToolkit.cs index b732dad75..d40bb083a 100644 --- a/StabilityMatrix.Core/Models/Packages/AiToolkit.cs +++ b/StabilityMatrix.Core/Models/Packages/AiToolkit.cs @@ -82,17 +82,23 @@ public override async Task InstallPackage( .ConfigureAwait(false); venvRunner.UpdateEnvironmentVariables(GetEnvVars); - var isBlackwell = - SettingsManager.Settings.PreferredGpu?.IsBlackwellGpu() ?? HardwareHelper.HasBlackwellGpu(); + var isLegacyNvidia = + SettingsManager.Settings.PreferredGpu?.IsLegacyNvidiaGpu() ?? HardwareHelper.HasLegacyNvidiaGpu(); var config = new PipInstallConfig { RequirementsFilePaths = ["requirements.txt"], - TorchVersion = "==2.7.0", - TorchvisionVersion = "==0.22.0", - TorchaudioVersion = "==2.7.0", - CudaIndex = isBlackwell ? "cu128" : "cu126", + // Upstream (ostris/ai-toolkit README) installs torch 2.9.1 / cu128. + TorchVersion = "==2.9.1", + TorchvisionVersion = "==0.24.1", + TorchaudioVersion = "==2.9.1", + // cu128 by default; keep cu126 for legacy NVIDIA GPUs without cu128 support. + CudaIndex = isLegacyNvidia ? "cu126" : "cu128", ExtraPipArgs = [Compat.IsWindows ? "triton-windows" : "triton"], + // ai-toolkit doesn't pin numpy, so it floats to 2.x and breaks the scipy/diffusers + // C-extensions (built for numpy 1.x): "numpy.dtype size changed... binary incompatibility". + // Pin to the last 1.x release to keep them ABI-compatible. + PostInstallPipArgs = ["numpy==1.26.4"], UpgradePackages = true, }; diff --git a/StabilityMatrix.Core/Models/Packages/BaseGitPackage.cs b/StabilityMatrix.Core/Models/Packages/BaseGitPackage.cs index 82b06f817..2d16f3c2c 100644 --- a/StabilityMatrix.Core/Models/Packages/BaseGitPackage.cs +++ b/StabilityMatrix.Core/Models/Packages/BaseGitPackage.cs @@ -886,7 +886,8 @@ protected PipInstallArgs GetTorchPipArgs( string torchaudioVersion = "", string xformersVersion = "", string cudaIndex = "cu130", - string rocmIndex = "rocm6.4" + string rocmIndex = "rocm6.4", + string xpuIndex = "xpu" ) { var pipArgs = new PipInstallArgs(); @@ -910,6 +911,7 @@ protected PipInstallArgs GetTorchPipArgs( TorchIndex.Rocm => rocmIndex, TorchIndex.Mps => "cpu", TorchIndex.Zluda => cudaIndex, + TorchIndex.Ipex => xpuIndex, _ => "cpu", }; diff --git a/StabilityMatrix.Core/Models/Packages/Cogstudio.cs b/StabilityMatrix.Core/Models/Packages/Cogstudio.cs index a030dbb79..721c936d3 100644 --- a/StabilityMatrix.Core/Models/Packages/Cogstudio.cs +++ b/StabilityMatrix.Core/Models/Packages/Cogstudio.cs @@ -73,9 +73,9 @@ public override async Task InstallPackage( .ConfigureAwait(false); progress?.Report(new ProgressReport(-1f, "Setting up Cogstudio files", isIndeterminate: true)); - var gradioCompositeDemo = new FilePath(installLocation, "inference/gradio_composite_demo"); - var cogstudioFile = new FilePath(gradioCompositeDemo, "cogstudio.py"); - gradioCompositeDemo.Directory?.Create(); + var gradioCompositeDemo = new DirectoryPath(installLocation, "inference", "gradio_composite_demo"); + gradioCompositeDemo.Create(); + var cogstudioFile = gradioCompositeDemo.JoinFile("cogstudio.py"); await DownloadService .DownloadToFileAsync(cogstudioUrl, cogstudioFile, cancellationToken: cancellationToken) .ConfigureAwait(false); @@ -121,12 +121,25 @@ await requirements.ReadAllTextAsync(cancellationToken).ConfigureAwait(false), // SwissArmyTransformer is not available on Windows and DeepSpeed needs prebuilt wheels if (Compat.IsWindows) { - await venvRunner - .PipInstall( - " https://github.com/daswer123/deepspeed-windows/releases/download/11.2/deepspeed-0.11.2+cuda121-cp310-cp310-win_amd64.whl", - onConsoleOutput - ) - .ConfigureAwait(false); + // This prebuilt deepspeed wheel reports an internal metadata version of "0.11.2+unknown", + // which doesn't match its filename's "0.11.2+cuda121". uv treats that mismatch as a + // malformed wheel and refuses to install it; opt out of the filename/version consistency + // check just for this wheel, then restore strict checking for the rest of the install. + venvRunner.UpdateEnvironmentVariables(env => env.SetItem("UV_SKIP_WHEEL_FILENAME_CHECK", "1")); + try + { + await venvRunner + .PipInstall( + " https://github.com/daswer123/deepspeed-windows/releases/download/11.2/deepspeed-0.11.2+cuda121-cp310-cp310-win_amd64.whl", + onConsoleOutput + ) + .ConfigureAwait(false); + } + finally + { + venvRunner.UpdateEnvironmentVariables(env => env.Remove("UV_SKIP_WHEEL_FILENAME_CHECK")); + } + await venvRunner .PipInstall("spandrel opencv-python scikit-video", onConsoleOutput) .ConfigureAwait(false); diff --git a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs index b2f3c973b..0b4c530fa 100644 --- a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs +++ b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs @@ -196,6 +196,30 @@ IRocmPackageHelper rocmPackageHelper TargetRelativePaths = ["models/diffusion_models"], ConfigDocumentPaths = ["diffusion_models"], }, + new SharedFolderLayoutRule // Style Models (e.g. Flux Redux, B-Lora) + { + SourceTypes = [SharedFolderType.StyleModels], + TargetRelativePaths = ["models/style_models"], + ConfigDocumentPaths = ["style_models"], + }, + new SharedFolderLayoutRule // Audio Encoders + { + SourceTypes = [SharedFolderType.AudioEncoders], + TargetRelativePaths = ["models/audio_encoders"], + ConfigDocumentPaths = ["audio_encoders"], + }, + new SharedFolderLayoutRule // Model Patches + { + SourceTypes = [SharedFolderType.ModelPatches], + TargetRelativePaths = ["models/model_patches"], + ConfigDocumentPaths = ["model_patches"], + }, + new SharedFolderLayoutRule // Background Removal (e.g. BiRefNet) + { + SourceTypes = [SharedFolderType.BackgroundRemoval], + TargetRelativePaths = ["models/background_removal"], + ConfigDocumentPaths = ["background_removal"], + }, ], }; @@ -323,7 +347,14 @@ IRocmPackageHelper rocmPackageHelper public override string MainBranch => "master"; public override IEnumerable AvailableTorchIndices => - [TorchIndex.Cpu, TorchIndex.Cuda, TorchIndex.DirectMl, TorchIndex.Rocm, TorchIndex.Mps]; + [ + TorchIndex.Cpu, + TorchIndex.Cuda, + TorchIndex.DirectMl, + TorchIndex.Ipex, + TorchIndex.Mps, + TorchIndex.Rocm, + ]; public override List GetExtraCommands() { @@ -455,6 +486,7 @@ await rocmPackageHelper TorchaudioVersion = " ", // Request torchaudio without a specific version CudaIndex = isLegacyNvidia ? "cu126" : "cu130", RocmIndex = "rocm7.2", + XpuIndex = "xpu", UpgradePackages = true, PostInstallPipArgs = ["typing-extensions>=4.15.0"], }; @@ -563,6 +595,14 @@ await SetupVenv(installLocation, pythonVersion: PyVersion.Parse(installedPackage VenvRunner.UpdateEnvironmentVariables(env => GetEnvVars(env, installLocation, installedPackage)); var launchArguments = NormalizeLaunchArguments(installedPackage, options.Arguments); + // Don't leak the build-constraints env var into the running package. It points at a + // path relative to the install dir (see BaseGitPackage.SetupVenvPure), which only resolves + // when the working directory is that install dir. ComfyUI-Manager launches `uv pip install` + // from a different working directory, so the inherited value breaks with + // "File not found: venv/uv-build-constraints.txt". The constraint is only needed for our + // own setup-time builds, not for the running server. + VenvRunner.UpdateEnvironmentVariables(env => env.Remove("UV_BUILD_CONSTRAINT")); + // Check for old NVIDIA driver version with cu130 installations var isNvidia = SettingsManager.Settings.PreferredGpu?.IsNvidia ?? HardwareHelper.HasNvidiaGpu(); var isLegacyNvidia = @@ -926,6 +966,19 @@ await venvRunner venvRunner.WorkingDirectory = installScript.Directory; venvRunner.UpdateEnvironmentVariables(env => { + // Recompute UV_BUILD_CONSTRAINT relative to the new working directory, + // since the constraints file is in the ComfyUI root's venv folder. + var constraintsAbsPath = Path.Combine( + installedPackage.FullPath!, + "venv", + "uv-build-constraints.txt" + ); + var constraintsRelPath = Path.GetRelativePath( + installScript.Directory!.FullPath, + constraintsAbsPath + ); + env = env.SetItem("UV_BUILD_CONSTRAINT", constraintsRelPath); + // set env vars for Impact Pack for Face Detailer env = env.SetItem("COMFYUI_PATH", installedPackage.FullPath!); @@ -1186,6 +1239,18 @@ InstalledPackage installedPackage var hasRocmGpu = HasWindowsRocmSupport(); var selectedTorchIndex = installedPackage.PreferredTorchIndex ?? GetRecommendedTorchVersion(); + // Add FFmpeg to PATH if it's installed (optional - for video processing) + if (PrerequisiteHelper.IsFfmpegInstalled) + { + var ffmpegDir = Path.GetDirectoryName(PrerequisiteHelper.FfmpegPath); + if (!string.IsNullOrEmpty(ffmpegDir)) + { + var currentPath = + env.GetValueOrDefault("PATH") ?? Environment.GetEnvironmentVariable("PATH") ?? ""; + env = env.SetItem("PATH", Compat.GetEnvPathWithExtensions(ffmpegDir, currentPath)); + } + } + if (!Compat.IsWindows || !hasRocmGpu || selectedTorchIndex != TorchIndex.Rocm) return env; diff --git a/StabilityMatrix.Core/Models/Packages/FluxGym.cs b/StabilityMatrix.Core/Models/Packages/FluxGym.cs index 177915582..abf17b7b2 100644 --- a/StabilityMatrix.Core/Models/Packages/FluxGym.cs +++ b/StabilityMatrix.Core/Models/Packages/FluxGym.cs @@ -121,12 +121,20 @@ await PrerequisiteHelper var config = new PipInstallConfig { RequirementsFilePaths = ["sd-scripts/requirements.txt", "requirements.txt"], + // The fluxgym and sd-scripts requirements conflict on several pins, so we exclude the + // conflicting entries here and reinstall known-good versions via PostInstallPipArgs: + // - diffusers: fluxgym pins an unpinned git+HEAD build (now resolves to a dev release + // requiring safetensors>=0.8.0rc0) while sd-scripts pins safetensors==0.4.5, which is + // unsatisfiable. Exclude both the git HEAD and the sd-scripts diffusers[torch]==0.32.1. + // - transformers: fluxgym pins 4.49.0 and sd-scripts pins 4.54.1, also unsatisfiable. + // We then install the known-good diffusers[torch]==0.32.1 (compatible with safetensors + // 0.4.5) and transformers==4.54.1 after the requirements step. RequirementsExcludePattern = - "(diffusers\\[torch\\]==0.32.1|torch|torchvision|torchaudio|xformers|bitsandbytes|-e\\s\\.)", + "(diffusers\\[torch\\]==0.32.1|git\\+https://github\\.com/huggingface/diffusers\\.git|torch|torchvision|torchaudio|xformers|bitsandbytes|transformers.*|-e\\s\\.)", TorchaudioVersion = " ", CudaIndex = isLegacyNvidiaGpu ? "cu126" : "cu128", ExtraPipArgs = ["bitsandbytes>=0.46.0"], - PostInstallPipArgs = ["diffusers[torch]==0.32.1"], + PostInstallPipArgs = ["diffusers[torch]==0.32.1", "transformers==4.54.1"], }; await StandardPipInstallProcessAsync( diff --git a/StabilityMatrix.Core/Models/Packages/ForgeClassic.cs b/StabilityMatrix.Core/Models/Packages/ForgeClassic.cs index 1084eec83..f9365ab75 100644 --- a/StabilityMatrix.Core/Models/Packages/ForgeClassic.cs +++ b/StabilityMatrix.Core/Models/Packages/ForgeClassic.cs @@ -40,7 +40,6 @@ IPipWheelService pipWheelService public override string RepositoryName => "sd-webui-forge-classic"; public override string DisplayName { get; set; } = "Stable Diffusion WebUI Forge - Classic"; public override string MainBranch => "classic"; - public override string Blurb => "This fork is focused exclusively on SD1 and SDXL checkpoints, having various optimizations implemented, with the main goal of being the lightest WebUI without any bloatwares."; public override string LicenseUrl => diff --git a/StabilityMatrix.Core/Models/Packages/ForgeNeo.cs b/StabilityMatrix.Core/Models/Packages/ForgeNeo.cs index 8c5d65fc5..ee8f4f9f2 100644 --- a/StabilityMatrix.Core/Models/Packages/ForgeNeo.cs +++ b/StabilityMatrix.Core/Models/Packages/ForgeNeo.cs @@ -1,6 +1,7 @@ ο»Ώusing Injectio.Attributes; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Helper.Cache; +using StabilityMatrix.Core.Models; using StabilityMatrix.Core.Python; using StabilityMatrix.Core.Services; @@ -28,6 +29,25 @@ IPipWheelService pipWheelService public override string DisplayName { get; set; } = "Stable Diffusion WebUI Forge - Neo"; public override string MainBranch => "neo"; public override PackageType PackageType => PackageType.SdInference; + public override List LaunchOptions + { + get + { + var options = new List(base.LaunchOptions); + var insertIndex = Math.Max(0, options.Count - 1); + options.Insert( + insertIndex, + new LaunchOptionDefinition + { + Name = "Bitsandbytes NF4", + Type = LaunchOptionType.Bool, + Description = "Install bitsandbytes for low-bits (NF4) inference", + Options = ["--bnb"], + } + ); + return options; + } + } public override string Blurb => "Neo mainly serves as an continuation for the \"latest\" version of Forge. Additionally, this fork is focused on optimization and usability, with the main goal of being the lightest WebUI without any bloatwares."; diff --git a/StabilityMatrix.Core/Models/Packages/KohyaSs.cs b/StabilityMatrix.Core/Models/Packages/KohyaSs.cs index eee7305f4..cde21c178 100644 --- a/StabilityMatrix.Core/Models/Packages/KohyaSs.cs +++ b/StabilityMatrix.Core/Models/Packages/KohyaSs.cs @@ -172,10 +172,12 @@ await venvRunner { PrePipInstallArgs = ["rich", "packaging", "setuptools", "uv"], RequirementsFilePaths = ["requirements_windows.txt"], - // Exclude torch ecosystem (default) AND the specific bitsandbytes version - RequirementsExcludePattern = - "(torch|torchvision|torchaudio|xformers|bitsandbytes==0\\.44\\.0)", - TorchaudioVersion = " ", + // Pin torch to match upstream's requirements_pytorch_windows.txt (torch 2.7.0+cu128, + // torchvision 0.22.0+cu128, xformers>=0.0.30). The torch ecosystem is excluded from + // the requirements file via the default pattern and installed here instead. + TorchVersion = "==2.7.0", + TorchvisionVersion = "==0.22.0", + TorchaudioVersion = "==2.7.0", XformersVersion = " ", CudaIndex = isLegacyNvidia ? "cu126" : "cu128", // Add back the generic bitsandbytes and the specific numpy version diff --git a/StabilityMatrix.Core/Models/Packages/PipInstallConfig.cs b/StabilityMatrix.Core/Models/Packages/PipInstallConfig.cs index af30cd2f7..70865129f 100644 --- a/StabilityMatrix.Core/Models/Packages/PipInstallConfig.cs +++ b/StabilityMatrix.Core/Models/Packages/PipInstallConfig.cs @@ -17,6 +17,7 @@ public record PipInstallConfig public string XformersVersion { get; init; } = ""; public string CudaIndex { get; init; } = "cu130"; public string RocmIndex { get; init; } = "rocm7.2"; + public string XpuIndex { get; init; } = "xpu"; public bool ForceReinstallTorch { get; init; } = true; public bool UpgradePackages { get; init; } = false; public bool SkipTorchInstall { get; init; } = false; diff --git a/StabilityMatrix.Core/Models/Packages/Reforge.cs b/StabilityMatrix.Core/Models/Packages/Reforge.cs index c0d9f760c..6bf79fb8c 100644 --- a/StabilityMatrix.Core/Models/Packages/Reforge.cs +++ b/StabilityMatrix.Core/Models/Packages/Reforge.cs @@ -155,6 +155,10 @@ await rocmPackageHelper progress?.Report(new ProgressReport(1f, "Install complete", isIndeterminate: false)); } + // reForge upstream pins torch==2.9.0 (modules/launch_utils.py TORCH_COMMAND); torchvision stays + // unpinned to match. Forge (the SDWebForge base) keeps its own default. + protected override string TorchVersionSpec => "==2.9.0"; + protected override ImmutableDictionary GetEnvVars( ImmutableDictionary env, InstalledPackage installedPackage diff --git a/StabilityMatrix.Core/Models/Packages/SDWebForge.cs b/StabilityMatrix.Core/Models/Packages/SDWebForge.cs index 84a45d009..517427ad1 100644 --- a/StabilityMatrix.Core/Models/Packages/SDWebForge.cs +++ b/StabilityMatrix.Core/Models/Packages/SDWebForge.cs @@ -48,6 +48,12 @@ IPipWheelService pipWheelService public override PackageDifficulty InstallerSortOrder => PackageDifficulty.Simple; public override PackageType PackageType => PackageType.Legacy; + /// + /// Torch version spec used during install. Default (" ") leaves torch unpinned (resolved from + /// the cu* index per the requirements file). Subclasses can override to pin a specific version. + /// + protected virtual string TorchVersionSpec => " "; + public override List LaunchOptions => [ new() @@ -182,7 +188,7 @@ torchIndex is TorchIndex.Cuda { PrePipInstallArgs = ["joblib"], RequirementsFilePaths = requirementsPaths, - TorchVersion = " ", + TorchVersion = TorchVersionSpec, TorchvisionVersion = " ", CudaIndex = isLegacyNvidia ? "cu126" : "cu128", RocmIndex = "rocm7.2", diff --git a/StabilityMatrix.Core/Models/Packages/StableSwarm.cs b/StabilityMatrix.Core/Models/Packages/StableSwarm.cs index 0e1ac3a23..70d8110ab 100644 --- a/StabilityMatrix.Core/Models/Packages/StableSwarm.cs +++ b/StabilityMatrix.Core/Models/Packages/StableSwarm.cs @@ -406,19 +406,21 @@ public override async Task RunPackage( ["DOTNET_ROOT"] = dotnetDir.FullPath, }; - if (aspEnvVars.TryGetValue("PATH", out var pathValue)) - { - aspEnvVars["PATH"] = Compat.GetEnvPathWithExtensions( - dotnetDir.FullPath, - portableGitBin, - pathValue - ); - } - else + // Build PATH with required directories + var pathDirs = new List { dotnetDir.FullPath, portableGitBin }; + + // Add FFmpeg to PATH if it's installed (optional - for video processing) + if (PrerequisiteHelper.IsFfmpegInstalled) { - aspEnvVars["PATH"] = Compat.GetEnvPathWithExtensions(dotnetDir.FullPath, portableGitBin); + var ffmpegDir = Path.GetDirectoryName(PrerequisiteHelper.FfmpegPath); + if (!string.IsNullOrEmpty(ffmpegDir)) + { + pathDirs.Add(ffmpegDir); + } } + aspEnvVars["PATH"] = Compat.GetEnvPathWithExtensions([.. pathDirs]); + aspEnvVars.Update(settingsManager.Settings.EnvironmentVariables); aspEnvVars.Update(BuildLinkedComfyLaunchEnvironment()); // Windows ROCm ComfyUI env var pass-through diff --git a/StabilityMatrix.Core/Models/Packages/VladAutomatic.cs b/StabilityMatrix.Core/Models/Packages/VladAutomatic.cs index 09cc94b85..691f92a2e 100644 --- a/StabilityMatrix.Core/Models/Packages/VladAutomatic.cs +++ b/StabilityMatrix.Core/Models/Packages/VladAutomatic.cs @@ -332,7 +332,7 @@ public override async Task InstallPackage( // Run initial install case TorchIndex.Cuda: await venvRunner - .CustomInstall("launch.py --use-cuda --debug --test --uv", onConsoleOutput) + .CustomInstall(["launch.py", "--use-cuda", "--debug", "--test", "--uv"], onConsoleOutput) .ConfigureAwait(false); break; case TorchIndex.Rocm: diff --git a/StabilityMatrix.Core/Models/Packages/Wan2GP.cs b/StabilityMatrix.Core/Models/Packages/Wan2GP.cs index 0ea46fbfb..11def19bc 100644 --- a/StabilityMatrix.Core/Models/Packages/Wan2GP.cs +++ b/StabilityMatrix.Core/Models/Packages/Wan2GP.cs @@ -450,8 +450,26 @@ await SetupVenv(installLocation, pythonVersion: PyVersion.Parse(installedPackage VenvRunner.UpdateEnvironmentVariables(env => env.SetItems(rocmEnvironment)); } - // Fix for distutils compatibility issue with Python 3.10 and setuptools - VenvRunner.UpdateEnvironmentVariables(env => env.SetItem("SETUPTOOLS_USE_DISTUTILS", "stdlib")); + // Set environment variables + VenvRunner.UpdateEnvironmentVariables(env => + { + // Fix for distutils compatibility issue with Python 3.10 and setuptools + env = env.SetItem("SETUPTOOLS_USE_DISTUTILS", "stdlib"); + + // Add FFmpeg to PATH if it's installed (optional - for video processing) + if (!PrerequisiteHelper.IsFfmpegInstalled) + return env; + + var ffmpegDir = Path.GetDirectoryName(PrerequisiteHelper.FfmpegPath); + if (string.IsNullOrEmpty(ffmpegDir)) + return env; + + var currentPath = + env.GetValueOrDefault("PATH") ?? Environment.GetEnvironmentVariable("PATH") ?? ""; + env = env.SetItem("PATH", Compat.GetEnvPathWithExtensions(ffmpegDir, currentPath)); + + return env; + }); // Write the Gradio logging patch wrapper script so gr.Info/Warning/Error // messages are also printed to stdout/stderr for console capture diff --git a/StabilityMatrix.Core/Models/Secrets.cs b/StabilityMatrix.Core/Models/Secrets.cs index 6ce43c844..87f680563 100644 --- a/StabilityMatrix.Core/Models/Secrets.cs +++ b/StabilityMatrix.Core/Models/Secrets.cs @@ -13,6 +13,8 @@ public readonly record struct Secrets public LykosAccountV2Tokens? LykosAccountV2 { get; init; } public string? HuggingFaceToken { get; init; } + + public string? GeminiApiKey { get; init; } } public static class SecretsExtensions diff --git a/StabilityMatrix.Core/Models/Settings/CivArchiveBrowserOptions.cs b/StabilityMatrix.Core/Models/Settings/CivArchiveBrowserOptions.cs new file mode 100644 index 000000000..db6491edd --- /dev/null +++ b/StabilityMatrix.Core/Models/Settings/CivArchiveBrowserOptions.cs @@ -0,0 +1,19 @@ +using System.Collections.Generic; +using StabilityMatrix.Core.Models.Api.CivArchive; + +namespace StabilityMatrix.Core.Models.Settings; + +public class CivArchiveBrowserOptions +{ + public string Query { get; set; } = string.Empty; + public string Tags { get; set; } = string.Empty; + public string Username { get; set; } = string.Empty; + public CivArchivePlatformOption Platform { get; set; } = CivArchivePlatformOption.All; + public CivArchiveSortOption Sort { get; set; } = CivArchiveSortOption.Top; + public CivArchivePeriodOption Period { get; set; } = CivArchivePeriodOption.All; + public CivArchiveRatingOption Rating { get; set; } = CivArchiveRatingOption.Safe; + public CivArchivePlatformStatusOption PlatformStatus { get; set; } = CivArchivePlatformStatusOption.All; + public CivArchiveKindOption Kind { get; set; } = CivArchiveKindOption.All; + public List SelectedModelTypes { get; set; } = []; + public List SelectedBaseModels { get; set; } = []; +} diff --git a/StabilityMatrix.Core/Models/Settings/ModelPickerFilterState.cs b/StabilityMatrix.Core/Models/Settings/ModelPickerFilterState.cs new file mode 100644 index 000000000..82a88fec0 --- /dev/null +++ b/StabilityMatrix.Core/Models/Settings/ModelPickerFilterState.cs @@ -0,0 +1,9 @@ +namespace StabilityMatrix.Core.Models.Settings; + +public class ModelPickerFilterState +{ + public string SearchText { get; set; } = string.Empty; + public bool ShowCheckpointsOnly { get; set; } + public bool ShowUnetsOnly { get; set; } + public List SelectedBaseModels { get; set; } = []; +} diff --git a/StabilityMatrix.Core/Models/Settings/NotificationKey.cs b/StabilityMatrix.Core/Models/Settings/NotificationKey.cs index 6778c30e6..b13207986 100644 --- a/StabilityMatrix.Core/Models/Settings/NotificationKey.cs +++ b/StabilityMatrix.Core/Models/Settings/NotificationKey.cs @@ -23,7 +23,15 @@ public record NotificationKey(string Value) : StringValue(Value), IParsable + new("Inference_BatchCompleted") + { + DefaultOption = Compat.IsLinux ? NotificationOption.AppToast : NotificationOption.NativePush, + Level = NotificationLevel.Success, + DisplayName = "Inference Batch Completed", }; public static NotificationKey Download_Completed => @@ -31,7 +39,7 @@ public record NotificationKey(string Value) : StringValue(Value), IParsable @@ -39,7 +47,7 @@ public record NotificationKey(string Value) : StringValue(Value), IParsable @@ -47,7 +55,7 @@ public record NotificationKey(string Value) : StringValue(Value), IParsable @@ -55,7 +63,7 @@ public record NotificationKey(string Value) : StringValue(Value), IParsable @@ -63,7 +71,7 @@ public record NotificationKey(string Value) : StringValue(Value), IParsable All { get; } = GetValues(); diff --git a/StabilityMatrix.Core/Models/Settings/Settings.cs b/StabilityMatrix.Core/Models/Settings/Settings.cs index eb1337e47..dd0f25ccb 100644 --- a/StabilityMatrix.Core/Models/Settings/Settings.cs +++ b/StabilityMatrix.Core/Models/Settings/Settings.cs @@ -4,6 +4,7 @@ using Semver; using StabilityMatrix.Core.Converters.Json; using StabilityMatrix.Core.Helper.HardwareInfo; +using StabilityMatrix.Core.Models.Api.CivArchive; using StabilityMatrix.Core.Models.Update; namespace StabilityMatrix.Core.Models.Settings; @@ -84,6 +85,7 @@ public InstalledPackage? PreferredWorkflowPackage public WindowSettings? WindowSettings { get; set; } public ModelSearchOptions? ModelSearchOptions { get; set; } + public CivArchiveBrowserOptions CivArchiveBrowserOptions { get; set; } = new(); /// /// Whether prompt auto completion is enabled @@ -119,6 +121,14 @@ public InstalledPackage? PreferredWorkflowPackage public bool IsDiscordRichPresenceEnabled { get; set; } + /// + /// When true, vertical scrollbars stay permanently visible at their expanded + /// (thicker) thickness instead of fading to a thin auto-hiding bar. Default + /// matches the app's preferred behavior; users can toggle back to Avalonia's + /// auto-hide default via the Appearance settings. + /// + public bool ScrollBarsAlwaysVisible { get; set; } = true; + public HashSet DisabledBaseModelTypes { get; set; } = []; public HashSet SavedInferenceDimensions { get; set; } = @@ -151,6 +161,8 @@ public InstalledPackage? PreferredWorkflowPackage [JsonPropertyName("EnvironmentVariables")] public Dictionary? UserEnvironmentVariables { get; set; } + public List? UserEnvironmentVariablesList { get; set; } + [JsonIgnore] public IReadOnlyDictionary EnvironmentVariables { @@ -164,13 +176,21 @@ public IReadOnlyDictionary EnvironmentVariables "cache" ); - if (UserEnvironmentVariables is null || UserEnvironmentVariables.Count == 0) + // Prefer new list format, fall back to legacy dict + var userVars = UserEnvironmentVariablesList is { Count: > 0 } + ? UserEnvironmentVariablesList + .Where(kvp => kvp.IsEnabled && !string.IsNullOrWhiteSpace(kvp.Key)) + .GroupBy(kvp => kvp.Key, StringComparer.Ordinal) + .ToDictionary(g => g.Key, g => g.Last().Value, StringComparer.Ordinal) + : UserEnvironmentVariables; + + if (userVars is null || userVars.Count == 0) { return DefaultEnvironmentVariables; } return DefaultEnvironmentVariables - .Concat(UserEnvironmentVariables) + .Concat(userVars) .GroupBy(pair => pair.Key) // User variables override default variables with the same key .ToDictionary(grouping => grouping.Key, grouping => grouping.Last().Value); @@ -211,6 +231,9 @@ public IReadOnlyDictionary EnvironmentVariables public int ConsoleFontSize { get; set; } = 14; public bool AutoLoadCivitModels { get; set; } = true; + [JsonPropertyName("UseLegacyModelSearch")] + public bool UseLegacySearch { get; set; } + /// /// When false, will copy files when drag/drop import happens /// Otherwise, it will move, as it states @@ -235,6 +258,10 @@ public IReadOnlyDictionary EnvironmentVariables public double CivitBrowserResizeFactor { get; set; } = 1.0d; + public double CivArchiveBrowserResizeFactor { get; set; } = 1.0d; + + public bool CivArchiveBrowserFitCardImages { get; set; } = true; + public bool HideEarlyAccessModels { get; set; } public bool CivitUseDiscoveryApi { get; set; } @@ -257,6 +284,11 @@ public IReadOnlyDictionary EnvironmentVariables public string? CivitModelBrowserFileNamePattern { get; set; } + public string? ModelOrganizationFileNamePattern { get; set; } + + public bool ModelPickerIsGridView { get; set; } + public Dictionary ModelPickerFilterStates { get; set; } = []; + public int InferenceDimensionStepChange { get; set; } = 128; [JsonIgnore] @@ -337,4 +369,7 @@ public static CultureInfo GetDefaultCulture() [JsonSerializable(typeof(string))] [JsonSerializable(typeof(LastDownloadLocationInfo))] [JsonSerializable(typeof(Dictionary))] +[JsonSerializable(typeof(ModelPickerFilterState))] +[JsonSerializable(typeof(Dictionary))] +[JsonSerializable(typeof(CivArchiveBrowserOptions))] internal partial class SettingsSerializerContext : JsonSerializerContext; diff --git a/StabilityMatrix.Core/Models/Settings/TeachingTip.cs b/StabilityMatrix.Core/Models/Settings/TeachingTip.cs index f3a9fdc2f..a170e7914 100644 --- a/StabilityMatrix.Core/Models/Settings/TeachingTip.cs +++ b/StabilityMatrix.Core/Models/Settings/TeachingTip.cs @@ -12,7 +12,6 @@ public record TeachingTip(string Value) : StringValue(Value) public static TeachingTip AccountsCredentialsStorageNotice => new("AccountsCredentialsStorageNotice"); public static TeachingTip CheckpointCategoriesTip => new("CheckpointCategoriesTip"); public static TeachingTip PackageExtensionsInstallNotice => new("PackageExtensionsInstallNotice"); - public static TeachingTip DownloadsTip => new("DownloadsTip"); public static TeachingTip WebUiButtonMovedTip => new("WebUiButtonMovedTip"); public static TeachingTip InferencePromptHelpButtonTip => new("InferencePromptHelpButtonTip"); public static TeachingTip LykosAccountMigrateTip => new("LykosAccountMigrateTip"); @@ -20,6 +19,7 @@ public record TeachingTip(string Value) : StringValue(Value) public static TeachingTip FolderMapTip => new("FolderMapTip"); public static TeachingTip InferencePromptAmplifyTip => new("InferencePromptAmplifyTip"); public static TeachingTip PromptAmplifyDisclaimer => new("PromptAmplifyDisclaimer"); + public static TeachingTip ActivityCenterTip => new("ActivityCenterTip"); /// public override string ToString() diff --git a/StabilityMatrix.Core/Models/SharedFolderType.cs b/StabilityMatrix.Core/Models/SharedFolderType.cs index 4978ad3f7..9b42a81f2 100644 --- a/StabilityMatrix.Core/Models/SharedFolderType.cs +++ b/StabilityMatrix.Core/Models/SharedFolderType.cs @@ -46,4 +46,16 @@ public enum SharedFolderType : ulong [Extensions.Description("Diffusion Models (UNet-only)")] DiffusionModels = 1ul << 31, + + [Extensions.Description("Style Models")] + StyleModels = 1ul << 32, + + [Extensions.Description("Audio Encoders")] + AudioEncoders = 1ul << 33, + + [Extensions.Description("Model Patches")] + ModelPatches = 1ul << 34, + + [Extensions.Description("Background Removal")] + BackgroundRemoval = 1ul << 35, } diff --git a/StabilityMatrix.Core/Python/PipInstallArgs.cs b/StabilityMatrix.Core/Python/PipInstallArgs.cs index 6c58430d1..42aa255e0 100644 --- a/StabilityMatrix.Core/Python/PipInstallArgs.cs +++ b/StabilityMatrix.Core/Python/PipInstallArgs.cs @@ -34,6 +34,9 @@ public PipInstallArgs WithExtraIndex(string indexUrl) => public PipInstallArgs WithTorchExtraIndex(string index) => WithExtraIndex($"https://download.pytorch.org/whl/{index}"); + public PipInstallArgs WithUvTorchExtraIndex(string index) => + this.AddKeyedArgs("--index", ["--index", $"pytorch=https://download.pytorch.org/whl/{index}"]); + public PipInstallArgs WithParsedFromRequirementsTxt( string requirements, [StringSyntax(StringSyntaxAttribute.Regex)] string? excludePattern = null @@ -129,13 +132,12 @@ public PipInstallArgs RemovePipArgKey(string argumentKey) return this with { Arguments = Arguments - .Where( - arg => - arg.HasKey - ? (arg.Key != argumentKey) - : (arg.Value != argumentKey && !arg.Value.Contains($"{argumentKey}==")) + .Where(arg => + arg.HasKey + ? (arg.Key != argumentKey) + : (arg.Value != argumentKey && !arg.Value.Contains($"{argumentKey}==")) ) - .ToImmutableList() + .ToImmutableList(), }; } diff --git a/StabilityMatrix.Core/Services/IModelIndexService.cs b/StabilityMatrix.Core/Services/IModelIndexService.cs index 9b7bfcf86..3a9e60df3 100644 --- a/StabilityMatrix.Core/Services/IModelIndexService.cs +++ b/StabilityMatrix.Core/Services/IModelIndexService.cs @@ -13,6 +13,18 @@ public interface IModelIndexService /// IReadOnlySet ModelIndexBlake3Hashes { get; } + /// + /// Set of all files SHA256 hashes (case-insensitive). + /// Synchronized with internal changes to . + /// + IReadOnlySet ModelIndexSha256Hashes { get; } + + /// + /// Set of CivArchive relative URLs for all locally installed CivArchive-sourced models. + /// Synchronized with internal changes to . + /// + IReadOnlySet ModelIndexCivArchiveUrls { get; } + /// /// Refreshes the local model file index. /// diff --git a/StabilityMatrix.Core/Services/INotificationHistoryService.cs b/StabilityMatrix.Core/Services/INotificationHistoryService.cs new file mode 100644 index 000000000..1ba9943b5 --- /dev/null +++ b/StabilityMatrix.Core/Services/INotificationHistoryService.cs @@ -0,0 +1,37 @@ +using StabilityMatrix.Core.Models.Notifications; + +namespace StabilityMatrix.Core.Services; + +/// +/// Session-only history of notifications shown by . +/// Populated regardless of whether the user has the corresponding NotificationKey suppressed, so suppressed events +/// are still visible in the activity panel. +/// +public interface INotificationHistoryService +{ + IReadOnlyList Entries { get; } + + /// Total number of entries. O(1) β€” avoids the snapshot allocation of . + int Count { get; } + + int UnreadCount { get; } + + event EventHandler? EntryAdded; + + /// Raised on bulk changes (clear / mark-all-read / remove) where per-entry events would be noisy. + event EventHandler? EntriesChanged; + + /// Add an entry. Returns the same entry (with a fresh Id if none was set) so callers can correlate. + NotificationHistoryEntry Add(NotificationHistoryEntry entry); + + void MarkRead(Guid id); + + void MarkAllRead(); + + void Remove(Guid id); + + void Clear(); + + /// Look up an entry by Id, or null if it has been evicted. + NotificationHistoryEntry? Find(Guid id); +} diff --git a/StabilityMatrix.Core/Services/IVideoThumbnailService.cs b/StabilityMatrix.Core/Services/IVideoThumbnailService.cs new file mode 100644 index 000000000..e4d8aedf1 --- /dev/null +++ b/StabilityMatrix.Core/Services/IVideoThumbnailService.cs @@ -0,0 +1,40 @@ +using StabilityMatrix.Core.Models.Progress; + +namespace StabilityMatrix.Core.Services; + +/// +/// Service for generating video thumbnails using FFmpeg. +/// +public interface IVideoThumbnailService +{ + /// + /// Supported video file extensions. + /// + IReadOnlySet SupportedVideoExtensions { get; } + + /// + /// Check if a file path represents a video file based on extension. + /// + bool IsVideoFile(string filePath); + + /// + /// Get or create a thumbnail for a video file. + /// + /// Absolute path to the video file. + /// Cancellation token. + /// Absolute path to the thumbnail image, or null if generation failed. + Task GetOrCreateThumbnailAsync(string videoPath, CancellationToken cancellationToken = default); + + /// + /// Gets the thumbnail path if it already exists (synchronous, no generation). + /// Use this for initial display to avoid async delays. + /// + /// Absolute path to the video file. + /// Absolute path to the thumbnail if it exists, otherwise null. + string? GetExistingThumbnailPath(string videoPath); + + /// + /// Ensures FFmpeg is installed before using the service. + /// + Task EnsureFfmpegInstalledAsync(IProgress? progress = null); +} diff --git a/StabilityMatrix.Core/Services/ImageGeneration/BananaVisionProviderIds.cs b/StabilityMatrix.Core/Services/ImageGeneration/BananaVisionProviderIds.cs new file mode 100644 index 000000000..c58441848 --- /dev/null +++ b/StabilityMatrix.Core/Services/ImageGeneration/BananaVisionProviderIds.cs @@ -0,0 +1,54 @@ +namespace StabilityMatrix.Core.Services.ImageGeneration; + +/// +/// Contains constant provider IDs for BananaVision image generation providers. +/// Use these constants instead of hardcoded strings to avoid typos and enable refactoring. +/// +public static class BananaVisionProviderIds +{ + /// + /// Gemini 2.5 Flash image generation provider (Nano Banana) + /// + public const string Gemini25Flash = "gemini-2.5-flash"; + + /// + /// Gemini 3.1 Flash image generation provider (Nano Banana 2) with thinking support + /// + public const string Gemini31Flash = "gemini-3.1-flash"; + + /// + /// Gemini 3 Pro image generation provider with thinking support + /// + public const string Gemini3Pro = "gemini-3-pro"; + + /// + /// Flux Kontext local image generation provider (requires ComfyUI) + /// + public const string FluxKontext = "flux-kontext"; + + /// + /// Qwen Image Edit local image generation provider (requires ComfyUI) + /// + public const string QwenImageEdit = "qwen-image-edit"; + + /// + /// Flux.2 Klein local image generation provider (requires ComfyUI) + /// + public const string Flux2Klein = "flux2-klein"; + + /// + /// Check if a provider ID is a local provider that requires ComfyUI backend + /// + public static bool IsLocalProvider(string? providerId) => + providerId is FluxKontext or QwenImageEdit or Flux2Klein; + + /// + /// Check if a provider ID is a cloud/API provider (Gemini) + /// + public static bool IsCloudProvider(string? providerId) => providerId?.Contains("gemini") == true; + + /// + /// Check if a provider ID supports thinking/reasoning output + /// + public static bool SupportsThinking(string? providerId) => providerId is Gemini3Pro or Gemini31Flash; +} diff --git a/StabilityMatrix.Core/Services/ImageGeneration/Gemini31FlashImageGenerationProvider.cs b/StabilityMatrix.Core/Services/ImageGeneration/Gemini31FlashImageGenerationProvider.cs new file mode 100644 index 000000000..6388d241c --- /dev/null +++ b/StabilityMatrix.Core/Services/ImageGeneration/Gemini31FlashImageGenerationProvider.cs @@ -0,0 +1,67 @@ +using Microsoft.Extensions.Logging; +using StabilityMatrix.Core.Api; +using StabilityMatrix.Core.Models.Api.Gemini; + +namespace StabilityMatrix.Core.Services.ImageGeneration; + +/// +/// Image generation provider for Google Gemini 3.1 Flash (Nano Banana 2) +/// with thinking/reasoning support. Uses the newer thinking_level string config +/// instead of the integer thinking_budget used by Gemini 3 Pro. +/// +public class Gemini31FlashImageGenerationProvider( + ILogger logger, + IGeminiApi geminiApi, + ISecretsManager secretsManager +) : GeminiBaseImageGenerationProvider(logger, geminiApi, secretsManager) +{ + private const string DefaultThinkingLevel = "high"; + + public override string ProviderId => BananaVisionProviderIds.Gemini31Flash; + public override string ProviderName => "Gemini 3.1 Flash (Nano Banana 2)"; + public override string DefaultModel => "gemini-3.1-flash-image-preview"; + public override bool RequiresThoughtSignatures => true; + + protected override GeminiGenerateContentRequest BuildGeminiRequest(ImageGenerationRequest request) + { + var geminiRequest = base.BuildGeminiRequest(request); + + var enableThinking = + request.ProviderOptions?.TryGetValue("enableThinking", out var thinkingValue) == true + && thinkingValue is true or "true"; + + var thinkingLevel = + request.ProviderOptions?.TryGetValue("thinkingLevel", out var levelValue) == true + && levelValue is string level + && !string.IsNullOrWhiteSpace(level) + ? level + : DefaultThinkingLevel; + + Logger.LogInformation( + "Gemini 3.1 Flash Config - Thinking: {Thinking}, Level: {Level}", + enableThinking, + enableThinking ? thinkingLevel : "minimal" + ); + + // Only attach a thinkingConfig when the user explicitly enabled thinking. + // Omitting it lets Gemini 3.1 Flash use its server-side default ("minimal"). + if (enableThinking) + { + var existingConfig = geminiRequest.GenerationConfig ?? new GeminiGenerationConfig(); + + geminiRequest = geminiRequest with + { + GenerationConfig = existingConfig with + { + ThinkingConfig = new GeminiThinkingConfig + { + ThinkingLevel = thinkingLevel, + IncludeThoughts = true, + }, + }, + }; + } + + return geminiRequest; + } +} diff --git a/StabilityMatrix.Core/Services/ImageGeneration/Gemini3ProImageGenerationProvider.cs b/StabilityMatrix.Core/Services/ImageGeneration/Gemini3ProImageGenerationProvider.cs new file mode 100644 index 000000000..a90ebddb6 --- /dev/null +++ b/StabilityMatrix.Core/Services/ImageGeneration/Gemini3ProImageGenerationProvider.cs @@ -0,0 +1,66 @@ +using Microsoft.Extensions.Logging; +using StabilityMatrix.Core.Api; +using StabilityMatrix.Core.Models.Api.Gemini; + +namespace StabilityMatrix.Core.Services.ImageGeneration; + +/// +/// Image generation provider for Google Gemini 3 Pro (Nano Banana Pro Preview) +/// with thinking/reasoning support +/// +public class Gemini3ProImageGenerationProvider( + ILogger logger, + IGeminiApi geminiApi, + ISecretsManager secretsManager +) : GeminiBaseImageGenerationProvider(logger, geminiApi, secretsManager) +{ + private const int DefaultThinkingBudget = 2048; + + public override string ProviderId => BananaVisionProviderIds.Gemini3Pro; + public override string ProviderName => "Gemini 3 Pro (Nano Banana Pro)"; + public override string DefaultModel => "gemini-3-pro-image-preview"; + public override bool RequiresThoughtSignatures => true; + + protected override GeminiGenerateContentRequest BuildGeminiRequest(ImageGenerationRequest request) + { + // Get the base request + var geminiRequest = base.BuildGeminiRequest(request); + + // Check if thinking is enabled + var enableThinking = + request.ProviderOptions?.TryGetValue("enableThinking", out var thinkingValue) == true + && thinkingValue is true or "true"; + + var thinkingBudget = + request.ProviderOptions?.TryGetValue("thinkingBudget", out var budgetValue) == true + && budgetValue is int budget + ? budget + : DefaultThinkingBudget; + + Logger.LogInformation( + "Gemini 3 Pro Config - Thinking: {Thinking}, Budget: {Budget}", + enableThinking, + enableThinking ? thinkingBudget : 0 + ); + + // Add thinking config if enabled + if (enableThinking) + { + var existingConfig = geminiRequest.GenerationConfig ?? new GeminiGenerationConfig(); + + geminiRequest = geminiRequest with + { + GenerationConfig = existingConfig with + { + ThinkingConfig = new GeminiThinkingConfig + { + ThinkingBudget = thinkingBudget, + IncludeThoughts = true, + }, + }, + }; + } + + return geminiRequest; + } +} diff --git a/StabilityMatrix.Core/Services/ImageGeneration/GeminiBaseImageGenerationProvider.cs b/StabilityMatrix.Core/Services/ImageGeneration/GeminiBaseImageGenerationProvider.cs new file mode 100644 index 000000000..4e75091f4 --- /dev/null +++ b/StabilityMatrix.Core/Services/ImageGeneration/GeminiBaseImageGenerationProvider.cs @@ -0,0 +1,288 @@ +using Microsoft.Extensions.Logging; +using Refit; +using StabilityMatrix.Core.Api; +using StabilityMatrix.Core.Models.Api.Gemini; + +namespace StabilityMatrix.Core.Services.ImageGeneration; + +/// +/// Base class for Gemini image generation providers +/// +public abstract class GeminiBaseImageGenerationProvider( + ILogger logger, + IGeminiApi geminiApi, + ISecretsManager secretsManager +) : IImageGenerationProvider +{ + public abstract string ProviderId { get; } + public abstract string ProviderName { get; } + public abstract string DefaultModel { get; } + public bool SupportsImageInput => true; + public bool SupportsMultiTurn => true; + public virtual bool RequiresThoughtSignatures => false; + + protected ILogger Logger => logger; + + public async Task GenerateAsync( + ImageGenerationRequest request, + CancellationToken cancellationToken = default + ) + { + // Check for API key first + var secrets = await secretsManager.SafeLoadAsync().ConfigureAwait(false); + if (string.IsNullOrEmpty(secrets.GeminiApiKey)) + { + return new ImageGenerationResponse + { + IsSuccess = false, + ErrorMessage = "Gemini API key not configured. Please add it in Settings.", + }; + } + + try + { + var geminiRequest = BuildGeminiRequest(request); + + var model = + request.ProviderOptions?.TryGetValue("model", out var modelValue) == true + ? modelValue?.ToString() ?? DefaultModel + : DefaultModel; + + logger.LogInformation("Generating image with Gemini model: {Model}", model); + + var response = await geminiApi + .GenerateContentAsync(model, geminiRequest, cancellationToken) + .ConfigureAwait(false); + + return ParseGeminiResponse(response); + } + catch (ApiException apiEx) when (apiEx.StatusCode == System.Net.HttpStatusCode.TooManyRequests) + { + logger.LogError(apiEx, "Rate limit or quota exceeded for Gemini API"); + return new ImageGenerationResponse + { + IsSuccess = false, + ErrorMessage = + "Rate limit or quota exceeded. " + + "Note: Free Gemini API keys do not support image generation - you need a paid API key. " + + "If you have a paid key, you may be hitting rate limits. Please try again in a moment.", + }; + } + catch (ApiException apiEx) + { + logger.LogError(apiEx, "Gemini API error: {StatusCode}", apiEx.StatusCode); + + const string keyRestrictionHint = + " Note: as of June 19 2026, Google blocks unrestricted Gemini API keys. " + + "If your key was created earlier, you may need to add API restrictions to it at " + + "https://aistudio.google.com/apikey."; + + var errorMessage = apiEx.StatusCode switch + { + System.Net.HttpStatusCode.Unauthorized => + "Invalid API key. Please check your Gemini API key in Settings." + keyRestrictionHint, + System.Net.HttpStatusCode.Forbidden => + "Access forbidden. Your API key may not have the required permissions." + + keyRestrictionHint, + System.Net.HttpStatusCode.BadRequest => $"Invalid request: {apiEx.Content}", + _ => $"API error ({apiEx.StatusCode}): {apiEx.Message}", + }; + + return new ImageGenerationResponse { IsSuccess = false, ErrorMessage = errorMessage }; + } + catch (Exception ex) + { + logger.LogError(ex, "Failed to generate image with Gemini"); + return new ImageGenerationResponse { IsSuccess = false, ErrorMessage = ex.Message }; + } + } + + /// + /// Builds the Gemini request. Can be overridden to add specific configuration. + /// + protected virtual GeminiGenerateContentRequest BuildGeminiRequest(ImageGenerationRequest request) + { + var contents = new List(); + + // Add conversation history if present + if (request.ConversationHistory != null) + { + foreach (var message in request.ConversationHistory) + { + var parts = new List(); + + if (!string.IsNullOrEmpty(message.TextContent)) + { + parts.Add( + new GeminiPart + { + Text = message.TextContent, + ThoughtSignature = message.TextThoughtSignature, + } + ); + } + + if (message.ImageContent != null) + { + parts.Add( + new GeminiPart + { + InlineData = new GeminiInlineData + { + MimeType = message.ImageContent.MimeType, + Data = message.ImageContent.Base64Data, + }, + ThoughtSignature = message.ImageContent.ThoughtSignature, + } + ); + } + + if (parts.Count > 0) + { + contents.Add( + new GeminiContent + { + Role = message.Role == MessageRole.User ? "user" : "model", + Parts = parts, + } + ); + } + } + } + + // Add current request + var currentParts = new List(); + + if (!string.IsNullOrEmpty(request.TextPrompt)) + { + currentParts.Add(new GeminiPart { Text = request.TextPrompt }); + } + + if (request.InputImages != null) + { + foreach (var image in request.InputImages) + { + currentParts.Add( + new GeminiPart + { + InlineData = new GeminiInlineData + { + MimeType = image.MimeType, + Data = image.Base64Data, + }, + } + ); + } + } + + if (currentParts.Count > 0) + { + contents.Add(new GeminiContent { Role = "user", Parts = currentParts }); + } + + // Build generation config + var generationConfig = new GeminiGenerationConfig { ResponseModalities = ["TEXT", "IMAGE"] }; + + // Add aspect ratio if specified + if (request.ProviderOptions?.TryGetValue("aspectRatio", out var aspectRatioValue) == true) + { + generationConfig = generationConfig with + { + ImageConfig = new GeminiImageConfig { AspectRatio = aspectRatioValue?.ToString() }, + }; + } + + return new GeminiGenerateContentRequest { Contents = contents, GenerationConfig = generationConfig }; + } + + protected virtual ImageGenerationResponse ParseGeminiResponse(GeminiGenerateContentResponse response) + { + if (response.Candidates == null || response.Candidates.Count == 0) + { + var blockReason = response.PromptFeedback?.BlockReason; + return new ImageGenerationResponse + { + IsSuccess = false, + ErrorMessage = string.IsNullOrEmpty(blockReason) + ? "No candidates returned from Gemini" + : $"Request blocked: {blockReason}", + }; + } + + var candidate = response.Candidates[0]; + var images = new List(); + string? textResponse = null; + string? thinkingContent = null; + string? lastThoughtSignature = null; + + if (candidate.Content?.Parts != null) + { + var parts = candidate.Content.Parts; + + // For thinking models, images that appear between text parts are intermediate + // "draft" outputs from the reasoning process β€” only images at or after the last + // text part are the final result. For non-thinking models this is a no-op + // (a typical single-text + single-image response has lastTextPartIndex = 0 + // and the trailing image is correctly kept). + var lastTextPartIndex = -1; + for (var i = 0; i < parts.Count; i++) + { + if (!string.IsNullOrEmpty(parts[i].Text)) + { + lastTextPartIndex = i; + } + } + + for (var i = 0; i < parts.Count; i++) + { + var part = parts[i]; + + if (!string.IsNullOrEmpty(part.ThoughtSignature)) + { + lastThoughtSignature = part.ThoughtSignature; + } + + if (part is { Thought: true, Text: not null }) + { + thinkingContent = string.IsNullOrEmpty(thinkingContent) + ? part.Text + : thinkingContent + "\n\n" + part.Text; + continue; + } + + if (!string.IsNullOrEmpty(part.Text)) + { + textResponse = part.Text; + } + + if (part.InlineData != null && i >= lastTextPartIndex) + { + images.Add( + new GeneratedImage + { + Base64Data = part.InlineData.Data, + MimeType = part.InlineData.MimeType, + ThoughtSignature = part.ThoughtSignature, + } + ); + } + } + } + + var responseThoughtSignature = images.FirstOrDefault()?.ThoughtSignature ?? lastThoughtSignature; + + return new ImageGenerationResponse + { + IsSuccess = true, + Images = images.Count > 0 ? images : null, + TextResponse = textResponse, + ThinkingContent = thinkingContent, + ThoughtSignature = responseThoughtSignature, + Metadata = new Dictionary + { + ["finishReason"] = candidate.FinishReason ?? "unknown", + ["hasThinking"] = !string.IsNullOrEmpty(thinkingContent), + }, + }; + } +} diff --git a/StabilityMatrix.Core/Services/ImageGeneration/GeminiImageGenerationProvider.cs b/StabilityMatrix.Core/Services/ImageGeneration/GeminiImageGenerationProvider.cs new file mode 100644 index 000000000..401cc9204 --- /dev/null +++ b/StabilityMatrix.Core/Services/ImageGeneration/GeminiImageGenerationProvider.cs @@ -0,0 +1,20 @@ +using Microsoft.Extensions.Logging; +using Refit; +using StabilityMatrix.Core.Api; +using StabilityMatrix.Core.Models.Api.Gemini; + +namespace StabilityMatrix.Core.Services.ImageGeneration; + +/// +/// Image generation provider for Google Gemini (Nano Banana) +/// +public class GeminiImageGenerationProvider( + ILogger logger, + IGeminiApi geminiApi, + ISecretsManager secretsManager +) : GeminiBaseImageGenerationProvider(logger, geminiApi, secretsManager) +{ + public override string ProviderId => BananaVisionProviderIds.Gemini25Flash; + public override string ProviderName => "Gemini 2.5 Flash (Nano Banana)"; + public override string DefaultModel => "gemini-2.5-flash-image"; +} diff --git a/StabilityMatrix.Core/Services/ImageGeneration/IImageGenerationChatService.cs b/StabilityMatrix.Core/Services/ImageGeneration/IImageGenerationChatService.cs new file mode 100644 index 000000000..c4b2783e3 --- /dev/null +++ b/StabilityMatrix.Core/Services/ImageGeneration/IImageGenerationChatService.cs @@ -0,0 +1,132 @@ +using StabilityMatrix.Core.Models.Database; + +namespace StabilityMatrix.Core.Services.ImageGeneration; + +/// +/// Service for managing image generation conversations +/// +public interface IImageGenerationChatService +{ + /// + /// Get all conversations + /// + Task> GetConversationsAsync(); + + /// + /// Get a specific conversation by ID + /// + Task GetConversationAsync(Guid conversationId); + + /// + /// Get all messages for a conversation + /// + Task> GetMessagesAsync(Guid conversationId); + + /// + /// Create a new conversation + /// + Task CreateConversationAsync( + string providerId, + string initialTitle = "New Conversation" + ); + + /// + /// Update a conversation + /// + Task UpdateConversationAsync(ImageGenerationConversation conversation); + + /// + /// Delete a conversation and all its messages + /// + Task DeleteConversationAsync(Guid conversationId); + + /// + /// Delete a specific message from a conversation + /// + /// The message ID to delete + /// If true, keeps the image file on disk (for regenerate). If false, deletes it (for edit/delete) + Task DeleteMessageAsync(Guid messageId, bool preserveImageFile = false); + + /// + /// Remove a specific image from a message. If it's the last image, deletes the entire message. + /// + /// The message ID + /// The specific image path to remove + /// True if the entire message was deleted, false if only the image was removed + Task RemoveImageFromMessageAsync(Guid messageId, string imagePath); + + /// + /// Get a specific message by ID + /// + Task GetMessageAsync(Guid messageId); + + /// + /// Update a message's text content without affecting other messages or triggering regeneration + /// + /// The message ID to update + /// The new text content + /// The updated message, or null if not found + Task UpdateMessageTextAsync(Guid messageId, string newTextContent); + + /// + /// Send a message and generate a response using the specified provider + /// + /// The conversation ID + /// The provider to use for this message + /// Optional text prompt + /// Optional image paths to include + /// Cancellation token + Task<(ImageGenerationMessage UserMessage, ImageGenerationMessage? AssistantMessage)> SendMessageAsync( + Guid conversationId, + string providerId, + string? textPrompt, + List? imagePaths = null, + CancellationToken cancellationToken = default + ); + + /// + /// Send a message and generate a response with provider options using the specified provider + /// + /// The conversation ID + /// The provider to use for this message + /// Optional text prompt + /// Optional image paths to include + /// Provider-specific options + /// Cancellation token + Task<(ImageGenerationMessage UserMessage, ImageGenerationMessage? AssistantMessage)> SendMessageAsync( + Guid conversationId, + string providerId, + string? textPrompt, + List? imagePaths, + Dictionary? providerOptions, + IProgress? progress = null, + CancellationToken cancellationToken = default + ); + + /// + /// Retry generation for the last user message in a conversation. + /// Does not create a new user message - just regenerates the assistant response. + /// + /// The conversation ID + /// The provider to use for regeneration + /// Provider-specific options + /// Cancellation token + /// The generated assistant message + Task RetryGenerationAsync( + Guid conversationId, + string providerId, + Dictionary? providerOptions = null, + IProgress? progress = null, + CancellationToken cancellationToken = default + ); + + /// + /// Get available providers + /// + List GetAvailableProviders(); + + /// + /// Get a provider by ID + /// + IImageGenerationProvider? GetProvider(string providerId); +} diff --git a/StabilityMatrix.Core/Services/ImageGeneration/IImageGenerationProvider.cs b/StabilityMatrix.Core/Services/ImageGeneration/IImageGenerationProvider.cs new file mode 100644 index 000000000..ad78c4aa4 --- /dev/null +++ b/StabilityMatrix.Core/Services/ImageGeneration/IImageGenerationProvider.cs @@ -0,0 +1,41 @@ +namespace StabilityMatrix.Core.Services.ImageGeneration; + +/// +/// Represents a provider for image generation (e.g., Gemini, Flux Kontext) +/// +public interface IImageGenerationProvider +{ + /// + /// Unique identifier for this provider (e.g., "gemini-2.5-flash", "flux-kontext") + /// + string ProviderId { get; } + + /// + /// Display name for this provider + /// + string ProviderName { get; } + + /// + /// Whether this provider supports image input for editing/composition + /// + bool SupportsImageInput { get; } + + /// + /// Whether this provider supports multi-turn conversations + /// + bool SupportsMultiTurn { get; } + + /// + /// Whether this provider requires thought signatures on image parts (Gemini 3 Pro). + /// If true, conversations started with non-thinking providers cannot be continued. + /// + bool RequiresThoughtSignatures { get; } + + /// + /// Generate an image based on the provided request + /// + Task GenerateAsync( + ImageGenerationRequest request, + CancellationToken cancellationToken = default + ); +} diff --git a/StabilityMatrix.Core/Services/ImageGeneration/ImageGenerationChatService.cs b/StabilityMatrix.Core/Services/ImageGeneration/ImageGenerationChatService.cs new file mode 100644 index 000000000..6fe90a94f --- /dev/null +++ b/StabilityMatrix.Core/Services/ImageGeneration/ImageGenerationChatService.cs @@ -0,0 +1,1107 @@ +using Microsoft.Extensions.Logging; +using StabilityMatrix.Core.Database; +using StabilityMatrix.Core.Models.Database; +using StabilityMatrix.Core.Models.FileInterfaces; + +namespace StabilityMatrix.Core.Services.ImageGeneration; + +/// +/// Service for managing image generation conversations +/// +public class ImageGenerationChatService( + ILogger logger, + IBananaVisionDbContext database, + ISettingsManager settingsManager, + IEnumerable providers +) : IImageGenerationChatService +{ + private readonly List providers = providers.ToList(); + + private static List GetMessageImagePaths(ImageGenerationMessage message) + { + var paths = new List(); + + if (message.ImagePaths is { Count: > 0 }) + { + paths.AddRange(message.ImagePaths.Where(p => !string.IsNullOrWhiteSpace(p))); + } + else if (!string.IsNullOrEmpty(message.ImagePath)) + { + paths.Add(message.ImagePath); + } + + return paths.Distinct(StringComparer.OrdinalIgnoreCase).ToList(); + } + + private static bool IsPathUnderDirectory(string filePath, string directoryPath) + { + var fullFilePath = Path.GetFullPath(filePath); + var fullDirectoryPath = Path.GetFullPath(directoryPath); + + var relativePath = Path.GetRelativePath(fullDirectoryPath, fullFilePath); + if (string.IsNullOrWhiteSpace(relativePath)) + return false; + + if (Path.IsPathRooted(relativePath)) + return false; + + return !relativePath.StartsWith("..", StringComparison.Ordinal) + && !relativePath.StartsWith($"..{Path.DirectorySeparatorChar}", StringComparison.Ordinal) + && !relativePath.StartsWith($"..{Path.AltDirectorySeparatorChar}", StringComparison.Ordinal); + } + + private static string CreateOutputFileName(string extension) + { + var timestamp = DateTime.UtcNow.ToString("yyyyMMdd_HHmmss_fff"); + var shortGuid = Guid.NewGuid().ToString("N")[..8]; + return $"imagelab_{timestamp}_{shortGuid}{extension}"; + } + + private static string GenerateConversationTitle(string textPrompt) + { + var title = textPrompt.Trim(); + if (title.Length == 0) + return "New Conversation"; + + var firstSentenceEnd = title.IndexOfAny(['.', '!', '?']); + if (firstSentenceEnd > 0 && firstSentenceEnd < 50) + { + title = title[..firstSentenceEnd].Trim(); + } + else if (title.Length > 50) + { + title = title[..50].TrimEnd() + "..."; + } + + return title.Length == 0 ? "New Conversation" : title; + } + + private string GetOutputDirectory() + { + return Path.Combine(settingsManager.ImagesDirectory, "ImageLab"); + } + + private string GetInputDirectory(Guid conversationId) + { + return Path.Combine(GetOutputDirectory(), "Inputs", conversationId.ToString("N")); + } + + private async Task PersistInputImageAsync( + Guid conversationId, + string sourcePath, + CancellationToken cancellationToken + ) + { + try + { + if (!File.Exists(sourcePath)) + return null; + + var inputDir = GetInputDirectory(conversationId); + Directory.CreateDirectory(inputDir); + + var extension = Path.GetExtension(sourcePath); + if (string.IsNullOrWhiteSpace(extension)) + { + extension = ".png"; + } + + var timestamp = DateTime.UtcNow.ToString("yyyyMMdd_HHmmss_fff"); + var shortGuid = Guid.NewGuid().ToString("N")[..8]; + var fileName = $"input_{timestamp}_{shortGuid}{extension}"; + var destinationPath = Path.Combine(inputDir, fileName); + + await using var sourceStream = File.OpenRead(sourcePath); + await using var destinationStream = File.Create(destinationPath); + await sourceStream.CopyToAsync(destinationStream, cancellationToken).ConfigureAwait(false); + + return destinationPath; + } + catch (Exception ex) + { + logger.LogWarning(ex, "Failed to persist input image {SourcePath}", sourcePath); + return null; + } + } + + public async Task> GetConversationsAsync() + { + logger.LogDebug("Querying conversations from database..."); + var conversations = await database + .Conversations.Query() + .OrderByDescending(c => c.UpdatedAt) + .ToListAsync() + .ConfigureAwait(false); + + logger.LogInformation("Retrieved {Count} conversations from database", conversations.Count); + return conversations; + } + + public async Task GetConversationAsync(Guid conversationId) + { + return await database.Conversations.FindByIdAsync(conversationId).ConfigureAwait(false); + } + + public async Task> GetMessagesAsync(Guid conversationId) + { + var messages = await database + .Messages.Query() + .Where(m => m.ConversationId == conversationId) + .OrderBy(m => m.Timestamp) + .ToListAsync() + .ConfigureAwait(false); + + return messages; + } + + public async Task CreateConversationAsync( + string providerId, + string initialTitle = "New Conversation" + ) + { + var conversation = new ImageGenerationConversation { Title = initialTitle, ProviderId = providerId }; + + await database.Conversations.InsertAsync(conversation).ConfigureAwait(false); + + logger.LogInformation( + "Created new conversation {ConversationId} with provider {ProviderId}", + conversation.Id, + providerId + ); + + return conversation; + } + + public async Task UpdateConversationAsync(ImageGenerationConversation conversation) + { + var updated = conversation with { UpdatedAt = DateTime.UtcNow }; + await database.Conversations.UpdateAsync(updated).ConfigureAwait(false); + } + + public async Task DeleteConversationAsync(Guid conversationId) + { + // Delete all messages first + var messages = await GetMessagesAsync(conversationId).ConfigureAwait(false); + var outputDir = GetOutputDirectory(); + var inputDir = GetInputDirectory(conversationId); + var deletedImageCount = 0; + + foreach (var message in messages) + { + var imagePaths = new List(); + if (!string.IsNullOrEmpty(message.ImagePath)) + { + imagePaths.Add(message.ImagePath); + } + if (message.ImagePaths is { Count: > 0 }) + { + imagePaths.AddRange(message.ImagePaths.Where(p => !string.IsNullOrWhiteSpace(p))); + } + + foreach (var imagePathValue in imagePaths.Distinct(StringComparer.OrdinalIgnoreCase)) + { + if (!File.Exists(imagePathValue)) + continue; + + // Delete generated output images (output dir) and app-managed input copies (input dir). + // Never delete arbitrary user filesystem paths. + if ( + IsPathUnderDirectory(imagePathValue, outputDir) + || IsPathUnderDirectory(imagePathValue, inputDir) + ) + { + try + { + File.Delete(imagePathValue); + deletedImageCount++; + } + catch (Exception ex) + { + logger.LogWarning( + ex, + "Failed to delete generated image file {ImagePath}", + imagePathValue + ); + } + } + else + { + logger.LogDebug( + "Preserving user input image {ImagePath} (not in output directory)", + imagePathValue + ); + } + } + + await database.Messages.DeleteAsync(message.Id).ConfigureAwait(false); + } + + // Delete the conversation + await database.Conversations.DeleteAsync(conversationId).ConfigureAwait(false); + + logger.LogInformation( + "Deleted conversation {ConversationId} with {MessageCount} messages and {ImageCount} generated images", + conversationId, + messages.Count, + deletedImageCount + ); + } + + public async Task DeleteMessageAsync(Guid messageId, bool preserveImageFile = false) + { + var message = await database.Messages.FindByIdAsync(messageId).ConfigureAwait(false); + if (message == null) + { + logger.LogWarning("Message {MessageId} not found", messageId); + return; + } + + var imagePaths = new List(); + if (!string.IsNullOrEmpty(message.ImagePath)) + { + imagePaths.Add(message.ImagePath); + } + if (message.ImagePaths is { Count: > 0 }) + { + imagePaths.AddRange(message.ImagePaths.Where(p => !string.IsNullOrWhiteSpace(p))); + } + + // Delete generated image files if they exist (unless we're preserving output images). + // Also delete app-managed input copies when present (safe to delete). + if (!preserveImageFile) + { + var outputDir = GetOutputDirectory(); + var inputDir = GetInputDirectory(message.ConversationId); + + foreach (var imagePathValue in imagePaths.Distinct(StringComparer.OrdinalIgnoreCase)) + { + if (!File.Exists(imagePathValue)) + continue; + + if ( + !IsPathUnderDirectory(imagePathValue, outputDir) + && !IsPathUnderDirectory(imagePathValue, inputDir) + ) + continue; + + try + { + File.Delete(imagePathValue); + logger.LogDebug("Deleted managed image file {ImagePath}", imagePathValue); + } + catch (Exception ex) + { + logger.LogWarning( + ex, + "Failed to delete generated image file {ImagePath}", + imagePathValue + ); + } + } + } + else if (preserveImageFile && imagePaths.Count > 0) + { + logger.LogDebug("Preserved image file(s) for message {MessageId} (regenerate mode)", message.Id); + } + + await database.Messages.DeleteAsync(messageId).ConfigureAwait(false); + logger.LogDebug("Deleted message {MessageId}", messageId); + } + + public async Task GetMessageAsync(Guid messageId) + { + return await database.Messages.FindByIdAsync(messageId).ConfigureAwait(false); + } + + public async Task UpdateMessageTextAsync(Guid messageId, string newTextContent) + { + var message = await database.Messages.FindByIdAsync(messageId).ConfigureAwait(false); + if (message == null) + { + logger.LogWarning("Message {MessageId} not found for text update", messageId); + return null; + } + + var updatedMessage = message with { TextContent = newTextContent }; + await database.Messages.UpdateAsync(updatedMessage).ConfigureAwait(false); + + logger.LogDebug("Updated text content for message {MessageId}", messageId); + return updatedMessage; + } + + public async Task RemoveImageFromMessageAsync(Guid messageId, string imagePath) + { + var message = await database.Messages.FindByIdAsync(messageId).ConfigureAwait(false); + if (message == null) + { + logger.LogWarning("Message {MessageId} not found for image removal", messageId); + return true; // Treat as fully deleted + } + + var allImagePaths = GetMessageImagePaths(message).ToList(); + + // If this is the only image (or no images), delete the whole message + if (allImagePaths.Count <= 1) + { + await DeleteMessageAsync(messageId).ConfigureAwait(false); + return true; + } + + // Remove the specific image path + var updatedPaths = allImagePaths + .Where(p => !string.Equals(p, imagePath, StringComparison.OrdinalIgnoreCase)) + .ToList(); + + // Delete the image file + if (File.Exists(imagePath)) + { + try + { + File.Delete(imagePath); + logger.LogDebug("Deleted image file {ImagePath}", imagePath); + } + catch (Exception ex) + { + logger.LogWarning(ex, "Failed to delete image file {ImagePath}", imagePath); + } + } + + // Update the message with remaining paths + var updatedMessage = message with + { + ImagePath = updatedPaths.FirstOrDefault(), + ImagePaths = updatedPaths.Count > 0 ? updatedPaths : null, + }; + + await database.Messages.UpdateAsync(updatedMessage).ConfigureAwait(false); + logger.LogDebug( + "Removed image {ImagePath} from message {MessageId}, {RemainingCount} images remaining", + imagePath, + messageId, + updatedPaths.Count + ); + + return false; + } + + public Task<( + ImageGenerationMessage UserMessage, + ImageGenerationMessage? AssistantMessage + )> SendMessageAsync( + Guid conversationId, + string providerId, + string? textPrompt, + List? imagePaths = null, + CancellationToken cancellationToken = default + ) + { + return SendMessageAsync( + conversationId, + providerId, + textPrompt, + imagePaths, + null, + null, + cancellationToken + ); + } + + public async Task<( + ImageGenerationMessage UserMessage, + ImageGenerationMessage? AssistantMessage + )> SendMessageAsync( + Guid conversationId, + string providerId, + string? textPrompt, + List? imagePaths, + Dictionary? providerOptions, + IProgress? progress, + CancellationToken cancellationToken = default + ) + { + var conversation = await GetConversationAsync(conversationId).ConfigureAwait(false); + if (conversation == null) + { + throw new InvalidOperationException($"Conversation {conversationId} not found"); + } + + var provider = GetProvider(providerId); + if (provider == null) + { + throw new InvalidOperationException($"Provider {providerId} not found"); + } + + // Update conversation's provider if it changed + var providerChanged = conversation.ProviderId != providerId; + if (providerChanged) + { + logger.LogInformation( + "Switching conversation {ConversationId} provider from {OldProvider} to {NewProvider}", + conversationId, + conversation.ProviderId, + providerId + ); + conversation.ProviderId = providerId; + } + + // Check for provider compatibility - thought signature requirements + // If switching to a thinking model with incompatible history, we'll carry forward + // the last output image as an input instead of using the full history + string? carryForwardImagePath = null; + if (provider.RequiresThoughtSignatures) + { + var existingMessages = await GetMessagesAsync(conversationId).ConfigureAwait(false); + var incompatibleMessages = existingMessages + .Where(m => + m.Role == MessageRole.Assistant + && !string.IsNullOrEmpty(m.ImagePath) + && string.IsNullOrEmpty(m.ThoughtSignature) + ) + .ToList(); + + if (incompatibleMessages.Count > 0) + { + // Find the last assistant message with an image to carry forward + var lastAssistantImage = existingMessages + .Where(m => m.Role == MessageRole.Assistant && !string.IsNullOrEmpty(m.ImagePath)) + .OrderByDescending(m => m.Timestamp) + .FirstOrDefault(); + + if (lastAssistantImage != null) + { + carryForwardImagePath = GetMessageImagePaths(lastAssistantImage) + .FirstOrDefault(File.Exists); + if (carryForwardImagePath != null) + { + logger.LogInformation( + "Switching to thinking model with incompatible history. " + + "Carrying forward last image as input: {ImagePath}", + carryForwardImagePath + ); + } + } + } + } + + // Create user message + var userMessage = new ImageGenerationMessage + { + ConversationId = conversationId, + Role = MessageRole.User, + TextContent = textPrompt, + }; + + // Handle image inputs if provided + List? inputImages = null; + var persistedImagePaths = new List(); + + // If we have a carry-forward image (from incompatible history), add it first + if (!string.IsNullOrEmpty(carryForwardImagePath) && File.Exists(carryForwardImagePath)) + { + inputImages = []; + var imageBytes = await File.ReadAllBytesAsync(carryForwardImagePath, cancellationToken) + .ConfigureAwait(false); + var base64 = Convert.ToBase64String(imageBytes); + var mimeType = GetMimeTypeFromPath(carryForwardImagePath); + + inputImages.Add( + new ImageInputData + { + Base64Data = base64, + MimeType = mimeType, + FilePath = carryForwardImagePath, + } + ); + // Note: We don't persist the carry-forward image as it's already persisted + // and we don't want to duplicate it in the user message metadata + } + + if (imagePaths?.Count > 0) + { + inputImages ??= []; + foreach (var originalImagePath in imagePaths.Where(File.Exists)) + { + var persistedPath = await PersistInputImageAsync( + conversationId, + originalImagePath, + cancellationToken + ) + .ConfigureAwait(false); + var imagePathToUse = persistedPath ?? originalImagePath; + + var imageBytes = await File.ReadAllBytesAsync(imagePathToUse, cancellationToken) + .ConfigureAwait(false); + var base64 = Convert.ToBase64String(imageBytes); + var mimeType = GetMimeTypeFromPath(imagePathToUse); + + inputImages.Add( + new ImageInputData + { + Base64Data = base64, + MimeType = mimeType, + FilePath = imagePathToUse, + } + ); + persistedImagePaths.Add(imagePathToUse); + + // Stop if cancellation requested + cancellationToken.ThrowIfCancellationRequested(); + } + + if (persistedImagePaths.Count > 0) + { + userMessage = userMessage with + { + ImagePath = persistedImagePaths[0], + ImageMimeType = inputImages.First(i => i.FilePath == persistedImagePaths[0]).MimeType, + ImagePaths = persistedImagePaths, + }; + } + } + + await database.Messages.InsertAsync(userMessage).ConfigureAwait(false); + + // Update conversation title if it's still the default + if (conversation.Title == "New Conversation" && !string.IsNullOrEmpty(textPrompt)) + { + var newTitle = GenerateConversationTitle(textPrompt); + conversation = conversation with + { + Title = newTitle, + ProviderId = providerId, // Update provider ID as well + UpdatedAt = DateTime.UtcNow, + }; + + await database.Conversations.UpdateAsync(conversation).ConfigureAwait(false); + logger.LogDebug("Updated conversation title to: {Title}", newTitle); + } + + // Get conversation history (load image files and convert to base64 for providers) + // If we're carrying forward an image (due to incompatible history), skip the history + var conversationHistory = new List(); + + if (string.IsNullOrEmpty(carryForwardImagePath)) + { + var previousMessages = await GetMessagesAsync(conversationId).ConfigureAwait(false); + logger.LogInformation( + "Building conversation history with {Count} previous messages", + previousMessages.Count - 1 + ); + + foreach (var m in previousMessages.Where(msg => msg.Id != userMessage.Id)) + { + var messageImagePaths = GetMessageImagePaths(m).Where(File.Exists).ToList(); + + // If no images, keep a single message with text (and thought signature for text parts). + if (messageImagePaths.Count == 0) + { + conversationHistory.Add( + new ConversationMessage + { + Role = m.Role, + TextContent = m.TextContent, + ImageContent = null, + TextThoughtSignature = m.ThoughtSignature, + } + ); + continue; + } + + // If there are multiple images, emit one history entry per image. + // Include the message text only on the first entry to avoid duplicating prompt text. + for (var i = 0; i < messageImagePaths.Count; i++) + { + var imagePath = messageImagePaths[i]; + ImageInputData? imageContent = null; + + try + { + var imageBytes = await File.ReadAllBytesAsync(imagePath, cancellationToken) + .ConfigureAwait(false); + imageContent = new ImageInputData + { + Base64Data = Convert.ToBase64String(imageBytes), + MimeType = GetMimeTypeFromPath(imagePath), + FilePath = imagePath, + ThoughtSignature = m.ThoughtSignature, + }; + + logger.LogInformation( + "Loaded history image for {Role} message: {ImagePath} ({Size} bytes){ThoughtSig}", + m.Role, + imagePath, + imageBytes.Length, + !string.IsNullOrEmpty(m.ThoughtSignature) ? " [with thought signature]" : "" + ); + } + catch (Exception ex) + { + logger.LogWarning( + ex, + "Failed to load conversation history image from {ImagePath}", + imagePath + ); + } + + conversationHistory.Add( + new ConversationMessage + { + Role = m.Role, + TextContent = i == 0 ? m.TextContent : null, + ImageContent = imageContent, + // Include thought signature for text-only parts (we carry it only when no image). + TextThoughtSignature = null, + } + ); + } + } + } + else + { + logger.LogInformation( + "Skipping conversation history due to incompatible thought signature requirements. " + + "Using carry-forward image instead." + ); + } + + // Build request + var request = new ImageGenerationRequest + { + TextPrompt = textPrompt, + InputImages = inputImages, + ConversationHistory = conversationHistory, + ProviderOptions = providerOptions, + Progress = progress, + }; + + // Generate response + logger.LogInformation( + "Generating image with provider {ProviderId} for conversation {ConversationId}", + provider.ProviderId, + conversationId + ); + + progress?.Report( + new ImageGenerationProgress( + ProviderId: providerId, + PromptId: null, + Value: null, + Maximum: null, + RunningNode: null, + Stage: "Generating..." + ) + ); + + var response = await provider.GenerateAsync(request, cancellationToken).ConfigureAwait(false); + + if (!response.IsSuccess) + { + logger.LogError("Image generation failed: {ErrorMessage}", response.ErrorMessage); + + // Don't save error messages to the database - let the caller handle the error via UI + // Update conversation timestamp and provider + var errorUpdatedConversation = conversation with + { + ProviderId = providerId, + UpdatedAt = DateTime.UtcNow, + }; + await database.Conversations.UpdateAsync(errorUpdatedConversation).ConfigureAwait(false); + + // Throw exception so caller can handle it appropriately (show notification, etc.) + throw new ImageGenerationException(response.ErrorMessage ?? "Image generation failed"); + } + + // Save generated images + List? savedImagePaths = null; + if (response.Images?.Count > 0) + { + progress?.Report( + new ImageGenerationProgress( + ProviderId: providerId, + PromptId: null, + Value: null, + Maximum: null, + RunningNode: null, + Stage: "Saving image(s)..." + ) + ); + + var outputDir = GetOutputDirectory(); + Directory.CreateDirectory(outputDir); + + savedImagePaths = []; + + foreach (var generatedImage in response.Images) + { + var imageBytes = Convert.FromBase64String(generatedImage.Base64Data); + var extension = GetExtensionFromMimeType(generatedImage.MimeType); + var fileName = CreateOutputFileName(extension); + var savedPath = Path.Combine(outputDir, fileName); + + await File.WriteAllBytesAsync(savedPath, imageBytes, cancellationToken).ConfigureAwait(false); + savedImagePaths.Add(savedPath); + } + + logger.LogInformation( + "Saved {Count} generated image(s) to {OutputDir}", + savedImagePaths.Count, + outputDir + ); + } + + // Create assistant message - capture thought signature for multi-turn continuity + // The thought signature comes from either the image part or the response level + var thoughtSignature = + response.Images?.FirstOrDefault()?.ThoughtSignature ?? response.ThoughtSignature; + + var primarySavedImagePath = savedImagePaths?.FirstOrDefault(); + var assistantMessage = new ImageGenerationMessage + { + ConversationId = conversationId, + Role = MessageRole.Assistant, + TextContent = response.TextResponse, + ImagePath = primarySavedImagePath, + ImagePaths = savedImagePaths, + ImageMimeType = response.Images?.FirstOrDefault()?.MimeType, + ThinkingContent = response.ThinkingContent, + ThoughtSignature = thoughtSignature, + }; + + if (!string.IsNullOrEmpty(thoughtSignature)) + { + logger.LogInformation("Saved thought signature for assistant message"); + } + + await database.Messages.InsertAsync(assistantMessage).ConfigureAwait(false); + + // Update conversation title if this is the first exchange, and always update timestamp + var completionUpdate = conversation with + { + ProviderId = providerId, + UpdatedAt = DateTime.UtcNow, + }; + + await database.Conversations.UpdateAsync(completionUpdate).ConfigureAwait(false); + + return (userMessage, assistantMessage); + } + + public async Task RetryGenerationAsync( + Guid conversationId, + string providerId, + Dictionary? providerOptions = null, + IProgress? progress = null, + CancellationToken cancellationToken = default + ) + { + var conversation = await GetConversationAsync(conversationId).ConfigureAwait(false); + if (conversation == null) + { + throw new InvalidOperationException($"Conversation {conversationId} not found"); + } + + var provider = GetProvider(providerId); + if (provider == null) + { + throw new InvalidOperationException($"Provider {providerId} not found"); + } + + // Get all messages to find the last user message and build history + var allMessages = await GetMessagesAsync(conversationId).ConfigureAwait(false); + + // Check for provider compatibility - thought signature requirements + string? carryForwardImagePath = null; + if (provider.RequiresThoughtSignatures) + { + var incompatibleMessages = allMessages + .Where(m => + m.Role == MessageRole.Assistant + && !string.IsNullOrEmpty(m.ImagePath) + && string.IsNullOrEmpty(m.ThoughtSignature) + ) + .ToList(); + + if (incompatibleMessages.Count > 0) + { + // Find the last assistant message with an image to carry forward + var lastAssistantImage = allMessages + .Where(m => m.Role == MessageRole.Assistant && !string.IsNullOrEmpty(m.ImagePath)) + .OrderByDescending(m => m.Timestamp) + .FirstOrDefault(); + + if (lastAssistantImage != null) + { + carryForwardImagePath = GetMessageImagePaths(lastAssistantImage) + .FirstOrDefault(File.Exists); + if (carryForwardImagePath != null) + { + logger.LogInformation( + "Retry: Switching to thinking model with incompatible history. " + + "Carrying forward last image as input: {ImagePath}", + carryForwardImagePath + ); + } + } + } + } + + // Find the last user message + var lastUserMessage = allMessages.LastOrDefault(m => m.Role == MessageRole.User); + if (lastUserMessage == null) + { + throw new InvalidOperationException("No user message found to retry"); + } + + // Build conversation history (everything except the last user message) + // If we're carrying forward an image, skip the incompatible history + var conversationHistory = new List(); + + if (string.IsNullOrEmpty(carryForwardImagePath)) + { + foreach (var m in allMessages.Where(msg => msg.Id != lastUserMessage.Id)) + { + var messageImagePaths = GetMessageImagePaths(m).Where(File.Exists).ToList(); + + if (messageImagePaths.Count == 0) + { + conversationHistory.Add( + new ConversationMessage + { + Role = m.Role, + TextContent = m.TextContent, + ImageContent = null, + TextThoughtSignature = m.ThoughtSignature, + } + ); + continue; + } + + for (var i = 0; i < messageImagePaths.Count; i++) + { + var imagePath = messageImagePaths[i]; + ImageInputData? imageContent = null; + try + { + var imageBytes = await File.ReadAllBytesAsync(imagePath, cancellationToken) + .ConfigureAwait(false); + imageContent = new ImageInputData + { + Base64Data = Convert.ToBase64String(imageBytes), + MimeType = GetMimeTypeFromPath(imagePath), + FilePath = imagePath, + ThoughtSignature = m.ThoughtSignature, + }; + } + catch (Exception ex) + { + logger.LogWarning(ex, "Failed to load history image from {ImagePath}", imagePath); + } + + conversationHistory.Add( + new ConversationMessage + { + Role = m.Role, + TextContent = i == 0 ? m.TextContent : null, + ImageContent = imageContent, + TextThoughtSignature = null, + } + ); + } + } + } + else + { + logger.LogInformation( + "Retry: Skipping conversation history due to incompatible thought signature requirements. " + + "Using carry-forward image instead." + ); + } + + // Build input images from the last user message + List? inputImages = null; + + // If we have a carry-forward image, add it first + if (!string.IsNullOrEmpty(carryForwardImagePath) && File.Exists(carryForwardImagePath)) + { + inputImages = []; + var imageBytes = await File.ReadAllBytesAsync(carryForwardImagePath, cancellationToken) + .ConfigureAwait(false); + inputImages.Add( + new ImageInputData + { + Base64Data = Convert.ToBase64String(imageBytes), + MimeType = GetMimeTypeFromPath(carryForwardImagePath), + FilePath = carryForwardImagePath, + } + ); + } + + var retryImagePaths = + lastUserMessage.ImagePaths?.Where(p => !string.IsNullOrWhiteSpace(p)).ToList() + ?? (!string.IsNullOrEmpty(lastUserMessage.ImagePath) ? [lastUserMessage.ImagePath] : []); + + retryImagePaths = retryImagePaths.Where(File.Exists).ToList(); + + if (retryImagePaths.Count > 0) + { + inputImages ??= []; + foreach (var retryImagePath in retryImagePaths) + { + var imageBytes = await File.ReadAllBytesAsync(retryImagePath, cancellationToken) + .ConfigureAwait(false); + inputImages.Add( + new ImageInputData + { + Base64Data = Convert.ToBase64String(imageBytes), + MimeType = GetMimeTypeFromPath(retryImagePath), + FilePath = retryImagePath, + } + ); + } + } + + // Build request + var request = new ImageGenerationRequest + { + TextPrompt = lastUserMessage.TextContent, + InputImages = inputImages, + ConversationHistory = conversationHistory, + ProviderOptions = providerOptions, + Progress = progress, + }; + + logger.LogInformation( + "Retrying generation with provider {ProviderId} for conversation {ConversationId}", + provider.ProviderId, + conversationId + ); + + progress?.Report( + new ImageGenerationProgress( + ProviderId: providerId, + PromptId: null, + Value: null, + Maximum: null, + RunningNode: null, + Stage: "Generating..." + ) + ); + + var response = await provider.GenerateAsync(request, cancellationToken).ConfigureAwait(false); + + if (!response.IsSuccess) + { + logger.LogError("Retry generation failed: {ErrorMessage}", response.ErrorMessage); + + // Update conversation provider and timestamp + conversation.ProviderId = providerId; + conversation.UpdatedAt = DateTime.UtcNow; + await database.Conversations.UpdateAsync(conversation).ConfigureAwait(false); + + throw new ImageGenerationException(response.ErrorMessage ?? "Image generation failed"); + } + + // Save generated images + List? savedImagePaths = null; + if (response.Images?.Count > 0) + { + progress?.Report( + new ImageGenerationProgress( + ProviderId: providerId, + PromptId: null, + Value: null, + Maximum: null, + RunningNode: null, + Stage: "Saving image(s)..." + ) + ); + + var outputDir = GetOutputDirectory(); + Directory.CreateDirectory(outputDir); + + savedImagePaths = []; + + foreach (var generatedImage in response.Images) + { + var imageBytes = Convert.FromBase64String(generatedImage.Base64Data); + var extension = GetExtensionFromMimeType(generatedImage.MimeType); + var fileName = CreateOutputFileName(extension); + var savedPath = Path.Combine(outputDir, fileName); + + await File.WriteAllBytesAsync(savedPath, imageBytes, cancellationToken).ConfigureAwait(false); + savedImagePaths.Add(savedPath); + } + + logger.LogInformation( + "Saved {Count} retry generated image(s) to {OutputDir}", + savedImagePaths.Count, + outputDir + ); + } + + // Create assistant message - capture thought signature for multi-turn continuity + var thoughtSignature = + response.Images?.FirstOrDefault()?.ThoughtSignature ?? response.ThoughtSignature; + + var assistantMessage = new ImageGenerationMessage + { + ConversationId = conversationId, + Role = MessageRole.Assistant, + TextContent = response.TextResponse, + ImagePath = savedImagePaths?.FirstOrDefault(), + ImagePaths = savedImagePaths, + ImageMimeType = response.Images?.FirstOrDefault()?.MimeType, + ThinkingContent = response.ThinkingContent, + ThoughtSignature = thoughtSignature, + }; + + if (!string.IsNullOrEmpty(thoughtSignature)) + { + logger.LogInformation("Saved thought signature for retry assistant message"); + } + + await database.Messages.InsertAsync(assistantMessage).ConfigureAwait(false); + + // Update conversation + conversation.ProviderId = providerId; + conversation.UpdatedAt = DateTime.UtcNow; + await database.Conversations.UpdateAsync(conversation).ConfigureAwait(false); + + return assistantMessage; + } + + public List GetAvailableProviders() + { + return providers; + } + + public IImageGenerationProvider? GetProvider(string providerId) + { + return providers.FirstOrDefault(p => p.ProviderId == providerId); + } + + private static string GetMimeTypeFromPath(string path) + { + var extension = Path.GetExtension(path).ToLowerInvariant(); + return extension switch + { + ".png" => "image/png", + ".jpg" or ".jpeg" => "image/jpeg", + ".webp" => "image/webp", + ".gif" => "image/gif", + _ => "image/png", + }; + } + + private static string GetExtensionFromMimeType(string mimeType) + { + return mimeType switch + { + "image/png" => ".png", + "image/jpeg" => ".jpg", + "image/webp" => ".webp", + "image/gif" => ".gif", + _ => ".png", + }; + } +} diff --git a/StabilityMatrix.Core/Services/ImageGeneration/ImageGenerationException.cs b/StabilityMatrix.Core/Services/ImageGeneration/ImageGenerationException.cs new file mode 100644 index 000000000..b160e8dd1 --- /dev/null +++ b/StabilityMatrix.Core/Services/ImageGeneration/ImageGenerationException.cs @@ -0,0 +1,13 @@ +namespace StabilityMatrix.Core.Services.ImageGeneration; + +/// +/// Exception thrown when image generation fails +/// +public class ImageGenerationException : Exception +{ + public ImageGenerationException(string message) + : base(message) { } + + public ImageGenerationException(string message, Exception innerException) + : base(message, innerException) { } +} diff --git a/StabilityMatrix.Core/Services/ImageGeneration/ImageGenerationProgress.cs b/StabilityMatrix.Core/Services/ImageGeneration/ImageGenerationProgress.cs new file mode 100644 index 000000000..bb6b3e284 --- /dev/null +++ b/StabilityMatrix.Core/Services/ImageGeneration/ImageGenerationProgress.cs @@ -0,0 +1,17 @@ +namespace StabilityMatrix.Core.Services.ImageGeneration; + +/// +/// Progress update emitted during image generation. +/// Intended to be UI-friendly and provider-agnostic. +/// +public readonly record struct ImageGenerationProgress( + string? ProviderId, + string? PromptId, + int? Value, + int? Maximum, + string? RunningNode, + string? Stage +) +{ + public int? Percent => Value is >= 0 && Maximum is > 0 ? (Value.Value * 100) / Maximum.Value : null; +} diff --git a/StabilityMatrix.Core/Services/ImageGeneration/ImageGenerationRequest.cs b/StabilityMatrix.Core/Services/ImageGeneration/ImageGenerationRequest.cs new file mode 100644 index 000000000..361b193a7 --- /dev/null +++ b/StabilityMatrix.Core/Services/ImageGeneration/ImageGenerationRequest.cs @@ -0,0 +1,96 @@ +namespace StabilityMatrix.Core.Services.ImageGeneration; + +/// +/// Request for image generation +/// +public record ImageGenerationRequest +{ + /// + /// Text prompt for generation + /// + public string? TextPrompt { get; init; } + + /// + /// Input images for editing or composition (base64 encoded) + /// + public List? InputImages { get; init; } + + /// + /// Previous conversation history for multi-turn support + /// + public List? ConversationHistory { get; init; } + + /// + /// Provider-specific configuration options + /// + public Dictionary? ProviderOptions { get; init; } + + /// + /// Optional progress reporter for providers that can emit generation progress (e.g., local ComfyUI). + /// + public IProgress? Progress { get; init; } +} + +/// +/// Represents an input image with its data +/// +public record ImageInputData +{ + /// + /// Base64 encoded image data + /// + public required string Base64Data { get; init; } + + /// + /// MIME type (e.g., "image/png", "image/jpeg") + /// + public required string MimeType { get; init; } + + /// + /// Optional file path on disk (for local providers that can upload directly) + /// + public string? FilePath { get; init; } + + /// + /// Thought signature from Gemini API for this image. + /// Must be passed back in follow-up requests to preserve reasoning context. + /// + public string? ThoughtSignature { get; init; } +} + +/// +/// Represents a message in the conversation history +/// +public record ConversationMessage +{ + /// + /// Role of the message sender + /// + public required MessageRole Role { get; init; } + + /// + /// Text content of the message + /// + public string? TextContent { get; init; } + + /// + /// Image content (base64 encoded) + /// + public ImageInputData? ImageContent { get; init; } + + /// + /// Thought signature for text parts from Gemini API. + /// Must be passed back in follow-up requests to preserve reasoning context. + /// + public string? TextThoughtSignature { get; init; } +} + +/// +/// Role of a message sender +/// +public enum MessageRole +{ + User, + Assistant, + System, +} diff --git a/StabilityMatrix.Core/Services/ImageGeneration/ImageGenerationResponse.cs b/StabilityMatrix.Core/Services/ImageGeneration/ImageGenerationResponse.cs new file mode 100644 index 000000000..0f203f3ff --- /dev/null +++ b/StabilityMatrix.Core/Services/ImageGeneration/ImageGenerationResponse.cs @@ -0,0 +1,66 @@ +namespace StabilityMatrix.Core.Services.ImageGeneration; + +/// +/// Response from image generation +/// +public record ImageGenerationResponse +{ + /// + /// Generated images (base64 encoded) + /// + public List? Images { get; init; } + + /// + /// Text response from the model (if any) + /// + public string? TextResponse { get; init; } + + /// + /// Thinking/reasoning content from the model (Gemini 3 Pro) + /// + public string? ThinkingContent { get; init; } + + /// + /// Thought signature from Gemini API response. + /// Must be stored and passed back in follow-up requests. + /// See: https://ai.google.dev/gemini-api/docs/thought-signatures + /// + public string? ThoughtSignature { get; init; } + + /// + /// Whether the generation was successful + /// + public bool IsSuccess { get; init; } + + /// + /// Error message if generation failed + /// + public string? ErrorMessage { get; init; } + + /// + /// Provider-specific metadata + /// + public Dictionary? Metadata { get; init; } +} + +/// +/// Represents a generated image +/// +public record GeneratedImage +{ + /// + /// Base64 encoded image data + /// + public required string Base64Data { get; init; } + + /// + /// MIME type (e.g., "image/png", "image/jpeg") + /// + public required string MimeType { get; init; } + + /// + /// Thought signature for this specific image from Gemini API. + /// Must be passed back in follow-up requests. + /// + public string? ThoughtSignature { get; init; } +} diff --git a/StabilityMatrix.Core/Services/ImageIndexService.cs b/StabilityMatrix.Core/Services/ImageIndexService.cs index 903e1b083..c8785d2b0 100644 --- a/StabilityMatrix.Core/Services/ImageIndexService.cs +++ b/StabilityMatrix.Core/Services/ImageIndexService.cs @@ -27,7 +27,7 @@ public ImageIndexService(ILogger logger, ISettingsManager set InferenceImages = new IndexCollection(this, file => file.AbsolutePath) { - RelativePath = "Inference" + RelativePath = "Inference", }; EventManager.Instance.ImageFileAdded += OnImageFileAdded; @@ -61,7 +61,8 @@ await Task.Run(() => { var files = searchDir .EnumerateFiles("*", EnumerationOptionConstants.AllDirectories) - .Where(file => LocalImageFile.SupportedImageExtensions.Contains(file.Extension)); + .Where(file => LocalImageFile.SupportedImageExtensions.Contains(file.Extension)) + .Where(file => !file.FullPath.Contains(".sm-thumbs")); // Exclude video thumbnail directories Parallel.ForEach( files, diff --git a/StabilityMatrix.Core/Services/MetadataImportService.cs b/StabilityMatrix.Core/Services/MetadataImportService.cs index 657e2598b..cae475554 100644 --- a/StabilityMatrix.Core/Services/MetadataImportService.cs +++ b/StabilityMatrix.Core/Services/MetadataImportService.cs @@ -95,32 +95,18 @@ public async Task ScanDirectoryForMissingInfo( continue; } - var (model, modelVersion, modelFile) = modelInfo.Value; + var updatedCmInfo = await BuildConnectedModelInfoAsync( + checkpointFilePath, + null, + modelInfo.Value, + progress + ) + .ConfigureAwait(false); - var updatedCmInfo = new ConnectedModelInfo( - model, - modelVersion, - modelFile, - DateTimeOffset.UtcNow - ); await updatedCmInfo .SaveJsonToDirectory(checkpointFilePath.Directory, fileNameWithoutExtension) .ConfigureAwait(false); - var image = modelVersion.Images?.FirstOrDefault( - img => - LocalModelFile.SupportedImageExtensions.Contains(Path.GetExtension(img.Url)) - && img.Type == "image" - ); - if (image == null) - { - scanned++; - success++; - continue; - } - - await DownloadImage(image, checkpointFilePath, progress).ConfigureAwait(false); - scanned++; success++; } @@ -164,9 +150,7 @@ var cmInfoPath in directory.EnumerateFiles( ConnectedModelInfo? cmInfo; try { - cmInfo = JsonSerializer.Deserialize( - await cmInfoPath.ReadAllTextAsync().ConfigureAwait(false) - ); + cmInfo = await ReadConnectedModelInfo(cmInfoPath).ConfigureAwait(false); } catch (JsonException) { @@ -179,7 +163,7 @@ await cmInfoPath.ReadAllTextAsync().ConfigureAwait(false) } var success = 1; - foreach (var (filePath, cmInfoValue) in cmInfoList) + foreach (var (filePath, existingCmInfo) in cmInfoList) { progress?.Report( new ProgressReport( @@ -191,7 +175,7 @@ await cmInfoPath.ReadAllTextAsync().ConfigureAwait(false) try { - var hash = cmInfoValue.Hashes.BLAKE3; + var hash = existingCmInfo.Hashes.BLAKE3; if (string.IsNullOrWhiteSpace(hash)) continue; @@ -202,30 +186,19 @@ await cmInfoPath.ReadAllTextAsync().ConfigureAwait(false) continue; } - var (model, modelVersion, modelFile) = modelInfo.Value; - - var updatedCmInfo = new ConnectedModelInfo( - model, - modelVersion, - modelFile, - DateTimeOffset.UtcNow - ); + var updatedCmInfo = await BuildConnectedModelInfoAsync( + filePath, + existingCmInfo, + modelInfo.Value, + progress + ) + .ConfigureAwait(false); var nameWithoutCmInfo = filePath.NameWithoutExtension.Replace(".cm-info", string.Empty); await updatedCmInfo .SaveJsonToDirectory(filePath.Directory, nameWithoutCmInfo) .ConfigureAwait(false); - var image = modelVersion.Images?.FirstOrDefault( - img => - LocalModelFile.SupportedImageExtensions.Contains(Path.GetExtension(img.Url)) - && img.Type == "image" - ); - if (image == null) - continue; - - await DownloadImage(image, filePath, progress).ConfigureAwait(false); - success++; } catch (Exception e) @@ -245,8 +218,11 @@ await updatedCmInfo var fileNameWithoutExtension = filePath.NameWithoutExtension; var cmInfoPath = filePath.Directory?.JoinFile($"{fileNameWithoutExtension}.cm-info.json"); - var cmInfoExists = File.Exists(cmInfoPath); - if (cmInfoExists && !forceReimport) + var existingCmInfo = + cmInfoPath is not null && File.Exists(cmInfoPath) + ? await ReadConnectedModelInfo(cmInfoPath).ConfigureAwait(false) + : null; + if (existingCmInfo != null && !forceReimport) return null; var hashProgress = new Progress(report => @@ -275,28 +251,74 @@ await updatedCmInfo return null; } - var (model, modelVersion, modelFile) = modelInfo.Value; - - var updatedCmInfo = new ConnectedModelInfo(model, modelVersion, modelFile, DateTimeOffset.UtcNow); + var updatedCmInfo = await BuildConnectedModelInfoAsync( + filePath, + existingCmInfo, + modelInfo.Value, + progress + ) + .ConfigureAwait(false); await updatedCmInfo .SaveJsonToDirectory(filePath.Directory, fileNameWithoutExtension) .ConfigureAwait(false); - var image = modelVersion.Images?.FirstOrDefault( - img => - LocalModelFile.SupportedImageExtensions.Contains(Path.GetExtension(img.Url)) - && img.Type == "image" + return updatedCmInfo; + } + + private async Task BuildConnectedModelInfoAsync( + FilePath modelFilePath, + ConnectedModelInfo? existingCmInfo, + ModelSearchResult modelInfo, + IProgress? progress + ) + { + var (model, modelVersion, modelFile) = modelInfo; + var updatedCmInfo = MergeRemoteMetadata( + existingCmInfo, + new ConnectedModelInfo(model, modelVersion, modelFile, DateTimeOffset.UtcNow) ); - if (image == null) + if (!string.IsNullOrWhiteSpace(updatedCmInfo.ThumbnailImageUrl)) return updatedCmInfo; - var imagePath = await DownloadImage(image, filePath, progress).ConfigureAwait(false); - updatedCmInfo.ThumbnailImageUrl = imagePath; + var image = modelVersion.Images?.FirstOrDefault(img => + LocalModelFile.SupportedImageExtensions.Contains(Path.GetExtension(img.Url)) + && img.Type == "image" + ); + + if (image != null) + { + updatedCmInfo.ThumbnailImageUrl = await DownloadImage(image, modelFilePath, progress) + .ConfigureAwait(false); + } return updatedCmInfo; } + private static ConnectedModelInfo MergeRemoteMetadata( + ConnectedModelInfo? existingCmInfo, + ConnectedModelInfo refreshedCmInfo + ) + { + if (existingCmInfo == null) + return refreshedCmInfo; + + refreshedCmInfo.ImportedAt = + existingCmInfo.ImportedAt == default ? refreshedCmInfo.ImportedAt : existingCmInfo.ImportedAt; + refreshedCmInfo.UserTitle = existingCmInfo.UserTitle; + refreshedCmInfo.ThumbnailImageUrl = existingCmInfo.ThumbnailImageUrl; + refreshedCmInfo.InferenceDefaults = existingCmInfo.InferenceDefaults; + + return refreshedCmInfo; + } + + private static async Task ReadConnectedModelInfo(FilePath cmInfoPath) + { + return JsonSerializer.Deserialize( + await cmInfoPath.ReadAllTextAsync().ConfigureAwait(false) + ); + } + private static async Task GetBlake3Hash( FilePath? cmInfoPath, FilePath checkpointFilePath, @@ -308,9 +330,7 @@ IProgress hashProgress return await FileHash.GetBlake3Async(checkpointFilePath, hashProgress).ConfigureAwait(false); } - var cmInfo = JsonSerializer.Deserialize( - await cmInfoPath.ReadAllTextAsync().ConfigureAwait(false) - ); + var cmInfo = await ReadConnectedModelInfo(cmInfoPath).ConfigureAwait(false); return cmInfo?.Hashes.BLAKE3; } diff --git a/StabilityMatrix.Core/Services/ModelIndexService.cs b/StabilityMatrix.Core/Services/ModelIndexService.cs index 924642da1..ea69d302f 100644 --- a/StabilityMatrix.Core/Services/ModelIndexService.cs +++ b/StabilityMatrix.Core/Services/ModelIndexService.cs @@ -34,6 +34,8 @@ public partial class ModelIndexService : IModelIndexService private Dictionary> _modelIndex = new(); private HashSet? _modelIndexBlake3Hashes; + private HashSet? _modelIndexSha256Hashes; + private HashSet? _modelIndexCivArchiveUrls; /// /// Whether the database has been initially loaded. @@ -53,6 +55,12 @@ private set public IReadOnlySet ModelIndexBlake3Hashes => _modelIndexBlake3Hashes ??= CollectModelHashes(ModelIndex.Values.SelectMany(x => x)); + public IReadOnlySet ModelIndexSha256Hashes => + _modelIndexSha256Hashes ??= CollectModelSha256Hashes(ModelIndex.Values.SelectMany(x => x)); + + public IReadOnlySet ModelIndexCivArchiveUrls => + _modelIndexCivArchiveUrls ??= CollectCivArchiveUrls(ModelIndex.Values.SelectMany(x => x)); + [AutoPostConstruct] private void Initialize() { @@ -688,6 +696,8 @@ public async Task UpsertModelAsync(LocalModelFile model) private void OnModelIndexReset() { _modelIndexBlake3Hashes = null; + _modelIndexSha256Hashes = null; + _modelIndexCivArchiveUrls = null; } private static HashSet CollectModelHashes(IEnumerable models) @@ -703,6 +713,32 @@ private static HashSet CollectModelHashes(IEnumerable mo return hashes; } + private static HashSet CollectModelSha256Hashes(IEnumerable models) + { + var hashes = new HashSet(StringComparer.OrdinalIgnoreCase); + foreach (var model in models) + { + if (!string.IsNullOrWhiteSpace(model.HashSha256)) + { + hashes.Add(model.HashSha256); + } + } + return hashes; + } + + private static HashSet CollectCivArchiveUrls(IEnumerable models) + { + var urls = new HashSet(StringComparer.OrdinalIgnoreCase); + foreach (var model in models) + { + if (model.HasCivArchiveMetadata && !string.IsNullOrWhiteSpace(model.ConnectedModelInfo.SourceUrl)) + { + urls.Add(model.ConnectedModelInfo.SourceUrl); + } + } + return urls; + } + private static bool GetHasEarlyAccessUpdateOnly(LocalModelFile model, CivitModel? remoteModel) { if (!model.HasUpdate || !model.HasCivitMetadata) diff --git a/StabilityMatrix.Core/Services/NotificationHistoryService.cs b/StabilityMatrix.Core/Services/NotificationHistoryService.cs new file mode 100644 index 000000000..bf653dd4b --- /dev/null +++ b/StabilityMatrix.Core/Services/NotificationHistoryService.cs @@ -0,0 +1,145 @@ +using Injectio.Attributes; +using StabilityMatrix.Core.Models.Notifications; + +namespace StabilityMatrix.Core.Services; + +[RegisterSingleton] +public class NotificationHistoryService : INotificationHistoryService +{ + private const int MaxEntries = 100; + + private readonly LinkedList entries = new(); + private readonly object sync = new(); + + public IReadOnlyList Entries + { + get + { + lock (sync) + return entries.ToList(); + } + } + + public int Count + { + get + { + lock (sync) + return entries.Count; + } + } + + public int UnreadCount + { + get + { + lock (sync) + return entries.Count(e => !e.IsRead); + } + } + + public event EventHandler? EntryAdded; + public event EventHandler? EntriesChanged; + + public NotificationHistoryEntry Add(NotificationHistoryEntry entry) + { + lock (sync) + { + entries.AddFirst(entry); + while (entries.Count > MaxEntries) + { + entries.RemoveLast(); + } + } + + EntryAdded?.Invoke(this, entry); + return entry; + } + + public void MarkRead(Guid id) + { + bool changed; + lock (sync) + { + var entry = entries.FirstOrDefault(e => e.Id == id); + changed = entry is { IsRead: false }; + if (entry != null) + { + entry.IsRead = true; + } + } + + if (changed) + { + EntriesChanged?.Invoke(this, EventArgs.Empty); + } + } + + public void MarkAllRead() + { + bool changed; + lock (sync) + { + changed = false; + foreach (var entry in entries.Where(e => !e.IsRead)) + { + entry.IsRead = true; + changed = true; + } + } + + if (changed) + { + EntriesChanged?.Invoke(this, EventArgs.Empty); + } + } + + public void Remove(Guid id) + { + bool changed; + lock (sync) + { + var node = entries.First; + changed = false; + while (node != null) + { + if (node.Value.Id == id) + { + entries.Remove(node); + changed = true; + break; + } + + node = node.Next; + } + } + + if (changed) + { + EntriesChanged?.Invoke(this, EventArgs.Empty); + } + } + + public void Clear() + { + bool changed; + lock (sync) + { + changed = entries.Count > 0; + entries.Clear(); + } + + if (changed) + { + EntriesChanged?.Invoke(this, EventArgs.Empty); + } + } + + public NotificationHistoryEntry? Find(Guid id) + { + lock (sync) + { + return entries.FirstOrDefault(e => e.Id == id); + } + } +} diff --git a/StabilityMatrix.Tests/Avalonia/CivArchiveBrowserViewModelTests.cs b/StabilityMatrix.Tests/Avalonia/CivArchiveBrowserViewModelTests.cs new file mode 100644 index 000000000..2206be695 --- /dev/null +++ b/StabilityMatrix.Tests/Avalonia/CivArchiveBrowserViewModelTests.cs @@ -0,0 +1,548 @@ +using FluentAvalonia.UI.Media.Animation; +using NSubstitute; +using StabilityMatrix.Avalonia.Services; +using StabilityMatrix.Avalonia.ViewModels; +using StabilityMatrix.Avalonia.ViewModels.Base; +using StabilityMatrix.Avalonia.ViewModels.CheckpointBrowser; +using StabilityMatrix.Core.Api; +using StabilityMatrix.Core.Models; +using StabilityMatrix.Core.Models.Api.CivArchive; +using StabilityMatrix.Core.Models.FileInterfaces; +using StabilityMatrix.Core.Models.Settings; +using StabilityMatrix.Core.Services; + +namespace StabilityMatrix.Tests.Avalonia; + +[TestClass] +public class CivArchiveBrowserViewModelTests +{ + [TestMethod] + public void Constructor_LoadsExpectedDefaults() + { + var vm = CreateViewModel(Substitute.For(), out _, out _); + + vm.OnLoaded(); + + Assert.AreEqual(CivArchivePlatformOption.All, vm.SelectedPlatform?.Value); + Assert.AreEqual(CivArchiveSortOption.Top, vm.SelectedSort?.Value); + Assert.AreEqual(CivArchivePeriodOption.All, vm.SelectedPeriod?.Value); + Assert.AreEqual(CivArchiveRatingOption.Safe, vm.SelectedRating?.Value); + Assert.AreEqual(CivArchiveKindOption.All, vm.SelectedKind?.Value); + } + + [TestMethod] + public async Task ChangingFilters_TriggersNewQueryAndResetsPaging() + { + var apiClient = Substitute.For(); + var recordedFilters = new List(); + apiClient + .SearchAsync(Arg.Any(), Arg.Any()) + .Returns(call => + { + var filters = call.Arg(); + recordedFilters.Add(filters); + return Task.FromResult(CreateSearchResponse(filters.Page)); + }); + + var vm = CreateViewModel(apiClient, out _, out _); + // Collapse the debounce so the sort-change re-fetch happens within the test + // window instead of after the production-default 300ms idle delay. + vm.SearchDebounceInterval = TimeSpan.Zero; + vm.OnLoaded(); + + await vm.SearchModelsCommand.ExecuteAsync(false); + await vm.LoadNextPageAsync(); + + vm.SelectedSort = vm.AllSorts.First(x => x.Value == CivArchiveSortOption.Newest); + await Task.Delay(50); + + // 3 calls: initial fetch, LoadNextPage, sort change. + // (No saved selections means ApplyFilterOptions doesn't trigger a redundant re-fetch.) + Assert.AreEqual(3, recordedFilters.Count); + Assert.AreEqual(1, recordedFilters[0].Page); + Assert.AreEqual(2, recordedFilters[1].Page); + Assert.AreEqual(1, recordedFilters[2].Page); + Assert.AreEqual(CivArchiveSortOption.Newest, recordedFilters[2].Sort); + } + + [TestMethod] + public async Task OpenResult_UserPivotSetsUsernameFilter() + { + var apiClient = Substitute.For(); + CivArchiveSearchFilters? recordedFilter = null; + apiClient + .SearchAsync(Arg.Any(), Arg.Any()) + .Returns(call => + { + recordedFilter = call.Arg(); + return Task.FromResult(CreateSearchResponse(1)); + }); + + var vm = CreateViewModel(apiClient, out _, out _); + vm.OnLoaded(); + + await vm.OpenResultCommand.ExecuteAsync( + new CivArchiveSearchResult + { + Id = "user-1", + Name = "artist-name", + KindRaw = "user", + Username = "artist-name", + Url = "/users/artist-name", + } + ); + + Assert.IsNotNull(recordedFilter); + Assert.AreEqual("artist-name", recordedFilter.Username); + } + + [TestMethod] + public async Task OpenResult_VersionNavigatesToDetailsPage() + { + var apiClient = Substitute.For(); + var navigationService = Substitute.For>(); + var vm = CreateViewModel(apiClient, out _, out var serviceManager, navigationService); + + await vm.OpenResultCommand.ExecuteAsync( + new CivArchiveSearchResult + { + Id = "version-1", + Name = "Version", + KindRaw = "version", + Url = "/models/443821?modelVersionId=2581228", + } + ); + + navigationService + .Received(1) + .NavigateTo( + Arg.Is(x => + x.GetType() == typeof(CivArchiveDetailsPageViewModel) + && ((CivArchiveDetailsPageViewModel)x).RelativeUrl.Contains("modelVersionId=2581228") + ), + Arg.Any() + ); + } + + [TestMethod] + public async Task ChangingFilterWhileLoading_QueuesRefreshWithLatestFilter() + { + var apiClient = Substitute.For(); + var delayedResponse = new TaskCompletionSource(); + var recordedFilters = new List(); + + apiClient + .SearchAsync(Arg.Any(), Arg.Any()) + .Returns(call => + { + var filters = call.Arg(); + recordedFilters.Add(filters); + + // Delay call #2 β€” that's the explicit ExecuteAsync at line below (line 149 + // equivalent), so the sort change happens while it's in flight and gets + // queued. With the redundant init re-fetch removed, the second call is + // now the queueable one (used to be call #3). + return recordedFilters.Count switch + { + 2 => delayedResponse.Task, + _ => Task.FromResult(CreateSearchResponse(filters.Page)), + }; + }); + + var vm = CreateViewModel(apiClient, out _, out _); + vm.SearchDebounceInterval = TimeSpan.Zero; + vm.OnLoaded(); + + await vm.SearchModelsCommand.ExecuteAsync(false); + + var loadingSearch = vm.SearchModelsCommand.ExecuteAsync(false); + vm.SelectedSort = vm.AllSorts.First(x => x.Value == CivArchiveSortOption.Newest); + + // Let the debounced fire-and-forget task spin up: it'll await Task.Delay(0), + // call SearchModels, see IsLoading=true, and set searchQueued. After the + // delayed in-flight call completes, the queued mechanism re-fires with the + // newest sort. + await Task.Delay(50); + + delayedResponse.SetResult(CreateSearchResponse(1)); + await loadingSearch; + + // 3 calls: initial, in-flight (delayed), queued sort-change refresh. + Assert.AreEqual(3, recordedFilters.Count); + Assert.AreEqual(CivArchiveSortOption.Top, recordedFilters[1].Sort); + Assert.AreEqual(CivArchiveSortOption.Newest, recordedFilters[2].Sort); + } + + [TestMethod] + public async Task DownloadModel_UsesPrimaryFileUrlNameAndHash() + { + var apiClient = Substitute.For(); + var modelImportService = Substitute.For(); + var settingsManager = Substitute.For(); + var model = CreateDetailsModel( + new CivArchiveModelFile + { + Name = "realDream_sdxl7.ckpt", + DownloadUrl = "https://example.org/download/realDream_sdxl7.ckpt", + Sha256 = "63b1db60611f52c4fbb2cade67dbdf4029c6620c5b22f2a4ddb27a47d7601953", + IsPrimary = true, + } + ); + + IReadOnlyList? capturedUris = null; + string? capturedFileName = null; + Action? configureDownload = null; + + apiClient + .GetModelDetailsAsync(Arg.Any(), Arg.Any()) + .Returns(new CivArchiveModelDetailsResponse { Model = model }); + apiClient.GetAbsoluteUri(Arg.Any()).Returns(call => new Uri(call.Arg())); + settingsManager.IsLibraryDirSet.Returns(true); + settingsManager.ModelsDirectory.Returns(Path.GetTempPath()); + modelImportService + .DoCustomImport( + Arg.Do>(uris => capturedUris = uris.ToList()), + Arg.Do(fileName => capturedFileName = fileName), + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Do?>(action => configureDownload = action) + ) + .Returns(Task.CompletedTask); + + var vm = CreateDetailsViewModel(apiClient, modelImportService, settingsManager); + vm.RelativeUrl = "/models/153568?modelVersionId=2053273"; + + await vm.OnLoadedAsync(); + await vm.DownloadModelCommand.ExecuteAsync(null); + + Assert.IsTrue(vm.HasDownloadUrl); + Assert.AreEqual("realDream_sdxl7.ckpt", capturedFileName); + Assert.AreEqual("https://example.org/download/realDream_sdxl7.ckpt", capturedUris?[0].ToString()); + + var download = new TrackedDownload + { + Id = Guid.NewGuid(), + SourceUrl = capturedUris![0], + DownloadDirectory = new DirectoryPath(Path.GetTempPath()), + FileName = capturedFileName!, + TempFileName = $"{capturedFileName}.tmp", + }; + configureDownload?.Invoke(download); + Assert.AreEqual(model.Version?.Files[0].Sha256, download.ExpectedHashSha256); + } + + [TestMethod] + public async Task DownloadModel_UsesFileMirrorUrlWhenDirectUrlIsMissing() + { + var apiClient = Substitute.For(); + var modelImportService = Substitute.For(); + var settingsManager = Substitute.For(); + var model = CreateDetailsModel( + new CivArchiveModelFile + { + Name = "mirror-only.safetensors", + Mirrors = + [ + new CivArchiveFileMirror + { + Source = "civitai", + Url = "https://example.org/mirror/mirror-only.safetensors", + }, + ], + } + ); + + IReadOnlyList? capturedUris = null; + + apiClient + .GetModelDetailsAsync(Arg.Any(), Arg.Any()) + .Returns(new CivArchiveModelDetailsResponse { Model = model }); + apiClient.GetAbsoluteUri(Arg.Any()).Returns(call => new Uri(call.Arg())); + settingsManager.IsLibraryDirSet.Returns(true); + settingsManager.ModelsDirectory.Returns(Path.GetTempPath()); + modelImportService + .DoCustomImport( + Arg.Do>(uris => capturedUris = uris.ToList()), + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any?>() + ) + .Returns(Task.CompletedTask); + + var vm = CreateDetailsViewModel(apiClient, modelImportService, settingsManager); + vm.RelativeUrl = "/models/153568?modelVersionId=2053273"; + + await vm.OnLoadedAsync(); + await vm.DownloadModelCommand.ExecuteAsync(null); + + Assert.IsTrue(vm.HasDownloadUrl); + Assert.AreEqual("https://example.org/mirror/mirror-only.safetensors", capturedUris?[0].ToString()); + } + + [TestMethod] + public async Task DownloadModel_UsesImagePreviewAndSkipsVideoMedia() + { + var apiClient = Substitute.For(); + var modelImportService = Substitute.For(); + var settingsManager = Substitute.For(); + var model = CreateDetailsModel( + new CivArchiveModelFile + { + Name = "model.safetensors", + DownloadUrl = "https://example.org/download/model.safetensors", + IsPrimary = true, + } + ); + model.Version!.Images = + [ + new CivArchiveModelImage { Url = "https://c.genur.art/video-id", Type = "video" }, + new CivArchiveModelImage + { + Url = "https://img.genur.art/sig/width:450/quality:85/image-id", + Type = "image", + }, + ]; + + Uri? capturedPreviewUri = null; + + apiClient + .GetModelDetailsAsync(Arg.Any(), Arg.Any()) + .Returns(new CivArchiveModelDetailsResponse { Model = model }); + apiClient.GetAbsoluteUri(Arg.Any()).Returns(call => new Uri(call.Arg())); + settingsManager.IsLibraryDirSet.Returns(true); + settingsManager.ModelsDirectory.Returns(Path.GetTempPath()); + modelImportService + .DoCustomImport( + Arg.Any>(), + Arg.Any(), + Arg.Any(), + Arg.Do(uri => capturedPreviewUri = uri), + Arg.Any(), + Arg.Any(), + Arg.Any?>() + ) + .Returns(Task.CompletedTask); + + var vm = CreateDetailsViewModel(apiClient, modelImportService, settingsManager); + vm.RelativeUrl = "/models/1?modelVersionId=2"; + + await vm.OnLoadedAsync(); + await vm.DownloadModelCommand.ExecuteAsync(null); + + Assert.AreEqual(1, vm.Images.Count); + Assert.AreEqual("https://img.genur.art/sig/width:450/quality:85/image-id", vm.Images[0].Url); + Assert.AreEqual( + "https://img.genur.art/sig/width:450/quality:85/image-id", + capturedPreviewUri?.ToString() + ); + } + + [TestMethod] + public void ParseSearchQuery_PlainQuery_ReturnsQueryOnly() + { + var (query, tags, username) = CivArchiveBrowserViewModel.ParseSearchQuery("dragon style"); + + Assert.AreEqual("dragon style", query); + Assert.AreEqual(string.Empty, tags); + Assert.AreEqual(string.Empty, username); + } + + [TestMethod] + public void ParseSearchQuery_AtToken_ExtractsUsername() + { + var (query, tags, username) = CivArchiveBrowserViewModel.ParseSearchQuery("dragon @sinatra"); + + Assert.AreEqual("dragon", query); + Assert.AreEqual(string.Empty, tags); + Assert.AreEqual("sinatra", username); + } + + [TestMethod] + public void ParseSearchQuery_HashTokens_ExtractedAsCommaJoined() + { + var (query, tags, username) = CivArchiveBrowserViewModel.ParseSearchQuery("painting #anime #sdxl"); + + Assert.AreEqual("painting", query); + Assert.AreEqual("anime,sdxl", tags); + Assert.AreEqual(string.Empty, username); + } + + [TestMethod] + public void ParseSearchQuery_MultipleAtTokens_LastWins() + { + var (_, _, username) = CivArchiveBrowserViewModel.ParseSearchQuery("@alice @bob"); + + Assert.AreEqual("bob", username); + } + + [TestMethod] + public void ParseSearchQuery_MixedTokens_AllExtracted() + { + var (query, tags, username) = CivArchiveBrowserViewModel.ParseSearchQuery( + "dragon #anime @sinatra #sdxl knight" + ); + + Assert.AreEqual("dragon knight", query); + Assert.AreEqual("anime,sdxl", tags); + Assert.AreEqual("sinatra", username); + } + + [TestMethod] + public void ParseSearchQuery_EmptyOrWhitespace_ReturnsEmptyTuple() + { + var empty = CivArchiveBrowserViewModel.ParseSearchQuery(""); + var whitespace = CivArchiveBrowserViewModel.ParseSearchQuery(" "); + + Assert.AreEqual(string.Empty, empty.query); + Assert.AreEqual(string.Empty, empty.tags); + Assert.AreEqual(string.Empty, empty.username); + Assert.AreEqual(string.Empty, whitespace.query); + } + + [TestMethod] + public void ParseSearchQuery_BareSigil_KeptAsRegularToken() + { + // A lone "@" or "#" with nothing after isn't a username/tag prefix β€” + // the parser leaves it as part of the regular query. + var (query, tags, username) = CivArchiveBrowserViewModel.ParseSearchQuery("foo @ # bar"); + + Assert.AreEqual("foo @ # bar", query); + Assert.AreEqual(string.Empty, tags); + Assert.AreEqual(string.Empty, username); + } + + private static CivArchiveBrowserViewModel CreateViewModel( + ICivArchiveApiClient apiClient, + out Settings settings, + out IServiceManager serviceManager, + INavigationService? navigationService = null + ) + { + var localSettings = new Settings(); + settings = localSettings; + var settingsManager = Substitute.For(); + settingsManager.Settings.Returns(localSettings); + settingsManager.IsLibraryDirSet.Returns(true); + settingsManager + .When(x => x.Transaction(Arg.Any>(), Arg.Any())) + .Do(call => call.Arg>()(localSettings)); + + serviceManager = new TestServiceManager( + new CivArchiveDetailsPageViewModel( + Substitute.For(), + navigationService ?? Substitute.For>(), + Substitute.For>(), + Substitute.For(), + Substitute.For(), + Substitute.For(), + CreateModelIndexServiceStub() + ) + ); + + return new CivArchiveBrowserViewModel( + apiClient, + settingsManager, + serviceManager, + navigationService ?? Substitute.For>(), + CreateModelIndexServiceStub() + ); + } + + private static IModelIndexService CreateModelIndexServiceStub() + { + var stub = Substitute.For(); + stub.ModelIndexSha256Hashes.Returns(new HashSet()); + stub.ModelIndexBlake3Hashes.Returns(new HashSet()); + return stub; + } + + private static CivArchiveDetailsPageViewModel CreateDetailsViewModel( + ICivArchiveApiClient apiClient, + IModelImportService modelImportService, + ISettingsManager settingsManager + ) + { + return new CivArchiveDetailsPageViewModel( + apiClient, + Substitute.For>(), + Substitute.For>(), + modelImportService, + settingsManager, + Substitute.For(), + CreateModelIndexServiceStub() + ); + } + + private static CivArchiveModelDetails CreateDetailsModel(CivArchiveModelFile file) + { + return new CivArchiveModelDetails + { + Name = "Real Dream", + Type = "Checkpoint", + Version = new CivArchiveModelVersion + { + Name = "SDXL 7", + BaseModel = "SDXL 1.0", + Files = [file], + Images = [new CivArchiveModelImage { Url = "https://example.org/preview.webp" }], + }, + }; + } + + private static CivArchiveSearchResponse CreateSearchResponse(int page) + { + return new CivArchiveSearchResponse + { + Results = + [ + new CivArchiveSearchResult + { + Id = $"item-{page}-1", + Name = $"Item {page}", + KindRaw = "version", + Url = $"/models/{page}?modelVersionId={page}", + }, + ], + FilterOptions = new CivArchiveFilterOptions + { + BaseModels = ["Illustrious", "Pony"], + ModelTypes = ["LORA", "Checkpoint"], + }, + EffectiveFilters = new CivArchiveSearchFilters { Page = page }, + CanonicalUrl = "https://civarchive.com/top-models", + Hits = 1, + TotalHits = 4, + }; + } + + private sealed class TestServiceManager(T instance) : IServiceManager + { + public IServiceManager Register(TService serviceInstance) + where TService : T => this; + + public IServiceManager Register(Func provider) + where TService : T => this; + + public void Register(Type type, Func providerFunc) { } + + public IServiceManager RegisterProvider(IServiceProvider provider) + where TService : notnull, T => this; + + public IServiceManager RegisterScoped(Func provider) + where TService : T => this; + + public IServiceManagerScope CreateScope() => throw new NotImplementedException(); + + public T Get(Type serviceType) => instance; + + public TService Get() + where TService : T => (TService)instance!; + + public IServiceManager RegisterScoped(Type type, Func provider) => this; + } +} diff --git a/StabilityMatrix.Tests/Avalonia/FileNameFormatProviderTests.cs b/StabilityMatrix.Tests/Avalonia/FileNameFormatProviderTests.cs index 5905aca02..bc948706b 100644 --- a/StabilityMatrix.Tests/Avalonia/FileNameFormatProviderTests.cs +++ b/StabilityMatrix.Tests/Avalonia/FileNameFormatProviderTests.cs @@ -1,5 +1,8 @@ ο»Ώusing System.ComponentModel.DataAnnotations; using StabilityMatrix.Avalonia.Models.Inference; +using StabilityMatrix.Core.Models; +using StabilityMatrix.Core.Models.Api; +using StabilityMatrix.Core.Models.Database; namespace StabilityMatrix.Tests.Avalonia; @@ -25,4 +28,50 @@ public void TestFileNameFormatProviderValidate_Invalid_ShouldThrow() Assert.AreEqual("Unknown variable 'invalid'", result.ErrorMessage); } + + [TestMethod] + public void TestFileNameFormatProviderTryResolveVariable_UsesLocalModelMetadata() + { + var provider = new FileNameFormatProvider + { + LocalModelFile = new LocalModelFile + { + RelativePath = Path.Combine("Checkpoints", "local-file.safetensors"), + SharedFolderType = SharedFolderType.StableDiffusion, + ConnectedModelInfo = new ConnectedModelInfo + { + ModelName = "Remote Model", + VersionName = "Version One", + ModelType = CivitModelType.Checkpoint, + AuthorUsername = "creator-name", + BaseModel = "SDXL", + RemoteFileName = "remote-file.safetensors", + RemoteFileId = 123, + Hashes = new CivitFileHashes(), + }, + }, + }; + + var resolved = provider.TryResolveVariable("file_name", out var fileName, out var error); + + Assert.IsTrue(resolved, error); + Assert.AreEqual("remote-file", fileName); + } + + [TestMethod] + public void GetSampleForOrganization_ResolvesAllLocalOrganizationVariables() + { + var provider = FileNameFormatProvider.GetSampleForOrganization(); + + foreach (var variable in FileNameFormatProvider.LocalOrganizationVariables) + { + var resolved = provider.TryResolveVariable(variable, out var value, out var error); + + Assert.IsTrue(resolved, $"Variable '{variable}' failed to resolve: {error}"); + Assert.IsFalse( + string.IsNullOrWhiteSpace(value), + $"Variable '{variable}' resolved to null or empty" + ); + } + } } diff --git a/StabilityMatrix.Tests/Avalonia/ModelImportServiceTests.cs b/StabilityMatrix.Tests/Avalonia/ModelImportServiceTests.cs new file mode 100644 index 000000000..269dfc628 --- /dev/null +++ b/StabilityMatrix.Tests/Avalonia/ModelImportServiceTests.cs @@ -0,0 +1,85 @@ +using Avalonia.Controls.Notifications; +using NSubstitute; +using StabilityMatrix.Avalonia.Services; +using StabilityMatrix.Core.Models; +using StabilityMatrix.Core.Models.FileInterfaces; +using StabilityMatrix.Core.Models.Progress; +using StabilityMatrix.Core.Services; + +namespace StabilityMatrix.Tests.Avalonia; + +[TestClass] +public class ModelImportServiceTests +{ + [TestMethod] + public async Task DoCustomImport_StartsTrackedDownloadBeforePreviewDownloadCompletes() + { + var downloadService = Substitute.For(); + var notificationService = Substitute.For(); + var trackedDownloadService = Substitute.For(); + var service = new ModelImportService(downloadService, notificationService, trackedDownloadService); + + var tempDir = Directory.CreateTempSubdirectory(); + var modelUri = new Uri("https://example.org/model.safetensors"); + var previewUri = new Uri("https://example.org/preview.webp"); + var previewDownload = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var trackedDownload = new TrackedDownload + { + Id = Guid.NewGuid(), + SourceUrl = modelUri, + DownloadDirectory = new DirectoryPath(tempDir.FullName), + FileName = "model.safetensors", + TempFileName = "model.safetensors.tmp", + }; + + try + { + downloadService + .DownloadToFileAsync( + previewUri.ToString(), + Arg.Any(), + Arg.Any?>(), + Arg.Any(), + Arg.Any() + ) + .Returns(previewDownload.Task); + + notificationService + .TryAsync(Arg.Any(), Arg.Any(), Arg.Any(), Arg.Any()) + .Returns(call => AwaitTask(call.Arg())); + + trackedDownloadService.NewDownload(modelUri, Arg.Any()).Returns(trackedDownload); + trackedDownloadService.TryStartDownload(trackedDownload).Returns(Task.CompletedTask); + + var importTask = service.DoCustomImport( + [modelUri], + "model.safetensors", + new DirectoryPath(tempDir.FullName), + previewUri, + ".webp" + ); + + var completedTask = await Task.WhenAny(importTask, Task.Delay(TimeSpan.FromSeconds(1))); + + Assert.AreSame( + importTask, + completedTask, + "The model import should not wait for preview image download completion." + ); + + await trackedDownloadService.Received(1).TryStartDownload(trackedDownload); + Assert.IsFalse(previewDownload.Task.IsCompleted); + } + finally + { + previewDownload.TrySetResult(); + tempDir.Delete(recursive: true); + } + } + + private static async Task> AwaitTask(Task task) + { + await task; + return new TaskResult(true); + } +} diff --git a/StabilityMatrix.Tests/Avalonia/ModelOrganizationServiceTests.cs b/StabilityMatrix.Tests/Avalonia/ModelOrganizationServiceTests.cs new file mode 100644 index 000000000..46bfccf14 --- /dev/null +++ b/StabilityMatrix.Tests/Avalonia/ModelOrganizationServiceTests.cs @@ -0,0 +1,375 @@ +using StabilityMatrix.Avalonia.Services; +using StabilityMatrix.Core.Models; +using StabilityMatrix.Core.Models.Api; +using StabilityMatrix.Core.Models.Database; + +namespace StabilityMatrix.Tests.Avalonia; + +[TestClass] +public class ModelOrganizationServiceTests +{ + private readonly ModelOrganizationService service = new(); + + [TestMethod] + public void BuildPlan_UsesLocalMetadataPattern() + { + var tempRoot = CreateTempDirectory(); + + try + { + var scopePath = Path.Combine(tempRoot, "Checkpoints"); + var model = CreateModelFile( + tempRoot, + Path.Combine("Checkpoints", "Source", "local-file.safetensors"), + "remote-file.safetensors", + authorUsername: "creator-name", + baseModel: "SDXL" + ); + + var plan = service.BuildPlan( + [model], + tempRoot, + scopePath, + includeNested: true, + template: "{author}/{base_model}/{file_name}" + ); + + Assert.AreEqual(1, plan.Items.Count); + Assert.AreEqual(1, plan.ReadyCount); + Assert.AreEqual( + Path.Combine(scopePath, "creator-name", "SDXL", "remote-file.safetensors"), + plan.Items[0].TargetPath + ); + } + finally + { + Directory.Delete(tempRoot, true); + } + } + + [TestMethod] + public void BuildPlan_SkipsUnsupportedVariable() + { + var tempRoot = CreateTempDirectory(); + + try + { + var scopePath = Path.Combine(tempRoot, "Checkpoints"); + var model = CreateModelFile( + tempRoot, + Path.Combine("Checkpoints", "local-file.safetensors"), + "remote-file.safetensors" + ); + + var plan = service.BuildPlan( + [model], + tempRoot, + scopePath, + includeNested: true, + template: "{seed}" + ); + + Assert.AreEqual(1, plan.Items.Count); + Assert.AreEqual(0, plan.ReadyCount); + StringAssert.Contains(plan.Items[0].Reason, "not supported"); + } + finally + { + Directory.Delete(tempRoot, true); + } + } + + [TestMethod] + public void BuildPlan_RespectsNestedScope() + { + var tempRoot = CreateTempDirectory(); + + try + { + var scopePath = Path.Combine(tempRoot, "Checkpoints", "Group"); + var model = CreateModelFile( + tempRoot, + Path.Combine("Checkpoints", "Group", "Nested", "local-file.safetensors"), + "remote-file.safetensors" + ); + + var withoutNested = service.BuildPlan( + [model], + tempRoot, + scopePath, + includeNested: false, + template: "{file_name}" + ); + var withNested = service.BuildPlan( + [model], + tempRoot, + scopePath, + includeNested: true, + template: "{file_name}" + ); + + Assert.AreEqual(0, withoutNested.Items.Count); + Assert.AreEqual(1, withNested.Items.Count); + } + finally + { + Directory.Delete(tempRoot, true); + } + } + + [TestMethod] + public async Task ApplyPlan_MovesModelAndSidecars() + { + var tempRoot = CreateTempDirectory(); + + try + { + var scopePath = Path.Combine(tempRoot, "Checkpoints"); + var relativePath = Path.Combine("Checkpoints", "Source", "local-file.safetensors"); + var model = CreateModelFile(tempRoot, relativePath, "remote-file.safetensors"); + var sourcePath = Path.Combine(tempRoot, relativePath); + await File.WriteAllTextAsync( + Path.Combine(Path.GetDirectoryName(sourcePath)!, "local-file.cm-info.json"), + "{}" + ); + await File.WriteAllTextAsync( + Path.Combine(Path.GetDirectoryName(sourcePath)!, "local-file.preview.png"), + "preview" + ); + await File.WriteAllTextAsync( + Path.Combine(Path.GetDirectoryName(sourcePath)!, "local-file.yaml"), + "config" + ); + + var plan = service.BuildPlan( + [model], + tempRoot, + scopePath, + includeNested: true, + template: "organized/{file_name}" + ); + + var result = await service.ApplyPlan(plan); + + var targetDirectory = Path.Combine(scopePath, "organized"); + Assert.AreEqual(1, result.MovedCount); + Assert.IsTrue(File.Exists(Path.Combine(targetDirectory, "remote-file.safetensors"))); + Assert.IsTrue(File.Exists(Path.Combine(targetDirectory, "remote-file.cm-info.json"))); + Assert.IsTrue(File.Exists(Path.Combine(targetDirectory, "remote-file.preview.png"))); + Assert.IsTrue(File.Exists(Path.Combine(targetDirectory, "remote-file.yaml"))); + Assert.IsFalse(File.Exists(sourcePath)); + } + finally + { + Directory.Delete(tempRoot, true); + } + } + + [TestMethod] + public void BuildPlan_PreservesTypeFolderWhenScopeIsRoot() + { + var tempRoot = CreateTempDirectory(); + + try + { + // scopePath is the models root, not a type folder + var scopePath = tempRoot; + var loraModel = CreateModelFile( + tempRoot, + Path.Combine("Lora", "SD 1.5", "add_detail.safetensors"), + "add_detail.safetensors", + authorUsername: "creator", + baseModel: "SD 1.5" + ); + var checkpointModel = CreateModelFile( + tempRoot, + Path.Combine("StableDiffusion", "some_model.safetensors"), + "some_model.safetensors", + authorUsername: "creator", + baseModel: "SDXL" + ); + + var plan = service.BuildPlan( + [loraModel, checkpointModel], + tempRoot, + scopePath, + includeNested: true, + template: "{base_model}/{file_name}" + ); + + var loraItem = plan.Items.First(i => i.SourcePath.Contains("Lora")); + var checkpointItem = plan.Items.First(i => i.SourcePath.Contains("StableDiffusion")); + + // Lora should stay within the Lora type folder + Assert.IsTrue( + loraItem.TargetPath!.Contains(Path.Combine("Lora", "SD 1.5")), + $"Lora target should stay in Lora folder, got: {loraItem.TargetPath}" + ); + + // Checkpoint should stay within the StableDiffusion type folder + Assert.IsTrue( + checkpointItem.TargetPath!.Contains(Path.Combine("StableDiffusion", "SDXL")), + $"Checkpoint target should stay in StableDiffusion folder, got: {checkpointItem.TargetPath}" + ); + } + finally + { + Directory.Delete(tempRoot, true); + } + } + + [TestMethod] + public void BuildPlan_PreservesDotsInFileName() + { + var tempRoot = CreateTempDirectory(); + + try + { + var scopePath = Path.Combine(tempRoot, "Checkpoints"); + var model = CreateModelFile( + tempRoot, + Path.Combine("Checkpoints", "local-file.safetensors"), + "wan2.1_i2v_480p_14B_fp8_e4m3fn.safetensors" + ); + + var plan = service.BuildPlan( + [model], + tempRoot, + scopePath, + includeNested: true, + template: "{file_name}" + ); + + Assert.AreEqual(1, plan.Items.Count); + Assert.AreEqual(1, plan.ReadyCount); + Assert.AreEqual( + Path.Combine(scopePath, "wan2.1_i2v_480p_14B_fp8_e4m3fn.safetensors"), + plan.Items[0].TargetPath + ); + } + finally + { + Directory.Delete(tempRoot, true); + } + } + + [TestMethod] + public void BuildPlan_PreservesSelectedNestedScopeForMultiSegmentTemplates() + { + var tempRoot = CreateTempDirectory(); + + try + { + var scopePath = Path.Combine(tempRoot, "StableDiffusion", "Favorites"); + var model = CreateModelFile( + tempRoot, + Path.Combine("StableDiffusion", "Favorites", "local-file.safetensors"), + "remote-file.safetensors", + baseModel: "SDXL" + ); + + var plan = service.BuildPlan( + [model], + tempRoot, + scopePath, + includeNested: true, + template: "{base_model}/{file_name}" + ); + + Assert.AreEqual(1, plan.Items.Count); + Assert.AreEqual( + Path.Combine(scopePath, "SDXL", "remote-file.safetensors"), + plan.Items[0].TargetPath + ); + } + finally + { + Directory.Delete(tempRoot, true); + } + } + + [TestMethod] + public async Task ApplyPlan_RollsBackCompletedMovesWhenLaterMoveFails() + { + var tempRoot = CreateTempDirectory(); + + try + { + var scopePath = Path.Combine(tempRoot, "Checkpoints"); + var relativePath = Path.Combine("Checkpoints", "Source", "local-file.safetensors"); + var model = CreateModelFile(tempRoot, relativePath, "remote-file.safetensors"); + var sourcePath = Path.Combine(tempRoot, relativePath); + var sourceDirectory = Path.GetDirectoryName(sourcePath)!; + var cmInfoPath = Path.Combine(sourceDirectory, "local-file.cm-info.json"); + var previewPath = Path.Combine(sourceDirectory, "local-file.preview.png"); + + await File.WriteAllTextAsync(cmInfoPath, "{}"); + await File.WriteAllTextAsync(previewPath, "preview"); + + var plan = service.BuildPlan( + [model], + tempRoot, + scopePath, + includeNested: true, + template: "organized/{file_name}" + ); + + File.Delete(previewPath); + + var result = await service.ApplyPlan(plan); + + var targetDirectory = Path.Combine(scopePath, "organized"); + Assert.AreEqual(0, result.MovedCount); + Assert.AreEqual(1, result.SkippedCount); + Assert.AreEqual(1, result.Errors.Count); + Assert.IsTrue(File.Exists(sourcePath)); + Assert.IsTrue(File.Exists(cmInfoPath)); + Assert.IsFalse(File.Exists(Path.Combine(targetDirectory, "remote-file.safetensors"))); + Assert.IsFalse(File.Exists(Path.Combine(targetDirectory, "remote-file.cm-info.json"))); + } + finally + { + Directory.Delete(tempRoot, true); + } + } + + private static LocalModelFile CreateModelFile( + string root, + string relativePath, + string remoteFileName, + string? authorUsername = null, + string? baseModel = null + ) + { + var fullPath = Path.Combine(root, relativePath); + Directory.CreateDirectory(Path.GetDirectoryName(fullPath)!); + File.WriteAllText(fullPath, "model"); + + return new LocalModelFile + { + RelativePath = relativePath, + SharedFolderType = SharedFolderType.StableDiffusion, + ConnectedModelInfo = new ConnectedModelInfo + { + ModelId = 123, + ModelName = "Remote Model", + VersionId = 456, + VersionName = "Version One", + ModelType = CivitModelType.Checkpoint, + Hashes = new CivitFileHashes { BLAKE3 = "hash" }, + AuthorUsername = authorUsername, + BaseModel = baseModel, + RemoteFileName = remoteFileName, + RemoteFileId = 321, + Source = ConnectedModelSource.Civitai, + }, + }; + } + + private static string CreateTempDirectory() + { + var path = Path.Combine(Path.GetTempPath(), $"sm-organizer-tests-{Guid.NewGuid():N}"); + Directory.CreateDirectory(path); + return path; + } +} diff --git a/StabilityMatrix.Tests/Avalonia/OrganizeModelsDialogViewModelTests.cs b/StabilityMatrix.Tests/Avalonia/OrganizeModelsDialogViewModelTests.cs new file mode 100644 index 000000000..1a9eb46b9 --- /dev/null +++ b/StabilityMatrix.Tests/Avalonia/OrganizeModelsDialogViewModelTests.cs @@ -0,0 +1,39 @@ +using NSubstitute; +using StabilityMatrix.Avalonia.Models.CheckpointOrganizer; +using StabilityMatrix.Avalonia.Services; +using StabilityMatrix.Avalonia.ViewModels.Dialogs; +using StabilityMatrix.Core.Services; + +namespace StabilityMatrix.Tests.Avalonia; + +[TestClass] +public class OrganizeModelsDialogViewModelTests +{ + [TestMethod] + public void ScanForMetadataCommand_RequestsMissingMetadataScan() + { + var viewModel = CreateViewModel(); + + viewModel.ScanForMetadataCommand.Execute(null); + + Assert.AreEqual(ModelOrganizationMetadataAction.ScanMissing, viewModel.RequestedMetadataAction); + } + + [TestMethod] + public void UpdateMetadataCommand_RequestsUpdateExistingMetadata() + { + var viewModel = CreateViewModel(); + + viewModel.UpdateMetadataCommand.Execute(null); + + Assert.AreEqual(ModelOrganizationMetadataAction.UpdateExisting, viewModel.RequestedMetadataAction); + } + + private static OrganizeModelsDialogViewModel CreateViewModel() + { + return new OrganizeModelsDialogViewModel( + Substitute.For(), + new ModelOrganizationService() + ); + } +} diff --git a/StabilityMatrix.Tests/Core/CivArchiveApiClientTests.cs b/StabilityMatrix.Tests/Core/CivArchiveApiClientTests.cs new file mode 100644 index 000000000..183af5b37 --- /dev/null +++ b/StabilityMatrix.Tests/Core/CivArchiveApiClientTests.cs @@ -0,0 +1,357 @@ +using System.Net; +using Microsoft.Extensions.Logging.Abstractions; +using NSubstitute; +using StabilityMatrix.Core.Api; +using StabilityMatrix.Core.Models.Api.CivArchive; + +namespace StabilityMatrix.Tests.Core; + +[TestClass] +public class CivArchiveApiClientTests +{ + [TestMethod] + public void BuildSearchDataPath_UsesDefaultFilterValues() + { + var result = CivArchiveApiClient.BuildSearchDataPath("/top-models", new CivArchiveSearchFilters()); + + Assert.AreEqual( + "/top-models.json?platform=all&sort=top&rating=safe&platform_status=all&kind=all&period=all&page=1", + result + ); + } + + [TestMethod] + public void BuildSearchDataPath_SerializesMultiSelectFilters() + { + var result = CivArchiveApiClient.BuildSearchDataPath( + "/top-models", + new CivArchiveSearchFilters + { + Types = ["LORA", "Checkpoint"], + BaseModels = ["Illustrious", "Pony"], + Page = 2, + } + ); + + StringAssert.Contains(result, "type=LORA%2CCheckpoint"); + StringAssert.Contains(result, "base_model=Illustrious%2CPony"); + StringAssert.Contains(result, "page=2"); + } + + [TestMethod] + public void BuildDetailDataPath_RewritesModelScopeCnPlatformRoutes() + { + var result = CivArchiveApiClient.BuildDetailDataPath("/modelscope_cn/models/123/versions/456"); + + Assert.AreEqual("/models/123.json?modelVersionId=456&platform=modelscope_cn", result); + } + + [TestMethod] + public async Task SearchAsync_RefreshesBuildIdAfter404() + { + var requests = new List(); + var responses = new Queue( + [ + CreateJsonResponse("""""", "text/html"), + new HttpResponseMessage(HttpStatusCode.NotFound), + CreateJsonResponse("""""", "text/html"), + CreateJsonResponse( + ListResponseJson( + """{"id":"v1","name":"Model","kind":"version","url":"/models/1?modelVersionId=2"}""" + ) + ), + ] + ); + + var client = CreateClient( + new RecordingHandler( + (request, _) => + { + requests.Add(request.RequestUri!.ToString()); + return responses.Dequeue(); + } + ) + ); + + var response = await client.SearchAsync(new CivArchiveSearchFilters()); + + Assert.AreEqual(1, response.Results.Count); + Assert.IsTrue(requests.Any(x => x.Contains("/_next/data/old-build/"))); + Assert.IsTrue(requests.Any(x => x.Contains("/_next/data/new-build/"))); + } + + [TestMethod] + public async Task SearchAsync_ParsesVersionAndFileResults() + { + var listJson = ListResponseJson( + """ + {"id":"v2581228","name":"CyberRealistic Pony v16.0","type":"Checkpoint","kind":"version","download_count":59326,"url":"/models/443821?modelVersionId=2581228","base_model":"Pony","image_url":"https://example.org/image.jpg","created_at":1767967595,"username":"Cyberdelia","platform":"civitai"}, + {"id":"rf_hash","name":"realDream_14Hyper.safetensors","kind":"file","download_count":0,"url":"/sha256/a00019e86d53aece9858347e4df8a774a6d2933c30d0691faa9beb0cc56e7366","username":"Carlos2312","platform":"huggingface","created_at":1771938405} + """ + ); + + var responses = new Queue( + [ + CreateJsonResponse("""""", "text/html"), + CreateJsonResponse(listJson), + ] + ); + + var client = CreateClient(new RecordingHandler((_, _) => responses.Dequeue())); + var response = await client.SearchAsync(new CivArchiveSearchFilters()); + + Assert.AreEqual(2, response.Results.Count); + Assert.AreEqual(CivArchiveKindOption.Version, response.Results[0].Kind); + Assert.AreEqual(CivArchiveKindOption.File, response.Results[1].Kind); + Assert.AreEqual( + "a00019e86d53aece9858347e4df8a774a6d2933c30d0691faa9beb0cc56e7366", + response.Results[1].Sha256FromUrl + ); + } + + [TestMethod] + public async Task SearchAsync_CachesIdenticalFilterRequestsWithinTtl() + { + // Two SearchAsync calls with identical filters should produce exactly one + // /_next/data/.../top-models.json fetch β€” repeated requests within the TTL hit + // the in-memory cache, which is the main 429 mitigation when users toggle + // filters back and forth. + var dataRequestCount = 0; + var listJson = ListResponseJson( + """{"id":"v1","name":"Cached","kind":"version","url":"/models/1?modelVersionId=2"}""" + ); + + var responses = new Queue( + [ + CreateJsonResponse("""""", "text/html"), + CreateJsonResponse(listJson), + ] + ); + + var client = CreateClient( + new RecordingHandler( + (request, _) => + { + if (request.RequestUri!.AbsolutePath.Contains("top-models.json")) + { + dataRequestCount++; + } + return responses.Dequeue(); + } + ) + ); + + var filters = new CivArchiveSearchFilters { Platform = CivArchivePlatformOption.Civitai }; + + var first = await client.SearchAsync(filters); + var second = await client.SearchAsync(filters); + + Assert.AreEqual(1, dataRequestCount, "Second identical request should have been served from cache"); + Assert.AreSame(first, second, "Cache should return the same response instance"); + } + + [TestMethod] + public async Task GetFilterOptionsAsync_UsesParameterlessRouteAndParsesLists() + { + // CivArchive only echoes the populated baseModels / modelTypes lists when the URL + // has no query string at all β€” we hit /top-models.json directly, no filter params. + var requestedUrls = new List(); + const string filterOptionsJson = + "{\"pageProps\":{\"filterOptions\":{" + + "\"baseModels\":[\"Flux.1 D\",\"Illustrious\",\"SDXL 1.0\"]," + + "\"modelTypes\":[\"Checkpoint\",\"LORA\",\"VAE\"]}}}"; + + var responses = new Queue( + [ + CreateJsonResponse("""""", "text/html"), + CreateJsonResponse(filterOptionsJson), + ] + ); + + var client = CreateClient( + new RecordingHandler( + (request, _) => + { + requestedUrls.Add(request.RequestUri!.ToString()); + return responses.Dequeue(); + } + ) + ); + + var options = await client.GetFilterOptionsAsync(); + + CollectionAssert.AreEqual( + new[] { "Flux.1 D", "Illustrious", "SDXL 1.0" }, + options.BaseModels.ToArray() + ); + CollectionAssert.AreEqual(new[] { "Checkpoint", "LORA", "VAE" }, options.ModelTypes.ToArray()); + + // Critical: no query string. If any param leaks in, the API silently returns + // empty arrays and the dropdowns end up blank. + Assert.IsTrue( + requestedUrls.Any(u => u.EndsWith("/_next/data/test-build/top-models.json")), + $"Expected parameterless top-models.json fetch, got: {string.Join(", ", requestedUrls)}" + ); + } + + [TestMethod] + public async Task SearchAsync_ParsesArrayShapedPlatformEcho() + { + // Regression: when the user picks a non-default platform, CivArchive's + // pageProps.filters.platform comes back as an array (["civitai"]) rather than + // a bare string, which used to throw a JsonException at $.pageProps.filters.platform. + const string listJson = + "{\"pageProps\":{\"canonicalUrl\":\"https://civarchive.com/top-models\",\"data\":{\"results\":[],\"hits\":0,\"totalHits\":0}," + + "\"filters\":{\"q\":\"\",\"type\":\"all\",\"base_model\":\"all\",\"platform\":[\"civitai\"]," + + "\"sort\":\"top\",\"rating\":\"safe\",\"platform_status\":\"all\",\"kind\":\"all\"," + + "\"tags\":\"\",\"username\":\"\",\"period\":\"all\",\"page\":1}," + + "\"filterOptions\":{\"baseModels\":[],\"modelTypes\":[]}}}"; + + var responses = new Queue( + [ + CreateJsonResponse("""""", "text/html"), + CreateJsonResponse(listJson), + ] + ); + + var client = CreateClient(new RecordingHandler((_, _) => responses.Dequeue())); + var response = await client.SearchAsync( + new CivArchiveSearchFilters { Platform = CivArchivePlatformOption.Civitai } + ); + + Assert.AreEqual(CivArchivePlatformOption.Civitai, response.EffectiveFilters.Platform); + } + + [TestMethod] + public async Task GetModelDetailsAsync_ParsesFilesMirrorsAndSha256() + { + // Field naming matches the real CivArchive API (snake_case throughout) + const string detailJson = """ + {"pageProps":{"model":{"id":153568,"name":"Real Dream","type":"Checkpoint","download_count":4923593,"favorite_count":12,"rating":4.5,"rating_count":8,"created_at":"2026-04-17T23:30:06Z","creator_username":"sinatra","platform":"civitai","platform_name":"CivitAI","version":{"id":2053273,"name":"SDXL 7","base_model":"SDXL 1.0","description":"

Version description

","download_count":12345,"created_at":"2026-04-17T23:30:06Z","files":[{"id":1950275,"name":"realDream_sdxl7.safetensors","type":"Model","size_kb":6775783.6,"download_url":"https://civitai.com/api/download/models/2053273","sha256":"63b1db60611f52c4fbb2cade67dbdf4029c6620c5b22f2a4ddb27a47d7601953","is_primary":true,"created_at":"2026-04-17T23:30:06Z","mirrors":[{"filename":"realDream_sdxl7.safetensors","url":"https://civitai.com/api/download/models/2053273","source":"civitai","is_gated":false,"is_paid":false}]}],"images":[{"id":1,"url":"https://example.org/image.webp","link":"https://example.org/image.webp","type":"image"}],"mirrors":[{"platform":"tungsten","platform_url":"https://tungsten.run/model/kZ7yDBQjZP?model_version=L2KrgferKS","version_name":"SDXL 7"}]}}}} + """; + + var responses = new Queue( + [ + CreateJsonResponse("""""", "text/html"), + CreateJsonResponse(detailJson), + ] + ); + + var client = CreateClient(new RecordingHandler((_, _) => responses.Dequeue())); + var response = await client.GetModelDetailsAsync("/models/153568?modelVersionId=2053273"); + + Assert.AreEqual("Real Dream", response.Model.Name); + Assert.AreEqual("SDXL 7", response.Model.Version?.Name); + Assert.AreEqual(1, response.Model.Version?.Files.Count); + Assert.AreEqual( + "63b1db60611f52c4fbb2cade67dbdf4029c6620c5b22f2a4ddb27a47d7601953", + response.Model.Version?.Files[0].Sha256 + ); + Assert.AreEqual(1, response.Model.Version?.Mirrors.Count); + + // Snake-case field names β€” these all silently defaulted to 0/null when the DTO + // mapped them to camelCase (downloadCount/baseModel/sizeKB/etc.). + Assert.AreEqual(4923593, response.Model.DownloadCount); + Assert.AreEqual(12, response.Model.FavoriteCount); + Assert.AreEqual(4.5, response.Model.Rating); + Assert.AreEqual(8, response.Model.RatingCount); + Assert.IsNotNull(response.Model.CreatedAt); + Assert.AreEqual("SDXL 1.0", response.Model.Version?.BaseModel); + Assert.AreEqual(12345, response.Model.Version?.DownloadCount); + Assert.AreEqual(6775783.6, response.Model.Version?.Files[0].SizeKb); + Assert.AreEqual( + "https://civitai.com/api/download/models/2053273", + response.Model.Version?.Files[0].DownloadUrl + ); + Assert.IsTrue(response.Model.Version?.Files[0].IsPrimary); + } + + [TestMethod] + public async Task ResolveFileUrlAsync_ReturnsLinkedVersionHref() + { + // /sha256/{hash} returns pageProps.models[] (plural) with full model data inside, + // including version.href β€” which is the canonical URL we want to navigate to. + const string sha256Json = """ + {"pageProps":{"id":"file-1","models":[{"id":878387,"name":"Stable Diffusion 3.5 Large","type":"Checkpoint","versions":[{"id":983602,"name":"Workflow","href":"/models/878387?modelVersionId=983602"}],"version":{"id":983309,"name":"Large","base_model":"SD 3.5 Large","href":"/models/878387?modelVersionId=983309"}}]}} + """; + + var responses = new Queue( + [ + CreateJsonResponse("""""", "text/html"), + CreateJsonResponse(sha256Json), + ] + ); + + var client = CreateClient(new RecordingHandler((_, _) => responses.Dequeue())); + var resolved = await client.ResolveFileUrlAsync( + "/sha256/ffef7a279d9134626e6ce0d494fba84fc1c7e720b3c7df2d19a09dc3796d8f93" + ); + + // Prefer version.href (the version that actually contains this file) over versions[0].href. + Assert.AreEqual("/models/878387?modelVersionId=983309", resolved); + } + + [TestMethod] + public async Task ResolveFileUrlAsync_NoLinkedModel_ReturnsNull() + { + // Orphaned hash with no linked models β†’ should return null so the caller can + // fall back to opening the URL externally instead of navigating to a dead page. + const string sha256Json = """{"pageProps":{"id":"file-2","models":[]}}"""; + + var responses = new Queue( + [ + CreateJsonResponse("""""", "text/html"), + CreateJsonResponse(sha256Json), + ] + ); + + var client = CreateClient(new RecordingHandler((_, _) => responses.Dequeue())); + var resolved = await client.ResolveFileUrlAsync("/sha256/abc"); + + Assert.IsNull(resolved); + } + + private static ICivArchiveApiClient CreateClient(HttpMessageHandler handler) + { + var httpClientFactory = Substitute.For(); + httpClientFactory + .CreateClient() + .Returns(new HttpClient(handler) { BaseAddress = new Uri("https://civarchive.com") }); + + return new CivArchiveApiClient(NullLogger.Instance, httpClientFactory); + } + + private static HttpResponseMessage CreateJsonResponse( + string content, + string mediaType = "application/json" + ) + { + return new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent(content) + { + Headers = { ContentType = new System.Net.Http.Headers.MediaTypeHeaderValue(mediaType) }, + }, + }; + } + + private static string ListResponseJson(string resultsJson) + { + return "{\"pageProps\":{\"canonicalUrl\":\"https://civarchive.com/top-models\",\"data\":{\"results\":[" + + resultsJson + + "],\"hits\":2,\"totalHits\":2},\"filters\":{\"q\":\"\",\"type\":\"all\",\"base_model\":\"all\",\"platform\":\"all\",\"sort\":\"top\",\"rating\":\"safe\",\"platform_status\":\"all\",\"kind\":\"all\",\"tags\":\"\",\"username\":\"\",\"period\":\"all\",\"page\":1},\"filterOptions\":{\"baseModels\":[\"Illustrious\",\"Pony\"],\"modelTypes\":[\"LORA\",\"Checkpoint\"]}}}"; + } + + private sealed class RecordingHandler( + Func responder + ) : HttpMessageHandler + { + protected override Task SendAsync( + HttpRequestMessage request, + CancellationToken cancellationToken + ) + { + return Task.FromResult(responder(request, cancellationToken)); + } + } +} diff --git a/StabilityMatrix.Tests/Core/MetadataImportServiceTests.cs b/StabilityMatrix.Tests/Core/MetadataImportServiceTests.cs new file mode 100644 index 000000000..bd3fe4737 --- /dev/null +++ b/StabilityMatrix.Tests/Core/MetadataImportServiceTests.cs @@ -0,0 +1,202 @@ +using Microsoft.Extensions.Logging; +using NSubstitute; +using StabilityMatrix.Core.Api; +using StabilityMatrix.Core.Database; +using StabilityMatrix.Core.Helper; +using StabilityMatrix.Core.Models; +using StabilityMatrix.Core.Models.Api; +using StabilityMatrix.Core.Models.FileInterfaces; +using StabilityMatrix.Core.Models.Progress; +using StabilityMatrix.Core.Services; + +namespace StabilityMatrix.Tests.Core; + +[TestClass] +public class MetadataImportServiceTests +{ + [TestMethod] + public void FromJson_AllowsMissingOrganizerFields() + { + const string json = """ + { + "ModelName": "Sample", + "Hashes": { + "BLAKE3": "hash" + } + } + """; + + var result = ConnectedModelInfo.FromJson(json); + + Assert.IsNotNull(result); + Assert.AreEqual("Sample", result.ModelName); + Assert.IsNull(result.AuthorUsername); + Assert.IsNull(result.RemoteFileName); + Assert.IsNull(result.RemoteFileId); + } + + [TestMethod] + public async Task UpdateExistingMetadata_MergesUserFieldsAndBackfillsRemoteFields() + { + var tempRoot = CreateTempDirectory(); + + try + { + var existingInfo = new ConnectedModelInfo + { + ModelName = "Old Name", + Hashes = new CivitFileHashes { BLAKE3 = "blake3-hash" }, + ImportedAt = new DateTimeOffset(2024, 1, 2, 3, 4, 5, TimeSpan.Zero), + UserTitle = "Pinned Name", + ThumbnailImageUrl = Path.Combine(tempRoot, "existing.preview.png"), + InferenceDefaults = new InferenceDefaults { Steps = 30, CfgScale = 7 }, + Source = ConnectedModelSource.Civitai, + }; + await existingInfo.SaveJsonToDirectory(tempRoot, "model"); + + var service = CreateMetadataImportService(); + ConfigureCivitLookup(service.Api, "blake3-hash"); + + await service.Instance.UpdateExistingMetadata(new DirectoryPath(tempRoot)); + + var updated = ConnectedModelInfo.FromJson( + await File.ReadAllTextAsync(Path.Combine(tempRoot, "model.cm-info.json")) + ); + + Assert.IsNotNull(updated); + Assert.AreEqual("Remote Model", updated.ModelName); + Assert.AreEqual("Pinned Name", updated.UserTitle); + Assert.AreEqual(existingInfo.ThumbnailImageUrl, updated.ThumbnailImageUrl); + Assert.AreEqual(existingInfo.ImportedAt, updated.ImportedAt); + Assert.IsNotNull(updated.InferenceDefaults); + Assert.AreEqual(30, updated.InferenceDefaults.Steps); + Assert.AreEqual("creator-name", updated.AuthorUsername); + Assert.AreEqual("remote-file.safetensors", updated.RemoteFileName); + Assert.AreEqual(321, updated.RemoteFileId); + } + finally + { + Directory.Delete(tempRoot, true); + } + } + + [TestMethod] + public async Task GetMetadataForFile_ForceReimport_PersistsDownloadedThumbnailPath() + { + var tempRoot = CreateTempDirectory(); + + try + { + var modelPath = Path.Combine(tempRoot, "model.safetensors"); + await File.WriteAllTextAsync(modelPath, "small model file"); + + var hash = await FileHash.GetBlake3Async(new FilePath(modelPath)); + var service = CreateMetadataImportService(); + ConfigureCivitLookup(service.Api, hash, includeImage: true); + service + .DownloadService.DownloadToFileAsync( + Arg.Any(), + Arg.Any(), + Arg.Any?>(), + Arg.Any(), + Arg.Any() + ) + .Returns(callInfo => + { + File.WriteAllText(callInfo.ArgAt(1), "preview"); + return Task.CompletedTask; + }); + + var result = await service.Instance.GetMetadataForFile( + new FilePath(modelPath), + forceReimport: true + ); + + Assert.IsNotNull(result); + var expectedPreviewPath = Path.Combine(tempRoot, "model.preview.png"); + Assert.AreEqual(expectedPreviewPath, result.ThumbnailImageUrl); + Assert.IsTrue(File.Exists(expectedPreviewPath)); + + var saved = ConnectedModelInfo.FromJson( + await File.ReadAllTextAsync(Path.Combine(tempRoot, "model.cm-info.json")) + ); + Assert.IsNotNull(saved); + Assert.AreEqual(expectedPreviewPath, saved.ThumbnailImageUrl); + } + finally + { + Directory.Delete(tempRoot, true); + } + } + + private static string CreateTempDirectory() + { + var path = Path.Combine(Path.GetTempPath(), $"sm-metadata-tests-{Guid.NewGuid():N}"); + Directory.CreateDirectory(path); + return path; + } + + private static void ConfigureCivitLookup(ICivitApi api, string hash, bool includeImage = false) + { + var file = new CivitFile + { + Id = 321, + Name = "remote-file.safetensors", + DownloadUrl = "https://example.invalid/model", + Type = CivitFileType.Model, + Metadata = new CivitFileMetadata { Size = "pruned" }, + Hashes = new CivitFileHashes { BLAKE3 = hash, SHA256 = "sha256" }, + }; + var version = new CivitModelVersion + { + Id = 456, + Name = "Version One", + BaseModel = "SDXL", + Files = [file], + Images = includeImage + ? [new CivitImage { Url = "https://example.invalid/preview.png", Type = "image" }] + : [], + TrainedWords = ["tag-a"], + }; + var model = new CivitModel + { + Id = 123, + Name = "Remote Model", + Type = CivitModelType.Checkpoint, + Tags = ["tag-a"], + Creator = new CivitCreator { Username = "creator-name" }, + ModelVersions = [version], + Stats = new CivitModelStats(), + }; + + api.GetModelVersionByHash(hash) + .Returns( + Task.FromResult( + new CivitModelVersionResponse( + version.Id, + model.Id, + version.Name, + version.BaseModel!, + [file], + version.Images ?? [], + file.DownloadUrl + ) + ) + ); + api.GetModelById(model.Id).Returns(Task.FromResult(model)); + } + + private static ( + MetadataImportService Instance, + ICivitApi Api, + IDownloadService DownloadService + ) CreateMetadataImportService() + { + var api = Substitute.For(); + var db = Substitute.For(); + var logger = Substitute.For>(); + var downloadService = Substitute.For(); + var finder = new ModelFinder(db, api); + return (new MetadataImportService(logger, downloadService, finder), api, downloadService); + } +} diff --git a/StabilityMatrix.Tests/Core/NotificationHistoryServiceTests.cs b/StabilityMatrix.Tests/Core/NotificationHistoryServiceTests.cs new file mode 100644 index 000000000..cfdf710f1 --- /dev/null +++ b/StabilityMatrix.Tests/Core/NotificationHistoryServiceTests.cs @@ -0,0 +1,197 @@ +using StabilityMatrix.Core.Models.Notifications; +using StabilityMatrix.Core.Models.Settings; +using StabilityMatrix.Core.Services; + +namespace StabilityMatrix.Tests.Core; + +[TestClass] +public class NotificationHistoryServiceTests +{ + private static NotificationHistoryEntry MakeEntry(string title = "t", bool read = false) => + new() + { + Title = title, + Body = "b", + Level = NotificationLevel.Information, + IsRead = read, + }; + + [TestMethod] + public void Add_StoresEntriesNewestFirst() + { + var svc = new NotificationHistoryService(); + var a = svc.Add(MakeEntry("first")); + var b = svc.Add(MakeEntry("second")); + var c = svc.Add(MakeEntry("third")); + + var entries = svc.Entries.ToList(); + Assert.AreEqual(3, entries.Count); + Assert.AreEqual(c.Id, entries[0].Id); + Assert.AreEqual(b.Id, entries[1].Id); + Assert.AreEqual(a.Id, entries[2].Id); + } + + [TestMethod] + public void Add_RaisesEntryAddedEvent() + { + var svc = new NotificationHistoryService(); + NotificationHistoryEntry? raised = null; + svc.EntryAdded += (_, e) => raised = e; + + var entry = svc.Add(MakeEntry()); + + Assert.IsNotNull(raised); + Assert.AreEqual(entry.Id, raised!.Id); + } + + [TestMethod] + public void Add_EnforcesMaxEntriesCap() + { + var svc = new NotificationHistoryService(); + for (var i = 0; i < 150; i++) + { + svc.Add(MakeEntry($"e{i}")); + } + + Assert.AreEqual(100, svc.Entries.Count); + // Newest preserved + Assert.AreEqual("e149", svc.Entries.First().Title); + // Oldest evicted + Assert.AreEqual("e50", svc.Entries.Last().Title); + } + + [TestMethod] + public void UnreadCount_TracksUnreadEntries() + { + var svc = new NotificationHistoryService(); + svc.Add(MakeEntry()); + svc.Add(MakeEntry()); + svc.Add(MakeEntry(read: true)); + + Assert.AreEqual(2, svc.UnreadCount); + } + + [TestMethod] + public void MarkRead_FlipsEntryAndRaisesChange() + { + var svc = new NotificationHistoryService(); + var entry = svc.Add(MakeEntry()); + var raised = 0; + svc.EntriesChanged += (_, _) => raised++; + + svc.MarkRead(entry.Id); + + Assert.IsTrue(svc.Find(entry.Id)!.IsRead); + Assert.AreEqual(0, svc.UnreadCount); + Assert.AreEqual(1, raised); + } + + [TestMethod] + public void MarkRead_NoOpForAlreadyReadDoesNotRaise() + { + var svc = new NotificationHistoryService(); + var entry = svc.Add(MakeEntry(read: true)); + var raised = 0; + svc.EntriesChanged += (_, _) => raised++; + + svc.MarkRead(entry.Id); + + Assert.AreEqual(0, raised); + } + + [TestMethod] + public void MarkAllRead_FlipsEveryUnreadEntry() + { + var svc = new NotificationHistoryService(); + svc.Add(MakeEntry()); + svc.Add(MakeEntry()); + svc.Add(MakeEntry(read: true)); + + svc.MarkAllRead(); + + Assert.AreEqual(0, svc.UnreadCount); + } + + [TestMethod] + public void Remove_DropsEntryAndRaisesChange() + { + var svc = new NotificationHistoryService(); + var a = svc.Add(MakeEntry("a")); + var b = svc.Add(MakeEntry("b")); + var raised = 0; + svc.EntriesChanged += (_, _) => raised++; + + svc.Remove(a.Id); + + var remaining = svc.Entries.ToList(); + Assert.AreEqual(1, remaining.Count); + Assert.AreEqual(b.Id, remaining[0].Id); + Assert.AreEqual(1, raised); + } + + [TestMethod] + public void Clear_EmptiesEntriesAndRaisesChangeWhenNonEmpty() + { + var svc = new NotificationHistoryService(); + svc.Add(MakeEntry()); + svc.Add(MakeEntry()); + var raised = 0; + svc.EntriesChanged += (_, _) => raised++; + + svc.Clear(); + + Assert.AreEqual(0, svc.Entries.Count); + Assert.AreEqual(1, raised); + } + + [TestMethod] + public void Clear_OnEmptyDoesNotRaise() + { + var svc = new NotificationHistoryService(); + var raised = 0; + svc.EntriesChanged += (_, _) => raised++; + + svc.Clear(); + + Assert.AreEqual(0, raised); + } + + [TestMethod] + public void Find_ReturnsEntryById() + { + var svc = new NotificationHistoryService(); + var entry = svc.Add(MakeEntry("findme")); + + var found = svc.Find(entry.Id); + + Assert.IsNotNull(found); + Assert.AreEqual("findme", found!.Title); + } + + [TestMethod] + public void Find_ReturnsNullWhenMissing() + { + var svc = new NotificationHistoryService(); + + Assert.IsNull(svc.Find(Guid.NewGuid())); + } + + [TestMethod] + public void Entry_CarriesNotificationAction() + { + var svc = new NotificationHistoryService(); + var action = new OpenFolderAction(@"C:\some\path"); + + var entry = svc.Add( + new NotificationHistoryEntry + { + Title = "t", + Action = action, + Level = NotificationLevel.Success, + } + ); + + Assert.IsInstanceOfType(svc.Find(entry.Id)!.Action); + Assert.AreEqual(@"C:\some\path", ((OpenFolderAction)svc.Find(entry.Id)!.Action!).Path); + } +} diff --git a/StabilityMatrix.UITests/Fakes/TestCivArchiveApiClient.cs b/StabilityMatrix.UITests/Fakes/TestCivArchiveApiClient.cs new file mode 100644 index 000000000..dff71bffd --- /dev/null +++ b/StabilityMatrix.UITests/Fakes/TestCivArchiveApiClient.cs @@ -0,0 +1,162 @@ +using StabilityMatrix.Core.Api; +using StabilityMatrix.Core.Models.Api.CivArchive; + +namespace StabilityMatrix.UITests.Fakes; + +public class TestCivArchiveApiClient : ICivArchiveApiClient +{ + public Task GetBuildIdAsync(CancellationToken cancellationToken = default) + { + return Task.FromResult("test-build"); + } + + public Task SearchAsync( + CivArchiveSearchFilters filters, + CancellationToken cancellationToken = default + ) + { + var page = filters.Page; + IReadOnlyList results = page switch + { + 1 => new List + { + new CivArchiveSearchResult + { + Id = "v-1", + Name = "Test Version 1", + KindRaw = "version", + Type = "Checkpoint", + BaseModel = "Pony", + Url = "/models/1?modelVersionId=11", + Username = "artist-one", + Platform = "civitai", + }, + new CivArchiveSearchResult + { + Id = "f-1", + Name = "test-file-1.safetensors", + KindRaw = "file", + Url = "/sha256/abc123", + Username = "artist-one", + Platform = "huggingface", + }, + }, + 2 => new List + { + new CivArchiveSearchResult + { + Id = "v-2", + Name = "Test Version 2", + KindRaw = "version", + Type = "LORA", + BaseModel = "Illustrious", + Url = "/models/2?modelVersionId=22", + Username = "artist-two", + Platform = "seaart", + }, + new CivArchiveSearchResult + { + Id = "u-1", + Name = "artist-two", + KindRaw = "user", + Url = "/users/artist-two", + Username = "artist-two", + Platform = "seaart", + }, + }, + _ => [], + }; + + return Task.FromResult( + new CivArchiveSearchResponse + { + Results = results, + FilterOptions = new CivArchiveFilterOptions + { + BaseModels = ["Illustrious", "Pony"], + ModelTypes = ["LORA", "Checkpoint"], + }, + EffectiveFilters = new CivArchiveSearchFilters + { + Page = page, + Platform = filters.Platform, + Sort = filters.Sort, + Period = filters.Period, + Rating = filters.Rating, + PlatformStatus = filters.PlatformStatus, + Kind = filters.Kind, + Query = filters.Query, + Tags = filters.Tags, + Username = filters.Username, + }, + CanonicalUrl = "https://civarchive.com/top-models", + Hits = results.Count, + TotalHits = 4, + } + ); + } + + public Task GetFilterOptionsAsync(CancellationToken cancellationToken = default) + { + return Task.FromResult( + new CivArchiveFilterOptions + { + BaseModels = ["Illustrious", "Pony"], + ModelTypes = ["LORA", "Checkpoint"], + } + ); + } + + public Task GetModelDetailsAsync( + string relativeUrl, + CancellationToken cancellationToken = default + ) + { + return Task.FromResult( + new CivArchiveModelDetailsResponse + { + Model = new CivArchiveModelDetails + { + Name = "Test Model", + PlatformName = "CivitAI", + Type = "Checkpoint", + Version = new CivArchiveModelVersion + { + Name = "Test Version", + BaseModel = "Pony", + Files = + [ + new CivArchiveModelFile + { + Name = "test.safetensors", + Sha256 = "abc123", + Mirrors = + [ + new CivArchiveFileMirror + { + Source = "civitai", + Url = "https://example.org/file", + }, + ], + }, + ], + }, + }, + } + ); + } + + public Task ResolveFileUrlAsync( + string sha256RelativeUrl, + CancellationToken cancellationToken = default + ) + { + // Map our seeded test file to the matching test version. Anything else returns null + // so the caller falls back to opening the URL externally. + return Task.FromResult( + sha256RelativeUrl == "/sha256/abc123" ? "/models/1?modelVersionId=11" : null + ); + } + + public Uri GetAbsoluteUri(string relativeUrl) => new($"https://civarchive.com{relativeUrl}"); +} diff --git a/StabilityMatrix.UITests/ModelBrowser/CivArchiveBrowserTests.cs b/StabilityMatrix.UITests/ModelBrowser/CivArchiveBrowserTests.cs new file mode 100644 index 000000000..b6eaa5ce1 --- /dev/null +++ b/StabilityMatrix.UITests/ModelBrowser/CivArchiveBrowserTests.cs @@ -0,0 +1,49 @@ +using Avalonia.Controls; +using Avalonia.Threading; +using FluentAvalonia.UI.Controls; +using Microsoft.Extensions.DependencyInjection; +using StabilityMatrix.Avalonia.Services; +using StabilityMatrix.Avalonia.ViewModels; +using StabilityMatrix.Avalonia.ViewModels.CheckpointBrowser; +using StabilityMatrix.Avalonia.Views; + +namespace StabilityMatrix.UITests.ModelBrowser; + +[Collection("TempDir")] +public class CivArchiveBrowserTests : TestBase +{ + [AvaloniaFact] + public async Task CivArchiveTab_ShouldLoadAndAppendResults() + { + var (window, _) = GetMainWindow(); + await DoInitialSetup(); + + var navigationService = Services.GetRequiredService>(); + navigationService.NavigateTo(); + + var frame = window.FindControl("FrameView"); + var page = Assert.IsType( + await WaitHelper.WaitForConditionAsync( + () => frame!.Content, + content => content is CheckpointBrowserPage + ) + ); + var vm = Assert.IsType(page.DataContext); + + var civArchiveTab = vm.Pages.First(item => Equals(item.Header, "CivArchive")); + vm.SelectedPage = civArchiveTab; + var civArchiveVm = Assert.IsType(civArchiveTab.Content); + civArchiveVm.OnLoaded(); + + await WaitHelper.WaitForConditionAsync(() => civArchiveVm.Results.Count, count => count == 2); + Dispatcher.UIThread.RunJobs(); + + Assert.Equal(2, civArchiveVm.Results.Count); + + await civArchiveVm.LoadNextPageAsync(); + await Task.Delay(100); + Dispatcher.UIThread.RunJobs(); + + Assert.Equal(4, civArchiveVm.Results.Count); + } +} diff --git a/StabilityMatrix.UITests/Snapshots/MainWindowTests.MainWindowViewModel_ShouldOk.received.txt b/StabilityMatrix.UITests/Snapshots/MainWindowTests.MainWindowViewModel_ShouldOk.received.txt new file mode 100644 index 000000000..4ab94819f --- /dev/null +++ b/StabilityMatrix.UITests/Snapshots/MainWindowTests.MainWindowViewModel_ShouldOk.received.txt @@ -0,0 +1,25 @@ +ο»Ώ{ + Greeting: Welcome to Avalonia!, + ProgressManagerViewModel: { + Title: Download Manager, + IconSource: { + Type: SymbolIconSource + }, + IsOpen: false, + CanNavigateNext: false, + CanNavigatePrevious: false, + RemoveFromParentListCommand: ViewModelBase.RemoveFromParentList(), + HasErrors: false + }, + UpdateViewModel: { + IsUpdateAvailable: false, + IsProgressIndeterminate: false, + ShowProgressBar: false, + InstallUpdateCommand: UpdateViewModel.InstallUpdate(), + RemoveFromParentListCommand: ViewModelBase.RemoveFromParentList(), + HasErrors: false + }, + PaneWidth: 200.0, + RemoveFromParentListCommand: ViewModelBase.RemoveFromParentList(), + HasErrors: false +} \ No newline at end of file diff --git a/StabilityMatrix.UITests/Snapshots/MainWindowTests.MainWindow_ShouldOpen.received.png b/StabilityMatrix.UITests/Snapshots/MainWindowTests.MainWindow_ShouldOpen.received.png new file mode 100644 index 000000000..7f017510b Binary files /dev/null and b/StabilityMatrix.UITests/Snapshots/MainWindowTests.MainWindow_ShouldOpen.received.png differ diff --git a/StabilityMatrix.UITests/Snapshots/MainWindowTests.MainWindow_ShouldOpen.received.txt b/StabilityMatrix.UITests/Snapshots/MainWindowTests.MainWindow_ShouldOpen.received.txt new file mode 100644 index 000000000..cce8d4ea4 --- /dev/null +++ b/StabilityMatrix.UITests/Snapshots/MainWindowTests.MainWindow_ShouldOpen.received.txt @@ -0,0 +1,6020 @@ +ο»Ώ{ + Type: MainWindow, + Title: Stability Matrix, + Icon: {}, + TransparencyLevelHint: [ + {}, + {}, + {} + ], + TransparencyBackgroundFallback: Transparent, + Content: { + Type: Grid, + Children: [ + { + Type: Grid, + Background: Transparent, + Height: 32.0, + Name: TitleBarHost, + Children: [ + { + Type: Image, + Source: { + Dpi: { + X: 96.0, + Y: 96.0, + Length: 135.7645019878171, + SquaredLength: 18432.0 + }, + Size: { + AspectRatio: 1.0, + Width: 256.0, + Height: 256.0 + }, + PixelSize: { + AspectRatio: 1.0, + Width: 256, + Height: 256 + }, + Format: { + BitsPerPixel: 32 + }, + AlphaFormat: Premul + }, + IsHitTestVisible: false, + Width: 18.0, + Height: 18.0, + Margin: 12,4,12,4, + IsVisible: true, + Name: WindowIcon + }, + { + Type: TextBlock, + FontSize: 12.0, + Text: Stability Matrix, + IsHitTestVisible: false, + VerticalAlignment: Center, + IsVisible: true + }, + { + Type: Border, + Padding: 6 + } + ] + }, + { + Type: NavigationView, + Content: { + Type: Frame, + Content: { + Type: PackageManagerPage, + Content: { + Type: Grid, + Children: [ + { + Type: BreadcrumbBar, + Margin: 16,8,16,8, + Name: BreadcrumbBar + }, + { + Type: Frame, + Content: { + Type: MainPackageManagerView, + Content: { + Type: Grid, + Margin: 16, + Children: [ + { + Type: ScrollViewer, + Content: { + Type: ItemsRepeater, + Name: PackageCardsRepeater + } + }, + { + Type: TeachingTip, + Name: TeachingTip1 + }, + { + Type: TeachingTip, + MinWidth: 100.0, + Margin: 8,0,0,0, + Name: LaunchTeachingTip + }, + { + Type: Button, + Command: {}, + Content: { + Type: StackPanel, + Orientation: Horizontal, + Margin: 8, + Children: [ + { + Type: SymbolIcon + }, + { + Type: TextBlock, + Text: Add Package, + Margin: 4,0,0,0 + } + ] + }, + Margin: 0,8,0,0, + HorizontalAlignment: Stretch, + VerticalAlignment: Bottom, + Name: AddPackagesButton + } + ] + }, + DataContext: { + Title: Packages, + IconSource: { + Type: SymbolIconSource + }, + NavigateToSubPageCommand: MainPackageManagerViewModel.NavigateToSubPage(Type viewModelType), + CanNavigateNext: false, + CanNavigatePrevious: false, + RemoveFromParentListCommand: ViewModelBase.RemoveFromParentList(), + HasErrors: false + } + }, + Name: FrameView + } + ] + }, + DataContext: { + Title: Packages, + IconSource: { + Type: SymbolIconSource + }, + SubPages: [ + { + Title: Packages, + IconSource: { + Type: SymbolIconSource + }, + NavigateToSubPageCommand: MainPackageManagerViewModel.NavigateToSubPage(Type viewModelType), + CanNavigateNext: false, + CanNavigatePrevious: false, + RemoveFromParentListCommand: ViewModelBase.RemoveFromParentList(), + HasErrors: false + }, + { + InferencePackages: [ + { + Name: stable-diffusion-webui-forge, + DisplayName: Stable Diffusion WebUI Forge, + Author: lllyasviel, + Blurb: Stable Diffusion WebUI Forge is a platform on top of Stable Diffusion WebUI (based on Gradio) to make development easier, optimize resource management, and speed up inference., + LicenseUrl: https://github.com/lllyasviel/stable-diffusion-webui-forge/blob/main/LICENSE.txt, + PreviewImageUri: https://github.com/lllyasviel/stable-diffusion-webui-forge/assets/19834515/ca5e05ed-bd86-4ced-8662-f41034648e8c, + MainBranch: main, + ShouldIgnoreReleases: true, + OutputFolderName: output, + InstallerSortOrder: ReallyRecommended, + SharedOutputFolders: { + Text2Img: [ + output/txt2img-images + ], + Img2Img: [ + output/img2img-images + ], + Extras: [ + output/extras-images + ], + Text2ImgGrids: [ + output/txt2img-grids + ], + Img2ImgGrids: [ + output/img2img-grids + ], + SVD: [ + output/svd + ], + Saved: [ + log/images + ] + }, + LaunchOptions: [ + { + Name: Host, + Type: String, + DefaultValue: localhost, + Options: [ + --server-name + ] + }, + { + Name: Port, + Type: String, + DefaultValue: 7860, + Options: [ + --port + ] + }, + { + Name: Share, + Description: Set whether to share on Gradio, + Options: [ + --share + ] + }, + { + Name: Pin Shared Memory, + Options: [ + --pin-shared-memory + ] + }, + { + Name: CUDA Malloc, + Options: [ + --cuda-malloc + ] + }, + { + Name: CUDA Stream, + Options: [ + --cuda-stream + ] + }, + { + Name: Always Offload from VRAM, + Options: [ + --always-offload-from-vram + ] + }, + { + Name: Always GPU, + Options: [ + --always-gpu + ] + }, + { + Name: Always CPU, + Options: [ + --always-cpu + ] + }, + { + Name: Use DirectML, + InitialValue: false, + Options: [ + --directml + ] + }, + { + Name: Skip Torch CUDA Test, + InitialValue: false, + Options: [ + --skip-torch-cuda-test + ] + }, + { + Name: No half-precision VAE, + InitialValue: false, + Options: [ + --no-half-vae + ] + }, + { + Name: Extra Launch Arguments, + Type: String, + Options: [ + + ] + } + ], + AvailableTorchVersions: [ + Cpu, + Cuda, + DirectMl, + Rocm, + Mps + ], + LicenseType: AGPL-3.0, + LaunchCommand: launch.py, + RelativeArgsDefinitionScriptPath: modules.cmd_args, + SharedFolders: { + StableDiffusion: [ + models/Stable-diffusion + ], + Lora: [ + models/Lora + ], + LyCORIS: [ + models/LyCORIS + ], + ESRGAN: [ + models/ESRGAN + ], + GFPGAN: [ + models/GFPGAN + ], + Codeformer: [ + models/Codeformer + ], + RealESRGAN: [ + models/RealESRGAN + ], + SwinIR: [ + models/SwinIR + ], + VAE: [ + models/VAE + ], + ApproxVAE: [ + models/VAE-approx + ], + Karlo: [ + models/karlo + ], + DeepDanbooru: [ + models/deepbooru + ], + TextualInversion: [ + embeddings + ], + Hypernetwork: [ + models/hypernetworks + ], + ControlNet: [ + models/controlnet/ControlNet + ], + LDSR: [ + models/LDSR + ], + AfterDetailer: [ + models/adetailer + ], + IpAdapter: [ + models/controlnet/IpAdapter + ], + T2IAdapter: [ + models/controlnet/T2IAdapter + ], + InvokeIpAdapters15: [ + models/controlnet/DiffusersIpAdapters + ], + InvokeIpAdaptersXl: [ + models/controlnet/DiffusersIpAdaptersXL + ], + SVD: [ + models/svd + ] + }, + AvailableSharedFolderMethods: [ + Symlink, + None + ], + ExtraLaunchArguments: , + RepositoryName: stable-diffusion-webui-forge, + RepositoryAuthor: lllyasviel, + GithubUrl: https://github.com/lllyasviel/stable-diffusion-webui-forge, + DownloadLocation: {TempPath}StabilityMatrixTest\AppDataHome\Packages\stable-diffusion-webui-forge.zip, + ByAuthor: By lllyasviel, + Disclaimer: , + OfferInOneClickInstaller: true, + UpdateAvailable: false, + IsInferenceCompatible: false, + IsCompatible: true, + SupportsExtensions: false, + AvailableVersionTypes: Commit, + Prerequisites: [ + Git, + Python310, + VcRedist + ] + }, + { + Name: stable-diffusion-webui, + DisplayName: Stable Diffusion WebUI, + Author: AUTOMATIC1111, + LicenseType: AGPL-3.0, + LicenseUrl: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/LICENSE.txt, + Blurb: A browser interface based on Gradio library for Stable Diffusion, + LaunchCommand: launch.py, + PreviewImageUri: https://github.com/AUTOMATIC1111/stable-diffusion-webui/raw/master/screenshot.png, + RelativeArgsDefinitionScriptPath: modules.cmd_args, + SharedFolders: { + StableDiffusion: [ + models/Stable-diffusion + ], + Lora: [ + models/Lora + ], + LyCORIS: [ + models/LyCORIS + ], + ESRGAN: [ + models/ESRGAN + ], + GFPGAN: [ + models/GFPGAN + ], + Codeformer: [ + models/Codeformer + ], + RealESRGAN: [ + models/RealESRGAN + ], + SwinIR: [ + models/SwinIR + ], + VAE: [ + models/VAE + ], + ApproxVAE: [ + models/VAE-approx + ], + Karlo: [ + models/karlo + ], + DeepDanbooru: [ + models/deepbooru + ], + TextualInversion: [ + embeddings + ], + Hypernetwork: [ + models/hypernetworks + ], + ControlNet: [ + models/controlnet/ControlNet + ], + LDSR: [ + models/LDSR + ], + AfterDetailer: [ + models/adetailer + ], + IpAdapter: [ + models/controlnet/IpAdapter + ], + T2IAdapter: [ + models/controlnet/T2IAdapter + ], + InvokeIpAdapters15: [ + models/controlnet/DiffusersIpAdapters + ], + InvokeIpAdaptersXl: [ + models/controlnet/DiffusersIpAdaptersXL + ], + SVD: [ + models/svd + ] + }, + SharedOutputFolders: { + Text2Img: [ + outputs/txt2img-images + ], + Img2Img: [ + outputs/img2img-images + ], + Extras: [ + outputs/extras-images + ], + Text2ImgGrids: [ + outputs/txt2img-grids + ], + Img2ImgGrids: [ + outputs/img2img-grids + ], + SVD: [ + outputs/svd + ], + Saved: [ + log/images + ] + }, + LaunchOptions: [ + { + Name: Host, + Type: String, + DefaultValue: localhost, + Options: [ + --server-name + ] + }, + { + Name: Port, + Type: String, + DefaultValue: 7860, + Options: [ + --port + ] + }, + { + Name: Share, + Description: Set whether to share on Gradio, + Options: [ + --share + ] + }, + { + Name: VRAM, + Options: [ + --lowvram, + --medvram, + --medvram-sdxl + ] + }, + { + Name: Xformers, + InitialValue: true, + Options: [ + --xformers + ] + }, + { + Name: API, + InitialValue: true, + Options: [ + --api + ] + }, + { + Name: Auto Launch Web UI, + InitialValue: false, + Options: [ + --autolaunch + ] + }, + { + Name: Skip Torch CUDA Check, + InitialValue: false, + Options: [ + --skip-torch-cuda-test + ] + }, + { + Name: Skip Python Version Check, + InitialValue: true, + Options: [ + --skip-python-version-check + ] + }, + { + Name: No Half, + Description: Do not switch the model to 16-bit floats, + InitialValue: false, + Options: [ + --no-half + ] + }, + { + Name: Skip SD Model Download, + InitialValue: false, + Options: [ + --no-download-sd-model + ] + }, + { + Name: Skip Install, + Options: [ + --skip-install + ] + }, + { + Name: Extra Launch Arguments, + Type: String, + Options: [ + + ] + } + ], + AvailableSharedFolderMethods: [ + Symlink, + None + ], + AvailableTorchVersions: [ + Cpu, + Cuda, + Rocm, + Mps + ], + MainBranch: master, + OutputFolderName: outputs, + ExtensionManager: { + RelativeInstallDirectory: extensions, + DefaultManifests: [ + { + Uri: https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json + } + ] + }, + ExtraLaunchArguments: , + RepositoryName: stable-diffusion-webui, + RepositoryAuthor: AUTOMATIC1111, + GithubUrl: https://github.com/AUTOMATIC1111/stable-diffusion-webui, + DownloadLocation: {TempPath}StabilityMatrixTest\AppDataHome\Packages\stable-diffusion-webui.zip, + ByAuthor: By AUTOMATIC1111, + Disclaimer: , + OfferInOneClickInstaller: true, + ShouldIgnoreReleases: false, + UpdateAvailable: false, + IsInferenceCompatible: false, + IsCompatible: true, + SupportsExtensions: true, + AvailableVersionTypes: GithubRelease, Commit, + Prerequisites: [ + Git, + Python310, + VcRedist + ] + }, + { + Name: ComfyUI, + DisplayName: ComfyUI, + Author: comfyanonymous, + LicenseType: GPL-3.0, + LicenseUrl: https://github.com/comfyanonymous/ComfyUI/blob/master/LICENSE, + Blurb: A powerful and modular stable diffusion GUI and backend, + LaunchCommand: main.py, + PreviewImageUri: https://github.com/comfyanonymous/ComfyUI/raw/master/comfyui_screenshot.png, + ShouldIgnoreReleases: true, + IsInferenceCompatible: true, + OutputFolderName: output, + InstallerSortOrder: InferenceCompatible, + RecommendedSharedFolderMethod: Configuration, + SharedFolders: { + StableDiffusion: [ + models/checkpoints + ], + Lora: [ + models/loras + ], + ESRGAN: [ + models/upscale_models + ], + Diffusers: [ + models/diffusers + ], + VAE: [ + models/vae + ], + ApproxVAE: [ + models/vae_approx + ], + TextualInversion: [ + models/embeddings + ], + Hypernetwork: [ + models/hypernetworks + ], + ControlNet: [ + models/controlnet/ControlNet + ], + CLIP: [ + models/clip + ], + GLIGEN: [ + models/gligen + ], + IpAdapter: [ + models/ipadapter/base + ], + T2IAdapter: [ + models/controlnet/T2IAdapter + ], + InvokeIpAdapters15: [ + models/ipadapter/sd15 + ], + InvokeIpAdaptersXl: [ + models/ipadapter/sdxl + ], + InvokeClipVision: [ + models/clip_vision + ], + PromptExpansion: [ + models/prompt_expansion + ] + }, + SharedOutputFolders: { + Text2Img: [ + output + ] + }, + LaunchOptions: [ + { + Name: Host, + Type: String, + DefaultValue: 127.0.0.1, + Options: [ + --listen + ] + }, + { + Name: Port, + Type: String, + DefaultValue: 8188, + Options: [ + --port + ] + }, + { + Name: VRAM, + Options: [ + --highvram, + --normalvram, + --lowvram, + --novram + ] + }, + { + Name: Preview Method, + InitialValue: --preview-method auto, + Options: [ + --preview-method auto, + --preview-method latent2rgb, + --preview-method taesd + ] + }, + { + Name: Enable DirectML, + InitialValue: false, + Options: [ + --directml + ] + }, + { + Name: Use CPU only, + InitialValue: false, + Options: [ + --cpu + ] + }, + { + Name: Cross Attention Method, + Options: [ + --use-split-cross-attention, + --use-quad-cross-attention, + --use-pytorch-cross-attention + ] + }, + { + Name: Force Floating Point Precision, + Options: [ + --force-fp32, + --force-fp16 + ] + }, + { + Name: VAE Precision, + Options: [ + --fp16-vae, + --fp32-vae, + --bf16-vae + ] + }, + { + Name: Disable Xformers, + InitialValue: false, + Options: [ + --disable-xformers + ] + }, + { + Name: Disable upcasting of attention, + Options: [ + --dont-upcast-attention + ] + }, + { + Name: Auto-Launch, + Options: [ + --auto-launch + ] + }, + { + Name: Extra Launch Arguments, + Type: String, + Options: [ + + ] + } + ], + MainBranch: master, + AvailableTorchVersions: [ + Cpu, + Cuda, + DirectMl, + Rocm, + Mps + ], + ExtensionManager: { + RelativeInstallDirectory: custom_nodes, + DefaultManifests: [ + { + Uri: https://cdn.jsdelivr.net/gh/ltdrdata/ComfyUI-Manager/custom-node-list.json + }, + { + Uri: https://cdn.jsdelivr.net/gh/LykosAI/ComfyUI-Extensions-Index/custom-node-list.json + } + ] + }, + RepositoryName: ComfyUI, + RepositoryAuthor: comfyanonymous, + GithubUrl: https://github.com/comfyanonymous/ComfyUI, + DownloadLocation: {TempPath}StabilityMatrixTest\AppDataHome\Packages\ComfyUI.zip, + ByAuthor: By comfyanonymous, + Disclaimer: , + OfferInOneClickInstaller: true, + UpdateAvailable: false, + IsCompatible: true, + AvailableSharedFolderMethods: [ + Symlink, + Configuration, + None + ], + SupportsExtensions: true, + AvailableVersionTypes: Commit, + Prerequisites: [ + Git, + Python310, + VcRedist + ] + }, + { + Name: Fooocus, + DisplayName: Fooocus, + Author: lllyasviel, + Blurb: Fooocus is a rethinking of Stable Diffusion and Midjourney’s designs, + LicenseType: GPL-3.0, + LicenseUrl: https://github.com/lllyasviel/Fooocus/blob/main/LICENSE, + LaunchCommand: launch.py, + PreviewImageUri: https://user-images.githubusercontent.com/19834515/261830306-f79c5981-cf80-4ee3-b06b-3fef3f8bfbc7.png, + LaunchOptions: [ + { + Name: Preset, + Options: [ + --preset anime, + --preset realistic + ] + }, + { + Name: Port, + Type: String, + Description: Sets the listen port, + Options: [ + --port + ] + }, + { + Name: Share, + Description: Set whether to share on Gradio, + Options: [ + --share + ] + }, + { + Name: Listen, + Type: String, + Description: Set the listen interface, + Options: [ + --listen + ] + }, + { + Name: Output Directory, + Type: String, + Description: Override the output directory, + Options: [ + --output-path + ] + }, + { + Name: Language, + Type: String, + Description: Change the language of the UI, + Options: [ + --language + ] + }, + { + Name: Auto-Launch, + Options: [ + --auto-launch + ] + }, + { + Name: Disable Image Log, + Options: [ + --disable-image-log + ] + }, + { + Name: Disable Analytics, + Options: [ + --disable-analytics + ] + }, + { + Name: Disable Preset Model Downloads, + Options: [ + --disable-preset-download + ] + }, + { + Name: Always Download Newer Models, + Options: [ + --always-download-new-model + ] + }, + { + Name: VRAM, + Options: [ + --always-high-vram, + --always-normal-vram, + --always-low-vram, + --always-no-vram + ] + }, + { + Name: Use DirectML, + Description: Use pytorch with DirectML support, + InitialValue: false, + Options: [ + --directml + ] + }, + { + Name: Disable Xformers, + InitialValue: false, + Options: [ + --disable-xformers + ] + }, + { + Name: Extra Launch Arguments, + Type: String, + Options: [ + + ] + } + ], + RecommendedSharedFolderMethod: Configuration, + AvailableSharedFolderMethods: [ + Symlink, + Configuration, + None + ], + SharedFolders: { + StableDiffusion: [ + models/checkpoints + ], + Lora: [ + models/loras + ], + ESRGAN: [ + models/upscale_models + ], + Diffusers: [ + models/diffusers + ], + VAE: [ + models/vae + ], + ApproxVAE: [ + models/vae_approx + ], + TextualInversion: [ + models/embeddings + ], + Hypernetwork: [ + models/hypernetworks + ], + ControlNet: [ + models/controlnet + ], + CLIP: [ + models/clip + ], + GLIGEN: [ + models/gligen + ], + InvokeClipVision: [ + models/clip_vision + ] + }, + SharedFolderLayout: { + RelativeConfigPath: config.txt, + Rules: [ + { + SourceTypes: [ + StableDiffusion + ], + TargetRelativePaths: [ + models/checkpoints + ], + ConfigDocumentPaths: [ + path_checkpoints + ] + }, + { + SourceTypes: [ + Diffusers + ], + TargetRelativePaths: [ + models/diffusers + ] + }, + { + SourceTypes: [ + CLIP + ], + TargetRelativePaths: [ + models/clip + ] + }, + { + SourceTypes: [ + GLIGEN + ], + TargetRelativePaths: [ + models/gligen + ] + }, + { + SourceTypes: [ + ESRGAN + ], + TargetRelativePaths: [ + models/upscale_models + ] + }, + { + SourceTypes: [ + Hypernetwork + ], + TargetRelativePaths: [ + models/hypernetworks + ] + }, + { + SourceTypes: [ + TextualInversion + ], + TargetRelativePaths: [ + models/embeddings + ], + ConfigDocumentPaths: [ + path_embeddings + ] + }, + { + SourceTypes: [ + VAE + ], + TargetRelativePaths: [ + models/vae + ], + ConfigDocumentPaths: [ + path_vae + ] + }, + { + SourceTypes: [ + ApproxVAE + ], + TargetRelativePaths: [ + models/vae_approx + ], + ConfigDocumentPaths: [ + path_vae_approx + ] + }, + { + SourceTypes: [ + Lora, + LyCORIS + ], + TargetRelativePaths: [ + models/loras + ], + ConfigDocumentPaths: [ + path_loras + ] + }, + { + SourceTypes: [ + InvokeClipVision + ], + TargetRelativePaths: [ + models/clip_vision + ], + ConfigDocumentPaths: [ + path_clip_vision + ] + }, + { + SourceTypes: [ + ControlNet + ], + TargetRelativePaths: [ + models/controlnet + ], + ConfigDocumentPaths: [ + path_controlnet + ] + }, + { + TargetRelativePaths: [ + models/inpaint + ], + ConfigDocumentPaths: [ + path_inpaint + ] + }, + { + TargetRelativePaths: [ + models/prompt_expansion/fooocus_expansion + ], + ConfigDocumentPaths: [ + path_fooocus_expansion + ] + }, + { + TargetRelativePaths: [ + outputs + ], + ConfigDocumentPaths: [ + path_outputs + ] + } + ] + }, + SharedOutputFolders: { + Text2Img: [ + outputs + ] + }, + AvailableTorchVersions: [ + Cpu, + Cuda, + DirectMl, + Rocm + ], + MainBranch: main, + ShouldIgnoreReleases: true, + OutputFolderName: outputs, + InstallerSortOrder: Simple, + RepositoryName: Fooocus, + RepositoryAuthor: lllyasviel, + GithubUrl: https://github.com/lllyasviel/Fooocus, + DownloadLocation: {TempPath}StabilityMatrixTest\AppDataHome\Packages\Fooocus.zip, + ByAuthor: By lllyasviel, + Disclaimer: , + OfferInOneClickInstaller: true, + UpdateAvailable: false, + IsInferenceCompatible: false, + IsCompatible: true, + SupportsExtensions: false, + AvailableVersionTypes: Commit, + Prerequisites: [ + Git, + Python310, + VcRedist + ] + }, + { + Name: voltaML-fast-stable-diffusion, + DisplayName: VoltaML, + Author: VoltaML, + LicenseType: GPL-3.0, + LicenseUrl: https://github.com/VoltaML/voltaML-fast-stable-diffusion/blob/main/License, + Blurb: Fast Stable Diffusion with support for AITemplate, + LaunchCommand: main.py, + PreviewImageUri: https://github.com/LykosAI/StabilityMatrix/assets/13956642/d9a908ed-5665-41a5-a380-98458f4679a8, + InstallerSortOrder: Simple, + ShouldIgnoreReleases: true, + SharedFolders: { + StableDiffusion: [ + data/models + ], + Lora: [ + data/lora + ], + TextualInversion: [ + data/textual-inversion + ] + }, + SharedOutputFolders: { + Text2Img: [ + data/outputs/txt2img + ], + Img2Img: [ + data/outputs/img2img + ], + Extras: [ + data/outputs/extra + ] + }, + OutputFolderName: data/outputs, + AvailableTorchVersions: [ + Cpu, + Cuda, + DirectMl + ], + AvailableSharedFolderMethods: [ + Symlink, + None + ], + LaunchOptions: [ + { + Name: Log Level, + DefaultValue: --log-level INFO, + Options: [ + --log-level DEBUG, + --log-level INFO, + --log-level WARNING, + --log-level ERROR, + --log-level CRITICAL + ] + }, + { + Name: Use ngrok to expose the API, + Options: [ + --ngrok + ] + }, + { + Name: Expose the API to the network, + Options: [ + --host + ] + }, + { + Name: Skip virtualenv check, + InitialValue: true, + Options: [ + --in-container + ] + }, + { + Name: Force VoltaML to use a specific type of PyTorch distribution, + Options: [ + --pytorch-type cpu, + --pytorch-type cuda, + --pytorch-type rocm, + --pytorch-type directml, + --pytorch-type intel, + --pytorch-type vulkan + ] + }, + { + Name: Run in tandem with the Discord bot, + Options: [ + --bot + ] + }, + { + Name: Enable Cloudflare R2 bucket upload support, + Options: [ + --enable-r2 + ] + }, + { + Name: Port, + Type: String, + DefaultValue: 5003, + Options: [ + --port + ] + }, + { + Name: Only install requirements and exit, + Options: [ + --install-only + ] + }, + { + Name: Extra Launch Arguments, + Type: String, + Options: [ + + ] + } + ], + MainBranch: main, + RepositoryName: voltaML-fast-stable-diffusion, + RepositoryAuthor: VoltaML, + GithubUrl: https://github.com/VoltaML/voltaML-fast-stable-diffusion, + DownloadLocation: {TempPath}StabilityMatrixTest\AppDataHome\Packages\voltaML-fast-stable-diffusion.zip, + ByAuthor: By VoltaML, + Disclaimer: , + OfferInOneClickInstaller: true, + UpdateAvailable: false, + IsInferenceCompatible: false, + IsCompatible: true, + SupportsExtensions: false, + AvailableVersionTypes: Commit, + Prerequisites: [ + Git, + Python310, + VcRedist + ] + }, + { + Name: stable-diffusion-webui-ux, + DisplayName: Stable Diffusion Web UI-UX, + Author: anapnoe, + LicenseType: AGPL-3.0, + LicenseUrl: https://github.com/anapnoe/stable-diffusion-webui-ux/blob/master/LICENSE.txt, + Blurb: A pixel perfect design, mobile friendly, customizable interface that adds accessibility, ease of use and extended functionallity to the stable diffusion web ui., + LaunchCommand: launch.py, + PreviewImageUri: https://raw.githubusercontent.com/anapnoe/stable-diffusion-webui-ux/master/screenshot.png, + InstallerSortOrder: Advanced, + ExtensionManager: { + RelativeInstallDirectory: extensions, + DefaultManifests: [ + { + Uri: https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json + } + ] + }, + SharedFolders: { + StableDiffusion: [ + models/Stable-diffusion + ], + Lora: [ + models/Lora + ], + LyCORIS: [ + models/LyCORIS + ], + ESRGAN: [ + models/ESRGAN + ], + Codeformer: [ + models/Codeformer + ], + RealESRGAN: [ + models/RealESRGAN + ], + SwinIR: [ + models/SwinIR + ], + VAE: [ + models/VAE + ], + ApproxVAE: [ + models/VAE-approx + ], + Karlo: [ + models/karlo + ], + DeepDanbooru: [ + models/deepbooru + ], + TextualInversion: [ + embeddings + ], + Hypernetwork: [ + models/hypernetworks + ], + ControlNet: [ + models/ControlNet + ], + LDSR: [ + models/LDSR + ], + AfterDetailer: [ + models/adetailer + ] + }, + SharedOutputFolders: { + Text2Img: [ + outputs/txt2img-images + ], + Img2Img: [ + outputs/img2img-images + ], + Extras: [ + outputs/extras-images + ], + Text2ImgGrids: [ + outputs/txt2img-grids + ], + Img2ImgGrids: [ + outputs/img2img-grids + ], + Saved: [ + log/images + ] + }, + LaunchOptions: [ + { + Name: Host, + Type: String, + DefaultValue: localhost, + Options: [ + --server-name + ] + }, + { + Name: Port, + Type: String, + DefaultValue: 7860, + Options: [ + --port + ] + }, + { + Name: VRAM, + Options: [ + --lowvram, + --medvram, + --medvram-sdxl + ] + }, + { + Name: Xformers, + InitialValue: true, + Options: [ + --xformers + ] + }, + { + Name: API, + InitialValue: true, + Options: [ + --api + ] + }, + { + Name: Auto Launch Web UI, + InitialValue: false, + Options: [ + --autolaunch + ] + }, + { + Name: Skip Torch CUDA Check, + InitialValue: false, + Options: [ + --skip-torch-cuda-test + ] + }, + { + Name: Skip Python Version Check, + InitialValue: true, + Options: [ + --skip-python-version-check + ] + }, + { + Name: No Half, + Description: Do not switch the model to 16-bit floats, + InitialValue: false, + Options: [ + --no-half + ] + }, + { + Name: Skip SD Model Download, + InitialValue: false, + Options: [ + --no-download-sd-model + ] + }, + { + Name: Skip Install, + Options: [ + --skip-install + ] + }, + { + Name: Extra Launch Arguments, + Type: String, + Options: [ + + ] + } + ], + AvailableSharedFolderMethods: [ + Symlink, + None + ], + AvailableTorchVersions: [ + Cpu, + Cuda, + Rocm, + Mps + ], + MainBranch: master, + ShouldIgnoreReleases: true, + OutputFolderName: outputs, + RepositoryName: stable-diffusion-webui-ux, + RepositoryAuthor: anapnoe, + GithubUrl: https://github.com/anapnoe/stable-diffusion-webui-ux, + DownloadLocation: {TempPath}StabilityMatrixTest\AppDataHome\Packages\stable-diffusion-webui-ux.zip, + ByAuthor: By anapnoe, + Disclaimer: , + OfferInOneClickInstaller: true, + UpdateAvailable: false, + IsInferenceCompatible: false, + IsCompatible: true, + SupportsExtensions: true, + AvailableVersionTypes: Commit, + Prerequisites: [ + Git, + Python310, + VcRedist + ] + }, + { + Name: StableSwarmUI, + RepositoryName: SwarmUI, + DisplayName: SwarmUI, + Author: mcmonkeyprojects, + Blurb: A Modular Stable Diffusion Web-User-Interface, with an emphasis on making powertools easily accessible, high performance, and extensibility., + LicenseType: MIT, + LicenseUrl: https://github.com/mcmonkeyprojects/SwarmUI/blob/master/LICENSE.txt, + LaunchCommand: , + PreviewImageUri: https://github.com/mcmonkeyprojects/SwarmUI/raw/master/.github/images/swarmui.jpg, + OutputFolderName: Output, + AvailableSharedFolderMethods: [ + Symlink, + Configuration, + None + ], + RecommendedSharedFolderMethod: Configuration, + OfferInOneClickInstaller: false, + LaunchOptions: [ + { + Name: Host, + Type: String, + DefaultValue: 127.0.0.1, + Options: [ + --host + ] + }, + { + Name: Port, + Type: String, + DefaultValue: 7801, + Options: [ + --port + ] + }, + { + Name: Ngrok Path, + Type: String, + Options: [ + --ngrok-path + ] + }, + { + Name: Ngrok Basic Auth, + Type: String, + Options: [ + --ngrok-basic-auth + ] + }, + { + Name: Cloudflared Path, + Type: String, + Options: [ + --cloudflared-path + ] + }, + { + Name: Proxy Region, + Type: String, + Options: [ + --proxy-region + ] + }, + { + Name: Launch Mode, + Options: [ + --launch-mode web, + --launch-mode webinstall + ] + }, + { + Name: Extra Launch Arguments, + Type: String, + Options: [ + + ] + } + ], + SharedFolders: { + StableDiffusion: [ + Models/Stable-Diffusion + ], + Lora: [ + Models/Lora + ], + VAE: [ + Models/VAE + ], + TextualInversion: [ + Models/Embeddings + ], + ControlNet: [ + Models/controlnet + ], + InvokeClipVision: [ + Models/clip_vision + ] + }, + SharedOutputFolders: { + Text2Img: [ + Output + ] + }, + MainBranch: master, + ShouldIgnoreReleases: true, + AvailableTorchVersions: [ + Cpu, + Cuda, + DirectMl, + Rocm, + Mps + ], + InstallerSortOrder: Advanced, + Prerequisites: [ + Git, + Dotnet, + Python310, + VcRedist + ], + RepositoryAuthor: mcmonkeyprojects, + GithubUrl: https://github.com/mcmonkeyprojects/SwarmUI, + DownloadLocation: {TempPath}StabilityMatrixTest\AppDataHome\Packages\StableSwarmUI.zip, + ByAuthor: By mcmonkeyprojects, + Disclaimer: , + UpdateAvailable: false, + IsInferenceCompatible: false, + IsCompatible: true, + SupportsExtensions: false, + AvailableVersionTypes: Commit + }, + { + Name: RuinedFooocus, + DisplayName: RuinedFooocus, + Author: runew0lf, + Blurb: RuinedFooocus combines the best aspects of Stable Diffusion and Midjourney into one seamless, cutting-edge experience, + LicenseUrl: https://github.com/runew0lf/RuinedFooocus/blob/main/LICENSE, + PreviewImageUri: https://raw.githubusercontent.com/runew0lf/pmmconfigs/main/RuinedFooocus_ss.png, + InstallerSortOrder: Expert, + AvailableSharedFolderMethods: [ + Symlink, + None + ], + LaunchOptions: [ + { + Name: Preset, + Options: [ + --preset anime, + --preset realistic + ] + }, + { + Name: Port, + Type: String, + Description: Sets the listen port, + Options: [ + --port + ] + }, + { + Name: Share, + Description: Set whether to share on Gradio, + Options: [ + --share + ] + }, + { + Name: Listen, + Type: String, + Description: Set the listen interface, + Options: [ + --listen + ] + }, + { + Name: Output Directory, + Type: String, + Description: Override the output directory, + Options: [ + --output-directory + ] + }, + { + Name: VRAM, + Options: [ + --highvram, + --normalvram, + --lowvram, + --novram + ] + }, + { + Name: Use DirectML, + Description: Use pytorch with DirectML support, + InitialValue: false, + Options: [ + --directml + ] + }, + { + Name: Disable Xformers, + InitialValue: false, + Options: [ + --disable-xformers + ] + }, + { + Name: Auto-Launch, + Options: [ + --auto-launch + ] + }, + { + Name: Extra Launch Arguments, + Type: String, + Options: [ + + ] + } + ], + LicenseType: GPL-3.0, + LaunchCommand: launch.py, + SharedFolders: { + StableDiffusion: [ + models/checkpoints + ], + Lora: [ + models/loras + ], + ESRGAN: [ + models/upscale_models + ], + Diffusers: [ + models/diffusers + ], + VAE: [ + models/vae + ], + ApproxVAE: [ + models/vae_approx + ], + TextualInversion: [ + models/embeddings + ], + Hypernetwork: [ + models/hypernetworks + ], + ControlNet: [ + models/controlnet + ], + CLIP: [ + models/clip + ], + GLIGEN: [ + models/gligen + ], + InvokeClipVision: [ + models/clip_vision + ] + }, + SharedFolderLayout: { + RelativeConfigPath: config.txt, + Rules: [ + { + SourceTypes: [ + StableDiffusion + ], + TargetRelativePaths: [ + models/checkpoints + ], + ConfigDocumentPaths: [ + path_checkpoints + ] + }, + { + SourceTypes: [ + Diffusers + ], + TargetRelativePaths: [ + models/diffusers + ] + }, + { + SourceTypes: [ + CLIP + ], + TargetRelativePaths: [ + models/clip + ] + }, + { + SourceTypes: [ + GLIGEN + ], + TargetRelativePaths: [ + models/gligen + ] + }, + { + SourceTypes: [ + ESRGAN + ], + TargetRelativePaths: [ + models/upscale_models + ] + }, + { + SourceTypes: [ + Hypernetwork + ], + TargetRelativePaths: [ + models/hypernetworks + ] + }, + { + SourceTypes: [ + TextualInversion + ], + TargetRelativePaths: [ + models/embeddings + ], + ConfigDocumentPaths: [ + path_embeddings + ] + }, + { + SourceTypes: [ + VAE + ], + TargetRelativePaths: [ + models/vae + ], + ConfigDocumentPaths: [ + path_vae + ] + }, + { + SourceTypes: [ + ApproxVAE + ], + TargetRelativePaths: [ + models/vae_approx + ], + ConfigDocumentPaths: [ + path_vae_approx + ] + }, + { + SourceTypes: [ + Lora, + LyCORIS + ], + TargetRelativePaths: [ + models/loras + ], + ConfigDocumentPaths: [ + path_loras + ] + }, + { + SourceTypes: [ + InvokeClipVision + ], + TargetRelativePaths: [ + models/clip_vision + ], + ConfigDocumentPaths: [ + path_clip_vision + ] + }, + { + SourceTypes: [ + ControlNet + ], + TargetRelativePaths: [ + models/controlnet + ], + ConfigDocumentPaths: [ + path_controlnet + ] + }, + { + TargetRelativePaths: [ + models/inpaint + ], + ConfigDocumentPaths: [ + path_inpaint + ] + }, + { + TargetRelativePaths: [ + models/prompt_expansion/fooocus_expansion + ], + ConfigDocumentPaths: [ + path_fooocus_expansion + ] + }, + { + TargetRelativePaths: [ + outputs + ], + ConfigDocumentPaths: [ + path_outputs + ] + } + ] + }, + SharedOutputFolders: { + Text2Img: [ + outputs + ] + }, + AvailableTorchVersions: [ + Cpu, + Cuda, + DirectMl, + Rocm + ], + MainBranch: main, + ShouldIgnoreReleases: true, + OutputFolderName: outputs, + RepositoryName: RuinedFooocus, + RepositoryAuthor: runew0lf, + GithubUrl: https://github.com/runew0lf/RuinedFooocus, + DownloadLocation: {TempPath}StabilityMatrixTest\AppDataHome\Packages\RuinedFooocus.zip, + ByAuthor: By runew0lf, + Disclaimer: , + OfferInOneClickInstaller: true, + UpdateAvailable: false, + IsInferenceCompatible: false, + IsCompatible: true, + SupportsExtensions: false, + AvailableVersionTypes: Commit, + Prerequisites: [ + Git, + Python310, + VcRedist + ] + }, + { + Name: automatic, + DisplayName: SD.Next, + Author: vladmandic, + LicenseType: AGPL-3.0, + LicenseUrl: https://github.com/vladmandic/automatic/blob/master/LICENSE.txt, + Blurb: Stable Diffusion implementation with advanced features and modern UI, + LaunchCommand: launch.py, + PreviewImageUri: https://github.com/vladmandic/automatic/raw/master/html/screenshot-modernui.jpg, + ShouldIgnoreReleases: true, + InstallerSortOrder: Expert, + AvailableTorchVersions: [ + Cpu, + Cuda, + DirectMl, + Rocm, + Zluda + ], + SharedFolders: { + StableDiffusion: [ + models/Stable-diffusion + ], + Lora: [ + models/Lora + ], + LyCORIS: [ + models/LyCORIS + ], + ESRGAN: [ + models/ESRGAN + ], + GFPGAN: [ + models/GFPGAN + ], + BSRGAN: [ + models/BSRGAN + ], + Codeformer: [ + models/Codeformer + ], + Diffusers: [ + models/Diffusers + ], + RealESRGAN: [ + models/RealESRGAN + ], + SwinIR: [ + models/SwinIR + ], + VAE: [ + models/VAE + ], + TextualInversion: [ + models/embeddings + ], + Hypernetwork: [ + models/hypernetworks + ], + ControlNet: [ + models/ControlNet + ], + LDSR: [ + models/LDSR + ], + CLIP: [ + models/CLIP + ], + ScuNET: [ + models/ScuNET + ] + }, + SharedOutputFolders: { + Text2Img: [ + outputs/text + ], + Img2Img: [ + outputs/image + ], + Extras: [ + outputs/extras + ], + Text2ImgGrids: [ + outputs/grids + ], + Img2ImgGrids: [ + outputs/grids + ], + Saved: [ + outputs/save + ] + }, + OutputFolderName: outputs, + ExtensionManager: { + RelativeInstallDirectory: extensions, + DefaultManifests: [ + { + Uri: https://vladmandic.github.io/sd-data/pages/extensions.json + } + ] + }, + LaunchOptions: [ + { + Name: Host, + Type: String, + DefaultValue: localhost, + Options: [ + --server-name + ] + }, + { + Name: Port, + Type: String, + DefaultValue: 7860, + Options: [ + --port + ] + }, + { + Name: VRAM, + Options: [ + --lowvram, + --medvram + ] + }, + { + Name: Auto-Launch Web UI, + Options: [ + --autolaunch + ] + }, + { + Name: Force use of Intel OneAPI XPU backend, + Options: [ + --use-ipex + ] + }, + { + Name: Use DirectML if no compatible GPU is detected, + InitialValue: false, + Options: [ + --use-directml + ] + }, + { + Name: Force use of Nvidia CUDA backend, + InitialValue: true, + Options: [ + --use-cuda + ] + }, + { + Name: Force use of AMD ROCm backend, + InitialValue: false, + Options: [ + --use-rocm + ] + }, + { + Name: Force use of ZLUDA backend, + InitialValue: false, + Options: [ + --use-zluda + ] + }, + { + Name: CUDA Device ID, + Type: String, + Options: [ + --device-id + ] + }, + { + Name: API, + Options: [ + --api + ] + }, + { + Name: Debug Logging, + Options: [ + --debug + ] + }, + { + Name: Extra Launch Arguments, + Type: String, + Options: [ + + ] + } + ], + ExtraLaunchArguments: , + MainBranch: master, + RepositoryName: automatic, + RepositoryAuthor: vladmandic, + GithubUrl: https://github.com/vladmandic/automatic, + DownloadLocation: {TempPath}StabilityMatrixTest\AppDataHome\Packages\automatic.zip, + ByAuthor: By vladmandic, + Disclaimer: , + OfferInOneClickInstaller: true, + UpdateAvailable: false, + IsInferenceCompatible: false, + IsCompatible: true, + AvailableSharedFolderMethods: [ + Symlink, + Configuration, + None + ], + SupportsExtensions: true, + AvailableVersionTypes: Commit, + Prerequisites: [ + Git, + Python310, + VcRedist + ] + }, + { + Name: sdfx, + DisplayName: SDFX, + Author: sdfxai, + Blurb: The ultimate no-code platform to build and share AI apps with beautiful UI., + LicenseType: AGPL-3.0, + LicenseUrl: https://github.com/sdfxai/sdfx/blob/main/LICENSE, + LaunchCommand: setup.py, + PreviewImageUri: https://github.com/sdfxai/sdfx/raw/main/docs/static/screen-sdfx.png, + OutputFolderName: data\media\output, + AvailableTorchVersions: [ + Cpu, + Cuda, + DirectMl, + Rocm, + Mps + ], + InstallerSortOrder: Expert, + RecommendedSharedFolderMethod: Configuration, + LaunchOptions: [ + { + Name: Extra Launch Arguments, + Type: String, + Options: [ + + ] + } + ], + SharedFolders: { + StableDiffusion: [ + data/models/checkpoints + ], + Lora: [ + data/models/loras + ], + ESRGAN: [ + data/models/upscale_models + ], + Diffusers: [ + data/models/diffusers + ], + VAE: [ + data/models/vae + ], + ApproxVAE: [ + data/models/vae_approx + ], + TextualInversion: [ + data/models/embeddings + ], + Hypernetwork: [ + data/models/hypernetworks + ], + ControlNet: [ + data/models/controlnet/ControlNet + ], + CLIP: [ + data/models/clip + ], + GLIGEN: [ + data/models/gligen + ], + IpAdapter: [ + data/models/ipadapter/base + ], + T2IAdapter: [ + data/models/controlnet/T2IAdapter + ], + InvokeIpAdapters15: [ + data/models/ipadapter/sd15 + ], + InvokeIpAdaptersXl: [ + data/models/ipadapter/sdxl + ], + InvokeClipVision: [ + data/models/clip_vision + ], + PromptExpansion: [ + data/models/prompt_expansion + ] + }, + SharedOutputFolders: { + Text2Img: [ + data/media/output + ] + }, + MainBranch: main, + ShouldIgnoreReleases: true, + Prerequisites: [ + Python310, + VcRedist, + Git, + Node + ], + RepositoryName: sdfx, + RepositoryAuthor: sdfxai, + GithubUrl: https://github.com/sdfxai/sdfx, + DownloadLocation: {TempPath}StabilityMatrixTest\AppDataHome\Packages\sdfx.zip, + ByAuthor: By sdfxai, + Disclaimer: , + OfferInOneClickInstaller: true, + UpdateAvailable: false, + IsInferenceCompatible: false, + IsCompatible: true, + AvailableSharedFolderMethods: [ + Symlink, + Configuration, + None + ], + SupportsExtensions: false, + AvailableVersionTypes: Commit + }, + { + Name: InvokeAI, + DisplayName: InvokeAI, + Author: invoke-ai, + LicenseType: Apache-2.0, + LicenseUrl: https://github.com/invoke-ai/InvokeAI/blob/main/LICENSE, + Blurb: Professional Creative Tools for Stable Diffusion, + LaunchCommand: invokeai-web, + InstallerSortOrder: Nightmare, + ExtraLaunchCommands: [ + invokeai-db-maintenance, + invokeai-import-images + ], + PreviewImageUri: https://raw.githubusercontent.com/invoke-ai/InvokeAI/main/docs/assets/canvas_preview.png, + AvailableSharedFolderMethods: [ + None + ], + RecommendedSharedFolderMethod: None, + MainBranch: main, + SharedFolders: { + StableDiffusion: [ + invokeai-root\autoimport\main + ], + Lora: [ + invokeai-root\autoimport\lora + ], + TextualInversion: [ + invokeai-root\autoimport\embedding + ], + ControlNet: [ + invokeai-root\autoimport\controlnet + ], + T2IAdapter: [ + invokeai-root\autoimport\t2i_adapter + ], + InvokeIpAdapters15: [ + invokeai-root\models\sd-1\ip_adapter + ], + InvokeIpAdaptersXl: [ + invokeai-root\models\sdxl\ip_adapter + ], + InvokeClipVision: [ + invokeai-root\models\any\clip_vision + ] + }, + SharedOutputFolders: { + Text2Img: [ + invokeai-root\outputs\images + ] + }, + OutputFolderName: invokeai-root\outputs\images, + LaunchOptions: [ + { + Name: Host, + Type: String, + DefaultValue: localhost, + Options: [ + --host + ] + }, + { + Name: Port, + Type: String, + DefaultValue: 9090, + Options: [ + --port + ] + }, + { + Name: Allow Origins, + Type: String, + Description: List of host names or IP addresses that are allowed to connect to the InvokeAI API in the format ['host1','host2',...], + DefaultValue: [], + Options: [ + --allow-origins + ] + }, + { + Name: Precision, + Options: [ + --precision auto, + --precision float16, + --precision float32 + ] + }, + { + Name: Extra Launch Arguments, + Type: String, + Options: [ + + ] + } + ], + AvailableTorchVersions: [ + Cpu, + Cuda, + Rocm, + Mps + ], + Prerequisites: [ + Python310, + VcRedist, + Git, + Node + ], + RepositoryName: InvokeAI, + RepositoryAuthor: invoke-ai, + GithubUrl: https://github.com/invoke-ai/InvokeAI, + DownloadLocation: {TempPath}StabilityMatrixTest\AppDataHome\Packages\InvokeAI.zip, + ByAuthor: By invoke-ai, + Disclaimer: , + OfferInOneClickInstaller: true, + ShouldIgnoreReleases: false, + UpdateAvailable: false, + IsInferenceCompatible: false, + IsCompatible: true, + SupportsExtensions: false, + AvailableVersionTypes: GithubRelease, Commit + }, + { + Name: Fooocus-ControlNet-SDXL, + DisplayName: Fooocus-ControlNet, + Author: fenneishi, + Blurb: Fooocus-ControlNet adds more control to the original Fooocus software., + Disclaimer: This package may no longer be actively maintained, + LicenseUrl: https://github.com/fenneishi/Fooocus-ControlNet-SDXL/blob/main/LICENSE, + PreviewImageUri: https://github.com/fenneishi/Fooocus-ControlNet-SDXL/raw/main/asset/canny/snip.png, + InstallerSortOrder: Impossible, + OfferInOneClickInstaller: false, + SharedFolderLayout: { + RelativeConfigPath: user_path_config.txt, + Rules: [ + { + SourceTypes: [ + StableDiffusion + ], + TargetRelativePaths: [ + models/checkpoints + ], + ConfigDocumentPaths: [ + path_checkpoints + ] + }, + { + SourceTypes: [ + Diffusers + ], + TargetRelativePaths: [ + models/diffusers + ] + }, + { + SourceTypes: [ + CLIP + ], + TargetRelativePaths: [ + models/clip + ] + }, + { + SourceTypes: [ + GLIGEN + ], + TargetRelativePaths: [ + models/gligen + ] + }, + { + SourceTypes: [ + ESRGAN + ], + TargetRelativePaths: [ + models/upscale_models + ] + }, + { + SourceTypes: [ + Hypernetwork + ], + TargetRelativePaths: [ + models/hypernetworks + ] + }, + { + SourceTypes: [ + TextualInversion + ], + TargetRelativePaths: [ + models/embeddings + ], + ConfigDocumentPaths: [ + path_embeddings + ] + }, + { + SourceTypes: [ + VAE + ], + TargetRelativePaths: [ + models/vae + ], + ConfigDocumentPaths: [ + path_vae + ] + }, + { + SourceTypes: [ + ApproxVAE + ], + TargetRelativePaths: [ + models/vae_approx + ], + ConfigDocumentPaths: [ + path_vae_approx + ] + }, + { + SourceTypes: [ + Lora, + LyCORIS + ], + TargetRelativePaths: [ + models/loras + ], + ConfigDocumentPaths: [ + path_loras + ] + }, + { + SourceTypes: [ + InvokeClipVision + ], + TargetRelativePaths: [ + models/clip_vision + ], + ConfigDocumentPaths: [ + path_clip_vision + ] + }, + { + SourceTypes: [ + ControlNet + ], + TargetRelativePaths: [ + models/controlnet + ], + ConfigDocumentPaths: [ + path_controlnet + ] + }, + { + TargetRelativePaths: [ + models/inpaint + ], + ConfigDocumentPaths: [ + path_inpaint + ] + }, + { + TargetRelativePaths: [ + models/prompt_expansion/fooocus_expansion + ], + ConfigDocumentPaths: [ + path_fooocus_expansion + ] + }, + { + TargetRelativePaths: [ + outputs + ], + ConfigDocumentPaths: [ + path_outputs + ] + } + ] + }, + LicenseType: GPL-3.0, + LaunchCommand: launch.py, + LaunchOptions: [ + { + Name: Preset, + Options: [ + --preset anime, + --preset realistic + ] + }, + { + Name: Port, + Type: String, + Description: Sets the listen port, + Options: [ + --port + ] + }, + { + Name: Share, + Description: Set whether to share on Gradio, + Options: [ + --share + ] + }, + { + Name: Listen, + Type: String, + Description: Set the listen interface, + Options: [ + --listen + ] + }, + { + Name: Output Directory, + Type: String, + Description: Override the output directory, + Options: [ + --output-path + ] + }, + { + Name: Language, + Type: String, + Description: Change the language of the UI, + Options: [ + --language + ] + }, + { + Name: Auto-Launch, + Options: [ + --auto-launch + ] + }, + { + Name: Disable Image Log, + Options: [ + --disable-image-log + ] + }, + { + Name: Disable Analytics, + Options: [ + --disable-analytics + ] + }, + { + Name: Disable Preset Model Downloads, + Options: [ + --disable-preset-download + ] + }, + { + Name: Always Download Newer Models, + Options: [ + --always-download-new-model + ] + }, + { + Name: VRAM, + Options: [ + --always-high-vram, + --always-normal-vram, + --always-low-vram, + --always-no-vram + ] + }, + { + Name: Use DirectML, + Description: Use pytorch with DirectML support, + InitialValue: false, + Options: [ + --directml + ] + }, + { + Name: Disable Xformers, + InitialValue: false, + Options: [ + --disable-xformers + ] + }, + { + Name: Extra Launch Arguments, + Type: String, + Options: [ + + ] + } + ], + RecommendedSharedFolderMethod: Configuration, + AvailableSharedFolderMethods: [ + Symlink, + Configuration, + None + ], + SharedFolders: { + StableDiffusion: [ + models/checkpoints + ], + Lora: [ + models/loras + ], + ESRGAN: [ + models/upscale_models + ], + Diffusers: [ + models/diffusers + ], + VAE: [ + models/vae + ], + ApproxVAE: [ + models/vae_approx + ], + TextualInversion: [ + models/embeddings + ], + Hypernetwork: [ + models/hypernetworks + ], + ControlNet: [ + models/controlnet + ], + CLIP: [ + models/clip + ], + GLIGEN: [ + models/gligen + ], + InvokeClipVision: [ + models/clip_vision + ] + }, + SharedOutputFolders: { + Text2Img: [ + outputs + ] + }, + AvailableTorchVersions: [ + Cpu, + Cuda, + DirectMl, + Rocm + ], + MainBranch: main, + ShouldIgnoreReleases: true, + OutputFolderName: outputs, + RepositoryName: Fooocus-ControlNet-SDXL, + RepositoryAuthor: fenneishi, + GithubUrl: https://github.com/fenneishi/Fooocus-ControlNet-SDXL, + DownloadLocation: {TempPath}StabilityMatrixTest\AppDataHome\Packages\Fooocus-ControlNet-SDXL.zip, + ByAuthor: By fenneishi, + UpdateAvailable: false, + IsInferenceCompatible: false, + IsCompatible: true, + SupportsExtensions: false, + AvailableVersionTypes: Commit, + Prerequisites: [ + Git, + Python310, + VcRedist + ] + }, + { + Name: Fooocus-MRE, + DisplayName: Fooocus-MRE, + Author: MoonRide303, + Blurb: Fooocus-MRE is an image generating software, enhanced variant of the original Fooocus dedicated for a bit more advanced users, + LicenseType: GPL-3.0, + LicenseUrl: https://github.com/MoonRide303/Fooocus-MRE/blob/moonride-main/LICENSE, + LaunchCommand: launch.py, + PreviewImageUri: https://user-images.githubusercontent.com/130458190/265366059-ce430ea0-0995-4067-98dd-cef1d7dc1ab6.png, + Disclaimer: This package may no longer receive updates from its author. It may be removed from Stability Matrix in the future., + InstallerSortOrder: Impossible, + OfferInOneClickInstaller: false, + LaunchOptions: [ + { + Name: Port, + Type: String, + Description: Sets the listen port, + Options: [ + --port + ] + }, + { + Name: Share, + Description: Set whether to share on Gradio, + Options: [ + --share + ] + }, + { + Name: Listen, + Type: String, + Description: Set the listen interface, + Options: [ + --listen + ] + }, + { + Name: Extra Launch Arguments, + Type: String, + Options: [ + + ] + } + ], + AvailableSharedFolderMethods: [ + Symlink, + None + ], + SharedFolders: { + StableDiffusion: [ + models/checkpoints + ], + Lora: [ + models/loras + ], + ESRGAN: [ + models/upscale_models + ], + Diffusers: [ + models/diffusers + ], + VAE: [ + models/vae + ], + ApproxVAE: [ + models/vae_approx + ], + TextualInversion: [ + models/embeddings + ], + Hypernetwork: [ + models/hypernetworks + ], + ControlNet: [ + models/controlnet + ], + CLIP: [ + models/clip + ], + GLIGEN: [ + models/gligen + ] + }, + SharedOutputFolders: { + Text2Img: [ + outputs + ] + }, + AvailableTorchVersions: [ + Cpu, + Cuda, + Rocm + ], + MainBranch: moonride-main, + OutputFolderName: outputs, + RepositoryName: Fooocus-MRE, + RepositoryAuthor: MoonRide303, + GithubUrl: https://github.com/MoonRide303/Fooocus-MRE, + DownloadLocation: {TempPath}StabilityMatrixTest\AppDataHome\Packages\Fooocus-MRE.zip, + ByAuthor: By MoonRide303, + ShouldIgnoreReleases: false, + UpdateAvailable: false, + IsInferenceCompatible: false, + IsCompatible: true, + SupportsExtensions: false, + AvailableVersionTypes: GithubRelease, Commit, + Prerequisites: [ + Git, + Python310, + VcRedist + ] + } + ], + TrainingPackages: [ + { + Name: OneTrainer, + DisplayName: OneTrainer, + Author: Nerogar, + Blurb: OneTrainer is a one-stop solution for all your stable diffusion training needs, + LicenseType: AGPL-3.0, + LicenseUrl: https://github.com/Nerogar/OneTrainer/blob/master/LICENSE.txt, + LaunchCommand: scripts/train_ui.py, + PreviewImageUri: https://github.com/Nerogar/OneTrainer/blob/master/resources/icons/icon.png?raw=true, + OutputFolderName: , + RecommendedSharedFolderMethod: None, + AvailableTorchVersions: [ + Cuda + ], + IsCompatible: true, + PackageType: SdTraining, + AvailableSharedFolderMethods: [ + None + ], + InstallerSortOrder: Nightmare, + OfferInOneClickInstaller: false, + ShouldIgnoreReleases: true, + Prerequisites: [ + Git, + Python310, + VcRedist, + Tkinter + ], + LaunchOptions: [ + { + Name: Extra Launch Arguments, + Type: String, + Options: [ + + ] + } + ], + MainBranch: master, + RepositoryName: OneTrainer, + RepositoryAuthor: Nerogar, + GithubUrl: https://github.com/Nerogar/OneTrainer, + DownloadLocation: {TempPath}StabilityMatrixTest\AppDataHome\Packages\OneTrainer.zip, + ByAuthor: By Nerogar, + Disclaimer: , + UpdateAvailable: false, + IsInferenceCompatible: false, + SupportsExtensions: false, + AvailableVersionTypes: Commit + }, + { + Name: kohya_ss, + DisplayName: kohya_ss, + Author: bmaltais, + Blurb: A Windows-focused Gradio GUI for Kohya's Stable Diffusion trainers, + LicenseType: Apache-2.0, + LicenseUrl: https://github.com/bmaltais/kohya_ss/blob/master/LICENSE.md, + LaunchCommand: kohya_gui.py, + PreviewImageUri: https://camo.githubusercontent.com/5154eea62c113d5c04393e51a0d0f76ef25a723aad29d256dcc85ead1961cd41/68747470733a2f2f696d672e796f75747562652e636f6d2f76692f6b35696d713031757655592f302e6a7067, + OutputFolderName: , + IsCompatible: true, + Disclaimer: Nvidia GPU with at least 8GB VRAM is recommended. May be unstable on Linux., + InstallerSortOrder: UltraNightmare, + PackageType: SdTraining, + OfferInOneClickInstaller: false, + RecommendedSharedFolderMethod: None, + AvailableTorchVersions: [ + Cuda + ], + AvailableSharedFolderMethods: [ + None + ], + Prerequisites: [ + Git, + Python310, + VcRedist, + Tkinter + ], + LaunchOptions: [ + { + Name: Listen Address, + Type: String, + DefaultValue: 127.0.0.1, + Options: [ + --listen + ] + }, + { + Name: Port, + Type: String, + Options: [ + --port + ] + }, + { + Name: Username, + Type: String, + Options: [ + --username + ] + }, + { + Name: Password, + Type: String, + Options: [ + --password + ] + }, + { + Name: Auto-Launch Browser, + Options: [ + --inbrowser + ] + }, + { + Name: Share, + Options: [ + --share + ] + }, + { + Name: Headless, + Options: [ + --headless + ] + }, + { + Name: Language, + Type: String, + Options: [ + --language + ] + }, + { + Name: Extra Launch Arguments, + Type: String, + Options: [ + + ] + } + ], + MainBranch: master, + RepositoryName: kohya_ss, + RepositoryAuthor: bmaltais, + GithubUrl: https://github.com/bmaltais/kohya_ss, + DownloadLocation: {TempPath}StabilityMatrixTest\AppDataHome\Packages\kohya_ss.zip, + ByAuthor: By bmaltais, + ShouldIgnoreReleases: false, + UpdateAvailable: false, + IsInferenceCompatible: false, + SupportsExtensions: false, + AvailableVersionTypes: GithubRelease, Commit + } + ], + Title: Add Package, + IconSource: { + Type: SymbolIconSource + }, + ShowIncompatiblePackages: false, + SearchFilter: , + CanNavigateNext: false, + CanNavigatePrevious: false, + RemoveFromParentListCommand: ViewModelBase.RemoveFromParentList(), + HasErrors: false + } + ], + CurrentPagePath: [ + { + Title: Packages, + IconSource: { + Type: SymbolIconSource + }, + NavigateToSubPageCommand: MainPackageManagerViewModel.NavigateToSubPage(Type viewModelType), + CanNavigateNext: false, + CanNavigatePrevious: false, + RemoveFromParentListCommand: ViewModelBase.RemoveFromParentList(), + HasErrors: false + } + ], + CurrentPage: { + Title: Packages, + IconSource: { + Type: SymbolIconSource + }, + NavigateToSubPageCommand: MainPackageManagerViewModel.NavigateToSubPage(Type viewModelType), + CanNavigateNext: false, + CanNavigatePrevious: false, + RemoveFromParentListCommand: ViewModelBase.RemoveFromParentList(), + HasErrors: false + }, + CanNavigateNext: false, + CanNavigatePrevious: false, + RemoveFromParentListCommand: ViewModelBase.RemoveFromParentList(), + HasErrors: false + } + }, + Name: FrameView + }, + Name: NavigationView + }, + { + Type: TeachingTip, + Name: UpdateAvailableTeachingTip + }, + { + Type: TeachingTip, + Name: DownloadsTeachingTip + } + ] + }, + Background: #ff101010, + FontFamily: Segoe UI Variable Text, + Width: 1400.0, + Height: 900.0, + IsVisible: true, + DataContext: { + Greeting: Welcome to Avalonia!, + ProgressManagerViewModel: { + Title: Download Manager, + IconSource: { + Type: SymbolIconSource + }, + IsOpen: false, + CanNavigateNext: false, + CanNavigatePrevious: false, + RemoveFromParentListCommand: ViewModelBase.RemoveFromParentList(), + HasErrors: false + }, + UpdateViewModel: { + IsUpdateAvailable: false, + IsProgressIndeterminate: false, + ShowProgressBar: false, + InstallUpdateCommand: UpdateViewModel.InstallUpdate(), + RemoveFromParentListCommand: ViewModelBase.RemoveFromParentList(), + HasErrors: false + }, + PaneWidth: 200.0, + SelectedCategory: { + Title: Packages, + IconSource: { + Type: SymbolIconSource + }, + SubPages: [ + { + Title: Packages, + IconSource: { + Type: SymbolIconSource + }, + NavigateToSubPageCommand: MainPackageManagerViewModel.NavigateToSubPage(Type viewModelType), + CanNavigateNext: false, + CanNavigatePrevious: false, + RemoveFromParentListCommand: ViewModelBase.RemoveFromParentList(), + HasErrors: false + }, + { + InferencePackages: [ + { + Name: stable-diffusion-webui-forge, + DisplayName: Stable Diffusion WebUI Forge, + Author: lllyasviel, + Blurb: Stable Diffusion WebUI Forge is a platform on top of Stable Diffusion WebUI (based on Gradio) to make development easier, optimize resource management, and speed up inference., + LicenseUrl: https://github.com/lllyasviel/stable-diffusion-webui-forge/blob/main/LICENSE.txt, + PreviewImageUri: https://github.com/lllyasviel/stable-diffusion-webui-forge/assets/19834515/ca5e05ed-bd86-4ced-8662-f41034648e8c, + MainBranch: main, + ShouldIgnoreReleases: true, + OutputFolderName: output, + InstallerSortOrder: ReallyRecommended, + SharedOutputFolders: { + Text2Img: [ + output/txt2img-images + ], + Img2Img: [ + output/img2img-images + ], + Extras: [ + output/extras-images + ], + Text2ImgGrids: [ + output/txt2img-grids + ], + Img2ImgGrids: [ + output/img2img-grids + ], + SVD: [ + output/svd + ], + Saved: [ + log/images + ] + }, + LaunchOptions: [ + { + Name: Host, + Type: String, + DefaultValue: localhost, + Options: [ + --server-name + ] + }, + { + Name: Port, + Type: String, + DefaultValue: 7860, + Options: [ + --port + ] + }, + { + Name: Share, + Description: Set whether to share on Gradio, + Options: [ + --share + ] + }, + { + Name: Pin Shared Memory, + Options: [ + --pin-shared-memory + ] + }, + { + Name: CUDA Malloc, + Options: [ + --cuda-malloc + ] + }, + { + Name: CUDA Stream, + Options: [ + --cuda-stream + ] + }, + { + Name: Always Offload from VRAM, + Options: [ + --always-offload-from-vram + ] + }, + { + Name: Always GPU, + Options: [ + --always-gpu + ] + }, + { + Name: Always CPU, + Options: [ + --always-cpu + ] + }, + { + Name: Use DirectML, + InitialValue: false, + Options: [ + --directml + ] + }, + { + Name: Skip Torch CUDA Test, + InitialValue: false, + Options: [ + --skip-torch-cuda-test + ] + }, + { + Name: No half-precision VAE, + InitialValue: false, + Options: [ + --no-half-vae + ] + }, + { + Name: Extra Launch Arguments, + Type: String, + Options: [ + + ] + } + ], + AvailableTorchVersions: [ + Cpu, + Cuda, + DirectMl, + Rocm, + Mps + ], + LicenseType: AGPL-3.0, + LaunchCommand: launch.py, + RelativeArgsDefinitionScriptPath: modules.cmd_args, + SharedFolders: { + StableDiffusion: [ + models/Stable-diffusion + ], + Lora: [ + models/Lora + ], + LyCORIS: [ + models/LyCORIS + ], + ESRGAN: [ + models/ESRGAN + ], + GFPGAN: [ + models/GFPGAN + ], + Codeformer: [ + models/Codeformer + ], + RealESRGAN: [ + models/RealESRGAN + ], + SwinIR: [ + models/SwinIR + ], + VAE: [ + models/VAE + ], + ApproxVAE: [ + models/VAE-approx + ], + Karlo: [ + models/karlo + ], + DeepDanbooru: [ + models/deepbooru + ], + TextualInversion: [ + embeddings + ], + Hypernetwork: [ + models/hypernetworks + ], + ControlNet: [ + models/controlnet/ControlNet + ], + LDSR: [ + models/LDSR + ], + AfterDetailer: [ + models/adetailer + ], + IpAdapter: [ + models/controlnet/IpAdapter + ], + T2IAdapter: [ + models/controlnet/T2IAdapter + ], + InvokeIpAdapters15: [ + models/controlnet/DiffusersIpAdapters + ], + InvokeIpAdaptersXl: [ + models/controlnet/DiffusersIpAdaptersXL + ], + SVD: [ + models/svd + ] + }, + AvailableSharedFolderMethods: [ + Symlink, + None + ], + ExtraLaunchArguments: , + RepositoryName: stable-diffusion-webui-forge, + RepositoryAuthor: lllyasviel, + GithubUrl: https://github.com/lllyasviel/stable-diffusion-webui-forge, + DownloadLocation: {TempPath}StabilityMatrixTest\AppDataHome\Packages\stable-diffusion-webui-forge.zip, + ByAuthor: By lllyasviel, + Disclaimer: , + OfferInOneClickInstaller: true, + UpdateAvailable: false, + IsInferenceCompatible: false, + IsCompatible: true, + SupportsExtensions: false, + AvailableVersionTypes: Commit, + Prerequisites: [ + Git, + Python310, + VcRedist + ] + }, + { + Name: stable-diffusion-webui, + DisplayName: Stable Diffusion WebUI, + Author: AUTOMATIC1111, + LicenseType: AGPL-3.0, + LicenseUrl: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/LICENSE.txt, + Blurb: A browser interface based on Gradio library for Stable Diffusion, + LaunchCommand: launch.py, + PreviewImageUri: https://github.com/AUTOMATIC1111/stable-diffusion-webui/raw/master/screenshot.png, + RelativeArgsDefinitionScriptPath: modules.cmd_args, + SharedFolders: { + StableDiffusion: [ + models/Stable-diffusion + ], + Lora: [ + models/Lora + ], + LyCORIS: [ + models/LyCORIS + ], + ESRGAN: [ + models/ESRGAN + ], + GFPGAN: [ + models/GFPGAN + ], + Codeformer: [ + models/Codeformer + ], + RealESRGAN: [ + models/RealESRGAN + ], + SwinIR: [ + models/SwinIR + ], + VAE: [ + models/VAE + ], + ApproxVAE: [ + models/VAE-approx + ], + Karlo: [ + models/karlo + ], + DeepDanbooru: [ + models/deepbooru + ], + TextualInversion: [ + embeddings + ], + Hypernetwork: [ + models/hypernetworks + ], + ControlNet: [ + models/controlnet/ControlNet + ], + LDSR: [ + models/LDSR + ], + AfterDetailer: [ + models/adetailer + ], + IpAdapter: [ + models/controlnet/IpAdapter + ], + T2IAdapter: [ + models/controlnet/T2IAdapter + ], + InvokeIpAdapters15: [ + models/controlnet/DiffusersIpAdapters + ], + InvokeIpAdaptersXl: [ + models/controlnet/DiffusersIpAdaptersXL + ], + SVD: [ + models/svd + ] + }, + SharedOutputFolders: { + Text2Img: [ + outputs/txt2img-images + ], + Img2Img: [ + outputs/img2img-images + ], + Extras: [ + outputs/extras-images + ], + Text2ImgGrids: [ + outputs/txt2img-grids + ], + Img2ImgGrids: [ + outputs/img2img-grids + ], + SVD: [ + outputs/svd + ], + Saved: [ + log/images + ] + }, + LaunchOptions: [ + { + Name: Host, + Type: String, + DefaultValue: localhost, + Options: [ + --server-name + ] + }, + { + Name: Port, + Type: String, + DefaultValue: 7860, + Options: [ + --port + ] + }, + { + Name: Share, + Description: Set whether to share on Gradio, + Options: [ + --share + ] + }, + { + Name: VRAM, + Options: [ + --lowvram, + --medvram, + --medvram-sdxl + ] + }, + { + Name: Xformers, + InitialValue: true, + Options: [ + --xformers + ] + }, + { + Name: API, + InitialValue: true, + Options: [ + --api + ] + }, + { + Name: Auto Launch Web UI, + InitialValue: false, + Options: [ + --autolaunch + ] + }, + { + Name: Skip Torch CUDA Check, + InitialValue: false, + Options: [ + --skip-torch-cuda-test + ] + }, + { + Name: Skip Python Version Check, + InitialValue: true, + Options: [ + --skip-python-version-check + ] + }, + { + Name: No Half, + Description: Do not switch the model to 16-bit floats, + InitialValue: false, + Options: [ + --no-half + ] + }, + { + Name: Skip SD Model Download, + InitialValue: false, + Options: [ + --no-download-sd-model + ] + }, + { + Name: Skip Install, + Options: [ + --skip-install + ] + }, + { + Name: Extra Launch Arguments, + Type: String, + Options: [ + + ] + } + ], + AvailableSharedFolderMethods: [ + Symlink, + None + ], + AvailableTorchVersions: [ + Cpu, + Cuda, + Rocm, + Mps + ], + MainBranch: master, + OutputFolderName: outputs, + ExtensionManager: { + RelativeInstallDirectory: extensions, + DefaultManifests: [ + { + Uri: https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json + } + ] + }, + ExtraLaunchArguments: , + RepositoryName: stable-diffusion-webui, + RepositoryAuthor: AUTOMATIC1111, + GithubUrl: https://github.com/AUTOMATIC1111/stable-diffusion-webui, + DownloadLocation: {TempPath}StabilityMatrixTest\AppDataHome\Packages\stable-diffusion-webui.zip, + ByAuthor: By AUTOMATIC1111, + Disclaimer: , + OfferInOneClickInstaller: true, + ShouldIgnoreReleases: false, + UpdateAvailable: false, + IsInferenceCompatible: false, + IsCompatible: true, + SupportsExtensions: true, + AvailableVersionTypes: GithubRelease, Commit, + Prerequisites: [ + Git, + Python310, + VcRedist + ] + }, + { + Name: ComfyUI, + DisplayName: ComfyUI, + Author: comfyanonymous, + LicenseType: GPL-3.0, + LicenseUrl: https://github.com/comfyanonymous/ComfyUI/blob/master/LICENSE, + Blurb: A powerful and modular stable diffusion GUI and backend, + LaunchCommand: main.py, + PreviewImageUri: https://github.com/comfyanonymous/ComfyUI/raw/master/comfyui_screenshot.png, + ShouldIgnoreReleases: true, + IsInferenceCompatible: true, + OutputFolderName: output, + InstallerSortOrder: InferenceCompatible, + RecommendedSharedFolderMethod: Configuration, + SharedFolders: { + StableDiffusion: [ + models/checkpoints + ], + Lora: [ + models/loras + ], + ESRGAN: [ + models/upscale_models + ], + Diffusers: [ + models/diffusers + ], + VAE: [ + models/vae + ], + ApproxVAE: [ + models/vae_approx + ], + TextualInversion: [ + models/embeddings + ], + Hypernetwork: [ + models/hypernetworks + ], + ControlNet: [ + models/controlnet/ControlNet + ], + CLIP: [ + models/clip + ], + GLIGEN: [ + models/gligen + ], + IpAdapter: [ + models/ipadapter/base + ], + T2IAdapter: [ + models/controlnet/T2IAdapter + ], + InvokeIpAdapters15: [ + models/ipadapter/sd15 + ], + InvokeIpAdaptersXl: [ + models/ipadapter/sdxl + ], + InvokeClipVision: [ + models/clip_vision + ], + PromptExpansion: [ + models/prompt_expansion + ] + }, + SharedOutputFolders: { + Text2Img: [ + output + ] + }, + LaunchOptions: [ + { + Name: Host, + Type: String, + DefaultValue: 127.0.0.1, + Options: [ + --listen + ] + }, + { + Name: Port, + Type: String, + DefaultValue: 8188, + Options: [ + --port + ] + }, + { + Name: VRAM, + Options: [ + --highvram, + --normalvram, + --lowvram, + --novram + ] + }, + { + Name: Preview Method, + InitialValue: --preview-method auto, + Options: [ + --preview-method auto, + --preview-method latent2rgb, + --preview-method taesd + ] + }, + { + Name: Enable DirectML, + InitialValue: false, + Options: [ + --directml + ] + }, + { + Name: Use CPU only, + InitialValue: false, + Options: [ + --cpu + ] + }, + { + Name: Cross Attention Method, + Options: [ + --use-split-cross-attention, + --use-quad-cross-attention, + --use-pytorch-cross-attention + ] + }, + { + Name: Force Floating Point Precision, + Options: [ + --force-fp32, + --force-fp16 + ] + }, + { + Name: VAE Precision, + Options: [ + --fp16-vae, + --fp32-vae, + --bf16-vae + ] + }, + { + Name: Disable Xformers, + InitialValue: false, + Options: [ + --disable-xformers + ] + }, + { + Name: Disable upcasting of attention, + Options: [ + --dont-upcast-attention + ] + }, + { + Name: Auto-Launch, + Options: [ + --auto-launch + ] + }, + { + Name: Extra Launch Arguments, + Type: String, + Options: [ + + ] + } + ], + MainBranch: master, + AvailableTorchVersions: [ + Cpu, + Cuda, + DirectMl, + Rocm, + Mps + ], + ExtensionManager: { + RelativeInstallDirectory: custom_nodes, + DefaultManifests: [ + { + Uri: https://cdn.jsdelivr.net/gh/ltdrdata/ComfyUI-Manager/custom-node-list.json + }, + { + Uri: https://cdn.jsdelivr.net/gh/LykosAI/ComfyUI-Extensions-Index/custom-node-list.json + } + ] + }, + RepositoryName: ComfyUI, + RepositoryAuthor: comfyanonymous, + GithubUrl: https://github.com/comfyanonymous/ComfyUI, + DownloadLocation: {TempPath}StabilityMatrixTest\AppDataHome\Packages\ComfyUI.zip, + ByAuthor: By comfyanonymous, + Disclaimer: , + OfferInOneClickInstaller: true, + UpdateAvailable: false, + IsCompatible: true, + AvailableSharedFolderMethods: [ + Symlink, + Configuration, + None + ], + SupportsExtensions: true, + AvailableVersionTypes: Commit, + Prerequisites: [ + Git, + Python310, + VcRedist + ] + }, + { + Name: Fooocus, + DisplayName: Fooocus, + Author: lllyasviel, + Blurb: Fooocus is a rethinking of Stable Diffusion and Midjourney’s designs, + LicenseType: GPL-3.0, + LicenseUrl: https://github.com/lllyasviel/Fooocus/blob/main/LICENSE, + LaunchCommand: launch.py, + PreviewImageUri: https://user-images.githubusercontent.com/19834515/261830306-f79c5981-cf80-4ee3-b06b-3fef3f8bfbc7.png, + LaunchOptions: [ + { + Name: Preset, + Options: [ + --preset anime, + --preset realistic + ] + }, + { + Name: Port, + Type: String, + Description: Sets the listen port, + Options: [ + --port + ] + }, + { + Name: Share, + Description: Set whether to share on Gradio, + Options: [ + --share + ] + }, + { + Name: Listen, + Type: String, + Description: Set the listen interface, + Options: [ + --listen + ] + }, + { + Name: Output Directory, + Type: String, + Description: Override the output directory, + Options: [ + --output-path + ] + }, + { + Name: Language, + Type: String, + Description: Change the language of the UI, + Options: [ + --language + ] + }, + { + Name: Auto-Launch, + Options: [ + --auto-launch + ] + }, + { + Name: Disable Image Log, + Options: [ + --disable-image-log + ] + }, + { + Name: Disable Analytics, + Options: [ + --disable-analytics + ] + }, + { + Name: Disable Preset Model Downloads, + Options: [ + --disable-preset-download + ] + }, + { + Name: Always Download Newer Models, + Options: [ + --always-download-new-model + ] + }, + { + Name: VRAM, + Options: [ + --always-high-vram, + --always-normal-vram, + --always-low-vram, + --always-no-vram + ] + }, + { + Name: Use DirectML, + Description: Use pytorch with DirectML support, + InitialValue: false, + Options: [ + --directml + ] + }, + { + Name: Disable Xformers, + InitialValue: false, + Options: [ + --disable-xformers + ] + }, + { + Name: Extra Launch Arguments, + Type: String, + Options: [ + + ] + } + ], + RecommendedSharedFolderMethod: Configuration, + AvailableSharedFolderMethods: [ + Symlink, + Configuration, + None + ], + SharedFolders: { + StableDiffusion: [ + models/checkpoints + ], + Lora: [ + models/loras + ], + ESRGAN: [ + models/upscale_models + ], + Diffusers: [ + models/diffusers + ], + VAE: [ + models/vae + ], + ApproxVAE: [ + models/vae_approx + ], + TextualInversion: [ + models/embeddings + ], + Hypernetwork: [ + models/hypernetworks + ], + ControlNet: [ + models/controlnet + ], + CLIP: [ + models/clip + ], + GLIGEN: [ + models/gligen + ], + InvokeClipVision: [ + models/clip_vision + ] + }, + SharedFolderLayout: { + RelativeConfigPath: config.txt, + Rules: [ + { + SourceTypes: [ + StableDiffusion + ], + TargetRelativePaths: [ + models/checkpoints + ], + ConfigDocumentPaths: [ + path_checkpoints + ] + }, + { + SourceTypes: [ + Diffusers + ], + TargetRelativePaths: [ + models/diffusers + ] + }, + { + SourceTypes: [ + CLIP + ], + TargetRelativePaths: [ + models/clip + ] + }, + { + SourceTypes: [ + GLIGEN + ], + TargetRelativePaths: [ + models/gligen + ] + }, + { + SourceTypes: [ + ESRGAN + ], + TargetRelativePaths: [ + models/upscale_models + ] + }, + { + SourceTypes: [ + Hypernetwork + ], + TargetRelativePaths: [ + models/hypernetworks + ] + }, + { + SourceTypes: [ + TextualInversion + ], + TargetRelativePaths: [ + models/embeddings + ], + ConfigDocumentPaths: [ + path_embeddings + ] + }, + { + SourceTypes: [ + VAE + ], + TargetRelativePaths: [ + models/vae + ], + ConfigDocumentPaths: [ + path_vae + ] + }, + { + SourceTypes: [ + ApproxVAE + ], + TargetRelativePaths: [ + models/vae_approx + ], + ConfigDocumentPaths: [ + path_vae_approx + ] + }, + { + SourceTypes: [ + Lora, + LyCORIS + ], + TargetRelativePaths: [ + models/loras + ], + ConfigDocumentPaths: [ + path_loras + ] + }, + { + SourceTypes: [ + InvokeClipVision + ], + TargetRelativePaths: [ + models/clip_vision + ], + ConfigDocumentPaths: [ + path_clip_vision + ] + }, + { + SourceTypes: [ + ControlNet + ], + TargetRelativePaths: [ + models/controlnet + ], + ConfigDocumentPaths: [ + path_controlnet + ] + }, + { + TargetRelativePaths: [ + models/inpaint + ], + ConfigDocumentPaths: [ + path_inpaint + ] + }, + { + TargetRelativePaths: [ + models/prompt_expansion/fooocus_expansion + ], + ConfigDocumentPaths: [ + path_fooocus_expansion + ] + }, + { + TargetRelativePaths: [ + outputs + ], + ConfigDocumentPaths: [ + path_outputs + ] + } + ] + }, + SharedOutputFolders: { + Text2Img: [ + outputs + ] + }, + AvailableTorchVersions: [ + Cpu, + Cuda, + DirectMl, + Rocm + ], + MainBranch: main, + ShouldIgnoreReleases: true, + OutputFolderName: outputs, + InstallerSortOrder: Simple, + RepositoryName: Fooocus, + RepositoryAuthor: lllyasviel, + GithubUrl: https://github.com/lllyasviel/Fooocus, + DownloadLocation: {TempPath}StabilityMatrixTest\AppDataHome\Packages\Fooocus.zip, + ByAuthor: By lllyasviel, + Disclaimer: , + OfferInOneClickInstaller: true, + UpdateAvailable: false, + IsInferenceCompatible: false, + IsCompatible: true, + SupportsExtensions: false, + AvailableVersionTypes: Commit, + Prerequisites: [ + Git, + Python310, + VcRedist + ] + }, + { + Name: voltaML-fast-stable-diffusion, + DisplayName: VoltaML, + Author: VoltaML, + LicenseType: GPL-3.0, + LicenseUrl: https://github.com/VoltaML/voltaML-fast-stable-diffusion/blob/main/License, + Blurb: Fast Stable Diffusion with support for AITemplate, + LaunchCommand: main.py, + PreviewImageUri: https://github.com/LykosAI/StabilityMatrix/assets/13956642/d9a908ed-5665-41a5-a380-98458f4679a8, + InstallerSortOrder: Simple, + ShouldIgnoreReleases: true, + SharedFolders: { + StableDiffusion: [ + data/models + ], + Lora: [ + data/lora + ], + TextualInversion: [ + data/textual-inversion + ] + }, + SharedOutputFolders: { + Text2Img: [ + data/outputs/txt2img + ], + Img2Img: [ + data/outputs/img2img + ], + Extras: [ + data/outputs/extra + ] + }, + OutputFolderName: data/outputs, + AvailableTorchVersions: [ + Cpu, + Cuda, + DirectMl + ], + AvailableSharedFolderMethods: [ + Symlink, + None + ], + LaunchOptions: [ + { + Name: Log Level, + DefaultValue: --log-level INFO, + Options: [ + --log-level DEBUG, + --log-level INFO, + --log-level WARNING, + --log-level ERROR, + --log-level CRITICAL + ] + }, + { + Name: Use ngrok to expose the API, + Options: [ + --ngrok + ] + }, + { + Name: Expose the API to the network, + Options: [ + --host + ] + }, + { + Name: Skip virtualenv check, + InitialValue: true, + Options: [ + --in-container + ] + }, + { + Name: Force VoltaML to use a specific type of PyTorch distribution, + Options: [ + --pytorch-type cpu, + --pytorch-type cuda, + --pytorch-type rocm, + --pytorch-type directml, + --pytorch-type intel, + --pytorch-type vulkan + ] + }, + { + Name: Run in tandem with the Discord bot, + Options: [ + --bot + ] + }, + { + Name: Enable Cloudflare R2 bucket upload support, + Options: [ + --enable-r2 + ] + }, + { + Name: Port, + Type: String, + DefaultValue: 5003, + Options: [ + --port + ] + }, + { + Name: Only install requirements and exit, + Options: [ + --install-only + ] + }, + { + Name: Extra Launch Arguments, + Type: String, + Options: [ + + ] + } + ], + MainBranch: main, + RepositoryName: voltaML-fast-stable-diffusion, + RepositoryAuthor: VoltaML, + GithubUrl: https://github.com/VoltaML/voltaML-fast-stable-diffusion, + DownloadLocation: {TempPath}StabilityMatrixTest\AppDataHome\Packages\voltaML-fast-stable-diffusion.zip, + ByAuthor: By VoltaML, + Disclaimer: , + OfferInOneClickInstaller: true, + UpdateAvailable: false, + IsInferenceCompatible: false, + IsCompatible: true, + SupportsExtensions: false, + AvailableVersionTypes: Commit, + Prerequisites: [ + Git, + Python310, + VcRedist + ] + }, + { + Name: stable-diffusion-webui-ux, + DisplayName: Stable Diffusion Web UI-UX, + Author: anapnoe, + LicenseType: AGPL-3.0, + LicenseUrl: https://github.com/anapnoe/stable-diffusion-webui-ux/blob/master/LICENSE.txt, + Blurb: A pixel perfect design, mobile friendly, customizable interface that adds accessibility, ease of use and extended functionallity to the stable diffusion web ui., + LaunchCommand: launch.py, + PreviewImageUri: https://raw.githubusercontent.com/anapnoe/stable-diffusion-webui-ux/master/screenshot.png, + InstallerSortOrder: Advanced, + ExtensionManager: { + RelativeInstallDirectory: extensions, + DefaultManifests: [ + { + Uri: https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json + } + ] + }, + SharedFolders: { + StableDiffusion: [ + models/Stable-diffusion + ], + Lora: [ + models/Lora + ], + LyCORIS: [ + models/LyCORIS + ], + ESRGAN: [ + models/ESRGAN + ], + Codeformer: [ + models/Codeformer + ], + RealESRGAN: [ + models/RealESRGAN + ], + SwinIR: [ + models/SwinIR + ], + VAE: [ + models/VAE + ], + ApproxVAE: [ + models/VAE-approx + ], + Karlo: [ + models/karlo + ], + DeepDanbooru: [ + models/deepbooru + ], + TextualInversion: [ + embeddings + ], + Hypernetwork: [ + models/hypernetworks + ], + ControlNet: [ + models/ControlNet + ], + LDSR: [ + models/LDSR + ], + AfterDetailer: [ + models/adetailer + ] + }, + SharedOutputFolders: { + Text2Img: [ + outputs/txt2img-images + ], + Img2Img: [ + outputs/img2img-images + ], + Extras: [ + outputs/extras-images + ], + Text2ImgGrids: [ + outputs/txt2img-grids + ], + Img2ImgGrids: [ + outputs/img2img-grids + ], + Saved: [ + log/images + ] + }, + LaunchOptions: [ + { + Name: Host, + Type: String, + DefaultValue: localhost, + Options: [ + --server-name + ] + }, + { + Name: Port, + Type: String, + DefaultValue: 7860, + Options: [ + --port + ] + }, + { + Name: VRAM, + Options: [ + --lowvram, + --medvram, + --medvram-sdxl + ] + }, + { + Name: Xformers, + InitialValue: true, + Options: [ + --xformers + ] + }, + { + Name: API, + InitialValue: true, + Options: [ + --api + ] + }, + { + Name: Auto Launch Web UI, + InitialValue: false, + Options: [ + --autolaunch + ] + }, + { + Name: Skip Torch CUDA Check, + InitialValue: false, + Options: [ + --skip-torch-cuda-test + ] + }, + { + Name: Skip Python Version Check, + InitialValue: true, + Options: [ + --skip-python-version-check + ] + }, + { + Name: No Half, + Description: Do not switch the model to 16-bit floats, + InitialValue: false, + Options: [ + --no-half + ] + }, + { + Name: Skip SD Model Download, + InitialValue: false, + Options: [ + --no-download-sd-model + ] + }, + { + Name: Skip Install, + Options: [ + --skip-install + ] + }, + { + Name: Extra Launch Arguments, + Type: String, + Options: [ + + ] + } + ], + AvailableSharedFolderMethods: [ + Symlink, + None + ], + AvailableTorchVersions: [ + Cpu, + Cuda, + Rocm, + Mps + ], + MainBranch: master, + ShouldIgnoreReleases: true, + OutputFolderName: outputs, + RepositoryName: stable-diffusion-webui-ux, + RepositoryAuthor: anapnoe, + GithubUrl: https://github.com/anapnoe/stable-diffusion-webui-ux, + DownloadLocation: {TempPath}StabilityMatrixTest\AppDataHome\Packages\stable-diffusion-webui-ux.zip, + ByAuthor: By anapnoe, + Disclaimer: , + OfferInOneClickInstaller: true, + UpdateAvailable: false, + IsInferenceCompatible: false, + IsCompatible: true, + SupportsExtensions: true, + AvailableVersionTypes: Commit, + Prerequisites: [ + Git, + Python310, + VcRedist + ] + }, + { + Name: StableSwarmUI, + RepositoryName: SwarmUI, + DisplayName: SwarmUI, + Author: mcmonkeyprojects, + Blurb: A Modular Stable Diffusion Web-User-Interface, with an emphasis on making powertools easily accessible, high performance, and extensibility., + LicenseType: MIT, + LicenseUrl: https://github.com/mcmonkeyprojects/SwarmUI/blob/master/LICENSE.txt, + LaunchCommand: , + PreviewImageUri: https://github.com/mcmonkeyprojects/SwarmUI/raw/master/.github/images/swarmui.jpg, + OutputFolderName: Output, + AvailableSharedFolderMethods: [ + Symlink, + Configuration, + None + ], + RecommendedSharedFolderMethod: Configuration, + OfferInOneClickInstaller: false, + LaunchOptions: [ + { + Name: Host, + Type: String, + DefaultValue: 127.0.0.1, + Options: [ + --host + ] + }, + { + Name: Port, + Type: String, + DefaultValue: 7801, + Options: [ + --port + ] + }, + { + Name: Ngrok Path, + Type: String, + Options: [ + --ngrok-path + ] + }, + { + Name: Ngrok Basic Auth, + Type: String, + Options: [ + --ngrok-basic-auth + ] + }, + { + Name: Cloudflared Path, + Type: String, + Options: [ + --cloudflared-path + ] + }, + { + Name: Proxy Region, + Type: String, + Options: [ + --proxy-region + ] + }, + { + Name: Launch Mode, + Options: [ + --launch-mode web, + --launch-mode webinstall + ] + }, + { + Name: Extra Launch Arguments, + Type: String, + Options: [ + + ] + } + ], + SharedFolders: { + StableDiffusion: [ + Models/Stable-Diffusion + ], + Lora: [ + Models/Lora + ], + VAE: [ + Models/VAE + ], + TextualInversion: [ + Models/Embeddings + ], + ControlNet: [ + Models/controlnet + ], + InvokeClipVision: [ + Models/clip_vision + ] + }, + SharedOutputFolders: { + Text2Img: [ + Output + ] + }, + MainBranch: master, + ShouldIgnoreReleases: true, + AvailableTorchVersions: [ + Cpu, + Cuda, + DirectMl, + Rocm, + Mps + ], + InstallerSortOrder: Advanced, + Prerequisites: [ + Git, + Dotnet, + Python310, + VcRedist + ], + RepositoryAuthor: mcmonkeyprojects, + GithubUrl: https://github.com/mcmonkeyprojects/SwarmUI, + DownloadLocation: {TempPath}StabilityMatrixTest\AppDataHome\Packages\StableSwarmUI.zip, + ByAuthor: By mcmonkeyprojects, + Disclaimer: , + UpdateAvailable: false, + IsInferenceCompatible: false, + IsCompatible: true, + SupportsExtensions: false, + AvailableVersionTypes: Commit + }, + { + Name: RuinedFooocus, + DisplayName: RuinedFooocus, + Author: runew0lf, + Blurb: RuinedFooocus combines the best aspects of Stable Diffusion and Midjourney into one seamless, cutting-edge experience, + LicenseUrl: https://github.com/runew0lf/RuinedFooocus/blob/main/LICENSE, + PreviewImageUri: https://raw.githubusercontent.com/runew0lf/pmmconfigs/main/RuinedFooocus_ss.png, + InstallerSortOrder: Expert, + AvailableSharedFolderMethods: [ + Symlink, + None + ], + LaunchOptions: [ + { + Name: Preset, + Options: [ + --preset anime, + --preset realistic + ] + }, + { + Name: Port, + Type: String, + Description: Sets the listen port, + Options: [ + --port + ] + }, + { + Name: Share, + Description: Set whether to share on Gradio, + Options: [ + --share + ] + }, + { + Name: Listen, + Type: String, + Description: Set the listen interface, + Options: [ + --listen + ] + }, + { + Name: Output Directory, + Type: String, + Description: Override the output directory, + Options: [ + --output-directory + ] + }, + { + Name: VRAM, + Options: [ + --highvram, + --normalvram, + --lowvram, + --novram + ] + }, + { + Name: Use DirectML, + Description: Use pytorch with DirectML support, + InitialValue: false, + Options: [ + --directml + ] + }, + { + Name: Disable Xformers, + InitialValue: false, + Options: [ + --disable-xformers + ] + }, + { + Name: Auto-Launch, + Options: [ + --auto-launch + ] + }, + { + Name: Extra Launch Arguments, + Type: String, + Options: [ + + ] + } + ], + LicenseType: GPL-3.0, + LaunchCommand: launch.py, + SharedFolders: { + StableDiffusion: [ + models/checkpoints + ], + Lora: [ + models/loras + ], + ESRGAN: [ + models/upscale_models + ], + Diffusers: [ + models/diffusers + ], + VAE: [ + models/vae + ], + ApproxVAE: [ + models/vae_approx + ], + TextualInversion: [ + models/embeddings + ], + Hypernetwork: [ + models/hypernetworks + ], + ControlNet: [ + models/controlnet + ], + CLIP: [ + models/clip + ], + GLIGEN: [ + models/gligen + ], + InvokeClipVision: [ + models/clip_vision + ] + }, + SharedFolderLayout: { + RelativeConfigPath: config.txt, + Rules: [ + { + SourceTypes: [ + StableDiffusion + ], + TargetRelativePaths: [ + models/checkpoints + ], + ConfigDocumentPaths: [ + path_checkpoints + ] + }, + { + SourceTypes: [ + Diffusers + ], + TargetRelativePaths: [ + models/diffusers + ] + }, + { + SourceTypes: [ + CLIP + ], + TargetRelativePaths: [ + models/clip + ] + }, + { + SourceTypes: [ + GLIGEN + ], + TargetRelativePaths: [ + models/gligen + ] + }, + { + SourceTypes: [ + ESRGAN + ], + TargetRelativePaths: [ + models/upscale_models + ] + }, + { + SourceTypes: [ + Hypernetwork + ], + TargetRelativePaths: [ + models/hypernetworks + ] + }, + { + SourceTypes: [ + TextualInversion + ], + TargetRelativePaths: [ + models/embeddings + ], + ConfigDocumentPaths: [ + path_embeddings + ] + }, + { + SourceTypes: [ + VAE + ], + TargetRelativePaths: [ + models/vae + ], + ConfigDocumentPaths: [ + path_vae + ] + }, + { + SourceTypes: [ + ApproxVAE + ], + TargetRelativePaths: [ + models/vae_approx + ], + ConfigDocumentPaths: [ + path_vae_approx + ] + }, + { + SourceTypes: [ + Lora, + LyCORIS + ], + TargetRelativePaths: [ + models/loras + ], + ConfigDocumentPaths: [ + path_loras + ] + }, + { + SourceTypes: [ + InvokeClipVision + ], + TargetRelativePaths: [ + models/clip_vision + ], + ConfigDocumentPaths: [ + path_clip_vision + ] + }, + { + SourceTypes: [ + ControlNet + ], + TargetRelativePaths: [ + models/controlnet + ], + ConfigDocumentPaths: [ + path_controlnet + ] + }, + { + TargetRelativePaths: [ + models/inpaint + ], + ConfigDocumentPaths: [ + path_inpaint + ] + }, + { + TargetRelativePaths: [ + models/prompt_expansion/fooocus_expansion + ], + ConfigDocumentPaths: [ + path_fooocus_expansion + ] + }, + { + TargetRelativePaths: [ + outputs + ], + ConfigDocumentPaths: [ + path_outputs + ] + } + ] + }, + SharedOutputFolders: { + Text2Img: [ + outputs + ] + }, + AvailableTorchVersions: [ + Cpu, + Cuda, + DirectMl, + Rocm + ], + MainBranch: main, + ShouldIgnoreReleases: true, + OutputFolderName: outputs, + RepositoryName: RuinedFooocus, + RepositoryAuthor: runew0lf, + GithubUrl: https://github.com/runew0lf/RuinedFooocus, + DownloadLocation: {TempPath}StabilityMatrixTest\AppDataHome\Packages\RuinedFooocus.zip, + ByAuthor: By runew0lf, + Disclaimer: , + OfferInOneClickInstaller: true, + UpdateAvailable: false, + IsInferenceCompatible: false, + IsCompatible: true, + SupportsExtensions: false, + AvailableVersionTypes: Commit, + Prerequisites: [ + Git, + Python310, + VcRedist + ] + }, + { + Name: automatic, + DisplayName: SD.Next, + Author: vladmandic, + LicenseType: AGPL-3.0, + LicenseUrl: https://github.com/vladmandic/automatic/blob/master/LICENSE.txt, + Blurb: Stable Diffusion implementation with advanced features and modern UI, + LaunchCommand: launch.py, + PreviewImageUri: https://github.com/vladmandic/automatic/raw/master/html/screenshot-modernui.jpg, + ShouldIgnoreReleases: true, + InstallerSortOrder: Expert, + AvailableTorchVersions: [ + Cpu, + Cuda, + DirectMl, + Rocm, + Zluda + ], + SharedFolders: { + StableDiffusion: [ + models/Stable-diffusion + ], + Lora: [ + models/Lora + ], + LyCORIS: [ + models/LyCORIS + ], + ESRGAN: [ + models/ESRGAN + ], + GFPGAN: [ + models/GFPGAN + ], + BSRGAN: [ + models/BSRGAN + ], + Codeformer: [ + models/Codeformer + ], + Diffusers: [ + models/Diffusers + ], + RealESRGAN: [ + models/RealESRGAN + ], + SwinIR: [ + models/SwinIR + ], + VAE: [ + models/VAE + ], + TextualInversion: [ + models/embeddings + ], + Hypernetwork: [ + models/hypernetworks + ], + ControlNet: [ + models/ControlNet + ], + LDSR: [ + models/LDSR + ], + CLIP: [ + models/CLIP + ], + ScuNET: [ + models/ScuNET + ] + }, + SharedOutputFolders: { + Text2Img: [ + outputs/text + ], + Img2Img: [ + outputs/image + ], + Extras: [ + outputs/extras + ], + Text2ImgGrids: [ + outputs/grids + ], + Img2ImgGrids: [ + outputs/grids + ], + Saved: [ + outputs/save + ] + }, + OutputFolderName: outputs, + ExtensionManager: { + RelativeInstallDirectory: extensions, + DefaultManifests: [ + { + Uri: https://vladmandic.github.io/sd-data/pages/extensions.json + } + ] + }, + LaunchOptions: [ + { + Name: Host, + Type: String, + DefaultValue: localhost, + Options: [ + --server-name + ] + }, + { + Name: Port, + Type: String, + DefaultValue: 7860, + Options: [ + --port + ] + }, + { + Name: VRAM, + Options: [ + --lowvram, + --medvram + ] + }, + { + Name: Auto-Launch Web UI, + Options: [ + --autolaunch + ] + }, + { + Name: Force use of Intel OneAPI XPU backend, + Options: [ + --use-ipex + ] + }, + { + Name: Use DirectML if no compatible GPU is detected, + InitialValue: false, + Options: [ + --use-directml + ] + }, + { + Name: Force use of Nvidia CUDA backend, + InitialValue: true, + Options: [ + --use-cuda + ] + }, + { + Name: Force use of AMD ROCm backend, + InitialValue: false, + Options: [ + --use-rocm + ] + }, + { + Name: Force use of ZLUDA backend, + InitialValue: false, + Options: [ + --use-zluda + ] + }, + { + Name: CUDA Device ID, + Type: String, + Options: [ + --device-id + ] + }, + { + Name: API, + Options: [ + --api + ] + }, + { + Name: Debug Logging, + Options: [ + --debug + ] + }, + { + Name: Extra Launch Arguments, + Type: String, + Options: [ + + ] + } + ], + ExtraLaunchArguments: , + MainBranch: master, + RepositoryName: automatic, + RepositoryAuthor: vladmandic, + GithubUrl: https://github.com/vladmandic/automatic, + DownloadLocation: {TempPath}StabilityMatrixTest\AppDataHome\Packages\automatic.zip, + ByAuthor: By vladmandic, + Disclaimer: , + OfferInOneClickInstaller: true, + UpdateAvailable: false, + IsInferenceCompatible: false, + IsCompatible: true, + AvailableSharedFolderMethods: [ + Symlink, + Configuration, + None + ], + SupportsExtensions: true, + AvailableVersionTypes: Commit, + Prerequisites: [ + Git, + Python310, + VcRedist + ] + }, + { + Name: sdfx, + DisplayName: SDFX, + Author: sdfxai, + Blurb: The ultimate no-code platform to build and share AI apps with beautiful UI., + LicenseType: AGPL-3.0, + LicenseUrl: https://github.com/sdfxai/sdfx/blob/main/LICENSE, + LaunchCommand: setup.py, + PreviewImageUri: https://github.com/sdfxai/sdfx/raw/main/docs/static/screen-sdfx.png, + OutputFolderName: data\media\output, + AvailableTorchVersions: [ + Cpu, + Cuda, + DirectMl, + Rocm, + Mps + ], + InstallerSortOrder: Expert, + RecommendedSharedFolderMethod: Configuration, + LaunchOptions: [ + { + Name: Extra Launch Arguments, + Type: String, + Options: [ + + ] + } + ], + SharedFolders: { + StableDiffusion: [ + data/models/checkpoints + ], + Lora: [ + data/models/loras + ], + ESRGAN: [ + data/models/upscale_models + ], + Diffusers: [ + data/models/diffusers + ], + VAE: [ + data/models/vae + ], + ApproxVAE: [ + data/models/vae_approx + ], + TextualInversion: [ + data/models/embeddings + ], + Hypernetwork: [ + data/models/hypernetworks + ], + ControlNet: [ + data/models/controlnet/ControlNet + ], + CLIP: [ + data/models/clip + ], + GLIGEN: [ + data/models/gligen + ], + IpAdapter: [ + data/models/ipadapter/base + ], + T2IAdapter: [ + data/models/controlnet/T2IAdapter + ], + InvokeIpAdapters15: [ + data/models/ipadapter/sd15 + ], + InvokeIpAdaptersXl: [ + data/models/ipadapter/sdxl + ], + InvokeClipVision: [ + data/models/clip_vision + ], + PromptExpansion: [ + data/models/prompt_expansion + ] + }, + SharedOutputFolders: { + Text2Img: [ + data/media/output + ] + }, + MainBranch: main, + ShouldIgnoreReleases: true, + Prerequisites: [ + Python310, + VcRedist, + Git, + Node + ], + RepositoryName: sdfx, + RepositoryAuthor: sdfxai, + GithubUrl: https://github.com/sdfxai/sdfx, + DownloadLocation: {TempPath}StabilityMatrixTest\AppDataHome\Packages\sdfx.zip, + ByAuthor: By sdfxai, + Disclaimer: , + OfferInOneClickInstaller: true, + UpdateAvailable: false, + IsInferenceCompatible: false, + IsCompatible: true, + AvailableSharedFolderMethods: [ + Symlink, + Configuration, + None + ], + SupportsExtensions: false, + AvailableVersionTypes: Commit + }, + { + Name: InvokeAI, + DisplayName: InvokeAI, + Author: invoke-ai, + LicenseType: Apache-2.0, + LicenseUrl: https://github.com/invoke-ai/InvokeAI/blob/main/LICENSE, + Blurb: Professional Creative Tools for Stable Diffusion, + LaunchCommand: invokeai-web, + InstallerSortOrder: Nightmare, + ExtraLaunchCommands: [ + invokeai-db-maintenance, + invokeai-import-images + ], + PreviewImageUri: https://raw.githubusercontent.com/invoke-ai/InvokeAI/main/docs/assets/canvas_preview.png, + AvailableSharedFolderMethods: [ + None + ], + RecommendedSharedFolderMethod: None, + MainBranch: main, + SharedFolders: { + StableDiffusion: [ + invokeai-root\autoimport\main + ], + Lora: [ + invokeai-root\autoimport\lora + ], + TextualInversion: [ + invokeai-root\autoimport\embedding + ], + ControlNet: [ + invokeai-root\autoimport\controlnet + ], + T2IAdapter: [ + invokeai-root\autoimport\t2i_adapter + ], + InvokeIpAdapters15: [ + invokeai-root\models\sd-1\ip_adapter + ], + InvokeIpAdaptersXl: [ + invokeai-root\models\sdxl\ip_adapter + ], + InvokeClipVision: [ + invokeai-root\models\any\clip_vision + ] + }, + SharedOutputFolders: { + Text2Img: [ + invokeai-root\outputs\images + ] + }, + OutputFolderName: invokeai-root\outputs\images, + LaunchOptions: [ + { + Name: Host, + Type: String, + DefaultValue: localhost, + Options: [ + --host + ] + }, + { + Name: Port, + Type: String, + DefaultValue: 9090, + Options: [ + --port + ] + }, + { + Name: Allow Origins, + Type: String, + Description: List of host names or IP addresses that are allowed to connect to the InvokeAI API in the format ['host1','host2',...], + DefaultValue: [], + Options: [ + --allow-origins + ] + }, + { + Name: Precision, + Options: [ + --precision auto, + --precision float16, + --precision float32 + ] + }, + { + Name: Extra Launch Arguments, + Type: String, + Options: [ + + ] + } + ], + AvailableTorchVersions: [ + Cpu, + Cuda, + Rocm, + Mps + ], + Prerequisites: [ + Python310, + VcRedist, + Git, + Node + ], + RepositoryName: InvokeAI, + RepositoryAuthor: invoke-ai, + GithubUrl: https://github.com/invoke-ai/InvokeAI, + DownloadLocation: {TempPath}StabilityMatrixTest\AppDataHome\Packages\InvokeAI.zip, + ByAuthor: By invoke-ai, + Disclaimer: , + OfferInOneClickInstaller: true, + ShouldIgnoreReleases: false, + UpdateAvailable: false, + IsInferenceCompatible: false, + IsCompatible: true, + SupportsExtensions: false, + AvailableVersionTypes: GithubRelease, Commit + }, + { + Name: Fooocus-ControlNet-SDXL, + DisplayName: Fooocus-ControlNet, + Author: fenneishi, + Blurb: Fooocus-ControlNet adds more control to the original Fooocus software., + Disclaimer: This package may no longer be actively maintained, + LicenseUrl: https://github.com/fenneishi/Fooocus-ControlNet-SDXL/blob/main/LICENSE, + PreviewImageUri: https://github.com/fenneishi/Fooocus-ControlNet-SDXL/raw/main/asset/canny/snip.png, + InstallerSortOrder: Impossible, + OfferInOneClickInstaller: false, + SharedFolderLayout: { + RelativeConfigPath: user_path_config.txt, + Rules: [ + { + SourceTypes: [ + StableDiffusion + ], + TargetRelativePaths: [ + models/checkpoints + ], + ConfigDocumentPaths: [ + path_checkpoints + ] + }, + { + SourceTypes: [ + Diffusers + ], + TargetRelativePaths: [ + models/diffusers + ] + }, + { + SourceTypes: [ + CLIP + ], + TargetRelativePaths: [ + models/clip + ] + }, + { + SourceTypes: [ + GLIGEN + ], + TargetRelativePaths: [ + models/gligen + ] + }, + { + SourceTypes: [ + ESRGAN + ], + TargetRelativePaths: [ + models/upscale_models + ] + }, + { + SourceTypes: [ + Hypernetwork + ], + TargetRelativePaths: [ + models/hypernetworks + ] + }, + { + SourceTypes: [ + TextualInversion + ], + TargetRelativePaths: [ + models/embeddings + ], + ConfigDocumentPaths: [ + path_embeddings + ] + }, + { + SourceTypes: [ + VAE + ], + TargetRelativePaths: [ + models/vae + ], + ConfigDocumentPaths: [ + path_vae + ] + }, + { + SourceTypes: [ + ApproxVAE + ], + TargetRelativePaths: [ + models/vae_approx + ], + ConfigDocumentPaths: [ + path_vae_approx + ] + }, + { + SourceTypes: [ + Lora, + LyCORIS + ], + TargetRelativePaths: [ + models/loras + ], + ConfigDocumentPaths: [ + path_loras + ] + }, + { + SourceTypes: [ + InvokeClipVision + ], + TargetRelativePaths: [ + models/clip_vision + ], + ConfigDocumentPaths: [ + path_clip_vision + ] + }, + { + SourceTypes: [ + ControlNet + ], + TargetRelativePaths: [ + models/controlnet + ], + ConfigDocumentPaths: [ + path_controlnet + ] + }, + { + TargetRelativePaths: [ + models/inpaint + ], + ConfigDocumentPaths: [ + path_inpaint + ] + }, + { + TargetRelativePaths: [ + models/prompt_expansion/fooocus_expansion + ], + ConfigDocumentPaths: [ + path_fooocus_expansion + ] + }, + { + TargetRelativePaths: [ + outputs + ], + ConfigDocumentPaths: [ + path_outputs + ] + } + ] + }, + LicenseType: GPL-3.0, + LaunchCommand: launch.py, + LaunchOptions: [ + { + Name: Preset, + Options: [ + --preset anime, + --preset realistic + ] + }, + { + Name: Port, + Type: String, + Description: Sets the listen port, + Options: [ + --port + ] + }, + { + Name: Share, + Description: Set whether to share on Gradio, + Options: [ + --share + ] + }, + { + Name: Listen, + Type: String, + Description: Set the listen interface, + Options: [ + --listen + ] + }, + { + Name: Output Directory, + Type: String, + Description: Override the output directory, + Options: [ + --output-path + ] + }, + { + Name: Language, + Type: String, + Description: Change the language of the UI, + Options: [ + --language + ] + }, + { + Name: Auto-Launch, + Options: [ + --auto-launch + ] + }, + { + Name: Disable Image Log, + Options: [ + --disable-image-log + ] + }, + { + Name: Disable Analytics, + Options: [ + --disable-analytics + ] + }, + { + Name: Disable Preset Model Downloads, + Options: [ + --disable-preset-download + ] + }, + { + Name: Always Download Newer Models, + Options: [ + --always-download-new-model + ] + }, + { + Name: VRAM, + Options: [ + --always-high-vram, + --always-normal-vram, + --always-low-vram, + --always-no-vram + ] + }, + { + Name: Use DirectML, + Description: Use pytorch with DirectML support, + InitialValue: false, + Options: [ + --directml + ] + }, + { + Name: Disable Xformers, + InitialValue: false, + Options: [ + --disable-xformers + ] + }, + { + Name: Extra Launch Arguments, + Type: String, + Options: [ + + ] + } + ], + RecommendedSharedFolderMethod: Configuration, + AvailableSharedFolderMethods: [ + Symlink, + Configuration, + None + ], + SharedFolders: { + StableDiffusion: [ + models/checkpoints + ], + Lora: [ + models/loras + ], + ESRGAN: [ + models/upscale_models + ], + Diffusers: [ + models/diffusers + ], + VAE: [ + models/vae + ], + ApproxVAE: [ + models/vae_approx + ], + TextualInversion: [ + models/embeddings + ], + Hypernetwork: [ + models/hypernetworks + ], + ControlNet: [ + models/controlnet + ], + CLIP: [ + models/clip + ], + GLIGEN: [ + models/gligen + ], + InvokeClipVision: [ + models/clip_vision + ] + }, + SharedOutputFolders: { + Text2Img: [ + outputs + ] + }, + AvailableTorchVersions: [ + Cpu, + Cuda, + DirectMl, + Rocm + ], + MainBranch: main, + ShouldIgnoreReleases: true, + OutputFolderName: outputs, + RepositoryName: Fooocus-ControlNet-SDXL, + RepositoryAuthor: fenneishi, + GithubUrl: https://github.com/fenneishi/Fooocus-ControlNet-SDXL, + DownloadLocation: {TempPath}StabilityMatrixTest\AppDataHome\Packages\Fooocus-ControlNet-SDXL.zip, + ByAuthor: By fenneishi, + UpdateAvailable: false, + IsInferenceCompatible: false, + IsCompatible: true, + SupportsExtensions: false, + AvailableVersionTypes: Commit, + Prerequisites: [ + Git, + Python310, + VcRedist + ] + }, + { + Name: Fooocus-MRE, + DisplayName: Fooocus-MRE, + Author: MoonRide303, + Blurb: Fooocus-MRE is an image generating software, enhanced variant of the original Fooocus dedicated for a bit more advanced users, + LicenseType: GPL-3.0, + LicenseUrl: https://github.com/MoonRide303/Fooocus-MRE/blob/moonride-main/LICENSE, + LaunchCommand: launch.py, + PreviewImageUri: https://user-images.githubusercontent.com/130458190/265366059-ce430ea0-0995-4067-98dd-cef1d7dc1ab6.png, + Disclaimer: This package may no longer receive updates from its author. It may be removed from Stability Matrix in the future., + InstallerSortOrder: Impossible, + OfferInOneClickInstaller: false, + LaunchOptions: [ + { + Name: Port, + Type: String, + Description: Sets the listen port, + Options: [ + --port + ] + }, + { + Name: Share, + Description: Set whether to share on Gradio, + Options: [ + --share + ] + }, + { + Name: Listen, + Type: String, + Description: Set the listen interface, + Options: [ + --listen + ] + }, + { + Name: Extra Launch Arguments, + Type: String, + Options: [ + + ] + } + ], + AvailableSharedFolderMethods: [ + Symlink, + None + ], + SharedFolders: { + StableDiffusion: [ + models/checkpoints + ], + Lora: [ + models/loras + ], + ESRGAN: [ + models/upscale_models + ], + Diffusers: [ + models/diffusers + ], + VAE: [ + models/vae + ], + ApproxVAE: [ + models/vae_approx + ], + TextualInversion: [ + models/embeddings + ], + Hypernetwork: [ + models/hypernetworks + ], + ControlNet: [ + models/controlnet + ], + CLIP: [ + models/clip + ], + GLIGEN: [ + models/gligen + ] + }, + SharedOutputFolders: { + Text2Img: [ + outputs + ] + }, + AvailableTorchVersions: [ + Cpu, + Cuda, + Rocm + ], + MainBranch: moonride-main, + OutputFolderName: outputs, + RepositoryName: Fooocus-MRE, + RepositoryAuthor: MoonRide303, + GithubUrl: https://github.com/MoonRide303/Fooocus-MRE, + DownloadLocation: {TempPath}StabilityMatrixTest\AppDataHome\Packages\Fooocus-MRE.zip, + ByAuthor: By MoonRide303, + ShouldIgnoreReleases: false, + UpdateAvailable: false, + IsInferenceCompatible: false, + IsCompatible: true, + SupportsExtensions: false, + AvailableVersionTypes: GithubRelease, Commit, + Prerequisites: [ + Git, + Python310, + VcRedist + ] + } + ], + TrainingPackages: [ + { + Name: OneTrainer, + DisplayName: OneTrainer, + Author: Nerogar, + Blurb: OneTrainer is a one-stop solution for all your stable diffusion training needs, + LicenseType: AGPL-3.0, + LicenseUrl: https://github.com/Nerogar/OneTrainer/blob/master/LICENSE.txt, + LaunchCommand: scripts/train_ui.py, + PreviewImageUri: https://github.com/Nerogar/OneTrainer/blob/master/resources/icons/icon.png?raw=true, + OutputFolderName: , + RecommendedSharedFolderMethod: None, + AvailableTorchVersions: [ + Cuda + ], + IsCompatible: true, + PackageType: SdTraining, + AvailableSharedFolderMethods: [ + None + ], + InstallerSortOrder: Nightmare, + OfferInOneClickInstaller: false, + ShouldIgnoreReleases: true, + Prerequisites: [ + Git, + Python310, + VcRedist, + Tkinter + ], + LaunchOptions: [ + { + Name: Extra Launch Arguments, + Type: String, + Options: [ + + ] + } + ], + MainBranch: master, + RepositoryName: OneTrainer, + RepositoryAuthor: Nerogar, + GithubUrl: https://github.com/Nerogar/OneTrainer, + DownloadLocation: {TempPath}StabilityMatrixTest\AppDataHome\Packages\OneTrainer.zip, + ByAuthor: By Nerogar, + Disclaimer: , + UpdateAvailable: false, + IsInferenceCompatible: false, + SupportsExtensions: false, + AvailableVersionTypes: Commit + }, + { + Name: kohya_ss, + DisplayName: kohya_ss, + Author: bmaltais, + Blurb: A Windows-focused Gradio GUI for Kohya's Stable Diffusion trainers, + LicenseType: Apache-2.0, + LicenseUrl: https://github.com/bmaltais/kohya_ss/blob/master/LICENSE.md, + LaunchCommand: kohya_gui.py, + PreviewImageUri: https://camo.githubusercontent.com/5154eea62c113d5c04393e51a0d0f76ef25a723aad29d256dcc85ead1961cd41/68747470733a2f2f696d672e796f75747562652e636f6d2f76692f6b35696d713031757655592f302e6a7067, + OutputFolderName: , + IsCompatible: true, + Disclaimer: Nvidia GPU with at least 8GB VRAM is recommended. May be unstable on Linux., + InstallerSortOrder: UltraNightmare, + PackageType: SdTraining, + OfferInOneClickInstaller: false, + RecommendedSharedFolderMethod: None, + AvailableTorchVersions: [ + Cuda + ], + AvailableSharedFolderMethods: [ + None + ], + Prerequisites: [ + Git, + Python310, + VcRedist, + Tkinter + ], + LaunchOptions: [ + { + Name: Listen Address, + Type: String, + DefaultValue: 127.0.0.1, + Options: [ + --listen + ] + }, + { + Name: Port, + Type: String, + Options: [ + --port + ] + }, + { + Name: Username, + Type: String, + Options: [ + --username + ] + }, + { + Name: Password, + Type: String, + Options: [ + --password + ] + }, + { + Name: Auto-Launch Browser, + Options: [ + --inbrowser + ] + }, + { + Name: Share, + Options: [ + --share + ] + }, + { + Name: Headless, + Options: [ + --headless + ] + }, + { + Name: Language, + Type: String, + Options: [ + --language + ] + }, + { + Name: Extra Launch Arguments, + Type: String, + Options: [ + + ] + } + ], + MainBranch: master, + RepositoryName: kohya_ss, + RepositoryAuthor: bmaltais, + GithubUrl: https://github.com/bmaltais/kohya_ss, + DownloadLocation: {TempPath}StabilityMatrixTest\AppDataHome\Packages\kohya_ss.zip, + ByAuthor: By bmaltais, + ShouldIgnoreReleases: false, + UpdateAvailable: false, + IsInferenceCompatible: false, + SupportsExtensions: false, + AvailableVersionTypes: GithubRelease, Commit + } + ], + Title: Add Package, + IconSource: { + Type: SymbolIconSource + }, + ShowIncompatiblePackages: false, + SearchFilter: , + CanNavigateNext: false, + CanNavigatePrevious: false, + RemoveFromParentListCommand: ViewModelBase.RemoveFromParentList(), + HasErrors: false + } + ], + CurrentPagePath: [ + { + Title: Packages, + IconSource: { + Type: SymbolIconSource + }, + NavigateToSubPageCommand: MainPackageManagerViewModel.NavigateToSubPage(Type viewModelType), + CanNavigateNext: false, + CanNavigatePrevious: false, + RemoveFromParentListCommand: ViewModelBase.RemoveFromParentList(), + HasErrors: false + } + ], + CurrentPage: { + Title: Packages, + IconSource: { + Type: SymbolIconSource + }, + NavigateToSubPageCommand: MainPackageManagerViewModel.NavigateToSubPage(Type viewModelType), + CanNavigateNext: false, + CanNavigatePrevious: false, + RemoveFromParentListCommand: ViewModelBase.RemoveFromParentList(), + HasErrors: false + }, + CanNavigateNext: false, + CanNavigatePrevious: false, + RemoveFromParentListCommand: ViewModelBase.RemoveFromParentList(), + HasErrors: false + }, + RemoveFromParentListCommand: ViewModelBase.RemoveFromParentList(), + HasErrors: false + } +} \ No newline at end of file diff --git a/StabilityMatrix.UITests/TestAppBuilder.cs b/StabilityMatrix.UITests/TestAppBuilder.cs index bff55647b..c3b752c24 100644 --- a/StabilityMatrix.UITests/TestAppBuilder.cs +++ b/StabilityMatrix.UITests/TestAppBuilder.cs @@ -1,15 +1,18 @@ ο»Ώusing Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; using NSubstitute; using NSubstitute.Extensions; using Semver; using StabilityMatrix.Avalonia; using StabilityMatrix.Avalonia.ViewModels.Dialogs; +using StabilityMatrix.Core.Api; using StabilityMatrix.Core.Helper; using StabilityMatrix.Core.Models.Update; using StabilityMatrix.Core.Services; using StabilityMatrix.Core.Updater; using StabilityMatrix.UITests; +using StabilityMatrix.UITests.Fakes; [assembly: AvaloniaTestApplication(typeof(TestAppBuilder))] @@ -42,7 +45,9 @@ private static void ConfigureGlobals() private static void ConfigureAppServices(IServiceCollection serviceCollection) { // ISettingsManager - var settingsManager = Substitute.ForPartsOf(); + var settingsManager = Substitute.ForPartsOf(NullLogger.Instance); + settingsManager.Settings.Analytics.LastSeenConsentVersion = Compat.AppVersion; + settingsManager.Settings.Analytics.LastSeenConsentAccepted = true; serviceCollection.AddSingleton(settingsManager); // IUpdateHelper @@ -56,7 +61,7 @@ private static void ConfigureAppServices(IServiceCollection serviceCollection) Changelog = new Uri("https://example.org"), HashBlake3 = "46e11a5216c55d4c9d3c54385f62f3e1022537ae191615237f05e06d6f8690d0", Signature = - "IX5/CCXWJQG0oGkYWVnuF34gTqF/dJSrDrUd6fuNMYnncL39G3HSvkXrjvJvR18MA2rQNB5z13h3/qBSf9c7DA==" + "IX5/CCXWJQG0oGkYWVnuF34gTqF/dJSrDrUd6fuNMYnncL39G3HSvkXrjvJvR18MA2rQNB5z13h3/qBSf9c7DA==", }; var updateHelper = Substitute.For(); @@ -68,15 +73,36 @@ private static void ConfigureAppServices(IServiceCollection serviceCollection) serviceCollection.AddSingleton(updateHelper); - // UpdateViewModel - var updateViewModel = Substitute.ForPartsOf( - Substitute.For>(), - settingsManager, - null, - updateHelper + var httpClientFactory = new StaticHttpClientFactory( + new HttpClient(new StaticHttpMessageHandler()) { BaseAddress = new Uri("https://example.org") } ); - updateViewModel.Configure().GetReleaseNotes("").Returns("Test"); + serviceCollection.AddSingleton(httpClientFactory); - serviceCollection.AddSingleton(updateViewModel); + serviceCollection.AddSingleton(); + } + + private sealed class StaticHttpClientFactory(HttpClient httpClient) : IHttpClientFactory + { + public HttpClient CreateClient(string name) => httpClient; + } + + private sealed class StaticHttpMessageHandler : HttpMessageHandler + { + protected override Task SendAsync( + HttpRequestMessage request, + CancellationToken cancellationToken + ) + { + return Task.FromResult( + new HttpResponseMessage(System.Net.HttpStatusCode.OK) + { + Content = new StringContent( + "# Test changelog", + System.Text.Encoding.UTF8, + "text/markdown" + ), + } + ); + } } } diff --git a/StabilityMatrix/Assets/7za - LICENSE.txt b/StabilityMatrix/Assets/7za - LICENSE.txt index 80473a66f..dae57cb4f 100644 --- a/StabilityMatrix/Assets/7za - LICENSE.txt +++ b/StabilityMatrix/Assets/7za - LICENSE.txt @@ -1,43 +1,123 @@ -7-Zip Extra 18.01 ------------------ + 7-Zip Extra + ~~~~~~~~~~~ + License for use and distribution + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -7-Zip Extra is package of extra modules of 7-Zip. + Copyright (C) 1999-2026 Igor Pavlov. -7-Zip Copyright (C) 1999-2018 Igor Pavlov. + The licenses for files are: -7-Zip is free software. Read License.txt for more information about license. + - 7za.exe: + - The "GNU LGPL" as main license for most of the code + - The "BSD 3-clause License" for some code + - The "BSD 2-clause License" for some code + - All other files: the "GNU LGPL". -Source code of binaries can be found at: - http://www.7-zip.org/ + Redistributions in binary form must reproduce related license information from this file. + Note: + You can use 7-Zip Extra on any computer, including a computer in a commercial + organization. You don't need to register or pay for 7-Zip. -7-Zip Extra -~~~~~~~~~~~ -License for use and distribution -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + It is allowed to digitally sign DLL and EXE files included into this package + with arbitrary signatures of third parties. -Copyright (C) 1999-2018 Igor Pavlov. -7-Zip Extra files are under the GNU LGPL license. +GNU LGPL information +-------------------- + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. -Notes: - You can use 7-Zip Extra on any computer, including a computer in a commercial - organization. You don't need to register or pay for 7-Zip. + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + You can receive a copy of the GNU Lesser General Public License from + http://www.gnu.org/ -GNU LGPL information --------------------- - This library is free software; you can redistribute it and/or - modify it under the terms of the GNU Lesser General Public - License as published by the Free Software Foundation; either - version 2.1 of the License, or (at your option) any later version. - This library is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU - Lesser General Public License for more details. +BSD 3-clause License in 7-Zip code +---------------------------------- + + The "BSD 3-clause License" is used for the following code in 7za.exe + - ZSTD data decompression. + that code was developed using original zstd decoder code as reference code. + The original zstd decoder code was developed by Facebook Inc, + that also uses the "BSD 3-clause License". + + Copyright (c) Facebook, Inc. All rights reserved. + Copyright (c) 2023-2025 Igor Pavlov. + +Text of the "BSD 3-clause License" +---------------------------------- + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may + be used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +--- + + + + +BSD 2-clause License in 7-Zip code +---------------------------------- + + The "BSD 2-clause License" is used for the XXH64 code in 7za.exe. + + XXH64 code in 7-Zip was derived from the original XXH64 code developed by Yann Collet. + + Copyright (c) 2012-2021 Yann Collet. + Copyright (c) 2023-2025 Igor Pavlov. + +Text of the "BSD 2-clause License" +---------------------------------- + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - You can receive a copy of the GNU Lesser General Public License from - http://www.gnu.org/ +--- diff --git a/StabilityMatrix/Assets/7za.exe b/StabilityMatrix/Assets/7za.exe index a67de9158..25773795e 100644 Binary files a/StabilityMatrix/Assets/7za.exe and b/StabilityMatrix/Assets/7za.exe differ