1use crate::graph::{BondOrder, Molecule};
9use petgraph::graph::NodeIndex;
10use petgraph::visit::EdgeRef;
11use serde::{Deserialize, Serialize};
12use std::collections::BTreeSet;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct HoseCode {
17 pub atom_index: usize,
19 pub element: u8,
21 pub spheres: Vec<String>,
23 pub full_code: String,
25}
26
27fn bond_symbol(order: BondOrder) -> &'static str {
28 match order {
29 BondOrder::Single => "",
30 BondOrder::Double => "=",
31 BondOrder::Triple => "#",
32 BondOrder::Aromatic => "*",
33 BondOrder::Unknown => "",
34 }
35}
36
37fn element_symbol(z: u8) -> &'static str {
38 match z {
39 1 => "H",
40 5 => "B",
41 6 => "C",
42 7 => "N",
43 8 => "O",
44 9 => "F",
45 14 => "Si",
46 15 => "P",
47 16 => "S",
48 17 => "Cl",
49 35 => "Br",
50 53 => "I",
51 _ => "X",
52 }
53}
54
55pub fn generate_hose_codes(mol: &Molecule, max_radius: usize) -> Vec<HoseCode> {
60 let n = mol.graph.node_count();
61 let mut codes = Vec::with_capacity(n);
62
63 for atom_idx in 0..n {
64 let center = NodeIndex::new(atom_idx);
65 let center_element = mol.graph[center].element;
66 let center_sym = element_symbol(center_element);
67
68 let mut spheres = Vec::with_capacity(max_radius + 1);
69 spheres.push(center_sym.to_string());
70
71 let mut visited = vec![false; n];
73 visited[atom_idx] = true;
74 let mut current_frontier: Vec<(NodeIndex, BondOrder)> = Vec::new();
75
76 for edge in mol.graph.edges(center) {
78 let neighbor = if edge.source() == center {
79 edge.target()
80 } else {
81 edge.source()
82 };
83 current_frontier.push((neighbor, edge.weight().order));
84 }
85
86 for _radius in 1..=max_radius {
87 if current_frontier.is_empty() {
88 spheres.push(String::new());
89 continue;
90 }
91
92 let mut sphere_parts: BTreeSet<String> = BTreeSet::new();
94 let mut next_frontier: Vec<(NodeIndex, BondOrder)> = Vec::new();
95
96 for (node, bond_order) in ¤t_frontier {
97 let node_idx = node.index();
98 if visited[node_idx] {
99 continue;
100 }
101 visited[node_idx] = true;
102
103 let elem = mol.graph[*node].element;
104 let sym = element_symbol(elem);
105 let bond_sym = bond_symbol(*bond_order);
106 sphere_parts.insert(format!("{}{}", bond_sym, sym));
107
108 for edge in mol.graph.edges(*node) {
110 let next = if edge.source() == *node {
111 edge.target()
112 } else {
113 edge.source()
114 };
115 if !visited[next.index()] {
116 next_frontier.push((next, edge.weight().order));
117 }
118 }
119 }
120
121 spheres.push(sphere_parts.into_iter().collect::<Vec<_>>().join(","));
122 current_frontier = next_frontier;
123 }
124
125 let full_code = format!("{}/{}", spheres[0], spheres[1..].join("/"));
126
127 codes.push(HoseCode {
128 atom_index: atom_idx,
129 element: center_element,
130 spheres,
131 full_code,
132 });
133 }
134
135 codes
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct HoseShiftLookup {
147 pub atom_index: usize,
149 pub element: u8,
151 pub shift_ppm: f64,
153 pub matched_hose: String,
155 pub match_radius: usize,
157 pub confidence: f64,
159}
160
161fn h1_hose_database() -> Vec<(&'static str, f64)> {
164 vec![
165 ("C/H,H,H,C", 0.90), ("C/H,H,C,C", 1.30), ("C/H,C,C,C", 1.50), ("C/H,H,H,O", 3.40), ("C/H,H,O", 3.50), ("C/H,H,H,N", 2.30), ("C/H,H,N", 2.60), ("C/H,H,H,=O", 2.10), ("C/H,H,=O", 2.50), ("C/H,H,H,*C", 2.30), ("C/*C,*C,H", 7.27), ("C/*C,*C,*C", 7.50), ("C/*C,*N,H", 7.80), ("C/=C,H,H", 5.25), ("C/=C,H,C", 5.40), ("C/=C,=C,H", 6.30), ("C/#C,H", 2.50), ("C/=O,H,C", 9.50), ("C/=O,H,H", 9.60), ("O/H,C", 2.50), ("O/H,*C", 5.50), ("N/H,H,C", 1.50), ("N/H,C,C", 2.20), ]
196}
197
198fn c13_hose_database() -> Vec<(&'static str, f64)> {
200 vec![
201 ("C/H,H,H,C", 15.0), ("C/H,H,C,C", 25.0), ("C/H,C,C,C", 35.0), ("C/C,C,C,C", 40.0), ("C/H,H,H,O", 55.0), ("C/H,H,O,C", 65.0), ("C/H,H,H,N", 32.0), ("C/H,H,N", 45.0), ("C/H,H,H,=O", 30.0), ("C/*C,*C,H", 128.0), ("C/*C,*C,C", 137.0), ("C/*C,*C,O", 155.0), ("C/*C,*C,N", 148.0), ("C/*C,*C,F", 163.0), ("C/*C,*C,Cl", 134.0), ("C/=C,H,H", 115.0), ("C/=C,H,C", 130.0), ("C/=C,C,C", 140.0), ("C/=O,O,C", 175.0), ("C/=O,N,C", 170.0), ("C/=O,C,C", 205.0), ("C/=O,H,C", 200.0), ("C/#C,H", 70.0), ("C/#C,C", 85.0), ]
231}
232
233pub fn predict_shift_from_hose(hose_code: &HoseCode, nucleus: u8) -> Option<HoseShiftLookup> {
238 let database = match nucleus {
239 1 => h1_hose_database(),
240 6 => c13_hose_database(),
241 _ => return None,
242 };
243
244 for radius in (1..=hose_code.spheres.len().saturating_sub(1)).rev() {
246 let prefix = format!(
248 "{}/{}",
249 hose_code.spheres[0],
250 hose_code.spheres[1..=radius].join("/")
251 );
252
253 for &(pattern, shift) in &database {
254 if prefix.contains(pattern)
255 || pattern.contains(&prefix)
256 || fuzzy_hose_match(&prefix, pattern)
257 {
258 return Some(HoseShiftLookup {
259 atom_index: hose_code.atom_index,
260 element: nucleus,
261 shift_ppm: shift,
262 matched_hose: pattern.to_string(),
263 match_radius: radius,
264 confidence: 0.5 + 0.1 * radius as f64,
265 });
266 }
267 }
268 }
269
270 None
271}
272
273fn fuzzy_hose_match(hose: &str, pattern: &str) -> bool {
276 let hose_parts: Vec<&str> = hose.split('/').collect();
277 let pat_parts: Vec<&str> = pattern.split('/').collect();
278
279 if hose_parts.len() < 2 || pat_parts.len() < 2 {
280 return false;
281 }
282
283 if hose_parts[0] != pat_parts[0] {
285 return false;
286 }
287
288 let hose_neighbors: BTreeSet<&str> = hose_parts[1].split(',').collect();
290 let pat_neighbors: BTreeSet<&str> = pat_parts[1].split(',').collect();
291
292 let intersection = hose_neighbors.intersection(&pat_neighbors).count();
293 let union = hose_neighbors.union(&pat_neighbors).count();
294
295 if union == 0 {
296 return false;
297 }
298
299 (intersection as f64 / union as f64) > 0.5
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306
307 #[test]
308 fn test_hose_codes_ethanol() {
309 let mol = Molecule::from_smiles("CCO").unwrap();
310 let codes = generate_hose_codes(&mol, 3);
311
312 assert_eq!(codes.len(), mol.graph.node_count());
314
315 for code in &codes {
317 assert!(!code.spheres[0].is_empty());
318 assert!(!code.full_code.is_empty());
319 }
320 }
321
322 #[test]
323 fn test_hose_codes_benzene() {
324 let mol = Molecule::from_smiles("c1ccccc1").unwrap();
325 let codes = generate_hose_codes(&mol, 3);
326
327 assert_eq!(codes.len(), mol.graph.node_count());
328
329 let carbon_codes: Vec<&HoseCode> = codes.iter().filter(|c| c.element == 6).collect();
331 assert!(!carbon_codes.is_empty());
332 }
333
334 #[test]
335 fn test_hose_codes_deterministic() {
336 let mol = Molecule::from_smiles("CC(=O)O").unwrap();
337 let codes1 = generate_hose_codes(&mol, 3);
338 let codes2 = generate_hose_codes(&mol, 3);
339
340 for (c1, c2) in codes1.iter().zip(codes2.iter()) {
341 assert_eq!(c1.full_code, c2.full_code);
342 }
343 }
344}