Skip to main content

sci_form/gpu/
alpb_born_gpu.rs

1//! GPU-accelerated ALPB Born-radii evaluation.
2
3use super::context::{
4    bytes_to_f64_vec_from_f32, ceil_div_u32, f32_slice_to_bytes, pack_uniform_values,
5    pack_vec3_positions_f32, ComputeBindingDescriptor, ComputeBindingKind,
6    ComputeDispatchDescriptor, GpuContext, UniformValue,
7};
8use crate::solvation_alpb::{intrinsic_radius, AlpbBornRadii};
9
10const GPU_DISPATCH_THRESHOLD: usize = 3;
11
12pub fn compute_born_radii_gpu(
13    ctx: &GpuContext,
14    elements: &[u8],
15    positions: &[[f64; 3]],
16    probe_radius: f64,
17) -> Result<AlpbBornRadii, String> {
18    let n = elements.len();
19    if n < GPU_DISPATCH_THRESHOLD {
20        return Err("System too small for GPU dispatch".to_string());
21    }
22    if positions.len() != n {
23        return Err("elements/position length mismatch".to_string());
24    }
25
26    let intrinsic: Vec<f64> = elements.iter().map(|&z| intrinsic_radius(z)).collect();
27    let rho: Vec<f32> = intrinsic
28        .iter()
29        .map(|radius| (*radius + probe_radius * 0.1) as f32)
30        .collect();
31
32    let params = pack_uniform_values(&[
33        UniformValue::U32(n as u32),
34        UniformValue::U32(0),
35        UniformValue::U32(0),
36        UniformValue::U32(0),
37        UniformValue::F32(1.0),
38        UniformValue::F32(0.8),
39        UniformValue::F32(4.85),
40        UniformValue::F32(0.0),
41    ]);
42
43    let descriptor = ComputeDispatchDescriptor {
44        label: "alpb born radii".to_string(),
45        shader_source: ALPB_BORN_RADII_SHADER.to_string(),
46        entry_point: "main".to_string(),
47        workgroup_count: [ceil_div_u32(n, 64), 1, 1],
48        bindings: vec![
49            ComputeBindingDescriptor {
50                label: "positions".to_string(),
51                kind: ComputeBindingKind::StorageReadOnly,
52                bytes: pack_vec3_positions_f32(positions),
53            },
54            ComputeBindingDescriptor {
55                label: "rho".to_string(),
56                kind: ComputeBindingKind::StorageReadOnly,
57                bytes: f32_slice_to_bytes(&rho),
58            },
59            ComputeBindingDescriptor {
60                label: "params".to_string(),
61                kind: ComputeBindingKind::Uniform,
62                bytes: params,
63            },
64            ComputeBindingDescriptor {
65                label: "output".to_string(),
66                kind: ComputeBindingKind::StorageReadWrite,
67                bytes: f32_slice_to_bytes(&vec![0.0f32; n]),
68            },
69        ],
70    };
71
72    let mut result = ctx.run_compute(&descriptor)?;
73    let bytes = result
74        .outputs
75        .pop()
76        .ok_or("No output from ALPB Born kernel")?;
77    if bytes.len() != n * 4 {
78        return Err(format!(
79            "ALPB Born output size mismatch: expected {}, got {}",
80            n * 4,
81            bytes.len()
82        ));
83    }
84
85    let radii = bytes_to_f64_vec_from_f32(&bytes);
86
87    Ok(AlpbBornRadii { radii, intrinsic })
88}
89
90pub const ALPB_BORN_RADII_SHADER: &str = r#"
91struct AtomPos {
92    x: f32, y: f32, z: f32, _pad: f32,
93};
94
95struct Params {
96    n_atoms: u32,
97    _pad0: u32,
98    _pad1: u32,
99    _pad2: u32,
100    alpha: f32,
101    beta: f32,
102    gamma: f32,
103    _pad3: f32,
104};
105
106@group(0) @binding(0) var<storage, read> positions: array<AtomPos>;
107@group(0) @binding(1) var<storage, read> rho: array<f32>;
108@group(0) @binding(2) var<uniform> params: Params;
109@group(0) @binding(3) var<storage, read_write> output: array<f32>;
110
111@compute @workgroup_size(64, 1, 1)
112fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
113    let i = gid.x;
114    let n = params.n_atoms;
115    if (i >= n) {
116        return;
117    }
118
119    let rho_i = rho[i];
120    let pos_i = positions[i];
121    var psi: f32 = 0.0;
122
123    for (var j: u32 = 0u; j < n; j = j + 1u) {
124        if (i == j) {
125            continue;
126        }
127
128        let pos_j = positions[j];
129        let dx = pos_i.x - pos_j.x;
130        let dy = pos_i.y - pos_j.y;
131        let dz = pos_i.z - pos_j.z;
132        let r_ij = sqrt(dx * dx + dy * dy + dz * dz);
133        let rho_j = rho[j];
134
135        if (r_ij > rho_j) {
136            let l_ij = max(rho_i, r_ij - rho_j);
137            let u_ij = r_ij + rho_j;
138            if (u_ij > l_ij) {
139                psi += 0.5 * (
140                    (1.0 / l_ij) - (1.0 / u_ij)
141                    + 0.25 * ((1.0 / u_ij) - (1.0 / l_ij)) * (r_ij * r_ij - rho_j * rho_j)
142                    + 0.5 * ((1.0 / (u_ij * u_ij)) - (1.0 / (l_ij * l_ij))) * r_ij
143                );
144            }
145        }
146    }
147
148    let psi_scaled = psi * rho_i;
149    let tanh_val = tanh(
150        params.alpha * psi_scaled
151        - params.beta * psi_scaled * psi_scaled
152        + params.gamma * psi_scaled * psi_scaled * psi_scaled
153    );
154
155    let inv_r_eff = 1.0 / rho_i - tanh_val / rho_i;
156    output[i] = select(100.0, 1.0 / inv_r_eff, inv_r_eff > 1e-10);
157}
158"#;
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163
164    #[test]
165    fn test_alpb_gpu_threshold() {
166        let ctx = GpuContext::cpu_fallback();
167        let result =
168            compute_born_radii_gpu(&ctx, &[8, 1], &[[0.0, 0.0, 0.0], [0.7, 0.0, 0.0]], 1.4);
169        assert!(result.is_err());
170    }
171}