Skip to main content

sci_form/gpu/
buffer_manager.rs

1//! GPU buffer lifecycle management for quantum chemistry workloads.
2//!
3//! Manages allocation, tracking, and release of GPU storage/uniform buffers,
4//! enforcing WebGPU binding limits (128 MB per storage buffer, 256 MB device max).
5
6use super::aligned_types::{GpuMatrixParams, GpuScfParams};
7
8/// Buffer role classification for slot tracking and reuse.
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
10pub enum BufferRole {
11    /// Atom position + parameter data (read-only).
12    Atoms,
13    /// Basis function descriptors (read-only).
14    BasisFunctions,
15    /// Gaussian primitive contraction table (read-only).
16    Primitives,
17    /// Overlap matrix S (read-write).
18    OverlapMatrix,
19    /// Kinetic energy matrix T (read-write).
20    KineticMatrix,
21    /// Nuclear attraction matrix V (read-write).
22    NuclearMatrix,
23    /// Core Hamiltonian H = T + V (read-write).
24    CoreHamiltonian,
25    /// Density matrix P (read-write).
26    DensityMatrix,
27    /// Fock matrix F (read-write).
28    FockMatrix,
29    /// Two-electron integral buffer (read-write, large).
30    TwoElectronIntegrals,
31    /// SCF parameters uniform (uniform).
32    ScfParams,
33    /// Matrix dimension parameters (uniform).
34    MatrixParams,
35    /// ESP grid output (read-write).
36    EspGrid,
37    /// Orbital grid output (read-write).
38    OrbitalGrid,
39}
40
41impl BufferRole {
42    /// Whether this buffer is a uniform buffer (vs storage buffer).
43    pub fn is_uniform(&self) -> bool {
44        matches!(self, BufferRole::ScfParams | BufferRole::MatrixParams)
45    }
46}
47
48/// CPU-side representation of a managed GPU buffer.
49#[derive(Debug, Clone)]
50pub struct ManagedBuffer {
51    pub role: BufferRole,
52    pub size_bytes: usize,
53    pub label: String,
54    pub is_mapped: bool,
55    /// CPU shadow copy (for readback or initial upload).
56    pub cpu_data: Option<Vec<u8>>,
57}
58
59impl ManagedBuffer {
60    pub fn new(role: BufferRole, size_bytes: usize, label: &str) -> Self {
61        Self {
62            role,
63            size_bytes,
64            label: label.to_string(),
65            is_mapped: false,
66            cpu_data: None,
67        }
68    }
69
70    pub fn with_data(role: BufferRole, data: Vec<u8>, label: &str) -> Self {
71        let size = data.len();
72        Self {
73            role,
74            size_bytes: size,
75            label: label.to_string(),
76            is_mapped: false,
77            cpu_data: Some(data),
78        }
79    }
80}
81
82/// Max storage buffer binding size (WebGPU default).
83const MAX_STORAGE_BUFFER_SIZE: usize = 128 * 1024 * 1024; // 128 MB
84
85/// Manages the lifecycle of all GPU buffers for a quantum chemistry calculation.
86pub struct BufferManager {
87    pub buffers: Vec<ManagedBuffer>,
88    pub total_allocated: usize,
89    pub max_budget: usize,
90}
91
92impl BufferManager {
93    pub fn new(max_budget_mb: usize) -> Self {
94        Self {
95            buffers: Vec::new(),
96            total_allocated: 0,
97            max_budget: max_budget_mb * 1024 * 1024,
98        }
99    }
100
101    /// Allocate a new managed buffer. Returns error if it exceeds WebGPU limits.
102    pub fn allocate(
103        &mut self,
104        role: BufferRole,
105        size_bytes: usize,
106        label: &str,
107    ) -> Result<usize, String> {
108        if !role.is_uniform() && size_bytes > MAX_STORAGE_BUFFER_SIZE {
109            return Err(format!(
110                "Buffer '{}' ({} MB) exceeds WebGPU maxStorageBufferBindingSize (128 MB)",
111                label,
112                size_bytes / (1024 * 1024)
113            ));
114        }
115
116        if self.total_allocated + size_bytes > self.max_budget {
117            return Err(format!(
118                "Buffer '{}' ({} bytes) would exceed GPU memory budget ({} MB)",
119                label,
120                size_bytes,
121                self.max_budget / (1024 * 1024)
122            ));
123        }
124
125        let buf = ManagedBuffer::new(role, size_bytes, label);
126        let idx = self.buffers.len();
127        self.total_allocated += size_bytes;
128        self.buffers.push(buf);
129        Ok(idx)
130    }
131
132    /// Allocate a buffer with initial CPU data.
133    pub fn allocate_with_data(
134        &mut self,
135        role: BufferRole,
136        data: Vec<u8>,
137        label: &str,
138    ) -> Result<usize, String> {
139        let size = data.len();
140        if !role.is_uniform() && size > MAX_STORAGE_BUFFER_SIZE {
141            return Err(format!(
142                "Buffer '{}' ({} MB) exceeds WebGPU maxStorageBufferBindingSize",
143                label,
144                size / (1024 * 1024)
145            ));
146        }
147        if self.total_allocated + size > self.max_budget {
148            return Err(format!(
149                "Allocation would exceed budget: {} + {} > {}",
150                self.total_allocated, size, self.max_budget
151            ));
152        }
153
154        let buf = ManagedBuffer::with_data(role, data, label);
155        let idx = self.buffers.len();
156        self.total_allocated += size;
157        self.buffers.push(buf);
158        Ok(idx)
159    }
160
161    /// Compute buffer size for a dense n×n matrix of f32.
162    pub fn matrix_size(n_basis: usize) -> usize {
163        n_basis * n_basis * std::mem::size_of::<f32>()
164    }
165
166    /// Plan all buffers needed for an SCF calculation.
167    pub fn plan_scf_buffers(
168        &mut self,
169        n_atoms: usize,
170        n_basis: usize,
171        n_primitives: usize,
172    ) -> Result<ScfBufferPlan, String> {
173        let atom_size = n_atoms * std::mem::size_of::<super::aligned_types::GpuAtom>();
174        let basis_size = n_basis * std::mem::size_of::<super::aligned_types::GpuBasisFunction>();
175        let prim_size =
176            n_primitives * std::mem::size_of::<super::aligned_types::GpuGaussianPrimitive>();
177        let mat_size = Self::matrix_size(n_basis);
178
179        let atoms = self.allocate(BufferRole::Atoms, atom_size, "atoms")?;
180        let basis = self.allocate(BufferRole::BasisFunctions, basis_size, "basis_functions")?;
181        let prims = self.allocate(BufferRole::Primitives, prim_size, "primitives")?;
182        let overlap = self.allocate(BufferRole::OverlapMatrix, mat_size, "overlap_matrix")?;
183        let kinetic = self.allocate(BufferRole::KineticMatrix, mat_size, "kinetic_matrix")?;
184        let nuclear = self.allocate(BufferRole::NuclearMatrix, mat_size, "nuclear_matrix")?;
185        let core_h = self.allocate(BufferRole::CoreHamiltonian, mat_size, "core_hamiltonian")?;
186        let density = self.allocate(BufferRole::DensityMatrix, mat_size, "density_matrix")?;
187        let fock = self.allocate(BufferRole::FockMatrix, mat_size, "fock_matrix")?;
188
189        let params_size = std::mem::size_of::<GpuMatrixParams>();
190        let scf_params_size = std::mem::size_of::<GpuScfParams>();
191        let mat_params = self.allocate(BufferRole::MatrixParams, params_size, "matrix_params")?;
192        let scf_params = self.allocate(BufferRole::ScfParams, scf_params_size, "scf_params")?;
193
194        Ok(ScfBufferPlan {
195            atoms,
196            basis,
197            prims,
198            overlap,
199            kinetic,
200            nuclear,
201            core_h,
202            density,
203            fock,
204            mat_params,
205            scf_params,
206            total_bytes: self.total_allocated,
207        })
208    }
209
210    /// Release all buffers and reset allocation tracking.
211    pub fn release_all(&mut self) {
212        self.buffers.clear();
213        self.total_allocated = 0;
214    }
215
216    /// Summary of all allocated buffers for debugging.
217    pub fn summary(&self) -> String {
218        let mut s = format!(
219            "BufferManager: {} buffers, {:.2} MB / {:.2} MB\n",
220            self.buffers.len(),
221            self.total_allocated as f64 / (1024.0 * 1024.0),
222            self.max_budget as f64 / (1024.0 * 1024.0),
223        );
224        for (i, buf) in self.buffers.iter().enumerate() {
225            s.push_str(&format!(
226                "  [{}] {:?} '{}' — {} bytes\n",
227                i, buf.role, buf.label, buf.size_bytes
228            ));
229        }
230        s
231    }
232}
233
234/// Indices into the BufferManager for all SCF-related buffers.
235#[derive(Debug, Clone)]
236pub struct ScfBufferPlan {
237    pub atoms: usize,
238    pub basis: usize,
239    pub prims: usize,
240    pub overlap: usize,
241    pub kinetic: usize,
242    pub nuclear: usize,
243    pub core_h: usize,
244    pub density: usize,
245    pub fock: usize,
246    pub mat_params: usize,
247    pub scf_params: usize,
248    pub total_bytes: usize,
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    #[test]
256    fn test_buffer_allocation() {
257        let mut mgr = BufferManager::new(256);
258        let idx = mgr
259            .allocate(BufferRole::OverlapMatrix, 1024, "test_overlap")
260            .unwrap();
261        assert_eq!(idx, 0);
262        assert_eq!(mgr.total_allocated, 1024);
263    }
264
265    #[test]
266    fn test_budget_overflow() {
267        let mut mgr = BufferManager::new(1); // 1 MB budget
268        let result = mgr.allocate(
269            BufferRole::TwoElectronIntegrals,
270            2 * 1024 * 1024,
271            "too_large",
272        );
273        assert!(result.is_err());
274    }
275
276    #[test]
277    fn test_scf_plan() {
278        let mut mgr = BufferManager::new(256);
279        let plan = mgr.plan_scf_buffers(10, 20, 60).unwrap();
280        assert!(plan.total_bytes > 0);
281        assert_eq!(mgr.buffers.len(), 11); // 3 input + 6 matrices + 2 params
282    }
283
284    #[test]
285    fn test_release_all() {
286        let mut mgr = BufferManager::new(256);
287        mgr.allocate(BufferRole::Atoms, 512, "atoms").unwrap();
288        mgr.allocate(BufferRole::OverlapMatrix, 1024, "overlap")
289            .unwrap();
290        assert_eq!(mgr.buffers.len(), 2);
291        mgr.release_all();
292        assert_eq!(mgr.buffers.len(), 0);
293        assert_eq!(mgr.total_allocated, 0);
294    }
295
296    #[test]
297    fn test_webgpu_storage_limit() {
298        let mut mgr = BufferManager::new(512);
299        // 200 MB exceeds the 128 MB per-binding limit
300        let result = mgr.allocate(
301            BufferRole::TwoElectronIntegrals,
302            200 * 1024 * 1024,
303            "huge_eri",
304        );
305        assert!(result.is_err());
306        assert!(result.unwrap_err().contains("maxStorageBufferBindingSize"));
307    }
308}