1use nalgebra::{DMatrix, DVector};
4
5use super::context::{
6 bytes_to_f64_vec_from_f32, ceil_div_u32, f32_slice_to_bytes, pack_uniform_values,
7 pack_vec3_positions_f32, ComputeBindingDescriptor, ComputeBindingKind,
8 ComputeDispatchDescriptor, GpuContext, UniformValue,
9};
10use crate::charges_eeq::{fractional_coordination, get_eeq_params, EeqChargeResult, EeqConfig};
11
12const GPU_DISPATCH_THRESHOLD: usize = 2;
13
14pub fn compute_eeq_charges_gpu(
15 ctx: &GpuContext,
16 elements: &[u8],
17 positions: &[[f64; 3]],
18 config: &EeqConfig,
19) -> Result<EeqChargeResult, String> {
20 let n = elements.len();
21 if positions.len() != n {
22 return Err("elements/position length mismatch".to_string());
23 }
24 if n < GPU_DISPATCH_THRESHOLD {
25 return Err("System too small for GPU dispatch".to_string());
26 }
27
28 let params: Vec<_> = elements.iter().map(|&z| get_eeq_params(z)).collect();
29 let cn = fractional_coordination(elements, positions);
30 let gamma = build_eeq_coulomb_gpu(
31 ctx,
32 ¶ms.iter().map(|p| p.r_eeq).collect::<Vec<_>>(),
33 positions,
34 )?;
35
36 let dim = n + 1;
37 let mut a = DMatrix::zeros(dim, dim);
38 let mut b_vec = vec![0.0; dim];
39
40 for i in 0..n {
41 a[(i, i)] = params[i].eta + config.regularization;
42
43 for j in (i + 1)..n {
44 let gij = gamma[(i, j)];
45 a[(i, j)] = gij;
46 a[(j, i)] = gij;
47 }
48
49 a[(i, n)] = 1.0;
50 a[(n, i)] = 1.0;
51
52 let cn_correction = -0.1 * (cn[i] - 2.0);
53 b_vec[i] = -(params[i].chi + cn_correction);
54 }
55
56 b_vec[n] = config.total_charge;
57
58 let solution = a.lu().solve(&DVector::from_vec(b_vec));
59 let charges = match solution {
60 Some(sol) => (0..n).map(|i| sol[i]).collect(),
61 None => vec![0.0; n],
62 };
63 let total_charge = charges.iter().sum();
64
65 Ok(EeqChargeResult {
66 charges,
67 coordination_numbers: cn,
68 total_charge,
69 })
70}
71
72pub fn build_eeq_coulomb_gpu(
73 ctx: &GpuContext,
74 radii: &[f64],
75 positions: &[[f64; 3]],
76) -> Result<DMatrix<f64>, String> {
77 let n = radii.len();
78 if positions.len() != n {
79 return Err("radii/position length mismatch".to_string());
80 }
81 if n < GPU_DISPATCH_THRESHOLD {
82 return Err("Matrix too small for GPU dispatch".to_string());
83 }
84
85 let descriptor = ComputeDispatchDescriptor {
86 label: "eeq coulomb".to_string(),
87 shader_source: EEQ_COULOMB_SHADER.to_string(),
88 entry_point: "main".to_string(),
89 workgroup_count: [ceil_div_u32(n, 16), ceil_div_u32(n, 16), 1],
90 bindings: vec![
91 ComputeBindingDescriptor {
92 label: "positions".to_string(),
93 kind: ComputeBindingKind::StorageReadOnly,
94 bytes: pack_vec3_positions_f32(positions),
95 },
96 ComputeBindingDescriptor {
97 label: "radii".to_string(),
98 kind: ComputeBindingKind::StorageReadOnly,
99 bytes: f32_slice_to_bytes(
100 &radii.iter().map(|value| *value as f32).collect::<Vec<_>>(),
101 ),
102 },
103 ComputeBindingDescriptor {
104 label: "params".to_string(),
105 kind: ComputeBindingKind::Uniform,
106 bytes: pack_uniform_values(&[
107 UniformValue::U32(n as u32),
108 UniformValue::U32(0),
109 UniformValue::U32(0),
110 UniformValue::U32(0),
111 ]),
112 },
113 ComputeBindingDescriptor {
114 label: "output".to_string(),
115 kind: ComputeBindingKind::StorageReadWrite,
116 bytes: f32_slice_to_bytes(&vec![0.0f32; n * n]),
117 },
118 ],
119 };
120
121 let mut outputs = ctx.run_compute(&descriptor)?.outputs;
122 let bytes = outputs.pop().ok_or("No output from EEQ Coulomb kernel")?;
123 if bytes.len() != n * n * 4 {
124 return Err(format!(
125 "EEQ Coulomb output size mismatch: expected {}, got {}",
126 n * n * 4,
127 bytes.len()
128 ));
129 }
130
131 Ok(DMatrix::from_row_slice(
132 n,
133 n,
134 &bytes_to_f64_vec_from_f32(&bytes),
135 ))
136}
137
138pub const EEQ_COULOMB_SHADER: &str = r#"
139struct AtomPos {
140 x: f32, y: f32, z: f32, _pad: f32,
141};
142
143struct Params {
144 n_atoms: u32,
145 _pad0: u32,
146 _pad1: u32,
147 _pad2: u32,
148};
149
150@group(0) @binding(0) var<storage, read> positions: array<AtomPos>;
151@group(0) @binding(1) var<storage, read> radii: array<f32>;
152@group(0) @binding(2) var<uniform> params: Params;
153@group(0) @binding(3) var<storage, read_write> output: array<f32>;
154
155fn erf_approx(x: f32) -> f32 {
156 let a1 = 0.254829592;
157 let a2 = -0.284496736;
158 let a3 = 1.421413741;
159 let a4 = -1.453152027;
160 let a5 = 1.061405429;
161 let p = 0.3275911;
162
163 let sign = select(-1.0, 1.0, x >= 0.0);
164 let ax = abs(x);
165 let t = 1.0 / (1.0 + p * ax);
166 let y = 1.0 - (((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t) * exp(-ax * ax);
167 return sign * y;
168}
169
170@compute @workgroup_size(16, 16, 1)
171fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
172 let i = gid.x;
173 let j = gid.y;
174 let n = params.n_atoms;
175 if (i >= n || j >= n) {
176 return;
177 }
178 if (i == j) {
179 output[i * n + j] = 0.0;
180 return;
181 }
182
183 let pi = positions[i];
184 let pj = positions[j];
185 let dx = pi.x - pj.x;
186 let dy = pi.y - pj.y;
187 let dz = pi.z - pj.z;
188 let r = sqrt(dx * dx + dy * dy + dz * dz);
189 if (r < 1e-10) {
190 output[i * n + j] = 0.0;
191 return;
192 }
193
194 let ri = radii[i];
195 let rj = radii[j];
196 let sigma = sqrt(ri * ri + rj * rj);
197 let arg = 1.41421356237 * r / sigma;
198 output[i * n + j] = erf_approx(arg) / r;
199}
200"#;
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205
206 #[test]
207 fn test_eeq_gpu_threshold() {
208 let ctx = GpuContext::cpu_fallback();
209 let result = compute_eeq_charges_gpu(&ctx, &[8], &[[0.0, 0.0, 0.0]], &EeqConfig::default());
210 assert!(result.is_err());
211 }
212}