Skip to main content

sci_form/smarts/
parser.rs

1//! SMARTS parser: converts a SMARTS string into a pattern graph.
2
3/// A parsed SMARTS pattern as a small graph of atom/bond queries.
4#[derive(Debug, Clone)]
5pub struct SmartsPattern {
6    pub atoms: Vec<SmartsAtom>,
7    pub bonds: Vec<SmartsBond>,
8}
9
10#[derive(Debug, Clone)]
11pub struct SmartsAtom {
12    pub query: AtomQuery,
13    pub map_idx: Option<u8>, // :N atom map number
14}
15
16#[derive(Debug, Clone)]
17pub struct SmartsBond {
18    pub from: usize,
19    pub to: usize,
20    pub query: BondQuery,
21}
22
23#[derive(Debug, Clone)]
24pub enum AtomQuery {
25    True,                  // matches any atom (used for *)
26    Element(u8),           // aliphatic element by atomic number
27    AromaticElem(u8),      // aromatic element (c=6, n=7, o=8, s=16, p=15)
28    AnyAromatic,           // 'a'
29    AnyAliphatic,          // 'A'
30    AtomicNum(u8),         // #N
31    NotAtomicNum(u8),      // !#N
32    TotalH(u8),            // HN
33    TotalDegree(u8),       // XN (total connections including implicit H)
34    HeavyDegree(u8),       // DN (connections to non-H)
35    RingBondCount(u8),     // xN
36    InRing,                // R (in any ring)
37    RingSize(u8),          // rN (in ring of exactly size N)
38    RingSizeRange(u8, u8), // r{N-M}
39    RingSizeMin(u8),       // r{N-}
40    FormalCharge(i8),      // +N or -N
41    Hybridization(u8),     // ^N
42    RingCount(u8),         // RN (number of SSSR rings containing this atom)
43    Recursive(Box<SmartsPattern>),
44    And(Vec<AtomQuery>),
45    Or(Vec<AtomQuery>),
46    Not(Box<AtomQuery>),
47}
48
49#[derive(Debug, Clone)]
50pub enum BondQuery {
51    Single,
52    Double,
53    Triple,
54    Aromatic, // ':'
55    Any,      // '~'
56    Ring,     // '@'
57    NotRing,  // '!@'
58    Implicit, // default (single or aromatic)
59    And(Vec<BondQuery>),
60    Not(Box<BondQuery>),
61}
62
63/// Parse a SMARTS string into a SmartsPattern.
64pub fn parse_smarts(smarts: &str) -> Result<SmartsPattern, String> {
65    let mut parser = SmartsParser::new(smarts);
66    parser.parse_chain(None)?;
67    Ok(SmartsPattern {
68        atoms: parser.atoms,
69        bonds: parser.bonds,
70    })
71}
72
73struct SmartsParser<'a> {
74    input: &'a [u8],
75    pos: usize,
76    atoms: Vec<SmartsAtom>,
77    bonds: Vec<SmartsBond>,
78    ring_opens: [Option<usize>; 10], // ring closure digits 0-9
79}
80
81impl<'a> SmartsParser<'a> {
82    fn new(s: &'a str) -> Self {
83        Self {
84            input: s.as_bytes(),
85            pos: 0,
86            atoms: Vec::new(),
87            bonds: Vec::new(),
88            ring_opens: [None; 10],
89        }
90    }
91
92    fn peek(&self) -> Option<u8> {
93        self.input.get(self.pos).copied()
94    }
95
96    fn advance(&mut self) -> Option<u8> {
97        let c = self.input.get(self.pos).copied();
98        if c.is_some() {
99            self.pos += 1;
100        }
101        c
102    }
103
104    fn expect(&mut self, ch: u8) -> Result<(), String> {
105        if self.advance() == Some(ch) {
106            Ok(())
107        } else {
108            Err(format!("expected '{}' at pos {}", ch as char, self.pos - 1))
109        }
110    }
111
112    /// Parse a chain of atoms/bonds, optionally connected to `prev_atom`.
113    fn parse_chain(&mut self, prev_atom: Option<usize>) -> Result<(), String> {
114        let mut prev = prev_atom;
115        while self.pos < self.input.len() {
116            let c = match self.peek() {
117                Some(c) => c,
118                None => break,
119            };
120
121            match c {
122                b')' => break, // end of branch
123                b'(' => {
124                    // Branch
125                    self.advance();
126                    self.parse_chain(prev)?;
127                    self.expect(b')')?;
128                }
129                b'[' | b'*' | b'c' | b'n' | b'o' | b's' | b'p' | b'C' | b'N' | b'O' | b'S'
130                | b'P' | b'F' | b'B' | b'I' | b'a' | b'A' | b'H' => {
131                    // Parse optional bond before atom
132                    let bond_q = self.parse_bond_if_present();
133                    let atom_idx = self.parse_atom()?;
134                    if let Some(p) = prev {
135                        self.bonds.push(SmartsBond {
136                            from: p,
137                            to: atom_idx,
138                            query: bond_q.unwrap_or(BondQuery::Implicit),
139                        });
140                    }
141                    prev = Some(atom_idx);
142                }
143                b'-' | b'=' | b'#' | b'~' | b'/' | b'\\' | b':' | b'!' | b'@' => {
144                    // Bond followed by atom
145                    let bond_q = self.parse_bond()?;
146                    let atom_idx = self.parse_atom()?;
147                    if let Some(p) = prev {
148                        self.bonds.push(SmartsBond {
149                            from: p,
150                            to: atom_idx,
151                            query: bond_q,
152                        });
153                    }
154                    prev = Some(atom_idx);
155                }
156                b'0'..=b'9' => {
157                    // Ring closure
158                    let digit = (self.advance().unwrap() - b'0') as usize;
159                    if let Some(open_atom) = self.ring_opens[digit] {
160                        self.bonds.push(SmartsBond {
161                            from: open_atom,
162                            to: prev.unwrap_or(0),
163                            query: BondQuery::Implicit,
164                        });
165                        self.ring_opens[digit] = None;
166                    } else {
167                        self.ring_opens[digit] = prev;
168                    }
169                }
170                _ => break,
171            }
172        }
173        Ok(())
174    }
175
176    /// Try to parse a bond query if one is present (without consuming atom chars).
177    fn parse_bond_if_present(&mut self) -> Option<BondQuery> {
178        match self.peek() {
179            Some(b'-') | Some(b'=') | Some(b'#') | Some(b'~') | Some(b'!') | Some(b'@')
180            | Some(b':') => self.parse_bond().ok(),
181            _ => None,
182        }
183    }
184
185    /// Parse a bond query.
186    fn parse_bond(&mut self) -> Result<BondQuery, String> {
187        let mut parts = Vec::new();
188        loop {
189            match self.peek() {
190                Some(b'-') => {
191                    self.advance();
192                    parts.push(BondQuery::Single);
193                }
194                Some(b'=') => {
195                    self.advance();
196                    parts.push(BondQuery::Double);
197                }
198                Some(b'#') => {
199                    self.advance();
200                    parts.push(BondQuery::Triple);
201                }
202                Some(b'~') => {
203                    self.advance();
204                    parts.push(BondQuery::Any);
205                }
206                Some(b':') => {
207                    self.advance();
208                    parts.push(BondQuery::Aromatic);
209                }
210                Some(b'@') => {
211                    self.advance();
212                    parts.push(BondQuery::Ring);
213                }
214                Some(b'!') => {
215                    self.advance();
216                    if self.peek() == Some(b'@') {
217                        self.advance();
218                        parts.push(BondQuery::NotRing);
219                    } else {
220                        let inner = self.parse_bond()?;
221                        parts.push(BondQuery::Not(Box::new(inner)));
222                    }
223                }
224                Some(b';') => {
225                    self.advance();
226                } // AND separator, continue
227                Some(b',') => {
228                    // OR — not typically used in torsion SMARTS bonds
229                    self.advance();
230                }
231                _ => break,
232            }
233        }
234        match parts.len() {
235            0 => Ok(BondQuery::Implicit),
236            1 => Ok(parts.pop().unwrap()),
237            _ => Ok(BondQuery::And(parts)),
238        }
239    }
240
241    /// Parse an atom (either bracket or organic subset).
242    fn parse_atom(&mut self) -> Result<usize, String> {
243        let atom = match self.peek() {
244            Some(b'[') => self.parse_bracket_atom()?,
245            Some(b'*') => {
246                self.advance();
247                SmartsAtom {
248                    query: AtomQuery::True,
249                    map_idx: None,
250                }
251            }
252            _ => self.parse_organic_atom()?,
253        };
254        let idx = self.atoms.len();
255        self.atoms.push(atom);
256        Ok(idx)
257    }
258
259    /// Parse an organic subset atom (single uppercase letter, possibly followed by lowercase).
260    fn parse_organic_atom(&mut self) -> Result<SmartsAtom, String> {
261        let c = self.advance().ok_or("unexpected end")?;
262        let query = match c {
263            b'C' if self.peek() == Some(b'l') => {
264                self.advance();
265                AtomQuery::Element(17)
266            }
267            b'B' if self.peek() == Some(b'r') => {
268                self.advance();
269                AtomQuery::Element(35)
270            }
271            b'C' => AtomQuery::Element(6),
272            b'N' => AtomQuery::Element(7),
273            b'O' => AtomQuery::Element(8),
274            b'S' => AtomQuery::Element(16),
275            b'P' => AtomQuery::Element(15),
276            b'F' => AtomQuery::Element(9),
277            b'B' => AtomQuery::Element(5),
278            b'I' => AtomQuery::Element(53),
279            b'H' => AtomQuery::Element(1),
280            b'c' => AtomQuery::AromaticElem(6),
281            b'n' => AtomQuery::AromaticElem(7),
282            b'o' => AtomQuery::AromaticElem(8),
283            b's' => AtomQuery::AromaticElem(16),
284            b'p' => AtomQuery::AromaticElem(15),
285            b'a' => AtomQuery::AnyAromatic,
286            b'A' => AtomQuery::AnyAliphatic,
287            _ => {
288                return Err(format!(
289                    "unexpected atom char '{}' at pos {}",
290                    c as char,
291                    self.pos - 1
292                ))
293            }
294        };
295        Ok(SmartsAtom {
296            query,
297            map_idx: None,
298        })
299    }
300
301    /// Parse a bracket atom [...]
302    fn parse_bracket_atom(&mut self) -> Result<SmartsAtom, String> {
303        self.expect(b'[')?;
304        let query = self.parse_atom_spec()?;
305        // Check for map class :N before closing bracket
306        let map_idx = if self.peek() == Some(b':') {
307            self.advance();
308            Some(self.parse_number()? as u8)
309        } else {
310            None
311        };
312        self.expect(b']')?;
313        Ok(SmartsAtom { query, map_idx })
314    }
315
316    /// Top-level atom spec: semicolon-separated low-priority AND groups.
317    /// Precedence (lowest to highest): ; (low AND) < , (OR) < implicit/& (high AND) < ! (NOT)
318    fn parse_atom_spec(&mut self) -> Result<AtomQuery, String> {
319        let mut parts = vec![self.parse_atom_query_or()?];
320        while self.peek() == Some(b';') {
321            self.advance();
322            parts.push(self.parse_atom_query_or()?);
323        }
324        if parts.len() == 1 {
325            Ok(parts.pop().unwrap())
326        } else {
327            Ok(AtomQuery::And(parts))
328        }
329    }
330
331    /// Parse an atom query with OR (comma-separated).
332    fn parse_atom_query_or(&mut self) -> Result<AtomQuery, String> {
333        let mut parts = vec![self.parse_atom_query_and()?];
334        while self.peek() == Some(b',') {
335            self.advance();
336            parts.push(self.parse_atom_query_and()?);
337        }
338        if parts.len() == 1 {
339            Ok(parts.pop().unwrap())
340        } else {
341            Ok(AtomQuery::Or(parts))
342        }
343    }
344
345    /// Parse an atom query with high-priority AND (implicit juxtaposition or &).
346    fn parse_atom_query_and(&mut self) -> Result<AtomQuery, String> {
347        let mut parts = Vec::new();
348        loop {
349            match self.peek() {
350                Some(b']') | Some(b',') | Some(b':') | Some(b';') | None => break,
351                Some(b'&') => {
352                    self.advance();
353                } // explicit high-priority AND
354                _ => parts.push(self.parse_atom_primitive()?),
355            }
356        }
357        match parts.len() {
358            0 => Ok(AtomQuery::True),
359            1 => Ok(parts.pop().unwrap()),
360            _ => Ok(AtomQuery::And(parts)),
361        }
362    }
363
364    /// Parse a single atom primitive.
365    fn parse_atom_primitive(&mut self) -> Result<AtomQuery, String> {
366        let c = self.peek().ok_or("unexpected end in atom spec")?;
367        match c {
368            b'!' => {
369                self.advance();
370                let inner = self.parse_atom_primitive()?;
371                Ok(AtomQuery::Not(Box::new(inner)))
372            }
373            b'#' => {
374                self.advance();
375                let n = self.parse_number()? as u8;
376                Ok(AtomQuery::AtomicNum(n))
377            }
378            b'$' => {
379                // Recursive SMARTS: $(smarts)
380                self.advance();
381                self.expect(b'(')?;
382                let start = self.pos;
383                // Find matching closing paren, handling nested parens
384                let mut depth = 1;
385                while depth > 0 && self.pos < self.input.len() {
386                    match self.input[self.pos] {
387                        b'(' => depth += 1,
388                        b')' => depth -= 1,
389                        _ => {}
390                    }
391                    if depth > 0 {
392                        self.pos += 1;
393                    }
394                }
395                let inner_str = std::str::from_utf8(&self.input[start..self.pos])
396                    .map_err(|_| "invalid utf8 in recursive SMARTS")?;
397                self.expect(b')')?;
398                let inner = parse_smarts(inner_str)?;
399                Ok(AtomQuery::Recursive(Box::new(inner)))
400            }
401            b'X' => {
402                self.advance();
403                let n = self.parse_number()? as u8;
404                Ok(AtomQuery::TotalDegree(n))
405            }
406            b'x' => {
407                self.advance();
408                let n = self.parse_number()? as u8;
409                Ok(AtomQuery::RingBondCount(n))
410            }
411            b'H' => {
412                self.advance();
413                // H followed by digit = hydrogen count; otherwise H0 is "no H"
414                if self.peek().is_some_and(|c| c.is_ascii_digit()) {
415                    let n = self.parse_number()? as u8;
416                    Ok(AtomQuery::TotalH(n))
417                } else {
418                    // H without number means H >= 1 (at least one hydrogen)
419                    Ok(AtomQuery::TotalH(1))
420                }
421            }
422            b'D' => {
423                self.advance();
424                if self.peek().is_some_and(|c| c.is_ascii_digit()) {
425                    let n = self.parse_number()? as u8;
426                    Ok(AtomQuery::HeavyDegree(n))
427                } else {
428                    Ok(AtomQuery::HeavyDegree(1))
429                }
430            }
431            b'R' => {
432                self.advance();
433                if self.peek().is_some_and(|c| c.is_ascii_digit()) {
434                    let n = self.parse_number()? as u8;
435                    Ok(AtomQuery::RingCount(n))
436                } else {
437                    Ok(AtomQuery::InRing)
438                }
439            }
440            b'r' => {
441                self.advance();
442                if self.peek() == Some(b'{') {
443                    // r{N-M} or r{N-}
444                    self.advance(); // '{'
445                    if self.peek() == Some(b'-') {
446                        // r{-M} → ring size ≤ M
447                        self.advance();
448                        let m = self.parse_number()? as u8;
449                        self.expect(b'}')?;
450                        Ok(AtomQuery::RingSizeRange(3, m))
451                    } else {
452                        let n = self.parse_number()? as u8;
453                        if self.peek() == Some(b'-') {
454                            self.advance();
455                            if self.peek() == Some(b'}') {
456                                self.advance();
457                                Ok(AtomQuery::RingSizeMin(n))
458                            } else {
459                                let m = self.parse_number()? as u8;
460                                self.expect(b'}')?;
461                                Ok(AtomQuery::RingSizeRange(n, m))
462                            }
463                        } else {
464                            self.expect(b'}')?;
465                            Ok(AtomQuery::RingSize(n))
466                        }
467                    }
468                } else if self.peek().is_some_and(|c| c.is_ascii_digit()) {
469                    let n = self.parse_number()? as u8;
470                    Ok(AtomQuery::RingSize(n))
471                } else {
472                    Ok(AtomQuery::InRing)
473                }
474            }
475            b'+' => {
476                self.advance();
477                if self.peek().is_some_and(|c| c.is_ascii_digit()) {
478                    let n = self.parse_number()? as i8;
479                    Ok(AtomQuery::FormalCharge(n))
480                } else {
481                    Ok(AtomQuery::FormalCharge(1))
482                }
483            }
484            b'-' => {
485                // Careful: '-' can also be a bond. Inside bracket atom, it's charge.
486                self.advance();
487                if self.peek().is_some_and(|c| c.is_ascii_digit()) {
488                    let n = self.parse_number()? as i8;
489                    Ok(AtomQuery::FormalCharge(-n))
490                } else {
491                    Ok(AtomQuery::FormalCharge(-1))
492                }
493            }
494            b'^' => {
495                self.advance();
496                let n = self.parse_number()? as u8;
497                Ok(AtomQuery::Hybridization(n))
498            }
499            b'*' => {
500                self.advance();
501                Ok(AtomQuery::True)
502            }
503            b'a' => {
504                self.advance();
505                Ok(AtomQuery::AnyAromatic)
506            }
507            b'A' => {
508                self.advance();
509                // Check if followed by more letters (element symbol like Al, As, etc.)
510                // For CSD patterns, 'A' alone means aliphatic
511                Ok(AtomQuery::AnyAliphatic)
512            }
513            b'C' => {
514                self.advance();
515                if self.peek() == Some(b'l') {
516                    self.advance();
517                    Ok(AtomQuery::Element(17))
518                } else {
519                    Ok(AtomQuery::Element(6))
520                }
521            }
522            b'N' => {
523                self.advance();
524                Ok(AtomQuery::Element(7))
525            }
526            b'O' => {
527                self.advance();
528                Ok(AtomQuery::Element(8))
529            }
530            b'S' => {
531                self.advance();
532                Ok(AtomQuery::Element(16))
533            }
534            b'P' => {
535                self.advance();
536                Ok(AtomQuery::Element(15))
537            }
538            b'F' => {
539                self.advance();
540                Ok(AtomQuery::Element(9))
541            }
542            b'B' => {
543                self.advance();
544                if self.peek() == Some(b'r') {
545                    self.advance();
546                    Ok(AtomQuery::Element(35))
547                } else {
548                    Ok(AtomQuery::Element(5))
549                }
550            }
551            b'I' => {
552                self.advance();
553                Ok(AtomQuery::Element(53))
554            }
555            b'c' => {
556                self.advance();
557                Ok(AtomQuery::AromaticElem(6))
558            }
559            b'n' => {
560                self.advance();
561                Ok(AtomQuery::AromaticElem(7))
562            }
563            b'o' => {
564                self.advance();
565                Ok(AtomQuery::AromaticElem(8))
566            }
567            b's' => {
568                self.advance();
569                Ok(AtomQuery::AromaticElem(16))
570            }
571            b'p' => {
572                self.advance();
573                Ok(AtomQuery::AromaticElem(15))
574            }
575            _ => Err(format!(
576                "unexpected '{}' at pos {} in atom spec",
577                c as char, self.pos
578            )),
579        }
580    }
581
582    fn parse_number(&mut self) -> Result<i32, String> {
583        let start = self.pos;
584        while self.pos < self.input.len() && self.input[self.pos].is_ascii_digit() {
585            self.pos += 1;
586        }
587        if self.pos == start {
588            return Err(format!("expected number at pos {}", self.pos));
589        }
590        let s = std::str::from_utf8(&self.input[start..self.pos]).map_err(|_| "invalid utf8")?;
591        s.parse::<i32>().map_err(|e| e.to_string())
592    }
593}
594
595#[cfg(test)]
596mod tests {
597    use super::*;
598
599    #[test]
600    fn test_simple_pattern() {
601        let p = parse_smarts("[O:1]=[C:2]!@;-[O:3]~[CH0:4]").unwrap();
602        assert_eq!(p.atoms.len(), 4);
603        assert_eq!(p.bonds.len(), 3);
604        assert_eq!(p.atoms[0].map_idx, Some(1));
605        assert_eq!(p.atoms[3].map_idx, Some(4));
606    }
607
608    #[test]
609    fn test_recursive_smarts() {
610        let p = parse_smarts("[$([CX3]=O):1][NX3H1:2]!@;-[c:3][cH1:4]").unwrap();
611        assert_eq!(p.atoms.len(), 4);
612        if let AtomQuery::Recursive(ref inner) = p.atoms[0].query {
613            assert_eq!(inner.atoms.len(), 2); // CX3 and O
614        } else {
615            panic!("expected recursive");
616        }
617    }
618
619    #[test]
620    fn test_branch() {
621        let p = parse_smarts("[a:1][c:2]([a])!@;-[O:3][C:4]").unwrap();
622        assert_eq!(p.atoms.len(), 5); // a, c, a_branch, O, C
623        assert_eq!(p.bonds.len(), 4);
624    }
625
626    #[test]
627    fn test_ring_size_range() {
628        let p = parse_smarts("[c;r{9-}:2]").unwrap();
629        assert_eq!(p.atoms.len(), 1);
630        if let AtomQuery::And(ref parts) = p.atoms[0].query {
631            assert!(parts.iter().any(|q| matches!(q, AtomQuery::RingSizeMin(9))));
632        }
633    }
634
635    #[test]
636    fn test_parse_all_csd_patterns() {
637        let data = include_str!("../../tests/fixtures/smarts_patterns.txt");
638        let mut ok = 0;
639        let mut fail = 0;
640        let mut failures = Vec::new();
641        for line in data.lines() {
642            let smarts = line.split('\t').next().unwrap().trim();
643            if smarts.is_empty() {
644                continue;
645            }
646            match parse_smarts(smarts) {
647                Ok(p) => {
648                    ok += 1;
649                    let mapped: Vec<_> = p.atoms.iter().filter(|a| a.map_idx.is_some()).collect();
650                    if mapped.len() != 4 {
651                        failures.push(format!("WARN mapped={}: {}", mapped.len(), smarts));
652                    }
653                }
654                Err(e) => {
655                    fail += 1;
656                    failures.push(format!("FAIL: {} → {}", smarts, e));
657                }
658            }
659        }
660        for f in &failures {
661            eprintln!("{}", f);
662        }
663        eprintln!("\nParsed: {} ok, {} failed out of {}", ok, fail, ok + fail);
664        assert_eq!(fail, 0, "{} patterns failed to parse", fail);
665    }
666}