Skip to main content

sci_form/gpu/
pipeline_coordinator.rs

1//! GPU Pipeline Coordinator — multi-kernel dispatch with memory-aware scheduling.
2//!
3//! Coordinates sequential GPU dispatches for multi-step pipelines
4//! (e.g., orbital grid → marching cubes → isosurface). Enforces
5//! memory budget checks before each dispatch and provides tier-aware
6//! scheduling decisions.
7
8use super::memory_budget::{GpuMemoryBudget, MemoryError};
9use super::shader_registry::ShaderDescriptor;
10
11/// A single step in a GPU computation pipeline.
12#[derive(Debug, Clone)]
13pub struct PipelineStep {
14    /// Human-readable label for this step.
15    pub label: String,
16    /// Reference to the shader descriptor.
17    pub shader_name: &'static str,
18    /// Estimated buffer sizes for each binding (bytes).
19    pub buffer_sizes: Vec<u64>,
20    /// Grid dimensions to dispatch over, or element count for 1D.
21    pub dispatch: DispatchShape,
22}
23
24/// How to dispatch a compute shader.
25#[derive(Debug, Clone)]
26pub enum DispatchShape {
27    /// 1D dispatch over `n` elements (shader divides by workgroup size).
28    Linear(u32),
29    /// 3D dispatch over a grid [nx, ny, nz].
30    Grid3D([u32; 3]),
31}
32
33impl DispatchShape {
34    /// Compute workgroup count given a workgroup size.
35    pub fn workgroup_count(&self, workgroup_size: [u32; 3]) -> [u32; 3] {
36        match self {
37            DispatchShape::Linear(n) => [(*n).div_ceil(workgroup_size[0]), 1, 1],
38            DispatchShape::Grid3D(dims) => [
39                dims[0].div_ceil(workgroup_size[0]),
40                dims[1].div_ceil(workgroup_size[1]),
41                dims[2].div_ceil(workgroup_size[2]),
42            ],
43        }
44    }
45}
46
47/// Outcome of validating a pipeline step.
48#[derive(Debug, Clone)]
49pub enum StepDecision {
50    /// GPU dispatch is feasible and recommended.
51    GpuDispatch {
52        workgroup_count: [u32; 3],
53        total_buffer_bytes: u64,
54    },
55    /// GPU dispatch possible but CPU may be faster for this tier/size.
56    CpuPreferred { reason: String },
57    /// Must fall back to CPU — exceeds GPU limits.
58    CpuRequired { error: MemoryError },
59}
60
61/// Coordinates multi-step GPU pipelines.
62#[derive(Debug)]
63pub struct PipelineCoordinator {
64    budget: GpuMemoryBudget,
65    gpu_available: bool,
66}
67
68impl PipelineCoordinator {
69    pub fn new(budget: GpuMemoryBudget, gpu_available: bool) -> Self {
70        Self {
71            budget,
72            gpu_available,
73        }
74    }
75
76    /// Create a coordinator that always falls back to CPU.
77    pub fn cpu_only() -> Self {
78        Self {
79            budget: GpuMemoryBudget::webgpu_default(),
80            gpu_available: false,
81        }
82    }
83
84    /// Evaluate whether a pipeline step should run on GPU or CPU.
85    pub fn evaluate_step(&self, step: &PipelineStep, shader: &ShaderDescriptor) -> StepDecision {
86        if !self.gpu_available {
87            return StepDecision::CpuPreferred {
88                reason: "No GPU available".to_string(),
89            };
90        }
91
92        // Check tier recommendation
93        if !shader.tier.gpu_recommended() {
94            return StepDecision::CpuPreferred {
95                reason: format!("{} — {}", shader.tier, shader.description),
96            };
97        }
98
99        // Check buffer sizes against budget
100        let total_bytes: u64 = step.buffer_sizes.iter().sum();
101        for (i, &size) in step.buffer_sizes.iter().enumerate() {
102            if let Err(e) = self.budget.check_buffer(size) {
103                return StepDecision::CpuRequired { error: e };
104            }
105            // Also check storage binding limit for buffer i
106            if i as u32 >= self.budget.limits.max_storage_buffers_per_stage {
107                return StepDecision::CpuRequired {
108                    error: MemoryError::TooManyBindings {
109                        current: i as u32 + 1,
110                        max: self.budget.limits.max_storage_buffers_per_stage,
111                    },
112                };
113            }
114        }
115
116        // Compute workgroup count and validate
117        let wg_count = step.dispatch.workgroup_count(shader.workgroup_size);
118        if let Err(e) = self.budget.check_workgroup(shader.workgroup_size, wg_count) {
119            return StepDecision::CpuRequired { error: e };
120        }
121
122        StepDecision::GpuDispatch {
123            workgroup_count: wg_count,
124            total_buffer_bytes: total_bytes,
125        }
126    }
127
128    /// Validate an entire multi-step pipeline and return decisions for each step.
129    pub fn plan_pipeline(
130        &self,
131        steps: &[PipelineStep],
132        shaders: &[&ShaderDescriptor],
133    ) -> PipelinePlan {
134        assert_eq!(steps.len(), shaders.len(), "Steps and shaders must match");
135
136        let decisions: Vec<(String, StepDecision)> = steps
137            .iter()
138            .zip(shaders.iter())
139            .map(|(step, shader)| {
140                let decision = self.evaluate_step(step, shader);
141                (step.label.clone(), decision)
142            })
143            .collect();
144
145        let gpu_steps = decisions
146            .iter()
147            .filter(|(_, d)| matches!(d, StepDecision::GpuDispatch { .. }))
148            .count();
149        let cpu_steps = decisions.len() - gpu_steps;
150
151        PipelinePlan {
152            decisions,
153            gpu_steps,
154            cpu_steps,
155        }
156    }
157
158    /// Determine if a large pairwise computation should be chunked.
159    ///
160    /// Returns chunk sizes for atoms when the pairwise O(N²) matrix
161    /// exceeds the storage buffer binding limit.
162    pub fn compute_chunks(&self, n_atoms: usize, bytes_per_pair: u64) -> Vec<(usize, usize)> {
163        let max_pairs = self.budget.limits.max_storage_buffer_binding_size / bytes_per_pair;
164        let max_chunk = (max_pairs as f64).sqrt() as usize;
165
166        if n_atoms <= max_chunk {
167            return vec![(0, n_atoms)];
168        }
169
170        let mut chunks = Vec::new();
171        let mut start = 0;
172        while start < n_atoms {
173            let end = (start + max_chunk).min(n_atoms);
174            chunks.push((start, end));
175            start = end;
176        }
177        chunks
178    }
179}
180
181/// Result of planning an entire pipeline.
182#[derive(Debug)]
183pub struct PipelinePlan {
184    /// Decision for each step: (label, decision).
185    pub decisions: Vec<(String, StepDecision)>,
186    /// Number of steps that will run on GPU.
187    pub gpu_steps: usize,
188    /// Number of steps that will fall back to CPU.
189    pub cpu_steps: usize,
190}
191
192impl PipelinePlan {
193    /// Whether all steps will run on GPU.
194    pub fn fully_gpu(&self) -> bool {
195        self.cpu_steps == 0
196    }
197
198    /// Generate a text report of the pipeline plan.
199    pub fn report(&self) -> String {
200        let mut out = format!(
201            "Pipeline Plan: {} GPU / {} CPU steps\n",
202            self.gpu_steps, self.cpu_steps
203        );
204        for (label, decision) in &self.decisions {
205            let status = match decision {
206                StepDecision::GpuDispatch {
207                    workgroup_count,
208                    total_buffer_bytes,
209                } => {
210                    format!(
211                        "GPU → wg[{},{},{}], {:.1} KB",
212                        workgroup_count[0],
213                        workgroup_count[1],
214                        workgroup_count[2],
215                        *total_buffer_bytes as f64 / 1024.0
216                    )
217                }
218                StepDecision::CpuPreferred { reason } => {
219                    format!("CPU (preferred) — {reason}")
220                }
221                StepDecision::CpuRequired { error } => {
222                    format!("CPU (required) — {error}")
223                }
224            };
225            out.push_str(&format!("  [{label}] {status}\n"));
226        }
227        out
228    }
229}
230
231// ─── Pre-built pipeline templates ──────────────────────────────────
232
233/// Build a standard orbital visualization pipeline.
234///
235/// Steps: orbital grid → marching cubes (positive lobe) → marching cubes (negative lobe).
236pub fn orbital_visualization_pipeline(
237    n_basis: usize,
238    grid_dims: [u32; 3],
239    max_primitives: usize,
240) -> Vec<PipelineStep> {
241    let basis_bytes = (n_basis * 32) as u64;
242    let mo_bytes = (n_basis * 4) as u64;
243    let prim_bytes = (n_basis * max_primitives * 8) as u64;
244    let grid_points = grid_dims[0] as u64 * grid_dims[1] as u64 * grid_dims[2] as u64;
245    let grid_bytes = grid_points * 4;
246
247    // Estimate worst-case triangle output: 5 triangles per active voxel, 10% active
248    let est_triangles = (grid_points / 10) * 5;
249    let vertex_bytes = est_triangles * 24; // 24 bytes per Vertex (vec3 pos + vec3 normal)
250
251    vec![
252        PipelineStep {
253            label: "Orbital Grid".to_string(),
254            shader_name: "orbital_grid",
255            buffer_sizes: vec![basis_bytes, mo_bytes, prim_bytes, 32, grid_bytes],
256            dispatch: DispatchShape::Grid3D(grid_dims),
257        },
258        PipelineStep {
259            label: "Marching Cubes (positive lobe)".to_string(),
260            shader_name: "marching_cubes",
261            buffer_sizes: vec![
262                grid_bytes,   // scalar_field
263                256 * 4,      // edge_table
264                256 * 16 * 4, // tri_table
265                vertex_bytes, // vertices output
266                4,            // tri_count atomic
267                32,           // params
268            ],
269            dispatch: DispatchShape::Grid3D([grid_dims[0] - 1, grid_dims[1] - 1, grid_dims[2] - 1]),
270        },
271        PipelineStep {
272            label: "Marching Cubes (negative lobe)".to_string(),
273            shader_name: "marching_cubes",
274            buffer_sizes: vec![grid_bytes, 256 * 4, 256 * 16 * 4, vertex_bytes, 4, 32],
275            dispatch: DispatchShape::Grid3D([grid_dims[0] - 1, grid_dims[1] - 1, grid_dims[2] - 1]),
276        },
277    ]
278}
279
280#[cfg(test)]
281mod tests {
282    use super::super::shader_registry;
283    use super::*;
284
285    #[test]
286    fn test_dispatch_1d() {
287        let shape = DispatchShape::Linear(1000);
288        let wg = shape.workgroup_count([64, 1, 1]);
289        assert_eq!(wg, [16, 1, 1]); // ceil(1000/64)
290    }
291
292    #[test]
293    fn test_dispatch_3d() {
294        let shape = DispatchShape::Grid3D([64, 64, 64]);
295        let wg = shape.workgroup_count([8, 8, 4]);
296        assert_eq!(wg, [8, 8, 16]);
297    }
298
299    #[test]
300    fn test_gpu_dispatch_feasible() {
301        let coord = PipelineCoordinator::new(GpuMemoryBudget::webgpu_default(), true);
302        let step = PipelineStep {
303            label: "test".to_string(),
304            shader_name: "orbital_grid",
305            buffer_sizes: vec![1024, 512, 2048, 32, 500_000],
306            dispatch: DispatchShape::Grid3D([50, 50, 50]),
307        };
308        let shader = shader_registry::find_shader("orbital_grid").unwrap();
309        let decision = coord.evaluate_step(&step, shader);
310        assert!(matches!(decision, StepDecision::GpuDispatch { .. }));
311    }
312
313    #[test]
314    fn test_cpu_fallback_when_no_gpu() {
315        let coord = PipelineCoordinator::cpu_only();
316        let step = PipelineStep {
317            label: "test".to_string(),
318            shader_name: "orbital_grid",
319            buffer_sizes: vec![1024],
320            dispatch: DispatchShape::Linear(100),
321        };
322        let shader = shader_registry::find_shader("orbital_grid").unwrap();
323        let decision = coord.evaluate_step(&step, shader);
324        assert!(matches!(decision, StepDecision::CpuPreferred { .. }));
325    }
326
327    #[test]
328    fn test_cpu_preferred_for_tier4() {
329        let coord = PipelineCoordinator::new(GpuMemoryBudget::webgpu_default(), true);
330        let step = PipelineStep {
331            label: "smoke test".to_string(),
332            shader_name: "vector_add",
333            buffer_sizes: vec![1024, 1024, 32, 1024],
334            dispatch: DispatchShape::Linear(256),
335        };
336        let shader = shader_registry::find_shader("vector_add").unwrap();
337        let decision = coord.evaluate_step(&step, shader);
338        assert!(matches!(decision, StepDecision::CpuPreferred { .. }));
339    }
340
341    #[test]
342    fn test_buffer_too_large_forces_cpu() {
343        let coord = PipelineCoordinator::new(GpuMemoryBudget::webgpu_default(), true);
344        let step = PipelineStep {
345            label: "huge".to_string(),
346            shader_name: "orbital_grid",
347            buffer_sizes: vec![300_000_000], // 300 MB > 256 MB limit
348            dispatch: DispatchShape::Linear(1),
349        };
350        let shader = shader_registry::find_shader("orbital_grid").unwrap();
351        let decision = coord.evaluate_step(&step, shader);
352        assert!(matches!(decision, StepDecision::CpuRequired { .. }));
353    }
354
355    #[test]
356    fn test_pipeline_plan() {
357        let coord = PipelineCoordinator::new(GpuMemoryBudget::webgpu_default(), true);
358        let steps = vec![
359            PipelineStep {
360                label: "grid".to_string(),
361                shader_name: "orbital_grid",
362                buffer_sizes: vec![1024, 256, 2048, 32, 500_000],
363                dispatch: DispatchShape::Grid3D([50, 50, 50]),
364            },
365            PipelineStep {
366                label: "mc".to_string(),
367                shader_name: "marching_cubes",
368                buffer_sizes: vec![500_000, 1024, 16384, 200_000, 4, 32],
369                dispatch: DispatchShape::Grid3D([49, 49, 49]),
370            },
371        ];
372        let shaders: Vec<&ShaderDescriptor> = vec![
373            shader_registry::find_shader("orbital_grid").unwrap(),
374            shader_registry::find_shader("marching_cubes").unwrap(),
375        ];
376        let plan = coord.plan_pipeline(&steps, &shaders);
377        assert_eq!(plan.gpu_steps, 2);
378        assert_eq!(plan.cpu_steps, 0);
379        assert!(plan.fully_gpu());
380    }
381
382    #[test]
383    fn test_compute_chunks_small() {
384        let coord = PipelineCoordinator::new(GpuMemoryBudget::webgpu_default(), true);
385        let chunks = coord.compute_chunks(100, 4); // 100 atoms, 4 bytes/pair
386        assert_eq!(chunks.len(), 1);
387        assert_eq!(chunks[0], (0, 100));
388    }
389
390    #[test]
391    fn test_compute_chunks_large() {
392        let coord = PipelineCoordinator::new(GpuMemoryBudget::webgpu_default(), true);
393        // Force chunking: 1 MB per pair with 128 MB limit → max ~11 atoms per chunk
394        let chunks = coord.compute_chunks(50, 1_000_000);
395        assert!(chunks.len() > 1);
396    }
397
398    #[test]
399    fn test_orbital_visualization_pipeline() {
400        let steps = orbital_visualization_pipeline(7, [50, 50, 50], 3);
401        assert_eq!(steps.len(), 3);
402        assert_eq!(steps[0].label, "Orbital Grid");
403        assert_eq!(steps[1].label, "Marching Cubes (positive lobe)");
404        assert_eq!(steps[2].label, "Marching Cubes (negative lobe)");
405    }
406
407    #[test]
408    fn test_pipeline_report() {
409        let coord = PipelineCoordinator::new(GpuMemoryBudget::webgpu_default(), true);
410        let steps = orbital_visualization_pipeline(7, [30, 30, 30], 3);
411        let shaders: Vec<&ShaderDescriptor> = steps
412            .iter()
413            .map(|s| shader_registry::find_shader(s.shader_name).unwrap())
414            .collect();
415        let plan = coord.plan_pipeline(&steps, &shaders);
416        let report = plan.report();
417        assert!(report.contains("Orbital Grid"));
418        assert!(report.contains("GPU"));
419    }
420}