sci_form/gpu/
alpb_born_gpu.rs1use super::context::{
4 bytes_to_f64_vec_from_f32, ceil_div_u32, f32_slice_to_bytes, pack_uniform_values,
5 pack_vec3_positions_f32, ComputeBindingDescriptor, ComputeBindingKind,
6 ComputeDispatchDescriptor, GpuContext, UniformValue,
7};
8use crate::solvation_alpb::{intrinsic_radius, AlpbBornRadii};
9
10const GPU_DISPATCH_THRESHOLD: usize = 3;
11
12pub fn compute_born_radii_gpu(
13 ctx: &GpuContext,
14 elements: &[u8],
15 positions: &[[f64; 3]],
16 probe_radius: f64,
17) -> Result<AlpbBornRadii, String> {
18 let n = elements.len();
19 if n < GPU_DISPATCH_THRESHOLD {
20 return Err("System too small for GPU dispatch".to_string());
21 }
22 if positions.len() != n {
23 return Err("elements/position length mismatch".to_string());
24 }
25
26 let intrinsic: Vec<f64> = elements.iter().map(|&z| intrinsic_radius(z)).collect();
27 let rho: Vec<f32> = intrinsic
28 .iter()
29 .map(|radius| (*radius + probe_radius * 0.1) as f32)
30 .collect();
31
32 let params = pack_uniform_values(&[
33 UniformValue::U32(n as u32),
34 UniformValue::U32(0),
35 UniformValue::U32(0),
36 UniformValue::U32(0),
37 UniformValue::F32(1.0),
38 UniformValue::F32(0.8),
39 UniformValue::F32(4.85),
40 UniformValue::F32(0.0),
41 ]);
42
43 let descriptor = ComputeDispatchDescriptor {
44 label: "alpb born radii".to_string(),
45 shader_source: ALPB_BORN_RADII_SHADER.to_string(),
46 entry_point: "main".to_string(),
47 workgroup_count: [ceil_div_u32(n, 64), 1, 1],
48 bindings: vec![
49 ComputeBindingDescriptor {
50 label: "positions".to_string(),
51 kind: ComputeBindingKind::StorageReadOnly,
52 bytes: pack_vec3_positions_f32(positions),
53 },
54 ComputeBindingDescriptor {
55 label: "rho".to_string(),
56 kind: ComputeBindingKind::StorageReadOnly,
57 bytes: f32_slice_to_bytes(&rho),
58 },
59 ComputeBindingDescriptor {
60 label: "params".to_string(),
61 kind: ComputeBindingKind::Uniform,
62 bytes: params,
63 },
64 ComputeBindingDescriptor {
65 label: "output".to_string(),
66 kind: ComputeBindingKind::StorageReadWrite,
67 bytes: f32_slice_to_bytes(&vec![0.0f32; n]),
68 },
69 ],
70 };
71
72 let mut result = ctx.run_compute(&descriptor)?;
73 let bytes = result
74 .outputs
75 .pop()
76 .ok_or("No output from ALPB Born kernel")?;
77 if bytes.len() != n * 4 {
78 return Err(format!(
79 "ALPB Born output size mismatch: expected {}, got {}",
80 n * 4,
81 bytes.len()
82 ));
83 }
84
85 let radii = bytes_to_f64_vec_from_f32(&bytes);
86
87 Ok(AlpbBornRadii { radii, intrinsic })
88}
89
90pub const ALPB_BORN_RADII_SHADER: &str = r#"
91struct AtomPos {
92 x: f32, y: f32, z: f32, _pad: f32,
93};
94
95struct Params {
96 n_atoms: u32,
97 _pad0: u32,
98 _pad1: u32,
99 _pad2: u32,
100 alpha: f32,
101 beta: f32,
102 gamma: f32,
103 _pad3: f32,
104};
105
106@group(0) @binding(0) var<storage, read> positions: array<AtomPos>;
107@group(0) @binding(1) var<storage, read> rho: array<f32>;
108@group(0) @binding(2) var<uniform> params: Params;
109@group(0) @binding(3) var<storage, read_write> output: array<f32>;
110
111@compute @workgroup_size(64, 1, 1)
112fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
113 let i = gid.x;
114 let n = params.n_atoms;
115 if (i >= n) {
116 return;
117 }
118
119 let rho_i = rho[i];
120 let pos_i = positions[i];
121 var psi: f32 = 0.0;
122
123 for (var j: u32 = 0u; j < n; j = j + 1u) {
124 if (i == j) {
125 continue;
126 }
127
128 let pos_j = positions[j];
129 let dx = pos_i.x - pos_j.x;
130 let dy = pos_i.y - pos_j.y;
131 let dz = pos_i.z - pos_j.z;
132 let r_ij = sqrt(dx * dx + dy * dy + dz * dz);
133 let rho_j = rho[j];
134
135 if (r_ij > rho_j) {
136 let l_ij = max(rho_i, r_ij - rho_j);
137 let u_ij = r_ij + rho_j;
138 if (u_ij > l_ij) {
139 psi += 0.5 * (
140 (1.0 / l_ij) - (1.0 / u_ij)
141 + 0.25 * ((1.0 / u_ij) - (1.0 / l_ij)) * (r_ij * r_ij - rho_j * rho_j)
142 + 0.5 * ((1.0 / (u_ij * u_ij)) - (1.0 / (l_ij * l_ij))) * r_ij
143 );
144 }
145 }
146 }
147
148 let psi_scaled = psi * rho_i;
149 let tanh_val = tanh(
150 params.alpha * psi_scaled
151 - params.beta * psi_scaled * psi_scaled
152 + params.gamma * psi_scaled * psi_scaled * psi_scaled
153 );
154
155 let inv_r_eff = 1.0 / rho_i - tanh_val / rho_i;
156 output[i] = select(100.0, 1.0 / inv_r_eff, inv_r_eff > 1e-10);
157}
158"#;
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163
164 #[test]
165 fn test_alpb_gpu_threshold() {
166 let ctx = GpuContext::cpu_fallback();
167 let result =
168 compute_born_radii_gpu(&ctx, &[8, 1], &[[0.0, 0.0, 0.0], [0.7, 0.0, 0.0]], 1.4);
169 assert!(result.is_err());
170 }
171}