Skip to main content

sci_form/gpu/
density_grid_gpu.rs

1//! GPU electron-density grid evaluation.
2
3use nalgebra::DMatrix;
4
5use super::backend_report::OrbitalGridReport;
6use super::context::{
7    bytes_to_f32_vec, f32_slice_to_bytes, ComputeBindingDescriptor, ComputeBindingKind,
8    ComputeDispatchDescriptor, GpuContext,
9};
10use crate::gpu::orbital_grid::{evaluate_density_cpu, GridParams};
11use crate::scf::basis::{BasisFunction, BasisSet};
12
13fn pack_basis_for_gpu(basis: &BasisSet) -> (Vec<u8>, Vec<u8>) {
14    let mut basis_bytes = Vec::new();
15    let mut prim_bytes = Vec::new();
16
17    for bf in &basis.functions {
18        basis_bytes.extend_from_slice(&(bf.center[0] as f32).to_ne_bytes());
19        basis_bytes.extend_from_slice(&(bf.center[1] as f32).to_ne_bytes());
20        basis_bytes.extend_from_slice(&(bf.center[2] as f32).to_ne_bytes());
21        basis_bytes.extend_from_slice(&bf.angular[0].to_ne_bytes());
22        basis_bytes.extend_from_slice(&bf.angular[1].to_ne_bytes());
23        basis_bytes.extend_from_slice(&bf.angular[2].to_ne_bytes());
24        basis_bytes.extend_from_slice(&(bf.primitives.len() as u32).to_ne_bytes());
25        let norm = BasisFunction::normalization(
26            bf.primitives.first().map(|p| p.alpha).unwrap_or(1.0),
27            bf.angular[0],
28            bf.angular[1],
29            bf.angular[2],
30        );
31        basis_bytes.extend_from_slice(&(norm as f32).to_ne_bytes());
32
33        for index in 0..3 {
34            if index < bf.primitives.len() {
35                prim_bytes.extend_from_slice(&(bf.primitives[index].alpha as f32).to_ne_bytes());
36                prim_bytes
37                    .extend_from_slice(&(bf.primitives[index].coefficient as f32).to_ne_bytes());
38            } else {
39                prim_bytes.extend_from_slice(&0.0f32.to_ne_bytes());
40                prim_bytes.extend_from_slice(&0.0f32.to_ne_bytes());
41            }
42        }
43    }
44
45    (basis_bytes, prim_bytes)
46}
47
48pub fn evaluate_density_with_report(
49    basis: &BasisSet,
50    density: &DMatrix<f64>,
51    params: &GridParams,
52) -> (Vec<f64>, OrbitalGridReport) {
53    let ctx = GpuContext::best_available();
54    if ctx.is_gpu_available() {
55        match evaluate_density_gpu(&ctx, basis, density, params) {
56            Ok(grid) => {
57                return (
58                    grid,
59                    OrbitalGridReport {
60                        backend: ctx.capabilities.backend.clone(),
61                        used_gpu: true,
62                        attempted_gpu: true,
63                        n_points: params.n_points(),
64                        note: format!("GPU density-grid dispatch on {}", ctx.capabilities.backend),
65                    },
66                );
67            }
68            Err(_err) => {}
69        }
70    }
71
72    let grid = evaluate_density_cpu(basis, density, params);
73    (
74        grid,
75        OrbitalGridReport {
76            backend: "CPU".to_string(),
77            used_gpu: false,
78            attempted_gpu: ctx.is_gpu_available(),
79            n_points: params.n_points(),
80            note: if ctx.is_gpu_available() {
81                "GPU available but density-grid dispatch failed; CPU fallback used".to_string()
82            } else {
83                "CPU density-grid evaluation (GPU not available)".to_string()
84            },
85        },
86    )
87}
88
89pub fn evaluate_density_gpu(
90    ctx: &GpuContext,
91    basis: &BasisSet,
92    density: &DMatrix<f64>,
93    params: &GridParams,
94) -> Result<Vec<f64>, String> {
95    let n_basis = basis.n_basis;
96    let n_points = params.n_points();
97    let (basis_bytes, prim_bytes) = pack_basis_for_gpu(basis);
98    let density_flat: Vec<f32> = (0..n_basis)
99        .flat_map(|mu| (0..n_basis).map(move |nu| density[(mu, nu)] as f32))
100        .collect();
101
102    let mut params_bytes = Vec::with_capacity(32);
103    for value in &params.origin {
104        params_bytes.extend_from_slice(&(*value as f32).to_ne_bytes());
105    }
106    params_bytes.extend_from_slice(&(params.spacing as f32).to_ne_bytes());
107    for dim in &params.dimensions {
108        params_bytes.extend_from_slice(&(*dim as u32).to_ne_bytes());
109    }
110    params_bytes.extend_from_slice(&(n_basis as u32).to_ne_bytes());
111
112    let [nx, ny, nz] = params.dimensions;
113    let descriptor = ComputeDispatchDescriptor {
114        label: "density grid".to_string(),
115        shader_source: DENSITY_GRID_SHADER.to_string(),
116        entry_point: "main".to_string(),
117        workgroup_count: [
118            (nx as u32).div_ceil(8),
119            (ny as u32).div_ceil(8),
120            (nz as u32).div_ceil(4),
121        ],
122        bindings: vec![
123            ComputeBindingDescriptor {
124                label: "basis".to_string(),
125                kind: ComputeBindingKind::StorageReadOnly,
126                bytes: basis_bytes,
127            },
128            ComputeBindingDescriptor {
129                label: "density".to_string(),
130                kind: ComputeBindingKind::StorageReadOnly,
131                bytes: f32_slice_to_bytes(&density_flat),
132            },
133            ComputeBindingDescriptor {
134                label: "primitives".to_string(),
135                kind: ComputeBindingKind::StorageReadOnly,
136                bytes: prim_bytes,
137            },
138            ComputeBindingDescriptor {
139                label: "params".to_string(),
140                kind: ComputeBindingKind::Uniform,
141                bytes: params_bytes,
142            },
143            ComputeBindingDescriptor {
144                label: "output".to_string(),
145                kind: ComputeBindingKind::StorageReadWrite,
146                bytes: f32_slice_to_bytes(&vec![0.0f32; n_points]),
147            },
148        ],
149    };
150
151    let mut outputs = ctx.run_compute(&descriptor)?.outputs;
152    let bytes = outputs.pop().ok_or("No output from density grid kernel")?;
153    let values = bytes_to_f32_vec(&bytes);
154    if values.len() != n_points {
155        return Err(format!(
156            "Output size mismatch: expected {}, got {}",
157            n_points,
158            values.len()
159        ));
160    }
161    Ok(values.into_iter().map(|value| value as f64).collect())
162}
163
164pub const DENSITY_GRID_SHADER: &str = r#"
165struct BasisFunc {
166    center_x: f32, center_y: f32, center_z: f32,
167    lx: u32, ly: u32, lz: u32,
168    n_primitives: u32,
169    norm_coeff: f32,
170};
171
172struct GridParams {
173    origin_x: f32, origin_y: f32, origin_z: f32,
174    spacing: f32,
175    dims_x: u32, dims_y: u32, dims_z: u32,
176    n_basis: u32,
177};
178
179@group(0) @binding(0) var<storage, read> basis: array<BasisFunc>;
180@group(0) @binding(1) var<storage, read> density: array<f32>;
181@group(0) @binding(2) var<storage, read> primitives: array<vec2<f32>>;
182@group(0) @binding(3) var<uniform> params: GridParams;
183@group(0) @binding(4) var<storage, read_write> output: array<f32>;
184
185fn phi_at(mu: u32, rx: f32, ry: f32, rz: f32) -> f32 {
186    let bf = basis[mu];
187    let dx = rx - bf.center_x;
188    let dy = ry - bf.center_y;
189    let dz = rz - bf.center_z;
190    let r2 = dx * dx + dy * dy + dz * dz;
191
192    var angular: f32 = 1.0;
193    for (var i: u32 = 0u; i < bf.lx; i = i + 1u) { angular *= dx; }
194    for (var i: u32 = 0u; i < bf.ly; i = i + 1u) { angular *= dy; }
195    for (var i: u32 = 0u; i < bf.lz; i = i + 1u) { angular *= dz; }
196
197    var radial: f32 = 0.0;
198    for (var p: u32 = 0u; p < bf.n_primitives; p = p + 1u) {
199        let prim = primitives[mu * 3u + p];
200        radial += prim.y * exp(-prim.x * r2);
201    }
202    return bf.norm_coeff * angular * radial;
203}
204
205@compute @workgroup_size(8, 8, 4)
206fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
207    let ix = gid.x;
208    let iy = gid.y;
209    let iz = gid.z;
210
211    if (ix >= params.dims_x || iy >= params.dims_y || iz >= params.dims_z) {
212        return;
213    }
214
215    let rx = params.origin_x + f32(ix) * params.spacing;
216    let ry = params.origin_y + f32(iy) * params.spacing;
217    let rz = params.origin_z + f32(iz) * params.spacing;
218    let flat_idx = ix * params.dims_y * params.dims_z + iy * params.dims_z + iz;
219
220    var rho: f32 = 0.0;
221    for (var mu: u32 = 0u; mu < params.n_basis; mu = mu + 1u) {
222        let phi_mu = phi_at(mu, rx, ry, rz);
223        if (abs(phi_mu) < 1e-7) { continue; }
224        for (var nu: u32 = 0u; nu < params.n_basis; nu = nu + 1u) {
225            let phi_nu = phi_at(nu, rx, ry, rz);
226            rho += density[mu * params.n_basis + nu] * phi_mu * phi_nu;
227        }
228    }
229
230    output[flat_idx] = rho;
231}
232"#;