1#[derive(Debug, Clone)]
24pub struct GpuMemoryLimits {
25 pub max_storage_buffer_binding_size: u64,
27 pub max_buffer_size: u64,
29 pub max_uniform_buffer_binding_size: u64,
31 pub max_storage_buffers_per_stage: u32,
33 pub max_workgroup_storage_size: u32,
35 pub max_invocations_per_workgroup: u32,
37 pub max_workgroup_size_x: u32,
39 pub max_workgroup_size_y: u32,
41 pub max_workgroup_size_z: u32,
43 pub max_workgroups_per_dimension: u32,
45}
46
47impl Default for GpuMemoryLimits {
48 fn default() -> Self {
49 Self {
50 max_storage_buffer_binding_size: 134_217_728, max_buffer_size: 268_435_456, max_uniform_buffer_binding_size: 65_536, max_storage_buffers_per_stage: 8,
54 max_workgroup_storage_size: 16_384, max_invocations_per_workgroup: 256,
56 max_workgroup_size_x: 256,
57 max_workgroup_size_y: 256,
58 max_workgroup_size_z: 64,
59 max_workgroups_per_dimension: 65535,
60 }
61 }
62}
63
64#[derive(Debug, Clone)]
66pub struct GpuMemoryBudget {
67 pub limits: GpuMemoryLimits,
68 pub allocated_bytes: u64,
70 pub storage_bindings_used: u32,
72}
73
74impl GpuMemoryBudget {
75 pub fn new(limits: GpuMemoryLimits) -> Self {
76 Self {
77 limits,
78 allocated_bytes: 0,
79 storage_bindings_used: 0,
80 }
81 }
82
83 pub fn webgpu_default() -> Self {
85 Self::new(GpuMemoryLimits::default())
86 }
87
88 pub fn check_buffer(&self, size_bytes: u64) -> Result<(), MemoryError> {
90 if size_bytes > self.limits.max_buffer_size {
91 return Err(MemoryError::BufferTooLarge {
92 requested: size_bytes,
93 max: self.limits.max_buffer_size,
94 });
95 }
96 if size_bytes > self.limits.max_storage_buffer_binding_size {
97 return Err(MemoryError::BindingTooLarge {
98 requested: size_bytes,
99 max: self.limits.max_storage_buffer_binding_size,
100 });
101 }
102 Ok(())
103 }
104
105 pub fn check_storage_binding(&self) -> Result<(), MemoryError> {
107 if self.storage_bindings_used >= self.limits.max_storage_buffers_per_stage {
108 return Err(MemoryError::TooManyBindings {
109 current: self.storage_bindings_used,
110 max: self.limits.max_storage_buffers_per_stage,
111 });
112 }
113 Ok(())
114 }
115
116 pub fn check_workgroup(&self, size: [u32; 3], count: [u32; 3]) -> Result<(), MemoryError> {
118 if size[0] > self.limits.max_workgroup_size_x
119 || size[1] > self.limits.max_workgroup_size_y
120 || size[2] > self.limits.max_workgroup_size_z
121 {
122 return Err(MemoryError::WorkgroupSizeExceeded {
123 requested: size,
124 max: [
125 self.limits.max_workgroup_size_x,
126 self.limits.max_workgroup_size_y,
127 self.limits.max_workgroup_size_z,
128 ],
129 });
130 }
131 let total = size[0] as u64 * size[1] as u64 * size[2] as u64;
132 if total > self.limits.max_invocations_per_workgroup as u64 {
133 return Err(MemoryError::TooManyInvocations {
134 requested: total as u32,
135 max: self.limits.max_invocations_per_workgroup,
136 });
137 }
138 for (i, &c) in count.iter().enumerate() {
139 if c > self.limits.max_workgroups_per_dimension {
140 return Err(MemoryError::TooManyWorkgroups {
141 dimension: i as u32,
142 requested: c,
143 max: self.limits.max_workgroups_per_dimension,
144 });
145 }
146 }
147 Ok(())
148 }
149
150 pub fn estimate_orbital_grid_memory(
154 &self,
155 n_basis: usize,
156 grid_points: usize,
157 max_primitives_per_basis: usize,
158 ) -> Result<OrbitalGridMemoryEstimate, MemoryError> {
159 let basis_bytes = (n_basis * 32) as u64;
166 let mo_bytes = (n_basis * 4) as u64;
167 let prim_bytes = (n_basis * max_primitives_per_basis * 8) as u64;
168 let params_bytes = 32u64;
169 let output_bytes = (grid_points * 4) as u64;
170
171 let total = basis_bytes + mo_bytes + prim_bytes + params_bytes + output_bytes;
172
173 self.check_buffer(basis_bytes)?;
175 self.check_buffer(mo_bytes)?;
176 self.check_buffer(prim_bytes)?;
177 self.check_buffer(output_bytes)?;
178
179 if 5 > self.limits.max_storage_buffers_per_stage {
181 return Err(MemoryError::TooManyBindings {
182 current: 5,
183 max: self.limits.max_storage_buffers_per_stage,
184 });
185 }
186
187 Ok(OrbitalGridMemoryEstimate {
188 basis_bytes,
189 mo_coefficients_bytes: mo_bytes,
190 primitives_bytes: prim_bytes,
191 params_bytes,
192 output_bytes,
193 total_bytes: total,
194 fits_in_webgpu: total <= self.limits.max_buffer_size,
195 })
196 }
197
198 pub fn estimate_d4_dispersion_memory(
200 &self,
201 n_atoms: usize,
202 ) -> Result<PairwiseMemoryEstimate, MemoryError> {
203 let pos_bytes = (n_atoms * 16) as u64;
210 let params_bytes = (n_atoms * 32) as u64;
211 let config_bytes = 32u64;
212 let pairwise_bytes = (n_atoms * n_atoms * 4) as u64;
213 let output_bytes = (n_atoms * 4) as u64;
214
215 let total = pos_bytes + params_bytes + config_bytes + pairwise_bytes + output_bytes;
216
217 self.check_buffer(pairwise_bytes)?;
218
219 Ok(PairwiseMemoryEstimate {
220 positions_bytes: pos_bytes,
221 params_bytes,
222 pairwise_bytes,
223 total_bytes: total,
224 fits_in_webgpu: total <= self.limits.max_buffer_size,
225 max_atoms_for_limit: ((self.limits.max_storage_buffer_binding_size / 4) as f64).sqrt()
226 as usize,
227 })
228 }
229
230 pub fn optimal_grid_dispatch(&self, dims: [u32; 3]) -> ([u32; 3], [u32; 3]) {
232 let wg_size = [
233 8u32.min(self.limits.max_workgroup_size_x),
234 8u32.min(self.limits.max_workgroup_size_y),
235 4u32.min(self.limits.max_workgroup_size_z),
236 ];
237 let wg_count = [
238 dims[0].div_ceil(wg_size[0]),
239 dims[1].div_ceil(wg_size[1]),
240 dims[2].div_ceil(wg_size[2]),
241 ];
242 (wg_size, wg_count)
243 }
244
245 pub fn optimal_1d_dispatch(&self, n: u32) -> (u32, u32) {
247 let wg_size = 64u32.min(self.limits.max_workgroup_size_x);
248 let wg_count = n.div_ceil(wg_size);
249 (wg_size, wg_count)
250 }
251}
252
253#[derive(Debug, Clone)]
255pub struct OrbitalGridMemoryEstimate {
256 pub basis_bytes: u64,
257 pub mo_coefficients_bytes: u64,
258 pub primitives_bytes: u64,
259 pub params_bytes: u64,
260 pub output_bytes: u64,
261 pub total_bytes: u64,
262 pub fits_in_webgpu: bool,
263}
264
265#[derive(Debug, Clone)]
267pub struct PairwiseMemoryEstimate {
268 pub positions_bytes: u64,
269 pub params_bytes: u64,
270 pub pairwise_bytes: u64,
271 pub total_bytes: u64,
272 pub fits_in_webgpu: bool,
273 pub max_atoms_for_limit: usize,
275}
276
277#[derive(Debug, Clone)]
279pub enum MemoryError {
280 BufferTooLarge {
281 requested: u64,
282 max: u64,
283 },
284 BindingTooLarge {
285 requested: u64,
286 max: u64,
287 },
288 TooManyBindings {
289 current: u32,
290 max: u32,
291 },
292 WorkgroupSizeExceeded {
293 requested: [u32; 3],
294 max: [u32; 3],
295 },
296 TooManyInvocations {
297 requested: u32,
298 max: u32,
299 },
300 TooManyWorkgroups {
301 dimension: u32,
302 requested: u32,
303 max: u32,
304 },
305}
306
307impl std::fmt::Display for MemoryError {
308 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309 match self {
310 MemoryError::BufferTooLarge { requested, max } => {
311 write!(f, "Buffer {requested} bytes exceeds max {max} bytes")
312 }
313 MemoryError::BindingTooLarge { requested, max } => {
314 write!(f, "Binding {requested} bytes exceeds max {max} bytes")
315 }
316 MemoryError::TooManyBindings { current, max } => {
317 write!(f, "Need {current} bindings, max {max}")
318 }
319 MemoryError::WorkgroupSizeExceeded { requested, max } => {
320 write!(
321 f,
322 "Workgroup [{},{},{}] exceeds max [{},{},{}]",
323 requested[0], requested[1], requested[2], max[0], max[1], max[2]
324 )
325 }
326 MemoryError::TooManyInvocations { requested, max } => {
327 write!(f, "{requested} invocations exceeds max {max}")
328 }
329 MemoryError::TooManyWorkgroups {
330 dimension,
331 requested,
332 max,
333 } => {
334 write!(
335 f,
336 "Dimension {dimension}: {requested} workgroups exceeds max {max}"
337 )
338 }
339 }
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346
347 #[test]
348 fn test_webgpu_defaults() {
349 let limits = GpuMemoryLimits::default();
350 assert_eq!(limits.max_storage_buffer_binding_size, 128 * 1024 * 1024);
351 assert_eq!(limits.max_buffer_size, 256 * 1024 * 1024);
352 assert_eq!(limits.max_uniform_buffer_binding_size, 64 * 1024);
353 }
354
355 #[test]
356 fn test_buffer_check_within_limits() {
357 let budget = GpuMemoryBudget::webgpu_default();
358 assert!(budget.check_buffer(1024 * 1024).is_ok()); }
360
361 #[test]
362 fn test_buffer_check_exceeds_limits() {
363 let budget = GpuMemoryBudget::webgpu_default();
364 assert!(budget.check_buffer(300_000_000).is_err()); }
366
367 #[test]
368 fn test_orbital_grid_small_molecule() {
369 let budget = GpuMemoryBudget::webgpu_default();
370 let est = budget.estimate_orbital_grid_memory(7, 125_000, 3).unwrap();
372 assert!(est.fits_in_webgpu);
373 assert!(est.total_bytes < 1_000_000); }
375
376 #[test]
377 fn test_workgroup_check_valid() {
378 let budget = GpuMemoryBudget::webgpu_default();
379 assert!(budget.check_workgroup([8, 8, 4], [100, 100, 50]).is_ok());
380 }
381
382 #[test]
383 fn test_workgroup_check_too_large() {
384 let budget = GpuMemoryBudget::webgpu_default();
385 assert!(budget.check_workgroup([512, 1, 1], [1, 1, 1]).is_err());
386 }
387
388 #[test]
389 fn test_d4_memory_small_system() {
390 let budget = GpuMemoryBudget::webgpu_default();
391 let est = budget.estimate_d4_dispersion_memory(100).unwrap();
392 assert!(est.fits_in_webgpu);
393 }
394
395 #[test]
396 fn test_d4_max_atoms_calculable() {
397 let budget = GpuMemoryBudget::webgpu_default();
398 let est = budget.estimate_d4_dispersion_memory(10).unwrap();
399 assert!(est.max_atoms_for_limit > 5000);
401 }
402
403 #[test]
404 fn test_optimal_grid_dispatch() {
405 let budget = GpuMemoryBudget::webgpu_default();
406 let (wg_size, wg_count) = budget.optimal_grid_dispatch([64, 64, 64]);
407 assert_eq!(wg_size, [8, 8, 4]);
408 assert_eq!(wg_count, [8, 8, 16]);
409 }
410
411 #[test]
412 fn test_optimal_1d_dispatch() {
413 let budget = GpuMemoryBudget::webgpu_default();
414 let (wg_size, wg_count) = budget.optimal_1d_dispatch(1000);
415 assert_eq!(wg_size, 64);
416 assert_eq!(wg_count, 16); }
418}