Skip to main content

sci_form/nmr/
hose.rs

1//! HOSE (Hierarchically Ordered Spherical description of Environment) code generation.
2//!
3//! Generates unique environment descriptors for each atom by traversing
4//! the molecular graph in concentric spheres (radius 0 to max_radius).
5
6use crate::graph::{BondOrder, Molecule};
7use petgraph::graph::NodeIndex;
8use petgraph::visit::EdgeRef;
9use serde::{Deserialize, Serialize};
10use std::collections::BTreeSet;
11
12/// A HOSE code descriptor for a specific atom.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct HoseCode {
15    /// Atom index in the molecule.
16    pub atom_index: usize,
17    /// Atomic number of the center atom.
18    pub element: u8,
19    /// HOSE string at each sphere radius (0..=max_radius).
20    pub spheres: Vec<String>,
21    /// Full concatenated HOSE code.
22    pub full_code: String,
23}
24
25fn bond_symbol(order: BondOrder) -> &'static str {
26    match order {
27        BondOrder::Single => "",
28        BondOrder::Double => "=",
29        BondOrder::Triple => "#",
30        BondOrder::Aromatic => "*",
31        BondOrder::Unknown => "",
32    }
33}
34
35fn element_symbol(z: u8) -> &'static str {
36    match z {
37        1 => "H",
38        5 => "B",
39        6 => "C",
40        7 => "N",
41        8 => "O",
42        9 => "F",
43        14 => "Si",
44        15 => "P",
45        16 => "S",
46        17 => "Cl",
47        35 => "Br",
48        53 => "I",
49        _ => "X",
50    }
51}
52
53/// Generate HOSE codes for all atoms in a molecule.
54///
55/// `mol`: parsed molecular graph
56/// `max_radius`: maximum sphere radius (typically 3–5)
57pub fn generate_hose_codes(mol: &Molecule, max_radius: usize) -> Vec<HoseCode> {
58    let n = mol.graph.node_count();
59    let mut codes = Vec::with_capacity(n);
60
61    for atom_idx in 0..n {
62        let center = NodeIndex::new(atom_idx);
63        let center_element = mol.graph[center].element;
64        let center_sym = element_symbol(center_element);
65
66        let mut spheres = Vec::with_capacity(max_radius + 1);
67        spheres.push(center_sym.to_string());
68
69        // BFS by sphere depth
70        let mut visited = vec![false; n];
71        visited[atom_idx] = true;
72        let mut current_frontier: Vec<(NodeIndex, BondOrder)> = Vec::new();
73
74        // Initial frontier: direct neighbors
75        for edge in mol.graph.edges(center) {
76            let neighbor = if edge.source() == center {
77                edge.target()
78            } else {
79                edge.source()
80            };
81            current_frontier.push((neighbor, edge.weight().order));
82        }
83
84        for _radius in 1..=max_radius {
85            if current_frontier.is_empty() {
86                spheres.push(String::new());
87                continue;
88            }
89
90            // Sort frontier entries for deterministic output
91            let mut sphere_parts: BTreeSet<String> = BTreeSet::new();
92            let mut next_frontier: Vec<(NodeIndex, BondOrder)> = Vec::new();
93
94            for (node, bond_order) in &current_frontier {
95                let node_idx = node.index();
96                if visited[node_idx] {
97                    continue;
98                }
99                visited[node_idx] = true;
100
101                let elem = mol.graph[*node].element;
102                let sym = element_symbol(elem);
103                let bond_sym = bond_symbol(*bond_order);
104                sphere_parts.insert(format!("{}{}", bond_sym, sym));
105
106                // Collect next frontier
107                for edge in mol.graph.edges(*node) {
108                    let next = if edge.source() == *node {
109                        edge.target()
110                    } else {
111                        edge.source()
112                    };
113                    if !visited[next.index()] {
114                        next_frontier.push((next, edge.weight().order));
115                    }
116                }
117            }
118
119            spheres.push(sphere_parts.into_iter().collect::<Vec<_>>().join(","));
120            current_frontier = next_frontier;
121        }
122
123        let full_code = format!("{}/{}", spheres[0], spheres[1..].join("/"));
124
125        codes.push(HoseCode {
126            atom_index: atom_idx,
127            element: center_element,
128            spheres,
129            full_code,
130        });
131    }
132
133    codes
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139
140    #[test]
141    fn test_hose_codes_ethanol() {
142        let mol = Molecule::from_smiles("CCO").unwrap();
143        let codes = generate_hose_codes(&mol, 3);
144
145        // Should have codes for all atoms (including H)
146        assert_eq!(codes.len(), mol.graph.node_count());
147
148        // All codes should have non-empty first sphere
149        for code in &codes {
150            assert!(!code.spheres[0].is_empty());
151            assert!(!code.full_code.is_empty());
152        }
153    }
154
155    #[test]
156    fn test_hose_codes_benzene() {
157        let mol = Molecule::from_smiles("c1ccccc1").unwrap();
158        let codes = generate_hose_codes(&mol, 3);
159
160        assert_eq!(codes.len(), mol.graph.node_count());
161
162        // All carbon atoms should have similar HOSE codes due to symmetry
163        let carbon_codes: Vec<&HoseCode> = codes.iter().filter(|c| c.element == 6).collect();
164        assert!(!carbon_codes.is_empty());
165    }
166
167    #[test]
168    fn test_hose_codes_deterministic() {
169        let mol = Molecule::from_smiles("CC(=O)O").unwrap();
170        let codes1 = generate_hose_codes(&mol, 3);
171        let codes2 = generate_hose_codes(&mol, 3);
172
173        for (c1, c2) in codes1.iter().zip(codes2.iter()) {
174            assert_eq!(c1.full_code, c2.full_code);
175        }
176    }
177}