Skip to main content

sci_form/forcefield/
energy.rs

1use nalgebra::Vector3;
2use petgraph::visit::EdgeRef;
3
4/// UFF VDW parameters by element: (x1 = VDW distance in Å, d1 = well depth in kcal/mol)
5pub fn uff_vdw_params(element: u8) -> (f32, f32) {
6    match element {
7        1 => (2.886, 0.044),  // H
8        5 => (3.637, 0.180),  // B
9        6 => (3.851, 0.105),  // C
10        7 => (3.660, 0.069),  // N
11        8 => (3.500, 0.060),  // O
12        9 => (3.364, 0.050),  // F
13        14 => (4.295, 0.402), // Si
14        15 => (4.147, 0.305), // P
15        16 => (4.035, 0.274), // S
16        17 => (3.947, 0.227), // Cl
17        35 => (4.189, 0.251), // Br
18        53 => (4.500, 0.339), // I
19        _ => (3.851, 0.105),  // default to C
20    }
21}
22
23/// LJ 12-6 VDW energy between two atoms
24pub fn vdw_energy(p1: &Vector3<f32>, p2: &Vector3<f32>, r_star: f32, epsilon: f32) -> f32 {
25    let r = (p1 - p2).norm();
26    if !(0.5..=8.0).contains(&r) {
27        return 0.0;
28    }
29    let u = r_star / r;
30    let u6 = u * u * u * u * u * u;
31    let u12 = u6 * u6;
32    epsilon * (u12 - 2.0 * u6)
33}
34
35/// Flat-bottom distance constraint energy (matching RDKit's DistanceConstraintContribs)
36/// Zero energy when minLen <= d <= maxLen, harmonic penalty outside
37pub fn distance_constraint_energy(
38    p1: &Vector3<f32>,
39    p2: &Vector3<f32>,
40    min_len: f32,
41    max_len: f32,
42    k: f32,
43) -> f32 {
44    let d2 = (p1 - p2).norm_squared();
45    if d2 < min_len * min_len {
46        let d = d2.sqrt();
47        let diff = min_len - d;
48        0.5 * k * diff * diff
49    } else if d2 > max_len * max_len {
50        let d = d2.sqrt();
51        let diff = d - max_len;
52        0.5 * k * diff * diff
53    } else {
54        0.0
55    }
56}
57
58/// Harmonic bond stretching penalty (Hooke's Law)
59pub fn bond_stretch_energy(p1: &Vector3<f32>, p2: &Vector3<f32>, k_b: f32, r_eq: f32) -> f32 {
60    let r = (p1 - p2).norm();
61    0.5 * k_b * (r - r_eq).powi(2)
62}
63
64/// Harmonic angle bending penalty
65pub fn angle_bend_energy(
66    p1: &Vector3<f32>,
67    p2: &Vector3<f32>, // central atom
68    p3: &Vector3<f32>,
69    k_theta: f32,
70    theta_eq: f32,
71) -> f32 {
72    let v1 = p1 - p2;
73    let v2 = p3 - p2;
74    let r1 = v1.norm();
75    let r2 = v2.norm();
76    if r1 < 1e-4 || r2 < 1e-4 {
77        return 0.0;
78    }
79    let cos_th = (v1.dot(&v2) / (r1 * r2)).clamp(-0.999999, 0.999999);
80
81    // Linear angle special case: use cosine potential for stability
82    if (theta_eq - std::f32::consts::PI).abs() < 1e-4 {
83        // Special linear potential: E = k * (1 + cos(theta))
84        return k_theta * (1.0 + cos_th);
85    }
86
87    let theta = cos_th.acos();
88    0.5 * k_theta * (theta - theta_eq).powi(2)
89}
90
91/// Harmonic torsional potential (Dihedral)
92pub fn torsional_energy(
93    p1: &Vector3<f32>,
94    p2: &Vector3<f32>,
95    p3: &Vector3<f32>,
96    p4: &Vector3<f32>,
97    v: f32,
98    n: f32,
99    gamma: f32,
100) -> f32 {
101    let b1 = p2 - p1;
102    let b2 = p3 - p2;
103    let b3 = p4 - p3;
104
105    let n1 = b1.cross(&b2).normalize();
106    let n2 = b2.cross(&b3).normalize();
107    let m1 = n1.cross(&b2.normalize());
108
109    let x = n1.dot(&n2);
110    let y = m1.dot(&n2);
111    let phi = y.atan2(x);
112
113    v * (1.0 + (n * phi - gamma).cos())
114}
115
116/// Distance-based penalty (used in embedding refinement)
117/// Matches RDKit's DistViolationContribs formulation
118pub fn bounds_energy(
119    p1: &Vector3<f32>,
120    p2: &Vector3<f32>,
121    lower: f32,
122    upper: f32,
123    k_bounds: f32,
124) -> f32 {
125    let r2 = (p1 - p2).norm_squared();
126    let u2 = upper * upper;
127    let l2 = lower * lower;
128    if r2 > u2 && u2 > 1e-6 {
129        let val = r2 / u2 - 1.0;
130        k_bounds * val * val
131    } else if r2 < l2 && l2 > 1e-6 {
132        // RDKit formula: val = 2L²/(L²+d²) - 1
133        let val = 2.0 * l2 / (l2 + r2.max(1e-6)) - 1.0;
134        k_bounds * val * val
135    } else {
136        0.0
137    }
138}
139
140/// Harmonic Out-of-Plane bending penalty
141pub fn oop_energy(
142    p_center: &Vector3<f32>,
143    p1: &Vector3<f32>,
144    p2: &Vector3<f32>,
145    p3: &Vector3<f32>,
146    k_oop: f32,
147    phi_eq: f32,
148) -> f32 {
149    let v1 = p1 - p_center;
150    let v2 = p2 - p_center;
151    let v3 = p3 - p_center;
152
153    let normal = v2.cross(&v3).normalize();
154    let dist = v1.dot(&normal);
155    let sin_phi = dist / v1.norm().max(1e-4);
156    let phi = sin_phi.asin();
157
158    0.5 * k_oop * (phi - phi_eq).powi(2)
159}
160
161/// Chiral volume penalty
162pub fn chirality_energy(
163    p_center: &Vector3<f32>,
164    p1: &Vector3<f32>,
165    p2: &Vector3<f32>,
166    p3: &Vector3<f32>,
167    target_vol: f32,
168    k_chiral: f32,
169) -> f32 {
170    let v1 = p1 - p_center;
171    let v2 = p2 - p_center;
172    let v3 = p3 - p_center;
173    let vol = v1.dot(&v2.cross(&v3));
174    0.5 * k_chiral * (vol - target_vol).powi(2)
175}
176
177#[derive(Clone, Debug)]
178pub struct FFParams {
179    pub kb: f32,
180    pub k_theta: f32,
181    pub k_omega: f32,
182    pub k_oop: f32,
183    pub k_bounds: f32,
184    pub k_chiral: f32,
185    pub k_vdw: f32,
186}
187
188impl Default for FFParams {
189    fn default() -> Self {
190        Self {
191            kb: 500.0,
192            k_theta: 300.0,
193            k_omega: 20.0,
194            k_oop: 40.0,
195            k_bounds: 200.0,
196            k_chiral: 100.0,
197            k_vdw: 0.0,
198        }
199    }
200}
201
202pub fn calculate_total_energy(
203    coords: &nalgebra::DMatrix<f32>,
204    mol: &crate::graph::Molecule,
205    params: &FFParams,
206    bounds_matrix: &nalgebra::DMatrix<f64>,
207) -> f32 {
208    let n = mol.graph.node_count();
209    let mut energy = 0.0;
210
211    // 1. Bond Stretch
212    for edge in mol.graph.edge_references() {
213        let idx1 = edge.source().index();
214        let idx2 = edge.target().index();
215        let p1 = Vector3::new(coords[(idx1, 0)], coords[(idx1, 1)], coords[(idx1, 2)]);
216        let p2 = Vector3::new(coords[(idx2, 0)], coords[(idx2, 1)], coords[(idx2, 2)]);
217        let r_eq = crate::distgeom::get_bond_length(mol, edge.source(), edge.target()) as f32;
218        energy += bond_stretch_energy(&p1, &p2, params.kb, r_eq);
219    }
220
221    // 2. Angles
222    for i in 0..n {
223        let ni = petgraph::graph::NodeIndex::new(i);
224        let nbs: Vec<_> = mol.graph.neighbors(ni).collect();
225        for j in 0..nbs.len() {
226            for k in (j + 1)..nbs.len() {
227                let n1 = nbs[j];
228                let n2 = nbs[k];
229                let p1 = Vector3::new(
230                    coords[(n1.index(), 0)],
231                    coords[(n1.index(), 1)],
232                    coords[(n1.index(), 2)],
233                );
234                let pc = Vector3::new(coords[(i, 0)], coords[(i, 1)], coords[(i, 2)]);
235                let p2 = Vector3::new(
236                    coords[(n2.index(), 0)],
237                    coords[(n2.index(), 1)],
238                    coords[(n2.index(), 2)],
239                );
240                let ideal = crate::graph::get_corrected_ideal_angle(mol, ni, n1, n2) as f32;
241                energy += angle_bend_energy(&p1, &pc, &p2, params.k_theta, ideal);
242            }
243        }
244    }
245
246    // 3. Distance Bounds
247    for i in 0..n {
248        for j in (i + 1)..n {
249            let upper = bounds_matrix[(i, j)] as f32;
250            let lower = bounds_matrix[(j, i)] as f32;
251            let p1 = Vector3::new(coords[(i, 0)], coords[(i, 1)], coords[(i, 2)]);
252            let p2 = Vector3::new(coords[(j, 0)], coords[(j, 1)], coords[(j, 2)]);
253            energy += bounds_energy(&p1, &p2, lower, upper, params.k_bounds);
254        }
255    }
256
257    // 4. Out-of-Plane bending for SP2 atoms with 3 neighbors (volume-based)
258    if params.k_oop.abs() > 1e-8 {
259        for i in 0..n {
260            let ni = petgraph::graph::NodeIndex::new(i);
261            if mol.graph[ni].hybridization != crate::graph::Hybridization::SP2 {
262                continue;
263            }
264            let nbs: Vec<_> = mol.graph.neighbors(ni).collect();
265            if nbs.len() != 3 {
266                continue;
267            }
268            let pc = Vector3::new(coords[(i, 0)], coords[(i, 1)], coords[(i, 2)]);
269            let p1 = Vector3::new(
270                coords[(nbs[0].index(), 0)],
271                coords[(nbs[0].index(), 1)],
272                coords[(nbs[0].index(), 2)],
273            );
274            let p2 = Vector3::new(
275                coords[(nbs[1].index(), 0)],
276                coords[(nbs[1].index(), 1)],
277                coords[(nbs[1].index(), 2)],
278            );
279            let p3 = Vector3::new(
280                coords[(nbs[2].index(), 0)],
281                coords[(nbs[2].index(), 1)],
282                coords[(nbs[2].index(), 2)],
283            );
284            // Triple product: V = (p1-pc)·((p2-pc)×(p3-pc))
285            let v1 = p1 - pc;
286            let v2 = p2 - pc;
287            let v3 = p3 - pc;
288            let vol = v1.dot(&v2.cross(&v3));
289            energy += params.k_oop * vol * vol;
290        }
291    }
292
293    // 5. Torsional energy (UFF-style hybridization-dependent)
294    if n >= 4 && params.k_omega.abs() > 1e-8 {
295        for edge in mol.graph.edge_references() {
296            let u = edge.source();
297            let v = edge.target();
298            let hyb_u = mol.graph[u].hybridization;
299            let hyb_v = mol.graph[v].hybridization;
300            if hyb_u == crate::graph::Hybridization::SP || hyb_v == crate::graph::Hybridization::SP
301            {
302                continue;
303            }
304
305            // Determine torsion params based on central bond hybridization
306            let (n_fold, gamma, weight) = torsion_params(hyb_u, hyb_v);
307
308            let neighbors_u: Vec<_> = mol.graph.neighbors(u).filter(|&x| x != v).collect();
309            let neighbors_v: Vec<_> = mol.graph.neighbors(v).filter(|&x| x != u).collect();
310
311            for &nu in &neighbors_u {
312                for &nv in &neighbors_v {
313                    let (p1, p2, p3, p4) = (
314                        Vector3::new(
315                            coords[(nu.index(), 0)],
316                            coords[(nu.index(), 1)],
317                            coords[(nu.index(), 2)],
318                        ),
319                        Vector3::new(
320                            coords[(u.index(), 0)],
321                            coords[(u.index(), 1)],
322                            coords[(u.index(), 2)],
323                        ),
324                        Vector3::new(
325                            coords[(v.index(), 0)],
326                            coords[(v.index(), 1)],
327                            coords[(v.index(), 2)],
328                        ),
329                        Vector3::new(
330                            coords[(nv.index(), 0)],
331                            coords[(nv.index(), 1)],
332                            coords[(nv.index(), 2)],
333                        ),
334                    );
335                    energy += torsional_energy(
336                        &p1,
337                        &p2,
338                        &p3,
339                        &p4,
340                        params.k_omega * weight,
341                        n_fold,
342                        gamma,
343                    );
344                }
345            }
346        }
347    }
348
349    // 6. ETKDG-lite M6 torsion preferences (applied to non-ring rotatable bonds)
350    if n >= 4 {
351        for edge in mol.graph.edge_references() {
352            let u = edge.source();
353            let v = edge.target();
354            // Skip ring bonds — ETKDG preferences are for rotatable bonds
355            if crate::graph::min_path_excluding2(mol, u, v, u, v, 7).is_some() {
356                continue;
357            }
358            let m6 =
359                crate::forcefield::etkdg_lite::infer_etkdg_parameters(mol, u.index(), v.index());
360            // Skip if all coefficients are zero
361            if m6.v.iter().all(|&x| x.abs() < 1e-6) {
362                continue;
363            }
364
365            let neighbors_u: Vec<_> = mol.graph.neighbors(u).filter(|&x| x != v).collect();
366            let neighbors_v: Vec<_> = mol.graph.neighbors(v).filter(|&x| x != u).collect();
367            if neighbors_u.is_empty() || neighbors_v.is_empty() {
368                continue;
369            }
370            // Use first neighbor pair only (matching RDKit's approach for ETKDG)
371            let nu = neighbors_u[0];
372            let nv = neighbors_v[0];
373            let (p1, p2, p3, p4) = (
374                Vector3::new(
375                    coords[(nu.index(), 0)],
376                    coords[(nu.index(), 1)],
377                    coords[(nu.index(), 2)],
378                ),
379                Vector3::new(
380                    coords[(u.index(), 0)],
381                    coords[(u.index(), 1)],
382                    coords[(u.index(), 2)],
383                ),
384                Vector3::new(
385                    coords[(v.index(), 0)],
386                    coords[(v.index(), 1)],
387                    coords[(v.index(), 2)],
388                ),
389                Vector3::new(
390                    coords[(nv.index(), 0)],
391                    coords[(nv.index(), 1)],
392                    coords[(nv.index(), 2)],
393                ),
394            );
395            energy +=
396                crate::forcefield::etkdg_lite::calc_torsion_energy_m6(&p1, &p2, &p3, &p4, &m6);
397        }
398    }
399
400    // 7. VDW non-bonded interactions (1-4+ pairs)
401    if params.k_vdw.abs() > 1e-8 {
402        // Build exclusion set: 1-2 and 1-3 pairs
403        let mut excluded = std::collections::HashSet::new();
404        for edge in mol.graph.edge_references() {
405            let a = edge.source().index();
406            let b = edge.target().index();
407            let (lo, hi) = if a < b { (a, b) } else { (b, a) };
408            excluded.insert((lo, hi));
409        }
410        // 1-3 pairs: i-center-j for each angle
411        for center in 0..n {
412            let nc = petgraph::graph::NodeIndex::new(center);
413            let nbs: Vec<_> = mol.graph.neighbors(nc).collect();
414            for j in 0..nbs.len() {
415                for k in (j + 1)..nbs.len() {
416                    let a = nbs[j].index();
417                    let b = nbs[k].index();
418                    let (lo, hi) = if a < b { (a, b) } else { (b, a) };
419                    excluded.insert((lo, hi));
420                }
421            }
422        }
423        // 1-4 pairs: need to identify for 0.5 scaling
424        let mut is_14 = std::collections::HashSet::new();
425        for edge in mol.graph.edge_references() {
426            let u = edge.source();
427            let v = edge.target();
428            let neighbors_u: Vec<_> = mol.graph.neighbors(u).filter(|&x| x != v).collect();
429            let neighbors_v: Vec<_> = mol.graph.neighbors(v).filter(|&x| x != u).collect();
430            for &nu in &neighbors_u {
431                for &nv in &neighbors_v {
432                    let a = nu.index();
433                    let b = nv.index();
434                    if a == b {
435                        continue;
436                    }
437                    let (lo, hi) = if a < b { (a, b) } else { (b, a) };
438                    if !excluded.contains(&(lo, hi)) {
439                        is_14.insert((lo, hi));
440                    }
441                }
442            }
443        }
444
445        for i in 0..n {
446            let ei = mol.graph[petgraph::graph::NodeIndex::new(i)].element;
447            let (xi, di) = uff_vdw_params(ei);
448            let pi = Vector3::new(coords[(i, 0)], coords[(i, 1)], coords[(i, 2)]);
449            for j in (i + 1)..n {
450                if excluded.contains(&(i, j)) {
451                    continue;
452                }
453                let ej = mol.graph[petgraph::graph::NodeIndex::new(j)].element;
454                let (xj, dj) = uff_vdw_params(ej);
455                let r_star = (xi + xj) * 0.5;
456                let eps_full = (di * dj).sqrt();
457                let scale = if is_14.contains(&(i, j)) { 0.5 } else { 1.0 };
458                let pj = Vector3::new(coords[(j, 0)], coords[(j, 1)], coords[(j, 2)]);
459                energy += params.k_vdw * scale * vdw_energy(&pi, &pj, r_star, eps_full);
460            }
461        }
462    }
463
464    energy
465}
466
467/// Determine torsion periodicity, phase, and relative weight based on UFF rules
468pub fn torsion_params(
469    hyb_u: crate::graph::Hybridization,
470    hyb_v: crate::graph::Hybridization,
471) -> (f32, f32, f32) {
472    use crate::graph::Hybridization::*;
473    let pi = std::f32::consts::PI;
474    match (hyb_u, hyb_v) {
475        (SP3, SP3) => (3.0, 0.0, 1.0),             // staggered, normal weight
476        (SP2, SP2) => (2.0, pi, 5.0),              // planar, strong weight
477        (SP2, SP3) | (SP3, SP2) => (6.0, pi, 0.5), // 6-fold weak barrier
478        _ => (3.0, 0.0, 1.0),
479    }
480}