1use super::aligned_types::{GpuMatrixParams, GpuScfParams};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
10pub enum BufferRole {
11 Atoms,
13 BasisFunctions,
15 Primitives,
17 OverlapMatrix,
19 KineticMatrix,
21 NuclearMatrix,
23 CoreHamiltonian,
25 DensityMatrix,
27 FockMatrix,
29 TwoElectronIntegrals,
31 ScfParams,
33 MatrixParams,
35 EspGrid,
37 OrbitalGrid,
39}
40
41impl BufferRole {
42 pub fn is_uniform(&self) -> bool {
44 matches!(self, BufferRole::ScfParams | BufferRole::MatrixParams)
45 }
46}
47
48#[derive(Debug, Clone)]
50pub struct ManagedBuffer {
51 pub role: BufferRole,
52 pub size_bytes: usize,
53 pub label: String,
54 pub is_mapped: bool,
55 pub cpu_data: Option<Vec<u8>>,
57}
58
59impl ManagedBuffer {
60 pub fn new(role: BufferRole, size_bytes: usize, label: &str) -> Self {
61 Self {
62 role,
63 size_bytes,
64 label: label.to_string(),
65 is_mapped: false,
66 cpu_data: None,
67 }
68 }
69
70 pub fn with_data(role: BufferRole, data: Vec<u8>, label: &str) -> Self {
71 let size = data.len();
72 Self {
73 role,
74 size_bytes: size,
75 label: label.to_string(),
76 is_mapped: false,
77 cpu_data: Some(data),
78 }
79 }
80}
81
82const MAX_STORAGE_BUFFER_SIZE: usize = 128 * 1024 * 1024; pub struct BufferManager {
87 pub buffers: Vec<ManagedBuffer>,
88 pub total_allocated: usize,
89 pub max_budget: usize,
90}
91
92impl BufferManager {
93 pub fn new(max_budget_mb: usize) -> Self {
94 Self {
95 buffers: Vec::new(),
96 total_allocated: 0,
97 max_budget: max_budget_mb * 1024 * 1024,
98 }
99 }
100
101 pub fn allocate(
103 &mut self,
104 role: BufferRole,
105 size_bytes: usize,
106 label: &str,
107 ) -> Result<usize, String> {
108 if !role.is_uniform() && size_bytes > MAX_STORAGE_BUFFER_SIZE {
109 return Err(format!(
110 "Buffer '{}' ({} MB) exceeds WebGPU maxStorageBufferBindingSize (128 MB)",
111 label,
112 size_bytes / (1024 * 1024)
113 ));
114 }
115
116 if self.total_allocated + size_bytes > self.max_budget {
117 return Err(format!(
118 "Buffer '{}' ({} bytes) would exceed GPU memory budget ({} MB)",
119 label,
120 size_bytes,
121 self.max_budget / (1024 * 1024)
122 ));
123 }
124
125 let buf = ManagedBuffer::new(role, size_bytes, label);
126 let idx = self.buffers.len();
127 self.total_allocated += size_bytes;
128 self.buffers.push(buf);
129 Ok(idx)
130 }
131
132 pub fn allocate_with_data(
134 &mut self,
135 role: BufferRole,
136 data: Vec<u8>,
137 label: &str,
138 ) -> Result<usize, String> {
139 let size = data.len();
140 if !role.is_uniform() && size > MAX_STORAGE_BUFFER_SIZE {
141 return Err(format!(
142 "Buffer '{}' ({} MB) exceeds WebGPU maxStorageBufferBindingSize",
143 label,
144 size / (1024 * 1024)
145 ));
146 }
147 if self.total_allocated + size > self.max_budget {
148 return Err(format!(
149 "Allocation would exceed budget: {} + {} > {}",
150 self.total_allocated, size, self.max_budget
151 ));
152 }
153
154 let buf = ManagedBuffer::with_data(role, data, label);
155 let idx = self.buffers.len();
156 self.total_allocated += size;
157 self.buffers.push(buf);
158 Ok(idx)
159 }
160
161 pub fn matrix_size(n_basis: usize) -> usize {
163 n_basis * n_basis * std::mem::size_of::<f32>()
164 }
165
166 pub fn plan_scf_buffers(
168 &mut self,
169 n_atoms: usize,
170 n_basis: usize,
171 n_primitives: usize,
172 ) -> Result<ScfBufferPlan, String> {
173 let atom_size = n_atoms * std::mem::size_of::<super::aligned_types::GpuAtom>();
174 let basis_size = n_basis * std::mem::size_of::<super::aligned_types::GpuBasisFunction>();
175 let prim_size =
176 n_primitives * std::mem::size_of::<super::aligned_types::GpuGaussianPrimitive>();
177 let mat_size = Self::matrix_size(n_basis);
178
179 let atoms = self.allocate(BufferRole::Atoms, atom_size, "atoms")?;
180 let basis = self.allocate(BufferRole::BasisFunctions, basis_size, "basis_functions")?;
181 let prims = self.allocate(BufferRole::Primitives, prim_size, "primitives")?;
182 let overlap = self.allocate(BufferRole::OverlapMatrix, mat_size, "overlap_matrix")?;
183 let kinetic = self.allocate(BufferRole::KineticMatrix, mat_size, "kinetic_matrix")?;
184 let nuclear = self.allocate(BufferRole::NuclearMatrix, mat_size, "nuclear_matrix")?;
185 let core_h = self.allocate(BufferRole::CoreHamiltonian, mat_size, "core_hamiltonian")?;
186 let density = self.allocate(BufferRole::DensityMatrix, mat_size, "density_matrix")?;
187 let fock = self.allocate(BufferRole::FockMatrix, mat_size, "fock_matrix")?;
188
189 let params_size = std::mem::size_of::<GpuMatrixParams>();
190 let scf_params_size = std::mem::size_of::<GpuScfParams>();
191 let mat_params = self.allocate(BufferRole::MatrixParams, params_size, "matrix_params")?;
192 let scf_params = self.allocate(BufferRole::ScfParams, scf_params_size, "scf_params")?;
193
194 Ok(ScfBufferPlan {
195 atoms,
196 basis,
197 prims,
198 overlap,
199 kinetic,
200 nuclear,
201 core_h,
202 density,
203 fock,
204 mat_params,
205 scf_params,
206 total_bytes: self.total_allocated,
207 })
208 }
209
210 pub fn release_all(&mut self) {
212 self.buffers.clear();
213 self.total_allocated = 0;
214 }
215
216 pub fn summary(&self) -> String {
218 let mut s = format!(
219 "BufferManager: {} buffers, {:.2} MB / {:.2} MB\n",
220 self.buffers.len(),
221 self.total_allocated as f64 / (1024.0 * 1024.0),
222 self.max_budget as f64 / (1024.0 * 1024.0),
223 );
224 for (i, buf) in self.buffers.iter().enumerate() {
225 s.push_str(&format!(
226 " [{}] {:?} '{}' — {} bytes\n",
227 i, buf.role, buf.label, buf.size_bytes
228 ));
229 }
230 s
231 }
232}
233
234#[derive(Debug, Clone)]
236pub struct ScfBufferPlan {
237 pub atoms: usize,
238 pub basis: usize,
239 pub prims: usize,
240 pub overlap: usize,
241 pub kinetic: usize,
242 pub nuclear: usize,
243 pub core_h: usize,
244 pub density: usize,
245 pub fock: usize,
246 pub mat_params: usize,
247 pub scf_params: usize,
248 pub total_bytes: usize,
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 #[test]
256 fn test_buffer_allocation() {
257 let mut mgr = BufferManager::new(256);
258 let idx = mgr
259 .allocate(BufferRole::OverlapMatrix, 1024, "test_overlap")
260 .unwrap();
261 assert_eq!(idx, 0);
262 assert_eq!(mgr.total_allocated, 1024);
263 }
264
265 #[test]
266 fn test_budget_overflow() {
267 let mut mgr = BufferManager::new(1); let result = mgr.allocate(
269 BufferRole::TwoElectronIntegrals,
270 2 * 1024 * 1024,
271 "too_large",
272 );
273 assert!(result.is_err());
274 }
275
276 #[test]
277 fn test_scf_plan() {
278 let mut mgr = BufferManager::new(256);
279 let plan = mgr.plan_scf_buffers(10, 20, 60).unwrap();
280 assert!(plan.total_bytes > 0);
281 assert_eq!(mgr.buffers.len(), 11); }
283
284 #[test]
285 fn test_release_all() {
286 let mut mgr = BufferManager::new(256);
287 mgr.allocate(BufferRole::Atoms, 512, "atoms").unwrap();
288 mgr.allocate(BufferRole::OverlapMatrix, 1024, "overlap")
289 .unwrap();
290 assert_eq!(mgr.buffers.len(), 2);
291 mgr.release_all();
292 assert_eq!(mgr.buffers.len(), 0);
293 assert_eq!(mgr.total_allocated, 0);
294 }
295
296 #[test]
297 fn test_webgpu_storage_limit() {
298 let mut mgr = BufferManager::new(512);
299 let result = mgr.allocate(
301 BufferRole::TwoElectronIntegrals,
302 200 * 1024 * 1024,
303 "huge_eri",
304 );
305 assert!(result.is_err());
306 assert!(result.unwrap_err().contains("maxStorageBufferBindingSize"));
307 }
308}