1mod consts;
3mod test;
4mod utils;
5
6use crate::consts::POLAR_AMINO_ACIDS;
7use crate::utils::{serialize_chain_id, simd_sum};
8use nalgebra::{Point3, Vector3};
9use pdbtbx::PDB;
10use rayon::prelude::*;
11use rstar::{PointDistance, RTree, RTreeObject, AABB};
12use snafu::prelude::*;
13use snafu::OptionExt;
14use std::collections::HashMap;
15use std::sync::Arc;
16
17#[derive(Clone)]
19pub struct Atom {
20 pub position: Point3<f32>,
22 pub radius: f32,
24 pub id: usize,
26 pub parent_id: Option<isize>,
28}
29
30pub enum SASALevel {
32 Atom,
33 Residue,
34 Chain,
35 Protein,
36}
37
38#[derive(Debug, PartialEq)]
39pub struct ChainResult {
40 pub name: String,
42 pub value: f32,
44}
45
46#[derive(Debug, PartialEq)]
47pub struct ResidueResult {
48 pub serial_number: isize,
50 pub value: f32,
52 pub name: String,
54 pub is_polar: bool,
56 pub chain_id: String,
58}
59
60#[derive(Debug, PartialEq)]
61pub struct ProteinResult {
62 pub global_total: f32,
64 pub polar_total: f32,
66 pub non_polar_total: f32,
68}
69
70#[derive(Debug, PartialEq)]
71pub enum SASAResult {
72 Atom(Vec<f32>),
73 Residue(Vec<ResidueResult>),
74 Chain(Vec<ChainResult>),
75 Protein(ProteinResult),
76}
77
78#[derive(Debug, Snafu)]
79pub enum SASACalcError {
80 #[snafu(display("Element missing for atom"))]
81 ElementMissing,
82
83 #[snafu(display("Van der Waals radius missing for element"))]
84 VanDerWaalsMissing,
85
86 #[snafu(display("Failed to map atoms back to level element"))]
87 AtomMapToLevelElementFailed,
88
89 #[snafu(display("Failed to get residue name"))]
90 FailedToGetResidueName,
91}
92
93impl RTreeObject for Atom {
94 type Envelope = AABB<[f32; 3]>;
95
96 fn envelope(&self) -> Self::Envelope {
97 AABB::from_point(<[f32; 3]>::from(self.position))
98 }
99}
100
101impl PointDistance for Atom {
102 fn distance_2(&self, other: &[f32; 3]) -> f32 {
103 let xyz = self.position.coords.xyz();
104 let z = xyz[2];
105 let y = xyz[1];
106 let x = xyz[0];
107 (other[2] - z).mul_add(
109 other[2] - z,
110 (other[1] - y).mul_add(other[1] - y, (other[0] - x).powi(2)),
111 )
112 }
113}
114
115fn generate_sphere_points(n_points: usize) -> Vec<Vector3<f32>> {
117 let mut points = Vec::with_capacity(n_points);
118 let golden_ratio = (1.0 + 5f32.sqrt()) / 2.0;
119 let angle_increment = 2.0 * std::f32::consts::PI * golden_ratio;
120
121 for i in 0..n_points {
122 let t = i as f32 / n_points as f32;
123 let inclination = (1.0 - 2.0 * t).acos();
124 let azimuth = angle_increment * i as f32;
125
126 let x = inclination.sin() * azimuth.cos();
127 let y = inclination.sin() * azimuth.sin();
128 let z = inclination.cos();
129
130 points.push(Vector3::new(x, y, z));
131 }
132
133 points
134}
135
136fn is_accessible_rstar(
137 test_point: &Point3<f32>,
138 atom: &Atom,
139 atoms: &RTree<Atom>,
140 probe_radius: f32,
141 max_radii: f32,
142) -> bool {
143 let xyz = test_point.coords.xyz();
144 let sr = probe_radius + (max_radii * 2.0);
145 let candidates = atoms.locate_within_distance([xyz[0], xyz[1], xyz[2]], sr * sr);
146 for candidate in candidates {
147 if atom.id != candidate.id
148 && (test_point - candidate.position).norm() < (candidate.radius + probe_radius)
149 {
150 return false;
151 }
152 }
153 true
154}
155
156pub fn calculate_sasa_internal(
180 atoms: &[Atom],
181 in_probe_radius: Option<f32>,
182 in_n_points: Option<usize>,
183) -> Vec<f32> {
184 let mut probe_radius = 1.4;
186 let mut n_points = 100;
187 if let Some(in_probe_radius) = in_probe_radius {
188 probe_radius = in_probe_radius;
189 }
190 if let Some(in_n_points) = in_n_points {
191 n_points = in_n_points;
192 }
193 let sphere_points = generate_sphere_points(n_points);
195
196 let tree = RTree::bulk_load(atoms.to_vec());
198 let tree_arc = Arc::new(tree); let mut max_radii = 0.0;
200 for atom in atoms {
201 if atom.radius > max_radii {
202 max_radii = atom.radius;
203 }
204 }
205 atoms
206 .par_iter()
207 .map(|atom| {
208 let mut accessible_points = 0;
209
210 for sphere_point in &sphere_points {
211 let test_point = atom.position + sphere_point * (atom.radius + probe_radius);
212 if is_accessible_rstar(&test_point, atom, &tree_arc, probe_radius, max_radii) {
213 accessible_points += 1;
214 }
215 }
216 4.0 * std::f32::consts::PI
217 * (atom.radius + probe_radius).powi(2)
218 * (accessible_points as f32)
219 / (n_points as f32)
220 })
221 .collect()
222}
223
224pub fn calculate_sasa(
237 pdb: &PDB,
238 probe_radius: Option<f32>,
239 n_points: Option<usize>,
240 level: SASALevel,
241) -> Result<SASAResult, SASACalcError> {
242 let mut atoms = vec![];
243 let mut parent_to_atoms = HashMap::new();
244 match level {
245 SASALevel::Atom => {
246 for atom in pdb.atoms() {
247 atoms.push(Atom {
248 position: Point3::new(
249 atom.pos().0 as f32,
250 atom.pos().1 as f32,
251 atom.pos().2 as f32,
252 ),
253 radius: atom
254 .element()
255 .context(ElementMissingSnafu)?
256 .atomic_radius()
257 .van_der_waals
258 .context(VanDerWaalsMissingSnafu)? as f32,
259 id: atom.serial_number(),
260 parent_id: None,
261 })
262 }
263 }
264 SASALevel::Residue | SASALevel::Protein => {
265 let mut i = 0;
266 for residue in pdb.residues() {
267 let mut temp = vec![];
268 for atom in residue.atoms() {
269 atoms.push(Atom {
270 position: Point3::new(
271 atom.pos().0 as f32,
272 atom.pos().1 as f32,
273 atom.pos().2 as f32,
274 ),
275 radius: atom
276 .element()
277 .context(ElementMissingSnafu)?
278 .atomic_radius()
279 .van_der_waals
280 .context(VanDerWaalsMissingSnafu)?
281 as f32,
282 id: atom.serial_number(),
283 parent_id: Some(residue.serial_number()),
284 });
285 temp.push(i);
286 i += 1;
287 }
288 parent_to_atoms.insert(residue.serial_number(), temp);
289 }
290 }
291 SASALevel::Chain => {
292 let mut i = 0;
293 for chain in pdb.chains() {
294 let mut temp = vec![];
295 let chain_id = serialize_chain_id(chain.id());
296 for atom in chain.atoms() {
297 atoms.push(Atom {
298 position: Point3::new(
299 atom.pos().0 as f32,
300 atom.pos().1 as f32,
301 atom.pos().2 as f32,
302 ),
303 radius: atom
304 .element()
305 .context(ElementMissingSnafu)?
306 .atomic_radius()
307 .van_der_waals
308 .context(VanDerWaalsMissingSnafu)?
309 as f32,
310 id: atom.serial_number(),
311 parent_id: Some(chain_id),
312 });
313 temp.push(i);
314 i += 1
315 }
316 parent_to_atoms.insert(chain_id, temp);
317 }
318 }
319 }
320 let atom_sasa = calculate_sasa_internal(&atoms, probe_radius, n_points);
321 return match level {
322 SASALevel::Atom => Ok(SASAResult::Atom(atom_sasa)),
323 SASALevel::Chain => {
324 let mut chain_sasa = vec![];
325 for chain in pdb.chains() {
326 let chain_id = serialize_chain_id(chain.id());
327 let chain_atom_index = parent_to_atoms
328 .get(&chain_id)
329 .context(AtomMapToLevelElementFailedSnafu)?;
330 let chain_atoms: Vec<_> = chain_atom_index
331 .iter()
332 .map(|&index| atom_sasa[index])
333 .collect();
334 let sum = simd_sum(chain_atoms.as_slice());
335 chain_sasa.push(ChainResult {
336 name: chain.id().to_string(),
337 value: sum,
338 })
339 }
340 Ok(SASAResult::Chain(chain_sasa))
341 }
342 SASALevel::Residue => {
343 let mut residue_sasa = vec![];
344 for chain in pdb.chains() {
345 for residue in chain.residues() {
346 let residue_atom_index = parent_to_atoms
347 .get(&residue.serial_number())
348 .context(AtomMapToLevelElementFailedSnafu)?;
349 let residue_atoms: Vec<_> = residue_atom_index
350 .iter()
351 .map(|&index| atom_sasa[index])
352 .collect();
353 let sum = simd_sum(residue_atoms.as_slice());
354 let name = residue
355 .name()
356 .context(FailedToGetResidueNameSnafu)?
357 .to_string();
358 residue_sasa.push(ResidueResult {
359 serial_number: residue.serial_number(),
360 value: sum,
361 is_polar: POLAR_AMINO_ACIDS.contains(&name),
362 chain_id: chain.id().to_string(),
363 name,
364 })
365 }
366 }
367 Ok(SASAResult::Residue(residue_sasa))
368 }
369 SASALevel::Protein => {
370 let mut polar_total: f32 = 0.0;
371 let mut non_polar_total: f32 = 0.0;
372 for residue in pdb.residues() {
373 let residue_atom_index = parent_to_atoms
374 .get(&residue.serial_number())
375 .context(AtomMapToLevelElementFailedSnafu)?;
376 let residue_atoms: Vec<_> = residue_atom_index
377 .iter()
378 .map(|&index| atom_sasa[index])
379 .collect();
380 let sum = simd_sum(residue_atoms.as_slice());
381 let name = residue
382 .name()
383 .context(FailedToGetResidueNameSnafu)?
384 .to_string();
385 if POLAR_AMINO_ACIDS.contains(&name) {
386 polar_total += sum
387 } else {
388 non_polar_total += sum
389 }
390 }
391 let global_sum = simd_sum(atom_sasa.as_slice());
392 Ok(SASAResult::Protein(ProteinResult {
393 global_total: global_sum,
394 polar_total,
395 non_polar_total,
396 }))
397 }
398 };
399}