diff --git a/Cargo.lock b/Cargo.lock index 4aae673311..83c2ceb3f1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -54,7 +54,15 @@ checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" dependencies = [ "memchr", ] - +[[package]] +name = "ai-models" +version = "0.0.0" +dependencies = [ + "log", + "serde", + "serde_json", + "thiserror 2.0.18", +] [[package]] name = "allocator-api2" version = "0.2.21" diff --git a/Cargo.toml b/Cargo.toml index 468b88ea79..089b247445 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,8 @@ members = [ "desktop/platform/mac", "desktop/platform/win", "editor", + "frontend/wasm", + "libraries/ai-models", "frontend/wrapper", "libraries/dyn-any", "libraries/math-parser", diff --git a/editor/src/messages/portfolio/document/document_message_handler.rs b/editor/src/messages/portfolio/document/document_message_handler.rs index 68b79a8d7f..19a4c85ad1 100644 --- a/editor/src/messages/portfolio/document/document_message_handler.rs +++ b/editor/src/messages/portfolio/document/document_message_handler.rs @@ -221,7 +221,7 @@ impl MessageHandler> for DocumentMes // Send the overlays message to the overlays message handler self.overlays_message_handler - .process_message(message, responses, OverlaysMessageContext { visibility_settings, viewport }); + .process_message(message, responses, OverlaysMessageContext { visibility_settings, viewport, animation_time: ipp.time as f64 }); } DocumentMessage::PropertiesPanel(message) => { let context = PropertiesPanelMessageContext { diff --git a/editor/src/messages/portfolio/document/graph_operation/graph_operation_message_handler.rs b/editor/src/messages/portfolio/document/graph_operation/graph_operation_message_handler.rs index 03fc166dd4..7be44a94ea 100644 --- a/editor/src/messages/portfolio/document/graph_operation/graph_operation_message_handler.rs +++ b/editor/src/messages/portfolio/document/graph_operation/graph_operation_message_handler.rs @@ -665,6 +665,8 @@ fn apply_usvg_stroke(stroke: &usvg::Stroke, modify_inputs: &mut ModifyInputsCont if let usvg::Paint::Color(color) = &stroke.paint() { modify_inputs.stroke_set(Stroke { color: Some(usvg_color(*color, stroke.opacity().get())), + //Added the gradient field to the Stroke struct + gradient: None, weight: stroke.width().get() as f64, dash_lengths: stroke.dasharray().as_ref().map(|lengths| lengths.iter().map(|&length| length as f64).collect()).unwrap_or_default(), dash_offset: stroke.dashoffset() as f64, diff --git a/editor/src/messages/portfolio/document/overlays/overlays_message_handler.rs b/editor/src/messages/portfolio/document/overlays/overlays_message_handler.rs index 3bf436251c..52831db81a 100644 --- a/editor/src/messages/portfolio/document/overlays/overlays_message_handler.rs +++ b/editor/src/messages/portfolio/document/overlays/overlays_message_handler.rs @@ -5,6 +5,8 @@ use crate::messages::prelude::*; pub struct OverlaysMessageContext<'a> { pub visibility_settings: OverlaysVisibilitySettings, pub viewport: &'a ViewportMessageHandler, + /// Current time in milliseconds passed from the input preprocessor, used to drive overlay animations (e.g. marching ants). + pub animation_time: f64, } #[derive(Debug, Clone, Default, ExtractField)] @@ -19,7 +21,7 @@ pub struct OverlaysMessageHandler { #[message_handler_data] impl MessageHandler> for OverlaysMessageHandler { fn process_message(&mut self, message: OverlaysMessage, responses: &mut VecDeque, context: OverlaysMessageContext) { - let OverlaysMessageContext { visibility_settings, viewport, .. } = context; + let OverlaysMessageContext { visibility_settings, viewport, animation_time } = context; match message { #[cfg(target_family = "wasm")] @@ -55,6 +57,7 @@ impl MessageHandler> for OverlaysMes render_context: canvas_context.clone(), visibility_settings: visibility_settings.clone(), viewport: *viewport, + animation_time, }, }); for provider in &self.overlay_providers { @@ -62,6 +65,7 @@ impl MessageHandler> for OverlaysMes render_context: canvas_context.clone(), visibility_settings: visibility_settings.clone(), viewport: *viewport, + animation_time, })); } } @@ -70,7 +74,7 @@ impl MessageHandler> for OverlaysMes OverlaysMessage::Draw => { use super::utility_types::OverlayContext; - let overlay_context = OverlayContext::new(*viewport, visibility_settings); + let overlay_context = OverlayContext::new(*viewport, visibility_settings, animation_time); if visibility_settings.all() { responses.add(DocumentMessage::GridOverlays { context: overlay_context.clone() }); @@ -83,7 +87,7 @@ impl MessageHandler> for OverlaysMes } #[cfg(all(not(target_family = "wasm"), test))] OverlaysMessage::Draw => { - let _ = (responses, visibility_settings, viewport); + let _ = (responses, visibility_settings, viewport, animation_time); } OverlaysMessage::AddProvider { provider: message } => { self.overlay_providers.insert(message); diff --git a/editor/src/messages/portfolio/document/overlays/utility_types_native.rs b/editor/src/messages/portfolio/document/overlays/utility_types_native.rs index 9991d39c57..e0be90f8e4 100644 --- a/editor/src/messages/portfolio/document/overlays/utility_types_native.rs +++ b/editor/src/messages/portfolio/document/overlays/utility_types_native.rs @@ -170,6 +170,8 @@ pub struct OverlayContext { internal: Arc>, pub viewport: ViewportMessageHandler, pub visibility_settings: OverlaysVisibilitySettings, + /// Current time in milliseconds, used to animate effects like marching ants. + pub animation_time: f64, } impl Clone for OverlayContext { @@ -181,6 +183,7 @@ impl Clone for OverlayContext { internal: self.internal.clone(), viewport: self.viewport, visibility_settings, + animation_time: self.animation_time, } } } @@ -198,6 +201,7 @@ impl std::fmt::Debug for OverlayContext { .field("scene", &"Scene { ... }") .field("viewport", &self.viewport) .field("visibility_settings", &self.visibility_settings) + .field("animation_time", &self.animation_time) .finish() } } @@ -209,6 +213,7 @@ impl Default for OverlayContext { internal: Mutex::new(OverlayContextInternal::default()).into(), viewport: ViewportMessageHandler::default(), visibility_settings: OverlaysVisibilitySettings::default(), + animation_time: 0., } } } @@ -220,7 +225,7 @@ impl core::hash::Hash for OverlayContext { impl OverlayContext { #[allow(dead_code)] - pub(super) fn new(viewport: ViewportMessageHandler, visibility_settings: OverlaysVisibilitySettings) -> Self { + pub(super) fn new(viewport: ViewportMessageHandler, visibility_settings: OverlaysVisibilitySettings, animation_time: f64) -> Self { Self { internal: Arc::new(Mutex::new(OverlayContextInternal::new(viewport, visibility_settings))), viewport, diff --git a/editor/src/messages/portfolio/document/overlays/utility_types_web.rs b/editor/src/messages/portfolio/document/overlays/utility_types_web.rs index c03ba387d3..8518881fdc 100644 --- a/editor/src/messages/portfolio/document/overlays/utility_types_web.rs +++ b/editor/src/messages/portfolio/document/overlays/utility_types_web.rs @@ -160,6 +160,8 @@ pub struct OverlayContext { pub render_context: web_sys::CanvasRenderingContext2d, pub viewport: ViewportMessageHandler, pub visibility_settings: OverlaysVisibilitySettings, + /// Current time in milliseconds (e.g. from `js_sys::Date::now()`), used to animate effects like marching ants. + pub animation_time: f64, } // Message hashing isn't used but is required by the message system macros impl core::hash::Hash for OverlayContext { diff --git a/editor/src/messages/tool/tool_messages/select_tool.rs b/editor/src/messages/tool/tool_messages/select_tool.rs index b8c6751565..7a7dd07bf2 100644 --- a/editor/src/messages/tool/tool_messages/select_tool.rs +++ b/editor/src/messages/tool/tool_messages/select_tool.rs @@ -406,6 +406,8 @@ struct SelectToolData { snap_candidates: Vec, auto_panning: AutoPanning, drag_start_center: ViewportPosition, + /// Whether the tool is currently subscribed to animation frame events to drive the marching ants animation. + marching_ants_subscribed: bool, } impl SelectToolData { @@ -421,6 +423,27 @@ impl SelectToolData { } } } +/// Subscribe to per-frame animation ticks so the marching ants selection border animates continuously. + fn start_marching_ants(&mut self, responses: &mut VecDeque) { + if !self.marching_ants_subscribed { + self.marching_ants_subscribed = true; + responses.add(BroadcastMessage::SubscribeEvent { + on: EventMessage::AnimationFrame, + send: Box::new(OverlaysMessage::Draw.into()), + }); + } + } + + /// Unsubscribe from per-frame animation ticks when the selection box is no longer being drawn. + fn stop_marching_ants(&mut self, responses: &mut VecDeque) { + if self.marching_ants_subscribed { + self.marching_ants_subscribed = false; + responses.add(BroadcastMessage::UnsubscribeEvent { + on: EventMessage::AnimationFrame, + send: Box::new(OverlaysMessage::Draw.into()), + }); + } + } pub fn selection_quad(&self) -> Quad { let bbox = self.selection_box(); @@ -965,10 +988,17 @@ impl Fsm for SelectToolFsmState { let fill_color = Some(COLOR_OVERLAY_BLUE_05); let polygon = &tool_data.lasso_polygon; + // Animate the dash offset to produce the "marching ants" effect. The dash pattern repeats every 8 px (4 px dash + 4 px gap), + // so wrapping the time to [0, 8) via the modulo gives a smooth, continuously looping animation. + // MARCHING_ANTS_PIXELS_PER_SECOND controls how fast the dashes march around the selection border. + const MARCHING_ANTS_PIXELS_PER_SECOND: f64 = 100.; // How many pixels the pattern advances per second + const MARCHING_ANTS_PERIOD: f64 = 8.; // One full cycle = dash length (4 px) + gap length (4 px) + let marching_ants_offset = (overlay_context.animation_time / 1000. * MARCHING_ANTS_PIXELS_PER_SECOND) % MARCHING_ANTS_PERIOD; + match (selection_shape, current_selection_mode) { - (SelectionShapeType::Box, SelectionMode::Enclosed) => overlay_context.dashed_quad(quad, None, fill_color, Some(4.), Some(4.), Some(0.5)), - (SelectionShapeType::Lasso, SelectionMode::Enclosed) => overlay_context.dashed_polygon(polygon, None, fill_color, Some(4.), Some(4.), Some(0.5)), + (SelectionShapeType::Box, SelectionMode::Enclosed) => overlay_context.dashed_quad(quad, None, fill_color, Some(4.), Some(4.), Some(marching_ants_offset)), + (SelectionShapeType::Lasso, SelectionMode::Enclosed) => overlay_context.dashed_polygon(polygon, None, fill_color, Some(4.), Some(4.), Some(marching_ants_offset)), (SelectionShapeType::Box, _) => overlay_context.quad(quad, None, fill_color), (SelectionShapeType::Lasso, _) => overlay_context.polygon(polygon, None, fill_color), } @@ -1125,6 +1155,8 @@ impl Fsm for SelectToolFsmState { } } else { let selection_shape = if input.keyboard.key(lasso_select) { SelectionShapeType::Lasso } else { SelectionShapeType::Box }; + // Subscribe to animation frames so the marching ants selection border animates continuously. + tool_data.start_marching_ants(responses); SelectToolFsmState::Drawing { selection_shape, has_drawn: false } } }; @@ -1556,7 +1588,8 @@ impl Fsm for SelectToolFsmState { } tool_data.lasso_polygon.clear(); - + // Unsubscribe from animation frames now that the selection box is finalized. + tool_data.stop_marching_ants(responses); responses.add(OverlaysMessage::Draw); let selection = tool_data.nested_selection_behavior; @@ -1603,6 +1636,8 @@ impl Fsm for SelectToolFsmState { responses.add(DocumentMessage::AbortTransaction); tool_data.snap_manager.cleanup(responses); tool_data.lasso_polygon.clear(); + // Unsubscribe from marching ants animation in case we were in Drawing state. + tool_data.stop_marching_ants(responses); responses.add(OverlaysMessage::Draw); let selection = tool_data.nested_selection_behavior; diff --git a/libraries/ai-models/Cargo.toml b/libraries/ai-models/Cargo.toml new file mode 100644 index 0000000000..a52752c456 --- /dev/null +++ b/libraries/ai-models/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "ai-models" +version = "0.0.0" +rust-version = "1.88" +edition = "2024" +authors = ["Graphite Authors "] +description = "Model Registry & Metadata Schema for Graphite AI capabilities" +license = "Apache-2.0" + +[dependencies] +serde = { workspace = true } +serde_json = { workspace = true } +thiserror = { workspace = true } +log = { workspace = true } diff --git a/libraries/ai-models/src/lib.rs b/libraries/ai-models/src/lib.rs new file mode 100644 index 0000000000..0bdcfd95eb --- /dev/null +++ b/libraries/ai-models/src/lib.rs @@ -0,0 +1,14 @@ +//! # AI Model Registry & Metadata Schema +//! +//! This crate is the central nervous system for Graphite's AI capabilities. +//! It manages how the editor identifies, validates, and prepares to launch +//! various machine learning models through three logical layers: +//! +//! 1. **[`ModelManifest`]** – the serialisable "identity card" of a model. +//! 2. **[`License`]** – a safety gate that blocks non-permissive models. +//! 3. **[`ModelRegistry`]** – the centralised service that tracks every model's lifecycle. +pub mod manifest; +pub mod registry; + +pub use manifest::{License, ModelManifest, TensorShape}; +pub use registry::{ModelRegistry, ModelStatus, RegistryError}; diff --git a/libraries/ai-models/src/manifest.rs b/libraries/ai-models/src/manifest.rs new file mode 100644 index 0000000000..5c9e3bca74 --- /dev/null +++ b/libraries/ai-models/src/manifest.rs @@ -0,0 +1,178 @@ +//! Model manifest – the serialisable identity of a machine-learning model. +use serde::{Deserialize, Serialize}; + +/// The shape of a single model input or output tensor. +/// +/// Each element is the size of that dimension; a value of `None` indicates a +/// dynamic (batch / variable-length) dimension. +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct TensorShape(pub Vec>); + +impl TensorShape { + /// Convenience constructor from a fixed-size shape. + pub fn fixed(dims: impl IntoIterator) -> Self { + Self(dims.into_iter().map(Some).collect()) + } + + /// Convenience constructor that marks the first dimension as dynamic + /// (batch) and the rest as fixed. + pub fn batched(dims: impl IntoIterator) -> Self { + let mut shape: Vec> = dims.into_iter().map(Some).collect(); + if let Some(first) = shape.first_mut() { + *first = None; + } + Self(shape) + } +} + +/// The open-source licence under which a model is distributed. +/// +/// Only the three variants listed in [`License::is_permissive`] are considered +/// compatible with Graphite's licensing standards. +#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub enum License { + /// MIT Licence + Mit, + /// BSD 2-Clause or 3-Clause licence + Bsd, + /// Apache Licence 2.0 + Apache2, + /// Any other licence whose SPDX identifier is not explicitly recognised. + Other(String), +} + +impl License { + /// Returns `true` only for licences that are permissive enough to be + /// distributed alongside Graphite without additional restrictions. + /// + /// Currently the permissive set is **MIT**, **BSD**, and **Apache-2.0**. + pub fn is_permissive(&self) -> bool { + matches!(self, License::Mit | License::Bsd | License::Apache2) + } +} + +impl std::fmt::Display for License { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + License::Mit => write!(f, "MIT"), + License::Bsd => write!(f, "BSD"), + License::Apache2 => write!(f, "Apache-2.0"), + License::Other(id) => write!(f, "{id}"), + } + } +} + +/// The complete identity and capability description of a machine-learning model. +/// +/// A manifest can be serialised to / deserialised from JSON so that it can be +/// shipped alongside the model weights on the CDN. +/// +/// # Example +/// ```rust +/// use ai_models::manifest::{License, ModelManifest, TensorShape}; +/// +/// let manifest = ModelManifest { +/// model_id: "sam2-base".to_string(), +/// version: "1.0.0".to_string(), +/// display_name: "SAM 2 (base)".to_string(), +/// description: "Segment Anything Model 2, base variant".to_string(), +/// license: License::Apache2, +/// input_shapes: vec![TensorShape::batched([3, 1024, 1024])], +/// output_shapes: vec![TensorShape::batched([1, 1024, 1024])], +/// download_url: "https://cdn.graphite.art/models/sam2-base/weights.bin".to_string(), +/// size_bytes: 358_000_000, +/// }; +/// +/// assert!(manifest.license.is_permissive()); +/// assert_eq!(manifest.model_id, "sam2-base"); +/// ``` +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct ModelManifest { + /// Short machine-readable identifier, e.g. `"sam2-base"`. + pub model_id: String, + + /// Semantic version of the model weights, e.g. `"1.0.0"`. + pub version: String, + + /// Human-readable name shown in the UI, e.g. `"SAM 2 (base)"`. + pub display_name: String, + + /// Short description of what the model does. + pub description: String, + + /// Licence under which the model weights are distributed. + pub license: License, + + /// Expected shapes of the model's input tensors (one entry per input port). + pub input_shapes: Vec, + + /// Expected shapes of the model's output tensors (one entry per output port). + pub output_shapes: Vec, + + /// URL from which the model weights can be downloaded. + pub download_url: String, + + /// Total download size in bytes. + pub size_bytes: u64, +} + +#[cfg(test)] +mod tests { + use super::*; + + fn sam2_manifest() -> ModelManifest { + ModelManifest { + model_id: "sam2-base".to_string(), + version: "1.0.0".to_string(), + display_name: "SAM 2 (base)".to_string(), + description: "Segment Anything Model 2, base variant".to_string(), + license: License::Apache2, + input_shapes: vec![TensorShape::batched([3, 1024, 1024])], + output_shapes: vec![TensorShape::batched([1, 1024, 1024])], + download_url: "https://cdn.graphite.art/models/sam2-base/weights.bin".to_string(), + size_bytes: 358_000_000, + } + } + + #[test] + fn apache2_is_permissive() { + assert!(License::Apache2.is_permissive()); + } + + #[test] + fn mit_is_permissive() { + assert!(License::Mit.is_permissive()); + } + + #[test] + fn bsd_is_permissive() { + assert!(License::Bsd.is_permissive()); + } + + #[test] + fn other_is_not_permissive() { + assert!(!License::Other("GPL-3.0".to_string()).is_permissive()); + } + + #[test] + fn manifest_roundtrip_json() { + let manifest = sam2_manifest(); + let json = serde_json::to_string(&manifest).expect("serialise"); + let back: ModelManifest = serde_json::from_str(&json).expect("deserialise"); + assert_eq!(manifest, back); + } + + #[test] + fn tensor_shape_fixed() { + let shape = TensorShape::fixed([3, 224, 224]); + assert_eq!(shape.0, vec![Some(3), Some(224), Some(224)]); + } + + #[test] + fn tensor_shape_batched_first_dim_is_none() { + let shape = TensorShape::batched([3, 1024, 1024]); + assert_eq!(shape.0[0], None); + assert_eq!(shape.0[1], Some(1024)); + } +} diff --git a/libraries/ai-models/src/registry.rs b/libraries/ai-models/src/registry.rs new file mode 100644 index 0000000000..6e319f0130 --- /dev/null +++ b/libraries/ai-models/src/registry.rs @@ -0,0 +1,310 @@ +//! Model registry – the centralised service that tracks every model's lifecycle. +use std::collections::HashMap; + +use thiserror::Error; + +use crate::manifest::ModelManifest; + +/// The lifecycle state of a single model. +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ModelStatus { + /// The model weights are present on disk and have been verified. + Ready, + /// The model is listed in the registry but has not been downloaded yet. + NotStarted, + /// The model weights are currently being fetched from the CDN. + Downloading { + /// Download progress in the range `[0.0, 1.0]`. + progress: f32, + }, + /// A previous download or verification attempt failed. + Failed { + /// Human-readable reason for the failure. + reason: String, + }, +} + +impl PartialEq for ModelStatus { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (ModelStatus::Ready, ModelStatus::Ready) => true, + (ModelStatus::NotStarted, ModelStatus::NotStarted) => true, + (ModelStatus::Downloading { progress: a }, ModelStatus::Downloading { progress: b }) => a.to_bits() == b.to_bits(), + (ModelStatus::Failed { reason: a }, ModelStatus::Failed { reason: b }) => a == b, + _ => false, + } + } +} + +impl std::fmt::Display for ModelStatus { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ModelStatus::Ready => write!(f, "Ready"), + ModelStatus::NotStarted => write!(f, "Not Started"), + ModelStatus::Downloading { progress } => write!(f, "Downloading ({:.0}%)", progress * 100.0), + ModelStatus::Failed { reason } => write!(f, "Failed: {reason}"), + } + } +} + +/// Errors that can be returned by [`ModelRegistry`] operations. +#[derive(Debug, Error, PartialEq)] +pub enum RegistryError { + /// The model ID was not found in the registry. + #[error("model '{0}' is not registered")] + NotFound(String), + + /// The model's licence is not permissive enough for Graphite to use. + #[error("model '{id}' uses a non-permissive licence ({license}); only MIT, BSD, and Apache-2.0 are accepted")] + LicenceNotPermissive { + /// The model identifier. + id: String, + /// The non-permissive licence that was rejected. + license: String, + }, + + /// Attempted to register a model whose ID already exists in the registry. + #[error("model '{0}' is already registered")] + AlreadyRegistered(String), +} + +/// Entry stored for each model inside the registry. +#[derive(Clone, Debug)] +struct RegistryEntry { + manifest: ModelManifest, + status: ModelStatus, +} + +/// The central manager that keeps track of which models are available and their +/// current lifecycle status. +/// +/// # Usage +/// ```rust +/// use ai_models::manifest::{License, ModelManifest, TensorShape}; +/// use ai_models::registry::{ModelRegistry, ModelStatus}; +/// +/// let mut registry = ModelRegistry::new(); +/// +/// let manifest = ModelManifest { +/// model_id: "sam2-base".to_string(), +/// version: "1.0.0".to_string(), +/// display_name: "SAM 2 (base)".to_string(), +/// description: "Segment Anything Model 2".to_string(), +/// license: License::Apache2, +/// input_shapes: vec![TensorShape::batched([3, 1024, 1024])], +/// output_shapes: vec![TensorShape::batched([1, 1024, 1024])], +/// download_url: "https://cdn.graphite.art/models/sam2-base/weights.bin".to_string(), +/// size_bytes: 358_000_000, +/// }; +/// +/// registry.register(manifest).expect("register model"); +/// assert_eq!(registry.status("sam2-base").unwrap(), &ModelStatus::NotStarted); +/// ``` +#[derive(Debug, Default)] +pub struct ModelRegistry { + entries: HashMap, +} + +impl ModelRegistry { + /// Creates a new, empty registry. + pub fn new() -> Self { + Self::default() + } + + /// Registers a new model manifest. + /// + /// # Errors + /// * [`RegistryError::LicenceNotPermissive`] – the manifest's licence is not MIT, BSD, or Apache-2.0. + /// * [`RegistryError::AlreadyRegistered`] – a model with the same `model_id` already exists. + pub fn register(&mut self, manifest: ModelManifest) -> Result<(), RegistryError> { + if !manifest.license.is_permissive() { + return Err(RegistryError::LicenceNotPermissive { + id: manifest.model_id.clone(), + license: manifest.license.to_string(), + }); + } + + if self.entries.contains_key(&manifest.model_id) { + return Err(RegistryError::AlreadyRegistered(manifest.model_id)); + } + + log::info!("Registering model '{}' v{}", manifest.model_id, manifest.version); + + self.entries.insert( + manifest.model_id.clone(), + RegistryEntry { + manifest, + status: ModelStatus::NotStarted, + }, + ); + Ok(()) + } + + /// Returns the current [`ModelStatus`] for `model_id`. + /// + /// # Errors + /// [`RegistryError::NotFound`] if the model is not registered. + pub fn status(&self, model_id: &str) -> Result<&ModelStatus, RegistryError> { + self.entries + .get(model_id) + .map(|e| &e.status) + .ok_or_else(|| RegistryError::NotFound(model_id.to_string())) + } + + /// Updates the status of a registered model. + /// + /// # Errors + /// [`RegistryError::NotFound`] if the model is not registered. + pub fn set_status(&mut self, model_id: &str, status: ModelStatus) -> Result<(), RegistryError> { + self.entries + .get_mut(model_id) + .map(|e| { + log::debug!("Model '{}' status → {status}", model_id); + e.status = status; + }) + .ok_or_else(|| RegistryError::NotFound(model_id.to_string())) + } + + /// Returns the [`ModelManifest`] for `model_id`. + /// + /// # Errors + /// [`RegistryError::NotFound`] if the model is not registered. + pub fn manifest(&self, model_id: &str) -> Result<&ModelManifest, RegistryError> { + self.entries + .get(model_id) + .map(|e| &e.manifest) + .ok_or_else(|| RegistryError::NotFound(model_id.to_string())) + } + + /// Returns `true` if the model is registered **and** its status is [`ModelStatus::Ready`]. + pub fn is_ready(&self, model_id: &str) -> bool { + self.entries.get(model_id).is_some_and(|e| matches!(e.status, ModelStatus::Ready)) + } + + /// Returns an iterator over all registered model IDs. + pub fn model_ids(&self) -> impl Iterator { + self.entries.keys().map(String::as_str) + } + + /// Returns a list of (model_id, status) pairs for every registered model. + pub fn all_statuses(&self) -> Vec<(&str, &ModelStatus)> { + let mut pairs: Vec<(&str, &ModelStatus)> = self.entries.iter().map(|(id, e)| (id.as_str(), &e.status)).collect(); + pairs.sort_by_key(|(id, _)| *id); + pairs + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::manifest::{License, TensorShape}; + + fn make_manifest(id: &str, license: License) -> ModelManifest { + ModelManifest { + model_id: id.to_string(), + version: "1.0.0".to_string(), + display_name: id.to_string(), + description: String::new(), + license, + input_shapes: vec![TensorShape::fixed([3, 224, 224])], + output_shapes: vec![TensorShape::fixed([1000])], + download_url: format!("https://cdn.example.com/{id}/weights.bin"), + size_bytes: 1_000_000, + } + } + + #[test] + fn register_and_default_status_is_not_started() { + let mut registry = ModelRegistry::new(); + registry.register(make_manifest("model-a", License::Mit)).unwrap(); + assert_eq!(registry.status("model-a").unwrap(), &ModelStatus::NotStarted); + } + + #[test] + fn register_non_permissive_licence_is_blocked() { + let mut registry = ModelRegistry::new(); + let err = registry.register(make_manifest("bad-model", License::Other("GPL-3.0".to_string()))).unwrap_err(); + assert!(matches!(err, RegistryError::LicenceNotPermissive { .. })); + } + + #[test] + fn duplicate_registration_returns_error() { + let mut registry = ModelRegistry::new(); + registry.register(make_manifest("model-b", License::Bsd)).unwrap(); + let err = registry.register(make_manifest("model-b", License::Bsd)).unwrap_err(); + assert_eq!(err, RegistryError::AlreadyRegistered("model-b".to_string())); + } + + #[test] + fn set_status_ready() { + let mut registry = ModelRegistry::new(); + registry.register(make_manifest("model-c", License::Apache2)).unwrap(); + registry.set_status("model-c", ModelStatus::Ready).unwrap(); + assert!(registry.is_ready("model-c")); + } + + #[test] + fn set_status_downloading() { + let mut registry = ModelRegistry::new(); + registry.register(make_manifest("model-d", License::Mit)).unwrap(); + registry.set_status("model-d", ModelStatus::Downloading { progress: 0.42 }).unwrap(); + assert!(!registry.is_ready("model-d")); + assert_eq!(registry.status("model-d").unwrap(), &ModelStatus::Downloading { progress: 0.42 }); + } + + #[test] + fn set_status_failed() { + let mut registry = ModelRegistry::new(); + registry.register(make_manifest("model-e", License::Apache2)).unwrap(); + registry.set_status("model-e", ModelStatus::Failed { reason: "network error".to_string() }).unwrap(); + assert!(!registry.is_ready("model-e")); + } + + #[test] + fn status_of_unknown_model_returns_not_found() { + let registry = ModelRegistry::new(); + assert_eq!(registry.status("ghost"), Err(RegistryError::NotFound("ghost".to_string()))); + } + + #[test] + fn is_ready_returns_false_for_unknown_model() { + let registry = ModelRegistry::new(); + assert!(!registry.is_ready("ghost")); + } + + #[test] + fn manifest_retrieval() { + let mut registry = ModelRegistry::new(); + let m = make_manifest("model-f", License::Mit); + registry.register(m.clone()).unwrap(); + assert_eq!(registry.manifest("model-f").unwrap().model_id, "model-f"); + } + + #[test] + fn all_statuses_is_sorted() { + let mut registry = ModelRegistry::new(); + registry.register(make_manifest("zzz-model", License::Mit)).unwrap(); + registry.register(make_manifest("aaa-model", License::Bsd)).unwrap(); + let statuses = registry.all_statuses(); + assert_eq!(statuses[0].0, "aaa-model"); + assert_eq!(statuses[1].0, "zzz-model"); + } + + #[test] + fn model_status_display() { + assert_eq!(ModelStatus::Ready.to_string(), "Ready"); + assert_eq!(ModelStatus::NotStarted.to_string(), "Not Started"); + assert_eq!(ModelStatus::Downloading { progress: 0.5 }.to_string(), "Downloading (50%)"); + assert_eq!(ModelStatus::Failed { reason: "err".to_string() }.to_string(), "Failed: err"); + } +} +Cargo.lock +Cargo.toml +libraries/ai-models +Cargo.toml +src +lib.rs +manifest.rs +registry.rs + diff --git a/node-graph/libraries/rendering/src/render_ext.rs b/node-graph/libraries/rendering/src/render_ext.rs index d7736f804b..2de73ddc41 100644 --- a/node-graph/libraries/rendering/src/render_ext.rs +++ b/node-graph/libraries/rendering/src/render_ext.rs @@ -107,10 +107,23 @@ impl RenderExt for Stroke { render_params: &RenderParams, ) -> Self::Output { // Don't render a stroke at all if it would be invisible - let Some(color) = self.color else { return String::new() }; if !self.has_renderable_stroke() { return String::new(); } + let paint = match (&self.gradient, self.color) { + (Some(gradient), _) => { + let gradient_id = gradient.render(_svg_defs, _element_transform, _stroke_transform, _bounds, _transformed_bounds, render_params); + format!(r##" stroke="url('#{gradient_id}')""##) + } + (_, Some(color)) => { + let mut result = format!(r##" stroke="#{}""##, color.to_rgb_hex_srgb_from_gamma()); + if color.a() < 1. { + let _ = write!(result, r#" stroke-opacity="{}""#, (color.a() * 1000.).round() / 1000.); + } + result + } + _ => return String::new(), + }; let default_weight = if self.align != StrokeAlign::Center && render_params.aligned_strokes { 1. / 2. } else { 1. }; @@ -125,10 +138,7 @@ impl RenderExt for Stroke { let paint_order = (self.paint_order != PaintOrder::StrokeAbove || render_params.override_paint_order).then_some(PaintOrder::StrokeBelow); // Render the needed stroke attributes - let mut attributes = format!(r##" stroke="#{}""##, color.to_rgb_hex_srgb_from_gamma()); - if color.a() < 1. { - let _ = write!(&mut attributes, r#" stroke-opacity="{}""#, (color.a() * 1000.).round() / 1000.); - } + let mut attributes = paint; if let Some(mut weight) = weight { if stroke_align.is_some() && render_params.aligned_strokes { weight *= 2.; diff --git a/node-graph/libraries/rendering/src/renderer.rs b/node-graph/libraries/rendering/src/renderer.rs index c589e6177d..f9a42d03e3 100644 --- a/node-graph/libraries/rendering/src/renderer.rs +++ b/node-graph/libraries/rendering/src/renderer.rs @@ -1115,34 +1115,90 @@ impl Render for Table { }; let do_stroke = |scene: &mut Scene, width_scale: f64| { - if let Some(stroke) = row.element.style.stroke() { - let color = match stroke.color { - Some(color) => peniko::Color::new([color.r(), color.g(), color.b(), color.a()]), - None => peniko::Color::TRANSPARENT, - }; - let cap = match stroke.cap { + if let Some(stroke_style) = row.element.style.stroke() { + let cap = match stroke_style.cap { StrokeCap::Butt => Cap::Butt, StrokeCap::Round => Cap::Round, StrokeCap::Square => Cap::Square, }; - let join = match stroke.join { + let join = match stroke_style.join { StrokeJoin::Miter => Join::Miter, StrokeJoin::Bevel => Join::Bevel, StrokeJoin::Round => Join::Round, }; - let dash_pattern = stroke.dash_lengths.iter().map(|l| l.max(0.)).collect(); - let stroke = kurbo::Stroke { - width: stroke.weight * width_scale, - miter_limit: stroke.join_miter_limit, + let dash_pattern = stroke_style.dash_lengths.iter().map(|l| l.max(0.)).collect(); + let kurbo_stroke = kurbo::Stroke { + width: stroke_style.weight * width_scale, + miter_limit: stroke_style.join_miter_limit, join, start_cap: cap, end_cap: cap, dash_pattern, - dash_offset: stroke.dash_offset, + dash_offset: stroke_style.dash_offset, }; - if stroke.width > 0. { - scene.stroke(&stroke, kurbo::Affine::new(element_transform.to_cols_array()), color, None, &path); + if kurbo_stroke.width > 0. { + let (brush, brush_transform) = if let Some(gradient) = stroke_style.gradient.as_ref() { + let mut stops = peniko::ColorStops::new(); + for (position, color, _) in gradient.stops.interpolated_samples() { + stops.push(peniko::ColorStop { + offset: position as f32, + color: peniko::color::DynamicColor::from_alpha_color(peniko::Color::new([color.r(), color.g(), color.b(), color.a()])), + }); + } + + let bounds = row.element.nonzero_bounding_box(); + let bound_transform = DAffine2::from_scale_angle_translation(bounds[1] - bounds[0], 0., bounds[0]); + + let inverse_parent_transform = if parent_transform.matrix2.determinant() != 0. { + parent_transform.inverse() + } else { + Default::default() + }; + let mod_points = inverse_parent_transform * multiplied_transform * bound_transform; + + let start = mod_points.transform_point2(gradient.start); + let end = mod_points.transform_point2(gradient.end); + + let brush = peniko::Brush::Gradient(peniko::Gradient { + kind: match gradient.gradient_type { + GradientType::Linear => peniko::LinearGradientPosition { + start: to_point(start), + end: to_point(end), + } + .into(), + GradientType::Radial => { + let radius = start.distance(end); + peniko::RadialGradientPosition { + start_center: to_point(start), + start_radius: 0., + end_center: to_point(start), + end_radius: radius as f32, + } + .into() + } + }, + stops, + interpolation_alpha_space: peniko::InterpolationAlphaSpace::Premultiplied, + ..Default::default() + }); + let inverse_element_transform = if element_transform.matrix2.determinant() != 0. { + element_transform.inverse() + } else { + Default::default() + }; + let brush_transform = kurbo::Affine::new((inverse_element_transform * parent_transform).to_cols_array()); + + (brush, Some(brush_transform)) + } else { + let color = stroke_style + .color + .map(|color| peniko::Color::new([color.r(), color.g(), color.b(), color.a()])) + .unwrap_or(peniko::Color::TRANSPARENT); + (peniko::Brush::Solid(color), None) + }; + + scene.stroke(&kurbo_stroke, kurbo::Affine::new(element_transform.to_cols_array()), &brush, brush_transform, &path); } } }; diff --git a/node-graph/libraries/vector-types/src/vector/style.rs b/node-graph/libraries/vector-types/src/vector/style.rs index 0828c4e6f2..ce096e7ad1 100644 --- a/node-graph/libraries/vector-types/src/vector/style.rs +++ b/node-graph/libraries/vector-types/src/vector/style.rs @@ -304,6 +304,8 @@ fn daffine2_identity() -> DAffine2 { pub struct Stroke { /// Stroke color pub color: Option, + /// Optional gradient paint. If set, overrides `color`. + pub gradient: Option, /// Line thickness pub weight: f64, pub dash_lengths: Vec, @@ -325,6 +327,7 @@ pub struct Stroke { impl std::hash::Hash for Stroke { fn hash(&self, state: &mut H) { self.color.hash(state); + self.gradient.hash(state); self.weight.to_bits().hash(state); { self.dash_lengths.len().hash(state); @@ -344,6 +347,7 @@ impl Stroke { pub const fn new(color: Option, weight: f64) -> Self { Self { color, + gradient: None, weight, dash_lengths: Vec::new(), dash_offset: 0., @@ -359,6 +363,12 @@ impl Stroke { pub fn lerp(&self, other: &Self, time: f64) -> Self { Self { color: self.color.map(|color| color.lerp(&other.color.unwrap_or(color), time as f32)), + gradient: match (&self.gradient, &other.gradient) { + (Some(a), Some(b)) => Some(a.lerp(b, time)), + (Some(a), None) if time < 0.5 => Some(a.clone()), + (None, Some(b)) if time >= 0.5 => Some(b.clone()), + _ => None, + }, weight: self.weight + (other.weight - self.weight) * time, dash_lengths: self.dash_lengths.iter().zip(other.dash_lengths.iter()).map(|(a, b)| a + (b - a) * time).collect(), dash_offset: self.dash_offset + (other.dash_offset - self.dash_offset) * time, @@ -398,6 +408,10 @@ impl Stroke { pub fn color(&self) -> Option { self.color } + /// Get the current stroke gradient. + pub fn gradient(&self) -> Option<&Gradient> { + self.gradient.as_ref() + } /// Get the current stroke weight. pub fn weight(&self) -> f64 { @@ -440,9 +454,20 @@ impl Stroke { pub fn with_color(mut self, color: &Option) -> Option { self.color = *color; + if color.is_some() { + self.gradient = None; + } Some(self) } + /// Set the stroke's gradient, replacing the color if necessary. + pub fn with_gradient(mut self, gradient: Option) -> Self { + self.gradient = gradient; + if self.gradient.is_some() { + self.color = None; + } + self + } pub fn with_weight(mut self, weight: f64) -> Self { self.weight = weight; @@ -488,7 +513,14 @@ impl Stroke { } pub fn has_renderable_stroke(&self) -> bool { - self.weight > 0. && self.color.is_some_and(|color| color.a() != 0.) + if self.weight <= 0. { + return false; + } + + let has_color_alpha = self.color.is_some_and(|color| color.a() != 0.); + let has_gradient_alpha = self.gradient.as_ref().is_some_and(|gradient| gradient.stops.color.iter().any(|color| color.a() != 0.)); + + has_color_alpha || has_gradient_alpha } } @@ -498,6 +530,7 @@ impl Default for Stroke { Self { weight: 0., color: Some(Color::from_rgba8_srgb(0, 0, 0, 255)), + gradient: None, dash_lengths: Vec::new(), dash_offset: 0., cap: StrokeCap::Butt, @@ -530,7 +563,14 @@ impl std::fmt::Display for PathStyle { let fill = &self.fill; let stroke = match &self.stroke { - Some(stroke) => format!("#{} (Weight: {} px)", stroke.color.map_or("None".to_string(), |c| c.to_rgba_hex_srgb()), stroke.weight), + Some(stroke) => { + let paint = match (&stroke.gradient, stroke.color) { + (Some(_), _) => "Gradient".to_string(), + (_, Some(color)) => format!("#{}", color.to_rgba_hex_srgb()), + _ => "None".to_string(), + }; + format!("{paint} (Weight: {} px)", stroke.weight) + } None => "None".to_string(), }; diff --git a/node-graph/nodes/vector/src/vector_nodes.rs b/node-graph/nodes/vector/src/vector_nodes.rs index f33d1a17fc..aea254524b 100644 --- a/node-graph/nodes/vector/src/vector_nodes.rs +++ b/node-graph/nodes/vector/src/vector_nodes.rs @@ -204,6 +204,7 @@ where { let stroke = Stroke { color: color.into(), + gradient: None, weight, dash_lengths: dash_lengths.into_vec(), dash_offset,