Skip to main content

sci_form/distgeom/
validation.rs

1//! Validation checks matching RDKit's embedding validation pipeline.
2//! Part of the retry-on-failure loop in embedPoints().
3
4use crate::forcefield::bounds_ff::ChiralSet;
5use crate::graph::Molecule;
6use nalgebra::{DMatrix, Vector3};
7
8const MIN_TETRAHEDRAL_CHIRAL_VOL: f64 = 0.50;
9const TETRAHEDRAL_CENTERINVOLUME_TOL: f64 = 0.30;
10pub const MAX_MINIMIZED_E_PER_ATOM: f32 = 0.05;
11
12/// Tetrahedral center — any sp3 atom with 4 neighbors (not necessarily chiral)
13pub struct TetrahedralCenter {
14    pub center: usize,
15    pub neighbors: [usize; 4],
16    pub in_small_ring: bool,
17}
18
19/// Identify tetrahedral centers for volume checks.
20/// Matches RDKit's findChiralSets logic for tetrahedralCarbons:
21///   - C or N atoms only, with exactly 4 neighbors
22///   - Must be in 2+ rings (ring junction atoms)
23///   - Must NOT be in any 3-membered ring
24pub fn identify_tetrahedral_centers(mol: &Molecule) -> Vec<TetrahedralCenter> {
25    let n = mol.graph.node_count();
26    // Compute SSSR rings, then derive per-atom ring count and 3-ring membership
27    let rings = find_sssr(mol);
28    let mut ring_count = vec![0usize; n];
29    let mut in_3_ring = vec![false; n];
30    let mut small_ring_count = vec![0usize; n]; // rings of size < 5
31    for ring in &rings {
32        for &atom_idx in ring {
33            ring_count[atom_idx] += 1;
34            if ring.len() == 3 {
35                in_3_ring[atom_idx] = true;
36            }
37            if ring.len() < 5 {
38                small_ring_count[atom_idx] += 1;
39            }
40        }
41    }
42
43    let mut centers = Vec::new();
44    for i in 0..n {
45        let ni = petgraph::graph::NodeIndex::new(i);
46        let atom = &mol.graph[ni];
47        // RDKit: only C (6) or N (7) with degree == 4
48        let elem = atom.element;
49        if elem != 6 && elem != 7 {
50            continue;
51        }
52        let nbs: Vec<_> = mol.graph.neighbors(ni).collect();
53        if nbs.len() != 4 {
54            continue;
55        }
56
57        // RDKit: only add if in 2+ rings AND not in any 3-ring
58        if ring_count[i] < 2 || in_3_ring[i] {
59            continue;
60        }
61
62        centers.push(TetrahedralCenter {
63            center: i,
64            neighbors: [
65                nbs[0].index(),
66                nbs[1].index(),
67                nbs[2].index(),
68                nbs[3].index(),
69            ],
70            in_small_ring: small_ring_count[i] > 1,
71        });
72    }
73    centers
74}
75
76/// Find the Smallest Set of Smallest Rings (SSSR) using Horton's algorithm.
77/// Returns a list of rings, each ring being a vector of atom indices.
78pub fn find_sssr_pub(mol: &Molecule) -> Vec<Vec<usize>> {
79    find_sssr(mol)
80}
81fn find_sssr(mol: &Molecule) -> Vec<Vec<usize>> {
82    use std::collections::VecDeque;
83    let n = mol.graph.node_count();
84    if n == 0 {
85        return vec![];
86    }
87
88    // Number of independent cycles = edges - vertices + connected_components
89    let num_edges = mol.graph.edge_count();
90    // Count connected components via BFS
91    let mut visited = vec![false; n];
92    let mut num_components = 0;
93    for start in 0..n {
94        if visited[start] {
95            continue;
96        }
97        num_components += 1;
98        let mut queue = VecDeque::new();
99        queue.push_back(start);
100        visited[start] = true;
101        while let Some(curr) = queue.pop_front() {
102            for nb in mol.graph.neighbors(petgraph::graph::NodeIndex::new(curr)) {
103                if !visited[nb.index()] {
104                    visited[nb.index()] = true;
105                    queue.push_back(nb.index());
106                }
107            }
108        }
109    }
110    let cycle_rank = (num_edges + num_components).saturating_sub(n);
111    if cycle_rank == 0 {
112        return vec![];
113    }
114
115    // For each vertex, BFS to compute shortest path tree
116    // Then for each non-tree edge found at vertex v, form the candidate ring
117    let mut candidates: Vec<Vec<usize>> = Vec::new();
118
119    for root in 0..n {
120        let mut dist = vec![usize::MAX; n];
121        let mut parent = vec![usize::MAX; n];
122        dist[root] = 0;
123        let mut queue = VecDeque::new();
124        queue.push_back(root);
125
126        while let Some(curr) = queue.pop_front() {
127            for nb in mol.graph.neighbors(petgraph::graph::NodeIndex::new(curr)) {
128                let nb_idx = nb.index();
129                if dist[nb_idx] == usize::MAX {
130                    dist[nb_idx] = dist[curr] + 1;
131                    parent[nb_idx] = curr;
132                    queue.push_back(nb_idx);
133                }
134            }
135        }
136
137        // For each neighbor of root, if they share neighbors that are equidistant or close,
138        // form candidate rings. Specifically, look for pairs (u, v) where edge (u,v) exists
139        // and dist[u] + dist[v] + 1 gives an odd ring, or dist[u] == dist[v] for even ring.
140        for u in 0..n {
141            for nb in mol.graph.neighbors(petgraph::graph::NodeIndex::new(u)) {
142                let v = nb.index();
143                if u >= v {
144                    continue;
145                } // avoid duplicates
146                let ring_len = dist[u] + dist[v] + 1;
147                if ring_len > 8 {
148                    continue;
149                } // skip very large rings
150                if dist[u] == usize::MAX || dist[v] == usize::MAX {
151                    continue;
152                }
153
154                // Build the ring: path from root to u + edge (u,v) + path from v to root
155                let path_u = trace_path(&parent, root, u);
156                let path_v = trace_path(&parent, root, v);
157
158                // Check that paths don't share intermediate vertices (would make it not a simple cycle)
159                let mut ring = path_u.clone();
160                // path_v goes root→...→v, we need to reverse it and skip the root
161                let mut path_v_rev: Vec<usize> = path_v.into_iter().rev().collect();
162                if !path_v_rev.is_empty() && !ring.is_empty() && path_v_rev.last() == ring.first() {
163                    path_v_rev.pop(); // remove duplicate root
164                }
165                ring.extend(path_v_rev);
166
167                // Check it's a valid simple cycle (no repeated vertices)
168                let mut seen = std::collections::HashSet::new();
169                let is_simple = ring.iter().all(|&x| seen.insert(x));
170                if is_simple && ring.len() >= 3 {
171                    // Normalize the ring for deduplication
172                    let normalized = normalize_ring(&ring);
173                    candidates.push(normalized);
174                }
175            }
176        }
177    }
178
179    // Deduplicate candidates
180    candidates.sort();
181    candidates.dedup();
182
183    // Filter: keep only rings that are "relevant" — not the XOR of two smaller rings.
184    // For the purposes of ring counting, we keep all unique smallest rings.
185    // Sort by size so smallest come first.
186    candidates.sort_by_key(|r| r.len());
187
188    // A ring is "relevant" if it cannot be expressed as the symmetric difference of
189    // two strictly smaller rings. This matches RDKit's ring perception behavior.
190    let edge_sets: Vec<std::collections::HashSet<(usize, usize)>> =
191        candidates.iter().map(|r| ring_edges(r).collect()).collect();
192
193    let mut relevant = Vec::new();
194    for (i, ring) in candidates.iter().enumerate() {
195        let mut is_xor_of_smaller = false;
196        // Check all pairs of strictly smaller rings
197        for j in 0..i {
198            if candidates[j].len() >= ring.len() {
199                continue;
200            }
201            for k in (j + 1)..i {
202                if candidates[k].len() >= ring.len() {
203                    continue;
204                }
205                // Symmetric difference of edge sets j and k
206                let sym_diff: std::collections::HashSet<(usize, usize)> = edge_sets[j]
207                    .symmetric_difference(&edge_sets[k])
208                    .copied()
209                    .collect();
210                if sym_diff == edge_sets[i] {
211                    is_xor_of_smaller = true;
212                    break;
213                }
214            }
215            if is_xor_of_smaller {
216                break;
217            }
218        }
219        if !is_xor_of_smaller {
220            relevant.push(ring.clone());
221        }
222    }
223
224    relevant
225}
226
227fn trace_path(parent: &[usize], root: usize, target: usize) -> Vec<usize> {
228    let mut path = Vec::new();
229    let mut curr = target;
230    while curr != root && curr != usize::MAX {
231        path.push(curr);
232        curr = parent[curr];
233    }
234    if curr == root {
235        path.push(root);
236    }
237    path.reverse();
238    path
239}
240
241fn normalize_ring(ring: &[usize]) -> Vec<usize> {
242    if ring.is_empty() {
243        return vec![];
244    }
245    // Find minimum element position
246    let min_pos = ring.iter().enumerate().min_by_key(|&(_, &v)| v).unwrap().0;
247    let n = ring.len();
248    // Try both directions (clockwise and counterclockwise)
249    let forward: Vec<usize> = (0..n).map(|i| ring[(min_pos + i) % n]).collect();
250    let backward: Vec<usize> = (0..n).map(|i| ring[(min_pos + n - i) % n]).collect();
251    forward.min(backward)
252}
253
254fn ring_edges(ring: &[usize]) -> impl Iterator<Item = (usize, usize)> + '_ {
255    let n = ring.len();
256    (0..n).map(move |i| {
257        let a = ring[i];
258        let b = ring[(i + 1) % n];
259        (a.min(b), a.max(b))
260    })
261}
262
263/// Volume test: check that a tetrahedral center has minimum volume.
264/// Uses NORMALIZED direction vectors from center to each neighbor.
265/// Checks all C(4,3)=4 combinations of 3 vectors.
266/// Uses f64 to match RDKit's Point3D (double) precision.
267fn volume_test(
268    center: usize,
269    neighbors: &[usize; 4],
270    coords: &DMatrix<f64>,
271    relaxed: bool,
272) -> bool {
273    let dim = coords.ncols().min(3);
274    let p0 = Vector3::new(
275        coords[(center, 0)],
276        coords[(center, 1)],
277        if dim >= 3 { coords[(center, 2)] } else { 0.0 },
278    );
279    let mut vecs = [Vector3::<f64>::zeros(); 4];
280    for (k, &nb) in neighbors.iter().enumerate() {
281        let pk = Vector3::new(
282            coords[(nb, 0)],
283            coords[(nb, 1)],
284            if dim >= 3 { coords[(nb, 2)] } else { 0.0 },
285        );
286        let v = p0 - pk; // RDKit: center - neighbor
287        let norm = v.norm();
288        vecs[k] = if norm > 1e-8 { v / norm } else { v };
289    }
290
291    let vol_scale: f64 = if relaxed { 0.25 } else { 1.0 };
292    let threshold = vol_scale * MIN_TETRAHEDRAL_CHIRAL_VOL;
293
294    // RDKit checks: (v1×v2)·v3, (v1×v2)·v4, (v1×v3)·v4, (v2×v3)·v4
295    let combos: [(usize, usize, usize); 4] = [(0, 1, 2), (0, 1, 3), (0, 2, 3), (1, 2, 3)];
296    for (a, b, c) in combos {
297        let cross = vecs[a].cross(&vecs[b]);
298        let vol = cross.dot(&vecs[c]).abs();
299        if vol < threshold {
300            return false;
301        }
302    }
303    true
304}
305
306/// Center-in-volume test: center atom must be inside the tetrahedron formed by its 4 neighbors.
307fn same_side(
308    v1: &Vector3<f64>,
309    v2: &Vector3<f64>,
310    v3: &Vector3<f64>,
311    v4: &Vector3<f64>,
312    p0: &Vector3<f64>,
313    tol: f64,
314) -> bool {
315    let normal = (v2 - v1).cross(&(v3 - v1));
316    let d1 = normal.dot(&(v4 - v1));
317    let d2 = normal.dot(&(p0 - v1));
318    if d1.abs() < tol || d2.abs() < tol {
319        return false;
320    }
321    (d1 < 0.0) == (d2 < 0.0)
322}
323
324fn center_in_volume(
325    center: usize,
326    neighbors: &[usize; 4],
327    coords: &DMatrix<f64>,
328    tol: f64,
329) -> bool {
330    let dim = coords.ncols().min(3);
331    let get_p3d = |idx: usize| -> Vector3<f64> {
332        Vector3::new(
333            coords[(idx, 0)],
334            coords[(idx, 1)],
335            if dim >= 3 { coords[(idx, 2)] } else { 0.0 },
336        )
337    };
338    let p0 = get_p3d(center);
339    let p = [
340        get_p3d(neighbors[0]),
341        get_p3d(neighbors[1]),
342        get_p3d(neighbors[2]),
343        get_p3d(neighbors[3]),
344    ];
345
346    same_side(&p[0], &p[1], &p[2], &p[3], &p0, tol)
347        && same_side(&p[1], &p[2], &p[3], &p[0], &p0, tol)
348        && same_side(&p[2], &p[3], &p[0], &p[1], &p0, tol)
349        && same_side(&p[3], &p[0], &p[1], &p[2], &p0, tol)
350}
351
352/// Check all tetrahedral centers for minimum volume and center-in-volume.
353/// Matches RDKit's checkTetrahedralCenters. Uses f64 coords matching RDKit's Point3D.
354pub fn check_tetrahedral_centers(coords: &DMatrix<f64>, centers: &[TetrahedralCenter]) -> bool {
355    for tc in centers {
356        if !volume_test(tc.center, &tc.neighbors, coords, tc.in_small_ring) {
357            return false;
358        }
359        if !center_in_volume(
360            tc.center,
361            &tc.neighbors,
362            coords,
363            TETRAHEDRAL_CENTERINVOLUME_TOL,
364        ) {
365            return false;
366        }
367    }
368    true
369}
370
371/// Check chiral center volumes have correct sign.
372/// Matches RDKit's checkChiralCenters — intentionally permissive (allows 20% undershoot if sign matches).
373/// Uses f64 coords matching RDKit's Point3D.
374pub fn check_chiral_centers(coords: &DMatrix<f64>, chiral_sets: &[ChiralSet]) -> bool {
375    for cs in chiral_sets {
376        let vol = crate::distgeom::calc_chiral_volume_f64(
377            cs.neighbors[0],
378            cs.neighbors[1],
379            cs.neighbors[2],
380            cs.neighbors[3],
381            coords,
382        );
383        let lb = cs.lower_vol as f64;
384        let ub = cs.upper_vol as f64;
385        if lb > 0.0 && vol < lb && (vol / lb < 0.8 || have_opposite_sign(vol, lb)) {
386            return false;
387        }
388        if ub < 0.0 && vol > ub && (vol / ub < 0.8 || have_opposite_sign(vol, ub)) {
389            return false;
390        }
391    }
392    true
393}
394
395fn have_opposite_sign(a: f64, b: f64) -> bool {
396    (a < 0.0) != (b < 0.0)
397}
398
399/// Planarity check: compute OOP (improper torsion) energy for SP2 centers.
400/// Reject if energy > n_impropers * tolerance.
401/// Matches RDKit's planarity check in minimizeWithExpTorsions.
402pub fn check_planarity(mol: &Molecule, coords: &DMatrix<f32>, oop_k: f32, tolerance: f32) -> bool {
403    let n = mol.graph.node_count();
404    let mut n_impropers = 0usize;
405    let mut improper_energy = 0.0f32;
406
407    // SP2 improper (out-of-plane) terms only
408    for i in 0..n {
409        let ni = petgraph::graph::NodeIndex::new(i);
410        if mol.graph[ni].hybridization != crate::graph::Hybridization::SP2 {
411            continue;
412        }
413        let nbs: Vec<_> = mol.graph.neighbors(ni).collect();
414        if nbs.len() != 3 {
415            continue;
416        }
417        n_impropers += 1;
418
419        let pc = Vector3::new(coords[(i, 0)], coords[(i, 1)], coords[(i, 2)]);
420        let p1 = Vector3::new(
421            coords[(nbs[0].index(), 0)],
422            coords[(nbs[0].index(), 1)],
423            coords[(nbs[0].index(), 2)],
424        );
425        let p2 = Vector3::new(
426            coords[(nbs[1].index(), 0)],
427            coords[(nbs[1].index(), 1)],
428            coords[(nbs[1].index(), 2)],
429        );
430        let p3 = Vector3::new(
431            coords[(nbs[2].index(), 0)],
432            coords[(nbs[2].index(), 1)],
433            coords[(nbs[2].index(), 2)],
434        );
435        let v1 = p1 - pc;
436        let v2 = p2 - pc;
437        let v3 = p3 - pc;
438        let vol = v1.dot(&v2.cross(&v3));
439        improper_energy += oop_k * vol * vol;
440    }
441
442    // SP linearity is enforced by the ETKDG 3D FF distance constraints (k=100),
443    // not by the validation check. Including SP angle penalties here caused
444    // false rejections for molecules with both SP and SP2 atoms.
445
446    if n_impropers == 0 {
447        return true;
448    }
449    improper_energy <= n_impropers as f32 * tolerance
450}
451
452/// Double bond geometry check: reject if substituent-double_bond_atom-other is nearly linear.
453/// Matches RDKit's doubleBondGeometryChecks with doubleBondEnds filtering.
454/// Uses f64 coords matching RDKit's Point3D.
455pub fn check_double_bond_geometry(mol: &Molecule, coords: &DMatrix<f64>) -> bool {
456    use petgraph::visit::EdgeRef;
457    for edge in mol.graph.edge_references() {
458        if mol.graph[edge.id()].order != crate::graph::BondOrder::Double {
459            continue;
460        }
461        let u = edge.source();
462        let v = edge.target();
463
464        // Check neighbors of u (substituents around the double bond end)
465        let u_deg = mol.graph.neighbors(u).count();
466        if u_deg >= 2 {
467            for nb in mol.graph.neighbors(u) {
468                if nb == v {
469                    continue;
470                }
471                // RDKit filter: skip if bond to neighbor is NOT single and atom has degree 2
472                if u_deg == 2 {
473                    if let Some(eid) = mol.graph.find_edge(u, nb) {
474                        if mol.graph[eid].order != crate::graph::BondOrder::Single {
475                            continue;
476                        }
477                    }
478                }
479                if !check_linearity(nb.index(), u.index(), v.index(), coords) {
480                    return false;
481                }
482            }
483        }
484        // Check neighbors of v
485        let v_deg = mol.graph.neighbors(v).count();
486        if v_deg >= 2 {
487            for nb in mol.graph.neighbors(v) {
488                if nb == u {
489                    continue;
490                }
491                // RDKit filter: skip if bond to neighbor is NOT single and atom has degree 2
492                if v_deg == 2 {
493                    if let Some(eid) = mol.graph.find_edge(v, nb) {
494                        if mol.graph[eid].order != crate::graph::BondOrder::Single {
495                            continue;
496                        }
497                    }
498                }
499                if !check_linearity(nb.index(), v.index(), u.index(), coords) {
500                    return false;
501                }
502            }
503        }
504    }
505    true
506}
507
508/// Returns false if a0-a1-a2 is nearly linear (angle ≈ 180°).
509fn check_linearity(a0: usize, a1: usize, a2: usize, coords: &DMatrix<f64>) -> bool {
510    let p0 = Vector3::new(coords[(a0, 0)], coords[(a0, 1)], coords[(a0, 2)]);
511    let p1 = Vector3::new(coords[(a1, 0)], coords[(a1, 1)], coords[(a1, 2)]);
512    let p2 = Vector3::new(coords[(a2, 0)], coords[(a2, 1)], coords[(a2, 2)]);
513    let mut v1 = p1 - p0;
514    let n1 = v1.norm();
515    if n1 < 1e-8 {
516        return true;
517    }
518    v1 /= n1;
519    let mut v2 = p1 - p2;
520    let n2 = v2.norm();
521    if n2 < 1e-8 {
522        return true;
523    }
524    v2 /= n2;
525    // dot ≈ -1 means linear; reject if dot + 1 < 1e-3
526    v1.dot(&v2) + 1.0 >= 1e-3
527}
528
529/// Check if coordinates are quasi-planar (2D) and perturb the z-axis if so.
530/// Returns true if coordinates were perturbed (caller should re-minimise).
531pub fn perturb_if_planar(coords: &mut DMatrix<f64>, rng: &mut crate::distgeom::MinstdRand) -> bool {
532    let n = coords.nrows();
533    if n < 4 || coords.ncols() < 3 {
534        return false;
535    }
536    // Compute spread along z-axis (column 2)
537    let mut z_min = f64::INFINITY;
538    let mut z_max = f64::NEG_INFINITY;
539    for i in 0..n {
540        let z = coords[(i, 2)];
541        if z < z_min {
542            z_min = z;
543        }
544        if z > z_max {
545            z_max = z;
546        }
547    }
548    let z_spread = z_max - z_min;
549    // If z-spread is tiny compared to x/y spread, coordinates are quasi-planar
550    let mut xy_max_spread = 0.0f64;
551    for d in 0..2 {
552        let mut lo = f64::INFINITY;
553        let mut hi = f64::NEG_INFINITY;
554        for i in 0..n {
555            let v = coords[(i, d)];
556            if v < lo {
557                lo = v;
558            }
559            if v > hi {
560                hi = v;
561            }
562        }
563        xy_max_spread = xy_max_spread.max(hi - lo);
564    }
565    if xy_max_spread < 1e-8 {
566        return false;
567    }
568    // Quasi-planar if z spread < 1% of max xy spread
569    if z_spread < 0.01 * xy_max_spread {
570        for i in 0..n {
571            coords[(i, 2)] += 0.3 * (rng.next_double() - 0.5);
572        }
573        return true;
574    }
575    false
576}