Skip to main content

sci_form/gpu/
d4_dispersion_gpu.rs

1//! GPU-accelerated two-body D4 dispersion accumulation.
2
3use super::context::{
4    bytes_to_f32_vec, ceil_div_u32, f32_slice_to_bytes, pack_uniform_values,
5    pack_vec3_positions_f32, ComputeBindingDescriptor, ComputeBindingKind,
6    ComputeDispatchDescriptor, GpuContext, UniformValue,
7};
8use crate::dispersion::{c8_from_c6, d4_coordination_number, dynamic_c6, D4Config, D4Result};
9
10const ANG_TO_BOHR: f64 = 1.0 / 0.529177;
11
12pub fn compute_d4_energy_gpu(
13    ctx: &GpuContext,
14    elements: &[u8],
15    positions: &[[f64; 3]],
16    config: &D4Config,
17) -> Result<D4Result, String> {
18    let n = elements.len();
19    if positions.len() != n {
20        return Err("elements/position length mismatch".to_string());
21    }
22    if n < 2 {
23        return Err("System too small for GPU dispatch".to_string());
24    }
25
26    let cn = d4_coordination_number(elements, positions);
27    let mut pair_params = vec![0.0f32; n * n * 4];
28    for i in 0..n {
29        for j in (i + 1)..n {
30            let c6 = dynamic_c6(elements[i], elements[j], cn[i], cn[j]);
31            let c8 = c8_from_c6(c6, elements[i], elements[j]);
32            let r0 = if c6 > 1e-10 { (c8 / c6).sqrt() } else { 5.0 };
33            let r_cut = config.a1 * r0 + config.a2;
34            let base = (i * n + j) * 4;
35            pair_params[base] = c6 as f32;
36            pair_params[base + 1] = c8 as f32;
37            pair_params[base + 2] = r_cut as f32;
38            pair_params[base + 3] = r_cut as f32;
39        }
40    }
41
42    let params_bytes = pack_uniform_values(&[
43        UniformValue::U32(n as u32),
44        UniformValue::U32(0),
45        UniformValue::U32(0),
46        UniformValue::U32(0),
47        UniformValue::F32(config.s6 as f32),
48        UniformValue::F32(config.s8 as f32),
49        UniformValue::F32(ANG_TO_BOHR as f32),
50        UniformValue::F32(0.0),
51    ]);
52
53    let descriptor = ComputeDispatchDescriptor {
54        label: "d4 dispersion".to_string(),
55        shader_source: D4_DISPERSION_SHADER.to_string(),
56        entry_point: "main".to_string(),
57        workgroup_count: [ceil_div_u32(n, 16), ceil_div_u32(n, 16), 1],
58        bindings: vec![
59            ComputeBindingDescriptor {
60                label: "positions".to_string(),
61                kind: ComputeBindingKind::StorageReadOnly,
62                bytes: pack_vec3_positions_f32(positions),
63            },
64            ComputeBindingDescriptor {
65                label: "pair_params".to_string(),
66                kind: ComputeBindingKind::StorageReadOnly,
67                bytes: f32_slice_to_bytes(&pair_params),
68            },
69            ComputeBindingDescriptor {
70                label: "params".to_string(),
71                kind: ComputeBindingKind::Uniform,
72                bytes: params_bytes,
73            },
74            ComputeBindingDescriptor {
75                label: "output".to_string(),
76                kind: ComputeBindingKind::StorageReadWrite,
77                bytes: f32_slice_to_bytes(&vec![0.0f32; n * n]),
78            },
79        ],
80    };
81
82    let mut outputs = ctx.run_compute(&descriptor)?.outputs;
83    let bytes = outputs.pop().ok_or("No output from D4 dispersion kernel")?;
84    let pair_energies = bytes_to_f32_vec(&bytes);
85    if pair_energies.len() != n * n {
86        return Err(format!(
87            "Output size mismatch: expected {}, got {}",
88            n * n,
89            pair_energies.len()
90        ));
91    }
92
93    let mut e2 = 0.0;
94    for i in 0..n {
95        for j in (i + 1)..n {
96            e2 += pair_energies[i * n + j] as f64;
97        }
98    }
99
100    let e3 = if config.three_body {
101        crate::dispersion::compute_d4_energy(elements, positions, config).e3_body
102    } else {
103        0.0
104    };
105    let total = e2 + e3;
106
107    Ok(D4Result {
108        e2_body: e2,
109        e3_body: e3,
110        total_energy: total,
111        total_kcal_mol: total * 627.509,
112        coordination_numbers: cn,
113    })
114}
115
116pub const D4_DISPERSION_SHADER: &str = r#"
117struct AtomPos {
118    x: f32, y: f32, z: f32, _pad: f32,
119};
120
121struct Params {
122    n_atoms: u32,
123    _pad0: u32,
124    _pad1: u32,
125    _pad2: u32,
126    s6: f32,
127    s8: f32,
128    ang_to_bohr: f32,
129    _pad3: f32,
130};
131
132@group(0) @binding(0) var<storage, read> positions: array<AtomPos>;
133@group(0) @binding(1) var<storage, read> pair_params: array<f32>;
134@group(0) @binding(2) var<uniform> params: Params;
135@group(0) @binding(3) var<storage, read_write> output: array<f32>;
136
137@compute @workgroup_size(16, 16, 1)
138fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
139    let i = gid.x;
140    let j = gid.y;
141    let n = params.n_atoms;
142    if (i >= n || j >= n) { return; }
143    if (j <= i) {
144        output[i * n + j] = 0.0;
145        return;
146    }
147
148    let base = (i * n + j) * 4u;
149    let c6 = pair_params[base];
150    if (c6 <= 1e-10) {
151        output[i * n + j] = 0.0;
152        return;
153    }
154    let c8 = pair_params[base + 1u];
155    let r_cut6 = pair_params[base + 2u];
156    let r_cut8 = pair_params[base + 3u];
157
158    let pi = positions[i];
159    let pj = positions[j];
160    let dx = (pi.x - pj.x) * params.ang_to_bohr;
161    let dy = (pi.y - pj.y) * params.ang_to_bohr;
162    let dz = (pi.z - pj.z) * params.ang_to_bohr;
163    let r = sqrt(dx * dx + dy * dy + dz * dz);
164    if (r < 1e-10) {
165        output[i * n + j] = 0.0;
166        return;
167    }
168
169    let r2 = r * r;
170    let r6 = r2 * r2 * r2;
171    let damp6 = r6 / (r6 + pow(r_cut6, 6.0));
172    let r8 = r6 * r2;
173    let damp8 = r8 / (r8 + pow(r_cut8, 8.0));
174
175    output[i * n + j] = -params.s6 * c6 / r6 * damp6 - params.s8 * c8 / r8 * damp8;
176}
177"#;