rust_sasa/
options.rs

1use crate::structures::atomic::{ChainResult, ProteinResult, ResidueResult};
2use crate::utils::consts::POLAR_AMINO_ACIDS;
3use crate::utils::{serialize_chain_id, simd_sum};
4use crate::{Atom, calculate_sasa_internal};
5use nalgebra::Point3;
6use pdbtbx::PDB;
7use snafu::OptionExt;
8use snafu::prelude::*;
9use std::collections::HashMap;
10use std::marker::PhantomData;
11
12/// Options for configuring SASA (Solvent Accessible Surface Area) calculations.
13///
14/// This struct provides configuration options for SASA calculations at different levels
15/// of granularity (atom, residue, chain, or protein level). The type parameter `T`
16/// determines the output type and processing behavior.
17///
18/// # Type Parameters
19///
20/// * `T` - The processing level, which must implement [`SASAProcessor`]. Available levels:
21///   - [`AtomLevel`] - Returns SASA values for individual atoms
22///   - [`ResidueLevel`] - Returns SASA values aggregated by residue
23///   - [`ChainLevel`] - Returns SASA values aggregated by chain
24///   - [`ProteinLevel`] - Returns SASA values aggregated for the entire protein
25///
26/// # Fields
27///
28/// * `probe_radius` - Radius of the solvent probe sphere in Angstroms (default: 1.4)
29/// * `n_points` - Number of points on the sphere surface for sampling (default: 100)
30/// * `parallel` - Whether to use parallel processing (default: true)
31///
32/// # Examples
33///
34/// ```rust
35/// use rust_sasa::options::{SASAOptions, ResidueLevel};
36/// use pdbtbx::PDB;
37///
38/// // Create options with default settings
39/// let options = SASAOptions::<ResidueLevel>::new();
40///
41/// // Customize the configuration
42/// let custom_options = SASAOptions::<ResidueLevel>::new()
43///     .with_probe_radius(1.2)
44///     .with_n_points(200)
45///     .with_parallel(true);
46///
47/// // Process a PDB structure
48/// # let pdb = PDB::new();
49/// let result = custom_options.process(&pdb)?;
50/// # Ok::<(), Box<dyn std::error::Error>>(())
51/// ```
52#[derive(Debug, Clone)]
53pub struct SASAOptions<T> {
54    probe_radius: f32,
55    n_points: usize,
56    parallel: bool,
57    _marker: PhantomData<T>,
58}
59
60// Zero-sized marker types for each level
61pub struct AtomLevel;
62pub struct ResidueLevel;
63pub struct ChainLevel;
64pub struct ProteinLevel;
65
66pub type AtomsMappingResult = Result<(Vec<Atom>, HashMap<isize, Vec<usize>>), SASACalcError>;
67
68// Trait that defines the processing behavior for each level
69pub trait SASAProcessor {
70    type Output;
71
72    fn process_atoms(
73        atoms: &[Atom],
74        atom_sasa: &[f32],
75        pdb: &PDB,
76        parent_to_atoms: &HashMap<isize, Vec<usize>>,
77    ) -> Result<Self::Output, SASACalcError>;
78
79    fn build_atoms_and_mapping(pdb: &PDB) -> AtomsMappingResult;
80}
81
82impl SASAProcessor for AtomLevel {
83    type Output = Vec<f32>;
84
85    fn process_atoms(
86        _atoms: &[Atom],
87        atom_sasa: &[f32],
88        _pdb: &PDB,
89        _parent_to_atoms: &HashMap<isize, Vec<usize>>,
90    ) -> Result<Self::Output, SASACalcError> {
91        Ok(atom_sasa.to_vec())
92    }
93
94    fn build_atoms_and_mapping(
95        pdb: &PDB,
96    ) -> Result<(Vec<Atom>, HashMap<isize, Vec<usize>>), SASACalcError> {
97        let mut atoms = vec![];
98        for atom in pdb.atoms() {
99            atoms.push(Atom {
100                position: Point3::new(
101                    atom.pos().0 as f32,
102                    atom.pos().1 as f32,
103                    atom.pos().2 as f32,
104                ),
105                radius: atom
106                    .element()
107                    .context(ElementMissingSnafu)?
108                    .atomic_radius()
109                    .van_der_waals
110                    .context(VanDerWaalsMissingSnafu)? as f32,
111                id: atom.serial_number(),
112                parent_id: None,
113            })
114        }
115        Ok((atoms, HashMap::new()))
116    }
117}
118
119impl SASAProcessor for ResidueLevel {
120    type Output = Vec<ResidueResult>;
121
122    fn process_atoms(
123        _atoms: &[Atom],
124        atom_sasa: &[f32],
125        pdb: &PDB,
126        parent_to_atoms: &HashMap<isize, Vec<usize>>,
127    ) -> Result<Self::Output, SASACalcError> {
128        let mut residue_sasa = vec![];
129        for chain in pdb.chains() {
130            for residue in chain.residues() {
131                let residue_atom_index = parent_to_atoms
132                    .get(&residue.serial_number())
133                    .context(AtomMapToLevelElementFailedSnafu)?;
134                let residue_atoms: Vec<_> = residue_atom_index
135                    .iter()
136                    .map(|&index| atom_sasa[index])
137                    .collect();
138                let sum = simd_sum(residue_atoms.as_slice());
139                let name = residue
140                    .name()
141                    .context(FailedToGetResidueNameSnafu)?
142                    .to_string();
143                residue_sasa.push(ResidueResult {
144                    serial_number: residue.serial_number(),
145                    value: sum,
146                    is_polar: POLAR_AMINO_ACIDS.contains(&name),
147                    chain_id: chain.id().to_string(),
148                    name,
149                })
150            }
151        }
152        Ok(residue_sasa)
153    }
154
155    fn build_atoms_and_mapping(
156        pdb: &PDB,
157    ) -> Result<(Vec<Atom>, HashMap<isize, Vec<usize>>), SASACalcError> {
158        let mut atoms = vec![];
159        let mut parent_to_atoms = HashMap::new();
160        let mut i = 0;
161        for residue in pdb.residues() {
162            let mut temp = vec![];
163            for atom in residue.atoms() {
164                atoms.push(Atom {
165                    position: Point3::new(
166                        atom.pos().0 as f32,
167                        atom.pos().1 as f32,
168                        atom.pos().2 as f32,
169                    ),
170                    radius: atom
171                        .element()
172                        .context(ElementMissingSnafu)?
173                        .atomic_radius()
174                        .van_der_waals
175                        .context(VanDerWaalsMissingSnafu)? as f32,
176                    id: atom.serial_number(),
177                    parent_id: Some(residue.serial_number()),
178                });
179                temp.push(i);
180                i += 1;
181            }
182            parent_to_atoms.insert(residue.serial_number(), temp);
183        }
184        Ok((atoms, parent_to_atoms))
185    }
186}
187
188impl SASAProcessor for ChainLevel {
189    type Output = Vec<ChainResult>;
190
191    fn process_atoms(
192        _atoms: &[Atom],
193        atom_sasa: &[f32],
194        pdb: &PDB,
195        parent_to_atoms: &HashMap<isize, Vec<usize>>,
196    ) -> Result<Self::Output, SASACalcError> {
197        let mut chain_sasa = vec![];
198        for chain in pdb.chains() {
199            let chain_id = serialize_chain_id(chain.id());
200            let chain_atom_index = parent_to_atoms
201                .get(&chain_id)
202                .context(AtomMapToLevelElementFailedSnafu)?;
203            let chain_atoms: Vec<_> = chain_atom_index
204                .iter()
205                .map(|&index| atom_sasa[index])
206                .collect();
207            let sum = simd_sum(chain_atoms.as_slice());
208            chain_sasa.push(ChainResult {
209                name: chain.id().to_string(),
210                value: sum,
211            })
212        }
213        Ok(chain_sasa)
214    }
215
216    fn build_atoms_and_mapping(
217        pdb: &PDB,
218    ) -> Result<(Vec<Atom>, HashMap<isize, Vec<usize>>), SASACalcError> {
219        let mut atoms = vec![];
220        let mut parent_to_atoms = HashMap::new();
221        let mut i = 0;
222        for chain in pdb.chains() {
223            let mut temp = vec![];
224            let chain_id = serialize_chain_id(chain.id());
225            for atom in chain.atoms() {
226                atoms.push(Atom {
227                    position: Point3::new(
228                        atom.pos().0 as f32,
229                        atom.pos().1 as f32,
230                        atom.pos().2 as f32,
231                    ),
232                    radius: atom
233                        .element()
234                        .context(ElementMissingSnafu)?
235                        .atomic_radius()
236                        .van_der_waals
237                        .context(VanDerWaalsMissingSnafu)? as f32,
238                    id: atom.serial_number(),
239                    parent_id: Some(chain_id),
240                });
241                temp.push(i);
242                i += 1
243            }
244            parent_to_atoms.insert(chain_id, temp);
245        }
246        Ok((atoms, parent_to_atoms))
247    }
248}
249
250impl SASAProcessor for ProteinLevel {
251    type Output = ProteinResult;
252
253    fn process_atoms(
254        _atoms: &[Atom],
255        atom_sasa: &[f32],
256        pdb: &PDB,
257        parent_to_atoms: &HashMap<isize, Vec<usize>>,
258    ) -> Result<Self::Output, SASACalcError> {
259        let mut polar_total: f32 = 0.0;
260        let mut non_polar_total: f32 = 0.0;
261        for residue in pdb.residues() {
262            let residue_atom_index = parent_to_atoms
263                .get(&residue.serial_number())
264                .context(AtomMapToLevelElementFailedSnafu)?;
265            let residue_atoms: Vec<_> = residue_atom_index
266                .iter()
267                .map(|&index| atom_sasa[index])
268                .collect();
269            let sum = simd_sum(residue_atoms.as_slice());
270            let name = residue
271                .name()
272                .context(FailedToGetResidueNameSnafu)?
273                .to_string();
274            if POLAR_AMINO_ACIDS.contains(&name) {
275                polar_total += sum
276            } else {
277                non_polar_total += sum
278            }
279        }
280        let global_sum = simd_sum(atom_sasa);
281        Ok(ProteinResult {
282            global_total: global_sum,
283            polar_total,
284            non_polar_total,
285        })
286    }
287
288    fn build_atoms_and_mapping(
289        pdb: &PDB,
290    ) -> Result<(Vec<Atom>, HashMap<isize, Vec<usize>>), SASACalcError> {
291        let mut atoms = vec![];
292        let mut parent_to_atoms = HashMap::new();
293        let mut i = 0;
294        for residue in pdb.residues() {
295            let mut temp = vec![];
296            for atom in residue.atoms() {
297                atoms.push(Atom {
298                    position: Point3::new(
299                        atom.pos().0 as f32,
300                        atom.pos().1 as f32,
301                        atom.pos().2 as f32,
302                    ),
303                    radius: atom
304                        .element()
305                        .context(ElementMissingSnafu)?
306                        .atomic_radius()
307                        .van_der_waals
308                        .context(VanDerWaalsMissingSnafu)? as f32,
309                    id: atom.serial_number(),
310                    parent_id: Some(residue.serial_number()),
311                });
312                temp.push(i);
313                i += 1;
314            }
315            parent_to_atoms.insert(residue.serial_number(), temp);
316        }
317        Ok((atoms, parent_to_atoms))
318    }
319}
320
321#[derive(Debug, Snafu)]
322pub enum SASACalcError {
323    #[snafu(display("Element missing for atom"))]
324    ElementMissing,
325
326    #[snafu(display("Van der Waals radius missing for element"))]
327    VanDerWaalsMissing,
328
329    #[snafu(display("Failed to map atoms back to level element"))]
330    AtomMapToLevelElementFailed,
331
332    #[snafu(display("Failed to get residue name"))]
333    FailedToGetResidueName,
334}
335
336impl Default for SASAOptions<ResidueLevel> {
337    fn default() -> Self {
338        Self {
339            probe_radius: 1.4, // Standard water probe radius in Angstroms
340            n_points: 100,     // Number of points on sphere for sampling
341            parallel: true,    // Parallel processing by default
342            _marker: PhantomData,
343        }
344    }
345}
346
347impl<T> SASAOptions<T> {
348    /// Create a new SASAOptions with the specified level type
349    pub fn new() -> SASAOptions<T> {
350        SASAOptions {
351            probe_radius: 1.4,
352            n_points: 100,
353            parallel: false,
354            _marker: PhantomData,
355        }
356    }
357
358    /// Set the probe radius (default: 1.4 Angstroms)
359    pub fn with_probe_radius(mut self, radius: f32) -> Self {
360        self.probe_radius = radius;
361        self
362    }
363
364    /// Set the number of points on the sphere for sampling (default: 100)
365    pub fn with_n_points(mut self, points: usize) -> Self {
366        self.n_points = points;
367        self
368    }
369
370    /// Enable or disable parallel processing (default: true)
371    pub fn with_parallel(mut self, parallel: bool) -> Self {
372        self.parallel = parallel;
373        self
374    }
375}
376
377// Convenience constructors for each level
378impl SASAOptions<AtomLevel> {
379    pub fn atom_level() -> Self {
380        Self::new()
381    }
382}
383
384impl SASAOptions<ResidueLevel> {
385    pub fn residue_level() -> Self {
386        Self::new()
387    }
388}
389
390impl SASAOptions<ChainLevel> {
391    pub fn chain_level() -> Self {
392        Self::new()
393    }
394}
395
396impl SASAOptions<ProteinLevel> {
397    pub fn protein_level() -> Self {
398        Self::new()
399    }
400}
401
402impl<T: SASAProcessor> SASAOptions<T> {
403    /// This function calculates the SASA for a given protein. The output type is determined by the level type parameter.
404    /// Probe radius and n_points can be customized, defaulting to 1.4 and 100 respectively.
405    /// If you want more fine-grained control you may want to use [calculate_sasa_internal] instead.
406    /// ## Example
407    /// ```
408    /// use pdbtbx::StrictnessLevel;
409    /// use rust_sasa::options::{SASAOptions, ResidueLevel};
410    /// let (mut pdb, _errors) = pdbtbx::open("./pdbs/example.cif").unwrap();
411    /// let result = SASAOptions::<ResidueLevel>::new().process(&pdb);
412    /// ```
413    pub fn process(&self, pdb: &PDB) -> Result<T::Output, SASACalcError> {
414        let (atoms, parent_to_atoms) = T::build_atoms_and_mapping(pdb)?;
415        let atom_sasa =
416            calculate_sasa_internal(&atoms, self.probe_radius, self.n_points, self.parallel);
417        T::process_atoms(&atoms, &atom_sasa, pdb, &parent_to_atoms)
418    }
419}