1use super::memory_budget::GpuMemoryBudget;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum GpuTier {
13 Tier1,
15 Tier2,
17 Tier3,
19 Tier4,
21}
22
23impl GpuTier {
24 pub fn gpu_recommended(&self) -> bool {
26 matches!(self, GpuTier::Tier1 | GpuTier::Tier2)
27 }
28}
29
30impl std::fmt::Display for GpuTier {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 match self {
33 GpuTier::Tier1 => write!(f, "Tier 1 (massive speedup)"),
34 GpuTier::Tier2 => write!(f, "Tier 2 (significant speedup)"),
35 GpuTier::Tier3 => write!(f, "Tier 3 (moderate speedup)"),
36 GpuTier::Tier4 => write!(f, "Tier 4 (CPU preferred)"),
37 }
38 }
39}
40
41#[derive(Debug, Clone)]
43pub struct ShaderDescriptor {
44 pub name: &'static str,
46 pub tier: GpuTier,
48 pub workgroup_size: [u32; 3],
50 pub entry_point: &'static str,
52 pub storage_bindings: u32,
54 pub uniform_bindings: u32,
56 pub description: &'static str,
58}
59
60impl ShaderDescriptor {
61 pub fn total_bindings(&self) -> u32 {
63 self.storage_bindings + self.uniform_bindings
64 }
65
66 pub fn check_against_budget(&self, budget: &GpuMemoryBudget) -> Result<(), String> {
68 let total = self.storage_bindings + self.uniform_bindings;
69 if total > budget.limits.max_storage_buffers_per_stage + 4 {
70 return Err(format!(
71 "Shader '{}' needs {} bindings, budget allows {}",
72 self.name, total, budget.limits.max_storage_buffers_per_stage
73 ));
74 }
75 let invocations = self.workgroup_size[0] as u64
76 * self.workgroup_size[1] as u64
77 * self.workgroup_size[2] as u64;
78 if invocations > budget.limits.max_invocations_per_workgroup as u64 {
79 return Err(format!(
80 "Shader '{}' workgroup has {} invocations, max {}",
81 self.name, invocations, budget.limits.max_invocations_per_workgroup
82 ));
83 }
84 Ok(())
85 }
86}
87
88pub const SHADER_VECTOR_ADD: ShaderDescriptor = ShaderDescriptor {
92 name: "vector_add",
93 tier: GpuTier::Tier4,
94 workgroup_size: [64, 1, 1],
95 entry_point: "main",
96 storage_bindings: 3, uniform_bindings: 1, description: "Element-wise vector addition (GPU smoke test)",
99};
100
101pub const SHADER_ORBITAL_GRID: ShaderDescriptor = ShaderDescriptor {
103 name: "orbital_grid",
104 tier: GpuTier::Tier1,
105 workgroup_size: [8, 8, 4],
106 entry_point: "main",
107 storage_bindings: 4, uniform_bindings: 1, description: "MO wavefunction on 3D grid (GPU Tier 1: O(grid × N_basis))",
110};
111
112pub const SHADER_MARCHING_CUBES: ShaderDescriptor = ShaderDescriptor {
114 name: "marching_cubes",
115 tier: GpuTier::Tier2,
116 workgroup_size: [4, 4, 4],
117 entry_point: "main",
118 storage_bindings: 5, uniform_bindings: 1, description: "Isosurface extraction via marching cubes (GPU Tier 2: O(voxels))",
121};
122
123pub const SHADER_ESP_GRID: ShaderDescriptor = ShaderDescriptor {
125 name: "esp_grid",
126 tier: GpuTier::Tier1,
127 workgroup_size: [8, 8, 4],
128 entry_point: "main",
129 storage_bindings: 4, uniform_bindings: 1, description: "Electrostatic potential on 3D grid (GPU Tier 1: O(grid × N²))",
132};
133
134pub const SHADER_D4_DISPERSION: ShaderDescriptor = ShaderDescriptor {
136 name: "d4_dispersion",
137 tier: GpuTier::Tier3,
138 workgroup_size: [16, 16, 1],
139 entry_point: "main",
140 storage_bindings: 3, uniform_bindings: 1, description: "D4 pairwise dispersion (GPU Tier 3: O(N²))",
143};
144
145pub const SHADER_EEQ_COULOMB: ShaderDescriptor = ShaderDescriptor {
147 name: "eeq_coulomb",
148 tier: GpuTier::Tier3,
149 workgroup_size: [16, 16, 1],
150 entry_point: "main",
151 storage_bindings: 3, uniform_bindings: 1, description: "EEQ damped Coulomb matrix gamma_ij (GPU Tier 3: O(N²))",
154};
155
156pub const SHADER_DENSITY_GRID: ShaderDescriptor = ShaderDescriptor {
158 name: "density_grid",
159 tier: GpuTier::Tier1,
160 workgroup_size: [8, 8, 4],
161 entry_point: "main",
162 storage_bindings: 4, uniform_bindings: 1, description: "Electron density on 3D grid (GPU Tier 1: O(grid × N²))",
165};
166
167pub const SHADER_TWO_ELECTRON: ShaderDescriptor = ShaderDescriptor {
169 name: "two_electron_eri",
170 tier: GpuTier::Tier1,
171 workgroup_size: [64, 1, 1],
172 entry_point: "main",
173 storage_bindings: 4, uniform_bindings: 1, description: "Two-electron repulsion integrals (GPU Tier 1: O(N⁴))",
176};
177
178pub const SHADER_FOCK_BUILD: ShaderDescriptor = ShaderDescriptor {
180 name: "fock_build",
181 tier: GpuTier::Tier1,
182 workgroup_size: [16, 16, 1],
183 entry_point: "main",
184 storage_bindings: 4, uniform_bindings: 1, description: "Fock matrix construction G(P) (GPU Tier 1: O(N⁴))",
187};
188
189pub const SHADER_ONE_ELECTRON: ShaderDescriptor = ShaderDescriptor {
191 name: "one_electron",
192 tier: GpuTier::Tier2,
193 workgroup_size: [16, 16, 1],
194 entry_point: "main",
195 storage_bindings: 4, uniform_bindings: 1, description: "One-electron matrices S,T,V (GPU Tier 2: O(N²))",
198};
199
200pub const SHADER_GAMMA_MATRIX: ShaderDescriptor = ShaderDescriptor {
202 name: "gamma_matrix",
203 tier: GpuTier::Tier3,
204 workgroup_size: [16, 16, 1],
205 entry_point: "main",
206 storage_bindings: 3, uniform_bindings: 1, description: "SCC-DFTB gamma matrix (GPU Tier 3: O(N²) pairwise Coulomb)",
209};
210
211pub const SHADER_ALPB_BORN_RADII: ShaderDescriptor = ShaderDescriptor {
213 name: "alpb_born_radii",
214 tier: GpuTier::Tier3,
215 workgroup_size: [64, 1, 1],
216 entry_point: "main",
217 storage_bindings: 3, uniform_bindings: 1, description: "ALPB Born radii (GPU Tier 3: O(N²) descreening)",
220};
221
222pub const SHADER_CPM_COULOMB: ShaderDescriptor = ShaderDescriptor {
224 name: "cpm_coulomb",
225 tier: GpuTier::Tier3,
226 workgroup_size: [16, 16, 1],
227 entry_point: "main",
228 storage_bindings: 2, uniform_bindings: 1, description: "CPM Coulomb matrix J_ij (GPU Tier 3: O(N²) pairwise electrostatics)",
231};
232
233pub const ALL_SHADERS: &[&ShaderDescriptor] = &[
235 &SHADER_VECTOR_ADD,
236 &SHADER_ORBITAL_GRID,
237 &SHADER_MARCHING_CUBES,
238 &SHADER_ESP_GRID,
239 &SHADER_D4_DISPERSION,
240 &SHADER_EEQ_COULOMB,
241 &SHADER_DENSITY_GRID,
242 &SHADER_TWO_ELECTRON,
243 &SHADER_FOCK_BUILD,
244 &SHADER_ONE_ELECTRON,
245 &SHADER_GAMMA_MATRIX,
246 &SHADER_ALPB_BORN_RADII,
247 &SHADER_CPM_COULOMB,
248];
249
250pub fn find_shader(name: &str) -> Option<&'static ShaderDescriptor> {
252 ALL_SHADERS.iter().find(|s| s.name == name).copied()
253}
254
255pub fn shaders_by_tier(tier: GpuTier) -> Vec<&'static ShaderDescriptor> {
257 ALL_SHADERS
258 .iter()
259 .filter(|s| s.tier == tier)
260 .copied()
261 .collect()
262}
263
264pub fn shader_catalogue_report() -> String {
266 let mut report = String::from("GPU Shader Catalogue\n====================\n\n");
267 for tier in &[
268 GpuTier::Tier1,
269 GpuTier::Tier2,
270 GpuTier::Tier3,
271 GpuTier::Tier4,
272 ] {
273 let shaders = shaders_by_tier(*tier);
274 if shaders.is_empty() {
275 continue;
276 }
277 report.push_str(&format!("{tier}\n"));
278 for s in &shaders {
279 report.push_str(&format!(
280 " {} — wg[{},{},{}], {} bindings — {}\n",
281 s.name,
282 s.workgroup_size[0],
283 s.workgroup_size[1],
284 s.workgroup_size[2],
285 s.total_bindings(),
286 s.description,
287 ));
288 }
289 report.push('\n');
290 }
291 report
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297
298 #[test]
299 fn test_tier_display() {
300 assert_eq!(GpuTier::Tier1.to_string(), "Tier 1 (massive speedup)");
301 assert!(GpuTier::Tier1.gpu_recommended());
302 assert!(!GpuTier::Tier4.gpu_recommended());
303 }
304
305 #[test]
306 fn test_shader_lookup() {
307 let s = find_shader("orbital_grid").unwrap();
308 assert_eq!(s.tier, GpuTier::Tier1);
309 assert_eq!(s.workgroup_size, [8, 8, 4]);
310 }
311
312 #[test]
313 fn test_shader_lookup_missing() {
314 assert!(find_shader("nonexistent").is_none());
315 }
316
317 #[test]
318 fn test_shaders_by_tier() {
319 let t1 = shaders_by_tier(GpuTier::Tier1);
320 assert!(t1.len() >= 3); assert!(t1.iter().all(|s| s.tier == GpuTier::Tier1));
322 }
323
324 #[test]
325 fn test_budget_check_passes() {
326 let budget = GpuMemoryBudget::webgpu_default();
327 assert!(SHADER_ORBITAL_GRID.check_against_budget(&budget).is_ok());
328 assert!(SHADER_MARCHING_CUBES.check_against_budget(&budget).is_ok());
329 }
330
331 #[test]
332 fn test_catalogue_report() {
333 let report = shader_catalogue_report();
334 assert!(report.contains("orbital_grid"));
335 assert!(report.contains("Tier 1"));
336 assert!(report.contains("Tier 3"));
337 }
338
339 #[test]
340 fn test_total_bindings() {
341 assert_eq!(SHADER_ORBITAL_GRID.total_bindings(), 5);
342 assert_eq!(SHADER_VECTOR_ADD.total_bindings(), 4);
343 }
344
345 #[test]
346 fn test_all_shaders_registered() {
347 assert_eq!(ALL_SHADERS.len(), 13);
348 }
349}