1use super::backend_report::OrbitalGridReport;
10use super::context::{
11 bytes_to_f32_vec, f32_slice_to_bytes, ComputeBindingDescriptor, ComputeBindingKind,
12 ComputeDispatchDescriptor, GpuContext,
13};
14use crate::scf::basis::{BasisFunction, BasisSet};
15use nalgebra::DMatrix;
16
17#[derive(Debug, Clone)]
19pub struct GridParams {
20 pub origin: [f64; 3],
22 pub spacing: f64,
24 pub dimensions: [usize; 3],
26}
27
28impl GridParams {
29 pub fn from_molecule(positions: &[[f64; 3]], spacing: f64, padding: f64) -> Self {
31 let mut min = [f64::MAX; 3];
32 let mut max = [f64::MIN; 3];
33
34 for pos in positions {
35 for k in 0..3 {
36 min[k] = min[k].min(pos[k]);
37 max[k] = max[k].max(pos[k]);
38 }
39 }
40
41 let origin = [min[0] - padding, min[1] - padding, min[2] - padding];
42 let dimensions = [
43 ((max[0] - min[0] + 2.0 * padding) / spacing).ceil() as usize + 1,
44 ((max[1] - min[1] + 2.0 * padding) / spacing).ceil() as usize + 1,
45 ((max[2] - min[2] + 2.0 * padding) / spacing).ceil() as usize + 1,
46 ];
47
48 Self {
49 origin,
50 spacing,
51 dimensions,
52 }
53 }
54
55 pub fn n_points(&self) -> usize {
57 self.dimensions[0] * self.dimensions[1] * self.dimensions[2]
58 }
59
60 pub fn point(&self, ix: usize, iy: usize, iz: usize) -> [f64; 3] {
62 [
63 self.origin[0] + ix as f64 * self.spacing,
64 self.origin[1] + iy as f64 * self.spacing,
65 self.origin[2] + iz as f64 * self.spacing,
66 ]
67 }
68
69 pub fn flat_index(&self, ix: usize, iy: usize, iz: usize) -> usize {
71 ix * self.dimensions[1] * self.dimensions[2] + iy * self.dimensions[2] + iz
72 }
73}
74
75#[derive(Debug, Clone)]
77pub struct OrbitalGrid {
78 pub values: Vec<f64>,
80 pub params: GridParams,
81 pub orbital_index: usize,
82}
83
84pub fn evaluate_orbital_with_report(
88 basis: &BasisSet,
89 mo_coefficients: &DMatrix<f64>,
90 orbital_index: usize,
91 params: &GridParams,
92) -> (OrbitalGrid, OrbitalGridReport) {
93 let ctx = GpuContext::best_available();
94
95 if ctx.is_gpu_available() {
96 match evaluate_orbital_gpu(&ctx, basis, mo_coefficients, orbital_index, params) {
97 Ok(grid) => {
98 let report = OrbitalGridReport {
99 backend: ctx.capabilities.backend.clone(),
100 used_gpu: true,
101 attempted_gpu: true,
102 n_points: params.n_points(),
103 note: format!("GPU dispatch on {}", ctx.capabilities.backend),
104 };
105 return (grid, report);
106 }
107 Err(_err) => {
108 }
110 }
111 }
112
113 let grid = evaluate_orbital_cpu(basis, mo_coefficients, orbital_index, params);
114 let report = OrbitalGridReport {
115 backend: "CPU".to_string(),
116 used_gpu: false,
117 attempted_gpu: ctx.is_gpu_available(),
118 n_points: params.n_points(),
119 note: if ctx.is_gpu_available() {
120 "GPU available but dispatch failed; CPU fallback used".to_string()
121 } else {
122 "CPU evaluation (GPU not available)".to_string()
123 },
124 };
125 (grid, report)
126}
127
128pub fn evaluate_orbital_cpu(
130 basis: &BasisSet,
131 mo_coefficients: &DMatrix<f64>,
132 orbital_index: usize,
133 params: &GridParams,
134) -> OrbitalGrid {
135 let n_points = params.n_points();
136 let mut values = vec![0.0; n_points];
137 let n_basis = basis.n_basis;
138 let [nx, ny, nz] = params.dimensions;
139
140 for ix in 0..nx {
141 for iy in 0..ny {
142 for iz in 0..nz {
143 let r = params.point(ix, iy, iz);
144 let idx = params.flat_index(ix, iy, iz);
145
146 let mut psi = 0.0;
147 for mu in 0..n_basis {
148 let c_mu = mo_coefficients[(mu, orbital_index)];
149 if c_mu.abs() < 1e-15 {
150 continue;
151 }
152 let phi_mu = evaluate_basis_function(&basis.functions[mu], &r);
153 psi += c_mu * phi_mu;
154 }
155 values[idx] = psi;
156 }
157 }
158 }
159
160 OrbitalGrid {
161 values,
162 params: params.clone(),
163 orbital_index,
164 }
165}
166
167pub fn evaluate_density_cpu(
169 basis: &BasisSet,
170 density: &DMatrix<f64>,
171 params: &GridParams,
172) -> Vec<f64> {
173 let n_points = params.n_points();
174 let mut values = vec![0.0; n_points];
175 let n_basis = basis.n_basis;
176 let [nx, ny, nz] = params.dimensions;
177
178 for ix in 0..nx {
179 for iy in 0..ny {
180 for iz in 0..nz {
181 let r = params.point(ix, iy, iz);
182 let idx = params.flat_index(ix, iy, iz);
183
184 let phi: Vec<f64> = (0..n_basis)
185 .map(|mu| evaluate_basis_function(&basis.functions[mu], &r))
186 .collect();
187
188 let mut rho = 0.0;
189 for mu in 0..n_basis {
190 if phi[mu].abs() < 1e-15 {
191 continue;
192 }
193 for nu in 0..n_basis {
194 rho += density[(mu, nu)] * phi[mu] * phi[nu];
195 }
196 }
197 values[idx] = rho;
198 }
199 }
200 }
201 values
202}
203
204fn evaluate_basis_function(bf: &BasisFunction, r: &[f64; 3]) -> f64 {
206 let dx = r[0] - bf.center[0];
207 let dy = r[1] - bf.center[1];
208 let dz = r[2] - bf.center[2];
209 let r2 = dx * dx + dy * dy + dz * dz;
210
211 let angular = dx.powi(bf.angular[0] as i32)
212 * dy.powi(bf.angular[1] as i32)
213 * dz.powi(bf.angular[2] as i32);
214
215 let mut radial = 0.0;
216 for prim in &bf.primitives {
217 radial += prim.coefficient * (-prim.alpha * r2).exp();
218 }
219
220 BasisFunction::normalization(
221 bf.primitives.first().map(|p| p.alpha).unwrap_or(1.0),
222 bf.angular[0],
223 bf.angular[1],
224 bf.angular[2],
225 ) * angular
226 * radial
227}
228
229fn pack_basis_for_gpu(basis: &BasisSet) -> (Vec<u8>, Vec<u8>) {
238 let mut basis_bytes = Vec::new();
239 let mut prim_bytes = Vec::new();
240
241 for bf in &basis.functions {
242 basis_bytes.extend_from_slice(&(bf.center[0] as f32).to_ne_bytes());
244 basis_bytes.extend_from_slice(&(bf.center[1] as f32).to_ne_bytes());
245 basis_bytes.extend_from_slice(&(bf.center[2] as f32).to_ne_bytes());
246 basis_bytes.extend_from_slice(&bf.angular[0].to_ne_bytes());
248 basis_bytes.extend_from_slice(&bf.angular[1].to_ne_bytes());
249 basis_bytes.extend_from_slice(&bf.angular[2].to_ne_bytes());
250 basis_bytes.extend_from_slice(&(bf.primitives.len() as u32).to_ne_bytes());
252 let norm = BasisFunction::normalization(
254 bf.primitives.first().map(|p| p.alpha).unwrap_or(1.0),
255 bf.angular[0],
256 bf.angular[1],
257 bf.angular[2],
258 );
259 basis_bytes.extend_from_slice(&(norm as f32).to_ne_bytes());
260
261 for i in 0..3 {
263 if i < bf.primitives.len() {
264 prim_bytes.extend_from_slice(&(bf.primitives[i].alpha as f32).to_ne_bytes());
265 prim_bytes.extend_from_slice(&(bf.primitives[i].coefficient as f32).to_ne_bytes());
266 } else {
267 prim_bytes.extend_from_slice(&0.0f32.to_ne_bytes());
268 prim_bytes.extend_from_slice(&0.0f32.to_ne_bytes());
269 }
270 }
271 }
272
273 (basis_bytes, prim_bytes)
274}
275
276fn evaluate_orbital_gpu(
278 ctx: &GpuContext,
279 basis: &BasisSet,
280 mo_coefficients: &DMatrix<f64>,
281 orbital_index: usize,
282 params: &GridParams,
283) -> Result<OrbitalGrid, String> {
284 let n_basis = basis.n_basis;
285 let n_points = params.n_points();
286
287 let (basis_bytes, prim_bytes) = pack_basis_for_gpu(basis);
289
290 let mo_coeffs: Vec<f32> = (0..n_basis)
292 .map(|mu| mo_coefficients[(mu, orbital_index)] as f32)
293 .collect();
294
295 let mut params_bytes = Vec::with_capacity(32);
297 for v in ¶ms.origin {
298 params_bytes.extend_from_slice(&(*v as f32).to_ne_bytes());
299 }
300 params_bytes.extend_from_slice(&(params.spacing as f32).to_ne_bytes());
301 for d in ¶ms.dimensions {
302 params_bytes.extend_from_slice(&(*d as u32).to_ne_bytes());
303 }
304 params_bytes.extend_from_slice(&(orbital_index as u32).to_ne_bytes());
305
306 let output_seed = vec![0.0f32; n_points];
308
309 let [nx, ny, nz] = params.dimensions;
310 let wg = [
311 (nx as u32).div_ceil(8),
312 (ny as u32).div_ceil(8),
313 (nz as u32).div_ceil(4),
314 ];
315
316 let descriptor = ComputeDispatchDescriptor {
317 label: "orbital grid".to_string(),
318 shader_source: ORBITAL_GRID_SHADER.to_string(),
319 entry_point: "main".to_string(),
320 workgroup_count: wg,
321 bindings: vec![
322 ComputeBindingDescriptor {
323 label: "basis".to_string(),
324 kind: ComputeBindingKind::StorageReadOnly,
325 bytes: basis_bytes,
326 },
327 ComputeBindingDescriptor {
328 label: "mo_coeffs".to_string(),
329 kind: ComputeBindingKind::StorageReadOnly,
330 bytes: f32_slice_to_bytes(&mo_coeffs),
331 },
332 ComputeBindingDescriptor {
333 label: "primitives".to_string(),
334 kind: ComputeBindingKind::StorageReadOnly,
335 bytes: prim_bytes,
336 },
337 ComputeBindingDescriptor {
338 label: "params".to_string(),
339 kind: ComputeBindingKind::Uniform,
340 bytes: params_bytes,
341 },
342 ComputeBindingDescriptor {
343 label: "output".to_string(),
344 kind: ComputeBindingKind::StorageReadWrite,
345 bytes: f32_slice_to_bytes(&output_seed),
346 },
347 ],
348 };
349
350 let mut result = ctx.run_compute(&descriptor)?.outputs;
351 let bytes = result.pop().ok_or("No output from orbital grid kernel")?;
352 let f32_values = bytes_to_f32_vec(&bytes);
353
354 if f32_values.len() != n_points {
355 return Err(format!(
356 "Output size mismatch: expected {}, got {}",
357 n_points,
358 f32_values.len()
359 ));
360 }
361
362 let values: Vec<f64> = f32_values.iter().map(|v| *v as f64).collect();
363
364 Ok(OrbitalGrid {
365 values,
366 params: params.clone(),
367 orbital_index,
368 })
369}
370
371pub const ORBITAL_GRID_SHADER: &str = r#"
376struct BasisFunc {
377 center_x: f32, center_y: f32, center_z: f32,
378 lx: u32, ly: u32, lz: u32,
379 n_primitives: u32,
380 norm_coeff: f32,
381};
382
383struct GridParams {
384 origin_x: f32, origin_y: f32, origin_z: f32,
385 spacing: f32,
386 dims_x: u32, dims_y: u32, dims_z: u32,
387 orbital_index: u32,
388};
389
390@group(0) @binding(0) var<storage, read> basis: array<BasisFunc>;
391@group(0) @binding(1) var<storage, read> mo_coeffs: array<f32>;
392@group(0) @binding(2) var<storage, read> primitives: array<vec2<f32>>;
393@group(0) @binding(3) var<uniform> params: GridParams;
394@group(0) @binding(4) var<storage, read_write> output: array<f32>;
395
396@compute @workgroup_size(8, 8, 4)
397fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
398 let ix = gid.x;
399 let iy = gid.y;
400 let iz = gid.z;
401
402 if (ix >= params.dims_x || iy >= params.dims_y || iz >= params.dims_z) {
403 return;
404 }
405
406 let rx = params.origin_x + f32(ix) * params.spacing;
407 let ry = params.origin_y + f32(iy) * params.spacing;
408 let rz = params.origin_z + f32(iz) * params.spacing;
409
410 let flat_idx = ix * params.dims_y * params.dims_z + iy * params.dims_z + iz;
411 let n_basis = arrayLength(&mo_coeffs);
412
413 var psi: f32 = 0.0;
414
415 for (var mu: u32 = 0u; mu < n_basis; mu = mu + 1u) {
416 let c_mu = mo_coeffs[mu];
417 if (abs(c_mu) < 1e-7) {
418 continue;
419 }
420
421 let bf = basis[mu];
422 let dx = rx - bf.center_x;
423 let dy = ry - bf.center_y;
424 let dz = rz - bf.center_z;
425 let r2 = dx * dx + dy * dy + dz * dz;
426
427 // Angular part
428 var angular: f32 = 1.0;
429 for (var i: u32 = 0u; i < bf.lx; i = i + 1u) { angular *= dx; }
430 for (var i: u32 = 0u; i < bf.ly; i = i + 1u) { angular *= dy; }
431 for (var i: u32 = 0u; i < bf.lz; i = i + 1u) { angular *= dz; }
432
433 // Radial part (contracted, max 3 primitives for STO-3G)
434 var radial: f32 = 0.0;
435 for (var p: u32 = 0u; p < bf.n_primitives; p = p + 1u) {
436 let prim = primitives[mu * 3u + p];
437 radial += prim.y * exp(-prim.x * r2);
438 }
439
440 psi += c_mu * bf.norm_coeff * angular * radial;
441 }
442
443 output[flat_idx] = psi;
444}
445"#;
446
447#[cfg(test)]
448mod tests {
449 use super::*;
450
451 #[test]
452 fn test_grid_params_from_molecule() {
453 let positions = vec![[0.0, 0.0, 0.0], [2.0, 0.0, 0.0]];
454 let params = GridParams::from_molecule(&positions, 0.5, 3.0);
455 assert!(params.dimensions[0] > 0);
456 assert!(params.n_points() > 0);
457 assert!(params.origin[0] < -2.0);
458 }
459
460 #[test]
461 fn test_grid_point_coordinates() {
462 let params = GridParams {
463 origin: [0.0, 0.0, 0.0],
464 spacing: 1.0,
465 dimensions: [3, 3, 3],
466 };
467 let p = params.point(1, 2, 0);
468 assert!((p[0] - 1.0).abs() < 1e-12);
469 assert!((p[1] - 2.0).abs() < 1e-12);
470 }
471
472 #[test]
473 fn test_flat_index() {
474 let params = GridParams {
475 origin: [0.0, 0.0, 0.0],
476 spacing: 1.0,
477 dimensions: [3, 4, 5],
478 };
479 assert_eq!(params.flat_index(0, 0, 0), 0);
480 assert_eq!(params.flat_index(0, 0, 1), 1);
481 assert_eq!(params.flat_index(0, 1, 0), 5);
482 assert_eq!(params.flat_index(1, 0, 0), 20);
483 }
484
485 #[test]
486 fn test_evaluate_orbital_cpu_h2() {
487 let elements = [1u8, 1];
489 let positions = [[0.0, 0.0, 0.0], [0.0, 0.0, 1.4]]; let basis = BasisSet::sto3g(&elements, &positions);
491
492 let n = basis.n_basis;
494 let mut coeffs = DMatrix::zeros(n, n);
495 let c = 1.0 / (2.0f64).sqrt();
496 coeffs[(0, 0)] = c;
497 if n > 1 {
498 coeffs[(1, 0)] = c;
499 }
500
501 let params = GridParams {
502 origin: [-2.0, -2.0, -2.0],
503 spacing: 0.5,
504 dimensions: [5, 5, 13],
505 };
506
507 let grid = evaluate_orbital_cpu(&basis, &coeffs, 0, ¶ms);
508 assert_eq!(grid.values.len(), params.n_points());
509
510 let center_idx = params.flat_index(2, 2, 5); assert!(grid.values[center_idx].abs() > 1e-6);
513 }
514
515 #[test]
516 fn test_evaluate_orbital_with_report() {
517 let elements = [1u8, 1];
518 let positions = [[0.0, 0.0, 0.0], [0.0, 0.0, 1.4]];
519 let basis = BasisSet::sto3g(&elements, &positions);
520
521 let n = basis.n_basis;
522 let mut coeffs = DMatrix::zeros(n, n);
523 coeffs[(0, 0)] = 1.0 / (2.0f64).sqrt();
524 if n > 1 {
525 coeffs[(1, 0)] = 1.0 / (2.0f64).sqrt();
526 }
527
528 let params = GridParams {
529 origin: [-1.0, -1.0, -1.0],
530 spacing: 1.0,
531 dimensions: [3, 3, 5],
532 };
533
534 let (grid, report) = evaluate_orbital_with_report(&basis, &coeffs, 0, ¶ms);
535 assert_eq!(grid.values.len(), params.n_points());
536 assert!(!report.backend.is_empty());
537 assert_eq!(report.n_points, params.n_points());
538 }
539
540 #[test]
541 fn test_evaluate_density_cpu() {
542 let elements = [1u8, 1];
543 let positions = [[0.0, 0.0, 0.0], [0.0, 0.0, 1.4]];
544 let basis = BasisSet::sto3g(&elements, &positions);
545
546 let n = basis.n_basis;
547 let density = DMatrix::from_fn(n, n, |i, j| if i == j { 1.0 } else { 0.3 });
549
550 let params = GridParams {
551 origin: [-1.0, -1.0, -1.0],
552 spacing: 1.0,
553 dimensions: [3, 3, 4],
554 };
555
556 let values = evaluate_density_cpu(&basis, &density, ¶ms);
557 assert_eq!(values.len(), params.n_points());
558 }
560
561 #[test]
562 fn test_pack_basis_for_gpu() {
563 let elements = [1u8];
564 let positions = [[0.0, 0.0, 0.0]];
565 let basis = BasisSet::sto3g(&elements, &positions);
566
567 let (basis_bytes, prim_bytes) = pack_basis_for_gpu(&basis);
568 assert_eq!(basis_bytes.len(), basis.n_basis * 32);
570 assert_eq!(prim_bytes.len(), basis.n_basis * 24);
572 }
573}