Skip to main content

uff_relax/forcefield/
mod.rs

1pub mod interactions;
2#[cfg(not(target_arch = "wasm32"))]
3pub mod parallel;
4pub mod sequential;
5
6use crate::atom::{Atom, Bond, UffAtomType};
7use crate::cell::UnitCell;
8use crate::params::element_symbol;
9use glam::DVec3;
10
11#[cfg(not(target_arch = "wasm32"))]
12const PARALLEL_THRESHOLD: usize = 1000;
13
14#[derive(Debug, Default, Clone, Copy)]
15pub struct EnergyTerms {
16    pub bond: f64,
17    pub angle: f64,
18    pub torsion: f64,
19    pub electrostatic: f64,
20    pub non_bonded: f64,
21    pub total: f64,
22}
23
24/// Represents a molecular system consisting of atoms, bonds, and an optional unit cell.
25pub struct System {
26    /// List of atoms in the system.
27    pub atoms: Vec<Atom>,
28    /// List of chemical bonds.
29    pub bonds: Vec<Bond>,
30    /// Unit cell for periodic boundary conditions.
31    pub cell: UnitCell,
32}
33
34impl System {
35    /// Creates a new molecular system and automatically assigns UFF atom types.
36    ///
37    /// # Arguments
38    /// * `atoms` - Initial atom positions and elements.
39    /// * `bonds` - Connectivity and bond orders.
40    /// * `cell` - Boundary conditions (use `UnitCell::new_none()` for gas phase).
41    pub fn new(atoms: Vec<Atom>, bonds: Vec<Bond>, cell: UnitCell) -> Self {
42        let mut system = Self { atoms, bonds, cell };
43        system.auto_assign_uff_types();
44        system
45    }
46
47    /// Automatically infers UFF atom types based on element, connectivity, and bond orders.
48    pub fn auto_assign_uff_types(&mut self) {
49        let n = self.atoms.len();
50        let mut adj = vec![Vec::new(); n];
51        for bond in &self.bonds {
52            adj[bond.atom_indices.0].push(bond);
53            adj[bond.atom_indices.1].push(bond);
54        }
55
56        for i in 0..n {
57            let z = self.atoms[i].element;
58            let symbol = element_symbol(z);
59            let neighbors = &adj[i];
60            let n_neighbors = neighbors.len();
61            let has_order_1_5 = neighbors.iter().any(|b| (b.order - 1.5).abs() < 0.1);
62            let has_order_2_0 = neighbors.iter().any(|b| (b.order - 2.0).abs() < 0.1);
63            let bond_order_sum: f32 = neighbors.iter().map(|b| b.order).sum();
64
65            let label = match z {
66                1 => "H_".to_string(),
67                6 => { // Carbon
68                    match n_neighbors {
69                        4 => "C_3".to_string(),
70                        3 => if has_order_1_5 || has_order_2_0 { "C_R".to_string() } else { "C_2".to_string() },
71                        2 => "C_1".to_string(),
72                        _ => "C_3".to_string(),
73                    }
74                }
75                7 => { // Nitrogen
76                    match n_neighbors {
77                        3 => if has_order_1_5 { "N_R".to_string() } else { "N_3".to_string() },
78                        2 => "N_2".to_string(),
79                        1 => "N_1".to_string(),
80                        _ => "N_3".to_string(),
81                    }
82                }
83                8 => { // Oxygen
84                    if has_order_1_5 { "O_R".to_string() }
85                    else if n_neighbors == 1 && has_order_2_0 { "O_2".to_string() }
86                    else { "O_3".to_string() }
87                }
88                9 => "F_".to_string(),
89                15 => { // Phosphorus
90                    if n_neighbors >= 4 || bond_order_sum > 4.0 { "P_3+5".to_string() }
91                    else { "P_3+3".to_string() }
92                }
93                16 => { // Sulfur
94                    if has_order_1_5 { "S_R".to_string() }
95                    else if has_order_2_0 && n_neighbors == 1 { "S_2".to_string() }
96                    else if n_neighbors == 3 || (bond_order_sum > 3.0 && bond_order_sum < 5.0) { "S_3+4".to_string() }
97                    else if n_neighbors >= 4 || bond_order_sum >= 5.0 { "S_3+6".to_string() }
98                    else { "S_3+2".to_string() }
99                }
100                17 => "Cl".to_string(),
101                35 => "Br".to_string(),
102                53 => "I_".to_string(),
103                _ => {
104                    if n_neighbors == 0 { format!("{}_", symbol) } 
105                    else { format!("{}_", symbol) } // Always use symbol_ for generic match
106                }
107            };
108            self.atoms[i].uff_type = UffAtomType(label);
109        }
110    }
111
112    /// Computes forces and total energy breakdown.
113    pub fn compute_forces(&mut self) -> EnergyTerms {
114        self.compute_forces_with_threads(0, 6.0) // Default auto, cutoff 6.0
115    }
116
117    pub fn compute_forces_with_threads(&mut self, _num_threads: usize, _cutoff: f64) -> EnergyTerms {
118        #[cfg(target_arch = "wasm32")]
119        {
120            return self.compute_forces_serial(_cutoff);
121        }
122
123        #[cfg(not(target_arch = "wasm32"))]
124        {
125            let num_threads = _num_threads;
126            let cutoff = _cutoff;
127            if num_threads == 1 {
128                return self.compute_forces_serial(cutoff);
129            }
130
131            let use_parallel = if num_threads > 1 {
132                true
133            } else {
134                self.atoms.len() >= PARALLEL_THRESHOLD
135            };
136
137            if use_parallel {
138                let threads = if num_threads > 0 { 
139                    num_threads 
140                } else { 
141                    std::env::var("RAYON_NUM_THREADS")
142                        .ok()
143                        .and_then(|s| s.parse().ok())
144                        .unwrap_or(4)
145                };
146                let pool = rayon::ThreadPoolBuilder::new().num_threads(threads).build().unwrap();
147                pool.install(|| self.compute_forces_parallel(cutoff))
148            } else {
149                self.compute_forces_serial(cutoff)
150            }
151        }
152    }
153
154    fn compute_forces_serial(&mut self, cutoff: f64) -> EnergyTerms {
155        let mut energy = EnergyTerms::default();
156        for atom in &mut self.atoms { atom.force = DVec3::ZERO; }
157        
158        let mut adj = vec![Vec::new(); self.atoms.len()];
159        for b in &self.bonds {
160            let (u, v) = b.atom_indices;
161            adj[u].push(v);
162            adj[v].push(u);
163        }
164
165        energy.bond = self.compute_bond_forces_sequential();
166        energy.angle = self.compute_angle_forces_sequential();
167        energy.torsion = self.compute_torsion_forces_sequential();
168        let (nb, el) = self.compute_non_bonded_forces_sequential_cell_list(&adj, cutoff);
169        energy.non_bonded = nb;
170        energy.electrostatic = el;
171        energy.total = energy.bond + energy.angle + energy.torsion + energy.non_bonded + energy.electrostatic;
172        
173        energy
174    }
175
176    #[cfg(not(target_arch = "wasm32"))]
177    fn compute_forces_parallel(&mut self, cutoff: f64) -> EnergyTerms {
178        let mut energy = EnergyTerms::default();
179        for atom in &mut self.atoms { atom.force = DVec3::ZERO; }
180        
181        let mut adj = vec![Vec::new(); self.atoms.len()];
182        for b in &self.bonds {
183            let (u, v) = b.atom_indices;
184            adj[u].push(v);
185            adj[v].push(u);
186        }
187
188        energy.bond = self.compute_bond_forces_parallel();
189        energy.angle = self.compute_angle_forces_parallel();
190        energy.torsion = self.compute_torsion_forces_parallel();
191        let (nb, el) = self.compute_non_bonded_forces_parallel_cell_list(&adj, cutoff);
192        energy.non_bonded = nb;
193        energy.electrostatic = el;
194        energy.total = energy.bond + energy.angle + energy.torsion + energy.non_bonded + energy.electrostatic;
195        
196        energy
197    }
198
199    pub(crate) fn get_cell_neighbors(&self, cl: &crate::spatial::CellList, pos: DVec3, _cutoff: f64) -> Vec<usize> {
200        let mut neighbors = Vec::new();
201        let rel = pos - cl.min_p;
202        let ix = (rel.x / cl.cell_size.x) as i32;
203        let iy = (rel.y / cl.cell_size.y) as i32;
204        let iz = (rel.z / cl.cell_size.z) as i32;
205
206        for dx in -1..=1 {
207            for dy in -1..=1 {
208                for dz in -1..=1 {
209                    let mut nx = ix + dx; let mut ny = iy + dy; let mut nz = iz + dz;
210                    
211                    // PBC wrap for cells
212                    if nx < 0 { nx += cl.dx as i32; } else if nx >= cl.dx as i32 { nx -= cl.dx as i32; }
213                    if ny < 0 { ny += cl.dy as i32; } else if ny >= cl.dy as i32 { ny -= cl.dy as i32; }
214                    if nz < 0 { nz += cl.dz as i32; } else if nz >= cl.dz as i32 { nz -= cl.dz as i32; }
215
216                    if nx >= 0 && nx < cl.dx as i32 && ny >= 0 && ny < cl.dy as i32 && nz >= 0 && nz < cl.dz as i32 {
217                        let idx = (nx as usize * cl.dy * cl.dz) + (ny as usize * cl.dz) + nz as usize;
218                        neighbors.extend(&cl.cells[idx]);
219                    }
220                }
221            }
222        }
223        
224        // Remove duplicates (e.g. if cell size is large or system is small)
225        neighbors.sort_unstable();
226        neighbors.dedup();
227        neighbors
228    }
229
230        pub(crate) fn get_exclusion_scale(&self, i: usize, j: usize, adj: &[Vec<usize>]) -> (bool, f64) {
231
232            if i == j { return (true, 0.0); }
233
234            
235
236            // 1-2 neighbors: Strictly excluded from LJ
237
238            for &n1 in &adj[i] {
239
240                if n1 == j { return (true, 0.0); }
241
242            }
243
244            
245
246            // 1-3 neighbors: Small scale (0.1) to prevent collapse
247
248            for &n1 in &adj[i] {
249
250                for &n2 in &adj[n1] {
251
252                    if n2 == j { return (false, 0.1); }
253
254                }
255
256            }
257
258            
259
260            // 1-4 neighbors: Standard scale 0.5
261
262            for &n1 in &adj[i] {
263
264                for &n2 in &adj[n1] {
265
266                    for &n3 in &adj[n2] {
267
268                        if n3 == j { return (false, 0.5); }
269
270                    }
271
272                }
273
274            }
275
276            
277
278            (false, 1.0)
279
280        }
281
282    }
283
284