1use super::backend_report::OrbitalGridReport;
4use super::context::{
5 bytes_to_f64_vec_from_f32, ceil_div_u32, f32_slice_to_bytes, pack_uniform_values,
6 pack_vec3_positions_f32, ComputeBindingDescriptor, ComputeBindingKind,
7 ComputeDispatchDescriptor, GpuContext, UniformValue,
8};
9use crate::esp::{compute_esp_grid, EspGrid};
10
11pub fn compute_esp_grid_with_report(
12 elements: &[u8],
13 positions: &[[f64; 3]],
14 mulliken_charges: &[f64],
15 spacing: f64,
16 padding: f64,
17) -> (EspGrid, OrbitalGridReport) {
18 let ctx = GpuContext::best_available();
19 if ctx.is_gpu_available() {
20 match compute_esp_grid_gpu(&ctx, positions, mulliken_charges, spacing, padding) {
21 Ok(grid) => {
22 let n_points = grid.dims[0] * grid.dims[1] * grid.dims[2];
23 return (
24 grid,
25 OrbitalGridReport {
26 backend: ctx.capabilities.backend.clone(),
27 used_gpu: true,
28 attempted_gpu: true,
29 n_points,
30 note: format!("GPU ESP-grid dispatch on {}", ctx.capabilities.backend),
31 },
32 );
33 }
34 Err(_err) => {}
35 }
36 }
37
38 let grid = compute_esp_grid(elements, positions, mulliken_charges, spacing, padding);
39 let n_points = grid.dims[0] * grid.dims[1] * grid.dims[2];
40 (
41 grid,
42 OrbitalGridReport {
43 backend: "CPU".to_string(),
44 used_gpu: false,
45 attempted_gpu: ctx.is_gpu_available(),
46 n_points,
47 note: if ctx.is_gpu_available() {
48 "GPU available but ESP-grid dispatch failed; CPU fallback used".to_string()
49 } else {
50 "CPU ESP-grid evaluation (GPU not available)".to_string()
51 },
52 },
53 )
54}
55
56pub fn compute_esp_grid_gpu(
57 ctx: &GpuContext,
58 positions: &[[f64; 3]],
59 mulliken_charges: &[f64],
60 spacing: f64,
61 padding: f64,
62) -> Result<EspGrid, String> {
63 if positions.len() != mulliken_charges.len() {
64 return Err("positions/charges length mismatch".to_string());
65 }
66
67 let mut min = [f64::MAX; 3];
68 let mut max = [f64::MIN; 3];
69 for pos in positions {
70 for axis in 0..3 {
71 min[axis] = min[axis].min(pos[axis]);
72 max[axis] = max[axis].max(pos[axis]);
73 }
74 }
75
76 let origin = [min[0] - padding, min[1] - padding, min[2] - padding];
77 let dims = [
78 ((max[0] - min[0] + 2.0 * padding) / spacing).ceil() as usize + 1,
79 ((max[1] - min[1] + 2.0 * padding) / spacing).ceil() as usize + 1,
80 ((max[2] - min[2] + 2.0 * padding) / spacing).ceil() as usize + 1,
81 ];
82 let total = dims[0] * dims[1] * dims[2];
83
84 let params_bytes = pack_uniform_values(&[
85 UniformValue::F32(origin[0] as f32),
86 UniformValue::F32(origin[1] as f32),
87 UniformValue::F32(origin[2] as f32),
88 UniformValue::F32(spacing as f32),
89 UniformValue::U32(dims[0] as u32),
90 UniformValue::U32(dims[1] as u32),
91 UniformValue::U32(dims[2] as u32),
92 UniformValue::U32(positions.len() as u32),
93 ]);
94
95 let descriptor = ComputeDispatchDescriptor {
96 label: "esp grid".to_string(),
97 shader_source: ESP_GRID_SHADER.to_string(),
98 entry_point: "main".to_string(),
99 workgroup_count: [
100 ceil_div_u32(dims[0], 8),
101 ceil_div_u32(dims[1], 8),
102 ceil_div_u32(dims[2], 4),
103 ],
104 bindings: vec![
105 ComputeBindingDescriptor {
106 label: "positions".to_string(),
107 kind: ComputeBindingKind::StorageReadOnly,
108 bytes: pack_vec3_positions_f32(positions),
109 },
110 ComputeBindingDescriptor {
111 label: "charges".to_string(),
112 kind: ComputeBindingKind::StorageReadOnly,
113 bytes: f32_slice_to_bytes(
114 &mulliken_charges
115 .iter()
116 .map(|value| *value as f32)
117 .collect::<Vec<_>>(),
118 ),
119 },
120 ComputeBindingDescriptor {
121 label: "params".to_string(),
122 kind: ComputeBindingKind::Uniform,
123 bytes: params_bytes,
124 },
125 ComputeBindingDescriptor {
126 label: "output".to_string(),
127 kind: ComputeBindingKind::StorageReadWrite,
128 bytes: f32_slice_to_bytes(&vec![0.0f32; total]),
129 },
130 ],
131 };
132
133 let mut outputs = ctx.run_compute(&descriptor)?.outputs;
134 let bytes = outputs.pop().ok_or("No output from ESP grid kernel")?;
135 let values = bytes_to_f64_vec_from_f32(&bytes);
136 if values.len() != total {
137 return Err(format!(
138 "Output size mismatch: expected {}, got {}",
139 total,
140 values.len()
141 ));
142 }
143
144 Ok(EspGrid {
145 origin,
146 spacing,
147 dims,
148 values,
149 })
150}
151
152pub const ESP_GRID_SHADER: &str = r#"
153struct AtomPos {
154 x: f32, y: f32, z: f32, _pad: f32,
155};
156
157struct GridParams {
158 origin_x: f32, origin_y: f32, origin_z: f32,
159 spacing: f32,
160 dims_x: u32, dims_y: u32, dims_z: u32,
161 n_atoms: u32,
162};
163
164@group(0) @binding(0) var<storage, read> positions: array<AtomPos>;
165@group(0) @binding(1) var<storage, read> charges: array<f32>;
166@group(0) @binding(2) var<uniform> params: GridParams;
167@group(0) @binding(3) var<storage, read_write> output: array<f32>;
168
169@compute @workgroup_size(8, 8, 4)
170fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
171 let ix = gid.x;
172 let iy = gid.y;
173 let iz = gid.z;
174
175 if (ix >= params.dims_x || iy >= params.dims_y || iz >= params.dims_z) {
176 return;
177 }
178
179 let rx = params.origin_x + f32(ix) * params.spacing;
180 let ry = params.origin_y + f32(iy) * params.spacing;
181 let rz = params.origin_z + f32(iz) * params.spacing;
182 let flat_idx = ix * params.dims_y * params.dims_z + iy * params.dims_z + iz;
183
184 var phi: f32 = 0.0;
185 for (var atom: u32 = 0u; atom < params.n_atoms; atom = atom + 1u) {
186 let pos = positions[atom];
187 let dx = rx - pos.x;
188 let dy = ry - pos.y;
189 let dz = rz - pos.z;
190 let dist = sqrt(dx * dx + dy * dy + dz * dz);
191 if (dist < 0.01) { continue; }
192 phi += charges[atom] / (dist * 1.88972599);
193 }
194 output[flat_idx] = phi;
195}
196"#;