Skip to main content

sci_form/smarts/
matcher.rs

1//! Substructure matching: match a SmartsPattern against a Molecule graph.
2
3use super::parser::*;
4use crate::graph::{BondOrder, ChiralType, Hybridization, Molecule};
5use petgraph::graph::NodeIndex;
6use petgraph::visit::EdgeRef;
7
8/// Find all substructure matches of `pattern` in `mol`.
9/// Returns a vec of mappings: each mapping is pattern_atom_idx → molecule_atom_idx.
10pub fn substruct_match(mol: &Molecule, pattern: &SmartsPattern) -> Vec<Vec<usize>> {
11    if pattern.atoms.is_empty() {
12        return vec![];
13    }
14
15    // Pre-compute ring info (lazily, only if needed)
16    let needs_ring = pattern_needs_ring(pattern);
17    let ring_info = if needs_ring {
18        Some(compute_ring_info(mol))
19    } else {
20        None
21    };
22
23    substruct_match_inner(mol, pattern, ring_info.as_ref())
24}
25
26/// Find all substructure matches using pre-computed ring info.
27/// This avoids recomputing SSSR when matching many patterns against the same molecule.
28pub fn substruct_match_with_ring_info(
29    mol: &Molecule,
30    pattern: &SmartsPattern,
31    ring_info: &RingInfo,
32) -> Vec<Vec<usize>> {
33    if pattern.atoms.is_empty() {
34        return vec![];
35    }
36    substruct_match_inner(mol, pattern, Some(ring_info))
37}
38
39/// Compute ring info for a molecule (SSSR + ring membership + ring sizes).
40/// Call once per molecule, then pass to `substruct_match_with_ring_info`.
41pub fn precompute_ring_info(mol: &Molecule) -> RingInfo {
42    compute_ring_info(mol)
43}
44
45fn substruct_match_inner(
46    mol: &Molecule,
47    pattern: &SmartsPattern,
48    ring_info: Option<&RingInfo>,
49) -> Vec<Vec<usize>> {
50    let n_pat = pattern.atoms.len();
51    let n_mol = mol.graph.node_count();
52
53    // Build adjacency for the pattern
54    let mut pat_adj: Vec<Vec<(usize, usize)>> = vec![vec![]; n_pat]; // (neighbor, bond_idx)
55    for (bi, bond) in pattern.bonds.iter().enumerate() {
56        pat_adj[bond.from].push((bond.to, bi));
57        pat_adj[bond.to].push((bond.from, bi));
58    }
59
60    let mut results = Vec::new();
61    let mut mapping = vec![usize::MAX; n_pat];
62    let mut used = vec![false; n_mol];
63
64    // Try each molecule atom as starting point for pattern atom 0
65    for start in 0..n_mol {
66        mapping[0] = start;
67        used[start] = true;
68        if atom_matches(
69            mol,
70            NodeIndex::new(start),
71            &pattern.atoms[0].query,
72            ring_info,
73        ) {
74            backtrack(
75                mol,
76                pattern,
77                &pat_adj,
78                &mut mapping,
79                &mut used,
80                1,
81                ring_info,
82                &mut results,
83            );
84        }
85        used[start] = false;
86        mapping[0] = usize::MAX;
87    }
88    results
89}
90
91fn backtrack(
92    mol: &Molecule,
93    pattern: &SmartsPattern,
94    pat_adj: &[Vec<(usize, usize)>],
95    mapping: &mut Vec<usize>,
96    used: &mut Vec<bool>,
97    depth: usize,
98    ring_info: Option<&RingInfo>,
99    results: &mut Vec<Vec<usize>>,
100) {
101    if depth == pattern.atoms.len() {
102        results.push(mapping.clone());
103        return;
104    }
105
106    // Find the next unmapped pattern atom that connects to an already-mapped atom.
107    // This ensures we extend the match along the pattern's connectivity.
108    let (pat_atom, connected_to, bond_idx) = match find_next_atom(pattern, pat_adj, mapping, depth)
109    {
110        Some(x) => x,
111        None => return,
112    };
113
114    let mol_anchor = NodeIndex::new(mapping[connected_to]);
115
116    // Try each neighbor of the molecule anchor atom
117    for nb in mol.graph.neighbors(mol_anchor) {
118        let ni = nb.index();
119        if used[ni] {
120            continue;
121        }
122
123        // Check atom match
124        if !atom_matches(mol, nb, &pattern.atoms[pat_atom].query, ring_info) {
125            continue;
126        }
127
128        // Check bond match
129        let edge = mol.graph.find_edge(mol_anchor, nb).unwrap();
130        let mol_bond = &mol.graph[edge];
131        if !bond_matches(
132            mol_bond.order,
133            &pattern.bonds[bond_idx].query,
134            mol,
135            mol_anchor,
136            nb,
137            ring_info,
138        ) {
139            continue;
140        }
141
142        // Also verify that all other already-mapped neighbors of pat_atom have matching bonds
143        let mut all_bonds_ok = true;
144        for &(other_pat, bi) in &pat_adj[pat_atom] {
145            if other_pat == connected_to {
146                continue;
147            }
148            if mapping[other_pat] == usize::MAX {
149                continue;
150            }
151            // There should be a bond from nb to mapping[other_pat] in the molecule
152            let other_mol = NodeIndex::new(mapping[other_pat]);
153            if let Some(e) = mol.graph.find_edge(nb, other_mol) {
154                if !bond_matches(
155                    mol.graph[e].order,
156                    &pattern.bonds[bi].query,
157                    mol,
158                    nb,
159                    other_mol,
160                    ring_info,
161                ) {
162                    all_bonds_ok = false;
163                    break;
164                }
165            } else {
166                all_bonds_ok = false;
167                break;
168            }
169        }
170        if !all_bonds_ok {
171            continue;
172        }
173
174        mapping[pat_atom] = ni;
175        used[ni] = true;
176        backtrack(
177            mol,
178            pattern,
179            pat_adj,
180            mapping,
181            used,
182            depth + 1,
183            ring_info,
184            results,
185        );
186        used[ni] = false;
187        mapping[pat_atom] = usize::MAX;
188    }
189}
190
191/// Find the next pattern atom to map: must be unmapped but adjacent to a mapped atom.
192fn find_next_atom(
193    pattern: &SmartsPattern,
194    pat_adj: &[Vec<(usize, usize)>],
195    mapping: &[usize],
196    _depth: usize,
197) -> Option<(usize, usize, usize)> {
198    // BFS order: find first unmapped atom connected to a mapped atom
199    for pat_idx in 0..pattern.atoms.len() {
200        if mapping[pat_idx] != usize::MAX {
201            continue;
202        }
203        for &(neighbor, bond_idx) in &pat_adj[pat_idx] {
204            if mapping[neighbor] != usize::MAX {
205                return Some((pat_idx, neighbor, bond_idx));
206            }
207        }
208    }
209    None
210}
211
212/// Check if a molecule atom matches an atom query.
213fn atom_matches(
214    mol: &Molecule,
215    atom: NodeIndex,
216    query: &AtomQuery,
217    ring_info: Option<&RingInfo>,
218) -> bool {
219    let a = &mol.graph[atom];
220    match query {
221        AtomQuery::True => true,
222        AtomQuery::Element(z) => a.element == *z && !is_aromatic_atom(mol, atom),
223        AtomQuery::AromaticElem(z) => a.element == *z && is_aromatic_atom(mol, atom),
224        AtomQuery::AnyAromatic => is_aromatic_atom(mol, atom),
225        AtomQuery::AnyAliphatic => !is_aromatic_atom(mol, atom),
226        AtomQuery::AtomicNum(z) => a.element == *z,
227        AtomQuery::NotAtomicNum(z) => a.element != *z,
228        AtomQuery::TotalH(n) => count_h(mol, atom) == *n as usize,
229        AtomQuery::TotalDegree(n) => mol.graph.neighbors(atom).count() == *n as usize,
230        AtomQuery::HeavyDegree(n) => {
231            mol.graph
232                .neighbors(atom)
233                .filter(|&nb| mol.graph[nb].element != 1)
234                .count()
235                == *n as usize
236        }
237        AtomQuery::RingBondCount(n) => {
238            if let Some(ri) = ring_info {
239                count_ring_bonds(mol, atom, ri) == *n as usize
240            } else {
241                false
242            }
243        }
244        AtomQuery::InRing => {
245            if let Some(ri) = ring_info {
246                ri.atom_in_ring[atom.index()]
247            } else {
248                false
249            }
250        }
251        AtomQuery::RingCount(n) => {
252            if let Some(ri) = ring_info {
253                let count = ri
254                    .rings
255                    .iter()
256                    .filter(|r| r.contains(&atom.index()))
257                    .count();
258                count == *n as usize
259            } else {
260                *n == 0
261            }
262        }
263        AtomQuery::RingSize(n) => {
264            if let Some(ri) = ring_info {
265                ri.atom_ring_sizes[atom.index()].contains(&(*n as usize))
266            } else {
267                false
268            }
269        }
270        AtomQuery::RingSizeRange(lo, hi) => {
271            if let Some(ri) = ring_info {
272                ri.atom_ring_sizes[atom.index()]
273                    .iter()
274                    .any(|&s| s >= *lo as usize && s <= *hi as usize)
275            } else {
276                false
277            }
278        }
279        AtomQuery::RingSizeMin(lo) => {
280            if let Some(ri) = ring_info {
281                ri.atom_ring_sizes[atom.index()]
282                    .iter()
283                    .any(|&s| s >= *lo as usize)
284            } else {
285                false
286            }
287        }
288        AtomQuery::FormalCharge(c) => a.formal_charge == *c,
289        AtomQuery::Hybridization(n) => matches!(
290            (n, &a.hybridization),
291            (1, Hybridization::SP) | (2, Hybridization::SP2) | (3, Hybridization::SP3)
292        ),
293        AtomQuery::Chiral(chiral) => matches!(
294            (chiral, &a.chiral_tag),
295            (ChiralType::TetrahedralCW, ChiralType::TetrahedralCW)
296                | (ChiralType::TetrahedralCCW, ChiralType::TetrahedralCCW)
297        ),
298        AtomQuery::Recursive(inner) => {
299            // The atom must match as atom 0 of the inner pattern
300            let matches = substruct_match_from(mol, inner, atom, ring_info);
301            !matches.is_empty()
302        }
303        AtomQuery::And(parts) => parts.iter().all(|q| atom_matches(mol, atom, q, ring_info)),
304        AtomQuery::Or(parts) => parts.iter().any(|q| atom_matches(mol, atom, q, ring_info)),
305        AtomQuery::Not(inner) => !atom_matches(mol, atom, inner, ring_info),
306    }
307}
308
309/// Check if a bond matches a bond query.
310fn bond_matches(
311    order: BondOrder,
312    query: &BondQuery,
313    _mol: &Molecule,
314    from: NodeIndex,
315    to: NodeIndex,
316    ring_info: Option<&RingInfo>,
317) -> bool {
318    match query {
319        BondQuery::Single => order == BondOrder::Single,
320        BondQuery::Double => order == BondOrder::Double,
321        BondQuery::Triple => order == BondOrder::Triple,
322        BondQuery::Aromatic => order == BondOrder::Aromatic,
323        BondQuery::Any => true,
324        BondQuery::Ring => {
325            if let Some(ri) = ring_info {
326                is_ring_bond(from, to, ri)
327            } else {
328                false
329            }
330        }
331        BondQuery::NotRing => {
332            if let Some(ri) = ring_info {
333                !is_ring_bond(from, to, ri)
334            } else {
335                true
336            }
337        }
338        BondQuery::Implicit => {
339            // Default bond: single or aromatic
340            order == BondOrder::Single || order == BondOrder::Aromatic
341        }
342        BondQuery::And(parts) => parts
343            .iter()
344            .all(|q| bond_matches(order, q, _mol, from, to, ring_info)),
345        BondQuery::Not(inner) => !bond_matches(order, inner, _mol, from, to, ring_info),
346    }
347}
348
349/// Match a pattern starting from a specific molecule atom (for recursive SMARTS).
350fn substruct_match_from(
351    mol: &Molecule,
352    pattern: &SmartsPattern,
353    start_atom: NodeIndex,
354    ring_info: Option<&RingInfo>,
355) -> Vec<Vec<usize>> {
356    if pattern.atoms.is_empty() {
357        return vec![];
358    }
359
360    let n_pat = pattern.atoms.len();
361    let n_mol = mol.graph.node_count();
362
363    let mut pat_adj: Vec<Vec<(usize, usize)>> = vec![vec![]; n_pat];
364    for (bi, bond) in pattern.bonds.iter().enumerate() {
365        pat_adj[bond.from].push((bond.to, bi));
366        pat_adj[bond.to].push((bond.from, bi));
367    }
368
369    let mut results = Vec::new();
370    let mut mapping = vec![usize::MAX; n_pat];
371    let mut used = vec![false; n_mol];
372
373    // Start from the specified atom
374    mapping[0] = start_atom.index();
375    used[start_atom.index()] = true;
376    if atom_matches(mol, start_atom, &pattern.atoms[0].query, ring_info) {
377        backtrack(
378            mol,
379            pattern,
380            &pat_adj,
381            &mut mapping,
382            &mut used,
383            1,
384            ring_info,
385            &mut results,
386        );
387    }
388    used[start_atom.index()] = false;
389    results
390}
391
392// ── Ring info helpers ──
393
394/// Pre-computed ring information for a molecule, used to speed up SMARTS matching.
395///
396/// Compute once with [`precompute_ring_info`] and reuse for many pattern matches.
397pub struct RingInfo {
398    pub atom_in_ring: Vec<bool>,
399    pub atom_ring_sizes: Vec<Vec<usize>>,
400    pub rings: Vec<Vec<usize>>,
401}
402
403fn compute_ring_info(mol: &Molecule) -> RingInfo {
404    let n = mol.graph.node_count();
405    let rings = crate::distgeom::find_sssr_pub(mol);
406
407    let mut atom_in_ring = vec![false; n];
408    let mut atom_ring_sizes: Vec<Vec<usize>> = vec![vec![]; n];
409
410    for ring in &rings {
411        for &a in ring {
412            atom_in_ring[a] = true;
413            let size = ring.len();
414            if !atom_ring_sizes[a].contains(&size) {
415                atom_ring_sizes[a].push(size);
416            }
417        }
418    }
419
420    // For macrocycle detection (r{9-}), we also need to detect large rings.
421    // The SSSR may not include all rings. For atoms not in SSSR but in a ring,
422    // check using BFS shortest alternative path.
423    // For now, also detect rings up to size 20 for any bond in the molecule.
424    for edge in mol.graph.edge_references() {
425        let u = edge.source().index();
426        let v = edge.target().index();
427        // If already in a detected ring, skip
428        if atom_in_ring[u] && atom_in_ring[v] {
429            // Check if they share a ring (both in same ring)
430            let shared = rings.iter().any(|r| r.contains(&u) && r.contains(&v));
431            if shared {
432                continue;
433            }
434        }
435        // BFS for alternative path (ring detection)
436        if let Some(alt_len) = crate::graph::min_path_excluding2(
437            mol,
438            NodeIndex::new(u),
439            NodeIndex::new(v),
440            NodeIndex::new(u),
441            NodeIndex::new(v),
442            19,
443        ) {
444            let ring_size = alt_len + 1;
445            if ring_size >= 3 {
446                atom_in_ring[u] = true;
447                atom_in_ring[v] = true;
448                if !atom_ring_sizes[u].contains(&ring_size) {
449                    atom_ring_sizes[u].push(ring_size);
450                }
451                if !atom_ring_sizes[v].contains(&ring_size) {
452                    atom_ring_sizes[v].push(ring_size);
453                }
454            }
455        }
456    }
457
458    RingInfo {
459        atom_in_ring,
460        atom_ring_sizes,
461        rings,
462    }
463}
464
465fn is_ring_bond(a: NodeIndex, b: NodeIndex, ring_info: &RingInfo) -> bool {
466    ring_info.rings.iter().any(|ring| {
467        let ai = a.index();
468        let bi = b.index();
469        if !ring.contains(&ai) || !ring.contains(&bi) {
470            return false;
471        }
472        // Check if a and b are adjacent in the ring
473        let len = ring.len();
474        for i in 0..len {
475            let j = (i + 1) % len;
476            if (ring[i] == ai && ring[j] == bi) || (ring[i] == bi && ring[j] == ai) {
477                return true;
478            }
479        }
480        false
481    })
482}
483
484fn count_ring_bonds(mol: &Molecule, atom: NodeIndex, ring_info: &RingInfo) -> usize {
485    mol.graph
486        .neighbors(atom)
487        .filter(|&nb| is_ring_bond(atom, nb, ring_info))
488        .count()
489}
490
491fn is_aromatic_atom(mol: &Molecule, atom: NodeIndex) -> bool {
492    mol.graph
493        .edges(atom)
494        .any(|e| mol.graph[e.id()].order == BondOrder::Aromatic)
495}
496
497fn count_h(mol: &Molecule, atom: NodeIndex) -> usize {
498    mol.graph
499        .neighbors(atom)
500        .filter(|&nb| mol.graph[nb].element == 1)
501        .count()
502}
503
504fn pattern_needs_ring(pattern: &SmartsPattern) -> bool {
505    for atom in &pattern.atoms {
506        if query_needs_ring(&atom.query) {
507            return true;
508        }
509    }
510    for bond in &pattern.bonds {
511        if bond_query_needs_ring(&bond.query) {
512            return true;
513        }
514    }
515    false
516}
517
518fn query_needs_ring(q: &AtomQuery) -> bool {
519    match q {
520        AtomQuery::InRing
521        | AtomQuery::RingSize(_)
522        | AtomQuery::RingSizeRange(..)
523        | AtomQuery::RingSizeMin(_)
524        | AtomQuery::RingBondCount(_)
525        | AtomQuery::RingCount(_) => true,
526        AtomQuery::And(parts) | AtomQuery::Or(parts) => parts.iter().any(query_needs_ring),
527        AtomQuery::Not(inner) => query_needs_ring(inner),
528        AtomQuery::Recursive(inner) => pattern_needs_ring(inner),
529        _ => false,
530    }
531}
532
533fn bond_query_needs_ring(q: &BondQuery) -> bool {
534    match q {
535        BondQuery::Ring | BondQuery::NotRing => true,
536        BondQuery::And(parts) => parts.iter().any(bond_query_needs_ring),
537        BondQuery::Not(inner) => bond_query_needs_ring(inner),
538        _ => false,
539    }
540}
541
542/// Batch substructure matching: match a single pattern against many molecules.
543/// Returns one Vec<Vec<usize>> per molecule (empty if no match).
544pub fn substruct_match_batch(
545    molecules: &[&Molecule],
546    pattern: &SmartsPattern,
547) -> Vec<Vec<Vec<usize>>> {
548    molecules
549        .iter()
550        .map(|mol| substruct_match(mol, pattern))
551        .collect()
552}
553
554/// Batch substructure matching with rayon parallelism.
555#[cfg(feature = "parallel")]
556pub fn substruct_match_batch_parallel(
557    molecules: &[&Molecule],
558    pattern: &SmartsPattern,
559) -> Vec<Vec<Vec<usize>>> {
560    use rayon::prelude::*;
561    molecules
562        .par_iter()
563        .map(|mol| substruct_match(mol, pattern))
564        .collect()
565}
566
567/// Check if a molecule contains a substructure (boolean, faster than full match).
568pub fn has_substruct_match(mol: &Molecule, pattern: &SmartsPattern) -> bool {
569    !substruct_match(mol, pattern).is_empty()
570}
571
572/// Batch boolean substructure check with rayon parallelism.
573#[cfg(feature = "parallel")]
574pub fn has_substruct_match_batch_parallel(
575    molecules: &[&Molecule],
576    pattern: &SmartsPattern,
577) -> Vec<bool> {
578    use rayon::prelude::*;
579    molecules
580        .par_iter()
581        .map(|mol| has_substruct_match(mol, pattern))
582        .collect()
583}
584
585#[cfg(test)]
586mod tests {
587    use super::*;
588    use crate::graph::Molecule;
589
590    #[test]
591    fn test_tetrahedral_chirality_matches_explicit_query() {
592        let mol = Molecule::from_smiles("C[C@H](F)Cl").unwrap();
593        let pattern = parse_smarts("[C@H]").unwrap();
594        let inverse = parse_smarts("[C@@H]").unwrap();
595
596        assert!(has_substruct_match(&mol, &pattern));
597        assert!(!has_substruct_match(&mol, &inverse));
598    }
599}