sci_form/gpu/
ani_aev_gpu.rs1use super::context::{
6 bytes_to_f64_vec_from_f32, ceil_div_u32, f32_slice_to_bytes, pack_uniform_values,
7 pack_vec3_positions_f32, ComputeBindingDescriptor, ComputeBindingKind,
8 ComputeDispatchDescriptor, GpuContext, UniformValue,
9};
10
11const GPU_DISPATCH_THRESHOLD: usize = 8;
12
13pub fn compute_radial_aev_gpu(
18 ctx: &GpuContext,
19 species: &[u32],
20 positions: &[[f64; 3]],
21 eta: &[f64],
22 rs: &[f64],
23 cutoff: f64,
24 n_species: usize,
25) -> Result<Vec<f64>, String> {
26 let n = positions.len();
27 if n < GPU_DISPATCH_THRESHOLD {
28 return Err("System too small for GPU dispatch".to_string());
29 }
30
31 let n_radial = eta.len() * rs.len();
32 let output_len = n * n_species * n_radial;
33
34 let species_f32: Vec<f32> = species.iter().map(|&s| s as f32).collect();
35 let eta_f32: Vec<f32> = eta.iter().map(|&v| v as f32).collect();
36 let rs_f32: Vec<f32> = rs.iter().map(|&v| v as f32).collect();
37
38 let mut eta_padded = [0.0f32; 16];
40 let mut rs_padded = [0.0f32; 16];
41 for (i, &v) in eta_f32.iter().enumerate().take(16) {
42 eta_padded[i] = v;
43 }
44 for (i, &v) in rs_f32.iter().enumerate().take(16) {
45 rs_padded[i] = v;
46 }
47
48 let params = pack_uniform_values(&[
49 UniformValue::U32(n as u32),
50 UniformValue::U32(n_species as u32),
51 UniformValue::U32(eta.len() as u32),
52 UniformValue::U32(rs.len() as u32),
53 UniformValue::F32(cutoff as f32),
54 UniformValue::F32(0.0),
55 UniformValue::F32(0.0),
56 UniformValue::F32(0.0),
57 ]);
58
59 let descriptor = ComputeDispatchDescriptor {
60 label: "ani radial aev".to_string(),
61 shader_source: ANI_RADIAL_AEV_SHADER.to_string(),
62 entry_point: "main".to_string(),
63 workgroup_count: [ceil_div_u32(n, 64), 1, 1],
64 bindings: vec![
65 ComputeBindingDescriptor {
66 label: "positions".to_string(),
67 kind: ComputeBindingKind::StorageReadOnly,
68 bytes: pack_vec3_positions_f32(positions),
69 },
70 ComputeBindingDescriptor {
71 label: "species".to_string(),
72 kind: ComputeBindingKind::StorageReadOnly,
73 bytes: f32_slice_to_bytes(&species_f32),
74 },
75 ComputeBindingDescriptor {
76 label: "eta".to_string(),
77 kind: ComputeBindingKind::StorageReadOnly,
78 bytes: f32_slice_to_bytes(&eta_padded),
79 },
80 ComputeBindingDescriptor {
81 label: "rs".to_string(),
82 kind: ComputeBindingKind::StorageReadOnly,
83 bytes: f32_slice_to_bytes(&rs_padded),
84 },
85 ComputeBindingDescriptor {
86 label: "params".to_string(),
87 kind: ComputeBindingKind::Uniform,
88 bytes: params,
89 },
90 ComputeBindingDescriptor {
91 label: "output".to_string(),
92 kind: ComputeBindingKind::StorageReadWrite,
93 bytes: f32_slice_to_bytes(&vec![0.0f32; output_len]),
94 },
95 ],
96 };
97
98 let mut result = ctx.run_compute(&descriptor)?;
99 let bytes = result
100 .outputs
101 .pop()
102 .ok_or("No output from ANI AEV kernel")?;
103 Ok(bytes_to_f64_vec_from_f32(&bytes))
104}
105
106pub const ANI_RADIAL_AEV_SHADER: &str = r#"
107struct AtomPos {
108 x: f32, y: f32, z: f32, _pad: f32,
109};
110
111struct Params {
112 n_atoms: u32,
113 n_species: u32,
114 n_eta: u32,
115 n_rs: u32,
116 cutoff: f32,
117 _pad0: f32,
118 _pad1: f32,
119 _pad2: f32,
120};
121
122@group(0) @binding(0) var<storage, read> positions: array<AtomPos>;
123@group(0) @binding(1) var<storage, read> species: array<f32>;
124@group(0) @binding(2) var<storage, read> eta: array<f32>;
125@group(0) @binding(3) var<storage, read> rs: array<f32>;
126@group(0) @binding(4) var<uniform> params: Params;
127@group(0) @binding(5) var<storage, read_write> output: array<f32>;
128
129fn cosine_cutoff(r: f32, rc: f32) -> f32 {
130 if (r >= rc) {
131 return 0.0;
132 }
133 return 0.5 * (1.0 + cos(3.14159265358979 * r / rc));
134}
135
136@compute @workgroup_size(64, 1, 1)
137fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
138 let i = gid.x;
139 let n = params.n_atoms;
140 if (i >= n) {
141 return;
142 }
143
144 let pos_i = positions[i];
145 let n_radial = params.n_eta * params.n_rs;
146 let stride = params.n_species * n_radial;
147
148 for (var j: u32 = 0u; j < n; j = j + 1u) {
149 if (i == j) {
150 continue;
151 }
152
153 let pos_j = positions[j];
154 let dx = pos_i.x - pos_j.x;
155 let dy = pos_i.y - pos_j.y;
156 let dz = pos_i.z - pos_j.z;
157 let rij = sqrt(dx * dx + dy * dy + dz * dz);
158
159 if (rij >= params.cutoff) {
160 continue;
161 }
162
163 let sj = u32(species[j]);
164 let fc = cosine_cutoff(rij, params.cutoff);
165 let base = i * stride + sj * n_radial;
166
167 var k: u32 = 0u;
168 for (var ie: u32 = 0u; ie < params.n_eta; ie = ie + 1u) {
169 let e = eta[ie];
170 for (var ir: u32 = 0u; ir < params.n_rs; ir = ir + 1u) {
171 let dr = rij - rs[ir];
172 output[base + k] += exp(-e * dr * dr) * fc;
173 k = k + 1u;
174 }
175 }
176 }
177}
178"#;
179
180#[cfg(test)]
181mod tests {
182 use super::ANI_RADIAL_AEV_SHADER;
183
184 #[test]
185 fn test_ani_aev_gpu_module_compiles() {
186 assert!(!ANI_RADIAL_AEV_SHADER.is_empty());
187 }
188}