rust_sasa/
options.rs

1use crate::structures::atomic::{ChainResult, ProteinResult, ResidueResult};
2use crate::utils::consts::{POLAR_AMINO_ACIDS, load_radii_from_file};
3use crate::utils::{get_radius, serialize_chain_id, simd_sum};
4use crate::{Atom, calculate_sasa_internal};
5use fnv::FnvHashMap;
6use nalgebra::Point3;
7use pdbtbx::PDB;
8use snafu::OptionExt;
9use snafu::prelude::*;
10use std::marker::PhantomData;
11
12/// Options for configuring SASA (Solvent Accessible Surface Area) calculations.
13///
14/// This struct provides configuration options for SASA calculations at different levels
15/// of granularity (atom, residue, chain, or protein level). The type parameter `T`
16/// determines the output type and processing behavior.
17///
18/// # Type Parameters
19///
20/// * `T` - The processing level, which must implement [`SASAProcessor`]. Available levels:
21///   - [`AtomLevel`] - Returns SASA values for individual atoms
22///   - [`ResidueLevel`] - Returns SASA values aggregated by residue
23///   - [`ChainLevel`] - Returns SASA values aggregated by chain
24///   - [`ProteinLevel`] - Returns SASA values aggregated for the entire protein
25///
26/// # Fields
27///
28/// * `probe_radius` - Radius of the solvent probe sphere in Angstroms (default: 1.4)
29/// * `n_points` - Number of points on the sphere surface for sampling (default: 100)
30/// * `threads` - Number of threads to use for parallel processing (default: -1 for all cores)
31/// * `include_hydrogens` - Whether to include hydrogen atoms in calculations (default: false)
32/// * `radii_config` - Optional custom radii configuration (default: uses embedded protor.config)
33///
34/// # Examples
35///
36/// ```rust
37/// use rust_sasa::options::{SASAOptions, ResidueLevel};
38/// use pdbtbx::PDB;
39///
40/// // Create options with default settings
41/// let options = SASAOptions::<ResidueLevel>::new();
42///
43/// // Customize the configuration
44/// let custom_options = SASAOptions::<ResidueLevel>::new()
45///     .with_probe_radius(1.2)
46///     .with_n_points(200)
47///     .with_threads(-1)
48///     .with_include_hydrogens(false);
49///
50/// // Process a PDB structure
51/// # let pdb = PDB::new();
52/// let result = custom_options.process(&pdb)?;
53/// # Ok::<(), Box<dyn std::error::Error>>(())
54/// ```
55#[derive(Debug, Clone)]
56pub struct SASAOptions<T> {
57    probe_radius: f32,
58    n_points: usize,
59    threads: isize,
60    include_hydrogens: bool,
61    radii_config: Option<FnvHashMap<String, FnvHashMap<String, f32>>>,
62    allow_vdw_fallback: bool,
63    include_hetatms: bool,
64    _marker: PhantomData<T>,
65}
66
67// Zero-sized marker types for each level
68pub struct AtomLevel;
69pub struct ResidueLevel;
70pub struct ChainLevel;
71pub struct ProteinLevel;
72
73pub type AtomsMappingResult = Result<(Vec<Atom>, FnvHashMap<isize, Vec<usize>>), SASACalcError>;
74
75/// Macro to reduce duplication in atom building logic
76macro_rules! build_atom {
77    ($atoms:expr, $atom:expr, $element:expr, $residue_name:expr, $atom_name:expr, $parent_id:expr, $radii_config:expr, $allow_vdw_fallback:expr) => {{
78        let radius = match get_radius($residue_name, $atom_name, $radii_config) {
79            Some(r) => r,
80            None => {
81                if $allow_vdw_fallback {
82                    $element
83                        .atomic_radius()
84                        .van_der_waals
85                        .context(VanDerWaalsMissingSnafu)? as f32
86                } else {
87                    return Err(SASACalcError::RadiusMissing {
88                        residue_name: $residue_name.to_string(),
89                        atom_name: $atom_name.to_string(),
90                        element: $element.to_string(),
91                    });
92                }
93            }
94        };
95
96        $atoms.push(Atom {
97            position: Point3::new(
98                $atom.pos().0 as f32,
99                $atom.pos().1 as f32,
100                $atom.pos().2 as f32,
101            ),
102            radius,
103            id: $atom.serial_number(),
104            parent_id: $parent_id,
105        });
106    }};
107}
108
109// Trait that defines the processing behavior for each level
110pub trait SASAProcessor {
111    type Output;
112
113    fn process_atoms(
114        atoms: &[Atom],
115        atom_sasa: &[f32],
116        pdb: &PDB,
117        parent_to_atoms: &FnvHashMap<isize, Vec<usize>>,
118    ) -> Result<Self::Output, SASACalcError>;
119
120    fn build_atoms_and_mapping(
121        pdb: &PDB,
122        radii_config: Option<&FnvHashMap<String, FnvHashMap<String, f32>>>,
123        allow_vdw_fallback: bool,
124        include_hydrogens: bool,
125        include_hetatms: bool,
126    ) -> AtomsMappingResult;
127}
128
129impl SASAProcessor for AtomLevel {
130    type Output = Vec<f32>;
131
132    fn process_atoms(
133        _atoms: &[Atom],
134        atom_sasa: &[f32],
135        _pdb: &PDB,
136        _parent_to_atoms: &FnvHashMap<isize, Vec<usize>>,
137    ) -> Result<Self::Output, SASACalcError> {
138        Ok(atom_sasa.to_vec())
139    }
140
141    fn build_atoms_and_mapping(
142        pdb: &PDB,
143        radii_config: Option<&FnvHashMap<String, FnvHashMap<String, f32>>>,
144        allow_vdw_fallback: bool,
145        include_hydrogens: bool,
146        include_hetatms: bool,
147    ) -> Result<(Vec<Atom>, FnvHashMap<isize, Vec<usize>>), SASACalcError> {
148        let mut atoms = vec![];
149        for residue in pdb.residues() {
150            let residue_name = residue.name().context(FailedToGetResidueNameSnafu)?;
151            for atom in residue.atoms() {
152                let element = atom.element().context(ElementMissingSnafu)?;
153                let atom_name = atom.name();
154                if element == &pdbtbx::Element::H && !include_hydrogens {
155                    continue;
156                };
157                if atom.hetero() && !include_hetatms {
158                    continue;
159                }
160                build_atom!(
161                    atoms,
162                    atom,
163                    element,
164                    residue_name,
165                    atom_name,
166                    None,
167                    radii_config,
168                    allow_vdw_fallback
169                );
170            }
171        }
172        Ok((atoms, FnvHashMap::default()))
173    }
174}
175
176impl SASAProcessor for ResidueLevel {
177    type Output = Vec<ResidueResult>;
178
179    fn process_atoms(
180        _atoms: &[Atom],
181        atom_sasa: &[f32],
182        pdb: &PDB,
183        parent_to_atoms: &FnvHashMap<isize, Vec<usize>>,
184    ) -> Result<Self::Output, SASACalcError> {
185        let mut residue_sasa = vec![];
186        for chain in pdb.chains() {
187            for residue in chain.residues() {
188                let residue_atom_index = parent_to_atoms
189                    .get(&residue.serial_number())
190                    .context(AtomMapToLevelElementFailedSnafu)?;
191                let residue_atoms: Vec<_> = residue_atom_index
192                    .iter()
193                    .map(|&index| atom_sasa[index])
194                    .collect();
195                let sum = simd_sum(residue_atoms.as_slice());
196                let name = residue
197                    .name()
198                    .context(FailedToGetResidueNameSnafu)?
199                    .to_string();
200                residue_sasa.push(ResidueResult {
201                    serial_number: residue.serial_number(),
202                    value: sum,
203                    is_polar: POLAR_AMINO_ACIDS.contains(&name),
204                    chain_id: chain.id().to_string(),
205                    name,
206                })
207            }
208        }
209        Ok(residue_sasa)
210    }
211
212    fn build_atoms_and_mapping(
213        pdb: &PDB,
214        radii_config: Option<&FnvHashMap<String, FnvHashMap<String, f32>>>,
215        allow_vdw_fallback: bool,
216        include_hydrogens: bool,
217        include_hetatms: bool,
218    ) -> Result<(Vec<Atom>, FnvHashMap<isize, Vec<usize>>), SASACalcError> {
219        let mut atoms = vec![];
220        let mut parent_to_atoms = FnvHashMap::default();
221        let mut i = 0;
222        for residue in pdb.residues() {
223            let residue_name = residue.name().context(FailedToGetResidueNameSnafu)?;
224            let mut temp = vec![];
225            for atom in residue.atoms() {
226                let element = atom.element().context(ElementMissingSnafu)?;
227                let atom_name = atom.name();
228                if element == &pdbtbx::Element::H && !include_hydrogens {
229                    continue;
230                };
231                if atom.hetero() && !include_hetatms {
232                    continue;
233                }
234                build_atom!(
235                    atoms,
236                    atom,
237                    element,
238                    residue_name,
239                    atom_name,
240                    Some(residue.serial_number()),
241                    radii_config,
242                    allow_vdw_fallback
243                );
244                temp.push(i);
245                i += 1;
246            }
247            parent_to_atoms.insert(residue.serial_number(), temp);
248        }
249        Ok((atoms, parent_to_atoms))
250    }
251}
252
253impl SASAProcessor for ChainLevel {
254    type Output = Vec<ChainResult>;
255
256    fn process_atoms(
257        _atoms: &[Atom],
258        atom_sasa: &[f32],
259        pdb: &PDB,
260        parent_to_atoms: &FnvHashMap<isize, Vec<usize>>,
261    ) -> Result<Self::Output, SASACalcError> {
262        let mut chain_sasa = vec![];
263        for chain in pdb.chains() {
264            let chain_id = serialize_chain_id(chain.id());
265            let chain_atom_index = parent_to_atoms
266                .get(&chain_id)
267                .context(AtomMapToLevelElementFailedSnafu)?;
268            let chain_atoms: Vec<_> = chain_atom_index
269                .iter()
270                .map(|&index| atom_sasa[index])
271                .collect();
272            let sum = simd_sum(chain_atoms.as_slice());
273            chain_sasa.push(ChainResult {
274                name: chain.id().to_string(),
275                value: sum,
276            })
277        }
278        Ok(chain_sasa)
279    }
280
281    fn build_atoms_and_mapping(
282        pdb: &PDB,
283        radii_config: Option<&FnvHashMap<String, FnvHashMap<String, f32>>>,
284        allow_vdw_fallback: bool,
285        include_hydrogens: bool,
286        include_hetatms: bool,
287    ) -> Result<(Vec<Atom>, FnvHashMap<isize, Vec<usize>>), SASACalcError> {
288        let mut atoms = vec![];
289        let mut parent_to_atoms = FnvHashMap::default();
290        let mut i = 0;
291        for chain in pdb.chains() {
292            let mut temp = vec![];
293            let chain_id = serialize_chain_id(chain.id());
294            for residue in chain.residues() {
295                let residue_name = residue.name().context(FailedToGetResidueNameSnafu)?;
296                for atom in residue.atoms() {
297                    let element = atom.element().context(ElementMissingSnafu)?;
298                    let atom_name = atom.name();
299                    if element == &pdbtbx::Element::H && !include_hydrogens {
300                        continue;
301                    };
302                    if atom.hetero() && !include_hetatms {
303                        continue;
304                    }
305                    build_atom!(
306                        atoms,
307                        atom,
308                        element,
309                        residue_name,
310                        atom_name,
311                        Some(chain_id),
312                        radii_config,
313                        allow_vdw_fallback
314                    );
315                    temp.push(i);
316                    i += 1
317                }
318            }
319            parent_to_atoms.insert(chain_id, temp);
320        }
321        Ok((atoms, parent_to_atoms))
322    }
323}
324
325impl SASAProcessor for ProteinLevel {
326    type Output = ProteinResult;
327
328    fn process_atoms(
329        _atoms: &[Atom],
330        atom_sasa: &[f32],
331        pdb: &PDB,
332        parent_to_atoms: &FnvHashMap<isize, Vec<usize>>,
333    ) -> Result<Self::Output, SASACalcError> {
334        let mut polar_total: f32 = 0.0;
335        let mut non_polar_total: f32 = 0.0;
336        for residue in pdb.residues() {
337            let residue_atom_index = parent_to_atoms
338                .get(&residue.serial_number())
339                .context(AtomMapToLevelElementFailedSnafu)?;
340            let residue_atoms: Vec<_> = residue_atom_index
341                .iter()
342                .map(|&index| atom_sasa[index])
343                .collect();
344            let sum = simd_sum(residue_atoms.as_slice());
345            let name = residue
346                .name()
347                .context(FailedToGetResidueNameSnafu)?
348                .to_string();
349            if POLAR_AMINO_ACIDS.contains(&name) {
350                polar_total += sum
351            } else {
352                non_polar_total += sum
353            }
354        }
355        let global_sum = simd_sum(atom_sasa);
356        Ok(ProteinResult {
357            global_total: global_sum,
358            polar_total,
359            non_polar_total,
360        })
361    }
362
363    fn build_atoms_and_mapping(
364        pdb: &PDB,
365        radii_config: Option<&FnvHashMap<String, FnvHashMap<String, f32>>>,
366        allow_vdw_fallback: bool,
367        include_hydrogens: bool,
368        include_hetatms: bool,
369    ) -> Result<(Vec<Atom>, FnvHashMap<isize, Vec<usize>>), SASACalcError> {
370        let mut atoms = vec![];
371        let mut parent_to_atoms = FnvHashMap::default();
372        let mut i = 0;
373        for residue in pdb.residues() {
374            let residue_name = residue.name().context(FailedToGetResidueNameSnafu)?;
375            let mut temp = vec![];
376            for atom in residue.atoms() {
377                let element = atom.element().context(ElementMissingSnafu)?;
378                let atom_name = atom.name();
379                if element == &pdbtbx::Element::H && !include_hydrogens {
380                    continue;
381                };
382                if atom.hetero() && !include_hetatms {
383                    continue;
384                }
385                build_atom!(
386                    atoms,
387                    atom,
388                    element,
389                    residue_name,
390                    atom_name,
391                    Some(residue.serial_number()),
392                    radii_config,
393                    allow_vdw_fallback
394                );
395                temp.push(i);
396                i += 1;
397            }
398            parent_to_atoms.insert(residue.serial_number(), temp);
399        }
400        Ok((atoms, parent_to_atoms))
401    }
402}
403
404#[derive(Debug, Snafu)]
405pub enum SASACalcError {
406    #[snafu(display("Element missing for atom"))]
407    ElementMissing,
408
409    #[snafu(display("Van der Waals radius missing for element"))]
410    VanDerWaalsMissing,
411
412    #[snafu(display(
413        "Radius not found for residue '{}' atom '{}' of type '{}'. This error can can be ignored, if you are using the CLI pass --allow-vdw-fallback or use with_allow_vdw_fallback if you are using the API.",
414        residue_name,
415        atom_name,
416        element
417    ))]
418    RadiusMissing {
419        residue_name: String,
420        atom_name: String,
421        element: String,
422    },
423
424    #[snafu(display("Failed to map atoms back to level element"))]
425    AtomMapToLevelElementFailed,
426
427    #[snafu(display("Failed to get residue name"))]
428    FailedToGetResidueName,
429
430    #[snafu(display("Failed to load radii file: {source}"))]
431    RadiiFileLoad { source: std::io::Error },
432}
433
434impl<T> SASAOptions<T> {
435    /// Create a new SASAOptions with the specified level type
436    pub fn new() -> SASAOptions<T> {
437        SASAOptions {
438            probe_radius: 1.4,
439            n_points: 100,
440            threads: -1,
441            include_hydrogens: false,
442            radii_config: None,
443            allow_vdw_fallback: false,
444            include_hetatms: false,
445            _marker: PhantomData,
446        }
447    }
448
449    /// Set the probe radius (default: 1.4 Angstroms)
450    pub fn with_probe_radius(mut self, radius: f32) -> Self {
451        self.probe_radius = radius;
452        self
453    }
454
455    /// Include or exclude HETATM records in protein.
456    pub fn with_include_hetatms(mut self, include_hetatms: bool) -> Self {
457        self.include_hetatms = include_hetatms;
458        self
459    }
460
461    /// Set the number of points on the sphere for sampling (default: 100)
462    pub fn with_n_points(mut self, points: usize) -> Self {
463        self.n_points = points;
464        self
465    }
466
467    /// Configure the number of threads to use for parallel processing
468    ///   - `-1`: Use all available CPU cores (default)
469    ///   - `1`: Single-threaded execution (disables parallelism)
470    ///   - `> 1`: Use specified number of threads
471    pub fn with_threads(mut self, threads: isize) -> Self {
472        self.threads = threads;
473        self
474    }
475
476    /// Include or exclude hydrogen atoms in calculations (default: false)
477    pub fn with_include_hydrogens(mut self, include_hydrogens: bool) -> Self {
478        self.include_hydrogens = include_hydrogens;
479        self
480    }
481
482    /// Load custom radii configuration from a file
483    pub fn with_radii_file(mut self, path: &str) -> Result<Self, std::io::Error> {
484        self.radii_config = Some(load_radii_from_file(path)?);
485        Ok(self)
486    }
487
488    /// Allow fallback to van der Waals radii when custom radius is not found (default: false)
489    pub fn with_allow_vdw_fallback(mut self, allow: bool) -> Self {
490        self.allow_vdw_fallback = allow;
491        self
492    }
493}
494
495// Convenience constructors for each level
496impl SASAOptions<AtomLevel> {
497    pub fn atom_level() -> Self {
498        Self::new()
499    }
500}
501
502impl SASAOptions<ResidueLevel> {
503    pub fn residue_level() -> Self {
504        Self::new()
505    }
506}
507
508impl SASAOptions<ChainLevel> {
509    pub fn chain_level() -> Self {
510        Self::new()
511    }
512}
513
514impl SASAOptions<ProteinLevel> {
515    pub fn protein_level() -> Self {
516        Self::new()
517    }
518}
519
520impl<T> Default for SASAOptions<T> {
521    fn default() -> Self {
522        Self::new()
523    }
524}
525
526impl<T: SASAProcessor> SASAOptions<T> {
527    /// This function calculates the SASA for a given protein. The output type is determined by the level type parameter.
528    /// Probe radius and n_points can be customized, defaulting to 1.4 and 100 respectively.
529    /// If you want more fine-grained control you may want to use [calculate_sasa_internal] instead.
530    /// ## Example
531    /// ```
532    /// use pdbtbx::StrictnessLevel;
533    /// use rust_sasa::options::{SASAOptions, ResidueLevel};
534    /// let (mut pdb, _errors) = pdbtbx::open("./pdbs/example.cif").unwrap();
535    /// let result = SASAOptions::<ResidueLevel>::new().process(&pdb);
536    /// ```
537    pub fn process(&self, pdb: &PDB) -> Result<T::Output, SASACalcError> {
538        let (atoms, parent_to_atoms) = T::build_atoms_and_mapping(
539            pdb,
540            self.radii_config.as_ref(),
541            self.allow_vdw_fallback,
542            self.include_hydrogens,
543            self.include_hetatms,
544        )?;
545        let atom_sasa =
546            calculate_sasa_internal(&atoms, self.probe_radius, self.n_points, self.threads);
547        T::process_atoms(&atoms, &atom_sasa, pdb, &parent_to_atoms)
548    }
549}