Skip to main content

sci_form/gpu/
esp_grid_gpu.rs

1//! GPU point-charge ESP grid evaluation.
2
3use super::backend_report::OrbitalGridReport;
4use super::context::{
5    bytes_to_f64_vec_from_f32, ceil_div_u32, f32_slice_to_bytes, pack_uniform_values,
6    pack_vec3_positions_f32, ComputeBindingDescriptor, ComputeBindingKind,
7    ComputeDispatchDescriptor, GpuContext, UniformValue,
8};
9use crate::esp::{compute_esp_grid, EspGrid};
10
11pub fn compute_esp_grid_with_report(
12    elements: &[u8],
13    positions: &[[f64; 3]],
14    mulliken_charges: &[f64],
15    spacing: f64,
16    padding: f64,
17) -> (EspGrid, OrbitalGridReport) {
18    let ctx = GpuContext::best_available();
19    if ctx.is_gpu_available() {
20        match compute_esp_grid_gpu(&ctx, positions, mulliken_charges, spacing, padding) {
21            Ok(grid) => {
22                let n_points = grid.dims[0] * grid.dims[1] * grid.dims[2];
23                return (
24                    grid,
25                    OrbitalGridReport {
26                        backend: ctx.capabilities.backend.clone(),
27                        used_gpu: true,
28                        attempted_gpu: true,
29                        n_points,
30                        note: format!("GPU ESP-grid dispatch on {}", ctx.capabilities.backend),
31                    },
32                );
33            }
34            Err(_err) => {}
35        }
36    }
37
38    let grid = compute_esp_grid(elements, positions, mulliken_charges, spacing, padding);
39    let n_points = grid.dims[0] * grid.dims[1] * grid.dims[2];
40    (
41        grid,
42        OrbitalGridReport {
43            backend: "CPU".to_string(),
44            used_gpu: false,
45            attempted_gpu: ctx.is_gpu_available(),
46            n_points,
47            note: if ctx.is_gpu_available() {
48                "GPU available but ESP-grid dispatch failed; CPU fallback used".to_string()
49            } else {
50                "CPU ESP-grid evaluation (GPU not available)".to_string()
51            },
52        },
53    )
54}
55
56pub fn compute_esp_grid_gpu(
57    ctx: &GpuContext,
58    positions: &[[f64; 3]],
59    mulliken_charges: &[f64],
60    spacing: f64,
61    padding: f64,
62) -> Result<EspGrid, String> {
63    if positions.len() != mulliken_charges.len() {
64        return Err("positions/charges length mismatch".to_string());
65    }
66
67    let mut min = [f64::MAX; 3];
68    let mut max = [f64::MIN; 3];
69    for pos in positions {
70        for axis in 0..3 {
71            min[axis] = min[axis].min(pos[axis]);
72            max[axis] = max[axis].max(pos[axis]);
73        }
74    }
75
76    let origin = [min[0] - padding, min[1] - padding, min[2] - padding];
77    let dims = [
78        ((max[0] - min[0] + 2.0 * padding) / spacing).ceil() as usize + 1,
79        ((max[1] - min[1] + 2.0 * padding) / spacing).ceil() as usize + 1,
80        ((max[2] - min[2] + 2.0 * padding) / spacing).ceil() as usize + 1,
81    ];
82    let total = dims[0] * dims[1] * dims[2];
83
84    let params_bytes = pack_uniform_values(&[
85        UniformValue::F32(origin[0] as f32),
86        UniformValue::F32(origin[1] as f32),
87        UniformValue::F32(origin[2] as f32),
88        UniformValue::F32(spacing as f32),
89        UniformValue::U32(dims[0] as u32),
90        UniformValue::U32(dims[1] as u32),
91        UniformValue::U32(dims[2] as u32),
92        UniformValue::U32(positions.len() as u32),
93    ]);
94
95    let descriptor = ComputeDispatchDescriptor {
96        label: "esp grid".to_string(),
97        shader_source: ESP_GRID_SHADER.to_string(),
98        entry_point: "main".to_string(),
99        workgroup_count: [
100            ceil_div_u32(dims[0], 8),
101            ceil_div_u32(dims[1], 8),
102            ceil_div_u32(dims[2], 4),
103        ],
104        bindings: vec![
105            ComputeBindingDescriptor {
106                label: "positions".to_string(),
107                kind: ComputeBindingKind::StorageReadOnly,
108                bytes: pack_vec3_positions_f32(positions),
109            },
110            ComputeBindingDescriptor {
111                label: "charges".to_string(),
112                kind: ComputeBindingKind::StorageReadOnly,
113                bytes: f32_slice_to_bytes(
114                    &mulliken_charges
115                        .iter()
116                        .map(|value| *value as f32)
117                        .collect::<Vec<_>>(),
118                ),
119            },
120            ComputeBindingDescriptor {
121                label: "params".to_string(),
122                kind: ComputeBindingKind::Uniform,
123                bytes: params_bytes,
124            },
125            ComputeBindingDescriptor {
126                label: "output".to_string(),
127                kind: ComputeBindingKind::StorageReadWrite,
128                bytes: f32_slice_to_bytes(&vec![0.0f32; total]),
129            },
130        ],
131    };
132
133    let mut outputs = ctx.run_compute(&descriptor)?.outputs;
134    let bytes = outputs.pop().ok_or("No output from ESP grid kernel")?;
135    let values = bytes_to_f64_vec_from_f32(&bytes);
136    if values.len() != total {
137        return Err(format!(
138            "Output size mismatch: expected {}, got {}",
139            total,
140            values.len()
141        ));
142    }
143
144    Ok(EspGrid {
145        origin,
146        spacing,
147        dims,
148        values,
149    })
150}
151
152pub const ESP_GRID_SHADER: &str = r#"
153struct AtomPos {
154    x: f32, y: f32, z: f32, _pad: f32,
155};
156
157struct GridParams {
158    origin_x: f32, origin_y: f32, origin_z: f32,
159    spacing: f32,
160    dims_x: u32, dims_y: u32, dims_z: u32,
161    n_atoms: u32,
162};
163
164@group(0) @binding(0) var<storage, read> positions: array<AtomPos>;
165@group(0) @binding(1) var<storage, read> charges: array<f32>;
166@group(0) @binding(2) var<uniform> params: GridParams;
167@group(0) @binding(3) var<storage, read_write> output: array<f32>;
168
169@compute @workgroup_size(8, 8, 4)
170fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
171    let ix = gid.x;
172    let iy = gid.y;
173    let iz = gid.z;
174
175    if (ix >= params.dims_x || iy >= params.dims_y || iz >= params.dims_z) {
176        return;
177    }
178
179    let rx = params.origin_x + f32(ix) * params.spacing;
180    let ry = params.origin_y + f32(iy) * params.spacing;
181    let rz = params.origin_z + f32(iz) * params.spacing;
182    let flat_idx = ix * params.dims_y * params.dims_z + iy * params.dims_z + iz;
183
184    var phi: f32 = 0.0;
185    for (var atom: u32 = 0u; atom < params.n_atoms; atom = atom + 1u) {
186        let pos = positions[atom];
187        let dx = rx - pos.x;
188        let dy = ry - pos.y;
189        let dz = rz - pos.z;
190        let dist = sqrt(dx * dx + dy * dy + dz * dz);
191        if (dist < 0.01) { continue; }
192        phi += charges[atom] / (dist * 1.88972599);
193    }
194    output[flat_idx] = phi;
195}
196"#;