1use super::context::{
11 ComputeBindingDescriptor, ComputeBindingKind, ComputeDispatchDescriptor, GpuContext,
12};
13use crate::scf::basis::BasisSet;
14use crate::scf::two_electron::TwoElectronIntegrals;
15
16const MAX_PRIMITIVES: usize = 3;
18
19const GPU_DISPATCH_THRESHOLD: usize = 4;
21
22fn pack_basis_eri(basis: &BasisSet) -> (Vec<u8>, Vec<u8>) {
30 let mut basis_bytes = Vec::with_capacity(basis.n_basis * 32);
31 let mut prim_bytes = Vec::with_capacity(basis.n_basis * 24);
32
33 for bf in &basis.functions {
34 basis_bytes.extend_from_slice(&(bf.center[0] as f32).to_ne_bytes());
36 basis_bytes.extend_from_slice(&(bf.center[1] as f32).to_ne_bytes());
37 basis_bytes.extend_from_slice(&(bf.center[2] as f32).to_ne_bytes());
38 basis_bytes.extend_from_slice(&bf.angular[0].to_ne_bytes());
40 basis_bytes.extend_from_slice(&bf.angular[1].to_ne_bytes());
41 basis_bytes.extend_from_slice(&bf.angular[2].to_ne_bytes());
42 basis_bytes.extend_from_slice(&(bf.primitives.len() as u32).to_ne_bytes());
44 basis_bytes.extend_from_slice(&0u32.to_ne_bytes());
45
46 for i in 0..MAX_PRIMITIVES {
48 if i < bf.primitives.len() {
49 let norm = crate::scf::basis::BasisFunction::normalization(
50 bf.primitives[i].alpha,
51 bf.angular[0],
52 bf.angular[1],
53 bf.angular[2],
54 );
55 prim_bytes.extend_from_slice(&(bf.primitives[i].alpha as f32).to_ne_bytes());
56 prim_bytes.extend_from_slice(
57 &((bf.primitives[i].coefficient * norm) as f32).to_ne_bytes(),
58 );
59 } else {
60 prim_bytes.extend_from_slice(&0.0f32.to_ne_bytes());
61 prim_bytes.extend_from_slice(&0.0f32.to_ne_bytes());
62 }
63 }
64 }
65
66 (basis_bytes, prim_bytes)
67}
68
69fn enumerate_unique_quartets(n: usize) -> (usize, Vec<u8>) {
73 let mut quartets = Vec::new();
74 let mut count = 0usize;
75
76 for i in 0..n {
77 for j in 0..=i {
78 let ij = i * n + j;
79 for k in 0..n {
80 for l in 0..=k {
81 let kl = k * n + l;
82 if ij < kl {
83 continue;
84 }
85 quartets.extend_from_slice(&(i as u32).to_ne_bytes());
86 quartets.extend_from_slice(&(j as u32).to_ne_bytes());
87 quartets.extend_from_slice(&(k as u32).to_ne_bytes());
88 quartets.extend_from_slice(&(l as u32).to_ne_bytes());
89 count += 1;
90 }
91 }
92 }
93 }
94
95 (count, quartets)
96}
97
98pub fn compute_eris_gpu(
104 ctx: &GpuContext,
105 basis: &BasisSet,
106) -> Result<TwoElectronIntegrals, String> {
107 let n = basis.n_basis;
108
109 if n < GPU_DISPATCH_THRESHOLD {
110 return Err("Basis too small for GPU dispatch".to_string());
111 }
112
113 let output_size = (n * n * n * n * 4) as u64;
115 if output_size > ctx.capabilities.max_storage_buffer_size {
116 return Err(format!(
117 "ERI tensor ({} bytes) exceeds GPU storage limit ({} bytes)",
118 output_size, ctx.capabilities.max_storage_buffer_size
119 ));
120 }
121
122 let (basis_bytes, prim_bytes) = pack_basis_eri(basis);
123 let (n_quartets, quartet_bytes) = enumerate_unique_quartets(n);
124
125 let mut params = Vec::with_capacity(16);
127 params.extend_from_slice(&(n as u32).to_ne_bytes());
128 params.extend_from_slice(&(n_quartets as u32).to_ne_bytes());
129 params.extend_from_slice(&0u32.to_ne_bytes());
130 params.extend_from_slice(&0u32.to_ne_bytes());
131
132 let output_seed = vec![0u8; n * n * n * n * 4];
134
135 let workgroup_count = [(n_quartets as u32).div_ceil(64), 1, 1];
136
137 let descriptor = ComputeDispatchDescriptor {
138 label: "two-electron ERI".to_string(),
139 shader_source: TWO_ELECTRON_SHADER.to_string(),
140 entry_point: "main".to_string(),
141 workgroup_count,
142 bindings: vec![
143 ComputeBindingDescriptor {
144 label: "basis".to_string(),
145 kind: ComputeBindingKind::StorageReadOnly,
146 bytes: basis_bytes,
147 },
148 ComputeBindingDescriptor {
149 label: "primitives".to_string(),
150 kind: ComputeBindingKind::StorageReadOnly,
151 bytes: prim_bytes,
152 },
153 ComputeBindingDescriptor {
154 label: "quartets".to_string(),
155 kind: ComputeBindingKind::StorageReadOnly,
156 bytes: quartet_bytes,
157 },
158 ComputeBindingDescriptor {
159 label: "params".to_string(),
160 kind: ComputeBindingKind::Uniform,
161 bytes: params,
162 },
163 ComputeBindingDescriptor {
164 label: "output".to_string(),
165 kind: ComputeBindingKind::StorageReadWrite,
166 bytes: output_seed,
167 },
168 ],
169 };
170
171 let mut result = ctx.run_compute(&descriptor)?;
172 let bytes = result.outputs.pop().ok_or("No output from ERI kernel")?;
173
174 if bytes.len() != n * n * n * n * 4 {
176 return Err(format!(
177 "ERI output size mismatch: expected {}, got {}",
178 n * n * n * n * 4,
179 bytes.len()
180 ));
181 }
182
183 let data: Vec<f64> = bytes
184 .chunks_exact(4)
185 .map(|c| f32::from_ne_bytes([c[0], c[1], c[2], c[3]]) as f64)
186 .collect();
187
188 Ok(TwoElectronIntegrals::from_raw(data, n))
189}
190
191pub const TWO_ELECTRON_SHADER: &str = r#"
199struct BasisFunc {
200 cx: f32, cy: f32, cz: f32,
201 lx: u32, ly: u32, lz: u32,
202 n_prims: u32, _pad: u32,
203};
204
205struct Params {
206 n_basis: u32,
207 n_quartets: u32,
208 _pad0: u32,
209 _pad1: u32,
210};
211
212@group(0) @binding(0) var<storage, read> basis: array<BasisFunc>;
213@group(0) @binding(1) var<storage, read> primitives: array<vec2<f32>>; // (alpha, norm*coeff)
214@group(0) @binding(2) var<storage, read> quartets: array<vec4<u32>>;
215@group(0) @binding(3) var<uniform> params: Params;
216@group(0) @binding(4) var<storage, read_write> output: array<f32>;
217
218// Boys function F₀(x) via series expansion
219fn boys_f0(x: f32) -> f32 {
220 if (x < 1e-7) {
221 return 1.0;
222 }
223 if (x > 30.0) {
224 return 0.8862269 / sqrt(x); // √π / (2√x)
225 }
226 var sum: f32 = 1.0;
227 var term: f32 = 1.0;
228 for (var k: u32 = 1u; k < 50u; k = k + 1u) {
229 term *= 2.0 * x / f32(2u * k + 1u);
230 sum += term;
231 if (abs(term) < 1e-6 * abs(sum)) {
232 break;
233 }
234 }
235 return exp(-x) * sum;
236}
237
238@compute @workgroup_size(64)
239fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
240 let idx = gid.x;
241 if (idx >= params.n_quartets) {
242 return;
243 }
244
245 let q = quartets[idx];
246 let i = q.x;
247 let j = q.y;
248 let k = q.z;
249 let l = q.w;
250
251 let bf_i = basis[i];
252 let bf_j = basis[j];
253 let bf_k = basis[k];
254 let bf_l = basis[l];
255
256 var eri: f32 = 0.0;
257
258 // Contract over primitives (max 3 per function for STO-3G)
259 for (var pi: u32 = 0u; pi < bf_i.n_prims; pi = pi + 1u) {
260 let prim_a = primitives[i * 3u + pi];
261 let alpha = prim_a.x;
262 let ca = prim_a.y;
263
264 for (var pj: u32 = 0u; pj < bf_j.n_prims; pj = pj + 1u) {
265 let prim_b = primitives[j * 3u + pj];
266 let beta = prim_b.x;
267 let cb = prim_b.y;
268
269 let p = alpha + beta;
270 let mu_ab = alpha * beta / p;
271 let ab2 = (bf_i.cx - bf_j.cx) * (bf_i.cx - bf_j.cx)
272 + (bf_i.cy - bf_j.cy) * (bf_i.cy - bf_j.cy)
273 + (bf_i.cz - bf_j.cz) * (bf_i.cz - bf_j.cz);
274 let k_ab = exp(-mu_ab * ab2);
275
276 let px = (alpha * bf_i.cx + beta * bf_j.cx) / p;
277 let py = (alpha * bf_i.cy + beta * bf_j.cy) / p;
278 let pz = (alpha * bf_i.cz + beta * bf_j.cz) / p;
279
280 for (var pk: u32 = 0u; pk < bf_k.n_prims; pk = pk + 1u) {
281 let prim_c = primitives[k * 3u + pk];
282 let gamma = prim_c.x;
283 let cc = prim_c.y;
284
285 for (var pl: u32 = 0u; pl < bf_l.n_prims; pl = pl + 1u) {
286 let prim_d = primitives[l * 3u + pl];
287 let delta = prim_d.x;
288 let cd = prim_d.y;
289
290 let qq = gamma + delta;
291 let mu_cd = gamma * delta / qq;
292 let cd2 = (bf_k.cx - bf_l.cx) * (bf_k.cx - bf_l.cx)
293 + (bf_k.cy - bf_l.cy) * (bf_k.cy - bf_l.cy)
294 + (bf_k.cz - bf_l.cz) * (bf_k.cz - bf_l.cz);
295 let k_cd = exp(-mu_cd * cd2);
296
297 let qx = (gamma * bf_k.cx + delta * bf_l.cx) / qq;
298 let qy = (gamma * bf_k.cy + delta * bf_l.cy) / qq;
299 let qz = (gamma * bf_k.cz + delta * bf_l.cz) / qq;
300
301 let pq2 = (px - qx) * (px - qx)
302 + (py - qy) * (py - qy)
303 + (pz - qz) * (pz - qz);
304 let alpha_pq = p * qq / (p + qq);
305
306 // prefactor = 2π^{5/2} / (p · q · √(p+q))
307 let prefactor = 2.0 * pow(3.14159265, 2.5) / (p * qq * sqrt(p + qq));
308
309 eri += ca * cb * cc * cd * prefactor * k_ab * k_cd * boys_f0(alpha_pq * pq2);
310 }
311 }
312 }
313 }
314
315 let n = params.n_basis;
316 let n2 = n * n;
317
318 // Store with 8-fold symmetry
319 output[i * n * n2 + j * n2 + k * n + l] = eri;
320 output[j * n * n2 + i * n2 + k * n + l] = eri;
321 output[i * n * n2 + j * n2 + l * n + k] = eri;
322 output[j * n * n2 + i * n2 + l * n + k] = eri;
323 output[k * n * n2 + l * n2 + i * n + j] = eri;
324 output[l * n * n2 + k * n2 + i * n + j] = eri;
325 output[k * n * n2 + l * n2 + j * n + i] = eri;
326 output[l * n * n2 + k * n2 + j * n + i] = eri;
327}
328"#;
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333
334 #[test]
335 fn test_enumerate_quartets_h2() {
336 let (count, bytes) = enumerate_unique_quartets(2);
338 assert!(count > 0);
339 assert_eq!(bytes.len(), count * 16); }
341
342 #[test]
343 fn test_enumerate_quartets_single() {
344 let (count, _) = enumerate_unique_quartets(1);
346 assert_eq!(count, 1);
347 }
348
349 #[test]
350 fn test_pack_basis_eri() {
351 let basis =
352 crate::scf::basis::BasisSet::sto3g(&[1, 1], &[[0.0, 0.0, 0.0], [1.4, 0.0, 0.0]]);
353 let (basis_bytes, prim_bytes) = pack_basis_eri(&basis);
354 assert_eq!(basis_bytes.len(), basis.n_basis * 32);
356 assert_eq!(prim_bytes.len(), basis.n_basis * 24);
358 }
359
360 #[test]
361 fn test_gpu_threshold() {
362 let ctx = GpuContext::cpu_fallback();
363 let basis = crate::scf::basis::BasisSet::sto3g(&[1], &[[0.0, 0.0, 0.0]]);
364 let result = compute_eris_gpu(&ctx, &basis);
365 assert!(result.is_err()); }
367}