1use nalgebra::Vector3;
2use petgraph::visit::EdgeRef;
3
4pub fn uff_vdw_params(element: u8) -> (f32, f32) {
6 match element {
7 1 => (2.886, 0.044), 5 => (3.637, 0.180), 6 => (3.851, 0.105), 7 => (3.660, 0.069), 8 => (3.500, 0.060), 9 => (3.364, 0.050), 14 => (4.295, 0.402), 15 => (4.147, 0.305), 16 => (4.035, 0.274), 17 => (3.947, 0.227), 35 => (4.189, 0.251), 53 => (4.500, 0.339), _ => (3.851, 0.105), }
21}
22
23pub fn vdw_energy(p1: &Vector3<f32>, p2: &Vector3<f32>, r_star: f32, epsilon: f32) -> f32 {
25 let r = (p1 - p2).norm();
26 if !(0.5..=8.0).contains(&r) {
27 return 0.0;
28 }
29 let u = r_star / r;
30 let u6 = u * u * u * u * u * u;
31 let u12 = u6 * u6;
32 epsilon * (u12 - 2.0 * u6)
33}
34
35pub fn distance_constraint_energy(
38 p1: &Vector3<f32>,
39 p2: &Vector3<f32>,
40 min_len: f32,
41 max_len: f32,
42 k: f32,
43) -> f32 {
44 let d2 = (p1 - p2).norm_squared();
45 if d2 < min_len * min_len {
46 let d = d2.sqrt();
47 let diff = min_len - d;
48 0.5 * k * diff * diff
49 } else if d2 > max_len * max_len {
50 let d = d2.sqrt();
51 let diff = d - max_len;
52 0.5 * k * diff * diff
53 } else {
54 0.0
55 }
56}
57
58pub fn bond_stretch_energy(p1: &Vector3<f32>, p2: &Vector3<f32>, k_b: f32, r_eq: f32) -> f32 {
60 let r = (p1 - p2).norm();
61 0.5 * k_b * (r - r_eq).powi(2)
62}
63
64pub fn angle_bend_energy(
66 p1: &Vector3<f32>,
67 p2: &Vector3<f32>, p3: &Vector3<f32>,
69 k_theta: f32,
70 theta_eq: f32,
71) -> f32 {
72 let v1 = p1 - p2;
73 let v2 = p3 - p2;
74 let r1 = v1.norm();
75 let r2 = v2.norm();
76 if r1 < 1e-4 || r2 < 1e-4 {
77 return 0.0;
78 }
79 let cos_th = (v1.dot(&v2) / (r1 * r2)).clamp(-0.999999, 0.999999);
80
81 if (theta_eq - std::f32::consts::PI).abs() < 1e-4 {
83 return k_theta * (1.0 + cos_th);
85 }
86
87 let theta = cos_th.acos();
88 0.5 * k_theta * (theta - theta_eq).powi(2)
89}
90
91pub fn torsional_energy(
93 p1: &Vector3<f32>,
94 p2: &Vector3<f32>,
95 p3: &Vector3<f32>,
96 p4: &Vector3<f32>,
97 v: f32,
98 n: f32,
99 gamma: f32,
100) -> f32 {
101 let b1 = p2 - p1;
102 let b2 = p3 - p2;
103 let b3 = p4 - p3;
104
105 let n1 = b1.cross(&b2).normalize();
106 let n2 = b2.cross(&b3).normalize();
107 let m1 = n1.cross(&b2.normalize());
108
109 let x = n1.dot(&n2);
110 let y = m1.dot(&n2);
111 let phi = y.atan2(x);
112
113 v * (1.0 + (n * phi - gamma).cos())
114}
115
116pub fn bounds_energy(
119 p1: &Vector3<f32>,
120 p2: &Vector3<f32>,
121 lower: f32,
122 upper: f32,
123 k_bounds: f32,
124) -> f32 {
125 let r2 = (p1 - p2).norm_squared();
126 let u2 = upper * upper;
127 let l2 = lower * lower;
128 if r2 > u2 && u2 > 1e-6 {
129 let val = r2 / u2 - 1.0;
130 k_bounds * val * val
131 } else if r2 < l2 && l2 > 1e-6 {
132 let val = 2.0 * l2 / (l2 + r2.max(1e-6)) - 1.0;
134 k_bounds * val * val
135 } else {
136 0.0
137 }
138}
139
140pub fn oop_energy(
142 p_center: &Vector3<f32>,
143 p1: &Vector3<f32>,
144 p2: &Vector3<f32>,
145 p3: &Vector3<f32>,
146 k_oop: f32,
147 phi_eq: f32,
148) -> f32 {
149 let v1 = p1 - p_center;
150 let v2 = p2 - p_center;
151 let v3 = p3 - p_center;
152
153 let normal = v2.cross(&v3).normalize();
154 let dist = v1.dot(&normal);
155 let sin_phi = dist / v1.norm().max(1e-4);
156 let phi = sin_phi.asin();
157
158 0.5 * k_oop * (phi - phi_eq).powi(2)
159}
160
161pub fn chirality_energy(
163 p_center: &Vector3<f32>,
164 p1: &Vector3<f32>,
165 p2: &Vector3<f32>,
166 p3: &Vector3<f32>,
167 target_vol: f32,
168 k_chiral: f32,
169) -> f32 {
170 let v1 = p1 - p_center;
171 let v2 = p2 - p_center;
172 let v3 = p3 - p_center;
173 let vol = v1.dot(&v2.cross(&v3));
174 0.5 * k_chiral * (vol - target_vol).powi(2)
175}
176
177#[derive(Clone, Debug)]
178pub struct FFParams {
179 pub kb: f32,
180 pub k_theta: f32,
181 pub k_omega: f32,
182 pub k_oop: f32,
183 pub k_bounds: f32,
184 pub k_chiral: f32,
185 pub k_vdw: f32,
186}
187
188impl Default for FFParams {
189 fn default() -> Self {
190 Self {
191 kb: 500.0,
192 k_theta: 300.0,
193 k_omega: 20.0,
194 k_oop: 40.0,
195 k_bounds: 200.0,
196 k_chiral: 100.0,
197 k_vdw: 0.0,
198 }
199 }
200}
201
202pub fn calculate_total_energy(
203 coords: &nalgebra::DMatrix<f32>,
204 mol: &crate::graph::Molecule,
205 params: &FFParams,
206 bounds_matrix: &nalgebra::DMatrix<f64>,
207) -> f32 {
208 let n = mol.graph.node_count();
209 let mut energy = 0.0;
210
211 for edge in mol.graph.edge_references() {
213 let idx1 = edge.source().index();
214 let idx2 = edge.target().index();
215 let p1 = Vector3::new(coords[(idx1, 0)], coords[(idx1, 1)], coords[(idx1, 2)]);
216 let p2 = Vector3::new(coords[(idx2, 0)], coords[(idx2, 1)], coords[(idx2, 2)]);
217 let r_eq = crate::distgeom::get_bond_length(mol, edge.source(), edge.target()) as f32;
218 energy += bond_stretch_energy(&p1, &p2, params.kb, r_eq);
219 }
220
221 for i in 0..n {
223 let ni = petgraph::graph::NodeIndex::new(i);
224 let nbs: Vec<_> = mol.graph.neighbors(ni).collect();
225 for j in 0..nbs.len() {
226 for k in (j + 1)..nbs.len() {
227 let n1 = nbs[j];
228 let n2 = nbs[k];
229 let p1 = Vector3::new(
230 coords[(n1.index(), 0)],
231 coords[(n1.index(), 1)],
232 coords[(n1.index(), 2)],
233 );
234 let pc = Vector3::new(coords[(i, 0)], coords[(i, 1)], coords[(i, 2)]);
235 let p2 = Vector3::new(
236 coords[(n2.index(), 0)],
237 coords[(n2.index(), 1)],
238 coords[(n2.index(), 2)],
239 );
240 let ideal = crate::graph::get_corrected_ideal_angle(mol, ni, n1, n2) as f32;
241 energy += angle_bend_energy(&p1, &pc, &p2, params.k_theta, ideal);
242 }
243 }
244 }
245
246 for i in 0..n {
248 for j in (i + 1)..n {
249 let upper = bounds_matrix[(i, j)] as f32;
250 let lower = bounds_matrix[(j, i)] as f32;
251 let p1 = Vector3::new(coords[(i, 0)], coords[(i, 1)], coords[(i, 2)]);
252 let p2 = Vector3::new(coords[(j, 0)], coords[(j, 1)], coords[(j, 2)]);
253 energy += bounds_energy(&p1, &p2, lower, upper, params.k_bounds);
254 }
255 }
256
257 if params.k_oop.abs() > 1e-8 {
259 for i in 0..n {
260 let ni = petgraph::graph::NodeIndex::new(i);
261 if mol.graph[ni].hybridization != crate::graph::Hybridization::SP2 {
262 continue;
263 }
264 let nbs: Vec<_> = mol.graph.neighbors(ni).collect();
265 if nbs.len() != 3 {
266 continue;
267 }
268 let pc = Vector3::new(coords[(i, 0)], coords[(i, 1)], coords[(i, 2)]);
269 let p1 = Vector3::new(
270 coords[(nbs[0].index(), 0)],
271 coords[(nbs[0].index(), 1)],
272 coords[(nbs[0].index(), 2)],
273 );
274 let p2 = Vector3::new(
275 coords[(nbs[1].index(), 0)],
276 coords[(nbs[1].index(), 1)],
277 coords[(nbs[1].index(), 2)],
278 );
279 let p3 = Vector3::new(
280 coords[(nbs[2].index(), 0)],
281 coords[(nbs[2].index(), 1)],
282 coords[(nbs[2].index(), 2)],
283 );
284 let v1 = p1 - pc;
286 let v2 = p2 - pc;
287 let v3 = p3 - pc;
288 let vol = v1.dot(&v2.cross(&v3));
289 energy += params.k_oop * vol * vol;
290 }
291 }
292
293 if n >= 4 && params.k_omega.abs() > 1e-8 {
295 for edge in mol.graph.edge_references() {
296 let u = edge.source();
297 let v = edge.target();
298 let hyb_u = mol.graph[u].hybridization;
299 let hyb_v = mol.graph[v].hybridization;
300 if hyb_u == crate::graph::Hybridization::SP || hyb_v == crate::graph::Hybridization::SP
301 {
302 continue;
303 }
304
305 let (n_fold, gamma, weight) = torsion_params(hyb_u, hyb_v);
307
308 let neighbors_u: Vec<_> = mol.graph.neighbors(u).filter(|&x| x != v).collect();
309 let neighbors_v: Vec<_> = mol.graph.neighbors(v).filter(|&x| x != u).collect();
310
311 for &nu in &neighbors_u {
312 for &nv in &neighbors_v {
313 let (p1, p2, p3, p4) = (
314 Vector3::new(
315 coords[(nu.index(), 0)],
316 coords[(nu.index(), 1)],
317 coords[(nu.index(), 2)],
318 ),
319 Vector3::new(
320 coords[(u.index(), 0)],
321 coords[(u.index(), 1)],
322 coords[(u.index(), 2)],
323 ),
324 Vector3::new(
325 coords[(v.index(), 0)],
326 coords[(v.index(), 1)],
327 coords[(v.index(), 2)],
328 ),
329 Vector3::new(
330 coords[(nv.index(), 0)],
331 coords[(nv.index(), 1)],
332 coords[(nv.index(), 2)],
333 ),
334 );
335 energy += torsional_energy(
336 &p1,
337 &p2,
338 &p3,
339 &p4,
340 params.k_omega * weight,
341 n_fold,
342 gamma,
343 );
344 }
345 }
346 }
347 }
348
349 if n >= 4 {
351 for edge in mol.graph.edge_references() {
352 let u = edge.source();
353 let v = edge.target();
354 if crate::graph::min_path_excluding2(mol, u, v, u, v, 7).is_some() {
356 continue;
357 }
358 let m6 =
359 crate::forcefield::etkdg_lite::infer_etkdg_parameters(mol, u.index(), v.index());
360 if m6.v.iter().all(|&x| x.abs() < 1e-6) {
362 continue;
363 }
364
365 let neighbors_u: Vec<_> = mol.graph.neighbors(u).filter(|&x| x != v).collect();
366 let neighbors_v: Vec<_> = mol.graph.neighbors(v).filter(|&x| x != u).collect();
367 if neighbors_u.is_empty() || neighbors_v.is_empty() {
368 continue;
369 }
370 let nu = neighbors_u[0];
372 let nv = neighbors_v[0];
373 let (p1, p2, p3, p4) = (
374 Vector3::new(
375 coords[(nu.index(), 0)],
376 coords[(nu.index(), 1)],
377 coords[(nu.index(), 2)],
378 ),
379 Vector3::new(
380 coords[(u.index(), 0)],
381 coords[(u.index(), 1)],
382 coords[(u.index(), 2)],
383 ),
384 Vector3::new(
385 coords[(v.index(), 0)],
386 coords[(v.index(), 1)],
387 coords[(v.index(), 2)],
388 ),
389 Vector3::new(
390 coords[(nv.index(), 0)],
391 coords[(nv.index(), 1)],
392 coords[(nv.index(), 2)],
393 ),
394 );
395 energy +=
396 crate::forcefield::etkdg_lite::calc_torsion_energy_m6(&p1, &p2, &p3, &p4, &m6);
397 }
398 }
399
400 if params.k_vdw.abs() > 1e-8 {
402 let mut excluded = std::collections::HashSet::new();
404 for edge in mol.graph.edge_references() {
405 let a = edge.source().index();
406 let b = edge.target().index();
407 let (lo, hi) = if a < b { (a, b) } else { (b, a) };
408 excluded.insert((lo, hi));
409 }
410 for center in 0..n {
412 let nc = petgraph::graph::NodeIndex::new(center);
413 let nbs: Vec<_> = mol.graph.neighbors(nc).collect();
414 for j in 0..nbs.len() {
415 for k in (j + 1)..nbs.len() {
416 let a = nbs[j].index();
417 let b = nbs[k].index();
418 let (lo, hi) = if a < b { (a, b) } else { (b, a) };
419 excluded.insert((lo, hi));
420 }
421 }
422 }
423 let mut is_14 = std::collections::HashSet::new();
425 for edge in mol.graph.edge_references() {
426 let u = edge.source();
427 let v = edge.target();
428 let neighbors_u: Vec<_> = mol.graph.neighbors(u).filter(|&x| x != v).collect();
429 let neighbors_v: Vec<_> = mol.graph.neighbors(v).filter(|&x| x != u).collect();
430 for &nu in &neighbors_u {
431 for &nv in &neighbors_v {
432 let a = nu.index();
433 let b = nv.index();
434 if a == b {
435 continue;
436 }
437 let (lo, hi) = if a < b { (a, b) } else { (b, a) };
438 if !excluded.contains(&(lo, hi)) {
439 is_14.insert((lo, hi));
440 }
441 }
442 }
443 }
444
445 for i in 0..n {
446 let ei = mol.graph[petgraph::graph::NodeIndex::new(i)].element;
447 let (xi, di) = uff_vdw_params(ei);
448 let pi = Vector3::new(coords[(i, 0)], coords[(i, 1)], coords[(i, 2)]);
449 for j in (i + 1)..n {
450 if excluded.contains(&(i, j)) {
451 continue;
452 }
453 let ej = mol.graph[petgraph::graph::NodeIndex::new(j)].element;
454 let (xj, dj) = uff_vdw_params(ej);
455 let r_star = (xi + xj) * 0.5;
456 let eps_full = (di * dj).sqrt();
457 let scale = if is_14.contains(&(i, j)) { 0.5 } else { 1.0 };
458 let pj = Vector3::new(coords[(j, 0)], coords[(j, 1)], coords[(j, 2)]);
459 energy += params.k_vdw * scale * vdw_energy(&pi, &pj, r_star, eps_full);
460 }
461 }
462 }
463
464 energy
465}
466
467pub fn torsion_params(
469 hyb_u: crate::graph::Hybridization,
470 hyb_v: crate::graph::Hybridization,
471) -> (f32, f32, f32) {
472 use crate::graph::Hybridization::*;
473 let pi = std::f32::consts::PI;
474 match (hyb_u, hyb_v) {
475 (SP3, SP3) => (3.0, 0.0, 1.0), (SP2, SP2) => (2.0, pi, 5.0), (SP2, SP3) | (SP3, SP2) => (6.0, pi, 0.5), _ => (3.0, 0.0, 1.0),
479 }
480}