Skip to main content

sci_form/gpu/
hct_born_gpu.rs

1//! GPU-accelerated HCT Born-radii evaluation for Generalized Born solvation.
2
3use super::context::{
4    bytes_to_f64_vec_from_f32, ceil_div_u32, f32_slice_to_bytes, pack_uniform_values,
5    pack_vec3_positions_f32, ComputeBindingDescriptor, ComputeBindingKind,
6    ComputeDispatchDescriptor, GpuContext, UniformValue,
7};
8
9const GPU_DISPATCH_THRESHOLD: usize = 8;
10
11/// Intrinsic Born radius for HCT model (same as solvation.rs).
12fn intrinsic_born_radius(z: u8) -> f64 {
13    match z {
14        1 => 1.20,
15        6 => 1.70,
16        7 => 1.55,
17        8 => 1.52,
18        9 => 1.47,
19        15 => 1.80,
20        16 => 1.80,
21        17 => 1.75,
22        35 => 1.85,
23        53 => 1.98,
24        _ => 1.70,
25    }
26}
27
28/// HCT descreening scale factor.
29fn hct_scale(z: u8) -> f64 {
30    match z {
31        1 => 0.85,
32        6 => 0.72,
33        7 => 0.79,
34        8 => 0.85,
35        9 => 0.88,
36        _ => 0.80,
37    }
38}
39
40/// Compute HCT Born radii on GPU.
41pub fn compute_hct_born_radii_gpu(
42    ctx: &GpuContext,
43    elements: &[u8],
44    positions: &[[f64; 3]],
45) -> Result<Vec<f64>, String> {
46    let n = elements.len();
47    if n < GPU_DISPATCH_THRESHOLD {
48        return Err("System too small for GPU dispatch".to_string());
49    }
50
51    let rho: Vec<f32> = elements
52        .iter()
53        .map(|&z| intrinsic_born_radius(z) as f32)
54        .collect();
55    let scale: Vec<f32> = elements.iter().map(|&z| hct_scale(z) as f32).collect();
56
57    let params = pack_uniform_values(&[
58        UniformValue::U32(n as u32),
59        UniformValue::U32(0),
60        UniformValue::U32(0),
61        UniformValue::U32(0),
62    ]);
63
64    let descriptor = ComputeDispatchDescriptor {
65        label: "hct born radii".to_string(),
66        shader_source: HCT_BORN_RADII_SHADER.to_string(),
67        entry_point: "main".to_string(),
68        workgroup_count: [ceil_div_u32(n, 64), 1, 1],
69        bindings: vec![
70            ComputeBindingDescriptor {
71                label: "positions".to_string(),
72                kind: ComputeBindingKind::StorageReadOnly,
73                bytes: pack_vec3_positions_f32(positions),
74            },
75            ComputeBindingDescriptor {
76                label: "rho".to_string(),
77                kind: ComputeBindingKind::StorageReadOnly,
78                bytes: f32_slice_to_bytes(&rho),
79            },
80            ComputeBindingDescriptor {
81                label: "scale".to_string(),
82                kind: ComputeBindingKind::StorageReadOnly,
83                bytes: f32_slice_to_bytes(&scale),
84            },
85            ComputeBindingDescriptor {
86                label: "params".to_string(),
87                kind: ComputeBindingKind::Uniform,
88                bytes: params,
89            },
90            ComputeBindingDescriptor {
91                label: "output".to_string(),
92                kind: ComputeBindingKind::StorageReadWrite,
93                bytes: f32_slice_to_bytes(&vec![0.0f32; n]),
94            },
95        ],
96    };
97
98    let mut result = ctx.run_compute(&descriptor)?;
99    let bytes = result
100        .outputs
101        .pop()
102        .ok_or("No output from HCT Born kernel")?;
103    if bytes.len() != n * 4 {
104        return Err(format!(
105            "HCT Born output size mismatch: expected {}, got {}",
106            n * 4,
107            bytes.len()
108        ));
109    }
110
111    Ok(bytes_to_f64_vec_from_f32(&bytes))
112}
113
114pub const HCT_BORN_RADII_SHADER: &str = r#"
115struct AtomPos {
116    x: f32, y: f32, z: f32, _pad: f32,
117};
118
119struct Params {
120    n_atoms: u32,
121    _pad0: u32,
122    _pad1: u32,
123    _pad2: u32,
124};
125
126@group(0) @binding(0) var<storage, read> positions: array<AtomPos>;
127@group(0) @binding(1) var<storage, read> rho: array<f32>;
128@group(0) @binding(2) var<storage, read> scale: array<f32>;
129@group(0) @binding(3) var<uniform> params: Params;
130@group(0) @binding(4) var<storage, read_write> output: array<f32>;
131
132@compute @workgroup_size(64, 1, 1)
133fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
134    let i = gid.x;
135    let n = params.n_atoms;
136    if (i >= n) {
137        return;
138    }
139
140    let rho_i = rho[i];
141    let pos_i = positions[i];
142    var integral: f32 = 0.0;
143
144    for (var j: u32 = 0u; j < n; j = j + 1u) {
145        if (i == j) {
146            continue;
147        }
148
149        let pos_j = positions[j];
150        let dx = pos_i.x - pos_j.x;
151        let dy = pos_i.y - pos_j.y;
152        let dz = pos_i.z - pos_j.z;
153        let rij = sqrt(dx * dx + dy * dy + dz * dz);
154        let scaled_rj = rho[j] * scale[j];
155
156        if (rij > rho_i + scaled_rj) {
157            let denom1 = max(rij - scaled_rj, 1e-10);
158            let denom2 = rij + scaled_rj;
159            let denom3 = max(abs(rij * rij - scaled_rj * scaled_rj), 1e-10);
160            var ljr: f32 = 0.0;
161            if (rij > scaled_rj && scaled_rj > 1e-10) {
162                ljr = log(rij / scaled_rj);
163            }
164            integral += 0.5 * (1.0 / denom1 - 1.0 / denom2)
165                + scaled_rj * ljr / (2.0 * rij * max(denom3, 1e-10));
166        } else if (rij + rho_i > scaled_rj) {
167            let denom = max(abs(rij - scaled_rj), 1e-10);
168            integral += 0.5 * (1.0 / denom - 1.0 / (rij + scaled_rj));
169        }
170    }
171
172    let inv_r = 1.0 / rho_i - integral;
173    var born_r: f32;
174    if (inv_r > 1e-10) {
175        born_r = 1.0 / inv_r;
176    } else {
177        born_r = 50.0;
178    }
179    output[i] = max(born_r, rho_i);
180}
181"#;
182
183#[cfg(test)]
184mod tests {
185    use super::HCT_BORN_RADII_SHADER;
186
187    #[test]
188    fn test_hct_born_gpu_module_compiles() {
189        assert!(!HCT_BORN_RADII_SHADER.is_empty());
190    }
191}