1#[cfg(feature = "experimental-gpu")]
11use crate::gpu::context::{
12 ComputeBindingDescriptor, ComputeBindingKind, ComputeDispatchDescriptor, GpuContext,
13};
14
15use serde::{Deserialize, Serialize};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct Mmff94GpuResult {
20 pub total_energy: f64,
22 pub vdw_energy: f64,
24 pub electrostatic_energy: f64,
26 pub n_pairs: usize,
28 pub used_gpu: bool,
30 pub backend: String,
32}
33
34#[allow(dead_code)]
36const GPU_DISPATCH_THRESHOLD: usize = 50;
37
38#[cfg(feature = "experimental-gpu")]
44pub fn compute_mmff94_nonbonded_gpu(
45 ctx: &GpuContext,
46 coords: &[f64],
47 charges: &[f64],
48 vdw_radii: &[f64],
49 vdw_epsilon: &[f64],
50 exclusions_14: &[(usize, usize)],
51) -> Result<Mmff94GpuResult, String> {
52 let n_atoms = charges.len();
53
54 if n_atoms < GPU_DISPATCH_THRESHOLD {
55 return Ok(compute_mmff94_nonbonded_cpu(
56 coords,
57 charges,
58 vdw_radii,
59 vdw_epsilon,
60 exclusions_14,
61 "Atom count below GPU threshold",
62 ));
63 }
64
65 let excl_bitmap_bytes = (n_atoms * n_atoms).div_ceil(8);
67 let atom_buffer_bytes = n_atoms * 32;
68 let total_gpu_memory = excl_bitmap_bytes + atom_buffer_bytes;
69 const MAX_GPU_BUFFER: usize = 512 * 1024 * 1024; if total_gpu_memory > MAX_GPU_BUFFER {
71 return Ok(compute_mmff94_nonbonded_cpu(
72 coords,
73 charges,
74 vdw_radii,
75 vdw_epsilon,
76 exclusions_14,
77 "System too large for GPU buffers",
78 ));
79 }
80
81 let mut atom_bytes = Vec::with_capacity(n_atoms * 32);
83 for i in 0..n_atoms {
84 atom_bytes.extend_from_slice(&(coords[i * 3] as f32).to_ne_bytes());
85 atom_bytes.extend_from_slice(&(coords[i * 3 + 1] as f32).to_ne_bytes());
86 atom_bytes.extend_from_slice(&(coords[i * 3 + 2] as f32).to_ne_bytes());
87 atom_bytes.extend_from_slice(&(charges[i] as f32).to_ne_bytes());
88 atom_bytes.extend_from_slice(&(vdw_radii[i] as f32).to_ne_bytes());
89 atom_bytes.extend_from_slice(&(vdw_epsilon[i] as f32).to_ne_bytes());
90 atom_bytes.extend_from_slice(&0.0f32.to_ne_bytes());
91 atom_bytes.extend_from_slice(&0.0f32.to_ne_bytes());
92 }
93
94 let excl_size = (n_atoms * n_atoms).div_ceil(32);
96 let mut excl_bits = vec![0u32; excl_size];
97 for &(i, j) in exclusions_14 {
98 let bit_idx = i * n_atoms + j;
99 excl_bits[bit_idx / 32] |= 1 << (bit_idx % 32);
100 let bit_idx2 = j * n_atoms + i;
101 excl_bits[bit_idx2 / 32] |= 1 << (bit_idx2 % 32);
102 }
103 let excl_bytes: Vec<u8> = excl_bits.iter().flat_map(|b| b.to_ne_bytes()).collect();
104
105 let mut params = Vec::with_capacity(16);
107 params.extend_from_slice(&(n_atoms as u32).to_ne_bytes());
108 params.extend_from_slice(&0u32.to_ne_bytes());
109 params.extend_from_slice(&0u32.to_ne_bytes());
110 params.extend_from_slice(&0u32.to_ne_bytes());
111
112 let n_output = 2; let output_bytes = vec![0u8; n_output * 4];
115
116 let wg_size = 64u32;
117 let n_pairs = n_atoms * (n_atoms - 1) / 2;
118 let wg_count = (n_pairs as u32).div_ceil(wg_size);
119
120 let descriptor = ComputeDispatchDescriptor {
121 label: "MMFF94 non-bonded".to_string(),
122 shader_source: MMFF94_NB_SHADER.to_string(),
123 entry_point: "main".to_string(),
124 workgroup_count: [wg_count, 1, 1],
125 bindings: vec![
126 ComputeBindingDescriptor {
127 label: "atoms".to_string(),
128 kind: ComputeBindingKind::StorageReadOnly,
129 bytes: atom_bytes,
130 },
131 ComputeBindingDescriptor {
132 label: "exclusions".to_string(),
133 kind: ComputeBindingKind::StorageReadOnly,
134 bytes: excl_bytes,
135 },
136 ComputeBindingDescriptor {
137 label: "params".to_string(),
138 kind: ComputeBindingKind::Uniform,
139 bytes: params,
140 },
141 ComputeBindingDescriptor {
142 label: "output".to_string(),
143 kind: ComputeBindingKind::StorageReadWrite,
144 bytes: output_bytes,
145 },
146 ],
147 };
148
149 let mut result = ctx.run_compute(&descriptor)?;
150 let bytes = result.outputs.pop().ok_or("No output from MMFF94 kernel")?;
151
152 if bytes.len() < 8 {
153 return Err("Insufficient output from MMFF94 kernel".to_string());
154 }
155
156 let vdw = f32::from_ne_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as f64;
157 let elec = f32::from_ne_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]) as f64;
158
159 Ok(Mmff94GpuResult {
160 total_energy: vdw + elec,
161 vdw_energy: vdw,
162 electrostatic_energy: elec,
163 n_pairs,
164 used_gpu: true,
165 backend: ctx.capabilities.backend.clone(),
166 })
167}
168
169pub fn compute_mmff94_nonbonded_cpu(
171 coords: &[f64],
172 charges: &[f64],
173 vdw_radii: &[f64],
174 vdw_epsilon: &[f64],
175 exclusions_14: &[(usize, usize)],
176 note: &str,
177) -> Mmff94GpuResult {
178 let n_atoms = charges.len();
179 let mut vdw_energy = 0.0;
180 let mut elec_energy = 0.0;
181 let mut n_pairs = 0;
182
183 let mut excl_set = std::collections::HashSet::new();
185 for &(i, j) in exclusions_14 {
186 excl_set.insert((i.min(j), i.max(j)));
187 }
188
189 for i in 0..n_atoms {
190 let xi = coords[i * 3];
191 let yi = coords[i * 3 + 1];
192 let zi = coords[i * 3 + 2];
193
194 for j in (i + 1)..n_atoms {
195 if excl_set.contains(&(i, j)) {
196 continue;
197 }
198
199 let dx = xi - coords[j * 3];
200 let dy = yi - coords[j * 3 + 1];
201 let dz = zi - coords[j * 3 + 2];
202 let r2 = dx * dx + dy * dy + dz * dz;
203 let r = r2.sqrt();
204
205 if !(0.1..=15.0).contains(&r) {
206 continue;
207 }
208
209 let r_star = vdw_radii[i] + vdw_radii[j];
211 let eps = (vdw_epsilon[i] * vdw_epsilon[j]).sqrt();
212 let rho = r / r_star;
213 let rho7 = rho.powi(7);
214 let e_vdw = eps * (1.07 / (rho + 0.07)).powi(7) * ((1.12 / (rho7 + 0.12)) - 2.0);
215 vdw_energy += e_vdw;
216
217 let e_elec = 332.0716 * charges[i] * charges[j] / (r + 0.05);
219 elec_energy += e_elec;
220
221 n_pairs += 1;
222 }
223 }
224
225 Mmff94GpuResult {
226 total_energy: vdw_energy + elec_energy,
227 vdw_energy,
228 electrostatic_energy: elec_energy,
229 n_pairs,
230 used_gpu: false,
231 backend: format!("CPU ({})", note),
232 }
233}
234
235#[cfg(feature = "experimental-gpu")]
240pub const MMFF94_NB_SHADER: &str = r#"
241struct Atom {
242 x: f32, y: f32, z: f32,
243 charge: f32,
244 r_vdw: f32, eps_vdw: f32,
245 _pad0: f32, _pad1: f32,
246};
247
248struct Params {
249 n_atoms: u32,
250 _pad0: u32, _pad1: u32, _pad2: u32,
251};
252
253@group(0) @binding(0) var<storage, read> atoms: array<Atom>;
254@group(0) @binding(1) var<storage, read> exclusions: array<u32>;
255@group(0) @binding(2) var<uniform> params: Params;
256@group(0) @binding(3) var<storage, read_write> output: array<atomic<u32>>;
257
258fn pair_to_ij(pair_idx: u32, n: u32) -> vec2<u32> {
259 // Convert linear pair index to (i, j) with i < j
260 var i: u32 = 0u;
261 var remaining: u32 = pair_idx;
262 loop {
263 let row_size = n - 1u - i;
264 if (remaining < row_size) {
265 break;
266 }
267 remaining -= row_size;
268 i += 1u;
269 }
270 let j = i + 1u + remaining;
271 return vec2<u32>(i, j);
272}
273
274@compute @workgroup_size(64, 1, 1)
275fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
276 let n = params.n_atoms;
277 let n_pairs = n * (n - 1u) / 2u;
278 let pair_idx = gid.x;
279
280 if (pair_idx >= n_pairs) {
281 return;
282 }
283
284 let ij = pair_to_ij(pair_idx, n);
285 let i = ij.x;
286 let j = ij.y;
287
288 // Check exclusion bitmap
289 let bit_idx = i * n + j;
290 let word = exclusions[bit_idx / 32u];
291 if ((word >> (bit_idx % 32u)) & 1u) != 0u {
292 return;
293 }
294
295 let a = atoms[i];
296 let b = atoms[j];
297
298 let dx = a.x - b.x;
299 let dy = a.y - b.y;
300 let dz = a.z - b.z;
301 let r = sqrt(dx * dx + dy * dy + dz * dz);
302
303 if (r < 0.1 || r > 15.0) {
304 return;
305 }
306
307 // Buffered 14-7 VdW
308 let r_star = a.r_vdw + b.r_vdw;
309 let eps = sqrt(a.eps_vdw * b.eps_vdw);
310 let rho = r / r_star;
311 let rho7 = pow(rho, 7.0);
312 let e_vdw = eps * pow(1.07 / (rho + 0.07), 7.0) * ((1.12 / (rho7 + 0.12)) - 2.0);
313
314 // Coulomb with distance-dependent dielectric
315 let e_elec = 332.0716 * a.charge * b.charge / (r + 0.05);
316
317 // Atomic add to output (using integer representation)
318 let vdw_bits = bitcast<u32>(e_vdw);
319 let elec_bits = bitcast<u32>(e_elec);
320 atomicAdd(&output[0], vdw_bits);
321 atomicAdd(&output[1], elec_bits);
322}
323"#;
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328
329 #[test]
330 fn test_mmff94_cpu_fallback() {
331 let coords = vec![0.0, 0.0, 0.0, 3.0, 0.0, 0.0];
333 let charges = vec![0.3, -0.3];
334 let vdw_radii = vec![1.5, 1.7];
335 let vdw_epsilon = vec![0.05, 0.06];
336
337 let result =
338 compute_mmff94_nonbonded_cpu(&coords, &charges, &vdw_radii, &vdw_epsilon, &[], "test");
339 assert_eq!(result.n_pairs, 1);
340 assert!(result.total_energy.is_finite());
341 assert!(result.electrostatic_energy < 0.0); assert!(!result.used_gpu);
343 }
344}