1use crate::graph::Molecule;
8use crate::smarts::{parse_smarts, substruct_match};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct SmirksTransform {
15 pub reactant_smarts: Vec<String>,
17 pub product_smarts: Vec<String>,
19 pub atom_map: HashMap<usize, usize>,
21 pub bond_changes: Vec<BondChange>,
23 pub smirks: String,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct BondChange {
30 pub atom1_map: usize,
32 pub atom2_map: usize,
34 pub old_order: Option<String>,
36 pub new_order: Option<String>,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct SmirksResult {
43 pub products: Vec<String>,
45 pub atom_mapping: HashMap<usize, usize>,
47 pub n_transforms: usize,
49 pub success: bool,
51 pub messages: Vec<String>,
53}
54
55pub fn parse_smirks(smirks: &str) -> Result<SmirksTransform, String> {
65 let parts: Vec<&str> = smirks.split(">>").collect();
66 if parts.len() != 2 {
67 return Err("SMIRKS must contain exactly one '>>' separator".to_string());
68 }
69
70 let reactant_part = parts[0].trim();
71 let product_part = parts[1].trim();
72
73 if reactant_part.is_empty() || product_part.is_empty() {
74 return Err("SMIRKS reactant and product parts must be non-empty".to_string());
75 }
76
77 let reactant_smarts: Vec<String> = reactant_part.split('.').map(|s| s.to_string()).collect();
79 let product_smarts: Vec<String> = product_part.split('.').map(|s| s.to_string()).collect();
80
81 let reactant_maps = extract_atom_maps(reactant_part)?;
83 let product_maps = extract_atom_maps(product_part)?;
84
85 let mut atom_map = HashMap::new();
87 for map_num in reactant_maps.keys() {
88 if product_maps.contains_key(map_num) {
89 atom_map.insert(*map_num, *map_num);
90 }
91 }
92
93 let mapped_in_reactant: std::collections::HashSet<usize> = atom_map.keys().copied().collect();
96 let mapped_in_product: std::collections::HashSet<usize> = atom_map.values().copied().collect();
97 if mapped_in_reactant != mapped_in_product {
98 return Err(format!(
99 "SMIRKS atom maps are not bijective: reactant maps {:?} vs product maps {:?}",
100 mapped_in_reactant, mapped_in_product
101 ));
102 }
103
104 let bond_changes =
106 detect_bond_changes(reactant_part, product_part, &reactant_maps, &product_maps);
107
108 Ok(SmirksTransform {
109 reactant_smarts,
110 product_smarts,
111 atom_map,
112 bond_changes,
113 smirks: smirks.to_string(),
114 })
115}
116
117pub fn apply_smirks(smirks: &str, smiles: &str) -> Result<SmirksResult, String> {
121 let transform = parse_smirks(smirks)?;
122
123 if transform.reactant_smarts.len() > 1 || transform.product_smarts.len() > 1 {
125 return Ok(SmirksResult {
126 products: vec![],
127 atom_mapping: HashMap::new(),
128 n_transforms: 0,
129 success: false,
130 messages: vec![
131 "Multi-component SMIRKS are not supported by apply_smirks because Molecule::from_smiles retains only the largest fragment".to_string(),
132 ],
133 });
134 }
135
136 let mol = Molecule::from_smiles(smiles)?;
137
138 let matches = match_smarts_pattern(&mol, &transform.reactant_smarts[0])?;
140
141 if matches.is_empty() {
142 return Ok(SmirksResult {
143 products: vec![],
144 atom_mapping: HashMap::new(),
145 n_transforms: 0,
146 success: false,
147 messages: vec!["No match found for reactant pattern".to_string()],
148 });
149 }
150
151 let atom_mapping = &matches[0];
153
154 Ok(SmirksResult {
156 products: transform.product_smarts.clone(),
157 atom_mapping: atom_mapping.clone(),
158 n_transforms: 1,
159 success: true,
160 messages: vec![format!(
161 "Transform applied: {} atoms mapped",
162 atom_mapping.len()
163 )],
164 })
165}
166
167fn extract_atom_maps(pattern: &str) -> Result<HashMap<usize, usize>, String> {
170 let mut maps = HashMap::new();
171 let bytes = pattern.as_bytes();
172 let mut pos = 0;
173 let mut atom_idx = 0;
174
175 while pos < bytes.len() {
176 if bytes[pos] == b'[' {
177 let start = pos;
179 while pos < bytes.len() && bytes[pos] != b']' {
180 pos += 1;
181 }
182 let bracket_content = &pattern[start..=pos.min(bytes.len() - 1)];
183
184 if let Some(colon_pos) = bracket_content.rfind(':') {
186 let map_str = &bracket_content[colon_pos + 1..bracket_content.len() - 1];
187 if let Ok(map_num) = map_str.parse::<usize>() {
188 if maps.insert(map_num, atom_idx).is_some() {
189 return Err(format!(
190 "duplicate atom map :{} in pattern '{}'",
191 map_num, pattern
192 ));
193 }
194 }
195 }
196 atom_idx += 1;
197 } else if bytes[pos].is_ascii_uppercase()
198 || (bytes[pos] == b'c'
199 || bytes[pos] == b'n'
200 || bytes[pos] == b'o'
201 || bytes[pos] == b's')
202 {
203 atom_idx += 1;
204 }
205 pos += 1;
206 }
207
208 Ok(maps)
209}
210
211fn detect_bond_changes(
213 _reactant: &str,
214 _product: &str,
215 reactant_maps: &HashMap<usize, usize>,
216 product_maps: &HashMap<usize, usize>,
217) -> Vec<BondChange> {
218 let mut changes = Vec::new();
219
220 for map_num in reactant_maps.keys() {
222 if !product_maps.contains_key(map_num) {
223 changes.push(BondChange {
224 atom1_map: *map_num,
225 atom2_map: 0,
226 old_order: Some("SINGLE".to_string()),
227 new_order: None,
228 });
229 }
230 }
231
232 for map_num in product_maps.keys() {
234 if !reactant_maps.contains_key(map_num) {
235 changes.push(BondChange {
236 atom1_map: *map_num,
237 atom2_map: 0,
238 old_order: None,
239 new_order: Some("SINGLE".to_string()),
240 });
241 }
242 }
243
244 changes
245}
246
247fn match_smarts_pattern(
250 mol: &Molecule,
251 pattern: &str,
252) -> Result<Vec<HashMap<usize, usize>>, String> {
253 let parsed = parse_smarts(pattern)?;
254 let mapped_atoms: Vec<(usize, usize)> = parsed
255 .atoms
256 .iter()
257 .enumerate()
258 .filter_map(|(idx, atom)| atom.map_idx.map(|map_idx| (idx, map_idx as usize)))
259 .collect();
260
261 Ok(substruct_match(mol, &parsed)
262 .into_iter()
263 .map(|matched_atoms| {
264 mapped_atoms
265 .iter()
266 .map(|(pattern_idx, map_num)| (*map_num, matched_atoms[*pattern_idx]))
267 .collect()
268 })
269 .collect())
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 #[test]
277 fn test_parse_smirks_basic() {
278 let result = parse_smirks("[C:1](=O)[OH:2]>>[C:1](=O)[O-:2]");
279 assert!(result.is_ok());
280 let t = result.unwrap();
281 assert_eq!(t.reactant_smarts.len(), 1);
282 assert_eq!(t.product_smarts.len(), 1);
283 assert!(!t.atom_map.is_empty());
284 }
285
286 #[test]
287 fn test_parse_smirks_invalid() {
288 assert!(parse_smirks("no_separator").is_err());
289 assert!(parse_smirks(">>").is_err());
290 }
291
292 #[test]
293 fn test_extract_atom_maps() {
294 let maps = extract_atom_maps("[C:1](=O)[OH:2]").unwrap();
295 assert!(maps.contains_key(&1));
296 assert!(maps.contains_key(&2));
297 }
298
299 #[test]
300 fn test_extract_atom_maps_rejects_duplicates() {
301 let err = extract_atom_maps("[C:1][O:1]").unwrap_err();
302 assert!(err.contains("duplicate atom map"));
303 }
304
305 #[test]
306 fn test_apply_smirks() {
307 let result = apply_smirks("[C:1](=O)[OH:2]>>[C:1](=O)[O-:2]", "CC(=O)O");
308 let result = result.unwrap();
309 assert!(result.success);
310 assert_eq!(result.n_transforms, 1);
311 assert_eq!(result.atom_mapping.len(), 2);
312 }
313
314 #[test]
315 fn test_apply_smirks_requires_real_match() {
316 let result = apply_smirks("[N:1]>>[N:1]", "CCO").unwrap();
317 assert!(!result.success);
318 assert_eq!(result.n_transforms, 0);
319 }
320
321 #[test]
322 fn test_apply_smirks_rejects_multicomponent_transform() {
323 let result = apply_smirks("[O:1].[Na+:2]>>[O:1][Na+:2]", "CC(=O)O").unwrap();
324 assert!(!result.success);
325 assert!(result.messages[0].contains("Multi-component SMIRKS"));
326 }
327}