Skip to main content

uff_relax/forcefield/
mod.rs

1pub mod interactions;
2pub mod parallel;
3pub mod sequential;
4
5use crate::atom::{Atom, Bond, UffAtomType};
6use crate::cell::UnitCell;
7use crate::params::element_symbol;
8use glam::DVec3;
9
10const PARALLEL_THRESHOLD: usize = 1000;
11
12#[derive(Debug, Default, Clone, Copy)]
13pub struct EnergyTerms {
14    pub bond: f64,
15    pub angle: f64,
16    pub torsion: f64,
17    pub non_bonded: f64,
18    pub total: f64,
19}
20
21/// Represents a molecular system consisting of atoms, bonds, and an optional unit cell.
22pub struct System {
23    /// List of atoms in the system.
24    pub atoms: Vec<Atom>,
25    /// List of chemical bonds.
26    pub bonds: Vec<Bond>,
27    /// Unit cell for periodic boundary conditions.
28    pub cell: UnitCell,
29}
30
31impl System {
32    /// Creates a new molecular system and automatically assigns UFF atom types.
33    ///
34    /// # Arguments
35    /// * `atoms` - Initial atom positions and elements.
36    /// * `bonds` - Connectivity and bond orders.
37    /// * `cell` - Boundary conditions (use `UnitCell::new_none()` for gas phase).
38    pub fn new(atoms: Vec<Atom>, bonds: Vec<Bond>, cell: UnitCell) -> Self {
39        let mut system = Self { atoms, bonds, cell };
40        system.auto_assign_uff_types();
41        system
42    }
43
44    /// Automatically infers UFF atom types based on element, connectivity, and bond orders.
45    pub fn auto_assign_uff_types(&mut self) {
46        let n = self.atoms.len();
47        let mut adj = vec![Vec::new(); n];
48        for bond in &self.bonds {
49            adj[bond.atom_indices.0].push(bond);
50            adj[bond.atom_indices.1].push(bond);
51        }
52
53        for i in 0..n {
54            let z = self.atoms[i].element;
55            let symbol = element_symbol(z);
56            let neighbors = &adj[i];
57            let n_neighbors = neighbors.len();
58            let has_order_1_5 = neighbors.iter().any(|b| (b.order - 1.5).abs() < 0.1);
59            let has_order_2_0 = neighbors.iter().any(|b| (b.order - 2.0).abs() < 0.1);
60
61            let label = match z {
62                6 => { // Carbon
63                    match n_neighbors {
64                        4 => "C_3".to_string(),
65                        3 => if has_order_1_5 || has_order_2_0 { "C_R".to_string() } else { "C_2".to_string() },
66                        2 => "C_1".to_string(),
67                        _ => "C_3".to_string(),
68                    }
69                }
70                1 => "H_".to_string(),
71                7 => { // Nitrogen
72                    match n_neighbors {
73                        3 => if has_order_1_5 { "N_R".to_string() } else { "N_3".to_string() },
74                        2 => "N_2".to_string(),
75                        1 => "N_1".to_string(),
76                        _ => "N_3".to_string(),
77                    }
78                }
79                8 => { // Oxygen
80                    if has_order_1_5 { "O_R".to_string() }
81                    else if n_neighbors == 1 && has_order_2_0 { "O_2".to_string() }
82                    else { "O_3".to_string() }
83                }
84                _ => {
85                    if n_neighbors == 0 { format!("{}_", symbol) } 
86                    else { format!("{}_{}", symbol, n_neighbors) }
87                }
88            };
89            self.atoms[i].uff_type = UffAtomType(label);
90        }
91    }
92
93    /// Computes forces and total energy breakdown.
94    pub fn compute_forces(&mut self) -> EnergyTerms {
95        self.compute_forces_with_threads(0, 6.0) // Default auto, cutoff 6.0
96    }
97
98    pub fn compute_forces_with_threads(&mut self, num_threads: usize, cutoff: f64) -> EnergyTerms {
99        if num_threads == 1 {
100            return self.compute_forces_serial(cutoff);
101        }
102
103        let use_parallel = if num_threads > 1 {
104            true
105        } else {
106            self.atoms.len() >= PARALLEL_THRESHOLD
107        };
108
109        if use_parallel {
110            if num_threads > 1 {
111                let pool = rayon::ThreadPoolBuilder::new().num_threads(num_threads).build().unwrap();
112                pool.install(|| self.compute_forces_parallel(cutoff))
113            } else {
114                crate::init_parallelism(None);
115                self.compute_forces_parallel(cutoff)
116            }
117        } else {
118            self.compute_forces_serial(cutoff)
119        }
120    }
121
122    fn compute_forces_serial(&mut self, cutoff: f64) -> EnergyTerms {
123        let mut energy = EnergyTerms::default();
124        for atom in &mut self.atoms { atom.force = DVec3::ZERO; }
125        
126        let mut adj = vec![Vec::new(); self.atoms.len()];
127        for b in &self.bonds {
128            let (u, v) = b.atom_indices;
129            adj[u].push(v);
130            adj[v].push(u);
131        }
132
133        energy.bond = self.compute_bond_forces_sequential();
134        energy.angle = self.compute_angle_forces_sequential();
135        energy.torsion = self.compute_torsion_forces_sequential();
136        energy.non_bonded = self.compute_non_bonded_forces_sequential_cell_list(&adj, cutoff);
137        energy.total = energy.bond + energy.angle + energy.torsion + energy.non_bonded;
138        
139        energy
140    }
141
142    fn compute_forces_parallel(&mut self, cutoff: f64) -> EnergyTerms {
143        let mut energy = EnergyTerms::default();
144        for atom in &mut self.atoms { atom.force = DVec3::ZERO; }
145        
146        let mut adj = vec![Vec::new(); self.atoms.len()];
147        for b in &self.bonds {
148            let (u, v) = b.atom_indices;
149            adj[u].push(v);
150            adj[v].push(u);
151        }
152
153        energy.bond = self.compute_bond_forces_parallel();
154        energy.angle = self.compute_angle_forces_parallel();
155        energy.torsion = self.compute_torsion_forces_parallel();
156        energy.non_bonded = self.compute_non_bonded_forces_parallel_cell_list(&adj, cutoff);
157        energy.total = energy.bond + energy.angle + energy.torsion + energy.non_bonded;
158        
159        energy
160    }
161
162    pub(crate) fn get_cell_neighbors(&self, cl: &crate::spatial::CellList, pos: DVec3, _cutoff: f64) -> Vec<usize> {
163        let mut neighbors = Vec::new();
164        let rel = pos - cl.min_p;
165        let ix = (rel.x / cl.cell_size.x) as i32;
166        let iy = (rel.y / cl.cell_size.y) as i32;
167        let iz = (rel.z / cl.cell_size.z) as i32;
168
169        for dx in -1..=1 {
170            for dy in -1..=1 {
171                for dz in -1..=1 {
172                    let nx = ix + dx; let ny = iy + dy; let nz = iz + dz;
173                    if nx >= 0 && nx < cl.dx as i32 && ny >= 0 && ny < cl.dy as i32 && nz >= 0 && nz < cl.dz as i32 {
174                        let idx = (nx as usize * cl.dy * cl.dz) + (ny as usize * cl.dz) + nz as usize;
175                        neighbors.extend(&cl.cells[idx]);
176                    }
177                }
178            }
179        }
180        neighbors
181    }
182
183    pub(crate) fn get_exclusion_scale(&self, i: usize, j: usize, adj: &[Vec<usize>]) -> (bool, f64) {
184        for &n1 in &adj[i] {
185            if n1 == j { return (true, 0.0); }
186        }
187        for &n1 in &adj[i] {
188            for &n2 in &adj[n1] {
189                if n2 == j { return (true, 0.0); }
190            }
191        }
192        for &n1 in &adj[i] {
193            for &n2 in &adj[n1] {
194                for &n3 in &adj[n2] {
195                    if n3 == j { return (false, 0.5); }
196                }
197            }
198        }
199        (false, 1.0)
200    }
201}