Skip to main content

sci_form/ml/
descriptors.rs

1//! Molecular descriptors for ML property prediction.
2//!
3//! Computes a vector of constitutional, topological, and electronic
4//! descriptors from molecular graph and (optionally) 3D coordinates.
5
6use serde::{Deserialize, Serialize};
7
8/// Molecular descriptor vector.
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct MolecularDescriptors {
11    /// Molecular weight (Daltons).
12    pub molecular_weight: f64,
13    /// Number of heavy atoms (non-H).
14    pub n_heavy_atoms: usize,
15    /// Number of hydrogen atoms.
16    pub n_hydrogens: usize,
17    /// Number of bonds.
18    pub n_bonds: usize,
19    /// Number of rotatable bonds (single bonds between non-H heavy atoms, not in rings).
20    pub n_rotatable_bonds: usize,
21    /// Number of H-bond donors (N-H, O-H).
22    pub n_hbd: usize,
23    /// Number of H-bond acceptors (N, O).
24    pub n_hba: usize,
25    /// Fraction of sp3 carbons.
26    pub fsp3: f64,
27    /// Total partial charge magnitude (sum of |q_i|).
28    pub total_abs_charge: f64,
29    /// Max partial charge.
30    pub max_charge: f64,
31    /// Min partial charge.
32    pub min_charge: f64,
33    /// Wiener index (sum of shortest-path distances for all pairs).
34    pub wiener_index: f64,
35    /// Number of rings (from graph cycles, approximate).
36    pub n_rings: usize,
37    /// Number of aromatic atoms.
38    pub n_aromatic: usize,
39    /// Balaban J index (approximation).
40    pub balaban_j: f64,
41    /// Sum of atomic electronegativities (Pauling).
42    pub sum_electronegativity: f64,
43    /// Sum of atomic polarizabilities (empirical, ų).
44    pub sum_polarizability: f64,
45}
46
47/// 3D molecular shape descriptors computed from atomic coordinates.
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct Descriptors3D {
50    /// Radius of gyration (Å).
51    pub radius_of_gyration: f64,
52    /// Asphericity (0 = sphere, 1 = rod).
53    pub asphericity: f64,
54    /// Eccentricity (0 = sphere).
55    pub eccentricity: f64,
56    /// Principal moments of inertia ratios (NPR1 = I1/I3, NPR2 = I2/I3).
57    pub npr1: f64,
58    pub npr2: f64,
59    /// Sphericity (0 = asymmetric, 1 = sphere).
60    pub sphericity: f64,
61    /// Molecular span: max distance between any two atoms (Å).
62    pub span: f64,
63}
64
65/// Atomic weight table for common elements.
66fn atomic_weight(z: u8) -> f64 {
67    match z {
68        1 => 1.008,
69        5 => 10.81,
70        6 => 12.011,
71        7 => 14.007,
72        8 => 15.999,
73        9 => 18.998,
74        14 => 28.086,
75        15 => 30.974,
76        16 => 32.06,
77        17 => 35.45,
78        35 => 79.904,
79        53 => 126.904,
80        _ => z as f64 * 1.5, // rough fallback
81    }
82}
83
84/// Pauling electronegativity.
85fn electronegativity(z: u8) -> f64 {
86    match z {
87        1 => 2.20,
88        5 => 2.04,
89        6 => 2.55,
90        7 => 3.04,
91        8 => 3.44,
92        9 => 3.98,
93        14 => 1.90,
94        15 => 2.19,
95        16 => 2.58,
96        17 => 3.16,
97        35 => 2.96,
98        53 => 2.66,
99        _ => 1.80,
100    }
101}
102
103/// Empirical atomic polarizability (ų).
104fn atomic_polarizability(z: u8) -> f64 {
105    match z {
106        1 => 0.667,
107        5 => 3.03,
108        6 => 1.76,
109        7 => 1.10,
110        8 => 0.802,
111        9 => 0.557,
112        14 => 5.38,
113        15 => 3.63,
114        16 => 2.90,
115        17 => 2.18,
116        35 => 3.05,
117        53 => 5.35,
118        _ => 2.0,
119    }
120}
121
122/// Compute molecular descriptors from elements, bonds, and partial charges.
123///
124/// `elements`: atomic numbers.
125/// `bonds`: (atom_i, atom_j, bond_order) list.
126/// `charges`: partial charges (same length as elements), or empty.
127/// `aromatic_atoms`: boolean flags per atom, or empty.
128pub fn compute_descriptors(
129    elements: &[u8],
130    bonds: &[(usize, usize, u8)],
131    charges: &[f64],
132    aromatic_atoms: &[bool],
133) -> MolecularDescriptors {
134    let n = elements.len();
135    let n_heavy = elements.iter().filter(|&&z| z != 1).count();
136    let n_h = n - n_heavy;
137
138    let mw: f64 = elements.iter().map(|&z| atomic_weight(z)).sum();
139
140    // Build adjacency
141    let mut adj = vec![vec![]; n];
142    for &(i, j, _) in bonds {
143        if i < n && j < n {
144            adj[i].push(j);
145            adj[j].push(i);
146        }
147    }
148
149    // Rotatable bonds: single bonds between two heavy atoms with ≥2 neighbors each
150    let n_rot = bonds
151        .iter()
152        .filter(|&&(i, j, ord)| {
153            ord == 1
154                && elements[i] != 1
155                && elements[j] != 1
156                && adj[i].len() >= 2
157                && adj[j].len() >= 2
158        })
159        .count();
160
161    // H-bond donors: N-H or O-H
162    let n_hbd = (0..n)
163        .filter(|&i| {
164            (elements[i] == 7 || elements[i] == 8) && adj[i].iter().any(|&j| elements[j] == 1)
165        })
166        .count();
167
168    // H-bond acceptors: N or O
169    let n_hba = elements.iter().filter(|&&z| z == 7 || z == 8).count();
170
171    // Fsp3: fraction of sp3 carbons (4 neighbors is sp3 proxy)
172    let n_c = elements.iter().filter(|&&z| z == 6).count();
173    let n_c_sp3 = (0..n)
174        .filter(|&i| elements[i] == 6 && adj[i].len() == 4)
175        .count();
176    let fsp3 = if n_c > 0 {
177        n_c_sp3 as f64 / n_c as f64
178    } else {
179        0.0
180    };
181
182    // Charges
183    let (total_abs, max_q, min_q) = if !charges.is_empty() && charges.len() == n {
184        let abs_sum: f64 = charges.iter().map(|q| q.abs()).sum();
185        let max = charges.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
186        let min = charges.iter().cloned().fold(f64::INFINITY, f64::min);
187        (abs_sum, max, min)
188    } else {
189        (0.0, 0.0, 0.0)
190    };
191
192    // Wiener index: BFS from each node
193    let mut wiener = 0.0;
194    for start in 0..n {
195        let mut dist = vec![u32::MAX; n];
196        dist[start] = 0;
197        let mut queue = std::collections::VecDeque::new();
198        queue.push_back(start);
199        while let Some(u) = queue.pop_front() {
200            for &v in &adj[u] {
201                if dist[v] == u32::MAX {
202                    dist[v] = dist[u] + 1;
203                    queue.push_back(v);
204                }
205            }
206        }
207        wiener += dist
208            .iter()
209            .filter(|&&d| d != u32::MAX)
210            .map(|&d| d as f64)
211            .sum::<f64>();
212    }
213    wiener /= 2.0; // each pair counted twice
214
215    // Ring count (approximate): n_bonds - n_atoms + n_components
216    let n_components = {
217        let mut visited = vec![false; n];
218        let mut count = 0usize;
219        for start in 0..n {
220            if visited[start] {
221                continue;
222            }
223            count += 1;
224            let mut stack = vec![start];
225            while let Some(u) = stack.pop() {
226                if visited[u] {
227                    continue;
228                }
229                visited[u] = true;
230                for &v in &adj[u] {
231                    if !visited[v] {
232                        stack.push(v);
233                    }
234                }
235            }
236        }
237        count
238    };
239    let n_rings = if bonds.len() + n_components > n {
240        bonds.len() + n_components - n
241    } else {
242        0
243    };
244
245    let n_aromatic = if !aromatic_atoms.is_empty() {
246        aromatic_atoms.iter().filter(|&&a| a).count()
247    } else {
248        0
249    };
250
251    // Balaban J (simplified)
252    let balaban_j = if !bonds.is_empty() && n > 1 {
253        let mu = bonds.len() as f64 / (n as f64 - 1.0);
254        mu.ln().abs() + 1.0
255    } else {
256        0.0
257    };
258
259    let sum_en: f64 = elements.iter().map(|&z| electronegativity(z)).sum();
260    let sum_pol: f64 = elements.iter().map(|&z| atomic_polarizability(z)).sum();
261
262    MolecularDescriptors {
263        molecular_weight: mw,
264        n_heavy_atoms: n_heavy,
265        n_hydrogens: n_h,
266        n_bonds: bonds.len(),
267        n_rotatable_bonds: n_rot,
268        n_hbd,
269        n_hba,
270        fsp3,
271        total_abs_charge: total_abs,
272        max_charge: max_q,
273        min_charge: min_q,
274        wiener_index: wiener,
275        n_rings,
276        n_aromatic,
277        balaban_j,
278        sum_electronegativity: sum_en,
279        sum_polarizability: sum_pol,
280    }
281}
282
283/// Compute 3D shape descriptors from atomic coordinates and masses.
284///
285/// `elements`: atomic numbers (used for mass weighting).
286/// `positions`: 3D coordinates as flat [x0,y0,z0,x1,y1,z1,...].
287pub fn compute_3d_descriptors(elements: &[u8], positions: &[f64]) -> Descriptors3D {
288    let n = elements.len();
289    if n == 0 || positions.len() < n * 3 {
290        return Descriptors3D {
291            radius_of_gyration: 0.0,
292            asphericity: 0.0,
293            eccentricity: 0.0,
294            npr1: 0.0,
295            npr2: 0.0,
296            sphericity: 0.0,
297            span: 0.0,
298        };
299    }
300
301    let masses: Vec<f64> = elements.iter().map(|&z| atomic_weight(z)).collect();
302    let total_mass: f64 = masses.iter().sum();
303
304    // Center of mass
305    let mut com = [0.0f64; 3];
306    for i in 0..n {
307        for k in 0..3 {
308            com[k] += masses[i] * positions[i * 3 + k];
309        }
310    }
311    for k in 0..3 {
312        com[k] /= total_mass;
313    }
314
315    // Radius of gyration
316    let mut rg2 = 0.0f64;
317    for i in 0..n {
318        let mut d2 = 0.0;
319        for k in 0..3 {
320            let d = positions[i * 3 + k] - com[k];
321            d2 += d * d;
322        }
323        rg2 += masses[i] * d2;
324    }
325    rg2 /= total_mass;
326    let rg = rg2.sqrt();
327
328    // Gyration tensor (3x3 symmetric)
329    let mut gt = [[0.0f64; 3]; 3];
330    for i in 0..n {
331        let r = [
332            positions[i * 3] - com[0],
333            positions[i * 3 + 1] - com[1],
334            positions[i * 3 + 2] - com[2],
335        ];
336        for a in 0..3 {
337            for b in 0..3 {
338                gt[a][b] += masses[i] * r[a] * r[b];
339            }
340        }
341    }
342    for a in 0..3 {
343        for b in 0..3 {
344            gt[a][b] /= total_mass;
345        }
346    }
347
348    // Eigenvalues of 3x3 symmetric matrix (analytical Cardano's method)
349    let evals = eigenvalues_3x3_symmetric(&gt);
350    let mut sorted = evals;
351    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
352    let (l1, l2, l3) = (sorted[0].max(0.0), sorted[1].max(0.0), sorted[2].max(0.0));
353    let sum_l = l1 + l2 + l3;
354
355    let asphericity = if sum_l > 1e-14 {
356        ((l1 - l2).powi(2) + (l2 - l3).powi(2) + (l1 - l3).powi(2)) / (2.0 * sum_l * sum_l)
357    } else {
358        0.0
359    };
360
361    let eccentricity = if l3 > 1e-14 {
362        (1.0 - l1 / l3).sqrt().clamp(0.0, 1.0)
363    } else {
364        0.0
365    };
366
367    let (npr1, npr2) = if l3 > 1e-14 {
368        (l1 / l3, l2 / l3)
369    } else {
370        (0.0, 0.0)
371    };
372
373    let sphericity = if l3 > 1e-14 {
374        let prod = (l1 * l2 * l3).powf(1.0 / 3.0);
375        (prod / l3).min(1.0)
376    } else {
377        0.0
378    };
379
380    // Molecular span: max pairwise distance
381    let mut span = 0.0f64;
382    for i in 0..n {
383        for j in (i + 1)..n {
384            let mut d2 = 0.0;
385            for k in 0..3 {
386                let d = positions[i * 3 + k] - positions[j * 3 + k];
387                d2 += d * d;
388            }
389            span = span.max(d2.sqrt());
390        }
391    }
392
393    Descriptors3D {
394        radius_of_gyration: rg,
395        asphericity,
396        eccentricity,
397        npr1,
398        npr2,
399        sphericity,
400        span,
401    }
402}
403
404/// Eigenvalues of a 3x3 symmetric matrix using Cardano's formula.
405fn eigenvalues_3x3_symmetric(m: &[[f64; 3]; 3]) -> [f64; 3] {
406    let p1 = m[0][1] * m[0][1] + m[0][2] * m[0][2] + m[1][2] * m[1][2];
407
408    if p1.abs() < 1e-30 {
409        // Already diagonal
410        return [m[0][0], m[1][1], m[2][2]];
411    }
412
413    let q = (m[0][0] + m[1][1] + m[2][2]) / 3.0;
414    let p2 = (m[0][0] - q).powi(2) + (m[1][1] - q).powi(2) + (m[2][2] - q).powi(2) + 2.0 * p1;
415    let p = (p2 / 6.0).sqrt();
416
417    // B = (1/p) * (A - q*I)
418    let b = [
419        [(m[0][0] - q) / p, m[0][1] / p, m[0][2] / p],
420        [m[1][0] / p, (m[1][1] - q) / p, m[1][2] / p],
421        [m[2][0] / p, m[2][1] / p, (m[2][2] - q) / p],
422    ];
423
424    let det_b = b[0][0] * (b[1][1] * b[2][2] - b[1][2] * b[2][1])
425        - b[0][1] * (b[1][0] * b[2][2] - b[1][2] * b[2][0])
426        + b[0][2] * (b[1][0] * b[2][1] - b[1][1] * b[2][0]);
427
428    let r = det_b / 2.0;
429    let r_clamped = r.clamp(-1.0, 1.0);
430    let phi = r_clamped.acos() / 3.0;
431
432    let eig1 = q + 2.0 * p * phi.cos();
433    let eig3 = q + 2.0 * p * (phi + 2.0 * std::f64::consts::FRAC_PI_3).cos();
434    let eig2 = 3.0 * q - eig1 - eig3;
435
436    [eig1, eig2, eig3]
437}
438
439#[cfg(test)]
440mod tests {
441    use super::*;
442
443    #[test]
444    fn test_descriptors_water() {
445        let elements = [8u8, 1, 1];
446        let bonds = [(0, 1, 1u8), (0, 2, 1)];
447        let desc = compute_descriptors(&elements, &bonds, &[], &[]);
448        assert_eq!(desc.n_heavy_atoms, 1);
449        assert_eq!(desc.n_hydrogens, 2);
450        assert_eq!(desc.n_bonds, 2);
451        assert_eq!(desc.n_hbd, 1); // O-H
452        assert_eq!(desc.n_hba, 1); // O
453        assert!((desc.molecular_weight - 18.015).abs() < 0.01);
454    }
455
456    #[test]
457    fn test_descriptors_methane() {
458        let elements = [6u8, 1, 1, 1, 1];
459        let bonds = [(0, 1, 1u8), (0, 2, 1), (0, 3, 1), (0, 4, 1)];
460        let desc = compute_descriptors(&elements, &bonds, &[], &[]);
461        assert_eq!(desc.n_heavy_atoms, 1);
462        assert_eq!(desc.fsp3, 1.0); // all C are sp3
463        assert_eq!(desc.n_rotatable_bonds, 0);
464    }
465}