Skip to main content

viewport_lib/resources/
gpu_marching_cubes.rs

1//! GPU marching cubes — Phase 17.
2//!
3//! Three-pass GPU compute pipeline for isosurface extraction:
4//!   1. Classify — computes case index and triangle count per cell.
5//!   2. Prefix sum — hierarchical exclusive scan to build triangle offsets.
6//!   3. Generate — interpolates vertex positions and normals into a vertex buffer.
7//!
8//! The output is drawn with a lightweight Phong render pipeline via `draw_indirect`.
9
10use bytemuck::{Pod, Zeroable};
11use wgpu::util::DeviceExt as _;
12
13use crate::{
14    geometry::marching_cubes::{TRI_TABLE, VolumeData},
15    resources::{DualPipeline, ViewportGpuResources},
16    scene::material::Material,
17};
18
19// ---------------------------------------------------------------------------
20// Public API
21// ---------------------------------------------------------------------------
22
23/// Handle to a volume scalar field uploaded for GPU marching cubes.
24///
25/// Returned by [`ViewportGpuResources::upload_volume_for_mc`]. Pass to
26/// [`GpuMarchingCubesJob`] to select which volume to triangulate each frame.
27#[derive(Clone, Copy, Debug, PartialEq, Eq)]
28pub struct VolumeGpuId(pub(crate) usize);
29
30/// One GPU marching cubes draw job submitted per frame.
31///
32/// The volume referenced by `volume_id` is triangulated on the GPU at `isovalue`
33/// and drawn with `material`. No CPU readback occurs; the vertex count is
34/// determined by an indirect draw call.
35pub struct GpuMarchingCubesJob {
36    /// Volume to triangulate (must remain alive).
37    pub volume_id: VolumeGpuId,
38    /// Isovalue at which to extract the surface.
39    pub isovalue: f32,
40    /// Surface material (colour + roughness).
41    pub material: Material,
42    /// Pick ID for unified selection API. `0` = not selectable.
43    pub id: u64,
44    /// Per-item appearance overrides (hidden, unlit, opacity, wireframe).
45    pub appearance: crate::scene::material::AppearanceSettings,
46    /// If `true`, draws an outline ring around the marching cubes surface.
47    pub selected: bool,
48    /// CPU-side volume data for `pick()` and `pick_rect()`.
49    ///
50    /// When set, the CPU picker ray-marches the actual scalar field and detects
51    /// isovalue crossings rather than falling back to the volume AABB. `None`
52    /// means the item is not reachable by the CPU picking path.
53    pub cpu_data: Option<std::sync::Arc<crate::geometry::marching_cubes::VolumeData>>,
54}
55
56// ---------------------------------------------------------------------------
57// GPU-internal types
58// ---------------------------------------------------------------------------
59
60/// GPU buffers for one Z-axis slab of an uploaded volume.
61///
62/// A slab covers `dims[2]` scalar Z-layers (`dims[2] - 1` cell layers).
63/// Adjacent slabs share exactly one scalar Z-layer at their boundary so MC
64/// edge interpolation produces no seams.
65pub(crate) struct McSlabGpuData {
66    pub scalar_buf: wgpu::Buffer,   // f32 per slab node; STORAGE | COPY_DST
67    pub counts_buf: wgpu::Buffer,   // u32 per slab cell; STORAGE
68    pub case_idx_buf: wgpu::Buffer, // u32 per slab cell; STORAGE
69    pub offsets_buf: wgpu::Buffer,  // u32 per slab cell; STORAGE
70    pub block_sums_buf: wgpu::Buffer, // u32 per slab block; STORAGE
71    pub vertex_buf: wgpu::Buffer,   // f32 * 6 per vertex; STORAGE | VERTEX
72    pub indirect_buf: wgpu::Buffer, // 4 u32; STORAGE | INDIRECT (surface draw)
73    pub wire_indirect_buf: wgpu::Buffer, // 4 u32; STORAGE | INDIRECT (wireframe draw)
74    pub dims: [u32; 3],             // [nx, ny, slab_nz] (scalar layers)
75    pub origin: [f32; 3],           // world origin; z is offset per slab
76    pub spacing: [f32; 3],
77    pub cell_count: u32,
78    pub block_count: u32,
79}
80
81/// Persistent GPU resources for one uploaded volume, split into Z-axis slabs.
82///
83/// Z-axis chunking keeps every allocation within `device.limits().max_buffer_size`
84/// regardless of volume size. The single-slab path is equivalent to the old layout.
85pub(crate) struct McVolumeGpuData {
86    pub slabs: Vec<McSlabGpuData>,
87    /// False after `remove_mc_volume` is called; the slot is reused lazily.
88    pub alive: bool,
89}
90
91/// Per-frame data for one MC job, consumed by the render phase.
92pub(crate) struct McFrameData {
93    pub volume_idx: usize,
94    pub render_bg: wgpu::BindGroup,
95    /// True if this job was submitted with `appearance.wireframe = true`.
96    pub wireframe: bool,
97    /// Per-slab bind groups for the wireframe pipeline (binding 0 = vertex storage buffer).
98    pub wire_slab_bgs: Vec<wgpu::BindGroup>,
99}
100
101/// Per-selected MC job data for the outline mask pass.
102pub(crate) struct McOutlineItem {
103    /// Index into `mc_gpu_data` (frame-level array of processed MC jobs).
104    pub mc_gpu_idx: usize,
105    pub _uniform_buf: wgpu::Buffer,
106    pub mask_bind_group: wgpu::BindGroup,
107}
108
109// ---------------------------------------------------------------------------
110// Raw uniform buffer layouts (bytemuck-safe)
111// ---------------------------------------------------------------------------
112
113#[repr(C)]
114#[derive(Clone, Copy, Pod, Zeroable)]
115struct ClassifyParams {
116    nx: u32,
117    ny: u32,
118    nz: u32,
119    isovalue: f32,
120}
121
122#[repr(C)]
123#[derive(Clone, Copy, Pod, Zeroable)]
124struct PrefixSumParams {
125    cell_count: u32,
126    block_count: u32,
127    level: u32,
128    _pad: u32,
129}
130
131#[repr(C)]
132#[derive(Clone, Copy, Pod, Zeroable)]
133struct GenerateParams {
134    nx: u32,
135    ny: u32,
136    nz: u32,
137    isovalue: f32,
138    origin_x: f32,
139    origin_y: f32,
140    origin_z: f32,
141    _pad0: f32,
142    spacing_x: f32,
143    spacing_y: f32,
144    spacing_z: f32,
145    _pad1: f32,
146}
147
148#[repr(C)]
149#[derive(Clone, Copy, Pod, Zeroable)]
150struct McSurfaceRaw {
151    base_colour: [f32; 3],
152    roughness: f32,
153    unlit: u32,
154    opacity: f32,
155    _pad: [u32; 2],
156}
157
158// ---------------------------------------------------------------------------
159// Lookup table helpers
160// ---------------------------------------------------------------------------
161
162/// Triangle count per case: derived from TRI_TABLE by counting non-sentinel entries.
163fn case_triangle_count_table() -> [u32; 256] {
164    let mut out = [0u32; 256];
165    for (i, row) in TRI_TABLE.iter().enumerate() {
166        let mut count = 0u32;
167        let mut j = 0;
168        while j < 15 && row[j] >= 0 {
169            count += 1;
170            j += 3;
171        }
172        out[i] = count;
173    }
174    out
175}
176
177/// Flat TRI_TABLE for the GPU: 256 × 16 i32 values.
178fn case_table_flat() -> [i32; 256 * 16] {
179    let mut out = [-1i32; 256 * 16];
180    for (i, row) in TRI_TABLE.iter().enumerate() {
181        for (j, &v) in row.iter().enumerate() {
182            out[i * 16 + j] = v as i32;
183        }
184    }
185    out
186}
187
188// ---------------------------------------------------------------------------
189// Pipeline init and volume upload (impl ViewportGpuResources)
190// ---------------------------------------------------------------------------
191
192impl ViewportGpuResources {
193    /// Lazily create all GPU MC pipelines and shared lookup buffers.
194    ///
195    /// No-op if already initialised.
196    pub(crate) fn ensure_mc_pipelines(&mut self, device: &wgpu::Device) {
197        if self.mc_classify_pipeline.is_some() {
198            return;
199        }
200
201        // ----------------------------------------------------------------
202        // Shared lookup buffers (uploaded once).
203        // ----------------------------------------------------------------
204        let count_table = case_triangle_count_table();
205        let mc_case_count_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
206            label: Some("mc_case_count_buf"),
207            contents: bytemuck::cast_slice(&count_table),
208            usage: wgpu::BufferUsages::STORAGE,
209        });
210
211        let flat_table = case_table_flat();
212        let mc_case_table_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
213            label: Some("mc_case_table_buf"),
214            contents: bytemuck::cast_slice(&flat_table),
215            usage: wgpu::BufferUsages::STORAGE,
216        });
217
218        // ----------------------------------------------------------------
219        // Bind group layouts.
220        // ----------------------------------------------------------------
221
222        // Classify: 5 bindings (uniform + 2 read storage + 2 rw storage).
223        let classify_bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
224            label: Some("mc_classify_bgl"),
225            entries: &[
226                bgl_uniform(0),
227                bgl_storage_ro(1),
228                bgl_storage_ro(2),
229                bgl_storage_rw(3),
230                bgl_storage_rw(4),
231            ],
232        });
233
234        // Prefix sum: 6 bindings (uniform + ro + 3 rw + wire_indirect_buf rw).
235        let prefix_sum_bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
236            label: Some("mc_prefix_sum_bgl"),
237            entries: &[
238                bgl_uniform(0),
239                bgl_storage_ro(1),
240                bgl_storage_rw(2),
241                bgl_storage_rw(3),
242                bgl_storage_rw(4),
243                bgl_storage_rw(5), // wire_indirect_buf
244            ],
245        });
246
247        // Generate: 6 bindings (uniform + 3 ro + 2 rw [case_indices ro, vertex_buf rw]).
248        let generate_bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
249            label: Some("mc_generate_bgl"),
250            entries: &[
251                bgl_uniform(0),
252                bgl_storage_ro(1),
253                bgl_storage_ro(2),
254                bgl_storage_ro(3),
255                bgl_storage_ro(4),
256                bgl_storage_rw(5),
257            ],
258        });
259
260        // Surface render: one per-draw material uniform.
261        let render_bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
262            label: Some("mc_render_bgl"),
263            entries: &[wgpu::BindGroupLayoutEntry {
264                binding: 0,
265                visibility: wgpu::ShaderStages::FRAGMENT,
266                ty: wgpu::BindingType::Buffer {
267                    ty: wgpu::BufferBindingType::Uniform,
268                    has_dynamic_offset: false,
269                    min_binding_size: None,
270                },
271                count: None,
272            }],
273        });
274
275        // ----------------------------------------------------------------
276        // Compute pipelines.
277        // ----------------------------------------------------------------
278        let classify_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
279            label: Some("mc_classify_shader"),
280            source: wgpu::ShaderSource::Wgsl(include_str!("../shaders/mc_classify.wgsl").into()),
281        });
282        let classify_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
283            label: Some("mc_classify_layout"),
284            bind_group_layouts: &[&classify_bgl],
285            push_constant_ranges: &[],
286        });
287        let classify_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
288            label: Some("mc_classify_pipeline"),
289            layout: Some(&classify_layout),
290            module: &classify_shader,
291            entry_point: Some("main"),
292            compilation_options: wgpu::PipelineCompilationOptions::default(),
293            cache: None,
294        });
295
296        let prefix_sum_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
297            label: Some("mc_prefix_sum_shader"),
298            source: wgpu::ShaderSource::Wgsl(include_str!("../shaders/mc_prefix_sum.wgsl").into()),
299        });
300        let prefix_sum_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
301            label: Some("mc_prefix_sum_layout"),
302            bind_group_layouts: &[&prefix_sum_bgl],
303            push_constant_ranges: &[],
304        });
305        let prefix_sum_pipeline =
306            device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
307                label: Some("mc_prefix_sum_pipeline"),
308                layout: Some(&prefix_sum_layout),
309                module: &prefix_sum_shader,
310                entry_point: Some("main"),
311                compilation_options: wgpu::PipelineCompilationOptions::default(),
312                cache: None,
313            });
314
315        let generate_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
316            label: Some("mc_generate_shader"),
317            source: wgpu::ShaderSource::Wgsl(include_str!("../shaders/mc_generate.wgsl").into()),
318        });
319        let generate_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
320            label: Some("mc_generate_layout"),
321            bind_group_layouts: &[&generate_bgl],
322            push_constant_ranges: &[],
323        });
324        let generate_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
325            label: Some("mc_generate_pipeline"),
326            layout: Some(&generate_layout),
327            module: &generate_shader,
328            entry_point: Some("main"),
329            compilation_options: wgpu::PipelineCompilationOptions::default(),
330            cache: None,
331        });
332
333        // ----------------------------------------------------------------
334        // Surface render pipeline.
335        // ----------------------------------------------------------------
336        let surface_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
337            label: Some("mc_surface_shader"),
338            source: wgpu::ShaderSource::Wgsl(include_str!("../shaders/mc_surface.wgsl").into()),
339        });
340        let surface_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
341            label: Some("mc_surface_layout"),
342            bind_group_layouts: &[&self.camera_bind_group_layout, &render_bgl],
343            push_constant_ranges: &[],
344        });
345
346        let vertex_attrs = [
347            wgpu::VertexAttribute {
348                format: wgpu::VertexFormat::Float32x3,
349                offset: 0,
350                shader_location: 0,
351            },
352            wgpu::VertexAttribute {
353                format: wgpu::VertexFormat::Float32x3,
354                offset: 12,
355                shader_location: 1,
356            },
357        ];
358        let vertex_layout = wgpu::VertexBufferLayout {
359            array_stride: 24,
360            step_mode: wgpu::VertexStepMode::Vertex,
361            attributes: &vertex_attrs,
362        };
363
364        let make_surface = |fmt: wgpu::TextureFormat| {
365            device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
366                label: Some("mc_surface_pipeline"),
367                layout: Some(&surface_layout),
368                vertex: wgpu::VertexState {
369                    module: &surface_shader,
370                    entry_point: Some("vs_main"),
371                    buffers: &[vertex_layout.clone()],
372                    compilation_options: wgpu::PipelineCompilationOptions::default(),
373                },
374                fragment: Some(wgpu::FragmentState {
375                    module: &surface_shader,
376                    entry_point: Some("fs_main"),
377                    targets: &[Some(wgpu::ColorTargetState {
378                        format: fmt,
379                        blend: Some(wgpu::BlendState::ALPHA_BLENDING),
380                        write_mask: wgpu::ColorWrites::ALL,
381                    })],
382                    compilation_options: wgpu::PipelineCompilationOptions::default(),
383                }),
384                primitive: wgpu::PrimitiveState {
385                    topology: wgpu::PrimitiveTopology::TriangleList,
386                    cull_mode: None,
387                    ..Default::default()
388                },
389                depth_stencil: Some(wgpu::DepthStencilState {
390                    format: wgpu::TextureFormat::Depth24PlusStencil8,
391                    depth_write_enabled: true,
392                    depth_compare: wgpu::CompareFunction::LessEqual,
393                    stencil: wgpu::StencilState::default(),
394                    bias: wgpu::DepthBiasState::default(),
395                }),
396                multisample: wgpu::MultisampleState {
397                    count: 1,
398                    mask: !0,
399                    alpha_to_coverage_enabled: false,
400                },
401                multiview: None,
402                cache: None,
403            })
404        };
405
406        // ----------------------------------------------------------------
407        // Wireframe render pipeline.
408        // ----------------------------------------------------------------
409        let wireframe_render_bgl =
410            device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
411                label: Some("mc_wireframe_render_bgl"),
412                entries: &[wgpu::BindGroupLayoutEntry {
413                    binding: 0,
414                    visibility: wgpu::ShaderStages::VERTEX,
415                    ty: wgpu::BindingType::Buffer {
416                        ty: wgpu::BufferBindingType::Storage { read_only: true },
417                        has_dynamic_offset: false,
418                        min_binding_size: None,
419                    },
420                    count: None,
421                }],
422            });
423
424        let wireframe_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
425            label: Some("mc_wireframe_shader"),
426            source: wgpu::ShaderSource::Wgsl(
427                include_str!("../shaders/mc_wireframe.wgsl").into(),
428            ),
429        });
430        let wireframe_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
431            label: Some("mc_wireframe_layout"),
432            bind_group_layouts: &[&self.camera_bind_group_layout, &wireframe_render_bgl],
433            push_constant_ranges: &[],
434        });
435        let make_wireframe = |fmt: wgpu::TextureFormat| {
436            device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
437                label: Some("mc_wireframe_pipeline"),
438                layout: Some(&wireframe_layout),
439                vertex: wgpu::VertexState {
440                    module: &wireframe_shader,
441                    entry_point: Some("vs_main"),
442                    buffers: &[], // positions read from storage buffer
443                    compilation_options: wgpu::PipelineCompilationOptions::default(),
444                },
445                fragment: Some(wgpu::FragmentState {
446                    module: &wireframe_shader,
447                    entry_point: Some("fs_main"),
448                    targets: &[Some(wgpu::ColorTargetState {
449                        format: fmt,
450                        blend: Some(wgpu::BlendState::ALPHA_BLENDING),
451                        write_mask: wgpu::ColorWrites::ALL,
452                    })],
453                    compilation_options: wgpu::PipelineCompilationOptions::default(),
454                }),
455                primitive: wgpu::PrimitiveState {
456                    topology: wgpu::PrimitiveTopology::LineList,
457                    cull_mode: None,
458                    ..Default::default()
459                },
460                depth_stencil: Some(wgpu::DepthStencilState {
461                    format: wgpu::TextureFormat::Depth24PlusStencil8,
462                    depth_write_enabled: true,
463                    depth_compare: wgpu::CompareFunction::LessEqual,
464                    stencil: wgpu::StencilState::default(),
465                    bias: wgpu::DepthBiasState::default(),
466                }),
467                multisample: wgpu::MultisampleState {
468                    count: 1,
469                    mask: !0,
470                    alpha_to_coverage_enabled: false,
471                },
472                multiview: None,
473                cache: None,
474            })
475        };
476
477        // ----------------------------------------------------------------
478        // Commit all resources.
479        // ----------------------------------------------------------------
480        self.mc_case_count_buf = Some(mc_case_count_buf);
481        self.mc_case_table_buf = Some(mc_case_table_buf);
482        self.mc_classify_bgl = Some(classify_bgl);
483        self.mc_prefix_sum_bgl = Some(prefix_sum_bgl);
484        self.mc_generate_bgl = Some(generate_bgl);
485        self.mc_render_bgl = Some(render_bgl);
486        self.mc_classify_pipeline = Some(classify_pipeline);
487        self.mc_prefix_sum_pipeline = Some(prefix_sum_pipeline);
488        self.mc_generate_pipeline = Some(generate_pipeline);
489        self.mc_surface_pipeline = Some(DualPipeline {
490            ldr: make_surface(self.target_format),
491            hdr: make_surface(wgpu::TextureFormat::Rgba16Float),
492        });
493        self.mc_wireframe_render_bgl = Some(wireframe_render_bgl);
494        self.mc_wireframe_pipeline = Some(DualPipeline {
495            ldr: make_wireframe(self.target_format),
496            hdr: make_wireframe(wgpu::TextureFormat::Rgba16Float),
497        });
498    }
499
500    /// Upload a [`VolumeData`] to GPU, pre-allocating all intermediate and output
501    /// buffers for GPU marching cubes.
502    ///
503    /// The returned [`VolumeGpuId`] is stable until [`remove_mc_volume`] is called.
504    ///
505    /// Returns `Err(ViewportError::McBufferTooLarge)` if any required buffer exceeds
506    /// the device's `max_buffer_size`; the caller should fall back to CPU isosurface
507    /// extraction.
508    pub fn upload_volume_for_mc(
509        &mut self,
510        device: &wgpu::Device,
511        queue: &wgpu::Queue,
512        vol: &VolumeData,
513    ) -> crate::ViewportResult<VolumeGpuId> {
514        let [nx, ny, nz] = vol.dims;
515        // The vertex buffer is bound as both STORAGE (compute) and VERTEX (render).
516        // The binding limit for compute shaders is max_storage_buffer_binding_size, which
517        // is often half of max_buffer_size (e.g. 128 MiB vs 256 MiB). Use the smaller of
518        // the two so slab sizing respects both constraints.
519        let max_binding = device.limits().max_storage_buffer_binding_size as u64;
520        let max_buf = device.limits().max_buffer_size;
521        let max_limit = max_binding.min(max_buf);
522
523        // Worst-case vertex buffer bytes per Z-cell-layer:
524        // (nx-1)*(ny-1) cells × 5 triangles × 3 vertices × 24 bytes = cells_xy × 360.
525        // Compute how many Z-cell layers fit within the effective limit.
526        let cells_xy = (nx - 1) as u64 * (ny - 1) as u64;
527        let max_cells_per_slab = max_limit / (15 * 24);
528        let z_cells_per_slab = if cells_xy > 0 {
529            (max_cells_per_slab / cells_xy).min((nz - 1) as u64) as u32
530        } else {
531            nz - 1
532        };
533        if z_cells_per_slab == 0 {
534            // Even a single Z-layer of cells exceeds the effective binding limit.
535            return Err(crate::ViewportError::McBufferTooLarge {
536                buffer: "vertex_buf",
537                needed: cells_xy * 15 * 24,
538                limit: max_limit,
539            });
540        }
541
542        let nz_cells_total = nz - 1;
543        let slab_count = nz_cells_total.div_ceil(z_cells_per_slab);
544        let nodes_per_z = (nx * ny) as usize;
545
546        let mut slabs = Vec::with_capacity(slab_count as usize);
547
548        for s in 0..slab_count {
549            let z_cell_start = s * z_cells_per_slab;
550            let z_cell_end = (z_cell_start + z_cells_per_slab).min(nz_cells_total);
551            let slab_z_cells = z_cell_end - z_cell_start; // cell layers in this slab
552            let slab_nz = slab_z_cells + 1; // scalar layers in this slab
553
554            // slab_cell_count is bounded by max_cells_per_slab, which fits in u32
555            // at any realistic max_buffer_size value.
556            let slab_cell_count = (cells_xy * slab_z_cells as u64) as u32;
557            let slab_block_count = slab_cell_count.div_ceil(256);
558            let slab_cell_bytes = (slab_cell_count as u64) * 4;
559            let slab_block_bytes = (slab_block_count as u64) * 4;
560            // At most 15 vertices per cell (5 triangles × 3 vertices) × 24 bytes each.
561            let slab_vertex_bytes = (slab_cell_count as u64) * 15 * 24;
562
563            // Scalar data is x-fastest: index = x + y*nx + z*nx*ny.
564            // A Z-slab covering scalar layers z_cell_start..z_cell_start+slab_nz is
565            // a contiguous slice — no copying required.
566            let scalar_start = z_cell_start as usize * nodes_per_z;
567            let scalar_end = (z_cell_start + slab_nz) as usize * nodes_per_z;
568            let slab_origin_z = vol.origin[2] + z_cell_start as f32 * vol.spacing[2];
569
570            let scalar_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
571                label: Some("mc_scalar_buf"),
572                contents: bytemuck::cast_slice(&vol.data[scalar_start..scalar_end]),
573                usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
574            });
575            let counts_buf = device.create_buffer(&wgpu::BufferDescriptor {
576                label: Some("mc_counts_buf"),
577                size: slab_cell_bytes,
578                usage: wgpu::BufferUsages::STORAGE,
579                mapped_at_creation: false,
580            });
581            let case_idx_buf = device.create_buffer(&wgpu::BufferDescriptor {
582                label: Some("mc_case_idx_buf"),
583                size: slab_cell_bytes,
584                usage: wgpu::BufferUsages::STORAGE,
585                mapped_at_creation: false,
586            });
587            let offsets_buf = device.create_buffer(&wgpu::BufferDescriptor {
588                label: Some("mc_offsets_buf"),
589                size: slab_cell_bytes,
590                usage: wgpu::BufferUsages::STORAGE,
591                mapped_at_creation: false,
592            });
593            let block_sums_buf = device.create_buffer(&wgpu::BufferDescriptor {
594                label: Some("mc_block_sums_buf"),
595                size: slab_block_bytes,
596                usage: wgpu::BufferUsages::STORAGE,
597                mapped_at_creation: false,
598            });
599            let vertex_buf = device.create_buffer(&wgpu::BufferDescriptor {
600                label: Some("mc_vertex_buf"),
601                size: slab_vertex_bytes,
602                usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::VERTEX,
603                mapped_at_creation: false,
604            });
605            let initial_indirect = bytemuck::cast_slice(&[0u32, 1u32, 0u32, 0u32]);
606            let indirect_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
607                label: Some("mc_indirect_buf"),
608                // Initial: 0 vertices, 1 instance, 0 first_vertex, 0 first_instance.
609                contents: initial_indirect,
610                usage: wgpu::BufferUsages::STORAGE
611                    | wgpu::BufferUsages::INDIRECT
612                    | wgpu::BufferUsages::COPY_DST,
613            });
614            let wire_indirect_buf =
615                device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
616                    label: Some("mc_wire_indirect_buf"),
617                    contents: initial_indirect,
618                    usage: wgpu::BufferUsages::STORAGE
619                        | wgpu::BufferUsages::INDIRECT
620                        | wgpu::BufferUsages::COPY_DST,
621                });
622
623            slabs.push(McSlabGpuData {
624                scalar_buf,
625                counts_buf,
626                case_idx_buf,
627                offsets_buf,
628                block_sums_buf,
629                vertex_buf,
630                indirect_buf,
631                wire_indirect_buf,
632                dims: [nx, ny, slab_nz],
633                origin: [vol.origin[0], vol.origin[1], slab_origin_z],
634                spacing: vol.spacing,
635                cell_count: slab_cell_count,
636                block_count: slab_block_count,
637            });
638        }
639
640        let _ = queue; // retained for potential future use (e.g. scalar updates)
641
642        let gpu_data = McVolumeGpuData { slabs, alive: true };
643
644        // Find a free slot (from a previous remove_mc_volume call) or push a new one.
645        let idx = if let Some(free_idx) = self.mc_volumes.iter().position(|v| !v.alive) {
646            self.mc_volumes[free_idx] = gpu_data;
647            free_idx
648        } else {
649            self.mc_volumes.push(gpu_data);
650            self.mc_volumes.len() - 1
651        };
652
653        Ok(VolumeGpuId(idx))
654    }
655
656    /// Mark a MC volume slot as free. The GPU buffers are dropped immediately.
657    pub fn remove_mc_volume(&mut self, id: VolumeGpuId) {
658        if let Some(v) = self.mc_volumes.get_mut(id.0) {
659            v.alive = false;
660        }
661    }
662
663    /// Lazily create the MC surface outline mask pipeline.
664    ///
665    /// Layout is `[camera_bind_group_layout, outline_bind_group_layout]`. The vertex
666    /// buffer matches the MC output format: stride 24 (position f32x3 at offset 0,
667    /// normal f32x3 at offset 12). Uses the existing `outline_mask.wgsl` shader since
668    /// only position is needed. No-op if already created.
669    pub(crate) fn ensure_mc_outline_mask_pipeline(&mut self, device: &wgpu::Device) {
670        if self.mc_outline_mask_pipeline.is_some() {
671            return;
672        }
673
674        let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
675            label: Some("mc_outline_mask_shader"),
676            source: wgpu::ShaderSource::Wgsl(include_str!("../shaders/outline_mask.wgsl").into()),
677        });
678
679        let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
680            label: Some("mc_outline_mask_pipeline_layout"),
681            bind_group_layouts: &[
682                &self.camera_bind_group_layout,
683                &self.outline_bind_group_layout,
684            ],
685            push_constant_ranges: &[],
686        });
687
688        let vert_attrs = [wgpu::VertexAttribute {
689            offset: 0,
690            shader_location: 0,
691            format: wgpu::VertexFormat::Float32x3,
692        }];
693        let vert_layout = wgpu::VertexBufferLayout {
694            array_stride: 24, // position (12 bytes) + normal (12 bytes)
695            step_mode: wgpu::VertexStepMode::Vertex,
696            attributes: &vert_attrs,
697        };
698
699        self.mc_outline_mask_pipeline = Some(device.create_render_pipeline(
700            &wgpu::RenderPipelineDescriptor {
701                label: Some("mc_outline_mask_pipeline"),
702                layout: Some(&layout),
703                vertex: wgpu::VertexState {
704                    module: &shader,
705                    entry_point: Some("vs_main"),
706                    buffers: &[vert_layout],
707                    compilation_options: wgpu::PipelineCompilationOptions::default(),
708                },
709                fragment: Some(wgpu::FragmentState {
710                    module: &shader,
711                    entry_point: Some("fs_main"),
712                    targets: &[Some(wgpu::ColorTargetState {
713                        format: wgpu::TextureFormat::R8Unorm,
714                        blend: None,
715                        write_mask: wgpu::ColorWrites::ALL,
716                    })],
717                    compilation_options: wgpu::PipelineCompilationOptions::default(),
718                }),
719                primitive: wgpu::PrimitiveState {
720                    topology: wgpu::PrimitiveTopology::TriangleList,
721                    cull_mode: None,
722                    ..Default::default()
723                },
724                depth_stencil: Some(wgpu::DepthStencilState {
725                    format: wgpu::TextureFormat::Depth24PlusStencil8,
726                    depth_write_enabled: true,
727                    depth_compare: wgpu::CompareFunction::Less,
728                    stencil: wgpu::StencilState::default(),
729                    bias: wgpu::DepthBiasState::default(),
730                }),
731                multisample: wgpu::MultisampleState {
732                    count: 1,
733                    mask: !0,
734                    alpha_to_coverage_enabled: false,
735                },
736                multiview: None,
737                cache: None,
738            },
739        ));
740    }
741
742    /// Dispatch all three compute passes for every pending MC job.
743    ///
744    /// Returns the per-frame render data to be stored in `ViewportRenderer.mc_gpu_data`.
745    pub(crate) fn run_mc_jobs(
746        &self,
747        device: &wgpu::Device,
748        queue: &wgpu::Queue,
749        jobs: &[GpuMarchingCubesJob],
750    ) -> Vec<McFrameData> {
751        if jobs.is_empty() {
752            return Vec::new();
753        }
754
755        let classify_pipeline = self.mc_classify_pipeline.as_ref().expect("mc pipelines");
756        let prefix_sum_pipeline = self.mc_prefix_sum_pipeline.as_ref().unwrap();
757        let generate_pipeline = self.mc_generate_pipeline.as_ref().unwrap();
758        let classify_bgl = self.mc_classify_bgl.as_ref().unwrap();
759        let prefix_sum_bgl = self.mc_prefix_sum_bgl.as_ref().unwrap();
760        let generate_bgl = self.mc_generate_bgl.as_ref().unwrap();
761        let render_bgl = self.mc_render_bgl.as_ref().unwrap();
762        let case_count_buf = self.mc_case_count_buf.as_ref().unwrap();
763        let case_table_buf = self.mc_case_table_buf.as_ref().unwrap();
764
765        let mut frame_data = Vec::with_capacity(jobs.len());
766        let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
767            label: Some("mc_compute_encoder"),
768        });
769
770        for job in jobs {
771            let vol = &self.mc_volumes[job.volume_id.0];
772            if !vol.alive {
773                continue;
774            }
775
776            // ----------------------------------------------------------
777            // Per-job surface material (one bind group shared by all slabs).
778            // ----------------------------------------------------------
779            let mat_raw = McSurfaceRaw {
780                base_colour: job.material.base_colour,
781                roughness: job.material.roughness,
782                unlit: job.appearance.unlit as u32,
783                opacity: job.appearance.opacity,
784                _pad: [0; 2],
785            };
786            let mat_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
787                label: Some("mc_surface_mat"),
788                contents: bytemuck::bytes_of(&mat_raw),
789                usage: wgpu::BufferUsages::UNIFORM,
790            });
791            let render_bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
792                label: Some("mc_render_bg"),
793                layout: render_bgl,
794                entries: &[wgpu::BindGroupEntry {
795                    binding: 0,
796                    resource: mat_buf.as_entire_binding(),
797                }],
798            });
799
800            // Run all three compute passes for each slab independently.
801            for slab in &vol.slabs {
802                let cc = slab.cell_count;
803                let bc = slab.block_count;
804
805                // ----------------------------------------------------------
806                // Per-slab classify uniform.
807                // ----------------------------------------------------------
808                let classify_params = ClassifyParams {
809                    nx: slab.dims[0],
810                    ny: slab.dims[1],
811                    nz: slab.dims[2],
812                    isovalue: job.isovalue,
813                };
814                let classify_uniform =
815                    device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
816                        label: Some("mc_classify_uniform"),
817                        contents: bytemuck::bytes_of(&classify_params),
818                        usage: wgpu::BufferUsages::UNIFORM,
819                    });
820
821                let classify_bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
822                    label: Some("mc_classify_bg"),
823                    layout: classify_bgl,
824                    entries: &[
825                        wgpu::BindGroupEntry {
826                            binding: 0,
827                            resource: classify_uniform.as_entire_binding(),
828                        },
829                        wgpu::BindGroupEntry {
830                            binding: 1,
831                            resource: slab.scalar_buf.as_entire_binding(),
832                        },
833                        wgpu::BindGroupEntry {
834                            binding: 2,
835                            resource: case_count_buf.as_entire_binding(),
836                        },
837                        wgpu::BindGroupEntry {
838                            binding: 3,
839                            resource: slab.counts_buf.as_entire_binding(),
840                        },
841                        wgpu::BindGroupEntry {
842                            binding: 4,
843                            resource: slab.case_idx_buf.as_entire_binding(),
844                        },
845                    ],
846                });
847
848                // ----------------------------------------------------------
849                // Per-slab prefix-sum uniforms (one per level).
850                // ----------------------------------------------------------
851                let ps_uniforms: [wgpu::Buffer; 3] = std::array::from_fn(|level| {
852                    let params = PrefixSumParams {
853                        cell_count: cc,
854                        block_count: bc,
855                        level: level as u32,
856                        _pad: 0,
857                    };
858                    device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
859                        label: Some("mc_ps_uniform"),
860                        contents: bytemuck::bytes_of(&params),
861                        usage: wgpu::BufferUsages::UNIFORM,
862                    })
863                });
864
865                let ps_bgs: [wgpu::BindGroup; 3] = std::array::from_fn(|level| {
866                    device.create_bind_group(&wgpu::BindGroupDescriptor {
867                        label: Some("mc_ps_bg"),
868                        layout: prefix_sum_bgl,
869                        entries: &[
870                            wgpu::BindGroupEntry {
871                                binding: 0,
872                                resource: ps_uniforms[level].as_entire_binding(),
873                            },
874                            wgpu::BindGroupEntry {
875                                binding: 1,
876                                resource: slab.counts_buf.as_entire_binding(),
877                            },
878                            wgpu::BindGroupEntry {
879                                binding: 2,
880                                resource: slab.offsets_buf.as_entire_binding(),
881                            },
882                            wgpu::BindGroupEntry {
883                                binding: 3,
884                                resource: slab.block_sums_buf.as_entire_binding(),
885                            },
886                            wgpu::BindGroupEntry {
887                                binding: 4,
888                                resource: slab.indirect_buf.as_entire_binding(),
889                            },
890                            wgpu::BindGroupEntry {
891                                binding: 5,
892                                resource: slab.wire_indirect_buf.as_entire_binding(),
893                            },
894                        ],
895                    })
896                });
897
898                // ----------------------------------------------------------
899                // Per-slab generate uniform (origin_z shifted by slab offset).
900                // ----------------------------------------------------------
901                let generate_params = GenerateParams {
902                    nx: slab.dims[0],
903                    ny: slab.dims[1],
904                    nz: slab.dims[2],
905                    isovalue: job.isovalue,
906                    origin_x: slab.origin[0],
907                    origin_y: slab.origin[1],
908                    origin_z: slab.origin[2],
909                    _pad0: 0.0,
910                    spacing_x: slab.spacing[0],
911                    spacing_y: slab.spacing[1],
912                    spacing_z: slab.spacing[2],
913                    _pad1: 0.0,
914                };
915                let generate_uniform =
916                    device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
917                        label: Some("mc_generate_uniform"),
918                        contents: bytemuck::bytes_of(&generate_params),
919                        usage: wgpu::BufferUsages::UNIFORM,
920                    });
921
922                let generate_bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
923                    label: Some("mc_generate_bg"),
924                    layout: generate_bgl,
925                    entries: &[
926                        wgpu::BindGroupEntry {
927                            binding: 0,
928                            resource: generate_uniform.as_entire_binding(),
929                        },
930                        wgpu::BindGroupEntry {
931                            binding: 1,
932                            resource: slab.scalar_buf.as_entire_binding(),
933                        },
934                        wgpu::BindGroupEntry {
935                            binding: 2,
936                            resource: case_table_buf.as_entire_binding(),
937                        },
938                        wgpu::BindGroupEntry {
939                            binding: 3,
940                            resource: slab.offsets_buf.as_entire_binding(),
941                        },
942                        wgpu::BindGroupEntry {
943                            binding: 4,
944                            resource: slab.case_idx_buf.as_entire_binding(),
945                        },
946                        wgpu::BindGroupEntry {
947                            binding: 5,
948                            resource: slab.vertex_buf.as_entire_binding(),
949                        },
950                    ],
951                });
952
953                // ----------------------------------------------------------
954                // Pass 1: classify.
955                // ----------------------------------------------------------
956                {
957                    let mut cp = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
958                        label: Some("mc_classify_pass"),
959                        timestamp_writes: None,
960                    });
961                    cp.set_pipeline(classify_pipeline);
962                    cp.set_bind_group(0, &classify_bg, &[]);
963                    cp.dispatch_workgroups(cc.div_ceil(256), 1, 1);
964                }
965
966                // ----------------------------------------------------------
967                // Pass 2a: prefix sum level 0.
968                // ----------------------------------------------------------
969                {
970                    let mut cp = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
971                        label: Some("mc_ps_level0_pass"),
972                        timestamp_writes: None,
973                    });
974                    cp.set_pipeline(prefix_sum_pipeline);
975                    cp.set_bind_group(0, &ps_bgs[0], &[]);
976                    cp.dispatch_workgroups(bc, 1, 1);
977                }
978
979                // ----------------------------------------------------------
980                // Pass 2b: prefix sum level 1 (single workgroup, sequential).
981                // ----------------------------------------------------------
982                {
983                    let mut cp = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
984                        label: Some("mc_ps_level1_pass"),
985                        timestamp_writes: None,
986                    });
987                    cp.set_pipeline(prefix_sum_pipeline);
988                    cp.set_bind_group(0, &ps_bgs[1], &[]);
989                    cp.dispatch_workgroups(1, 1, 1);
990                }
991
992                // ----------------------------------------------------------
993                // Pass 2c: prefix sum level 2 (propagate block offsets).
994                // ----------------------------------------------------------
995                {
996                    let mut cp = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
997                        label: Some("mc_ps_level2_pass"),
998                        timestamp_writes: None,
999                    });
1000                    cp.set_pipeline(prefix_sum_pipeline);
1001                    cp.set_bind_group(0, &ps_bgs[2], &[]);
1002                    cp.dispatch_workgroups(bc, 1, 1);
1003                }
1004
1005                // ----------------------------------------------------------
1006                // Pass 3: generate vertices.
1007                // ----------------------------------------------------------
1008                {
1009                    let mut cp = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1010                        label: Some("mc_generate_pass"),
1011                        timestamp_writes: None,
1012                    });
1013                    cp.set_pipeline(generate_pipeline);
1014                    cp.set_bind_group(0, &generate_bg, &[]);
1015                    cp.dispatch_workgroups(cc.div_ceil(256), 1, 1);
1016                }
1017            }
1018
1019            let wire_slab_bgs: Vec<wgpu::BindGroup> =
1020                if let Some(ref wire_bgl) = self.mc_wireframe_render_bgl {
1021                    vol.slabs
1022                        .iter()
1023                        .map(|slab| {
1024                            device.create_bind_group(&wgpu::BindGroupDescriptor {
1025                                label: Some("mc_wire_slab_bg"),
1026                                layout: wire_bgl,
1027                                entries: &[wgpu::BindGroupEntry {
1028                                    binding: 0,
1029                                    resource: slab.vertex_buf.as_entire_binding(),
1030                                }],
1031                            })
1032                        })
1033                        .collect()
1034                } else {
1035                    Vec::new()
1036                };
1037
1038            frame_data.push(McFrameData {
1039                volume_idx: job.volume_id.0,
1040                render_bg,
1041                wireframe: job.appearance.wireframe,
1042                wire_slab_bgs,
1043            });
1044        }
1045
1046        queue.submit(std::iter::once(encoder.finish()));
1047        frame_data
1048    }
1049}
1050
1051// ---------------------------------------------------------------------------
1052// Bind group layout entry helpers
1053// ---------------------------------------------------------------------------
1054
1055fn bgl_uniform(binding: u32) -> wgpu::BindGroupLayoutEntry {
1056    wgpu::BindGroupLayoutEntry {
1057        binding,
1058        visibility: wgpu::ShaderStages::COMPUTE,
1059        ty: wgpu::BindingType::Buffer {
1060            ty: wgpu::BufferBindingType::Uniform,
1061            has_dynamic_offset: false,
1062            min_binding_size: None,
1063        },
1064        count: None,
1065    }
1066}
1067
1068fn bgl_storage_ro(binding: u32) -> wgpu::BindGroupLayoutEntry {
1069    wgpu::BindGroupLayoutEntry {
1070        binding,
1071        visibility: wgpu::ShaderStages::COMPUTE,
1072        ty: wgpu::BindingType::Buffer {
1073            ty: wgpu::BufferBindingType::Storage { read_only: true },
1074            has_dynamic_offset: false,
1075            min_binding_size: None,
1076        },
1077        count: None,
1078    }
1079}
1080
1081fn bgl_storage_rw(binding: u32) -> wgpu::BindGroupLayoutEntry {
1082    wgpu::BindGroupLayoutEntry {
1083        binding,
1084        visibility: wgpu::ShaderStages::COMPUTE,
1085        ty: wgpu::BindingType::Buffer {
1086            ty: wgpu::BufferBindingType::Storage { read_only: false },
1087            has_dynamic_offset: false,
1088            min_binding_size: None,
1089        },
1090        count: None,
1091    }
1092}