rust_sasa/
lib.rs

1//! RustSASA is a Rust library for computing the absolute solvent accessible surface area (ASA/SASA) of each atom in a given protein structure using the Shrake-Rupley algorithm[1].
2mod consts;
3mod test;
4mod utils;
5
6use crate::consts::POLAR_AMINO_ACIDS;
7use crate::utils::{serialize_chain_id, simd_sum};
8use nalgebra::{Point3, Vector3};
9use pdbtbx::PDB;
10use rayon::prelude::*;
11use rstar::{PointDistance, RTree, RTreeObject, AABB};
12use snafu::prelude::*;
13use snafu::OptionExt;
14use std::collections::HashMap;
15use std::sync::Arc;
16
17/// This struct represents an individual Atom
18#[derive(Clone)]
19pub struct Atom {
20    /// The 3D position of the atom
21    pub position: Point3<f32>,
22    /// The Van Der Walls radius of the atom
23    pub radius: f32,
24    /// A unique Id for the atom
25    pub id: usize,
26    /// Parent Id
27    pub parent_id: Option<isize>,
28}
29
30/// Can be used to specify output resolution of SASA computation for convenience.
31pub enum SASALevel {
32    Atom,
33    Residue,
34    Chain,
35    Protein,
36}
37
38#[derive(Debug, PartialEq)]
39pub struct ChainResult {
40    /// Chain name
41    pub name: String,
42    /// Chain SASA value
43    pub value: f32,
44}
45
46#[derive(Debug, PartialEq)]
47pub struct ResidueResult {
48    /// Residue serial number
49    pub serial_number: isize,
50    /// SASA value for residue
51    pub value: f32,
52    //// The name of the residue
53    pub name: String,
54    /// Wether the residue is polar
55    pub is_polar: bool,
56    /// Chain ID
57    pub chain_id: String,
58}
59
60#[derive(Debug, PartialEq)]
61pub struct ProteinResult {
62    /// The total SASA value for the entire protein
63    pub global_total: f32,
64    /// The total polar SASA value for the entire protein
65    pub polar_total: f32,
66    /// The total *non*-polar SASA value for the entire protein
67    pub non_polar_total: f32,
68}
69
70#[derive(Debug, PartialEq)]
71pub enum SASAResult {
72    Atom(Vec<f32>),
73    Residue(Vec<ResidueResult>),
74    Chain(Vec<ChainResult>),
75    Protein(ProteinResult),
76}
77
78#[derive(Debug, Snafu)]
79pub enum SASACalcError {
80    #[snafu(display("Element missing for atom"))]
81    ElementMissing,
82
83    #[snafu(display("Van der Waals radius missing for element"))]
84    VanDerWaalsMissing,
85
86    #[snafu(display("Failed to map atoms back to level element"))]
87    AtomMapToLevelElementFailed,
88
89    #[snafu(display("Failed to get residue name"))]
90    FailedToGetResidueName,
91}
92
93impl RTreeObject for Atom {
94    type Envelope = AABB<[f32; 3]>;
95
96    fn envelope(&self) -> Self::Envelope {
97        AABB::from_point(<[f32; 3]>::from(self.position))
98    }
99}
100
101impl PointDistance for Atom {
102    fn distance_2(&self, other: &[f32; 3]) -> f32 {
103        let xyz = self.position.coords.xyz();
104        let z = xyz[2];
105        let y = xyz[1];
106        let x = xyz[0];
107        // No square root as that is required by the package
108        (other[2] - z).mul_add(
109            other[2] - z,
110            (other[1] - y).mul_add(other[1] - y, (other[0] - x).powi(2)),
111        )
112    }
113}
114
115/// Generates points on a sphere using the Golden Section Spiral algorithm
116fn generate_sphere_points(n_points: usize) -> Vec<Vector3<f32>> {
117    let mut points = Vec::with_capacity(n_points);
118    let golden_ratio = (1.0 + 5f32.sqrt()) / 2.0;
119    let angle_increment = 2.0 * std::f32::consts::PI * golden_ratio;
120
121    for i in 0..n_points {
122        let t = i as f32 / n_points as f32;
123        let inclination = (1.0 - 2.0 * t).acos();
124        let azimuth = angle_increment * i as f32;
125
126        let x = inclination.sin() * azimuth.cos();
127        let y = inclination.sin() * azimuth.sin();
128        let z = inclination.cos();
129
130        points.push(Vector3::new(x, y, z));
131    }
132
133    points
134}
135
136fn is_accessible_rstar(
137    test_point: &Point3<f32>,
138    atom: &Atom,
139    atoms: &RTree<Atom>,
140    probe_radius: f32,
141    max_radii: f32,
142) -> bool {
143    let xyz = test_point.coords.xyz();
144    let sr = probe_radius + (max_radii * 2.0);
145    let candidates = atoms.locate_within_distance([xyz[0], xyz[1], xyz[2]], sr * sr);
146    for candidate in candidates {
147        if atom.id != candidate.id
148            && (test_point - candidate.position).norm() < (candidate.radius + probe_radius)
149        {
150            return false;
151        }
152    }
153    true
154}
155
156/// Takes the probe radius and number of points to use along with a list of Atoms as inputs and returns a Vec with SASA values for each atom.
157/// For most users it is recommend that you use `calculate_sasa` instead. This method can be used directly if you do not want to use pdbtbx to load PDB/mmCIF files or want to load them from a different source.
158/// Probe Radius Default: 1.4
159/// Point Count Default: 100
160/// ## Example using pdbtbx:
161/// ```
162/// use nalgebra::{Point3, Vector3};
163/// use pdbtbx::StrictnessLevel;
164/// use rust_sasa::{Atom, calculate_sasa_internal};
165/// let (mut pdb, _errors) = pdbtbx::open(
166///             "./example.cif",
167///         ).unwrap();
168/// let mut atoms = vec![];
169/// for atom in pdb.atoms() {
170///     atoms.push(Atom {
171///                 position: Point3::new(atom.pos().0 as f32, atom.pos().1 as f32, atom.pos().2 as f32),
172///                 radius: atom.element().unwrap().atomic_radius().van_der_waals.unwrap() as f32,
173///                 id: atom.serial_number(),
174///                 parent_id: None
175///     })
176///  }
177///  let sasa = calculate_sasa_internal(&atoms, None, None);
178/// ```
179pub fn calculate_sasa_internal(
180    atoms: &[Atom],
181    in_probe_radius: Option<f32>,
182    in_n_points: Option<usize>,
183) -> Vec<f32> {
184    // Load defaults if not specified
185    let mut probe_radius = 1.4;
186    let mut n_points = 100;
187    if let Some(in_probe_radius) = in_probe_radius {
188        probe_radius = in_probe_radius;
189    }
190    if let Some(in_n_points) = in_n_points {
191        n_points = in_n_points;
192    }
193    //
194    let sphere_points = generate_sphere_points(n_points);
195
196    // Create R*-tree from atoms for spatial lookup
197    let tree = RTree::bulk_load(atoms.to_vec());
198    let tree_arc = Arc::new(tree); // Use Arc for safe sharing among threads
199    let mut max_radii = 0.0;
200    for atom in atoms {
201        if atom.radius > max_radii {
202            max_radii = atom.radius;
203        }
204    }
205    atoms
206        .par_iter()
207        .map(|atom| {
208            let mut accessible_points = 0;
209
210            for sphere_point in &sphere_points {
211                let test_point = atom.position + sphere_point * (atom.radius + probe_radius);
212                if is_accessible_rstar(&test_point, atom, &tree_arc, probe_radius, max_radii) {
213                    accessible_points += 1;
214                }
215            }
216            4.0 * std::f32::consts::PI
217                * (atom.radius + probe_radius).powi(2)
218                * (accessible_points as f32)
219                / (n_points as f32)
220        })
221        .collect()
222}
223
224/// This function calculates the SASA for a given protein. The output level can be specified with the level attribute e.g: (SASALevel::Atom,SASALevel::Residue,etc...).
225/// Probe radius and n_points can be customized if not customized will default to 1.4, and 100 respectively.
226/// If you want more fine-grained control you may want to use [calculate_sasa_internal] instead.
227/// ## Example
228/// ```
229/// use pdbtbx::StrictnessLevel;
230/// use rust_sasa::{Atom, calculate_sasa, calculate_sasa_internal, SASALevel};
231/// let (mut pdb, _errors) = pdbtbx::open(
232///             "./example.cif",
233/// ).unwrap();
234/// let result = calculate_sasa(&pdb,None,None,SASALevel::Residue);
235/// ```
236pub fn calculate_sasa(
237    pdb: &PDB,
238    probe_radius: Option<f32>,
239    n_points: Option<usize>,
240    level: SASALevel,
241) -> Result<SASAResult, SASACalcError> {
242    let mut atoms = vec![];
243    let mut parent_to_atoms = HashMap::new();
244    match level {
245        SASALevel::Atom => {
246            for atom in pdb.atoms() {
247                atoms.push(Atom {
248                    position: Point3::new(
249                        atom.pos().0 as f32,
250                        atom.pos().1 as f32,
251                        atom.pos().2 as f32,
252                    ),
253                    radius: atom
254                        .element()
255                        .context(ElementMissingSnafu)?
256                        .atomic_radius()
257                        .van_der_waals
258                        .context(VanDerWaalsMissingSnafu)? as f32,
259                    id: atom.serial_number(),
260                    parent_id: None,
261                })
262            }
263        }
264        SASALevel::Residue | SASALevel::Protein => {
265            let mut i = 0;
266            for residue in pdb.residues() {
267                let mut temp = vec![];
268                for atom in residue.atoms() {
269                    atoms.push(Atom {
270                        position: Point3::new(
271                            atom.pos().0 as f32,
272                            atom.pos().1 as f32,
273                            atom.pos().2 as f32,
274                        ),
275                        radius: atom
276                            .element()
277                            .context(ElementMissingSnafu)?
278                            .atomic_radius()
279                            .van_der_waals
280                            .context(VanDerWaalsMissingSnafu)?
281                            as f32,
282                        id: atom.serial_number(),
283                        parent_id: Some(residue.serial_number()),
284                    });
285                    temp.push(i);
286                    i += 1;
287                }
288                parent_to_atoms.insert(residue.serial_number(), temp);
289            }
290        }
291        SASALevel::Chain => {
292            let mut i = 0;
293            for chain in pdb.chains() {
294                let mut temp = vec![];
295                let chain_id = serialize_chain_id(chain.id());
296                for atom in chain.atoms() {
297                    atoms.push(Atom {
298                        position: Point3::new(
299                            atom.pos().0 as f32,
300                            atom.pos().1 as f32,
301                            atom.pos().2 as f32,
302                        ),
303                        radius: atom
304                            .element()
305                            .context(ElementMissingSnafu)?
306                            .atomic_radius()
307                            .van_der_waals
308                            .context(VanDerWaalsMissingSnafu)?
309                            as f32,
310                        id: atom.serial_number(),
311                        parent_id: Some(chain_id),
312                    });
313                    temp.push(i);
314                    i += 1
315                }
316                parent_to_atoms.insert(chain_id, temp);
317            }
318        }
319    }
320    let atom_sasa = calculate_sasa_internal(&atoms, probe_radius, n_points);
321    return match level {
322        SASALevel::Atom => Ok(SASAResult::Atom(atom_sasa)),
323        SASALevel::Chain => {
324            let mut chain_sasa = vec![];
325            for chain in pdb.chains() {
326                let chain_id = serialize_chain_id(chain.id());
327                let chain_atom_index = parent_to_atoms
328                    .get(&chain_id)
329                    .context(AtomMapToLevelElementFailedSnafu)?;
330                let chain_atoms: Vec<_> = chain_atom_index
331                    .iter()
332                    .map(|&index| atom_sasa[index])
333                    .collect();
334                let sum = simd_sum(chain_atoms.as_slice());
335                chain_sasa.push(ChainResult {
336                    name: chain.id().to_string(),
337                    value: sum,
338                })
339            }
340            Ok(SASAResult::Chain(chain_sasa))
341        }
342        SASALevel::Residue => {
343            let mut residue_sasa = vec![];
344            for chain in pdb.chains() {
345                for residue in chain.residues() {
346                    let residue_atom_index = parent_to_atoms
347                        .get(&residue.serial_number())
348                        .context(AtomMapToLevelElementFailedSnafu)?;
349                    let residue_atoms: Vec<_> = residue_atom_index
350                        .iter()
351                        .map(|&index| atom_sasa[index])
352                        .collect();
353                    let sum = simd_sum(residue_atoms.as_slice());
354                    let name = residue
355                        .name()
356                        .context(FailedToGetResidueNameSnafu)?
357                        .to_string();
358                    residue_sasa.push(ResidueResult {
359                        serial_number: residue.serial_number(),
360                        value: sum,
361                        is_polar: POLAR_AMINO_ACIDS.contains(&name),
362                        chain_id: chain.id().to_string(),
363                        name,
364                    })
365                }
366            }
367            Ok(SASAResult::Residue(residue_sasa))
368        }
369        SASALevel::Protein => {
370            let mut polar_total: f32 = 0.0;
371            let mut non_polar_total: f32 = 0.0;
372            for residue in pdb.residues() {
373                let residue_atom_index = parent_to_atoms
374                    .get(&residue.serial_number())
375                    .context(AtomMapToLevelElementFailedSnafu)?;
376                let residue_atoms: Vec<_> = residue_atom_index
377                    .iter()
378                    .map(|&index| atom_sasa[index])
379                    .collect();
380                let sum = simd_sum(residue_atoms.as_slice());
381                let name = residue
382                    .name()
383                    .context(FailedToGetResidueNameSnafu)?
384                    .to_string();
385                if POLAR_AMINO_ACIDS.contains(&name) {
386                    polar_total += sum
387                } else {
388                    non_polar_total += sum
389                }
390            }
391            let global_sum = simd_sum(atom_sasa.as_slice());
392            Ok(SASAResult::Protein(ProteinResult {
393                global_total: global_sum,
394                polar_total,
395                non_polar_total,
396            }))
397        }
398    };
399}