Skip to main content

sci_form/gpu/
shader_registry.rs

1//! Centralized WGSL shader registry.
2//!
3//! All GPU compute shaders are catalogued here with metadata for
4//! validation, tier classification, and dispatch parameter guidance.
5
6use super::memory_budget::GpuMemoryBudget;
7
8/// GPU compute tier classification (from algorithm analysis).
9///
10/// Higher tiers benefit more from GPU acceleration.
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum GpuTier {
13    /// Massive speedup: O(N⁴) two-electron integrals, O(grid × N_basis) orbital grids.
14    Tier1,
15    /// Significant speedup: O(N²) matrix builds (overlap, Fock, Coulomb), O(voxels) marching cubes.
16    Tier2,
17    /// Moderate speedup: O(N²) pairwise (D4 dispersion, EEQ Coulomb, KPM Chebyshev).
18    Tier3,
19    /// CPU-preferred: SCF loop control, DIIS, eigensolve (latency-bound).
20    Tier4,
21}
22
23impl GpuTier {
24    /// Whether GPU dispatch is recommended for this tier.
25    pub fn gpu_recommended(&self) -> bool {
26        matches!(self, GpuTier::Tier1 | GpuTier::Tier2)
27    }
28}
29
30impl std::fmt::Display for GpuTier {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        match self {
33            GpuTier::Tier1 => write!(f, "Tier 1 (massive speedup)"),
34            GpuTier::Tier2 => write!(f, "Tier 2 (significant speedup)"),
35            GpuTier::Tier3 => write!(f, "Tier 3 (moderate speedup)"),
36            GpuTier::Tier4 => write!(f, "Tier 4 (CPU preferred)"),
37        }
38    }
39}
40
41/// Descriptor for a registered GPU shader.
42#[derive(Debug, Clone)]
43pub struct ShaderDescriptor {
44    /// Human-readable name.
45    pub name: &'static str,
46    /// GPU tier classification.
47    pub tier: GpuTier,
48    /// Workgroup size [x, y, z].
49    pub workgroup_size: [u32; 3],
50    /// Entry point function name.
51    pub entry_point: &'static str,
52    /// Number of storage bindings required.
53    pub storage_bindings: u32,
54    /// Number of uniform bindings required.
55    pub uniform_bindings: u32,
56    /// Brief description of what this shader computes.
57    pub description: &'static str,
58}
59
60impl ShaderDescriptor {
61    /// Total bindings (storage + uniform).
62    pub fn total_bindings(&self) -> u32 {
63        self.storage_bindings + self.uniform_bindings
64    }
65
66    /// Pre-flight check against memory budget limits.
67    pub fn check_against_budget(&self, budget: &GpuMemoryBudget) -> Result<(), String> {
68        let total = self.storage_bindings + self.uniform_bindings;
69        if total > budget.limits.max_storage_buffers_per_stage + 4 {
70            return Err(format!(
71                "Shader '{}' needs {} bindings, budget allows {}",
72                self.name, total, budget.limits.max_storage_buffers_per_stage
73            ));
74        }
75        let invocations = self.workgroup_size[0] as u64
76            * self.workgroup_size[1] as u64
77            * self.workgroup_size[2] as u64;
78        if invocations > budget.limits.max_invocations_per_workgroup as u64 {
79            return Err(format!(
80                "Shader '{}' workgroup has {} invocations, max {}",
81                self.name, invocations, budget.limits.max_invocations_per_workgroup
82            ));
83        }
84        Ok(())
85    }
86}
87
88// ─── Shader catalogue ──────────────────────────────────────────────
89
90/// Vector addition — smoke-test / validation shader.
91pub const SHADER_VECTOR_ADD: ShaderDescriptor = ShaderDescriptor {
92    name: "vector_add",
93    tier: GpuTier::Tier4,
94    workgroup_size: [64, 1, 1],
95    entry_point: "main",
96    storage_bindings: 3, // lhs, rhs, out
97    uniform_bindings: 1, // params
98    description: "Element-wise vector addition (GPU smoke test)",
99};
100
101/// Orbital grid evaluation: ψ_i(r) = Σ_μ C_{μi} φ_μ(r).
102pub const SHADER_ORBITAL_GRID: ShaderDescriptor = ShaderDescriptor {
103    name: "orbital_grid",
104    tier: GpuTier::Tier1,
105    workgroup_size: [8, 8, 4],
106    entry_point: "main",
107    storage_bindings: 4, // basis, mo_coeffs, primitives, output
108    uniform_bindings: 1, // grid_params
109    description: "MO wavefunction on 3D grid (GPU Tier 1: O(grid × N_basis))",
110};
111
112/// Marching cubes isosurface extraction.
113pub const SHADER_MARCHING_CUBES: ShaderDescriptor = ShaderDescriptor {
114    name: "marching_cubes",
115    tier: GpuTier::Tier2,
116    workgroup_size: [4, 4, 4],
117    entry_point: "main",
118    storage_bindings: 5, // scalar_field, edge_table, tri_table, vertices, tri_count
119    uniform_bindings: 1, // mc_params
120    description: "Isosurface extraction via marching cubes (GPU Tier 2: O(voxels))",
121};
122
123/// ESP grid: V(r) = Σ_A Z_A/|r-R_A| - Σ_μν P_μν ∫ φ_μ(r') φ_ν(r')/|r-r'| dr'.
124pub const SHADER_ESP_GRID: ShaderDescriptor = ShaderDescriptor {
125    name: "esp_grid",
126    tier: GpuTier::Tier1,
127    workgroup_size: [8, 8, 4],
128    entry_point: "main",
129    storage_bindings: 4, // atoms, density, basis_info, output
130    uniform_bindings: 1, // grid_params
131    description: "Electrostatic potential on 3D grid (GPU Tier 1: O(grid × N²))",
132};
133
134/// D4 dispersion pairwise energy.
135pub const SHADER_D4_DISPERSION: ShaderDescriptor = ShaderDescriptor {
136    name: "d4_dispersion",
137    tier: GpuTier::Tier3,
138    workgroup_size: [16, 16, 1],
139    entry_point: "main",
140    storage_bindings: 3, // positions, d4_params, output_energies
141    uniform_bindings: 1, // config
142    description: "D4 pairwise dispersion (GPU Tier 3: O(N²))",
143};
144
145/// EEQ Coulomb matrix.
146pub const SHADER_EEQ_COULOMB: ShaderDescriptor = ShaderDescriptor {
147    name: "eeq_coulomb",
148    tier: GpuTier::Tier3,
149    workgroup_size: [16, 16, 1],
150    entry_point: "main",
151    storage_bindings: 3, // positions, radii, coulomb_matrix
152    uniform_bindings: 1, // config
153    description: "EEQ damped Coulomb matrix gamma_ij (GPU Tier 3: O(N²))",
154};
155
156/// Electron density grid: ρ(r) = Σ_μν P_μν φ_μ(r) φ_ν(r).
157pub const SHADER_DENSITY_GRID: ShaderDescriptor = ShaderDescriptor {
158    name: "density_grid",
159    tier: GpuTier::Tier1,
160    workgroup_size: [8, 8, 4],
161    entry_point: "main",
162    storage_bindings: 4, // basis, density_matrix, primitives, output
163    uniform_bindings: 1, // grid_params
164    description: "Electron density on 3D grid (GPU Tier 1: O(grid × N²))",
165};
166
167/// Two-electron repulsion integrals: (μν|λσ).
168pub const SHADER_TWO_ELECTRON: ShaderDescriptor = ShaderDescriptor {
169    name: "two_electron_eri",
170    tier: GpuTier::Tier1,
171    workgroup_size: [64, 1, 1],
172    entry_point: "main",
173    storage_bindings: 4, // basis, primitives, quartets, output
174    uniform_bindings: 1, // params
175    description: "Two-electron repulsion integrals (GPU Tier 1: O(N⁴))",
176};
177
178/// Fock matrix build: F = H + G(P).
179pub const SHADER_FOCK_BUILD: ShaderDescriptor = ShaderDescriptor {
180    name: "fock_build",
181    tier: GpuTier::Tier1,
182    workgroup_size: [16, 16, 1],
183    entry_point: "main",
184    storage_bindings: 4, // h_core, density, eris, output
185    uniform_bindings: 1, // params
186    description: "Fock matrix construction G(P) (GPU Tier 1: O(N⁴))",
187};
188
189/// One-electron matrices: S, T, V.
190pub const SHADER_ONE_ELECTRON: ShaderDescriptor = ShaderDescriptor {
191    name: "one_electron",
192    tier: GpuTier::Tier2,
193    workgroup_size: [16, 16, 1],
194    entry_point: "main",
195    storage_bindings: 4, // basis, primitives, atoms, output
196    uniform_bindings: 1, // params
197    description: "One-electron matrices S,T,V (GPU Tier 2: O(N²))",
198};
199
200/// SCC-DFTB gamma matrix.
201pub const SHADER_GAMMA_MATRIX: ShaderDescriptor = ShaderDescriptor {
202    name: "gamma_matrix",
203    tier: GpuTier::Tier3,
204    workgroup_size: [16, 16, 1],
205    entry_point: "main",
206    storage_bindings: 3, // eta, positions, output
207    uniform_bindings: 1, // params
208    description: "SCC-DFTB gamma matrix (GPU Tier 3: O(N²) pairwise Coulomb)",
209};
210
211/// ALPB Born radii.
212pub const SHADER_ALPB_BORN_RADII: ShaderDescriptor = ShaderDescriptor {
213    name: "alpb_born_radii",
214    tier: GpuTier::Tier3,
215    workgroup_size: [64, 1, 1],
216    entry_point: "main",
217    storage_bindings: 3, // positions, rho, output
218    uniform_bindings: 1, // params
219    description: "ALPB Born radii (GPU Tier 3: O(N²) descreening)",
220};
221
222/// CPM Coulomb matrix.
223pub const SHADER_CPM_COULOMB: ShaderDescriptor = ShaderDescriptor {
224    name: "cpm_coulomb",
225    tier: GpuTier::Tier3,
226    workgroup_size: [16, 16, 1],
227    entry_point: "main",
228    storage_bindings: 2, // positions, output
229    uniform_bindings: 1, // params
230    description: "CPM Coulomb matrix J_ij (GPU Tier 3: O(N²) pairwise electrostatics)",
231};
232
233/// All registered shaders.
234pub const ALL_SHADERS: &[&ShaderDescriptor] = &[
235    &SHADER_VECTOR_ADD,
236    &SHADER_ORBITAL_GRID,
237    &SHADER_MARCHING_CUBES,
238    &SHADER_ESP_GRID,
239    &SHADER_D4_DISPERSION,
240    &SHADER_EEQ_COULOMB,
241    &SHADER_DENSITY_GRID,
242    &SHADER_TWO_ELECTRON,
243    &SHADER_FOCK_BUILD,
244    &SHADER_ONE_ELECTRON,
245    &SHADER_GAMMA_MATRIX,
246    &SHADER_ALPB_BORN_RADII,
247    &SHADER_CPM_COULOMB,
248];
249
250/// Look up a shader by name.
251pub fn find_shader(name: &str) -> Option<&'static ShaderDescriptor> {
252    ALL_SHADERS.iter().find(|s| s.name == name).copied()
253}
254
255/// List all shaders in a given tier.
256pub fn shaders_by_tier(tier: GpuTier) -> Vec<&'static ShaderDescriptor> {
257    ALL_SHADERS
258        .iter()
259        .filter(|s| s.tier == tier)
260        .copied()
261        .collect()
262}
263
264/// Generate a summary report of all registered shaders.
265pub fn shader_catalogue_report() -> String {
266    let mut report = String::from("GPU Shader Catalogue\n====================\n\n");
267    for tier in &[
268        GpuTier::Tier1,
269        GpuTier::Tier2,
270        GpuTier::Tier3,
271        GpuTier::Tier4,
272    ] {
273        let shaders = shaders_by_tier(*tier);
274        if shaders.is_empty() {
275            continue;
276        }
277        report.push_str(&format!("{tier}\n"));
278        for s in &shaders {
279            report.push_str(&format!(
280                "  {} — wg[{},{},{}], {} bindings — {}\n",
281                s.name,
282                s.workgroup_size[0],
283                s.workgroup_size[1],
284                s.workgroup_size[2],
285                s.total_bindings(),
286                s.description,
287            ));
288        }
289        report.push('\n');
290    }
291    report
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    #[test]
299    fn test_tier_display() {
300        assert_eq!(GpuTier::Tier1.to_string(), "Tier 1 (massive speedup)");
301        assert!(GpuTier::Tier1.gpu_recommended());
302        assert!(!GpuTier::Tier4.gpu_recommended());
303    }
304
305    #[test]
306    fn test_shader_lookup() {
307        let s = find_shader("orbital_grid").unwrap();
308        assert_eq!(s.tier, GpuTier::Tier1);
309        assert_eq!(s.workgroup_size, [8, 8, 4]);
310    }
311
312    #[test]
313    fn test_shader_lookup_missing() {
314        assert!(find_shader("nonexistent").is_none());
315    }
316
317    #[test]
318    fn test_shaders_by_tier() {
319        let t1 = shaders_by_tier(GpuTier::Tier1);
320        assert!(t1.len() >= 3); // orbital_grid, esp_grid, density_grid
321        assert!(t1.iter().all(|s| s.tier == GpuTier::Tier1));
322    }
323
324    #[test]
325    fn test_budget_check_passes() {
326        let budget = GpuMemoryBudget::webgpu_default();
327        assert!(SHADER_ORBITAL_GRID.check_against_budget(&budget).is_ok());
328        assert!(SHADER_MARCHING_CUBES.check_against_budget(&budget).is_ok());
329    }
330
331    #[test]
332    fn test_catalogue_report() {
333        let report = shader_catalogue_report();
334        assert!(report.contains("orbital_grid"));
335        assert!(report.contains("Tier 1"));
336        assert!(report.contains("Tier 3"));
337    }
338
339    #[test]
340    fn test_total_bindings() {
341        assert_eq!(SHADER_ORBITAL_GRID.total_bindings(), 5);
342        assert_eq!(SHADER_VECTOR_ADD.total_bindings(), 4);
343    }
344
345    #[test]
346    fn test_all_shaders_registered() {
347        assert_eq!(ALL_SHADERS.len(), 13);
348    }
349}