1use crate::graph::{BondOrder, Molecule};
9use petgraph::graph::NodeIndex;
10use petgraph::visit::EdgeRef;
11use serde::{Deserialize, Serialize};
12use std::collections::BTreeSet;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct HoseCode {
17 pub atom_index: usize,
19 pub element: u8,
21 pub spheres: Vec<String>,
23 pub full_code: String,
25}
26
27fn bond_symbol(order: BondOrder) -> &'static str {
28 match order {
29 BondOrder::Single => "",
30 BondOrder::Double => "=",
31 BondOrder::Triple => "#",
32 BondOrder::Aromatic => "*",
33 BondOrder::Unknown => "",
34 }
35}
36
37fn element_symbol(z: u8) -> &'static str {
38 match z {
39 1 => "H",
40 2 => "He",
41 3 => "Li",
42 4 => "Be",
43 5 => "B",
44 6 => "C",
45 7 => "N",
46 8 => "O",
47 9 => "F",
48 11 => "Na",
49 12 => "Mg",
50 13 => "Al",
51 14 => "Si",
52 15 => "P",
53 16 => "S",
54 17 => "Cl",
55 19 => "K",
56 20 => "Ca",
57 21 => "Sc",
58 22 => "Ti",
59 23 => "V",
60 24 => "Cr",
61 25 => "Mn",
62 26 => "Fe",
63 27 => "Co",
64 28 => "Ni",
65 29 => "Cu",
66 30 => "Zn",
67 31 => "Ga",
68 32 => "Ge",
69 33 => "As",
70 34 => "Se",
71 35 => "Br",
72 37 => "Rb",
73 38 => "Sr",
74 40 => "Zr",
75 41 => "Nb",
76 42 => "Mo",
77 44 => "Ru",
78 45 => "Rh",
79 46 => "Pd",
80 47 => "Ag",
81 48 => "Cd",
82 49 => "In",
83 50 => "Sn",
84 51 => "Sb",
85 52 => "Te",
86 53 => "I",
87 54 => "Xe",
88 55 => "Cs",
89 56 => "Ba",
90 74 => "W",
91 78 => "Pt",
92 79 => "Au",
93 80 => "Hg",
94 81 => "Tl",
95 82 => "Pb",
96 83 => "Bi",
97 _ => "X",
98 }
99}
100
101pub fn generate_hose_codes(mol: &Molecule, max_radius: usize) -> Vec<HoseCode> {
106 let n = mol.graph.node_count();
107 let mut codes = Vec::with_capacity(n);
108
109 for atom_idx in 0..n {
110 let center = NodeIndex::new(atom_idx);
111 let center_element = mol.graph[center].element;
112 let center_sym = element_symbol(center_element);
113
114 let mut spheres = Vec::with_capacity(max_radius + 1);
115 spheres.push(center_sym.to_string());
116
117 let mut visited = vec![false; n];
119 visited[atom_idx] = true;
120 let mut current_frontier: Vec<(NodeIndex, BondOrder)> = Vec::new();
121
122 for edge in mol.graph.edges(center) {
124 let neighbor = if edge.source() == center {
125 edge.target()
126 } else {
127 edge.source()
128 };
129 current_frontier.push((neighbor, edge.weight().order));
130 }
131
132 for _radius in 1..=max_radius {
133 if current_frontier.is_empty() {
134 spheres.push(String::new());
135 continue;
136 }
137
138 let mut sphere_parts: BTreeSet<String> = BTreeSet::new();
140 let mut next_frontier: Vec<(NodeIndex, BondOrder)> = Vec::new();
141
142 for (node, bond_order) in ¤t_frontier {
143 let node_idx = node.index();
144 if visited[node_idx] {
145 continue;
146 }
147 visited[node_idx] = true;
148
149 let elem = mol.graph[*node].element;
150 let sym = element_symbol(elem);
151 let bond_sym = bond_symbol(*bond_order);
152 sphere_parts.insert(format!("{}{}", bond_sym, sym));
153
154 for edge in mol.graph.edges(*node) {
156 let next = if edge.source() == *node {
157 edge.target()
158 } else {
159 edge.source()
160 };
161 if !visited[next.index()] {
162 next_frontier.push((next, edge.weight().order));
163 }
164 }
165 }
166
167 spheres.push(sphere_parts.into_iter().collect::<Vec<_>>().join(","));
168 current_frontier = next_frontier;
169 }
170
171 let full_code = format!("{}/{}", spheres[0], spheres[1..].join("/"));
172
173 codes.push(HoseCode {
174 atom_index: atom_idx,
175 element: center_element,
176 spheres,
177 full_code,
178 });
179 }
180
181 codes
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct HoseShiftLookup {
193 pub atom_index: usize,
195 pub element: u8,
197 pub shift_ppm: f64,
199 pub matched_hose: String,
201 pub match_radius: usize,
203 pub confidence: f64,
205}
206
207fn h1_hose_database() -> Vec<(&'static str, f64)> {
210 vec![
211 ("C/H,H,H,C", 0.90), ("C/H,H,C,C", 1.30), ("C/H,C,C,C", 1.50), ("C/H,H,H,O", 3.40), ("C/H,H,O", 3.50), ("C/H,H,H,N", 2.30), ("C/H,H,N", 2.60), ("C/H,H,H,=O", 2.10), ("C/H,H,=O", 2.50), ("C/H,H,H,*C", 2.30), ("C/*C,*C,H", 7.27), ("C/*C,*C,*C", 7.50), ("C/*C,*N,H", 7.80), ("C/=C,H,H", 5.25), ("C/=C,H,C", 5.40), ("C/=C,=C,H", 6.30), ("C/#C,H", 2.50), ("C/=O,H,C", 9.50), ("C/=O,H,H", 9.60), ("O/H,C", 2.50), ("O/H,*C", 5.50), ("N/H,H,C", 1.50), ("N/H,C,C", 2.20), ]
242}
243
244fn c13_hose_database() -> Vec<(&'static str, f64)> {
246 vec![
247 ("C/H,H,H,C", 15.0), ("C/H,H,C,C", 25.0), ("C/H,C,C,C", 35.0), ("C/C,C,C,C", 40.0), ("C/H,H,H,O", 55.0), ("C/H,H,O,C", 65.0), ("C/H,H,H,N", 32.0), ("C/H,H,N", 45.0), ("C/H,H,H,=O", 30.0), ("C/*C,*C,H", 128.0), ("C/*C,*C,C", 137.0), ("C/*C,*C,O", 155.0), ("C/*C,*C,N", 148.0), ("C/*C,*C,F", 163.0), ("C/*C,*C,Cl", 134.0), ("C/=C,H,H", 115.0), ("C/=C,H,C", 130.0), ("C/=C,C,C", 140.0), ("C/=O,O,C", 175.0), ("C/=O,N,C", 170.0), ("C/=O,C,C", 205.0), ("C/=O,H,C", 200.0), ("C/#C,H", 70.0), ("C/#C,C", 85.0), ]
277}
278
279pub fn predict_shift_from_hose(hose_code: &HoseCode, nucleus: u8) -> Option<HoseShiftLookup> {
284 let database = match nucleus {
285 1 => h1_hose_database(),
286 6 => c13_hose_database(),
287 _ => return None,
288 };
289
290 for radius in (1..=hose_code.spheres.len().saturating_sub(1)).rev() {
292 let prefix = format!(
294 "{}/{}",
295 hose_code.spheres[0],
296 hose_code.spheres[1..=radius].join("/")
297 );
298
299 for &(pattern, shift) in &database {
300 if prefix.contains(pattern)
301 || pattern.contains(&prefix)
302 || fuzzy_hose_match(&prefix, pattern)
303 {
304 return Some(HoseShiftLookup {
305 atom_index: hose_code.atom_index,
306 element: nucleus,
307 shift_ppm: shift,
308 matched_hose: pattern.to_string(),
309 match_radius: radius,
310 confidence: 0.5 + 0.1 * radius as f64,
311 });
312 }
313 }
314 }
315
316 None
317}
318
319fn fuzzy_hose_match(hose: &str, pattern: &str) -> bool {
322 let hose_parts: Vec<&str> = hose.split('/').collect();
323 let pat_parts: Vec<&str> = pattern.split('/').collect();
324
325 if hose_parts.len() < 2 || pat_parts.len() < 2 {
326 return false;
327 }
328
329 if hose_parts[0] != pat_parts[0] {
331 return false;
332 }
333
334 let hose_neighbors: BTreeSet<&str> = hose_parts[1].split(',').collect();
336 let pat_neighbors: BTreeSet<&str> = pat_parts[1].split(',').collect();
337
338 let intersection = hose_neighbors.intersection(&pat_neighbors).count();
339 let union = hose_neighbors.union(&pat_neighbors).count();
340
341 if union == 0 {
342 return false;
343 }
344
345 (intersection as f64 / union as f64) > 0.5
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
354 fn test_hose_codes_ethanol() {
355 let mol = Molecule::from_smiles("CCO").unwrap();
356 let codes = generate_hose_codes(&mol, 3);
357
358 assert_eq!(codes.len(), mol.graph.node_count());
360
361 for code in &codes {
363 assert!(!code.spheres[0].is_empty());
364 assert!(!code.full_code.is_empty());
365 }
366 }
367
368 #[test]
369 fn test_hose_codes_benzene() {
370 let mol = Molecule::from_smiles("c1ccccc1").unwrap();
371 let codes = generate_hose_codes(&mol, 3);
372
373 assert_eq!(codes.len(), mol.graph.node_count());
374
375 let carbon_codes: Vec<&HoseCode> = codes.iter().filter(|c| c.element == 6).collect();
377 assert!(!carbon_codes.is_empty());
378 }
379
380 #[test]
381 fn test_hose_codes_deterministic() {
382 let mol = Molecule::from_smiles("CC(=O)O").unwrap();
383 let codes1 = generate_hose_codes(&mol, 3);
384 let codes2 = generate_hose_codes(&mol, 3);
385
386 for (c1, c2) in codes1.iter().zip(codes2.iter()) {
387 assert_eq!(c1.full_code, c2.full_code);
388 }
389 }
390}