Skip to main content

sci_form/gpu/
orbital_grid.rs

1//! Molecular orbital evaluation on 3D grids (GPU + CPU).
2//!
3//! Computes ψ_i(r) = Σ_μ C_{μi} φ_μ(r) on a regular 3D grid.
4//! Each grid point can be evaluated independently — ideal for GPU.
5//!
6//! GPU path: dispatches ORBITAL_GRID_SHADER via wgpu.
7//! CPU path: direct triple-nested loop (always available as fallback).
8
9use super::backend_report::OrbitalGridReport;
10use super::context::{
11    bytes_to_f32_vec, f32_slice_to_bytes, ComputeBindingDescriptor, ComputeBindingKind,
12    ComputeDispatchDescriptor, GpuContext,
13};
14use crate::scf::basis::{BasisFunction, BasisSet};
15use nalgebra::DMatrix;
16
17/// 3D grid parameters.
18#[derive(Debug, Clone)]
19pub struct GridParams {
20    /// Grid origin (x, y, z) in Bohr.
21    pub origin: [f64; 3],
22    /// Grid spacing in Bohr.
23    pub spacing: f64,
24    /// Number of grid points [nx, ny, nz].
25    pub dimensions: [usize; 3],
26}
27
28impl GridParams {
29    /// Create grid params enclosing the molecule with padding.
30    pub fn from_molecule(positions: &[[f64; 3]], spacing: f64, padding: f64) -> Self {
31        let mut min = [f64::MAX; 3];
32        let mut max = [f64::MIN; 3];
33
34        for pos in positions {
35            for k in 0..3 {
36                min[k] = min[k].min(pos[k]);
37                max[k] = max[k].max(pos[k]);
38            }
39        }
40
41        let origin = [min[0] - padding, min[1] - padding, min[2] - padding];
42        let dimensions = [
43            ((max[0] - min[0] + 2.0 * padding) / spacing).ceil() as usize + 1,
44            ((max[1] - min[1] + 2.0 * padding) / spacing).ceil() as usize + 1,
45            ((max[2] - min[2] + 2.0 * padding) / spacing).ceil() as usize + 1,
46        ];
47
48        Self {
49            origin,
50            spacing,
51            dimensions,
52        }
53    }
54
55    /// Total number of grid points.
56    pub fn n_points(&self) -> usize {
57        self.dimensions[0] * self.dimensions[1] * self.dimensions[2]
58    }
59
60    /// 3D coordinate of grid point (ix, iy, iz).
61    pub fn point(&self, ix: usize, iy: usize, iz: usize) -> [f64; 3] {
62        [
63            self.origin[0] + ix as f64 * self.spacing,
64            self.origin[1] + iy as f64 * self.spacing,
65            self.origin[2] + iz as f64 * self.spacing,
66        ]
67    }
68
69    /// Flat index from 3D indices.
70    pub fn flat_index(&self, ix: usize, iy: usize, iz: usize) -> usize {
71        ix * self.dimensions[1] * self.dimensions[2] + iy * self.dimensions[2] + iz
72    }
73}
74
75/// Result of orbital grid evaluation.
76#[derive(Debug, Clone)]
77pub struct OrbitalGrid {
78    /// Grid values (flat, row-major: x varies slowest).
79    pub values: Vec<f64>,
80    pub params: GridParams,
81    pub orbital_index: usize,
82}
83
84/// Evaluate a molecular orbital on a 3D grid with explicit backend reporting.
85///
86/// Attempts GPU dispatch when available; falls back to CPU otherwise.
87pub fn evaluate_orbital_with_report(
88    basis: &BasisSet,
89    mo_coefficients: &DMatrix<f64>,
90    orbital_index: usize,
91    params: &GridParams,
92) -> (OrbitalGrid, OrbitalGridReport) {
93    let ctx = GpuContext::best_available();
94
95    if ctx.is_gpu_available() {
96        match evaluate_orbital_gpu(&ctx, basis, mo_coefficients, orbital_index, params) {
97            Ok(grid) => {
98                let report = OrbitalGridReport {
99                    backend: ctx.capabilities.backend.clone(),
100                    used_gpu: true,
101                    attempted_gpu: true,
102                    n_points: params.n_points(),
103                    note: format!("GPU dispatch on {}", ctx.capabilities.backend),
104                };
105                return (grid, report);
106            }
107            Err(_err) => {
108                // Fall through to CPU
109            }
110        }
111    }
112
113    let grid = evaluate_orbital_cpu(basis, mo_coefficients, orbital_index, params);
114    let report = OrbitalGridReport {
115        backend: "CPU".to_string(),
116        used_gpu: false,
117        attempted_gpu: ctx.is_gpu_available(),
118        n_points: params.n_points(),
119        note: if ctx.is_gpu_available() {
120            "GPU available but dispatch failed; CPU fallback used".to_string()
121        } else {
122            "CPU evaluation (GPU not available)".to_string()
123        },
124    };
125    (grid, report)
126}
127
128/// Evaluate orbital on CPU (always available).
129pub fn evaluate_orbital_cpu(
130    basis: &BasisSet,
131    mo_coefficients: &DMatrix<f64>,
132    orbital_index: usize,
133    params: &GridParams,
134) -> OrbitalGrid {
135    let n_points = params.n_points();
136    let mut values = vec![0.0; n_points];
137    let n_basis = basis.n_basis;
138    let [nx, ny, nz] = params.dimensions;
139
140    for ix in 0..nx {
141        for iy in 0..ny {
142            for iz in 0..nz {
143                let r = params.point(ix, iy, iz);
144                let idx = params.flat_index(ix, iy, iz);
145
146                let mut psi = 0.0;
147                for mu in 0..n_basis {
148                    let c_mu = mo_coefficients[(mu, orbital_index)];
149                    if c_mu.abs() < 1e-15 {
150                        continue;
151                    }
152                    let phi_mu = evaluate_basis_function(&basis.functions[mu], &r);
153                    psi += c_mu * phi_mu;
154                }
155                values[idx] = psi;
156            }
157        }
158    }
159
160    OrbitalGrid {
161        values,
162        params: params.clone(),
163        orbital_index,
164    }
165}
166
167/// Evaluate electron density ρ(r) = Σ_{μν} P_{μν} φ_μ(r) φ_ν(r) on a 3D grid.
168pub fn evaluate_density_cpu(
169    basis: &BasisSet,
170    density: &DMatrix<f64>,
171    params: &GridParams,
172) -> Vec<f64> {
173    let n_points = params.n_points();
174    let mut values = vec![0.0; n_points];
175    let n_basis = basis.n_basis;
176    let [nx, ny, nz] = params.dimensions;
177
178    for ix in 0..nx {
179        for iy in 0..ny {
180            for iz in 0..nz {
181                let r = params.point(ix, iy, iz);
182                let idx = params.flat_index(ix, iy, iz);
183
184                let phi: Vec<f64> = (0..n_basis)
185                    .map(|mu| evaluate_basis_function(&basis.functions[mu], &r))
186                    .collect();
187
188                let mut rho = 0.0;
189                for mu in 0..n_basis {
190                    if phi[mu].abs() < 1e-15 {
191                        continue;
192                    }
193                    for nu in 0..n_basis {
194                        rho += density[(mu, nu)] * phi[mu] * phi[nu];
195                    }
196                }
197                values[idx] = rho;
198            }
199        }
200    }
201    values
202}
203
204/// Evaluate a single contracted Gaussian basis function at point r.
205fn evaluate_basis_function(bf: &BasisFunction, r: &[f64; 3]) -> f64 {
206    let dx = r[0] - bf.center[0];
207    let dy = r[1] - bf.center[1];
208    let dz = r[2] - bf.center[2];
209    let r2 = dx * dx + dy * dy + dz * dz;
210
211    let angular = dx.powi(bf.angular[0] as i32)
212        * dy.powi(bf.angular[1] as i32)
213        * dz.powi(bf.angular[2] as i32);
214
215    let mut radial = 0.0;
216    for prim in &bf.primitives {
217        radial += prim.coefficient * (-prim.alpha * r2).exp();
218    }
219
220    BasisFunction::normalization(
221        bf.primitives.first().map(|p| p.alpha).unwrap_or(1.0),
222        bf.angular[0],
223        bf.angular[1],
224        bf.angular[2],
225    ) * angular
226        * radial
227}
228
229// ─── GPU dispatch ────────────────────────────────────────────────────────────
230
231/// Pack basis function data for the GPU shader.
232///
233/// Each basis function → GpuBasisFunc (32 bytes):
234///   center: vec3<f32>, lx: u32, ly: u32, lz: u32, n_primitives: u32, coefficient: f32
235///
236/// Primitives → (alpha: f32, coeff: f32) pairs, max 3 per basis function (STO-3G).
237fn pack_basis_for_gpu(basis: &BasisSet) -> (Vec<u8>, Vec<u8>) {
238    let mut basis_bytes = Vec::new();
239    let mut prim_bytes = Vec::new();
240
241    for bf in &basis.functions {
242        // center xyz
243        basis_bytes.extend_from_slice(&(bf.center[0] as f32).to_ne_bytes());
244        basis_bytes.extend_from_slice(&(bf.center[1] as f32).to_ne_bytes());
245        basis_bytes.extend_from_slice(&(bf.center[2] as f32).to_ne_bytes());
246        // lx, ly, lz
247        basis_bytes.extend_from_slice(&bf.angular[0].to_ne_bytes());
248        basis_bytes.extend_from_slice(&bf.angular[1].to_ne_bytes());
249        basis_bytes.extend_from_slice(&bf.angular[2].to_ne_bytes());
250        // n_primitives
251        basis_bytes.extend_from_slice(&(bf.primitives.len() as u32).to_ne_bytes());
252        // normalization coefficient
253        let norm = BasisFunction::normalization(
254            bf.primitives.first().map(|p| p.alpha).unwrap_or(1.0),
255            bf.angular[0],
256            bf.angular[1],
257            bf.angular[2],
258        );
259        basis_bytes.extend_from_slice(&(norm as f32).to_ne_bytes());
260
261        // Pack primitives (max 3 for STO-3G)
262        for i in 0..3 {
263            if i < bf.primitives.len() {
264                prim_bytes.extend_from_slice(&(bf.primitives[i].alpha as f32).to_ne_bytes());
265                prim_bytes.extend_from_slice(&(bf.primitives[i].coefficient as f32).to_ne_bytes());
266            } else {
267                prim_bytes.extend_from_slice(&0.0f32.to_ne_bytes());
268                prim_bytes.extend_from_slice(&0.0f32.to_ne_bytes());
269            }
270        }
271    }
272
273    (basis_bytes, prim_bytes)
274}
275
276/// GPU-accelerated orbital grid evaluation.
277fn evaluate_orbital_gpu(
278    ctx: &GpuContext,
279    basis: &BasisSet,
280    mo_coefficients: &DMatrix<f64>,
281    orbital_index: usize,
282    params: &GridParams,
283) -> Result<OrbitalGrid, String> {
284    let n_basis = basis.n_basis;
285    let n_points = params.n_points();
286
287    // Pack basis functions and primitives
288    let (basis_bytes, prim_bytes) = pack_basis_for_gpu(basis);
289
290    // Pack MO coefficients for this orbital
291    let mo_coeffs: Vec<f32> = (0..n_basis)
292        .map(|mu| mo_coefficients[(mu, orbital_index)] as f32)
293        .collect();
294
295    // Pack grid params: origin (3×f32) + spacing (f32) + dims (3×u32) + orbital_index (u32) = 32 bytes
296    let mut params_bytes = Vec::with_capacity(32);
297    for v in &params.origin {
298        params_bytes.extend_from_slice(&(*v as f32).to_ne_bytes());
299    }
300    params_bytes.extend_from_slice(&(params.spacing as f32).to_ne_bytes());
301    for d in &params.dimensions {
302        params_bytes.extend_from_slice(&(*d as u32).to_ne_bytes());
303    }
304    params_bytes.extend_from_slice(&(orbital_index as u32).to_ne_bytes());
305
306    // Output buffer
307    let output_seed = vec![0.0f32; n_points];
308
309    let [nx, ny, nz] = params.dimensions;
310    let wg = [
311        (nx as u32).div_ceil(8),
312        (ny as u32).div_ceil(8),
313        (nz as u32).div_ceil(4),
314    ];
315
316    let descriptor = ComputeDispatchDescriptor {
317        label: "orbital grid".to_string(),
318        shader_source: ORBITAL_GRID_SHADER.to_string(),
319        entry_point: "main".to_string(),
320        workgroup_count: wg,
321        bindings: vec![
322            ComputeBindingDescriptor {
323                label: "basis".to_string(),
324                kind: ComputeBindingKind::StorageReadOnly,
325                bytes: basis_bytes,
326            },
327            ComputeBindingDescriptor {
328                label: "mo_coeffs".to_string(),
329                kind: ComputeBindingKind::StorageReadOnly,
330                bytes: f32_slice_to_bytes(&mo_coeffs),
331            },
332            ComputeBindingDescriptor {
333                label: "primitives".to_string(),
334                kind: ComputeBindingKind::StorageReadOnly,
335                bytes: prim_bytes,
336            },
337            ComputeBindingDescriptor {
338                label: "params".to_string(),
339                kind: ComputeBindingKind::Uniform,
340                bytes: params_bytes,
341            },
342            ComputeBindingDescriptor {
343                label: "output".to_string(),
344                kind: ComputeBindingKind::StorageReadWrite,
345                bytes: f32_slice_to_bytes(&output_seed),
346            },
347        ],
348    };
349
350    let mut result = ctx.run_compute(&descriptor)?.outputs;
351    let bytes = result.pop().ok_or("No output from orbital grid kernel")?;
352    let f32_values = bytes_to_f32_vec(&bytes);
353
354    if f32_values.len() != n_points {
355        return Err(format!(
356            "Output size mismatch: expected {}, got {}",
357            n_points,
358            f32_values.len()
359        ));
360    }
361
362    let values: Vec<f64> = f32_values.iter().map(|v| *v as f64).collect();
363
364    Ok(OrbitalGrid {
365        values,
366        params: params.clone(),
367        orbital_index,
368    })
369}
370
371/// WGSL compute shader for orbital grid evaluation.
372///
373/// Evaluates ψ_i(r) = Σ_μ C_{μi} φ_μ(r) at each grid point.
374/// Workgroup size: (8, 8, 4) = 256 threads.
375pub const ORBITAL_GRID_SHADER: &str = r#"
376struct BasisFunc {
377    center_x: f32, center_y: f32, center_z: f32,
378    lx: u32, ly: u32, lz: u32,
379    n_primitives: u32,
380    norm_coeff: f32,
381};
382
383struct GridParams {
384    origin_x: f32, origin_y: f32, origin_z: f32,
385    spacing: f32,
386    dims_x: u32, dims_y: u32, dims_z: u32,
387    orbital_index: u32,
388};
389
390@group(0) @binding(0) var<storage, read> basis: array<BasisFunc>;
391@group(0) @binding(1) var<storage, read> mo_coeffs: array<f32>;
392@group(0) @binding(2) var<storage, read> primitives: array<vec2<f32>>;
393@group(0) @binding(3) var<uniform> params: GridParams;
394@group(0) @binding(4) var<storage, read_write> output: array<f32>;
395
396@compute @workgroup_size(8, 8, 4)
397fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
398    let ix = gid.x;
399    let iy = gid.y;
400    let iz = gid.z;
401
402    if (ix >= params.dims_x || iy >= params.dims_y || iz >= params.dims_z) {
403        return;
404    }
405
406    let rx = params.origin_x + f32(ix) * params.spacing;
407    let ry = params.origin_y + f32(iy) * params.spacing;
408    let rz = params.origin_z + f32(iz) * params.spacing;
409
410    let flat_idx = ix * params.dims_y * params.dims_z + iy * params.dims_z + iz;
411    let n_basis = arrayLength(&mo_coeffs);
412
413    var psi: f32 = 0.0;
414
415    for (var mu: u32 = 0u; mu < n_basis; mu = mu + 1u) {
416        let c_mu = mo_coeffs[mu];
417        if (abs(c_mu) < 1e-7) {
418            continue;
419        }
420
421        let bf = basis[mu];
422        let dx = rx - bf.center_x;
423        let dy = ry - bf.center_y;
424        let dz = rz - bf.center_z;
425        let r2 = dx * dx + dy * dy + dz * dz;
426
427        // Angular part
428        var angular: f32 = 1.0;
429        for (var i: u32 = 0u; i < bf.lx; i = i + 1u) { angular *= dx; }
430        for (var i: u32 = 0u; i < bf.ly; i = i + 1u) { angular *= dy; }
431        for (var i: u32 = 0u; i < bf.lz; i = i + 1u) { angular *= dz; }
432
433        // Radial part (contracted, max 3 primitives for STO-3G)
434        var radial: f32 = 0.0;
435        for (var p: u32 = 0u; p < bf.n_primitives; p = p + 1u) {
436            let prim = primitives[mu * 3u + p];
437            radial += prim.y * exp(-prim.x * r2);
438        }
439
440        psi += c_mu * bf.norm_coeff * angular * radial;
441    }
442
443    output[flat_idx] = psi;
444}
445"#;
446
447#[cfg(test)]
448mod tests {
449    use super::*;
450
451    #[test]
452    fn test_grid_params_from_molecule() {
453        let positions = vec![[0.0, 0.0, 0.0], [2.0, 0.0, 0.0]];
454        let params = GridParams::from_molecule(&positions, 0.5, 3.0);
455        assert!(params.dimensions[0] > 0);
456        assert!(params.n_points() > 0);
457        assert!(params.origin[0] < -2.0);
458    }
459
460    #[test]
461    fn test_grid_point_coordinates() {
462        let params = GridParams {
463            origin: [0.0, 0.0, 0.0],
464            spacing: 1.0,
465            dimensions: [3, 3, 3],
466        };
467        let p = params.point(1, 2, 0);
468        assert!((p[0] - 1.0).abs() < 1e-12);
469        assert!((p[1] - 2.0).abs() < 1e-12);
470    }
471
472    #[test]
473    fn test_flat_index() {
474        let params = GridParams {
475            origin: [0.0, 0.0, 0.0],
476            spacing: 1.0,
477            dimensions: [3, 4, 5],
478        };
479        assert_eq!(params.flat_index(0, 0, 0), 0);
480        assert_eq!(params.flat_index(0, 0, 1), 1);
481        assert_eq!(params.flat_index(0, 1, 0), 5);
482        assert_eq!(params.flat_index(1, 0, 0), 20);
483    }
484
485    #[test]
486    fn test_evaluate_orbital_cpu_h2() {
487        // Build H2 basis
488        let elements = [1u8, 1];
489        let positions = [[0.0, 0.0, 0.0], [0.0, 0.0, 1.4]]; // ~0.74 Å in Bohr
490        let basis = BasisSet::sto3g(&elements, &positions);
491
492        // Simple MO coefficients (bonding orbital: equal contribution)
493        let n = basis.n_basis;
494        let mut coeffs = DMatrix::zeros(n, n);
495        let c = 1.0 / (2.0f64).sqrt();
496        coeffs[(0, 0)] = c;
497        if n > 1 {
498            coeffs[(1, 0)] = c;
499        }
500
501        let params = GridParams {
502            origin: [-2.0, -2.0, -2.0],
503            spacing: 0.5,
504            dimensions: [5, 5, 13],
505        };
506
507        let grid = evaluate_orbital_cpu(&basis, &coeffs, 0, &params);
508        assert_eq!(grid.values.len(), params.n_points());
509
510        // ψ should be non-zero near the bond axis
511        let center_idx = params.flat_index(2, 2, 5); // near midpoint
512        assert!(grid.values[center_idx].abs() > 1e-6);
513    }
514
515    #[test]
516    fn test_evaluate_orbital_with_report() {
517        let elements = [1u8, 1];
518        let positions = [[0.0, 0.0, 0.0], [0.0, 0.0, 1.4]];
519        let basis = BasisSet::sto3g(&elements, &positions);
520
521        let n = basis.n_basis;
522        let mut coeffs = DMatrix::zeros(n, n);
523        coeffs[(0, 0)] = 1.0 / (2.0f64).sqrt();
524        if n > 1 {
525            coeffs[(1, 0)] = 1.0 / (2.0f64).sqrt();
526        }
527
528        let params = GridParams {
529            origin: [-1.0, -1.0, -1.0],
530            spacing: 1.0,
531            dimensions: [3, 3, 5],
532        };
533
534        let (grid, report) = evaluate_orbital_with_report(&basis, &coeffs, 0, &params);
535        assert_eq!(grid.values.len(), params.n_points());
536        assert!(!report.backend.is_empty());
537        assert_eq!(report.n_points, params.n_points());
538    }
539
540    #[test]
541    fn test_evaluate_density_cpu() {
542        let elements = [1u8, 1];
543        let positions = [[0.0, 0.0, 0.0], [0.0, 0.0, 1.4]];
544        let basis = BasisSet::sto3g(&elements, &positions);
545
546        let n = basis.n_basis;
547        // Simple density matrix
548        let density = DMatrix::from_fn(n, n, |i, j| if i == j { 1.0 } else { 0.3 });
549
550        let params = GridParams {
551            origin: [-1.0, -1.0, -1.0],
552            spacing: 1.0,
553            dimensions: [3, 3, 4],
554        };
555
556        let values = evaluate_density_cpu(&basis, &density, &params);
557        assert_eq!(values.len(), params.n_points());
558        // Density should be non-negative at most points (positive definite P)
559    }
560
561    #[test]
562    fn test_pack_basis_for_gpu() {
563        let elements = [1u8];
564        let positions = [[0.0, 0.0, 0.0]];
565        let basis = BasisSet::sto3g(&elements, &positions);
566
567        let (basis_bytes, prim_bytes) = pack_basis_for_gpu(&basis);
568        // Each basis function: 8 × f32/u32 = 32 bytes
569        assert_eq!(basis_bytes.len(), basis.n_basis * 32);
570        // Each basis function: 3 primitives × 2 × f32 = 24 bytes
571        assert_eq!(prim_bytes.len(), basis.n_basis * 24);
572    }
573}