1use super::context::{
4 bytes_to_f32_vec, ceil_div_u32, f32_slice_to_bytes, pack_uniform_values,
5 pack_vec3_positions_f32, ComputeBindingDescriptor, ComputeBindingKind,
6 ComputeDispatchDescriptor, GpuContext, UniformValue,
7};
8use crate::dispersion::{c8_from_c6, d4_coordination_number, dynamic_c6, D4Config, D4Result};
9
10const ANG_TO_BOHR: f64 = 1.0 / 0.529177;
11
12pub fn compute_d4_energy_gpu(
13 ctx: &GpuContext,
14 elements: &[u8],
15 positions: &[[f64; 3]],
16 config: &D4Config,
17) -> Result<D4Result, String> {
18 let n = elements.len();
19 if positions.len() != n {
20 return Err("elements/position length mismatch".to_string());
21 }
22 if n < 2 {
23 return Err("System too small for GPU dispatch".to_string());
24 }
25
26 let cn = d4_coordination_number(elements, positions);
27 let mut pair_params = vec![0.0f32; n * n * 4];
28 for i in 0..n {
29 for j in (i + 1)..n {
30 let c6 = dynamic_c6(elements[i], elements[j], cn[i], cn[j]);
31 let c8 = c8_from_c6(c6, elements[i], elements[j]);
32 let r0 = if c6 > 1e-10 { (c8 / c6).sqrt() } else { 5.0 };
33 let r_cut = config.a1 * r0 + config.a2;
34 let base = (i * n + j) * 4;
35 pair_params[base] = c6 as f32;
36 pair_params[base + 1] = c8 as f32;
37 pair_params[base + 2] = r_cut as f32;
38 pair_params[base + 3] = r_cut as f32;
39 }
40 }
41
42 let params_bytes = pack_uniform_values(&[
43 UniformValue::U32(n as u32),
44 UniformValue::U32(0),
45 UniformValue::U32(0),
46 UniformValue::U32(0),
47 UniformValue::F32(config.s6 as f32),
48 UniformValue::F32(config.s8 as f32),
49 UniformValue::F32(ANG_TO_BOHR as f32),
50 UniformValue::F32(0.0),
51 ]);
52
53 let descriptor = ComputeDispatchDescriptor {
54 label: "d4 dispersion".to_string(),
55 shader_source: D4_DISPERSION_SHADER.to_string(),
56 entry_point: "main".to_string(),
57 workgroup_count: [ceil_div_u32(n, 16), ceil_div_u32(n, 16), 1],
58 bindings: vec![
59 ComputeBindingDescriptor {
60 label: "positions".to_string(),
61 kind: ComputeBindingKind::StorageReadOnly,
62 bytes: pack_vec3_positions_f32(positions),
63 },
64 ComputeBindingDescriptor {
65 label: "pair_params".to_string(),
66 kind: ComputeBindingKind::StorageReadOnly,
67 bytes: f32_slice_to_bytes(&pair_params),
68 },
69 ComputeBindingDescriptor {
70 label: "params".to_string(),
71 kind: ComputeBindingKind::Uniform,
72 bytes: params_bytes,
73 },
74 ComputeBindingDescriptor {
75 label: "output".to_string(),
76 kind: ComputeBindingKind::StorageReadWrite,
77 bytes: f32_slice_to_bytes(&vec![0.0f32; n * n]),
78 },
79 ],
80 };
81
82 let mut outputs = ctx.run_compute(&descriptor)?.outputs;
83 let bytes = outputs.pop().ok_or("No output from D4 dispersion kernel")?;
84 let pair_energies = bytes_to_f32_vec(&bytes);
85 if pair_energies.len() != n * n {
86 return Err(format!(
87 "Output size mismatch: expected {}, got {}",
88 n * n,
89 pair_energies.len()
90 ));
91 }
92
93 let mut e2 = 0.0;
94 for i in 0..n {
95 for j in (i + 1)..n {
96 e2 += pair_energies[i * n + j] as f64;
97 }
98 }
99
100 let e3 = if config.three_body {
101 crate::dispersion::compute_d4_energy(elements, positions, config).e3_body
102 } else {
103 0.0
104 };
105 let total = e2 + e3;
106
107 Ok(D4Result {
108 e2_body: e2,
109 e3_body: e3,
110 total_energy: total,
111 total_kcal_mol: total * 627.509,
112 coordination_numbers: cn,
113 })
114}
115
116pub const D4_DISPERSION_SHADER: &str = r#"
117struct AtomPos {
118 x: f32, y: f32, z: f32, _pad: f32,
119};
120
121struct Params {
122 n_atoms: u32,
123 _pad0: u32,
124 _pad1: u32,
125 _pad2: u32,
126 s6: f32,
127 s8: f32,
128 ang_to_bohr: f32,
129 _pad3: f32,
130};
131
132@group(0) @binding(0) var<storage, read> positions: array<AtomPos>;
133@group(0) @binding(1) var<storage, read> pair_params: array<f32>;
134@group(0) @binding(2) var<uniform> params: Params;
135@group(0) @binding(3) var<storage, read_write> output: array<f32>;
136
137@compute @workgroup_size(16, 16, 1)
138fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
139 let i = gid.x;
140 let j = gid.y;
141 let n = params.n_atoms;
142 if (i >= n || j >= n) { return; }
143 if (j <= i) {
144 output[i * n + j] = 0.0;
145 return;
146 }
147
148 let base = (i * n + j) * 4u;
149 let c6 = pair_params[base];
150 if (c6 <= 1e-10) {
151 output[i * n + j] = 0.0;
152 return;
153 }
154 let c8 = pair_params[base + 1u];
155 let r_cut6 = pair_params[base + 2u];
156 let r_cut8 = pair_params[base + 3u];
157
158 let pi = positions[i];
159 let pj = positions[j];
160 let dx = (pi.x - pj.x) * params.ang_to_bohr;
161 let dy = (pi.y - pj.y) * params.ang_to_bohr;
162 let dz = (pi.z - pj.z) * params.ang_to_bohr;
163 let r = sqrt(dx * dx + dy * dy + dz * dz);
164 if (r < 1e-10) {
165 output[i * n + j] = 0.0;
166 return;
167 }
168
169 let r2 = r * r;
170 let r6 = r2 * r2 * r2;
171 let damp6 = r6 / (r6 + pow(r_cut6, 6.0));
172 let r8 = r6 * r2;
173 let damp8 = r8 / (r8 + pow(r_cut8, 8.0));
174
175 output[i * n + j] = -params.s6 * c6 / r6 * damp6 - params.s8 * c8 / r8 * damp8;
176}
177"#;