Skip to main content

sci_form/
smirks.rs

1//! SMIRKS reaction transform support.
2//!
3//! SMIRKS is an extension of SMARTS that describes chemical reactions
4//! as atom-mapped reactant>>product transforms. This module parses
5//! SMIRKS patterns and applies them to molecular graphs.
6
7use crate::graph::Molecule;
8use crate::smarts::{parse_smarts, substruct_match};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12/// A parsed SMIRKS reaction transform.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct SmirksTransform {
15    /// Reactant SMARTS pattern(s).
16    pub reactant_smarts: Vec<String>,
17    /// Product SMARTS pattern(s).
18    pub product_smarts: Vec<String>,
19    /// Atom map: reactant_atom_map_num → product_atom_map_num.
20    pub atom_map: HashMap<usize, usize>,
21    /// Bond changes: (atom_map1, atom_map2, old_order, new_order).
22    pub bond_changes: Vec<BondChange>,
23    /// Original SMIRKS string.
24    pub smirks: String,
25}
26
27/// A bond change specified by a SMIRKS transform.
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct BondChange {
30    /// Atom map number of first atom.
31    pub atom1_map: usize,
32    /// Atom map number of second atom.
33    pub atom2_map: usize,
34    /// Bond order in reactant (None = bond doesn't exist).
35    pub old_order: Option<String>,
36    /// Bond order in product (None = bond broken).
37    pub new_order: Option<String>,
38}
39
40/// Result of applying a SMIRKS transform.
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct SmirksResult {
43    /// Product SMILES strings.
44    pub products: Vec<String>,
45    /// Atom mapping from reactant to product indices.
46    pub atom_mapping: HashMap<usize, usize>,
47    /// Number of transforms applied.
48    pub n_transforms: usize,
49    /// Whether the transform was successfully applied.
50    pub success: bool,
51    /// Error or warning messages.
52    pub messages: Vec<String>,
53}
54
55/// Parse a SMIRKS reaction string.
56///
57/// Format: `reactant_smarts>>product_smarts`
58/// Atom maps: `[C:1]`, `[N:2]`, etc.
59///
60/// # Example
61/// ```ignore
62/// let transform = parse_smirks("[C:1](=O)[OH:2]>>[C:1](=O)[O-:2]").unwrap();
63/// ```
64pub fn parse_smirks(smirks: &str) -> Result<SmirksTransform, String> {
65    let parts: Vec<&str> = smirks.split(">>").collect();
66    if parts.len() != 2 {
67        return Err("SMIRKS must contain exactly one '>>' separator".to_string());
68    }
69
70    let reactant_part = parts[0].trim();
71    let product_part = parts[1].trim();
72
73    if reactant_part.is_empty() || product_part.is_empty() {
74        return Err("SMIRKS reactant and product parts must be non-empty".to_string());
75    }
76
77    // Split by '.' for multi-component reactions
78    let reactant_smarts: Vec<String> = reactant_part.split('.').map(|s| s.to_string()).collect();
79    let product_smarts: Vec<String> = product_part.split('.').map(|s| s.to_string()).collect();
80
81    // Extract atom maps from both sides
82    let reactant_maps = extract_atom_maps(reactant_part)?;
83    let product_maps = extract_atom_maps(product_part)?;
84
85    // Build the atom map (reactant map num → product map num)
86    let mut atom_map = HashMap::new();
87    for map_num in reactant_maps.keys() {
88        if product_maps.contains_key(map_num) {
89            atom_map.insert(*map_num, *map_num);
90        }
91    }
92
93    // Validate atom map bijectivity: each mapped atom in reactant must appear
94    // exactly once in product and vice versa
95    let mapped_in_reactant: std::collections::HashSet<usize> = atom_map.keys().copied().collect();
96    let mapped_in_product: std::collections::HashSet<usize> = atom_map.values().copied().collect();
97    if mapped_in_reactant != mapped_in_product {
98        return Err(format!(
99            "SMIRKS atom maps are not bijective: reactant maps {:?} vs product maps {:?}",
100            mapped_in_reactant, mapped_in_product
101        ));
102    }
103
104    // Detect bond changes
105    let bond_changes =
106        detect_bond_changes(reactant_part, product_part, &reactant_maps, &product_maps);
107
108    Ok(SmirksTransform {
109        reactant_smarts,
110        product_smarts,
111        atom_map,
112        bond_changes,
113        smirks: smirks.to_string(),
114    })
115}
116
117/// Apply a SMIRKS transform to a molecule (represented as SMILES).
118///
119/// Returns the product SMILES if the reactant pattern matches.
120pub fn apply_smirks(smirks: &str, smiles: &str) -> Result<SmirksResult, String> {
121    let transform = parse_smirks(smirks)?;
122
123    // Parse the input molecule
124    if transform.reactant_smarts.len() > 1 || transform.product_smarts.len() > 1 {
125        return Ok(SmirksResult {
126            products: vec![],
127            atom_mapping: HashMap::new(),
128            n_transforms: 0,
129            success: false,
130            messages: vec![
131                "Multi-component SMIRKS are not supported by apply_smirks because Molecule::from_smiles retains only the largest fragment".to_string(),
132            ],
133        });
134    }
135
136    let mol = Molecule::from_smiles(smiles)?;
137
138    // Try to match the reactant pattern using substructure matching
139    let matches = match_smarts_pattern(&mol, &transform.reactant_smarts[0])?;
140
141    if matches.is_empty() {
142        return Ok(SmirksResult {
143            products: vec![],
144            atom_mapping: HashMap::new(),
145            n_transforms: 0,
146            success: false,
147            messages: vec!["No match found for reactant pattern".to_string()],
148        });
149    }
150
151    // Apply the first match
152    let atom_mapping = &matches[0];
153
154    // For now, return a simplified result indicating the transform was recognized
155    Ok(SmirksResult {
156        products: transform.product_smarts.clone(),
157        atom_mapping: atom_mapping.clone(),
158        n_transforms: 1,
159        success: true,
160        messages: vec![format!(
161            "Transform applied: {} atoms mapped",
162            atom_mapping.len()
163        )],
164    })
165}
166
167/// Extract atom map numbers from a SMARTS/SMIRKS string.
168/// Returns map_number → pattern_atom_index.
169fn extract_atom_maps(pattern: &str) -> Result<HashMap<usize, usize>, String> {
170    let mut maps = HashMap::new();
171    let bytes = pattern.as_bytes();
172    let mut pos = 0;
173    let mut atom_idx = 0;
174
175    while pos < bytes.len() {
176        if bytes[pos] == b'[' {
177            // Find the closing bracket
178            let start = pos;
179            while pos < bytes.len() && bytes[pos] != b']' {
180                pos += 1;
181            }
182            let bracket_content = &pattern[start..=pos.min(bytes.len() - 1)];
183
184            // Look for :N atom map
185            if let Some(colon_pos) = bracket_content.rfind(':') {
186                let map_str = &bracket_content[colon_pos + 1..bracket_content.len() - 1];
187                if let Ok(map_num) = map_str.parse::<usize>() {
188                    if maps.insert(map_num, atom_idx).is_some() {
189                        return Err(format!(
190                            "duplicate atom map :{} in pattern '{}'",
191                            map_num, pattern
192                        ));
193                    }
194                }
195            }
196            atom_idx += 1;
197        } else if bytes[pos].is_ascii_uppercase()
198            || (bytes[pos] == b'c'
199                || bytes[pos] == b'n'
200                || bytes[pos] == b'o'
201                || bytes[pos] == b's')
202        {
203            atom_idx += 1;
204        }
205        pos += 1;
206    }
207
208    Ok(maps)
209}
210
211/// Detect bond changes between reactant and product patterns.
212fn detect_bond_changes(
213    _reactant: &str,
214    _product: &str,
215    reactant_maps: &HashMap<usize, usize>,
216    product_maps: &HashMap<usize, usize>,
217) -> Vec<BondChange> {
218    let mut changes = Vec::new();
219
220    // Atoms that appear in reactant but not product → bonds broken
221    for map_num in reactant_maps.keys() {
222        if !product_maps.contains_key(map_num) {
223            changes.push(BondChange {
224                atom1_map: *map_num,
225                atom2_map: 0,
226                old_order: Some("SINGLE".to_string()),
227                new_order: None,
228            });
229        }
230    }
231
232    // Atoms that appear in product but not reactant → bonds formed
233    for map_num in product_maps.keys() {
234        if !reactant_maps.contains_key(map_num) {
235            changes.push(BondChange {
236                atom1_map: *map_num,
237                atom2_map: 0,
238                old_order: None,
239                new_order: Some("SINGLE".to_string()),
240            });
241        }
242    }
243
244    changes
245}
246
247/// Substructure matching for SMARTS patterns.
248/// Returns list of atom-map-number → molecule-index mappings.
249fn match_smarts_pattern(
250    mol: &Molecule,
251    pattern: &str,
252) -> Result<Vec<HashMap<usize, usize>>, String> {
253    let parsed = parse_smarts(pattern)?;
254    let mapped_atoms: Vec<(usize, usize)> = parsed
255        .atoms
256        .iter()
257        .enumerate()
258        .filter_map(|(idx, atom)| atom.map_idx.map(|map_idx| (idx, map_idx as usize)))
259        .collect();
260
261    Ok(substruct_match(mol, &parsed)
262        .into_iter()
263        .map(|matched_atoms| {
264            mapped_atoms
265                .iter()
266                .map(|(pattern_idx, map_num)| (*map_num, matched_atoms[*pattern_idx]))
267                .collect()
268        })
269        .collect())
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    #[test]
277    fn test_parse_smirks_basic() {
278        let result = parse_smirks("[C:1](=O)[OH:2]>>[C:1](=O)[O-:2]");
279        assert!(result.is_ok());
280        let t = result.unwrap();
281        assert_eq!(t.reactant_smarts.len(), 1);
282        assert_eq!(t.product_smarts.len(), 1);
283        assert!(!t.atom_map.is_empty());
284    }
285
286    #[test]
287    fn test_parse_smirks_invalid() {
288        assert!(parse_smirks("no_separator").is_err());
289        assert!(parse_smirks(">>").is_err());
290    }
291
292    #[test]
293    fn test_extract_atom_maps() {
294        let maps = extract_atom_maps("[C:1](=O)[OH:2]").unwrap();
295        assert!(maps.contains_key(&1));
296        assert!(maps.contains_key(&2));
297    }
298
299    #[test]
300    fn test_extract_atom_maps_rejects_duplicates() {
301        let err = extract_atom_maps("[C:1][O:1]").unwrap_err();
302        assert!(err.contains("duplicate atom map"));
303    }
304
305    #[test]
306    fn test_apply_smirks() {
307        let result = apply_smirks("[C:1](=O)[OH:2]>>[C:1](=O)[O-:2]", "CC(=O)O");
308        let result = result.unwrap();
309        assert!(result.success);
310        assert_eq!(result.n_transforms, 1);
311        assert_eq!(result.atom_mapping.len(), 2);
312    }
313
314    #[test]
315    fn test_apply_smirks_requires_real_match() {
316        let result = apply_smirks("[N:1]>>[N:1]", "CCO").unwrap();
317        assert!(!result.success);
318        assert_eq!(result.n_transforms, 0);
319    }
320
321    #[test]
322    fn test_apply_smirks_rejects_multicomponent_transform() {
323        let result = apply_smirks("[O:1].[Na+:2]>>[O:1][Na+:2]", "CC(=O)O").unwrap();
324        assert!(!result.success);
325        assert!(result.messages[0].contains("Multi-component SMIRKS"));
326    }
327}