rust_sasa/
lib.rs

1//! RustSASA is a Rust library for computing the absolute solvent accessible surface area (ASA/SASA) of each atom in a given protein structure using the Shrake-Rupley algorithm.
2//! Example:
3//! ```rust
4//! use pdbtbx::StrictnessLevel;
5//! use rust_sasa::{SASAOptions, ResidueLevel};
6//!
7//! let (mut pdb, _errors) = pdbtbx::open("./tests/data/pdbs/example.cif").unwrap();
8//! let result = SASAOptions::<ResidueLevel>::new().process(&pdb);
9//! ```
10// Copyright (c) 2024 Maxwell Campbell. Licensed under the MIT License.
11pub mod options;
12// Re-export the new level types and processor trait
13pub use options::{AtomLevel, ChainLevel, ProteinLevel, ResidueLevel, SASAProcessor};
14use utils::consts::ANGLE_INCREMENT;
15pub mod structures;
16pub mod utils;
17
18pub use crate::options::*;
19pub use crate::structures::atomic::*;
20use crate::utils::ARCH;
21use structures::spatial_grid::SpatialGrid;
22// Re-export io functions for use in the binary crate
23use rayon::prelude::*;
24#[cfg(feature = "serde_json")]
25pub use utils::io::sasa_result_to_json;
26pub use utils::io::sasa_result_to_protein_object;
27#[cfg(feature = "quick-xml")]
28pub use utils::io::sasa_result_to_xml;
29
30struct SpherePointsSoA {
31    x: Vec<f32>,
32    y: Vec<f32>,
33    z: Vec<f32>,
34}
35
36impl SpherePointsSoA {
37    fn len(&self) -> usize {
38        self.x.len()
39    }
40}
41
42/// Generates points on a sphere using the Golden Section Spiral algorithm
43fn generate_sphere_points(n_points: usize) -> SpherePointsSoA {
44    let mut x = Vec::with_capacity(n_points);
45    let mut y = Vec::with_capacity(n_points);
46    let mut z = Vec::with_capacity(n_points);
47
48    let inv_n_points = 1.0 / n_points as f32;
49
50    for i in 0..n_points {
51        let i_f32 = i as f32;
52        let t = i_f32 * inv_n_points;
53        let inclination = (1.0 - 2.0 * t).acos();
54        let azimuth = ANGLE_INCREMENT * i_f32;
55
56        // Use sin_cos for better performance
57        let (sin_azimuth, cos_azimuth) = azimuth.sin_cos();
58        let sin_inclination = inclination.sin();
59
60        x.push(sin_inclination * cos_azimuth);
61        y.push(sin_inclination * sin_azimuth);
62        z.push(inclination.cos());
63    }
64
65    SpherePointsSoA { x, y, z }
66}
67
68#[inline(never)]
69pub fn precompute_neighbors(
70    atoms: &[Atom],
71    active_indices: &[usize],
72    probe_radius: f32,
73    max_radii: f32,
74) -> Vec<Vec<NeighborData>> {
75    // Same cell_size as original
76    let cell_size = probe_radius + max_radii;
77
78    // Maximum search radius from original: atom.radius + max_radii + 2.0 * probe_radius
79    // Worst case is when atom.radius == max_radii
80    let max_search_radius = max_radii + max_radii + 2.0 * probe_radius;
81
82    let grid = SpatialGrid::new(atoms, active_indices, cell_size, max_search_radius);
83    grid.build_all_neighbor_lists(atoms, active_indices, probe_radius, max_radii)
84}
85
86struct AtomSasaKernel<'a> {
87    atom_index: usize,
88    atoms: &'a [Atom],
89    neighbors: &'a [NeighborData],
90    sphere_points: &'a SpherePointsSoA,
91    probe_radius: f32,
92}
93
94impl<'a> pulp::WithSimd for AtomSasaKernel<'a> {
95    type Output = f32;
96
97    #[inline(always)]
98    fn with_simd<S: pulp::Simd>(self, simd: S) -> Self::Output {
99        let atom = &self.atoms[self.atom_index];
100        let center_pos = atom.position;
101        let r = atom.radius + self.probe_radius;
102        let r2 = r * r;
103
104        let (sx_chunks, sx_rem) = S::as_simd_f32s(&self.sphere_points.x);
105        let (sy_chunks, sy_rem) = S::as_simd_f32s(&self.sphere_points.y);
106        let (sz_chunks, sz_rem) = S::as_simd_f32s(&self.sphere_points.z);
107
108        let mut accessible_points = 0.0;
109
110        // Precompute true mask for NOT operation
111        let zero = simd.splat_f32s(0.0);
112        let true_mask = simd.equal_f32s(zero, zero);
113
114        // Process chunks
115        for i in 0..sx_chunks.len() {
116            let sx = sx_chunks[i];
117            let sy = sy_chunks[i];
118            let sz = sz_chunks[i];
119
120            // Initialize with all false (0.0).
121            let mut chunk_mask = simd.less_than_f32s(simd.splat_f32s(1.0), simd.splat_f32s(0.0));
122
123            for neighbor in self.neighbors {
124                if self.atoms[neighbor.idx as usize].id == atom.id {
125                    continue;
126                }
127
128                let neighbor_pos = self.atoms[neighbor.idx as usize].position;
129                let vx_scalar = center_pos[0] - neighbor_pos[0];
130                let vy_scalar = center_pos[1] - neighbor_pos[1];
131                let vz_scalar = center_pos[2] - neighbor_pos[2];
132                let v_mag_sq =
133                    vx_scalar * vx_scalar + vy_scalar * vy_scalar + vz_scalar * vz_scalar;
134
135                let t = neighbor.threshold_squared;
136                let limit_scalar = (t - v_mag_sq - r2) / (2.0 * r);
137
138                let vx = simd.splat_f32s(vx_scalar);
139                let vy = simd.splat_f32s(vy_scalar);
140                let vz = simd.splat_f32s(vz_scalar);
141                let limit = simd.splat_f32s(limit_scalar);
142
143                let dot =
144                    simd.mul_add_f32s(sx, vx, simd.mul_add_f32s(sy, vy, simd.mul_f32s(sz, vz)));
145
146                let occ = simd.less_than_f32s(dot, limit);
147                chunk_mask = simd.or_m32s(chunk_mask, occ);
148
149                let not_occ = simd.xor_m32s(chunk_mask, true_mask);
150                if simd.first_true_m32s(not_occ) == S::F32_LANES {
151                    break;
152                }
153            }
154
155            // accessible points are !chunk_mask
156            let not_occ = simd.xor_m32s(chunk_mask, true_mask);
157            let contribution =
158                simd.select_f32s(not_occ, simd.splat_f32s(1.0), simd.splat_f32s(0.0));
159            accessible_points += simd.reduce_sum_f32s(contribution);
160        }
161
162        // Process remainder
163        let mut current_nb = 0;
164        for i in 0..sx_rem.len() {
165            let sx = sx_rem[i];
166            let sy = sy_rem[i];
167            let sz = sz_rem[i];
168            let mut occluded = false;
169
170            // First check the neighbor that occluded the previous point
171            if current_nb < self.neighbors.len() {
172                let neighbor = &self.neighbors[current_nb];
173                if self.atoms[neighbor.idx as usize].id != atom.id {
174                    if self.atoms[neighbor.idx as usize].id == atom.id {
175                        continue;
176                    }
177                    let n_pos = self.atoms[neighbor.idx as usize].position;
178                    let vx = center_pos[0] - n_pos[0];
179                    let vy = center_pos[1] - n_pos[1];
180                    let vz = center_pos[2] - n_pos[2];
181                    let v_mag_sq = vx * vx + vy * vy + vz * vz;
182
183                    let t = neighbor.threshold_squared;
184                    let limit = (t - v_mag_sq - r2) / (2.0 * r);
185                    let dot = sx * vx + sy * vy + sz * vz;
186                    if dot <= limit {
187                        occluded = true;
188                    }
189                }
190            }
191
192            // Only search all neighbors if the cached one didn't occlude
193            if !occluded {
194                for (idx, neighbor) in self.neighbors.iter().enumerate() {
195                    if self.atoms[neighbor.idx as usize].id == atom.id {
196                        continue;
197                    }
198                    let n_pos = self.atoms[neighbor.idx as usize].position;
199                    let vx = center_pos[0] - n_pos[0];
200                    let vy = center_pos[1] - n_pos[1];
201                    let vz = center_pos[2] - n_pos[2];
202                    let v_mag_sq = vx * vx + vy * vy + vz * vz;
203
204                    let t = neighbor.threshold_squared;
205                    let limit = (t - v_mag_sq - r2) / (2.0 * r);
206                    let dot = sx * vx + sy * vy + sz * vz;
207                    if dot <= limit {
208                        occluded = true;
209                        current_nb = idx; // Cache for next point
210                        break;
211                    }
212                }
213            }
214
215            if !occluded {
216                accessible_points += 1.0;
217            }
218        }
219
220        let surface_area = 4.0 * std::f32::consts::PI * r2;
221        let inv_n_points = 1.0 / (self.sphere_points.len() as f32);
222        surface_area * accessible_points * inv_n_points
223    }
224}
225
226/// Takes the probe radius and number of points to use along with a list of Atoms as inputs and returns a Vec with SASA values for each atom.
227/// For most users it is recommend that you use the SASAOptions interface instead. This method can be used directly if you do not want to use pdbtbx to load PDB/mmCIF files or want to load them from a different source.
228/// Probe Radius Default: 1.4
229/// Point Count Default: 100
230/// Threads Default: -1 (use all cores)
231/// ## Example using pdbtbx:
232/// ```
233/// use pdbtbx::StrictnessLevel;
234/// use rust_sasa::{Atom, calculate_sasa_internal};
235/// let (mut pdb, _errors) = pdbtbx::open(
236///             "./tests/data/pdbs/example.cif",
237///         ).unwrap();
238/// let mut atoms = vec![];
239/// for atom in pdb.atoms() {
240///     atoms.push(Atom {
241///                 position: [atom.pos().0 as f32, atom.pos().1 as f32, atom.pos().2 as f32],
242///                 radius: atom.element().unwrap().atomic_radius().van_der_waals.unwrap() as f32,
243///                 id: atom.serial_number(),
244///                 parent_id: None,
245///     })
246///  }
247///  let sasa = calculate_sasa_internal(&atoms, 1.4, 100, -1);
248/// ```
249pub fn calculate_sasa_internal(
250    atoms: &[Atom],
251    probe_radius: f32,
252    n_points: usize,
253    threads: isize,
254) -> Vec<f32> {
255    let active_indices: Vec<usize> = (0..atoms.len()).collect();
256
257    let sphere_points = generate_sphere_points(n_points);
258
259    let max_radii = active_indices
260        .iter()
261        .map(|&i| atoms[i].radius)
262        .fold(0.0f32, f32::max);
263
264    let neighbor_lists = precompute_neighbors(atoms, &active_indices, probe_radius, max_radii);
265
266    let process_atom = |(list_idx, neighbors): (usize, &Vec<NeighborData>)| {
267        let orig_idx = active_indices[list_idx];
268        ARCH.dispatch(AtomSasaKernel {
269            atom_index: orig_idx,
270            atoms,
271            neighbors,
272            sphere_points: &sphere_points,
273            probe_radius,
274        })
275    };
276
277    // Use sequential iteration when threads == 1 to avoid thread pool overhead
278    let active_results: Vec<f32> = if threads == 1 {
279        neighbor_lists
280            .iter()
281            .enumerate()
282            .map(process_atom)
283            .collect()
284    } else {
285        neighbor_lists
286            .par_iter()
287            .enumerate()
288            .map(process_atom)
289            .collect()
290    };
291
292    // Map results back to original indices (hydrogens get 0.0)
293    let mut results = vec![0.0; atoms.len()];
294    for (list_idx, &orig_idx) in active_indices.iter().enumerate() {
295        results[orig_idx] = active_results[list_idx];
296    }
297    results
298}