1use crate::graph::{BondOrder, Molecule};
7use petgraph::graph::NodeIndex;
8use petgraph::visit::EdgeRef;
9use serde::{Deserialize, Serialize};
10use std::collections::BTreeSet;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct HoseCode {
15 pub atom_index: usize,
17 pub element: u8,
19 pub spheres: Vec<String>,
21 pub full_code: String,
23}
24
25fn bond_symbol(order: BondOrder) -> &'static str {
26 match order {
27 BondOrder::Single => "",
28 BondOrder::Double => "=",
29 BondOrder::Triple => "#",
30 BondOrder::Aromatic => "*",
31 BondOrder::Unknown => "",
32 }
33}
34
35fn element_symbol(z: u8) -> &'static str {
36 match z {
37 1 => "H",
38 5 => "B",
39 6 => "C",
40 7 => "N",
41 8 => "O",
42 9 => "F",
43 14 => "Si",
44 15 => "P",
45 16 => "S",
46 17 => "Cl",
47 35 => "Br",
48 53 => "I",
49 _ => "X",
50 }
51}
52
53pub fn generate_hose_codes(mol: &Molecule, max_radius: usize) -> Vec<HoseCode> {
58 let n = mol.graph.node_count();
59 let mut codes = Vec::with_capacity(n);
60
61 for atom_idx in 0..n {
62 let center = NodeIndex::new(atom_idx);
63 let center_element = mol.graph[center].element;
64 let center_sym = element_symbol(center_element);
65
66 let mut spheres = Vec::with_capacity(max_radius + 1);
67 spheres.push(center_sym.to_string());
68
69 let mut visited = vec![false; n];
71 visited[atom_idx] = true;
72 let mut current_frontier: Vec<(NodeIndex, BondOrder)> = Vec::new();
73
74 for edge in mol.graph.edges(center) {
76 let neighbor = if edge.source() == center {
77 edge.target()
78 } else {
79 edge.source()
80 };
81 current_frontier.push((neighbor, edge.weight().order));
82 }
83
84 for _radius in 1..=max_radius {
85 if current_frontier.is_empty() {
86 spheres.push(String::new());
87 continue;
88 }
89
90 let mut sphere_parts: BTreeSet<String> = BTreeSet::new();
92 let mut next_frontier: Vec<(NodeIndex, BondOrder)> = Vec::new();
93
94 for (node, bond_order) in ¤t_frontier {
95 let node_idx = node.index();
96 if visited[node_idx] {
97 continue;
98 }
99 visited[node_idx] = true;
100
101 let elem = mol.graph[*node].element;
102 let sym = element_symbol(elem);
103 let bond_sym = bond_symbol(*bond_order);
104 sphere_parts.insert(format!("{}{}", bond_sym, sym));
105
106 for edge in mol.graph.edges(*node) {
108 let next = if edge.source() == *node {
109 edge.target()
110 } else {
111 edge.source()
112 };
113 if !visited[next.index()] {
114 next_frontier.push((next, edge.weight().order));
115 }
116 }
117 }
118
119 spheres.push(sphere_parts.into_iter().collect::<Vec<_>>().join(","));
120 current_frontier = next_frontier;
121 }
122
123 let full_code = format!("{}/{}", spheres[0], spheres[1..].join("/"));
124
125 codes.push(HoseCode {
126 atom_index: atom_idx,
127 element: center_element,
128 spheres,
129 full_code,
130 });
131 }
132
133 codes
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139
140 #[test]
141 fn test_hose_codes_ethanol() {
142 let mol = Molecule::from_smiles("CCO").unwrap();
143 let codes = generate_hose_codes(&mol, 3);
144
145 assert_eq!(codes.len(), mol.graph.node_count());
147
148 for code in &codes {
150 assert!(!code.spheres[0].is_empty());
151 assert!(!code.full_code.is_empty());
152 }
153 }
154
155 #[test]
156 fn test_hose_codes_benzene() {
157 let mol = Molecule::from_smiles("c1ccccc1").unwrap();
158 let codes = generate_hose_codes(&mol, 3);
159
160 assert_eq!(codes.len(), mol.graph.node_count());
161
162 let carbon_codes: Vec<&HoseCode> = codes.iter().filter(|c| c.element == 6).collect();
164 assert!(!carbon_codes.is_empty());
165 }
166
167 #[test]
168 fn test_hose_codes_deterministic() {
169 let mol = Molecule::from_smiles("CC(=O)O").unwrap();
170 let codes1 = generate_hose_codes(&mol, 3);
171 let codes2 = generate_hose_codes(&mol, 3);
172
173 for (c1, c2) in codes1.iter().zip(codes2.iter()) {
174 assert_eq!(c1.full_code, c2.full_code);
175 }
176 }
177}