Skip to main content

sci_form/gpu/
ani_aev_gpu.rs

1//! GPU-accelerated ANI Atomic Environment Vector (AEV) computation.
2//!
3//! Offloads the O(N²) radial symmetry function computation to GPU.
4
5use super::context::{
6    bytes_to_f64_vec_from_f32, ceil_div_u32, f32_slice_to_bytes, pack_uniform_values,
7    pack_vec3_positions_f32, ComputeBindingDescriptor, ComputeBindingKind,
8    ComputeDispatchDescriptor, GpuContext, UniformValue,
9};
10
11const GPU_DISPATCH_THRESHOLD: usize = 8;
12
13/// Compute radial AEV components on GPU.
14///
15/// Each atom i gets a vector of radial symmetry function values summed over
16/// all neighbors j within the cutoff radius.
17pub fn compute_radial_aev_gpu(
18    ctx: &GpuContext,
19    species: &[u32],
20    positions: &[[f64; 3]],
21    eta: &[f64],
22    rs: &[f64],
23    cutoff: f64,
24    n_species: usize,
25) -> Result<Vec<f64>, String> {
26    let n = positions.len();
27    if n < GPU_DISPATCH_THRESHOLD {
28        return Err("System too small for GPU dispatch".to_string());
29    }
30
31    let n_radial = eta.len() * rs.len();
32    let output_len = n * n_species * n_radial;
33
34    let species_f32: Vec<f32> = species.iter().map(|&s| s as f32).collect();
35    let eta_f32: Vec<f32> = eta.iter().map(|&v| v as f32).collect();
36    let rs_f32: Vec<f32> = rs.iter().map(|&v| v as f32).collect();
37
38    // Pad eta/rs to fixed sizes (max 16 each)
39    let mut eta_padded = [0.0f32; 16];
40    let mut rs_padded = [0.0f32; 16];
41    for (i, &v) in eta_f32.iter().enumerate().take(16) {
42        eta_padded[i] = v;
43    }
44    for (i, &v) in rs_f32.iter().enumerate().take(16) {
45        rs_padded[i] = v;
46    }
47
48    let params = pack_uniform_values(&[
49        UniformValue::U32(n as u32),
50        UniformValue::U32(n_species as u32),
51        UniformValue::U32(eta.len() as u32),
52        UniformValue::U32(rs.len() as u32),
53        UniformValue::F32(cutoff as f32),
54        UniformValue::F32(0.0),
55        UniformValue::F32(0.0),
56        UniformValue::F32(0.0),
57    ]);
58
59    let descriptor = ComputeDispatchDescriptor {
60        label: "ani radial aev".to_string(),
61        shader_source: ANI_RADIAL_AEV_SHADER.to_string(),
62        entry_point: "main".to_string(),
63        workgroup_count: [ceil_div_u32(n, 64), 1, 1],
64        bindings: vec![
65            ComputeBindingDescriptor {
66                label: "positions".to_string(),
67                kind: ComputeBindingKind::StorageReadOnly,
68                bytes: pack_vec3_positions_f32(positions),
69            },
70            ComputeBindingDescriptor {
71                label: "species".to_string(),
72                kind: ComputeBindingKind::StorageReadOnly,
73                bytes: f32_slice_to_bytes(&species_f32),
74            },
75            ComputeBindingDescriptor {
76                label: "eta".to_string(),
77                kind: ComputeBindingKind::StorageReadOnly,
78                bytes: f32_slice_to_bytes(&eta_padded),
79            },
80            ComputeBindingDescriptor {
81                label: "rs".to_string(),
82                kind: ComputeBindingKind::StorageReadOnly,
83                bytes: f32_slice_to_bytes(&rs_padded),
84            },
85            ComputeBindingDescriptor {
86                label: "params".to_string(),
87                kind: ComputeBindingKind::Uniform,
88                bytes: params,
89            },
90            ComputeBindingDescriptor {
91                label: "output".to_string(),
92                kind: ComputeBindingKind::StorageReadWrite,
93                bytes: f32_slice_to_bytes(&vec![0.0f32; output_len]),
94            },
95        ],
96    };
97
98    let mut result = ctx.run_compute(&descriptor)?;
99    let bytes = result
100        .outputs
101        .pop()
102        .ok_or("No output from ANI AEV kernel")?;
103    Ok(bytes_to_f64_vec_from_f32(&bytes))
104}
105
106pub const ANI_RADIAL_AEV_SHADER: &str = r#"
107struct AtomPos {
108    x: f32, y: f32, z: f32, _pad: f32,
109};
110
111struct Params {
112    n_atoms: u32,
113    n_species: u32,
114    n_eta: u32,
115    n_rs: u32,
116    cutoff: f32,
117    _pad0: f32,
118    _pad1: f32,
119    _pad2: f32,
120};
121
122@group(0) @binding(0) var<storage, read> positions: array<AtomPos>;
123@group(0) @binding(1) var<storage, read> species: array<f32>;
124@group(0) @binding(2) var<storage, read> eta: array<f32>;
125@group(0) @binding(3) var<storage, read> rs: array<f32>;
126@group(0) @binding(4) var<uniform> params: Params;
127@group(0) @binding(5) var<storage, read_write> output: array<f32>;
128
129fn cosine_cutoff(r: f32, rc: f32) -> f32 {
130    if (r >= rc) {
131        return 0.0;
132    }
133    return 0.5 * (1.0 + cos(3.14159265358979 * r / rc));
134}
135
136@compute @workgroup_size(64, 1, 1)
137fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
138    let i = gid.x;
139    let n = params.n_atoms;
140    if (i >= n) {
141        return;
142    }
143
144    let pos_i = positions[i];
145    let n_radial = params.n_eta * params.n_rs;
146    let stride = params.n_species * n_radial;
147
148    for (var j: u32 = 0u; j < n; j = j + 1u) {
149        if (i == j) {
150            continue;
151        }
152
153        let pos_j = positions[j];
154        let dx = pos_i.x - pos_j.x;
155        let dy = pos_i.y - pos_j.y;
156        let dz = pos_i.z - pos_j.z;
157        let rij = sqrt(dx * dx + dy * dy + dz * dz);
158
159        if (rij >= params.cutoff) {
160            continue;
161        }
162
163        let sj = u32(species[j]);
164        let fc = cosine_cutoff(rij, params.cutoff);
165        let base = i * stride + sj * n_radial;
166
167        var k: u32 = 0u;
168        for (var ie: u32 = 0u; ie < params.n_eta; ie = ie + 1u) {
169            let e = eta[ie];
170            for (var ir: u32 = 0u; ir < params.n_rs; ir = ir + 1u) {
171                let dr = rij - rs[ir];
172                output[base + k] += exp(-e * dr * dr) * fc;
173                k = k + 1u;
174            }
175        }
176    }
177}
178"#;
179
180#[cfg(test)]
181mod tests {
182    use super::ANI_RADIAL_AEV_SHADER;
183
184    #[test]
185    fn test_ani_aev_gpu_module_compiles() {
186        assert!(!ANI_RADIAL_AEV_SHADER.is_empty());
187    }
188}