1use super::memory_budget::{GpuMemoryBudget, MemoryError};
9use super::shader_registry::ShaderDescriptor;
10
11#[derive(Debug, Clone)]
13pub struct PipelineStep {
14 pub label: String,
16 pub shader_name: &'static str,
18 pub buffer_sizes: Vec<u64>,
20 pub dispatch: DispatchShape,
22}
23
24#[derive(Debug, Clone)]
26pub enum DispatchShape {
27 Linear(u32),
29 Grid3D([u32; 3]),
31}
32
33impl DispatchShape {
34 pub fn workgroup_count(&self, workgroup_size: [u32; 3]) -> [u32; 3] {
36 match self {
37 DispatchShape::Linear(n) => [(*n).div_ceil(workgroup_size[0]), 1, 1],
38 DispatchShape::Grid3D(dims) => [
39 dims[0].div_ceil(workgroup_size[0]),
40 dims[1].div_ceil(workgroup_size[1]),
41 dims[2].div_ceil(workgroup_size[2]),
42 ],
43 }
44 }
45}
46
47#[derive(Debug, Clone)]
49pub enum StepDecision {
50 GpuDispatch {
52 workgroup_count: [u32; 3],
53 total_buffer_bytes: u64,
54 },
55 CpuPreferred { reason: String },
57 CpuRequired { error: MemoryError },
59}
60
61#[derive(Debug)]
63pub struct PipelineCoordinator {
64 budget: GpuMemoryBudget,
65 gpu_available: bool,
66}
67
68impl PipelineCoordinator {
69 pub fn new(budget: GpuMemoryBudget, gpu_available: bool) -> Self {
70 Self {
71 budget,
72 gpu_available,
73 }
74 }
75
76 pub fn cpu_only() -> Self {
78 Self {
79 budget: GpuMemoryBudget::webgpu_default(),
80 gpu_available: false,
81 }
82 }
83
84 pub fn evaluate_step(&self, step: &PipelineStep, shader: &ShaderDescriptor) -> StepDecision {
86 if !self.gpu_available {
87 return StepDecision::CpuPreferred {
88 reason: "No GPU available".to_string(),
89 };
90 }
91
92 if !shader.tier.gpu_recommended() {
94 return StepDecision::CpuPreferred {
95 reason: format!("{} — {}", shader.tier, shader.description),
96 };
97 }
98
99 let total_bytes: u64 = step.buffer_sizes.iter().sum();
101 for (i, &size) in step.buffer_sizes.iter().enumerate() {
102 if let Err(e) = self.budget.check_buffer(size) {
103 return StepDecision::CpuRequired { error: e };
104 }
105 if i as u32 >= self.budget.limits.max_storage_buffers_per_stage {
107 return StepDecision::CpuRequired {
108 error: MemoryError::TooManyBindings {
109 current: i as u32 + 1,
110 max: self.budget.limits.max_storage_buffers_per_stage,
111 },
112 };
113 }
114 }
115
116 let wg_count = step.dispatch.workgroup_count(shader.workgroup_size);
118 if let Err(e) = self.budget.check_workgroup(shader.workgroup_size, wg_count) {
119 return StepDecision::CpuRequired { error: e };
120 }
121
122 StepDecision::GpuDispatch {
123 workgroup_count: wg_count,
124 total_buffer_bytes: total_bytes,
125 }
126 }
127
128 pub fn plan_pipeline(
130 &self,
131 steps: &[PipelineStep],
132 shaders: &[&ShaderDescriptor],
133 ) -> PipelinePlan {
134 assert_eq!(steps.len(), shaders.len(), "Steps and shaders must match");
135
136 let decisions: Vec<(String, StepDecision)> = steps
137 .iter()
138 .zip(shaders.iter())
139 .map(|(step, shader)| {
140 let decision = self.evaluate_step(step, shader);
141 (step.label.clone(), decision)
142 })
143 .collect();
144
145 let gpu_steps = decisions
146 .iter()
147 .filter(|(_, d)| matches!(d, StepDecision::GpuDispatch { .. }))
148 .count();
149 let cpu_steps = decisions.len() - gpu_steps;
150
151 PipelinePlan {
152 decisions,
153 gpu_steps,
154 cpu_steps,
155 }
156 }
157
158 pub fn compute_chunks(&self, n_atoms: usize, bytes_per_pair: u64) -> Vec<(usize, usize)> {
163 let max_pairs = self.budget.limits.max_storage_buffer_binding_size / bytes_per_pair;
164 let max_chunk = (max_pairs as f64).sqrt() as usize;
165
166 if n_atoms <= max_chunk {
167 return vec![(0, n_atoms)];
168 }
169
170 let mut chunks = Vec::new();
171 let mut start = 0;
172 while start < n_atoms {
173 let end = (start + max_chunk).min(n_atoms);
174 chunks.push((start, end));
175 start = end;
176 }
177 chunks
178 }
179}
180
181#[derive(Debug)]
183pub struct PipelinePlan {
184 pub decisions: Vec<(String, StepDecision)>,
186 pub gpu_steps: usize,
188 pub cpu_steps: usize,
190}
191
192impl PipelinePlan {
193 pub fn fully_gpu(&self) -> bool {
195 self.cpu_steps == 0
196 }
197
198 pub fn report(&self) -> String {
200 let mut out = format!(
201 "Pipeline Plan: {} GPU / {} CPU steps\n",
202 self.gpu_steps, self.cpu_steps
203 );
204 for (label, decision) in &self.decisions {
205 let status = match decision {
206 StepDecision::GpuDispatch {
207 workgroup_count,
208 total_buffer_bytes,
209 } => {
210 format!(
211 "GPU → wg[{},{},{}], {:.1} KB",
212 workgroup_count[0],
213 workgroup_count[1],
214 workgroup_count[2],
215 *total_buffer_bytes as f64 / 1024.0
216 )
217 }
218 StepDecision::CpuPreferred { reason } => {
219 format!("CPU (preferred) — {reason}")
220 }
221 StepDecision::CpuRequired { error } => {
222 format!("CPU (required) — {error}")
223 }
224 };
225 out.push_str(&format!(" [{label}] {status}\n"));
226 }
227 out
228 }
229}
230
231pub fn orbital_visualization_pipeline(
237 n_basis: usize,
238 grid_dims: [u32; 3],
239 max_primitives: usize,
240) -> Vec<PipelineStep> {
241 let basis_bytes = (n_basis * 32) as u64;
242 let mo_bytes = (n_basis * 4) as u64;
243 let prim_bytes = (n_basis * max_primitives * 8) as u64;
244 let grid_points = grid_dims[0] as u64 * grid_dims[1] as u64 * grid_dims[2] as u64;
245 let grid_bytes = grid_points * 4;
246
247 let est_triangles = (grid_points / 10) * 5;
249 let vertex_bytes = est_triangles * 24; vec![
252 PipelineStep {
253 label: "Orbital Grid".to_string(),
254 shader_name: "orbital_grid",
255 buffer_sizes: vec![basis_bytes, mo_bytes, prim_bytes, 32, grid_bytes],
256 dispatch: DispatchShape::Grid3D(grid_dims),
257 },
258 PipelineStep {
259 label: "Marching Cubes (positive lobe)".to_string(),
260 shader_name: "marching_cubes",
261 buffer_sizes: vec![
262 grid_bytes, 256 * 4, 256 * 16 * 4, vertex_bytes, 4, 32, ],
269 dispatch: DispatchShape::Grid3D([grid_dims[0] - 1, grid_dims[1] - 1, grid_dims[2] - 1]),
270 },
271 PipelineStep {
272 label: "Marching Cubes (negative lobe)".to_string(),
273 shader_name: "marching_cubes",
274 buffer_sizes: vec![grid_bytes, 256 * 4, 256 * 16 * 4, vertex_bytes, 4, 32],
275 dispatch: DispatchShape::Grid3D([grid_dims[0] - 1, grid_dims[1] - 1, grid_dims[2] - 1]),
276 },
277 ]
278}
279
280#[cfg(test)]
281mod tests {
282 use super::super::shader_registry;
283 use super::*;
284
285 #[test]
286 fn test_dispatch_1d() {
287 let shape = DispatchShape::Linear(1000);
288 let wg = shape.workgroup_count([64, 1, 1]);
289 assert_eq!(wg, [16, 1, 1]); }
291
292 #[test]
293 fn test_dispatch_3d() {
294 let shape = DispatchShape::Grid3D([64, 64, 64]);
295 let wg = shape.workgroup_count([8, 8, 4]);
296 assert_eq!(wg, [8, 8, 16]);
297 }
298
299 #[test]
300 fn test_gpu_dispatch_feasible() {
301 let coord = PipelineCoordinator::new(GpuMemoryBudget::webgpu_default(), true);
302 let step = PipelineStep {
303 label: "test".to_string(),
304 shader_name: "orbital_grid",
305 buffer_sizes: vec![1024, 512, 2048, 32, 500_000],
306 dispatch: DispatchShape::Grid3D([50, 50, 50]),
307 };
308 let shader = shader_registry::find_shader("orbital_grid").unwrap();
309 let decision = coord.evaluate_step(&step, shader);
310 assert!(matches!(decision, StepDecision::GpuDispatch { .. }));
311 }
312
313 #[test]
314 fn test_cpu_fallback_when_no_gpu() {
315 let coord = PipelineCoordinator::cpu_only();
316 let step = PipelineStep {
317 label: "test".to_string(),
318 shader_name: "orbital_grid",
319 buffer_sizes: vec![1024],
320 dispatch: DispatchShape::Linear(100),
321 };
322 let shader = shader_registry::find_shader("orbital_grid").unwrap();
323 let decision = coord.evaluate_step(&step, shader);
324 assert!(matches!(decision, StepDecision::CpuPreferred { .. }));
325 }
326
327 #[test]
328 fn test_cpu_preferred_for_tier4() {
329 let coord = PipelineCoordinator::new(GpuMemoryBudget::webgpu_default(), true);
330 let step = PipelineStep {
331 label: "smoke test".to_string(),
332 shader_name: "vector_add",
333 buffer_sizes: vec![1024, 1024, 32, 1024],
334 dispatch: DispatchShape::Linear(256),
335 };
336 let shader = shader_registry::find_shader("vector_add").unwrap();
337 let decision = coord.evaluate_step(&step, shader);
338 assert!(matches!(decision, StepDecision::CpuPreferred { .. }));
339 }
340
341 #[test]
342 fn test_buffer_too_large_forces_cpu() {
343 let coord = PipelineCoordinator::new(GpuMemoryBudget::webgpu_default(), true);
344 let step = PipelineStep {
345 label: "huge".to_string(),
346 shader_name: "orbital_grid",
347 buffer_sizes: vec![300_000_000], dispatch: DispatchShape::Linear(1),
349 };
350 let shader = shader_registry::find_shader("orbital_grid").unwrap();
351 let decision = coord.evaluate_step(&step, shader);
352 assert!(matches!(decision, StepDecision::CpuRequired { .. }));
353 }
354
355 #[test]
356 fn test_pipeline_plan() {
357 let coord = PipelineCoordinator::new(GpuMemoryBudget::webgpu_default(), true);
358 let steps = vec![
359 PipelineStep {
360 label: "grid".to_string(),
361 shader_name: "orbital_grid",
362 buffer_sizes: vec![1024, 256, 2048, 32, 500_000],
363 dispatch: DispatchShape::Grid3D([50, 50, 50]),
364 },
365 PipelineStep {
366 label: "mc".to_string(),
367 shader_name: "marching_cubes",
368 buffer_sizes: vec![500_000, 1024, 16384, 200_000, 4, 32],
369 dispatch: DispatchShape::Grid3D([49, 49, 49]),
370 },
371 ];
372 let shaders: Vec<&ShaderDescriptor> = vec![
373 shader_registry::find_shader("orbital_grid").unwrap(),
374 shader_registry::find_shader("marching_cubes").unwrap(),
375 ];
376 let plan = coord.plan_pipeline(&steps, &shaders);
377 assert_eq!(plan.gpu_steps, 2);
378 assert_eq!(plan.cpu_steps, 0);
379 assert!(plan.fully_gpu());
380 }
381
382 #[test]
383 fn test_compute_chunks_small() {
384 let coord = PipelineCoordinator::new(GpuMemoryBudget::webgpu_default(), true);
385 let chunks = coord.compute_chunks(100, 4); assert_eq!(chunks.len(), 1);
387 assert_eq!(chunks[0], (0, 100));
388 }
389
390 #[test]
391 fn test_compute_chunks_large() {
392 let coord = PipelineCoordinator::new(GpuMemoryBudget::webgpu_default(), true);
393 let chunks = coord.compute_chunks(50, 1_000_000);
395 assert!(chunks.len() > 1);
396 }
397
398 #[test]
399 fn test_orbital_visualization_pipeline() {
400 let steps = orbital_visualization_pipeline(7, [50, 50, 50], 3);
401 assert_eq!(steps.len(), 3);
402 assert_eq!(steps[0].label, "Orbital Grid");
403 assert_eq!(steps[1].label, "Marching Cubes (positive lobe)");
404 assert_eq!(steps[2].label, "Marching Cubes (negative lobe)");
405 }
406
407 #[test]
408 fn test_pipeline_report() {
409 let coord = PipelineCoordinator::new(GpuMemoryBudget::webgpu_default(), true);
410 let steps = orbital_visualization_pipeline(7, [30, 30, 30], 3);
411 let shaders: Vec<&ShaderDescriptor> = steps
412 .iter()
413 .map(|s| shader_registry::find_shader(s.shader_name).unwrap())
414 .collect();
415 let plan = coord.plan_pipeline(&steps, &shaders);
416 let report = plan.report();
417 assert!(report.contains("Orbital Grid"));
418 assert!(report.contains("GPU"));
419 }
420}