rust_sasa/
options.rs

1// Copyright (c) 2024 Maxwell Campbell. Licensed under the MIT License.
2use crate::structures::atomic::{ChainResult, ProteinResult, ResidueResult};
3use crate::utils::consts::{POLAR_AMINO_ACIDS, load_radii_from_file};
4use crate::utils::{combine_hash, get_radius, serialize_chain_id, simd_sum};
5use crate::{Atom, calculate_sasa_internal};
6use fnv::FnvHashMap;
7use pdbtbx::PDB;
8use snafu::OptionExt;
9use snafu::prelude::*;
10use std::marker::PhantomData;
11
12/// Options for configuring SASA (Solvent Accessible Surface Area) calculations.
13///
14/// This struct provides configuration options for SASA calculations at different levels
15/// of granularity (atom, residue, chain, or protein level). The type parameter `T`
16/// determines the output type and processing behavior.
17///
18/// # Type Parameters
19///
20/// * `T` - The processing level, which must implement [`SASAProcessor`]. Available levels:
21///   - [`AtomLevel`] - Returns SASA values for individual atoms
22///   - [`ResidueLevel`] - Returns SASA values aggregated by residue
23///   - [`ChainLevel`] - Returns SASA values aggregated by chain
24///   - [`ProteinLevel`] - Returns SASA values aggregated for the entire protein
25///
26/// # Fields
27///
28/// * `probe_radius` - Radius of the solvent probe sphere in Angstroms (default: 1.4)
29/// * `n_points` - Number of points on the sphere surface for sampling (default: 100)
30/// * `threads` - Number of threads to use for parallel processing (default: -1 for all cores)
31/// * `include_hydrogens` - Whether to include hydrogen atoms in calculations (default: false)
32/// * `radii_config` - Optional custom radii configuration (default: uses embedded protor.config)
33/// * `allow_vdw_fallback` - Allow fallback to PDBTBX van der Waals radii when radius is not found in radii file (default: false)
34/// * `include_hetatms` - Whether to include HETATM records (e.g. non-standard amino acids) in calculations (default: false)
35///
36/// # Examples
37///
38/// ```rust
39/// use rust_sasa::options::{SASAOptions, ResidueLevel};
40/// use pdbtbx::PDB;
41///
42/// // Create options with default settings
43/// let options = SASAOptions::<ResidueLevel>::new();
44///
45/// // Customize the configuration
46/// let custom_options = SASAOptions::<ResidueLevel>::new()
47///     .with_probe_radius(1.2)
48///     .with_n_points(200)
49///     .with_threads(-1)
50///     .with_include_hydrogens(false)
51///     .with_allow_vdw_fallback(true)
52///     .with_include_hetatms(false);
53///
54/// // Process a PDB structure
55/// # let pdb = PDB::new();
56/// let result = custom_options.process(&pdb)?;
57/// # Ok::<(), Box<dyn std::error::Error>>(())
58/// ```
59#[derive(Debug, Clone)]
60pub struct SASAOptions<T> {
61    probe_radius: f32,
62    n_points: usize,
63    threads: isize,
64    include_hydrogens: bool,
65    radii_config: Option<FnvHashMap<String, FnvHashMap<String, f32>>>,
66    allow_vdw_fallback: bool,
67    include_hetatms: bool,
68    _marker: PhantomData<T>,
69}
70
71// Zero-sized marker types for each level
72pub struct AtomLevel;
73pub struct ResidueLevel;
74pub struct ChainLevel;
75pub struct ProteinLevel;
76
77pub type AtomsMappingResult = Result<(Vec<Atom>, FnvHashMap<isize, Vec<usize>>), SASACalcError>;
78
79/// Macro to reduce duplication in atom building logic
80macro_rules! build_atom {
81    ($atoms:expr, $atom:expr, $element:expr, $residue_name:expr, $atom_name:expr, $parent_id:expr, $radii_config:expr, $allow_vdw_fallback:expr, $id:expr) => {{
82        let radius = match get_radius($residue_name, $atom_name, $radii_config) {
83            Some(r) => r,
84            None => {
85                if $allow_vdw_fallback {
86                    $element
87                        .atomic_radius()
88                        .van_der_waals
89                        .context(VanDerWaalsMissingSnafu)? as f32
90                } else {
91                    return Err(SASACalcError::RadiusMissing {
92                        residue_name: $residue_name.to_string(),
93                        atom_name: $atom_name.to_string(),
94                        element: $element.to_string(),
95                    });
96                }
97            }
98        };
99
100        $atoms.push(Atom {
101            position: [
102                $atom.pos().0 as f32,
103                $atom.pos().1 as f32,
104                $atom.pos().2 as f32,
105            ],
106            radius,
107            id: $id,
108            parent_id: $parent_id,
109        });
110    }};
111}
112
113// Trait that defines the processing behavior for each level
114pub trait SASAProcessor {
115    type Output;
116
117    fn process_atoms(
118        atoms: &[Atom],
119        atom_sasa: &[f32],
120        pdb: &PDB,
121        parent_to_atoms: &FnvHashMap<isize, Vec<usize>>,
122    ) -> Result<Self::Output, SASACalcError>;
123
124    fn build_atoms_and_mapping(
125        pdb: &PDB,
126        radii_config: Option<&FnvHashMap<String, FnvHashMap<String, f32>>>,
127        allow_vdw_fallback: bool,
128        include_hydrogens: bool,
129        include_hetatms: bool,
130    ) -> AtomsMappingResult;
131}
132
133impl SASAProcessor for AtomLevel {
134    type Output = Vec<f32>;
135
136    fn process_atoms(
137        _atoms: &[Atom],
138        atom_sasa: &[f32],
139        _pdb: &PDB,
140        _parent_to_atoms: &FnvHashMap<isize, Vec<usize>>,
141    ) -> Result<Self::Output, SASACalcError> {
142        Ok(atom_sasa.to_vec())
143    }
144
145    fn build_atoms_and_mapping(
146        pdb: &PDB,
147        radii_config: Option<&FnvHashMap<String, FnvHashMap<String, f32>>>,
148        allow_vdw_fallback: bool,
149        include_hydrogens: bool,
150        include_hetatms: bool,
151    ) -> Result<(Vec<Atom>, FnvHashMap<isize, Vec<usize>>), SASACalcError> {
152        let mut atoms = vec![];
153        for residue in pdb.residues() {
154            let residue_name = residue.name().context(FailedToGetResidueNameSnafu)?;
155            for atom in residue.atoms() {
156                let element = atom.element().context(ElementMissingSnafu)?;
157                let atom_name = atom.name();
158                if element == &pdbtbx::Element::H && !include_hydrogens {
159                    continue;
160                };
161                if atom.hetero() && !include_hetatms {
162                    continue;
163                }
164                build_atom!(
165                    atoms,
166                    atom,
167                    element,
168                    residue_name,
169                    atom_name,
170                    None,
171                    radii_config,
172                    allow_vdw_fallback,
173                    combine_hash("", atom.serial_number())
174                );
175            }
176        }
177        Ok((atoms, FnvHashMap::default()))
178    }
179}
180
181impl SASAProcessor for ResidueLevel {
182    type Output = Vec<ResidueResult>;
183
184    fn process_atoms(
185        _atoms: &[Atom],
186        atom_sasa: &[f32],
187        pdb: &PDB,
188        parent_to_atoms: &FnvHashMap<isize, Vec<usize>>,
189    ) -> Result<Self::Output, SASACalcError> {
190        let mut residue_sasa = vec![];
191        for chain in pdb.chains() {
192            for residue in chain.residues() {
193                let residue_atom_index = parent_to_atoms
194                    .get(&residue.serial_number())
195                    .context(AtomMapToLevelElementFailedSnafu)?;
196                let residue_atoms: Vec<_> = residue_atom_index
197                    .iter()
198                    .map(|&index| atom_sasa[index])
199                    .collect();
200                let sum = simd_sum(residue_atoms.as_slice());
201                let name = residue
202                    .name()
203                    .context(FailedToGetResidueNameSnafu)?
204                    .to_string();
205                residue_sasa.push(ResidueResult {
206                    serial_number: residue.serial_number(),
207                    value: sum,
208                    is_polar: POLAR_AMINO_ACIDS.contains(&name),
209                    chain_id: chain.id().to_string(),
210                    name,
211                })
212            }
213        }
214        Ok(residue_sasa)
215    }
216
217    fn build_atoms_and_mapping(
218        pdb: &PDB,
219        radii_config: Option<&FnvHashMap<String, FnvHashMap<String, f32>>>,
220        allow_vdw_fallback: bool,
221        include_hydrogens: bool,
222        include_hetatms: bool,
223    ) -> Result<(Vec<Atom>, FnvHashMap<isize, Vec<usize>>), SASACalcError> {
224        let mut atoms = vec![];
225        let mut parent_to_atoms = FnvHashMap::default();
226        let mut i = 0;
227        for residue in pdb.residues() {
228            let residue_name = residue.name().context(FailedToGetResidueNameSnafu)?;
229            let mut temp = vec![];
230            for conformer in residue.conformers() {
231                for atom in conformer.atoms() {
232                    let element = atom.element().context(ElementMissingSnafu)?;
233                    let atom_name = atom.name();
234                    if element == &pdbtbx::Element::H && !include_hydrogens {
235                        continue;
236                    };
237                    if atom.hetero() && !include_hetatms {
238                        continue;
239                    }
240                    let conformer_alt = conformer.alternative_location().unwrap_or("");
241                    build_atom!(
242                        atoms,
243                        atom,
244                        element,
245                        residue_name,
246                        atom_name,
247                        Some(residue.serial_number()),
248                        radii_config,
249                        allow_vdw_fallback,
250                        combine_hash(conformer_alt, atom.serial_number())
251                    );
252                    temp.push(i);
253                    i += 1;
254                }
255            }
256            parent_to_atoms.insert(residue.serial_number(), temp);
257        }
258        Ok((atoms, parent_to_atoms))
259    }
260}
261
262impl SASAProcessor for ChainLevel {
263    type Output = Vec<ChainResult>;
264
265    fn process_atoms(
266        _atoms: &[Atom],
267        atom_sasa: &[f32],
268        pdb: &PDB,
269        parent_to_atoms: &FnvHashMap<isize, Vec<usize>>,
270    ) -> Result<Self::Output, SASACalcError> {
271        let mut chain_sasa = vec![];
272        for chain in pdb.chains() {
273            let chain_id = serialize_chain_id(chain.id());
274            let chain_atom_index = parent_to_atoms
275                .get(&chain_id)
276                .context(AtomMapToLevelElementFailedSnafu)?;
277            let chain_atoms: Vec<_> = chain_atom_index
278                .iter()
279                .map(|&index| atom_sasa[index])
280                .collect();
281            let sum = simd_sum(chain_atoms.as_slice());
282            chain_sasa.push(ChainResult {
283                name: chain.id().to_string(),
284                value: sum,
285            })
286        }
287        Ok(chain_sasa)
288    }
289
290    fn build_atoms_and_mapping(
291        pdb: &PDB,
292        radii_config: Option<&FnvHashMap<String, FnvHashMap<String, f32>>>,
293        allow_vdw_fallback: bool,
294        include_hydrogens: bool,
295        include_hetatms: bool,
296    ) -> Result<(Vec<Atom>, FnvHashMap<isize, Vec<usize>>), SASACalcError> {
297        let mut atoms = vec![];
298        let mut parent_to_atoms = FnvHashMap::default();
299        let mut i = 0;
300        for chain in pdb.chains() {
301            let mut temp = vec![];
302            let chain_id = serialize_chain_id(chain.id());
303            for residue in chain.residues() {
304                let residue_name = residue.name().context(FailedToGetResidueNameSnafu)?;
305                for atom in residue.atoms() {
306                    let element = atom.element().context(ElementMissingSnafu)?;
307                    let atom_name = atom.name();
308                    if element == &pdbtbx::Element::H && !include_hydrogens {
309                        continue;
310                    };
311                    if atom.hetero() && !include_hetatms {
312                        continue;
313                    }
314                    build_atom!(
315                        atoms,
316                        atom,
317                        element,
318                        residue_name,
319                        atom_name,
320                        Some(chain_id),
321                        radii_config,
322                        allow_vdw_fallback,
323                        combine_hash("", atom.serial_number())
324                    );
325                    temp.push(i);
326                    i += 1
327                }
328            }
329            parent_to_atoms.insert(chain_id, temp);
330        }
331        Ok((atoms, parent_to_atoms))
332    }
333}
334
335impl SASAProcessor for ProteinLevel {
336    type Output = ProteinResult;
337
338    fn process_atoms(
339        _atoms: &[Atom],
340        atom_sasa: &[f32],
341        pdb: &PDB,
342        parent_to_atoms: &FnvHashMap<isize, Vec<usize>>,
343    ) -> Result<Self::Output, SASACalcError> {
344        let mut polar_total: f32 = 0.0;
345        let mut non_polar_total: f32 = 0.0;
346        for residue in pdb.residues() {
347            let residue_atom_index = parent_to_atoms
348                .get(&residue.serial_number())
349                .context(AtomMapToLevelElementFailedSnafu)?;
350            let residue_atoms: Vec<_> = residue_atom_index
351                .iter()
352                .map(|&index| atom_sasa[index])
353                .collect();
354            let sum = simd_sum(residue_atoms.as_slice());
355            let name = residue
356                .name()
357                .context(FailedToGetResidueNameSnafu)?
358                .to_string();
359            if POLAR_AMINO_ACIDS.contains(&name) {
360                polar_total += sum
361            } else {
362                non_polar_total += sum
363            }
364        }
365        let global_sum = simd_sum(atom_sasa);
366        Ok(ProteinResult {
367            global_total: global_sum,
368            polar_total,
369            non_polar_total,
370        })
371    }
372
373    fn build_atoms_and_mapping(
374        pdb: &PDB,
375        radii_config: Option<&FnvHashMap<String, FnvHashMap<String, f32>>>,
376        allow_vdw_fallback: bool,
377        include_hydrogens: bool,
378        include_hetatms: bool,
379    ) -> Result<(Vec<Atom>, FnvHashMap<isize, Vec<usize>>), SASACalcError> {
380        let mut atoms = vec![];
381        let mut parent_to_atoms = FnvHashMap::default();
382        let mut i = 0;
383        for residue in pdb.residues() {
384            let residue_name = residue.name().context(FailedToGetResidueNameSnafu)?;
385            let mut temp = vec![];
386            for atom in residue.atoms() {
387                let element = atom.element().context(ElementMissingSnafu)?;
388                let atom_name = atom.name();
389                if element == &pdbtbx::Element::H && !include_hydrogens {
390                    continue;
391                };
392                if atom.hetero() && !include_hetatms {
393                    continue;
394                }
395                build_atom!(
396                    atoms,
397                    atom,
398                    element,
399                    residue_name,
400                    atom_name,
401                    Some(residue.serial_number()),
402                    radii_config,
403                    allow_vdw_fallback,
404                    combine_hash("", atom.serial_number())
405                );
406                temp.push(i);
407                i += 1;
408            }
409            parent_to_atoms.insert(residue.serial_number(), temp);
410        }
411        Ok((atoms, parent_to_atoms))
412    }
413}
414
415#[derive(Debug, Snafu)]
416pub enum SASACalcError {
417    #[snafu(display("Element missing for atom"))]
418    ElementMissing,
419
420    #[snafu(display("Van der Waals radius missing for element"))]
421    VanDerWaalsMissing,
422
423    #[snafu(display(
424        "Radius not found for residue '{}' atom '{}' of type '{}'. This error can can be ignored, if you are using the CLI pass --allow-vdw-fallback or use with_allow_vdw_fallback if you are using the API.",
425        residue_name,
426        atom_name,
427        element
428    ))]
429    RadiusMissing {
430        residue_name: String,
431        atom_name: String,
432        element: String,
433    },
434
435    #[snafu(display("Failed to map atoms back to level element"))]
436    AtomMapToLevelElementFailed,
437
438    #[snafu(display("Failed to get residue name"))]
439    FailedToGetResidueName,
440
441    #[snafu(display("Failed to load radii file: {source}"))]
442    RadiiFileLoad { source: std::io::Error },
443}
444
445impl<T> SASAOptions<T> {
446    /// Create a new SASAOptions with the specified level type
447    pub fn new() -> SASAOptions<T> {
448        SASAOptions {
449            probe_radius: 1.4,
450            n_points: 100,
451            threads: -1,
452            include_hydrogens: false,
453            radii_config: None,
454            allow_vdw_fallback: false,
455            include_hetatms: false,
456            _marker: PhantomData,
457        }
458    }
459
460    /// Set the probe radius (default: 1.4 Angstroms)
461    pub fn with_probe_radius(mut self, radius: f32) -> Self {
462        self.probe_radius = radius;
463        self
464    }
465
466    /// Include or exclude HETATM records in protein.
467    pub fn with_include_hetatms(mut self, include_hetatms: bool) -> Self {
468        self.include_hetatms = include_hetatms;
469        self
470    }
471
472    /// Set the number of points on the sphere for sampling (default: 100)
473    pub fn with_n_points(mut self, points: usize) -> Self {
474        self.n_points = points;
475        self
476    }
477
478    /// Configure the number of threads to use for parallel processing
479    ///   - `-1`: Use all available CPU cores (default)
480    ///   - `1`: Single-threaded execution (disables parallelism)
481    ///   - `> 1`: Use specified number of threads
482    pub fn with_threads(mut self, threads: isize) -> Self {
483        self.threads = threads;
484        self
485    }
486
487    /// Include or exclude hydrogen atoms in calculations (default: false)
488    pub fn with_include_hydrogens(mut self, include_hydrogens: bool) -> Self {
489        self.include_hydrogens = include_hydrogens;
490        self
491    }
492
493    /// Load custom radii configuration from a file (default: uses embedded protor.config)
494    pub fn with_radii_file(mut self, path: &str) -> Result<Self, std::io::Error> {
495        self.radii_config = Some(load_radii_from_file(path)?);
496        Ok(self)
497    }
498
499    /// Allow fallback to PDBTBX van der Waals radii when radius is not found in radii config file (default: false)
500    pub fn with_allow_vdw_fallback(mut self, allow: bool) -> Self {
501        self.allow_vdw_fallback = allow;
502        self
503    }
504}
505
506// Convenience constructors for each level
507impl SASAOptions<AtomLevel> {
508    pub fn atom_level() -> Self {
509        Self::new()
510    }
511}
512
513impl SASAOptions<ResidueLevel> {
514    pub fn residue_level() -> Self {
515        Self::new()
516    }
517}
518
519impl SASAOptions<ChainLevel> {
520    pub fn chain_level() -> Self {
521        Self::new()
522    }
523}
524
525impl SASAOptions<ProteinLevel> {
526    pub fn protein_level() -> Self {
527        Self::new()
528    }
529}
530
531impl<T> Default for SASAOptions<T> {
532    fn default() -> Self {
533        Self::new()
534    }
535}
536
537impl<T: SASAProcessor> SASAOptions<T> {
538    /// This function calculates the SASA for a given protein. The output type is determined by the level type parameter.
539    /// Probe radius and n_points can be customized, defaulting to 1.4 and 100 respectively.
540    /// If you want more fine-grained control you may want to use [calculate_sasa_internal] instead.
541    /// ## Example
542    /// ```
543    /// use pdbtbx::StrictnessLevel;
544    /// use rust_sasa::options::{SASAOptions, ResidueLevel};
545    /// let (mut pdb, _errors) = pdbtbx::open("./tests/data/pdbs/example.cif").unwrap();
546    /// let result = SASAOptions::<ResidueLevel>::new().process(&pdb);
547    /// ```
548    pub fn process(&self, pdb: &PDB) -> Result<T::Output, SASACalcError> {
549        let (atoms, parent_to_atoms) = T::build_atoms_and_mapping(
550            pdb,
551            self.radii_config.as_ref(),
552            self.allow_vdw_fallback,
553            self.include_hydrogens,
554            self.include_hetatms,
555        )?;
556        let atom_sasa =
557            calculate_sasa_internal(&atoms, self.probe_radius, self.n_points, self.threads);
558        T::process_atoms(&atoms, &atom_sasa, pdb, &parent_to_atoms)
559    }
560}