From 0399d1af1ac0be89f84b41bab9253e4aeeb49c8f Mon Sep 17 00:00:00 2001 From: firestar99 Date: Sun, 24 May 2026 19:28:41 +0200 Subject: [PATCH] GPU: add support for N-many input images and not just exactly one --- .../per_pixel_adjust_runtime.rs | 141 ++++++++++-------- .../src/shader_nodes/per_pixel_adjust.rs | 22 +-- node-graph/nodes/raster/src/blending_nodes.rs | 2 +- 3 files changed, 91 insertions(+), 74 deletions(-) diff --git a/node-graph/libraries/wgpu-executor/src/shader_runtime/per_pixel_adjust_runtime.rs b/node-graph/libraries/wgpu-executor/src/shader_runtime/per_pixel_adjust_runtime.rs index f5d793e9a2..ec66b4ed2d 100644 --- a/node-graph/libraries/wgpu-executor/src/shader_runtime/per_pixel_adjust_runtime.rs +++ b/node-graph/libraries/wgpu-executor/src/shader_runtime/per_pixel_adjust_runtime.rs @@ -33,7 +33,8 @@ impl PerPixelAdjustShaderRuntime { } impl ShaderRuntime { - pub async fn run_per_pixel_adjust(&self, shaders: &Shaders<'_>, textures: List>, args: Option<&T>) -> List> { + pub async fn run_per_pixel_adjust(&self, shaders: &Shaders<'_>, textures: &[List>], args: Option<&T>) -> List> { + assert_eq!(shaders.input_images, textures.len()); let mut cache = self.per_pixel_adjust.pipeline_cache.lock().await; let pipeline = cache .entry(shaders.fragment_shader_name.to_owned()) @@ -54,11 +55,13 @@ impl ShaderRuntime { pub struct Shaders<'a> { pub wgsl_shader: &'a str, pub fragment_shader_name: &'a str, + pub input_images: usize, pub has_uniform: bool, } pub struct PerPixelAdjustGraphicsPipeline { name: String, + input_images: usize, has_uniform: bool, pipeline: wgpu::RenderPipeline, } @@ -76,32 +79,23 @@ impl PerPixelAdjustGraphicsPipeline { source: ShaderSource::Wgsl(Cow::Borrowed(info.wgsl_shader)), }); - let entries: &[_] = if info.has_uniform { - &[ - BindGroupLayoutEntry { - binding: 0, - visibility: ShaderStages::FRAGMENT, - ty: BindingType::Buffer { - ty: BufferBindingType::Storage { read_only: true }, - has_dynamic_offset: false, - min_binding_size: None, - }, - count: None, - }, - BindGroupLayoutEntry { - binding: 1, - visibility: ShaderStages::FRAGMENT, - ty: BindingType::Texture { - sample_type: TextureSampleType::Float { filterable: false }, - view_dimension: TextureViewDimension::D2, - multisampled: false, - }, - count: None, + let mut binding_alloc = Counter::default(); + let mut entries = Vec::new(); + if info.has_uniform { + entries.push(BindGroupLayoutEntry { + binding: binding_alloc.alloc(), + visibility: ShaderStages::FRAGMENT, + ty: BindingType::Buffer { + ty: BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: false, + min_binding_size: None, }, - ] - } else { - &[BindGroupLayoutEntry { - binding: 0, + count: None, + }); + } + for _ in 0..info.input_images { + entries.push(BindGroupLayoutEntry { + binding: binding_alloc.alloc(), visibility: ShaderStages::FRAGMENT, ty: BindingType::Texture { sample_type: TextureSampleType::Float { filterable: false }, @@ -109,13 +103,13 @@ impl PerPixelAdjustGraphicsPipeline { multisampled: false, }, count: None, - }] - }; + }); + } let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor { label: Some(&format!("PerPixelAdjust {name} PipelineLayout")), bind_group_layouts: &[Some(&device.create_bind_group_layout(&BindGroupLayoutDescriptor { label: Some(&format!("PerPixelAdjust {name} BindGroupLayout 0")), - entries, + entries: &entries, }))], ..Default::default() }); @@ -157,61 +151,73 @@ impl PerPixelAdjustGraphicsPipeline { pipeline, name, has_uniform: info.has_uniform, + input_images: info.input_images, } } - pub fn dispatch(&self, context: &WgpuContext, textures: List>, arg_buffer: Option) -> List> { + pub fn dispatch(&self, context: &WgpuContext, in_textures: &[List>], arg_buffer: Option) -> List> { assert_eq!(self.has_uniform, arg_buffer.is_some()); + assert_eq!(self.input_images, in_textures.len()); let device = &context.device; let name = self.name.as_str(); + // Assumption: when we have multiple input images to our node, each input's List of images can have a different + // length. Only process the minimum between all input images, same as `impl Blend for List>`. + let dispatch_cnt = match in_textures.iter().map(|t| t.len()).min() { + None => { + return List::new(); + } + Some(e) => e, + }; + let mut cmd = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some(&format!("{name} cmd encoder")), }); - let out = (0..textures.len()) - .map(|index| { - let element = textures.element(index).unwrap(); - let tex_in = &element.texture; - let view_in = tex_in.create_view(&TextureViewDescriptor::default()); - let format = tex_in.format(); - - let entries: &[_] = if let Some(arg_buffer) = arg_buffer.as_ref() { - &[ - BindGroupEntry { - binding: 0, - resource: BindingResource::Buffer(BufferBinding { - buffer: arg_buffer, - offset: 0, - size: None, - }), - }, - BindGroupEntry { - binding: 1, - resource: BindingResource::TextureView(&view_in), - }, - ] - } else { - &[BindGroupEntry { - binding: 0, + let out = (0..dispatch_cnt) + .map(|dispatch_id| { + let mut binding_alloc = Counter::default(); + let mut entries = Vec::new(); + if let Some(arg_buffer) = arg_buffer.as_ref() { + entries.push(BindGroupEntry { + binding: binding_alloc.alloc(), + resource: BindingResource::Buffer(BufferBinding { + buffer: arg_buffer, + offset: 0, + size: None, + }), + }); + } + let in_texture_views = in_textures.iter().map(|texture| { + let element = texture.element(dispatch_id).unwrap(); + element.texture.create_view(&TextureViewDescriptor::default()) + }).collect::>(); + for view_in in &in_texture_views { + entries.push(BindGroupEntry { + binding: binding_alloc.alloc(), resource: BindingResource::TextureView(&view_in), - }] - }; + }); + } + let bind_group = device.create_bind_group(&BindGroupDescriptor { label: Some(&format!("{name} bind group")), // `get_bind_group_layout` allocates unnecessary memory, we could create it manually to not do that layout: &self.pipeline.get_bind_group_layout(0), - entries, + entries: &entries, }); + // Assumption: The output texture has the same size and format as the first input texture. Like the + // blend node, that writes the output directly back into the first texture. + let outref_list = &in_textures[0]; + let outref_tex = &outref_list.element(dispatch_id).unwrap().texture; let tex_out = device.create_texture(&TextureDescriptor { label: Some(&format!("{name} texture out")), - size: tex_in.size(), + size: outref_tex.size(), mip_level_count: 1, sample_count: 1, dimension: TextureDimension::D2, - format, + format: outref_tex.format(), usage: wgpu::TextureUsages::TEXTURE_BINDING | wgpu::TextureUsages::COPY_DST | wgpu::TextureUsages::COPY_SRC | wgpu::TextureUsages::RENDER_ATTACHMENT, - view_formats: &[format], + view_formats: &[outref_tex.format()], }); let view_out = tex_out.create_view(&TextureViewDescriptor::default()); @@ -233,7 +239,7 @@ impl PerPixelAdjustGraphicsPipeline { rp.set_bind_group(0, Some(&bind_group), &[]); rp.draw(0..3, 0..1); - let attributes = textures.clone_item_attributes(index); + let attributes = outref_list.clone_item_attributes(dispatch_id); Item::from_parts(Raster::new(GPU { texture: tex_out }), attributes) }) .collect::>(); @@ -241,3 +247,14 @@ impl PerPixelAdjustGraphicsPipeline { out } } + +#[derive(Clone, Debug, Default)] +pub struct Counter(pub u32); + +impl Counter { + pub fn alloc(&mut self) -> u32 { + let out = self.0; + self.0 += 1; + out + } +} diff --git a/node-graph/node-macro/src/shader_nodes/per_pixel_adjust.rs b/node-graph/node-macro/src/shader_nodes/per_pixel_adjust.rs index f551e11821..3cafc4953e 100644 --- a/node-graph/node-macro/src/shader_nodes/per_pixel_adjust.rs +++ b/node-graph/node-macro/src/shader_nodes/per_pixel_adjust.rs @@ -248,16 +248,15 @@ impl PerPixelAdjustCodegen<'_> { is_data_field: false, }); - // find exactly one gpu_image field, runtime doesn't support more than 1 atm - let gpu_image_field = { - let mut iter = fields.iter().filter(|f| matches!(f.ty, ParsedFieldType::Regular(RegularParsedField { gpu_image: true, .. }))); - match (iter.next(), iter.next()) { - (Some(v), None) => Ok(v), - (Some(_), Some(more)) => Err(syn::Error::new_spanned(&more.pat_ident, "No more than one parameter must be annotated with `#[gpu_image]`")), - (None, _) => Err(syn::Error::new_spanned(&self.parsed.fn_name, "At least one parameter must be annotated with `#[gpu_image]`")), - }? - }; - let gpu_image = &gpu_image_field.pat_ident.ident; + // find gpu_image fields + let gpu_images = fields + .iter() + .filter_map(|f| match f.ty { + ParsedFieldType::Regular(RegularParsedField { gpu_image: true, .. }) => Some(&f.pat_ident.ident), + _ => None, + }) + .collect::>(); + let input_images = gpu_images.len(); // uniform buffer struct construction let has_uniform = self.has_uniform; @@ -287,7 +286,8 @@ impl PerPixelAdjustCodegen<'_> { wgsl_shader: crate::WGSL_SHADER, fragment_shader_name: super::#entry_point_name, has_uniform: #has_uniform, - }, #gpu_image, #uniform_buffer).await + input_images: #input_images, + }, &[#(#gpu_images),*], #uniform_buffer).await } }; diff --git a/node-graph/nodes/raster/src/blending_nodes.rs b/node-graph/nodes/raster/src/blending_nodes.rs index 1b6c4f6977..389954970a 100644 --- a/node-graph/nodes/raster/src/blending_nodes.rs +++ b/node-graph/nodes/raster/src/blending_nodes.rs @@ -141,7 +141,7 @@ pub fn apply_blend_mode(foreground: Color, background: Color, blend_mode: BlendM } } -#[node_macro::node(category("Raster"), cfg(feature = "std"))] +#[node_macro::node(category("Raster"), shader_node(PerPixelAdjust))] fn mix + Send>( _: impl Ctx, #[implementations(