Skip to main content

sci_form/gpu/
memory_budget.rs

1//! GPU Memory Budget — WebGPU/WASM limit enforcement.
2//!
3//! WebGPU imposes strict hardware limits that vary by platform. This module
4//! tracks buffer allocations against these limits and provides pre-flight
5//! checks before GPU dispatch to prevent silent failures or OOM.
6//!
7//! ## WebGPU Default Limits (W3C spec)
8//!
9//! | Limit | Default | Notes |
10//! |-------|---------|-------|
11//! | maxStorageBufferBindingSize | 128 MB | Per binding, per shader stage |
12//! | maxBufferSize | 256 MB | Total per buffer object |
13//! | maxUniformBufferBindingSize | 64 KB | Uniform blocks are small |
14//! | maxStorageBuffersPerShaderStage | 8 | Max bindings per stage |
15//! | maxComputeWorkgroupStorageSize | 16 KB | Shared workgroup memory |
16//! | maxComputeInvocationsPerWorkgroup | 256 | Threads per workgroup |
17//! | maxComputeWorkgroupSizeX | 256 | |
18//! | maxComputeWorkgroupSizeY | 256 | |
19//! | maxComputeWorkgroupSizeZ | 64 | |
20//! | maxComputeWorkgroupsPerDimension | 65535 | |
21
22/// WebGPU hardware limits, initialized to W3C default minimums.
23#[derive(Debug, Clone)]
24pub struct GpuMemoryLimits {
25    /// Max bytes per storage buffer binding (default: 128 MB).
26    pub max_storage_buffer_binding_size: u64,
27    /// Max total bytes per buffer object (default: 256 MB).
28    pub max_buffer_size: u64,
29    /// Max bytes per uniform buffer binding (default: 64 KB).
30    pub max_uniform_buffer_binding_size: u64,
31    /// Max storage buffers per shader stage (default: 8).
32    pub max_storage_buffers_per_stage: u32,
33    /// Max shared workgroup memory in bytes (default: 16 KB).
34    pub max_workgroup_storage_size: u32,
35    /// Max invocations per workgroup (default: 256).
36    pub max_invocations_per_workgroup: u32,
37    /// Max workgroup size X (default: 256).
38    pub max_workgroup_size_x: u32,
39    /// Max workgroup size Y (default: 256).
40    pub max_workgroup_size_y: u32,
41    /// Max workgroup size Z (default: 64).
42    pub max_workgroup_size_z: u32,
43    /// Max workgroups per dimension (default: 65535).
44    pub max_workgroups_per_dimension: u32,
45}
46
47impl Default for GpuMemoryLimits {
48    fn default() -> Self {
49        Self {
50            max_storage_buffer_binding_size: 134_217_728, // 128 MB
51            max_buffer_size: 268_435_456,                 // 256 MB
52            max_uniform_buffer_binding_size: 65_536,      // 64 KB
53            max_storage_buffers_per_stage: 8,
54            max_workgroup_storage_size: 16_384, // 16 KB
55            max_invocations_per_workgroup: 256,
56            max_workgroup_size_x: 256,
57            max_workgroup_size_y: 256,
58            max_workgroup_size_z: 64,
59            max_workgroups_per_dimension: 65535,
60        }
61    }
62}
63
64/// Tracks GPU buffer allocations and enforces limits.
65#[derive(Debug, Clone)]
66pub struct GpuMemoryBudget {
67    pub limits: GpuMemoryLimits,
68    /// Total bytes currently allocated across all buffers.
69    pub allocated_bytes: u64,
70    /// Number of storage buffer bindings in use.
71    pub storage_bindings_used: u32,
72}
73
74impl GpuMemoryBudget {
75    pub fn new(limits: GpuMemoryLimits) -> Self {
76        Self {
77            limits,
78            allocated_bytes: 0,
79            storage_bindings_used: 0,
80        }
81    }
82
83    /// WebGPU default limits.
84    pub fn webgpu_default() -> Self {
85        Self::new(GpuMemoryLimits::default())
86    }
87
88    /// Check if a buffer allocation would exceed limits.
89    pub fn check_buffer(&self, size_bytes: u64) -> Result<(), MemoryError> {
90        if size_bytes > self.limits.max_buffer_size {
91            return Err(MemoryError::BufferTooLarge {
92                requested: size_bytes,
93                max: self.limits.max_buffer_size,
94            });
95        }
96        if size_bytes > self.limits.max_storage_buffer_binding_size {
97            return Err(MemoryError::BindingTooLarge {
98                requested: size_bytes,
99                max: self.limits.max_storage_buffer_binding_size,
100            });
101        }
102        Ok(())
103    }
104
105    /// Check if a storage binding can be added.
106    pub fn check_storage_binding(&self) -> Result<(), MemoryError> {
107        if self.storage_bindings_used >= self.limits.max_storage_buffers_per_stage {
108            return Err(MemoryError::TooManyBindings {
109                current: self.storage_bindings_used,
110                max: self.limits.max_storage_buffers_per_stage,
111            });
112        }
113        Ok(())
114    }
115
116    /// Check workgroup configuration validity.
117    pub fn check_workgroup(&self, size: [u32; 3], count: [u32; 3]) -> Result<(), MemoryError> {
118        if size[0] > self.limits.max_workgroup_size_x
119            || size[1] > self.limits.max_workgroup_size_y
120            || size[2] > self.limits.max_workgroup_size_z
121        {
122            return Err(MemoryError::WorkgroupSizeExceeded {
123                requested: size,
124                max: [
125                    self.limits.max_workgroup_size_x,
126                    self.limits.max_workgroup_size_y,
127                    self.limits.max_workgroup_size_z,
128                ],
129            });
130        }
131        let total = size[0] as u64 * size[1] as u64 * size[2] as u64;
132        if total > self.limits.max_invocations_per_workgroup as u64 {
133            return Err(MemoryError::TooManyInvocations {
134                requested: total as u32,
135                max: self.limits.max_invocations_per_workgroup,
136            });
137        }
138        for (i, &c) in count.iter().enumerate() {
139            if c > self.limits.max_workgroups_per_dimension {
140                return Err(MemoryError::TooManyWorkgroups {
141                    dimension: i as u32,
142                    requested: c,
143                    max: self.limits.max_workgroups_per_dimension,
144                });
145            }
146        }
147        Ok(())
148    }
149
150    /// Pre-flight check for an orbital grid computation.
151    ///
152    /// Returns the estimated total GPU memory needed in bytes.
153    pub fn estimate_orbital_grid_memory(
154        &self,
155        n_basis: usize,
156        grid_points: usize,
157        max_primitives_per_basis: usize,
158    ) -> Result<OrbitalGridMemoryEstimate, MemoryError> {
159        // Buffers:
160        // 1. basis functions: n_basis × 32 bytes (BasisFunc struct)
161        // 2. MO coefficients: n_basis × 4 bytes (f32)
162        // 3. primitives: n_basis × max_primitives × 8 bytes (vec2<f32>)
163        // 4. grid params: 32 bytes (uniform)
164        // 5. output: grid_points × 4 bytes (f32)
165        let basis_bytes = (n_basis * 32) as u64;
166        let mo_bytes = (n_basis * 4) as u64;
167        let prim_bytes = (n_basis * max_primitives_per_basis * 8) as u64;
168        let params_bytes = 32u64;
169        let output_bytes = (grid_points * 4) as u64;
170
171        let total = basis_bytes + mo_bytes + prim_bytes + params_bytes + output_bytes;
172
173        // Check individual buffers
174        self.check_buffer(basis_bytes)?;
175        self.check_buffer(mo_bytes)?;
176        self.check_buffer(prim_bytes)?;
177        self.check_buffer(output_bytes)?;
178
179        // Need 5 bindings: 3 storage read, 1 uniform, 1 storage read-write
180        if 5 > self.limits.max_storage_buffers_per_stage {
181            return Err(MemoryError::TooManyBindings {
182                current: 5,
183                max: self.limits.max_storage_buffers_per_stage,
184            });
185        }
186
187        Ok(OrbitalGridMemoryEstimate {
188            basis_bytes,
189            mo_coefficients_bytes: mo_bytes,
190            primitives_bytes: prim_bytes,
191            params_bytes,
192            output_bytes,
193            total_bytes: total,
194            fits_in_webgpu: total <= self.limits.max_buffer_size,
195        })
196    }
197
198    /// Estimate memory for D4 dispersion pairwise computation.
199    pub fn estimate_d4_dispersion_memory(
200        &self,
201        n_atoms: usize,
202    ) -> Result<PairwiseMemoryEstimate, MemoryError> {
203        // Buffers:
204        // 1. positions: n_atoms × 16 bytes (vec4<f32>: x,y,z,Z)
205        // 2. params: n_atoms × 32 bytes (D4 params)
206        // 3. config: 32 bytes (uniform)
207        // 4. output pairwise: n_atoms × n_atoms × 4 bytes (f32)
208        // 5. output energy: 4 bytes (f32 reduction)
209        let pos_bytes = (n_atoms * 16) as u64;
210        let params_bytes = (n_atoms * 32) as u64;
211        let config_bytes = 32u64;
212        let pairwise_bytes = (n_atoms * n_atoms * 4) as u64;
213        let output_bytes = (n_atoms * 4) as u64;
214
215        let total = pos_bytes + params_bytes + config_bytes + pairwise_bytes + output_bytes;
216
217        self.check_buffer(pairwise_bytes)?;
218
219        Ok(PairwiseMemoryEstimate {
220            positions_bytes: pos_bytes,
221            params_bytes,
222            pairwise_bytes,
223            total_bytes: total,
224            fits_in_webgpu: total <= self.limits.max_buffer_size,
225            max_atoms_for_limit: ((self.limits.max_storage_buffer_binding_size / 4) as f64).sqrt()
226                as usize,
227        })
228    }
229
230    /// Compute optimal workgroup dispatch for a 3D grid.
231    pub fn optimal_grid_dispatch(&self, dims: [u32; 3]) -> ([u32; 3], [u32; 3]) {
232        let wg_size = [
233            8u32.min(self.limits.max_workgroup_size_x),
234            8u32.min(self.limits.max_workgroup_size_y),
235            4u32.min(self.limits.max_workgroup_size_z),
236        ];
237        let wg_count = [
238            dims[0].div_ceil(wg_size[0]),
239            dims[1].div_ceil(wg_size[1]),
240            dims[2].div_ceil(wg_size[2]),
241        ];
242        (wg_size, wg_count)
243    }
244
245    /// Compute optimal workgroup dispatch for a 1D array.
246    pub fn optimal_1d_dispatch(&self, n: u32) -> (u32, u32) {
247        let wg_size = 64u32.min(self.limits.max_workgroup_size_x);
248        let wg_count = n.div_ceil(wg_size);
249        (wg_size, wg_count)
250    }
251}
252
253/// Memory estimation for orbital grid dispatch.
254#[derive(Debug, Clone)]
255pub struct OrbitalGridMemoryEstimate {
256    pub basis_bytes: u64,
257    pub mo_coefficients_bytes: u64,
258    pub primitives_bytes: u64,
259    pub params_bytes: u64,
260    pub output_bytes: u64,
261    pub total_bytes: u64,
262    pub fits_in_webgpu: bool,
263}
264
265/// Memory estimation for pairwise computations (D4, EEQ Coulomb).
266#[derive(Debug, Clone)]
267pub struct PairwiseMemoryEstimate {
268    pub positions_bytes: u64,
269    pub params_bytes: u64,
270    pub pairwise_bytes: u64,
271    pub total_bytes: u64,
272    pub fits_in_webgpu: bool,
273    /// Maximum atom count that fits in one storage buffer.
274    pub max_atoms_for_limit: usize,
275}
276
277/// Errors from GPU memory budget checks.
278#[derive(Debug, Clone)]
279pub enum MemoryError {
280    BufferTooLarge {
281        requested: u64,
282        max: u64,
283    },
284    BindingTooLarge {
285        requested: u64,
286        max: u64,
287    },
288    TooManyBindings {
289        current: u32,
290        max: u32,
291    },
292    WorkgroupSizeExceeded {
293        requested: [u32; 3],
294        max: [u32; 3],
295    },
296    TooManyInvocations {
297        requested: u32,
298        max: u32,
299    },
300    TooManyWorkgroups {
301        dimension: u32,
302        requested: u32,
303        max: u32,
304    },
305}
306
307impl std::fmt::Display for MemoryError {
308    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309        match self {
310            MemoryError::BufferTooLarge { requested, max } => {
311                write!(f, "Buffer {requested} bytes exceeds max {max} bytes")
312            }
313            MemoryError::BindingTooLarge { requested, max } => {
314                write!(f, "Binding {requested} bytes exceeds max {max} bytes")
315            }
316            MemoryError::TooManyBindings { current, max } => {
317                write!(f, "Need {current} bindings, max {max}")
318            }
319            MemoryError::WorkgroupSizeExceeded { requested, max } => {
320                write!(
321                    f,
322                    "Workgroup [{},{},{}] exceeds max [{},{},{}]",
323                    requested[0], requested[1], requested[2], max[0], max[1], max[2]
324                )
325            }
326            MemoryError::TooManyInvocations { requested, max } => {
327                write!(f, "{requested} invocations exceeds max {max}")
328            }
329            MemoryError::TooManyWorkgroups {
330                dimension,
331                requested,
332                max,
333            } => {
334                write!(
335                    f,
336                    "Dimension {dimension}: {requested} workgroups exceeds max {max}"
337                )
338            }
339        }
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346
347    #[test]
348    fn test_webgpu_defaults() {
349        let limits = GpuMemoryLimits::default();
350        assert_eq!(limits.max_storage_buffer_binding_size, 128 * 1024 * 1024);
351        assert_eq!(limits.max_buffer_size, 256 * 1024 * 1024);
352        assert_eq!(limits.max_uniform_buffer_binding_size, 64 * 1024);
353    }
354
355    #[test]
356    fn test_buffer_check_within_limits() {
357        let budget = GpuMemoryBudget::webgpu_default();
358        assert!(budget.check_buffer(1024 * 1024).is_ok()); // 1 MB
359    }
360
361    #[test]
362    fn test_buffer_check_exceeds_limits() {
363        let budget = GpuMemoryBudget::webgpu_default();
364        assert!(budget.check_buffer(300_000_000).is_err()); // 300 MB > 256 MB
365    }
366
367    #[test]
368    fn test_orbital_grid_small_molecule() {
369        let budget = GpuMemoryBudget::webgpu_default();
370        // H₂O: 7 basis functions, 50³ grid = 125000 points
371        let est = budget.estimate_orbital_grid_memory(7, 125_000, 3).unwrap();
372        assert!(est.fits_in_webgpu);
373        assert!(est.total_bytes < 1_000_000); // < 1 MB
374    }
375
376    #[test]
377    fn test_workgroup_check_valid() {
378        let budget = GpuMemoryBudget::webgpu_default();
379        assert!(budget.check_workgroup([8, 8, 4], [100, 100, 50]).is_ok());
380    }
381
382    #[test]
383    fn test_workgroup_check_too_large() {
384        let budget = GpuMemoryBudget::webgpu_default();
385        assert!(budget.check_workgroup([512, 1, 1], [1, 1, 1]).is_err());
386    }
387
388    #[test]
389    fn test_d4_memory_small_system() {
390        let budget = GpuMemoryBudget::webgpu_default();
391        let est = budget.estimate_d4_dispersion_memory(100).unwrap();
392        assert!(est.fits_in_webgpu);
393    }
394
395    #[test]
396    fn test_d4_max_atoms_calculable() {
397        let budget = GpuMemoryBudget::webgpu_default();
398        let est = budget.estimate_d4_dispersion_memory(10).unwrap();
399        // sqrt(128MB / 4) ≈ 5792 atoms
400        assert!(est.max_atoms_for_limit > 5000);
401    }
402
403    #[test]
404    fn test_optimal_grid_dispatch() {
405        let budget = GpuMemoryBudget::webgpu_default();
406        let (wg_size, wg_count) = budget.optimal_grid_dispatch([64, 64, 64]);
407        assert_eq!(wg_size, [8, 8, 4]);
408        assert_eq!(wg_count, [8, 8, 16]);
409    }
410
411    #[test]
412    fn test_optimal_1d_dispatch() {
413        let budget = GpuMemoryBudget::webgpu_default();
414        let (wg_size, wg_count) = budget.optimal_1d_dispatch(1000);
415        assert_eq!(wg_size, 64);
416        assert_eq!(wg_count, 16); // ceil(1000/64)
417    }
418}