rust_sasa/
lib.rs

1//! RustSASA is a Rust library for computing the absolute solvent accessible surface area (ASA/SASA) of each atom in a given protein structure using the Shrake-Rupley algorithm.
2//! Example:
3//! ```rust
4//! use pdbtbx::StrictnessLevel;
5//! use rust_sasa::{SASAOptions, ResidueLevel};
6//!
7//! let (mut pdb, _errors) = pdbtbx::open("./pdbs/example.cif").unwrap();
8//! let result = SASAOptions::<ResidueLevel>::new().process(&pdb);
9//! ```
10pub mod options;
11// Re-export the new level types and processor trait
12pub use options::{AtomLevel, ChainLevel, ProteinLevel, ResidueLevel, SASAProcessor};
13use utils::consts::ANGLE_INCREMENT;
14pub mod structures;
15mod test;
16mod utils;
17
18pub use crate::options::*;
19pub use crate::structures::atomic::*;
20
21use structures::spatial_grid::SpatialGrid;
22// Re-export io functions for use in the binary crate
23use pulp::Arch;
24use rayon::prelude::*;
25pub use utils::io::{sasa_result_to_json, sasa_result_to_protein_object, sasa_result_to_xml};
26
27struct SpherePointsSoA {
28    x: Vec<f32>,
29    y: Vec<f32>,
30    z: Vec<f32>,
31}
32
33impl SpherePointsSoA {
34    fn len(&self) -> usize {
35        self.x.len()
36    }
37}
38
39/// Generates points on a sphere using the Golden Section Spiral algorithm
40fn generate_sphere_points(n_points: usize) -> SpherePointsSoA {
41    let mut x = Vec::with_capacity(n_points);
42    let mut y = Vec::with_capacity(n_points);
43    let mut z = Vec::with_capacity(n_points);
44
45    let inv_n_points = 1.0 / n_points as f32;
46
47    for i in 0..n_points {
48        let i_f32 = i as f32;
49        let t = i_f32 * inv_n_points;
50        let inclination = (1.0 - 2.0 * t).acos();
51        let azimuth = ANGLE_INCREMENT * i_f32;
52
53        // Use sin_cos for better performance
54        let (sin_azimuth, cos_azimuth) = azimuth.sin_cos();
55        let sin_inclination = inclination.sin();
56
57        x.push(sin_inclination * cos_azimuth);
58        y.push(sin_inclination * sin_azimuth);
59        z.push(inclination.cos());
60    }
61
62    SpherePointsSoA { x, y, z }
63}
64
65#[inline(never)]
66fn precompute_neighbors(
67    atoms: &[Atom],
68    probe_radius: f32,
69    max_radii: f32,
70) -> Vec<Vec<NeighborData>> {
71    let cell_size = probe_radius + max_radii;
72    let grid = SpatialGrid::new(atoms, cell_size);
73
74    let mut neighbors = Vec::with_capacity(atoms.len());
75    let mut temp_candidates = Vec::with_capacity(64); // Reuse buffer to avoid allocations
76    let sr = probe_radius + (max_radii * 2.0);
77    let sr_squared = sr * sr;
78
79    for atom in atoms.iter() {
80        let xyz = atom.position.coords.xyz();
81        grid.locate_within_distance([xyz[0], xyz[1], xyz[2]], sr_squared, &mut temp_candidates);
82        // Sort the candidates so the closest neighbors appears first.
83        // This maximizes the chance of an early exit in is_accessible_precomputed
84        let center_pos = atom.position;
85        temp_candidates.sort_unstable_by(|&a_idx, &b_idx| {
86            let pos_a = atoms[a_idx].position;
87            let pos_b = atoms[b_idx].position;
88
89            let dx_a = center_pos.x - pos_a.x;
90            let dy_a = center_pos.y - pos_a.y;
91            let dz_a = center_pos.z - pos_a.z;
92            let dist_sq_a = dx_a * dx_a + dy_a * dy_a + dz_a * dz_a;
93
94            let dx_b = center_pos.x - pos_b.x;
95            let dy_b = center_pos.y - pos_b.y;
96            let dz_b = center_pos.z - pos_b.z;
97            let dist_sq_b = dx_b * dx_b + dy_b * dy_b + dz_b * dz_b;
98
99            dist_sq_a
100                .partial_cmp(&dist_sq_b)
101                .unwrap_or(std::cmp::Ordering::Equal)
102        });
103
104        // Precompute squared thresholds for each neighbor
105        let mut neighbor_data = Vec::with_capacity(temp_candidates.len());
106        for &idx in &temp_candidates {
107            let neighbor = &atoms[idx];
108            let threshold = neighbor.radius + probe_radius;
109            neighbor_data.push(NeighborData {
110                idx: idx as u32,
111                threshold_squared: threshold * threshold,
112            });
113        }
114
115        neighbors.push(neighbor_data);
116    }
117
118    neighbors
119}
120
121struct AtomSasaKernel<'a> {
122    atom_index: usize,
123    atoms: &'a [Atom],
124    neighbors: &'a [NeighborData],
125    sphere_points: &'a SpherePointsSoA,
126    probe_radius: f32,
127}
128
129impl<'a> pulp::WithSimd for AtomSasaKernel<'a> {
130    type Output = f32;
131
132    #[inline(always)]
133    fn with_simd<S: pulp::Simd>(self, simd: S) -> Self::Output {
134        let atom = &self.atoms[self.atom_index];
135        let center_pos = atom.position;
136        let r = atom.radius + self.probe_radius;
137        let r2 = r * r;
138
139        let (sx_chunks, sx_rem) = S::as_simd_f32s(&self.sphere_points.x);
140        let (sy_chunks, sy_rem) = S::as_simd_f32s(&self.sphere_points.y);
141        let (sz_chunks, sz_rem) = S::as_simd_f32s(&self.sphere_points.z);
142
143        let mut accessible_points = 0.0;
144
145        // Precompute true mask for NOT operation
146        let zero = simd.splat_f32s(0.0);
147        let true_mask = simd.equal_f32s(zero, zero);
148
149        // Process chunks
150        for i in 0..sx_chunks.len() {
151            let sx = sx_chunks[i];
152            let sy = sy_chunks[i];
153            let sz = sz_chunks[i];
154
155            // Initialize with all false (0.0).
156            let mut chunk_mask = simd.less_than_f32s(simd.splat_f32s(1.0), simd.splat_f32s(0.0));
157
158            for neighbor in self.neighbors {
159                if self.atoms[neighbor.idx as usize].id == atom.id {
160                    continue;
161                }
162
163                let neighbor_pos = self.atoms[neighbor.idx as usize].position;
164                let vx_scalar = center_pos.x - neighbor_pos.x;
165                let vy_scalar = center_pos.y - neighbor_pos.y;
166                let vz_scalar = center_pos.z - neighbor_pos.z;
167                let v_mag_sq =
168                    vx_scalar * vx_scalar + vy_scalar * vy_scalar + vz_scalar * vz_scalar;
169
170                let t = neighbor.threshold_squared;
171                let limit_scalar = (t - v_mag_sq - r2) / (2.0 * r);
172
173                let vx = simd.splat_f32s(vx_scalar);
174                let vy = simd.splat_f32s(vy_scalar);
175                let vz = simd.splat_f32s(vz_scalar);
176                let limit = simd.splat_f32s(limit_scalar);
177
178                let dot =
179                    simd.mul_add_f32s(sx, vx, simd.mul_add_f32s(sy, vy, simd.mul_f32s(sz, vz)));
180
181                let occ = simd.less_than_f32s(dot, limit);
182                chunk_mask = simd.or_m32s(chunk_mask, occ);
183
184                let not_occ = simd.xor_m32s(chunk_mask, true_mask);
185                if simd.first_true_m32s(not_occ) == S::F32_LANES {
186                    break;
187                }
188            }
189
190            // accessible points are !chunk_mask
191            let not_occ = simd.xor_m32s(chunk_mask, true_mask);
192            let contribution =
193                simd.select_f32s(not_occ, simd.splat_f32s(1.0), simd.splat_f32s(0.0));
194            accessible_points += simd.reduce_sum_f32s(contribution);
195        }
196
197        // Process remainder
198        for i in 0..sx_rem.len() {
199            let sx = sx_rem[i];
200            let sy = sy_rem[i];
201            let sz = sz_rem[i];
202            let mut occluded = false;
203
204            for neighbor in self.neighbors {
205                if self.atoms[neighbor.idx as usize].id == atom.id {
206                    continue;
207                }
208                let n_pos = self.atoms[neighbor.idx as usize].position;
209                let vx = center_pos.x - n_pos.x;
210                let vy = center_pos.y - n_pos.y;
211                let vz = center_pos.z - n_pos.z;
212                let v_mag_sq = vx * vx + vy * vy + vz * vz;
213
214                let t = neighbor.threshold_squared;
215                let limit = (t - v_mag_sq - r2) / (2.0 * r);
216
217                let dot = sx * vx + sy * vy + sz * vz;
218                if dot < limit {
219                    occluded = true;
220                    break;
221                }
222            }
223
224            if !occluded {
225                accessible_points += 1.0;
226            }
227        }
228
229        let surface_area = 4.0 * std::f32::consts::PI * r2;
230        let inv_n_points = 1.0 / (self.sphere_points.len() as f32);
231        surface_area * accessible_points * inv_n_points
232    }
233}
234
235/// Takes the probe radius and number of points to use along with a list of Atoms as inputs and returns a Vec with SASA values for each atom.
236/// For most users it is recommend that you use `calculate_sasa` instead. This method can be used directly if you do not want to use pdbtbx to load PDB/mmCIF files or want to load them from a different source.
237/// Probe Radius Default: 1.4
238/// Point Count Default: 100
239/// ## Example using pdbtbx:
240/// ```
241/// use nalgebra::{Point3, Vector3};
242/// use pdbtbx::StrictnessLevel;
243/// use rust_sasa::{Atom, calculate_sasa_internal};
244/// let (mut pdb, _errors) = pdbtbx::open(
245///             "./pdbs/example.cif",
246///         ).unwrap();
247/// let mut atoms = vec![];
248/// for atom in pdb.atoms() {
249///     atoms.push(Atom {
250///                 position: Point3::new(atom.pos().0 as f32, atom.pos().1 as f32, atom.pos().2 as f32),
251///                 radius: atom.element().unwrap().atomic_radius().van_der_waals.unwrap() as f32,
252///                 id: atom.serial_number(),
253///                 parent_id: None
254///     })
255///  }
256///  let sasa = calculate_sasa_internal(&atoms, 1.4, 100,true);
257/// ```
258pub fn calculate_sasa_internal(
259    atoms: &[Atom],
260    probe_radius: f32,
261    n_points: usize,
262    parallel: bool,
263) -> Vec<f32> {
264    let sphere_points = generate_sphere_points(n_points);
265
266    let mut max_radii = 0.0;
267    for atom in atoms {
268        if atom.radius > max_radii {
269            max_radii = atom.radius;
270        }
271    }
272
273    // Use precomputed neighbors
274    let neighbor_indices = precompute_neighbors(atoms, probe_radius, max_radii);
275
276    let arch = Arch::new();
277
278    // Helper closure to wrap the kernel dispatch
279    let process_atom = |(i, neighbors): (usize, &Vec<NeighborData>)| {
280        arch.dispatch(AtomSasaKernel {
281            atom_index: i,
282            atoms,
283            neighbors,
284            sphere_points: &sphere_points,
285            probe_radius,
286        })
287    };
288
289    if parallel {
290        neighbor_indices
291            .par_iter()
292            .enumerate()
293            .map(process_atom)
294            .collect()
295    } else {
296        neighbor_indices
297            .iter()
298            .enumerate()
299            .map(process_atom)
300            .collect()
301    }
302}