sci_form/gpu/
hct_born_gpu.rs1use super::context::{
4 bytes_to_f64_vec_from_f32, ceil_div_u32, f32_slice_to_bytes, pack_uniform_values,
5 pack_vec3_positions_f32, ComputeBindingDescriptor, ComputeBindingKind,
6 ComputeDispatchDescriptor, GpuContext, UniformValue,
7};
8
9const GPU_DISPATCH_THRESHOLD: usize = 8;
10
11fn intrinsic_born_radius(z: u8) -> f64 {
13 match z {
14 1 => 1.20,
15 6 => 1.70,
16 7 => 1.55,
17 8 => 1.52,
18 9 => 1.47,
19 15 => 1.80,
20 16 => 1.80,
21 17 => 1.75,
22 35 => 1.85,
23 53 => 1.98,
24 _ => 1.70,
25 }
26}
27
28fn hct_scale(z: u8) -> f64 {
30 match z {
31 1 => 0.85,
32 6 => 0.72,
33 7 => 0.79,
34 8 => 0.85,
35 9 => 0.88,
36 _ => 0.80,
37 }
38}
39
40pub fn compute_hct_born_radii_gpu(
42 ctx: &GpuContext,
43 elements: &[u8],
44 positions: &[[f64; 3]],
45) -> Result<Vec<f64>, String> {
46 let n = elements.len();
47 if n < GPU_DISPATCH_THRESHOLD {
48 return Err("System too small for GPU dispatch".to_string());
49 }
50
51 let rho: Vec<f32> = elements
52 .iter()
53 .map(|&z| intrinsic_born_radius(z) as f32)
54 .collect();
55 let scale: Vec<f32> = elements.iter().map(|&z| hct_scale(z) as f32).collect();
56
57 let params = pack_uniform_values(&[
58 UniformValue::U32(n as u32),
59 UniformValue::U32(0),
60 UniformValue::U32(0),
61 UniformValue::U32(0),
62 ]);
63
64 let descriptor = ComputeDispatchDescriptor {
65 label: "hct born radii".to_string(),
66 shader_source: HCT_BORN_RADII_SHADER.to_string(),
67 entry_point: "main".to_string(),
68 workgroup_count: [ceil_div_u32(n, 64), 1, 1],
69 bindings: vec![
70 ComputeBindingDescriptor {
71 label: "positions".to_string(),
72 kind: ComputeBindingKind::StorageReadOnly,
73 bytes: pack_vec3_positions_f32(positions),
74 },
75 ComputeBindingDescriptor {
76 label: "rho".to_string(),
77 kind: ComputeBindingKind::StorageReadOnly,
78 bytes: f32_slice_to_bytes(&rho),
79 },
80 ComputeBindingDescriptor {
81 label: "scale".to_string(),
82 kind: ComputeBindingKind::StorageReadOnly,
83 bytes: f32_slice_to_bytes(&scale),
84 },
85 ComputeBindingDescriptor {
86 label: "params".to_string(),
87 kind: ComputeBindingKind::Uniform,
88 bytes: params,
89 },
90 ComputeBindingDescriptor {
91 label: "output".to_string(),
92 kind: ComputeBindingKind::StorageReadWrite,
93 bytes: f32_slice_to_bytes(&vec![0.0f32; n]),
94 },
95 ],
96 };
97
98 let mut result = ctx.run_compute(&descriptor)?;
99 let bytes = result
100 .outputs
101 .pop()
102 .ok_or("No output from HCT Born kernel")?;
103 if bytes.len() != n * 4 {
104 return Err(format!(
105 "HCT Born output size mismatch: expected {}, got {}",
106 n * 4,
107 bytes.len()
108 ));
109 }
110
111 Ok(bytes_to_f64_vec_from_f32(&bytes))
112}
113
114pub const HCT_BORN_RADII_SHADER: &str = r#"
115struct AtomPos {
116 x: f32, y: f32, z: f32, _pad: f32,
117};
118
119struct Params {
120 n_atoms: u32,
121 _pad0: u32,
122 _pad1: u32,
123 _pad2: u32,
124};
125
126@group(0) @binding(0) var<storage, read> positions: array<AtomPos>;
127@group(0) @binding(1) var<storage, read> rho: array<f32>;
128@group(0) @binding(2) var<storage, read> scale: array<f32>;
129@group(0) @binding(3) var<uniform> params: Params;
130@group(0) @binding(4) var<storage, read_write> output: array<f32>;
131
132@compute @workgroup_size(64, 1, 1)
133fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
134 let i = gid.x;
135 let n = params.n_atoms;
136 if (i >= n) {
137 return;
138 }
139
140 let rho_i = rho[i];
141 let pos_i = positions[i];
142 var integral: f32 = 0.0;
143
144 for (var j: u32 = 0u; j < n; j = j + 1u) {
145 if (i == j) {
146 continue;
147 }
148
149 let pos_j = positions[j];
150 let dx = pos_i.x - pos_j.x;
151 let dy = pos_i.y - pos_j.y;
152 let dz = pos_i.z - pos_j.z;
153 let rij = sqrt(dx * dx + dy * dy + dz * dz);
154 let scaled_rj = rho[j] * scale[j];
155
156 if (rij > rho_i + scaled_rj) {
157 let denom1 = max(rij - scaled_rj, 1e-10);
158 let denom2 = rij + scaled_rj;
159 let denom3 = max(abs(rij * rij - scaled_rj * scaled_rj), 1e-10);
160 var ljr: f32 = 0.0;
161 if (rij > scaled_rj && scaled_rj > 1e-10) {
162 ljr = log(rij / scaled_rj);
163 }
164 integral += 0.5 * (1.0 / denom1 - 1.0 / denom2)
165 + scaled_rj * ljr / (2.0 * rij * max(denom3, 1e-10));
166 } else if (rij + rho_i > scaled_rj) {
167 let denom = max(abs(rij - scaled_rj), 1e-10);
168 integral += 0.5 * (1.0 / denom - 1.0 / (rij + scaled_rj));
169 }
170 }
171
172 let inv_r = 1.0 / rho_i - integral;
173 var born_r: f32;
174 if (inv_r > 1e-10) {
175 born_r = 1.0 / inv_r;
176 } else {
177 born_r = 50.0;
178 }
179 output[i] = max(born_r, rho_i);
180}
181"#;
182
183#[cfg(test)]
184mod tests {
185 use super::HCT_BORN_RADII_SHADER;
186
187 #[test]
188 fn test_hct_born_gpu_module_compiles() {
189 assert!(!HCT_BORN_RADII_SHADER.is_empty());
190 }
191}