1use crate::forcefield::bounds_ff::ChiralSet;
5use crate::graph::Molecule;
6use nalgebra::{DMatrix, Vector3};
7
8const MIN_TETRAHEDRAL_CHIRAL_VOL: f64 = 0.50;
9const TETRAHEDRAL_CENTERINVOLUME_TOL: f64 = 0.30;
10pub const MAX_MINIMIZED_E_PER_ATOM: f32 = 0.05;
11
12pub struct TetrahedralCenter {
14 pub center: usize,
15 pub neighbors: [usize; 4],
16 pub in_small_ring: bool,
17}
18
19pub fn identify_tetrahedral_centers(mol: &Molecule) -> Vec<TetrahedralCenter> {
25 let n = mol.graph.node_count();
26 let rings = find_sssr(mol);
28 let mut ring_count = vec![0usize; n];
29 let mut in_3_ring = vec![false; n];
30 let mut small_ring_count = vec![0usize; n]; for ring in &rings {
32 for &atom_idx in ring {
33 ring_count[atom_idx] += 1;
34 if ring.len() == 3 {
35 in_3_ring[atom_idx] = true;
36 }
37 if ring.len() < 5 {
38 small_ring_count[atom_idx] += 1;
39 }
40 }
41 }
42
43 let mut centers = Vec::new();
44 for i in 0..n {
45 let ni = petgraph::graph::NodeIndex::new(i);
46 let atom = &mol.graph[ni];
47 let elem = atom.element;
49 if elem != 6 && elem != 7 {
50 continue;
51 }
52 let nbs: Vec<_> = mol.graph.neighbors(ni).collect();
53 if nbs.len() != 4 {
54 continue;
55 }
56
57 if ring_count[i] < 2 || in_3_ring[i] {
59 continue;
60 }
61
62 centers.push(TetrahedralCenter {
63 center: i,
64 neighbors: [
65 nbs[0].index(),
66 nbs[1].index(),
67 nbs[2].index(),
68 nbs[3].index(),
69 ],
70 in_small_ring: small_ring_count[i] > 1,
71 });
72 }
73 centers
74}
75
76pub fn find_sssr_pub(mol: &Molecule) -> Vec<Vec<usize>> {
79 find_sssr(mol)
80}
81fn find_sssr(mol: &Molecule) -> Vec<Vec<usize>> {
82 use std::collections::VecDeque;
83 let n = mol.graph.node_count();
84 if n == 0 {
85 return vec![];
86 }
87
88 let num_edges = mol.graph.edge_count();
90 let mut visited = vec![false; n];
92 let mut num_components = 0;
93 for start in 0..n {
94 if visited[start] {
95 continue;
96 }
97 num_components += 1;
98 let mut queue = VecDeque::new();
99 queue.push_back(start);
100 visited[start] = true;
101 while let Some(curr) = queue.pop_front() {
102 for nb in mol.graph.neighbors(petgraph::graph::NodeIndex::new(curr)) {
103 if !visited[nb.index()] {
104 visited[nb.index()] = true;
105 queue.push_back(nb.index());
106 }
107 }
108 }
109 }
110 let cycle_rank = (num_edges + num_components).saturating_sub(n);
111 if cycle_rank == 0 {
112 return vec![];
113 }
114
115 let mut candidates: Vec<Vec<usize>> = Vec::new();
118
119 for root in 0..n {
120 let mut dist = vec![usize::MAX; n];
121 let mut parent = vec![usize::MAX; n];
122 dist[root] = 0;
123 let mut queue = VecDeque::new();
124 queue.push_back(root);
125
126 while let Some(curr) = queue.pop_front() {
127 for nb in mol.graph.neighbors(petgraph::graph::NodeIndex::new(curr)) {
128 let nb_idx = nb.index();
129 if dist[nb_idx] == usize::MAX {
130 dist[nb_idx] = dist[curr] + 1;
131 parent[nb_idx] = curr;
132 queue.push_back(nb_idx);
133 }
134 }
135 }
136
137 for u in 0..n {
141 for nb in mol.graph.neighbors(petgraph::graph::NodeIndex::new(u)) {
142 let v = nb.index();
143 if u >= v {
144 continue;
145 } let ring_len = dist[u] + dist[v] + 1;
147 if ring_len > 8 {
148 continue;
149 } if dist[u] == usize::MAX || dist[v] == usize::MAX {
151 continue;
152 }
153
154 let path_u = trace_path(&parent, root, u);
156 let path_v = trace_path(&parent, root, v);
157
158 let mut ring = path_u.clone();
160 let mut path_v_rev: Vec<usize> = path_v.into_iter().rev().collect();
162 if !path_v_rev.is_empty() && !ring.is_empty() && path_v_rev.last() == ring.first() {
163 path_v_rev.pop(); }
165 ring.extend(path_v_rev);
166
167 let mut seen = std::collections::HashSet::new();
169 let is_simple = ring.iter().all(|&x| seen.insert(x));
170 if is_simple && ring.len() >= 3 {
171 let normalized = normalize_ring(&ring);
173 candidates.push(normalized);
174 }
175 }
176 }
177 }
178
179 candidates.sort();
181 candidates.dedup();
182
183 candidates.sort_by_key(|r| r.len());
187
188 let edge_sets: Vec<std::collections::HashSet<(usize, usize)>> =
191 candidates.iter().map(|r| ring_edges(r).collect()).collect();
192
193 let mut relevant = Vec::new();
194 for (i, ring) in candidates.iter().enumerate() {
195 let mut is_xor_of_smaller = false;
196 for j in 0..i {
198 if candidates[j].len() >= ring.len() {
199 continue;
200 }
201 for k in (j + 1)..i {
202 if candidates[k].len() >= ring.len() {
203 continue;
204 }
205 let sym_diff: std::collections::HashSet<(usize, usize)> = edge_sets[j]
207 .symmetric_difference(&edge_sets[k])
208 .copied()
209 .collect();
210 if sym_diff == edge_sets[i] {
211 is_xor_of_smaller = true;
212 break;
213 }
214 }
215 if is_xor_of_smaller {
216 break;
217 }
218 }
219 if !is_xor_of_smaller {
220 relevant.push(ring.clone());
221 }
222 }
223
224 relevant
225}
226
227fn trace_path(parent: &[usize], root: usize, target: usize) -> Vec<usize> {
228 let mut path = Vec::new();
229 let mut curr = target;
230 while curr != root && curr != usize::MAX {
231 path.push(curr);
232 curr = parent[curr];
233 }
234 if curr == root {
235 path.push(root);
236 }
237 path.reverse();
238 path
239}
240
241fn normalize_ring(ring: &[usize]) -> Vec<usize> {
242 if ring.is_empty() {
243 return vec![];
244 }
245 let min_pos = ring.iter().enumerate().min_by_key(|&(_, &v)| v).unwrap().0;
247 let n = ring.len();
248 let forward: Vec<usize> = (0..n).map(|i| ring[(min_pos + i) % n]).collect();
250 let backward: Vec<usize> = (0..n).map(|i| ring[(min_pos + n - i) % n]).collect();
251 forward.min(backward)
252}
253
254fn ring_edges(ring: &[usize]) -> impl Iterator<Item = (usize, usize)> + '_ {
255 let n = ring.len();
256 (0..n).map(move |i| {
257 let a = ring[i];
258 let b = ring[(i + 1) % n];
259 (a.min(b), a.max(b))
260 })
261}
262
263fn volume_test(
268 center: usize,
269 neighbors: &[usize; 4],
270 coords: &DMatrix<f64>,
271 relaxed: bool,
272) -> bool {
273 let dim = coords.ncols().min(3);
274 let p0 = Vector3::new(
275 coords[(center, 0)],
276 coords[(center, 1)],
277 if dim >= 3 { coords[(center, 2)] } else { 0.0 },
278 );
279 let mut vecs = [Vector3::<f64>::zeros(); 4];
280 for (k, &nb) in neighbors.iter().enumerate() {
281 let pk = Vector3::new(
282 coords[(nb, 0)],
283 coords[(nb, 1)],
284 if dim >= 3 { coords[(nb, 2)] } else { 0.0 },
285 );
286 let v = p0 - pk; let norm = v.norm();
288 vecs[k] = if norm > 1e-8 { v / norm } else { v };
289 }
290
291 let vol_scale: f64 = if relaxed { 0.25 } else { 1.0 };
292 let threshold = vol_scale * MIN_TETRAHEDRAL_CHIRAL_VOL;
293
294 let combos: [(usize, usize, usize); 4] = [(0, 1, 2), (0, 1, 3), (0, 2, 3), (1, 2, 3)];
296 for (a, b, c) in combos {
297 let cross = vecs[a].cross(&vecs[b]);
298 let vol = cross.dot(&vecs[c]).abs();
299 if vol < threshold {
300 return false;
301 }
302 }
303 true
304}
305
306fn same_side(
308 v1: &Vector3<f64>,
309 v2: &Vector3<f64>,
310 v3: &Vector3<f64>,
311 v4: &Vector3<f64>,
312 p0: &Vector3<f64>,
313 tol: f64,
314) -> bool {
315 let normal = (v2 - v1).cross(&(v3 - v1));
316 let d1 = normal.dot(&(v4 - v1));
317 let d2 = normal.dot(&(p0 - v1));
318 if d1.abs() < tol || d2.abs() < tol {
319 return false;
320 }
321 (d1 < 0.0) == (d2 < 0.0)
322}
323
324fn center_in_volume(
325 center: usize,
326 neighbors: &[usize; 4],
327 coords: &DMatrix<f64>,
328 tol: f64,
329) -> bool {
330 let dim = coords.ncols().min(3);
331 let get_p3d = |idx: usize| -> Vector3<f64> {
332 Vector3::new(
333 coords[(idx, 0)],
334 coords[(idx, 1)],
335 if dim >= 3 { coords[(idx, 2)] } else { 0.0 },
336 )
337 };
338 let p0 = get_p3d(center);
339 let p = [
340 get_p3d(neighbors[0]),
341 get_p3d(neighbors[1]),
342 get_p3d(neighbors[2]),
343 get_p3d(neighbors[3]),
344 ];
345
346 same_side(&p[0], &p[1], &p[2], &p[3], &p0, tol)
347 && same_side(&p[1], &p[2], &p[3], &p[0], &p0, tol)
348 && same_side(&p[2], &p[3], &p[0], &p[1], &p0, tol)
349 && same_side(&p[3], &p[0], &p[1], &p[2], &p0, tol)
350}
351
352pub fn check_tetrahedral_centers(coords: &DMatrix<f64>, centers: &[TetrahedralCenter]) -> bool {
355 for tc in centers {
356 if !volume_test(tc.center, &tc.neighbors, coords, tc.in_small_ring) {
357 return false;
358 }
359 if !center_in_volume(
360 tc.center,
361 &tc.neighbors,
362 coords,
363 TETRAHEDRAL_CENTERINVOLUME_TOL,
364 ) {
365 return false;
366 }
367 }
368 true
369}
370
371pub fn check_chiral_centers(coords: &DMatrix<f64>, chiral_sets: &[ChiralSet]) -> bool {
375 for cs in chiral_sets {
376 let vol = crate::distgeom::calc_chiral_volume_f64(
377 cs.neighbors[0],
378 cs.neighbors[1],
379 cs.neighbors[2],
380 cs.neighbors[3],
381 coords,
382 );
383 let lb = cs.lower_vol as f64;
384 let ub = cs.upper_vol as f64;
385 if lb > 0.0 && vol < lb && (vol / lb < 0.8 || have_opposite_sign(vol, lb)) {
386 return false;
387 }
388 if ub < 0.0 && vol > ub && (vol / ub < 0.8 || have_opposite_sign(vol, ub)) {
389 return false;
390 }
391 }
392 true
393}
394
395fn have_opposite_sign(a: f64, b: f64) -> bool {
396 (a < 0.0) != (b < 0.0)
397}
398
399pub fn check_planarity(mol: &Molecule, coords: &DMatrix<f32>, oop_k: f32, tolerance: f32) -> bool {
403 let n = mol.graph.node_count();
404 let mut n_impropers = 0usize;
405 let mut improper_energy = 0.0f32;
406
407 for i in 0..n {
409 let ni = petgraph::graph::NodeIndex::new(i);
410 if mol.graph[ni].hybridization != crate::graph::Hybridization::SP2 {
411 continue;
412 }
413 let nbs: Vec<_> = mol.graph.neighbors(ni).collect();
414 if nbs.len() != 3 {
415 continue;
416 }
417 n_impropers += 1;
418
419 let pc = Vector3::new(coords[(i, 0)], coords[(i, 1)], coords[(i, 2)]);
420 let p1 = Vector3::new(
421 coords[(nbs[0].index(), 0)],
422 coords[(nbs[0].index(), 1)],
423 coords[(nbs[0].index(), 2)],
424 );
425 let p2 = Vector3::new(
426 coords[(nbs[1].index(), 0)],
427 coords[(nbs[1].index(), 1)],
428 coords[(nbs[1].index(), 2)],
429 );
430 let p3 = Vector3::new(
431 coords[(nbs[2].index(), 0)],
432 coords[(nbs[2].index(), 1)],
433 coords[(nbs[2].index(), 2)],
434 );
435 let v1 = p1 - pc;
436 let v2 = p2 - pc;
437 let v3 = p3 - pc;
438 let vol = v1.dot(&v2.cross(&v3));
439 improper_energy += oop_k * vol * vol;
440 }
441
442 if n_impropers == 0 {
447 return true;
448 }
449 improper_energy <= n_impropers as f32 * tolerance
450}
451
452pub fn check_double_bond_geometry(mol: &Molecule, coords: &DMatrix<f64>) -> bool {
456 use petgraph::visit::EdgeRef;
457 for edge in mol.graph.edge_references() {
458 if mol.graph[edge.id()].order != crate::graph::BondOrder::Double {
459 continue;
460 }
461 let u = edge.source();
462 let v = edge.target();
463
464 let u_deg = mol.graph.neighbors(u).count();
466 if u_deg >= 2 {
467 for nb in mol.graph.neighbors(u) {
468 if nb == v {
469 continue;
470 }
471 if u_deg == 2 {
473 if let Some(eid) = mol.graph.find_edge(u, nb) {
474 if mol.graph[eid].order != crate::graph::BondOrder::Single {
475 continue;
476 }
477 }
478 }
479 if !check_linearity(nb.index(), u.index(), v.index(), coords) {
480 return false;
481 }
482 }
483 }
484 let v_deg = mol.graph.neighbors(v).count();
486 if v_deg >= 2 {
487 for nb in mol.graph.neighbors(v) {
488 if nb == u {
489 continue;
490 }
491 if v_deg == 2 {
493 if let Some(eid) = mol.graph.find_edge(v, nb) {
494 if mol.graph[eid].order != crate::graph::BondOrder::Single {
495 continue;
496 }
497 }
498 }
499 if !check_linearity(nb.index(), v.index(), u.index(), coords) {
500 return false;
501 }
502 }
503 }
504 }
505 true
506}
507
508fn check_linearity(a0: usize, a1: usize, a2: usize, coords: &DMatrix<f64>) -> bool {
510 let p0 = Vector3::new(coords[(a0, 0)], coords[(a0, 1)], coords[(a0, 2)]);
511 let p1 = Vector3::new(coords[(a1, 0)], coords[(a1, 1)], coords[(a1, 2)]);
512 let p2 = Vector3::new(coords[(a2, 0)], coords[(a2, 1)], coords[(a2, 2)]);
513 let mut v1 = p1 - p0;
514 let n1 = v1.norm();
515 if n1 < 1e-8 {
516 return true;
517 }
518 v1 /= n1;
519 let mut v2 = p1 - p2;
520 let n2 = v2.norm();
521 if n2 < 1e-8 {
522 return true;
523 }
524 v2 /= n2;
525 v1.dot(&v2) + 1.0 >= 1e-3
527}
528
529pub fn perturb_if_planar(coords: &mut DMatrix<f64>, rng: &mut crate::distgeom::MinstdRand) -> bool {
532 let n = coords.nrows();
533 if n < 4 || coords.ncols() < 3 {
534 return false;
535 }
536 let mut z_min = f64::INFINITY;
538 let mut z_max = f64::NEG_INFINITY;
539 for i in 0..n {
540 let z = coords[(i, 2)];
541 if z < z_min {
542 z_min = z;
543 }
544 if z > z_max {
545 z_max = z;
546 }
547 }
548 let z_spread = z_max - z_min;
549 let mut xy_max_spread = 0.0f64;
551 for d in 0..2 {
552 let mut lo = f64::INFINITY;
553 let mut hi = f64::NEG_INFINITY;
554 for i in 0..n {
555 let v = coords[(i, d)];
556 if v < lo {
557 lo = v;
558 }
559 if v > hi {
560 hi = v;
561 }
562 }
563 xy_max_spread = xy_max_spread.max(hi - lo);
564 }
565 if xy_max_spread < 1e-8 {
566 return false;
567 }
568 if z_spread < 0.01 * xy_max_spread {
570 for i in 0..n {
571 coords[(i, 2)] += 0.3 * (rng.next_double() - 0.5);
572 }
573 return true;
574 }
575 false
576}