rust_sasa/
lib.rs

1//! RustSASA is a Rust library for computing the absolute solvent accessible surface area (ASA/SASA) of each atom in a given protein structure using the Shrake-Rupley algorithm.
2//! Example:
3//! ```rust
4//! use pdbtbx::StrictnessLevel;
5//! use rust_sasa::{SASAOptions, ResidueLevel};
6//!
7//! let (mut pdb, _errors) = pdbtbx::open("./pdbs/example.cif").unwrap();
8//! let result = SASAOptions::<ResidueLevel>::new().process(&pdb);
9//! ```
10pub mod options;
11// Re-export the new level types and processor trait
12pub use options::{AtomLevel, ChainLevel, ProteinLevel, ResidueLevel, SASAProcessor};
13use utils::consts::ANGLE_INCREMENT;
14pub mod structures;
15mod test;
16mod utils;
17
18pub use crate::options::*;
19pub use crate::structures::atomic::*;
20use crate::utils::ARCH;
21use structures::spatial_grid::SpatialGrid;
22// Re-export io functions for use in the binary crate
23use rayon::prelude::*;
24pub use utils::io::{sasa_result_to_json, sasa_result_to_protein_object, sasa_result_to_xml};
25
26struct SpherePointsSoA {
27    x: Vec<f32>,
28    y: Vec<f32>,
29    z: Vec<f32>,
30}
31
32impl SpherePointsSoA {
33    fn len(&self) -> usize {
34        self.x.len()
35    }
36}
37
38/// Generates points on a sphere using the Golden Section Spiral algorithm
39fn generate_sphere_points(n_points: usize) -> SpherePointsSoA {
40    let mut x = Vec::with_capacity(n_points);
41    let mut y = Vec::with_capacity(n_points);
42    let mut z = Vec::with_capacity(n_points);
43
44    let inv_n_points = 1.0 / n_points as f32;
45
46    for i in 0..n_points {
47        let i_f32 = i as f32;
48        let t = i_f32 * inv_n_points;
49        let inclination = (1.0 - 2.0 * t).acos();
50        let azimuth = ANGLE_INCREMENT * i_f32;
51
52        // Use sin_cos for better performance
53        let (sin_azimuth, cos_azimuth) = azimuth.sin_cos();
54        let sin_inclination = inclination.sin();
55
56        x.push(sin_inclination * cos_azimuth);
57        y.push(sin_inclination * sin_azimuth);
58        z.push(inclination.cos());
59    }
60
61    SpherePointsSoA { x, y, z }
62}
63
64#[inline(never)]
65fn precompute_neighbors(
66    atoms: &[Atom],
67    active_indices: &[usize],
68    probe_radius: f32,
69    max_radii: f32,
70) -> Vec<Vec<NeighborData>> {
71    let cell_size = probe_radius + max_radii;
72    let grid = SpatialGrid::new(atoms, active_indices, cell_size);
73
74    let mut neighbors = Vec::with_capacity(active_indices.len());
75    let mut temp_candidates = Vec::with_capacity(64);
76    let sr = probe_radius + (max_radii * 2.0);
77    let sr_squared = sr * sr;
78
79    for &orig_idx in active_indices {
80        let atom = &atoms[orig_idx];
81        let xyz = atom.position.coords.xyz();
82        grid.locate_within_distance([xyz[0], xyz[1], xyz[2]], sr_squared, &mut temp_candidates);
83
84        // Sort candidates by distance (closest first for early exit optimization)
85        let center_pos = atom.position;
86        temp_candidates.sort_unstable_by(|&a_idx, &b_idx| {
87            let pos_a = atoms[a_idx].position;
88            let pos_b = atoms[b_idx].position;
89
90            let dx_a = center_pos.x - pos_a.x;
91            let dy_a = center_pos.y - pos_a.y;
92            let dz_a = center_pos.z - pos_a.z;
93            let dist_sq_a = dx_a * dx_a + dy_a * dy_a + dz_a * dz_a;
94
95            let dx_b = center_pos.x - pos_b.x;
96            let dy_b = center_pos.y - pos_b.y;
97            let dz_b = center_pos.z - pos_b.z;
98            let dist_sq_b = dx_b * dx_b + dy_b * dy_b + dz_b * dz_b;
99
100            dist_sq_a
101                .partial_cmp(&dist_sq_b)
102                .unwrap_or(std::cmp::Ordering::Equal)
103        });
104
105        // Precompute squared thresholds for each neighbor
106        let mut neighbor_data = Vec::with_capacity(temp_candidates.len());
107        for &idx in &temp_candidates {
108            let neighbor = &atoms[idx];
109            let threshold = neighbor.radius + probe_radius;
110            neighbor_data.push(NeighborData {
111                idx: idx as u32,
112                threshold_squared: threshold * threshold,
113            });
114        }
115
116        neighbors.push(neighbor_data);
117    }
118
119    neighbors
120}
121
122struct AtomSasaKernel<'a> {
123    atom_index: usize,
124    atoms: &'a [Atom],
125    neighbors: &'a [NeighborData],
126    sphere_points: &'a SpherePointsSoA,
127    probe_radius: f32,
128}
129
130impl<'a> pulp::WithSimd for AtomSasaKernel<'a> {
131    type Output = f32;
132
133    #[inline(always)]
134    fn with_simd<S: pulp::Simd>(self, simd: S) -> Self::Output {
135        let atom = &self.atoms[self.atom_index];
136        let center_pos = atom.position;
137        let r = atom.radius + self.probe_radius;
138        let r2 = r * r;
139
140        let (sx_chunks, sx_rem) = S::as_simd_f32s(&self.sphere_points.x);
141        let (sy_chunks, sy_rem) = S::as_simd_f32s(&self.sphere_points.y);
142        let (sz_chunks, sz_rem) = S::as_simd_f32s(&self.sphere_points.z);
143
144        let mut accessible_points = 0.0;
145
146        // Precompute true mask for NOT operation
147        let zero = simd.splat_f32s(0.0);
148        let true_mask = simd.equal_f32s(zero, zero);
149
150        // Process chunks
151        for i in 0..sx_chunks.len() {
152            let sx = sx_chunks[i];
153            let sy = sy_chunks[i];
154            let sz = sz_chunks[i];
155
156            // Initialize with all false (0.0).
157            let mut chunk_mask = simd.less_than_f32s(simd.splat_f32s(1.0), simd.splat_f32s(0.0));
158
159            for neighbor in self.neighbors {
160                if self.atoms[neighbor.idx as usize].id == atom.id {
161                    continue;
162                }
163
164                let neighbor_pos = self.atoms[neighbor.idx as usize].position;
165                let vx_scalar = center_pos.x - neighbor_pos.x;
166                let vy_scalar = center_pos.y - neighbor_pos.y;
167                let vz_scalar = center_pos.z - neighbor_pos.z;
168                let v_mag_sq =
169                    vx_scalar * vx_scalar + vy_scalar * vy_scalar + vz_scalar * vz_scalar;
170
171                let t = neighbor.threshold_squared;
172                let limit_scalar = (t - v_mag_sq - r2) / (2.0 * r);
173
174                let vx = simd.splat_f32s(vx_scalar);
175                let vy = simd.splat_f32s(vy_scalar);
176                let vz = simd.splat_f32s(vz_scalar);
177                let limit = simd.splat_f32s(limit_scalar);
178
179                let dot =
180                    simd.mul_add_f32s(sx, vx, simd.mul_add_f32s(sy, vy, simd.mul_f32s(sz, vz)));
181
182                let occ = simd.less_than_f32s(dot, limit);
183                chunk_mask = simd.or_m32s(chunk_mask, occ);
184
185                let not_occ = simd.xor_m32s(chunk_mask, true_mask);
186                if simd.first_true_m32s(not_occ) == S::F32_LANES {
187                    break;
188                }
189            }
190
191            // accessible points are !chunk_mask
192            let not_occ = simd.xor_m32s(chunk_mask, true_mask);
193            let contribution =
194                simd.select_f32s(not_occ, simd.splat_f32s(1.0), simd.splat_f32s(0.0));
195            accessible_points += simd.reduce_sum_f32s(contribution);
196        }
197
198        // Process remainder
199        for i in 0..sx_rem.len() {
200            let sx = sx_rem[i];
201            let sy = sy_rem[i];
202            let sz = sz_rem[i];
203            let mut occluded = false;
204
205            for neighbor in self.neighbors {
206                if self.atoms[neighbor.idx as usize].id == atom.id {
207                    continue;
208                }
209                let n_pos = self.atoms[neighbor.idx as usize].position;
210                let vx = center_pos.x - n_pos.x;
211                let vy = center_pos.y - n_pos.y;
212                let vz = center_pos.z - n_pos.z;
213                let v_mag_sq = vx * vx + vy * vy + vz * vz;
214
215                let t = neighbor.threshold_squared;
216                let limit = (t - v_mag_sq - r2) / (2.0 * r);
217
218                let dot = sx * vx + sy * vy + sz * vz;
219                if dot < limit {
220                    occluded = true;
221                    break;
222                }
223            }
224
225            if !occluded {
226                accessible_points += 1.0;
227            }
228        }
229
230        let surface_area = 4.0 * std::f32::consts::PI * r2;
231        let inv_n_points = 1.0 / (self.sphere_points.len() as f32);
232        surface_area * accessible_points * inv_n_points
233    }
234}
235
236/// Takes the probe radius and number of points to use along with a list of Atoms as inputs and returns a Vec with SASA values for each atom.
237/// For most users it is recommend that you use `calculate_sasa` instead. This method can be used directly if you do not want to use pdbtbx to load PDB/mmCIF files or want to load them from a different source.
238/// Probe Radius Default: 1.4
239/// Point Count Default: 100
240/// Include Hydrogens Default: false
241/// ## Example using pdbtbx:
242/// ```
243/// use nalgebra::{Point3, Vector3};
244/// use pdbtbx::StrictnessLevel;
245/// use rust_sasa::{Atom, calculate_sasa_internal};
246/// let (mut pdb, _errors) = pdbtbx::open(
247///             "./pdbs/example.cif",
248///         ).unwrap();
249/// let mut atoms = vec![];
250/// for atom in pdb.atoms() {
251///     atoms.push(Atom {
252///                 position: Point3::new(atom.pos().0 as f32, atom.pos().1 as f32, atom.pos().2 as f32),
253///                 radius: atom.element().unwrap().atomic_radius().van_der_waals.unwrap() as f32,
254///                 id: atom.serial_number(),
255///                 parent_id: None,
256///                 is_hydrogen: atom.element() == Some(&pdbtbx::Element::H)
257///     })
258///  }
259///  let sasa = calculate_sasa_internal(&atoms, 1.4, 100, true, false);
260/// ```
261pub fn calculate_sasa_internal(
262    atoms: &[Atom],
263    probe_radius: f32,
264    n_points: usize,
265    parallel: bool,
266    include_hydrogens: bool,
267) -> Vec<f32> {
268    let active_indices: Vec<usize> = if include_hydrogens {
269        (0..atoms.len()).collect()
270    } else {
271        atoms
272            .iter()
273            .enumerate()
274            .filter_map(|(i, atom)| (!atom.is_hydrogen).then_some(i))
275            .collect()
276    };
277
278    let sphere_points = generate_sphere_points(n_points);
279
280    let max_radii = active_indices
281        .iter()
282        .map(|&i| atoms[i].radius)
283        .fold(0.0f32, f32::max);
284
285    let neighbor_lists = precompute_neighbors(atoms, &active_indices, probe_radius, max_radii);
286
287    let process_atom = |(list_idx, neighbors): (usize, &Vec<NeighborData>)| {
288        let orig_idx = active_indices[list_idx];
289        ARCH.dispatch(AtomSasaKernel {
290            atom_index: orig_idx,
291            atoms,
292            neighbors,
293            sphere_points: &sphere_points,
294            probe_radius,
295        })
296    };
297
298    let active_results: Vec<f32> = if parallel {
299        neighbor_lists
300            .par_iter()
301            .enumerate()
302            .map(process_atom)
303            .collect()
304    } else {
305        neighbor_lists
306            .iter()
307            .enumerate()
308            .map(process_atom)
309            .collect()
310    };
311
312    // Map results back to original indices (hydrogens get 0.0)
313    let mut results = vec![0.0; atoms.len()];
314    for (list_idx, &orig_idx) in active_indices.iter().enumerate() {
315        results[orig_idx] = active_results[list_idx];
316    }
317    results
318}