Skip to main content

sci_form/gpu/
mmff94_gpu.rs

1//! GPU-accelerated MMFF94 force field evaluation.
2//!
3//! Computes MMFF94 non-bonded interactions (van der Waals + electrostatic)
4//! on the GPU. The O(N²) pairwise terms are the bottleneck for large
5//! molecules and parallelize naturally on GPU hardware.
6//!
7//! The bonded terms (bond stretching, angle bending, torsion) remain on CPU
8//! as they are O(N) and memory-access-heavy.
9
10#[cfg(feature = "experimental-gpu")]
11use crate::gpu::context::{
12    ComputeBindingDescriptor, ComputeBindingKind, ComputeDispatchDescriptor, GpuContext,
13};
14
15use serde::{Deserialize, Serialize};
16
17/// Result of GPU MMFF94 evaluation.
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct Mmff94GpuResult {
20    /// Total energy (kcal/mol).
21    pub total_energy: f64,
22    /// Van der Waals energy (kcal/mol).
23    pub vdw_energy: f64,
24    /// Electrostatic energy (kcal/mol).
25    pub electrostatic_energy: f64,
26    /// Number of atom pairs evaluated.
27    pub n_pairs: usize,
28    /// Whether GPU was actually used.
29    pub used_gpu: bool,
30    /// Backend description.
31    pub backend: String,
32}
33
34/// Minimum atoms to justify GPU dispatch for MMFF94.
35#[allow(dead_code)]
36const GPU_DISPATCH_THRESHOLD: usize = 50;
37
38/// Compute MMFF94 non-bonded energy on GPU.
39///
40/// Packs atom coordinates, charges, VdW parameters, and 1-4 exclusion
41/// data into GPU buffers. A single WGSL kernel evaluates all N(N-1)/2
42/// pair interactions in parallel.
43#[cfg(feature = "experimental-gpu")]
44pub fn compute_mmff94_nonbonded_gpu(
45    ctx: &GpuContext,
46    coords: &[f64],
47    charges: &[f64],
48    vdw_radii: &[f64],
49    vdw_epsilon: &[f64],
50    exclusions_14: &[(usize, usize)],
51) -> Result<Mmff94GpuResult, String> {
52    let n_atoms = charges.len();
53
54    if n_atoms < GPU_DISPATCH_THRESHOLD {
55        return Ok(compute_mmff94_nonbonded_cpu(
56            coords,
57            charges,
58            vdw_radii,
59            vdw_epsilon,
60            exclusions_14,
61            "Atom count below GPU threshold",
62        ));
63    }
64
65    // Guard against excessive memory usage for exclusion bitmap (N²/8 bytes)
66    let excl_bitmap_bytes = (n_atoms * n_atoms).div_ceil(8);
67    let atom_buffer_bytes = n_atoms * 32;
68    let total_gpu_memory = excl_bitmap_bytes + atom_buffer_bytes;
69    const MAX_GPU_BUFFER: usize = 512 * 1024 * 1024; // 512 MB
70    if total_gpu_memory > MAX_GPU_BUFFER {
71        return Ok(compute_mmff94_nonbonded_cpu(
72            coords,
73            charges,
74            vdw_radii,
75            vdw_epsilon,
76            exclusions_14,
77            "System too large for GPU buffers",
78        ));
79    }
80
81    // Pack atom data: [x, y, z, charge, r_vdw, eps_vdw, pad, pad] per atom (32 bytes)
82    let mut atom_bytes = Vec::with_capacity(n_atoms * 32);
83    for i in 0..n_atoms {
84        atom_bytes.extend_from_slice(&(coords[i * 3] as f32).to_ne_bytes());
85        atom_bytes.extend_from_slice(&(coords[i * 3 + 1] as f32).to_ne_bytes());
86        atom_bytes.extend_from_slice(&(coords[i * 3 + 2] as f32).to_ne_bytes());
87        atom_bytes.extend_from_slice(&(charges[i] as f32).to_ne_bytes());
88        atom_bytes.extend_from_slice(&(vdw_radii[i] as f32).to_ne_bytes());
89        atom_bytes.extend_from_slice(&(vdw_epsilon[i] as f32).to_ne_bytes());
90        atom_bytes.extend_from_slice(&0.0f32.to_ne_bytes());
91        atom_bytes.extend_from_slice(&0.0f32.to_ne_bytes());
92    }
93
94    // Pack exclusions as a bitmap for fast lookup
95    let excl_size = (n_atoms * n_atoms).div_ceil(32);
96    let mut excl_bits = vec![0u32; excl_size];
97    for &(i, j) in exclusions_14 {
98        let bit_idx = i * n_atoms + j;
99        excl_bits[bit_idx / 32] |= 1 << (bit_idx % 32);
100        let bit_idx2 = j * n_atoms + i;
101        excl_bits[bit_idx2 / 32] |= 1 << (bit_idx2 % 32);
102    }
103    let excl_bytes: Vec<u8> = excl_bits.iter().flat_map(|b| b.to_ne_bytes()).collect();
104
105    // Params: n_atoms, pad, pad, pad
106    let mut params = Vec::with_capacity(16);
107    params.extend_from_slice(&(n_atoms as u32).to_ne_bytes());
108    params.extend_from_slice(&0u32.to_ne_bytes());
109    params.extend_from_slice(&0u32.to_ne_bytes());
110    params.extend_from_slice(&0u32.to_ne_bytes());
111
112    // Output: per-pair energies (vdw + elec), reduced on GPU
113    let n_output = 2; // [total_vdw, total_elec]
114    let output_bytes = vec![0u8; n_output * 4];
115
116    let wg_size = 64u32;
117    let n_pairs = n_atoms * (n_atoms - 1) / 2;
118    let wg_count = (n_pairs as u32).div_ceil(wg_size);
119
120    let descriptor = ComputeDispatchDescriptor {
121        label: "MMFF94 non-bonded".to_string(),
122        shader_source: MMFF94_NB_SHADER.to_string(),
123        entry_point: "main".to_string(),
124        workgroup_count: [wg_count, 1, 1],
125        bindings: vec![
126            ComputeBindingDescriptor {
127                label: "atoms".to_string(),
128                kind: ComputeBindingKind::StorageReadOnly,
129                bytes: atom_bytes,
130            },
131            ComputeBindingDescriptor {
132                label: "exclusions".to_string(),
133                kind: ComputeBindingKind::StorageReadOnly,
134                bytes: excl_bytes,
135            },
136            ComputeBindingDescriptor {
137                label: "params".to_string(),
138                kind: ComputeBindingKind::Uniform,
139                bytes: params,
140            },
141            ComputeBindingDescriptor {
142                label: "output".to_string(),
143                kind: ComputeBindingKind::StorageReadWrite,
144                bytes: output_bytes,
145            },
146        ],
147    };
148
149    let mut result = ctx.run_compute(&descriptor)?;
150    let bytes = result.outputs.pop().ok_or("No output from MMFF94 kernel")?;
151
152    if bytes.len() < 8 {
153        return Err("Insufficient output from MMFF94 kernel".to_string());
154    }
155
156    let vdw = f32::from_ne_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as f64;
157    let elec = f32::from_ne_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]) as f64;
158
159    Ok(Mmff94GpuResult {
160        total_energy: vdw + elec,
161        vdw_energy: vdw,
162        electrostatic_energy: elec,
163        n_pairs,
164        used_gpu: true,
165        backend: ctx.capabilities.backend.clone(),
166    })
167}
168
169/// CPU fallback for MMFF94 non-bonded energy.
170pub fn compute_mmff94_nonbonded_cpu(
171    coords: &[f64],
172    charges: &[f64],
173    vdw_radii: &[f64],
174    vdw_epsilon: &[f64],
175    exclusions_14: &[(usize, usize)],
176    note: &str,
177) -> Mmff94GpuResult {
178    let n_atoms = charges.len();
179    let mut vdw_energy = 0.0;
180    let mut elec_energy = 0.0;
181    let mut n_pairs = 0;
182
183    // Build exclusion set
184    let mut excl_set = std::collections::HashSet::new();
185    for &(i, j) in exclusions_14 {
186        excl_set.insert((i.min(j), i.max(j)));
187    }
188
189    for i in 0..n_atoms {
190        let xi = coords[i * 3];
191        let yi = coords[i * 3 + 1];
192        let zi = coords[i * 3 + 2];
193
194        for j in (i + 1)..n_atoms {
195            if excl_set.contains(&(i, j)) {
196                continue;
197            }
198
199            let dx = xi - coords[j * 3];
200            let dy = yi - coords[j * 3 + 1];
201            let dz = zi - coords[j * 3 + 2];
202            let r2 = dx * dx + dy * dy + dz * dz;
203            let r = r2.sqrt();
204
205            if !(0.1..=15.0).contains(&r) {
206                continue;
207            }
208
209            // MMFF94 Buffered 14-7 van der Waals
210            let r_star = vdw_radii[i] + vdw_radii[j];
211            let eps = (vdw_epsilon[i] * vdw_epsilon[j]).sqrt();
212            let rho = r / r_star;
213            let rho7 = rho.powi(7);
214            let e_vdw = eps * (1.07 / (rho + 0.07)).powi(7) * ((1.12 / (rho7 + 0.12)) - 2.0);
215            vdw_energy += e_vdw;
216
217            // Coulomb with dielectric screening
218            let e_elec = 332.0716 * charges[i] * charges[j] / (r + 0.05);
219            elec_energy += e_elec;
220
221            n_pairs += 1;
222        }
223    }
224
225    Mmff94GpuResult {
226        total_energy: vdw_energy + elec_energy,
227        vdw_energy,
228        electrostatic_energy: elec_energy,
229        n_pairs,
230        used_gpu: false,
231        backend: format!("CPU ({})", note),
232    }
233}
234
235/// WGSL compute shader for MMFF94 non-bonded interactions.
236///
237/// Each workgroup processes a batch of atom pairs, computing buffered-14-7
238/// VdW and Coulomb electrostatics. Exclusions are checked via bitmap.
239#[cfg(feature = "experimental-gpu")]
240pub const MMFF94_NB_SHADER: &str = r#"
241struct Atom {
242    x: f32, y: f32, z: f32,
243    charge: f32,
244    r_vdw: f32, eps_vdw: f32,
245    _pad0: f32, _pad1: f32,
246};
247
248struct Params {
249    n_atoms: u32,
250    _pad0: u32, _pad1: u32, _pad2: u32,
251};
252
253@group(0) @binding(0) var<storage, read> atoms: array<Atom>;
254@group(0) @binding(1) var<storage, read> exclusions: array<u32>;
255@group(0) @binding(2) var<uniform> params: Params;
256@group(0) @binding(3) var<storage, read_write> output: array<atomic<u32>>;
257
258fn pair_to_ij(pair_idx: u32, n: u32) -> vec2<u32> {
259    // Convert linear pair index to (i, j) with i < j
260    var i: u32 = 0u;
261    var remaining: u32 = pair_idx;
262    loop {
263        let row_size = n - 1u - i;
264        if (remaining < row_size) {
265            break;
266        }
267        remaining -= row_size;
268        i += 1u;
269    }
270    let j = i + 1u + remaining;
271    return vec2<u32>(i, j);
272}
273
274@compute @workgroup_size(64, 1, 1)
275fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
276    let n = params.n_atoms;
277    let n_pairs = n * (n - 1u) / 2u;
278    let pair_idx = gid.x;
279
280    if (pair_idx >= n_pairs) {
281        return;
282    }
283
284    let ij = pair_to_ij(pair_idx, n);
285    let i = ij.x;
286    let j = ij.y;
287
288    // Check exclusion bitmap
289    let bit_idx = i * n + j;
290    let word = exclusions[bit_idx / 32u];
291    if ((word >> (bit_idx % 32u)) & 1u) != 0u {
292        return;
293    }
294
295    let a = atoms[i];
296    let b = atoms[j];
297
298    let dx = a.x - b.x;
299    let dy = a.y - b.y;
300    let dz = a.z - b.z;
301    let r = sqrt(dx * dx + dy * dy + dz * dz);
302
303    if (r < 0.1 || r > 15.0) {
304        return;
305    }
306
307    // Buffered 14-7 VdW
308    let r_star = a.r_vdw + b.r_vdw;
309    let eps = sqrt(a.eps_vdw * b.eps_vdw);
310    let rho = r / r_star;
311    let rho7 = pow(rho, 7.0);
312    let e_vdw = eps * pow(1.07 / (rho + 0.07), 7.0) * ((1.12 / (rho7 + 0.12)) - 2.0);
313
314    // Coulomb with distance-dependent dielectric
315    let e_elec = 332.0716 * a.charge * b.charge / (r + 0.05);
316
317    // Atomic add to output (using integer representation)
318    let vdw_bits = bitcast<u32>(e_vdw);
319    let elec_bits = bitcast<u32>(e_elec);
320    atomicAdd(&output[0], vdw_bits);
321    atomicAdd(&output[1], elec_bits);
322}
323"#;
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328
329    #[test]
330    fn test_mmff94_cpu_fallback() {
331        // Simple 2-atom system
332        let coords = vec![0.0, 0.0, 0.0, 3.0, 0.0, 0.0];
333        let charges = vec![0.3, -0.3];
334        let vdw_radii = vec![1.5, 1.7];
335        let vdw_epsilon = vec![0.05, 0.06];
336
337        let result =
338            compute_mmff94_nonbonded_cpu(&coords, &charges, &vdw_radii, &vdw_epsilon, &[], "test");
339        assert_eq!(result.n_pairs, 1);
340        assert!(result.total_energy.is_finite());
341        assert!(result.electrostatic_energy < 0.0); // opposite charges attract
342        assert!(!result.used_gpu);
343    }
344}