Skip to main content

sci_form/ml/
getaway.rs

1//! GETAWAY (GEometry, Topology, and Atom-Weights AssemblY) descriptors.
2//!
3//! Combine 3D geometry (molecular influence matrix) with topological
4//! information and atomic properties. The molecular influence matrix H
5//! encodes how each atom influences the overall molecular shape.
6//!
7//! Reference: Consonni et al., J. Chem. Inf. Model. 42, 682–692 (2002).
8
9use serde::{Deserialize, Serialize};
10
11/// GETAWAY descriptor result.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct GetawayDescriptors {
14    /// Leverage values (diagonal of influence matrix H_ii).
15    pub leverages: Vec<f64>,
16    /// H autocorrelation of lag k (unweighted): HATk.
17    pub hat_autocorrelation: Vec<f64>,
18    /// R autocorrelation (geometric distance weighted): Rk.
19    pub r_autocorrelation: Vec<f64>,
20    /// H total index: HATt = sum of |H_ij| for all bonded pairs.
21    pub hat_total: f64,
22    /// R total index: Rt.
23    pub r_total: f64,
24    /// Information content of leverages (Shannon entropy).
25    pub h_information: f64,
26    /// Maximum leverage.
27    pub h_max: f64,
28    /// Mean leverage.
29    pub h_mean: f64,
30    /// Maximum topological distance considered.
31    pub max_lag: usize,
32}
33
34/// Compute GETAWAY descriptors from 3D coordinates and connectivity.
35///
36/// # Arguments
37/// * `elements` - Atomic numbers
38/// * `positions` - 3D coordinates
39/// * `bonds` - Bond list as (atom_i, atom_j) pairs
40/// * `max_lag` - Maximum topological distance for autocorrelation (default: 8)
41pub fn compute_getaway(
42    elements: &[u8],
43    positions: &[[f64; 3]],
44    bonds: &[(usize, usize)],
45    max_lag: usize,
46) -> GetawayDescriptors {
47    let n = elements.len().min(positions.len());
48    if n < 2 {
49        return GetawayDescriptors {
50            leverages: vec![],
51            hat_autocorrelation: vec![0.0; max_lag],
52            r_autocorrelation: vec![0.0; max_lag],
53            hat_total: 0.0,
54            r_total: 0.0,
55            h_information: 0.0,
56            h_max: 0.0,
57            h_mean: 0.0,
58            max_lag,
59        };
60    }
61
62    // Compute centered coordinate matrix X (n × 3)
63    let mut centroid = [0.0f64; 3];
64    for i in 0..n {
65        for d in 0..3 {
66            centroid[d] += positions[i][d];
67        }
68    }
69    for d in 0..3 {
70        centroid[d] /= n as f64;
71    }
72
73    // X^T X (3×3 Gram matrix)
74    let mut xtx = [[0.0f64; 3]; 3];
75    for i in 0..n {
76        let dx = [
77            positions[i][0] - centroid[0],
78            positions[i][1] - centroid[1],
79            positions[i][2] - centroid[2],
80        ];
81        for r in 0..3 {
82            for c in 0..3 {
83                xtx[r][c] += dx[r] * dx[c];
84            }
85        }
86    }
87
88    // Influence matrix: H = X (X^T X)^{-1} X^T
89    // Leverages: h_ii = X_i (X^T X)^{-1} X_i^T
90    let xtx_inv = invert_3x3(&xtx);
91    let mut leverages = Vec::with_capacity(n);
92    for i in 0..n {
93        let dx = [
94            positions[i][0] - centroid[0],
95            positions[i][1] - centroid[1],
96            positions[i][2] - centroid[2],
97        ];
98        // h_ii = dx^T * xtx_inv * dx
99        let mut h = 0.0;
100        for r in 0..3 {
101            for c in 0..3 {
102                h += dx[r] * xtx_inv[r][c] * dx[c];
103            }
104        }
105        leverages.push(h);
106    }
107
108    // Build topological distance matrix via BFS
109    let topo_dist = build_topo_distance(n, bonds);
110
111    // Compute H_ij for off-diagonal elements
112    // H_ij = X_i (X^T X)^{-1} X_j^T / sqrt(h_ii * h_jj)
113    let mut h_matrix = vec![vec![0.0f64; n]; n];
114    for i in 0..n {
115        h_matrix[i][i] = leverages[i];
116        let dxi = [
117            positions[i][0] - centroid[0],
118            positions[i][1] - centroid[1],
119            positions[i][2] - centroid[2],
120        ];
121        for j in (i + 1)..n {
122            let dxj = [
123                positions[j][0] - centroid[0],
124                positions[j][1] - centroid[1],
125                positions[j][2] - centroid[2],
126            ];
127            let mut hij = 0.0;
128            for r in 0..3 {
129                for c in 0..3 {
130                    hij += dxi[r] * xtx_inv[r][c] * dxj[c];
131                }
132            }
133            h_matrix[i][j] = hij;
134            h_matrix[j][i] = hij;
135        }
136    }
137
138    // Geometric distance matrix
139    let mut geo_dist = vec![vec![0.0f64; n]; n];
140    for i in 0..n {
141        for j in (i + 1)..n {
142            let dx = positions[i][0] - positions[j][0];
143            let dy = positions[i][1] - positions[j][1];
144            let dz = positions[i][2] - positions[j][2];
145            let r = (dx * dx + dy * dy + dz * dz).sqrt();
146            geo_dist[i][j] = r;
147            geo_dist[j][i] = r;
148        }
149    }
150
151    // Autocorrelation at lag k
152    let mut hat_autocorrelation = vec![0.0f64; max_lag];
153    let mut r_autocorrelation = vec![0.0f64; max_lag];
154    let mut hat_total = 0.0;
155    let mut r_total = 0.0;
156
157    for i in 0..n {
158        for j in (i + 1)..n {
159            let d = topo_dist[i][j];
160            if d == 0 || d > max_lag {
161                continue;
162            }
163            let k = d - 1;
164            let hi = leverages[i].max(0.0).sqrt();
165            let hj = leverages[j].max(0.0).sqrt();
166            hat_autocorrelation[k] += hi * hj;
167
168            let rij = geo_dist[i][j];
169            if rij > 1e-12 {
170                r_autocorrelation[k] += hi * hj / rij;
171            }
172
173            hat_total += h_matrix[i][j].abs();
174            if rij > 1e-12 {
175                r_total += h_matrix[i][j].abs() / rij;
176            }
177        }
178    }
179
180    // Information content of leverages
181    let h_sum: f64 = leverages.iter().sum();
182    let h_information = if h_sum > 1e-12 {
183        let mut entropy = 0.0;
184        for &h in &leverages {
185            let p = h / h_sum;
186            if p > 1e-12 {
187                entropy -= p * p.ln();
188            }
189        }
190        entropy
191    } else {
192        0.0
193    };
194
195    let h_max = leverages.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
196    let h_mean = if n > 0 { h_sum / n as f64 } else { 0.0 };
197
198    GetawayDescriptors {
199        leverages,
200        hat_autocorrelation,
201        r_autocorrelation,
202        hat_total,
203        r_total,
204        h_information,
205        h_max,
206        h_mean,
207        max_lag,
208    }
209}
210
211/// Build topological distance matrix using BFS.
212fn build_topo_distance(n: usize, bonds: &[(usize, usize)]) -> Vec<Vec<usize>> {
213    let mut adj = vec![vec![]; n];
214    for &(a, b) in bonds {
215        if a < n && b < n {
216            adj[a].push(b);
217            adj[b].push(a);
218        }
219    }
220
221    let mut dist = vec![vec![0usize; n]; n];
222    for start in 0..n {
223        let mut visited = vec![false; n];
224        visited[start] = true;
225        let mut queue = std::collections::VecDeque::new();
226        queue.push_back((start, 0usize));
227        while let Some((node, d)) = queue.pop_front() {
228            dist[start][node] = d;
229            for &nb in &adj[node] {
230                if !visited[nb] {
231                    visited[nb] = true;
232                    queue.push_back((nb, d + 1));
233                }
234            }
235        }
236    }
237
238    dist
239}
240
241/// Invert a 3×3 matrix using Cramer's rule.
242fn invert_3x3(m: &[[f64; 3]; 3]) -> [[f64; 3]; 3] {
243    let det = m[0][0] * (m[1][1] * m[2][2] - m[1][2] * m[2][1])
244        - m[0][1] * (m[1][0] * m[2][2] - m[1][2] * m[2][0])
245        + m[0][2] * (m[1][0] * m[2][1] - m[1][1] * m[2][0]);
246
247    if det.abs() < 1e-12 {
248        return [[0.0; 3]; 3];
249    }
250
251    let inv_det = 1.0 / det;
252    let mut inv = [[0.0f64; 3]; 3];
253
254    inv[0][0] = (m[1][1] * m[2][2] - m[1][2] * m[2][1]) * inv_det;
255    inv[0][1] = (m[0][2] * m[2][1] - m[0][1] * m[2][2]) * inv_det;
256    inv[0][2] = (m[0][1] * m[1][2] - m[0][2] * m[1][1]) * inv_det;
257    inv[1][0] = (m[1][2] * m[2][0] - m[1][0] * m[2][2]) * inv_det;
258    inv[1][1] = (m[0][0] * m[2][2] - m[0][2] * m[2][0]) * inv_det;
259    inv[1][2] = (m[0][2] * m[1][0] - m[0][0] * m[1][2]) * inv_det;
260    inv[2][0] = (m[1][0] * m[2][1] - m[1][1] * m[2][0]) * inv_det;
261    inv[2][1] = (m[0][1] * m[2][0] - m[0][0] * m[2][1]) * inv_det;
262    inv[2][2] = (m[0][0] * m[1][1] - m[0][1] * m[1][0]) * inv_det;
263
264    inv
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270
271    #[test]
272    fn test_getaway_ethane() {
273        let elements = vec![6, 6, 1, 1, 1, 1, 1, 1];
274        let positions = vec![
275            [0.0, 0.0, 0.0],   // C
276            [1.54, 0.0, 0.0],  // C
277            [-0.5, 0.9, 0.0],  // H
278            [-0.5, -0.9, 0.0], // H
279            [-0.5, 0.0, 0.9],  // H
280            [2.04, 0.9, 0.0],  // H
281            [2.04, -0.9, 0.0], // H
282            [2.04, 0.0, 0.9],  // H
283        ];
284        let bonds = vec![(0, 1), (0, 2), (0, 3), (0, 4), (1, 5), (1, 6), (1, 7)];
285
286        let g = compute_getaway(&elements, &positions, &bonds, 8);
287        assert_eq!(g.leverages.len(), 8);
288        assert!(g.hat_total > 0.0);
289        assert!(g.h_max > 0.0);
290    }
291
292    #[test]
293    fn test_getaway_empty() {
294        let g = compute_getaway(&[], &[], &[], 8);
295        assert!(g.leverages.is_empty());
296        assert_eq!(g.hat_total, 0.0);
297    }
298}