Skip to main content

sci_form/ml/
ensemble.rs

1//! Ensemble ML models with non-linear predictions and uncertainty estimates.
2//!
3//! Extends the base linear models with:
4//! - Decision-tree-like cascading rules for LogP, solubility, pKa
5//! - TPSA (Topological Polar Surface Area) descriptor
6//! - Veber rule analysis for oral bioavailability
7//! - BBB permeability prediction
8//! - Consensus scoring with prediction confidence
9
10use super::descriptors::MolecularDescriptors;
11use serde::{Deserialize, Serialize};
12
13/// Extended ML property predictions with uncertainty and additional properties.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct EnsembleResult {
16    /// Predicted LogP (consensus of 3 models).
17    pub logp: f64,
18    /// Standard deviation across LogP models (uncertainty).
19    pub logp_std: f64,
20    /// Predicted aqueous solubility (log S, mol/L).
21    pub log_solubility: f64,
22    /// Predicted TPSA (Ų).
23    pub tpsa: f64,
24    /// Predicted pKa for the most acidic group.
25    pub pka_acidic: Option<f64>,
26    /// Predicted pKa for the most basic group.
27    pub pka_basic: Option<f64>,
28    /// Veber oral bioavailability rules.
29    pub veber: VeberResult,
30    /// BBB permeability prediction.
31    pub bbb_permeable: bool,
32    /// BBB permeability score (0–1).
33    pub bbb_score: f64,
34    /// Overall prediction confidence (0–1).
35    pub confidence: f64,
36}
37
38/// Veber's rules for oral bioavailability.
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct VeberResult {
41    /// TPSA ≤ 140 Ų.
42    pub tpsa_ok: bool,
43    /// Number of rotatable bonds ≤ 10.
44    pub rotb_ok: bool,
45    /// Passes Veber rules (both criteria met).
46    pub passes: bool,
47}
48
49/// Contribution of each atom type to TPSA (Ų).
50/// Ertl, P., Rohde, B., Selzer, P. (2000) — simplified fragment contributions.
51fn tpsa_contribution(z: u8, n_heavy_neighbors: usize, n_h_neighbors: usize) -> f64 {
52    match z {
53        // Nitrogen contributions
54        7 => match (n_heavy_neighbors, n_h_neighbors) {
55            (1, 2) => 26.02, // -NH₂
56            (1, 1) => 26.02, // =NH
57            (2, 1) => 12.36, // >NH
58            (2, 0) => 12.36, // =N-
59            (3, 0) => 3.24,  // >N-
60            (1, 0) => 23.79, // ≡N
61            _ => 12.0,
62        },
63        // Oxygen contributions
64        8 => match (n_heavy_neighbors, n_h_neighbors) {
65            (0, 2) => 20.23, // H₂O
66            (1, 1) => 20.23, // -OH
67            (1, 0) => 17.07, // =O
68            (2, 0) => 9.23,  // -O-
69            _ => 15.0,
70        },
71        // Sulfur: small contributions
72        16 => match (n_heavy_neighbors, n_h_neighbors) {
73            (1, 1) => 38.80, // -SH
74            (2, 0) => 25.30, // -S-
75            (1, 0) => 32.00, // =S
76            _ => 28.0,
77        },
78        // Phosphorus
79        15 => 34.14,
80        _ => 0.0,
81    }
82}
83
84/// Compute TPSA from elements and bond connectivity.
85///
86/// Uses Ertl fragment-based TPSA calculation.
87pub fn compute_tpsa(elements: &[u8], bonds: &[(usize, usize, u8)]) -> f64 {
88    let n = elements.len();
89    let mut adj: Vec<Vec<usize>> = vec![vec![]; n];
90    for &(i, j, _) in bonds {
91        if i < n && j < n {
92            adj[i].push(j);
93            adj[j].push(i);
94        }
95    }
96
97    let mut tpsa = 0.0;
98    for i in 0..n {
99        if !matches!(elements[i], 7 | 8 | 15 | 16) {
100            continue;
101        }
102        let n_heavy = adj[i].iter().filter(|&&j| elements[j] != 1).count();
103        let n_h = adj[i].iter().filter(|&&j| elements[j] == 1).count();
104        tpsa += tpsa_contribution(elements[i], n_heavy, n_h);
105    }
106    tpsa
107}
108
109// ─── Ensemble LogP models ────────────────────────────────────────────────────
110
111/// Model 1: Wildman-Crippen additive (same as base model).
112fn logp_model_1(desc: &MolecularDescriptors) -> f64 {
113    let base = 0.120 * desc.n_heavy_atoms as f64;
114    let h_corr = -0.230 * desc.n_hbd as f64;
115    let ring_corr = 0.150 * desc.n_rings as f64;
116    let arom_corr = 0.080 * desc.n_aromatic as f64;
117    let polar_corr = -0.310 * desc.n_hba as f64;
118    let sp3_corr = -0.180 * desc.fsp3;
119    let mw_term = 0.005 * (desc.molecular_weight - 100.0);
120    base + h_corr + ring_corr + arom_corr + polar_corr + sp3_corr + mw_term
121}
122
123/// Model 2: ALOGP-like fragment-based with non-linear corrections.
124fn logp_model_2(desc: &MolecularDescriptors, tpsa: f64) -> f64 {
125    // Use TPSA as a proxy for overall polarity
126    let polarity_term = -0.015 * tpsa;
127    let size_term = 0.008 * desc.molecular_weight;
128    let hbond_term = -0.45 * (desc.n_hbd as f64 + 0.5 * desc.n_hba as f64);
129    let lipophilic = 0.22 * (desc.n_heavy_atoms as f64 - desc.n_hba as f64);
130
131    // Non-linear correction: penalty for very high polarity
132    let nl_correction = if tpsa > 80.0 {
133        -0.003 * (tpsa - 80.0).powi(2) / 100.0
134    } else {
135        0.0
136    };
137
138    size_term + polarity_term + hbond_term + lipophilic + nl_correction - 1.5
139}
140
141/// Model 3: Topological descriptor based.
142fn logp_model_3(desc: &MolecularDescriptors) -> f64 {
143    let chi_approx = if desc.n_bonds > 0 {
144        (desc.n_bonds as f64).sqrt() / (desc.n_heavy_atoms as f64).max(1.0)
145    } else {
146        0.0
147    };
148
149    let wiener_term = if desc.wiener_index > 0.0 {
150        0.25 * desc.wiener_index.ln()
151    } else {
152        0.0
153    };
154
155    let polar_penalty = -0.35 * (desc.n_hbd + desc.n_hba) as f64;
156    let arom_bonus = 0.12 * desc.n_aromatic as f64;
157
158    chi_approx + wiener_term + polar_penalty + arom_bonus - 0.8
159}
160
161// ─── pKa prediction ──────────────────────────────────────────────────────────
162
163/// Predict acidic pKa based on functional groups.
164/// Returns None if no acidic group detected.
165fn predict_pka_acidic(desc: &MolecularDescriptors, tpsa: f64) -> Option<f64> {
166    if desc.n_hbd == 0 {
167        return None;
168    }
169
170    // Base pKa value depends on donor type
171    // Approximate: carboxylic acid ~4.5, phenol ~10, aliphatic OH ~16, NH ~35
172    let base_pka = if desc.n_hba >= 2 && desc.n_hbd >= 1 {
173        // Likely has carboxylic acid (O-H + C=O pattern)
174        4.5
175    } else {
176        // Generic O-H or N-H
177        14.0
178    };
179
180    // Corrections
181    let ew_correction =
182        -0.3 * (desc.sum_electronegativity / desc.n_heavy_atoms.max(1) as f64 - 2.5);
183    let arom_correction = if desc.n_aromatic > 0 { -1.5 } else { 0.0 };
184    let tpsa_correction = -0.02 * (tpsa - 60.0);
185
186    Some((base_pka + ew_correction + arom_correction + tpsa_correction).clamp(0.0, 25.0))
187}
188
189/// Predict basic pKa. Returns None if no basic group detected.
190fn predict_pka_basic(desc: &MolecularDescriptors) -> Option<f64> {
191    // Check for nitrogen atoms that could be basic
192    let has_nitrogen =
193        desc.n_hba > 0 && desc.sum_electronegativity / desc.n_heavy_atoms.max(1) as f64 > 2.6;
194    if !has_nitrogen {
195        return None;
196    }
197
198    // Amine base pKa: primary ~10.6, aromatic ~5
199    let base_pka = if desc.n_aromatic > 0 {
200        5.2 // pyridine-like
201    } else {
202        10.6 // aliphatic amine
203    };
204
205    let sp3_correction = 0.5 * desc.fsp3;
206    Some((base_pka + sp3_correction).clamp(0.0, 14.0))
207}
208
209// ─── BBB permeability ────────────────────────────────────────────────────────
210
211/// Predict blood-brain barrier permeability.
212///
213/// Lipinski-like heuristic: MW < 450, TPSA < 90, LogP ∈ (1, 5), HBD ≤ 3.
214fn predict_bbb(desc: &MolecularDescriptors, logp: f64, tpsa: f64) -> (bool, f64) {
215    let mut score = 1.0;
216
217    if desc.molecular_weight > 450.0 {
218        score -= 0.3 * ((desc.molecular_weight - 450.0) / 100.0).min(1.0);
219    }
220    if tpsa > 90.0 {
221        score -= 0.35 * ((tpsa - 90.0) / 50.0).min(1.0);
222    }
223    if logp < 1.0 {
224        score -= 0.2 * (1.0 - logp).min(1.0);
225    }
226    if logp > 5.0 {
227        score -= 0.2 * ((logp - 5.0) / 2.0).min(1.0);
228    }
229    if desc.n_hbd > 3 {
230        score -= 0.15 * (desc.n_hbd as f64 - 3.0).min(2.0) / 2.0;
231    }
232
233    let score = score.clamp(0.0, 1.0);
234    (score > 0.5, score)
235}
236
237/// Predict molecular properties using an ensemble of models.
238///
239/// Combines three LogP models via consensus, adds TPSA, pKa estimates,
240/// BBB permeability, and Veber bioavailability rules.
241pub fn predict_ensemble(
242    desc: &MolecularDescriptors,
243    elements: &[u8],
244    bonds: &[(usize, usize, u8)],
245) -> EnsembleResult {
246    let tpsa = compute_tpsa(elements, bonds);
247
248    // Ensemble LogP: average of 3 models
249    let lp1 = logp_model_1(desc);
250    let lp2 = logp_model_2(desc, tpsa);
251    let lp3 = logp_model_3(desc);
252    let logp = (lp1 + lp2 + lp3) / 3.0;
253
254    // Standard deviation as uncertainty proxy
255    let logp_std = {
256        let mean = logp;
257        let var = ((lp1 - mean).powi(2) + (lp2 - mean).powi(2) + (lp3 - mean).powi(2)) / 3.0;
258        var.sqrt()
259    };
260
261    // Solubility (ESOL-like, using consensus LogP)
262    let frac_aromatic = if desc.n_heavy_atoms > 0 {
263        desc.n_aromatic as f64 / desc.n_heavy_atoms as f64
264    } else {
265        0.0
266    };
267    let log_solubility = 0.16 - 0.63 * logp - 0.0062 * desc.molecular_weight
268        + 0.066 * desc.n_rotatable_bonds as f64
269        - 0.74 * frac_aromatic;
270
271    // pKa predictions
272    let pka_acidic = predict_pka_acidic(desc, tpsa);
273    let pka_basic = predict_pka_basic(desc);
274
275    // Veber rules
276    let tpsa_ok = tpsa <= 140.0;
277    let rotb_ok = desc.n_rotatable_bonds <= 10;
278    let veber = VeberResult {
279        tpsa_ok,
280        rotb_ok,
281        passes: tpsa_ok && rotb_ok,
282    };
283
284    // BBB
285    let (bbb_permeable, bbb_score) = predict_bbb(desc, logp, tpsa);
286
287    // Confidence: based on model agreement and descriptor coverage
288    let model_agreement = 1.0 - (logp_std / 2.0).min(1.0);
289    let size_confidence = if desc.n_heavy_atoms >= 3 && desc.n_heavy_atoms <= 50 {
290        1.0
291    } else {
292        0.5
293    };
294    let confidence = (model_agreement * 0.7 + size_confidence * 0.3).clamp(0.0, 1.0);
295
296    EnsembleResult {
297        logp,
298        logp_std,
299        log_solubility,
300        tpsa,
301        pka_acidic,
302        pka_basic,
303        veber,
304        bbb_permeable,
305        bbb_score,
306        confidence,
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313    use crate::ml::descriptors::compute_descriptors;
314
315    #[test]
316    fn test_tpsa_water() {
317        // Water: O with 2 H neighbors
318        let elements = [8u8, 1, 1];
319        let bonds = [(0usize, 1usize, 1u8), (0, 2, 1)];
320        let tpsa = compute_tpsa(&elements, &bonds);
321        assert!(tpsa > 15.0 && tpsa < 25.0, "Water TPSA: {tpsa}");
322    }
323
324    #[test]
325    fn test_tpsa_benzene() {
326        // Benzene: no polar atoms → TPSA = 0
327        let elements = [6u8, 6, 6, 6, 6, 6, 1, 1, 1, 1, 1, 1];
328        let bonds: Vec<(usize, usize, u8)> = vec![
329            (0, 1, 2),
330            (1, 2, 1),
331            (2, 3, 2),
332            (3, 4, 1),
333            (4, 5, 2),
334            (5, 0, 1),
335            (0, 6, 1),
336            (1, 7, 1),
337            (2, 8, 1),
338            (3, 9, 1),
339            (4, 10, 1),
340            (5, 11, 1),
341        ];
342        let tpsa = compute_tpsa(&elements, &bonds);
343        assert!(
344            (tpsa - 0.0).abs() < 1e-6,
345            "Benzene TPSA should be 0: {tpsa}"
346        );
347    }
348
349    #[test]
350    fn test_ensemble_ethanol() {
351        let elements = [6u8, 6, 8, 1, 1, 1, 1, 1, 1];
352        let bonds: Vec<(usize, usize, u8)> = vec![
353            (0, 1, 1),
354            (1, 2, 1),
355            (0, 3, 1),
356            (0, 4, 1),
357            (0, 5, 1),
358            (1, 6, 1),
359            (1, 7, 1),
360            (2, 8, 1),
361        ];
362        let desc = compute_descriptors(&elements, &bonds, &[], &[]);
363        let result = predict_ensemble(&desc, &elements, &bonds);
364
365        assert!(result.tpsa > 15.0, "Ethanol has polar O-H: {}", result.tpsa);
366        assert!(result.logp < 2.0, "Ethanol is hydrophilic: {}", result.logp);
367        assert!(result.logp_std >= 0.0, "Uncertainty must be non-negative");
368        assert!(result.confidence > 0.0 && result.confidence <= 1.0);
369        assert!(result.veber.passes);
370    }
371
372    #[test]
373    fn test_ensemble_logp_consistency() {
374        // All three models should give broadly similar results for simple molecules
375        let elements = [6u8, 6, 6, 1, 1, 1, 1, 1, 1, 1, 1];
376        let bonds: Vec<(usize, usize, u8)> = vec![
377            (0, 1, 1),
378            (1, 2, 1),
379            (0, 3, 1),
380            (0, 4, 1),
381            (0, 5, 1),
382            (1, 6, 1),
383            (1, 7, 1),
384            (2, 8, 1),
385            (2, 9, 1),
386            (2, 10, 1),
387        ];
388        let desc = compute_descriptors(&elements, &bonds, &[], &[]);
389        let result = predict_ensemble(&desc, &elements, &bonds);
390
391        // Models should agree within ~2 units for small alkanes
392        assert!(
393            result.logp_std < 2.0,
394            "Models should broadly agree: std={}",
395            result.logp_std
396        );
397    }
398
399    #[test]
400    fn test_veber_large_molecule() {
401        let desc = MolecularDescriptors {
402            molecular_weight: 600.0,
403            n_heavy_atoms: 45,
404            n_hydrogens: 20,
405            n_bonds: 60,
406            n_rotatable_bonds: 15,
407            n_hbd: 5,
408            n_hba: 10,
409            fsp3: 0.3,
410            total_abs_charge: 3.0,
411            max_charge: 0.4,
412            min_charge: -0.4,
413            wiener_index: 3000.0,
414            n_rings: 4,
415            n_aromatic: 8,
416            balaban_j: 1.8,
417            sum_electronegativity: 120.0,
418            sum_polarizability: 65.0,
419        };
420        let elements = [6u8; 45];
421        let bonds: Vec<(usize, usize, u8)> = (0..44).map(|i| (i, i + 1, 1u8)).collect();
422        let result = predict_ensemble(&desc, &elements, &bonds);
423
424        assert!(
425            !result.veber.rotb_ok,
426            "Too many rotatable bonds: {}",
427            desc.n_rotatable_bonds
428        );
429    }
430
431    #[test]
432    fn test_bbb_small_lipophilic() {
433        // Small, lipophilic molecule — should be BBB permeable
434        let elements = [6u8, 6, 6, 1, 1, 1, 1, 1, 1, 1, 1];
435        let bonds: Vec<(usize, usize, u8)> = vec![
436            (0, 1, 1),
437            (1, 2, 1),
438            (0, 3, 1),
439            (0, 4, 1),
440            (0, 5, 1),
441            (1, 6, 1),
442            (1, 7, 1),
443            (2, 8, 1),
444            (2, 9, 1),
445            (2, 10, 1),
446        ];
447        let desc = compute_descriptors(&elements, &bonds, &[], &[]);
448        let result = predict_ensemble(&desc, &elements, &bonds);
449        assert!(
450            result.bbb_score > 0.0,
451            "Small lipophilic molecule should have positive BBB score"
452        );
453    }
454
455    #[test]
456    fn test_pka_with_acid() {
457        // Molecule with O-H and C=O (carboxylic acid pattern)
458        let desc = MolecularDescriptors {
459            molecular_weight: 60.0,
460            n_heavy_atoms: 3,
461            n_hydrogens: 4,
462            n_bonds: 6,
463            n_rotatable_bonds: 0,
464            n_hbd: 1,
465            n_hba: 2,
466            fsp3: 0.0,
467            total_abs_charge: 0.5,
468            max_charge: 0.2,
469            min_charge: -0.3,
470            wiener_index: 4.0,
471            n_rings: 0,
472            n_aromatic: 0,
473            balaban_j: 1.0,
474            sum_electronegativity: 8.0,
475            sum_polarizability: 3.0,
476        };
477        let elements = [6u8, 8, 8, 1, 1, 1, 1];
478        let bonds: Vec<(usize, usize, u8)> = vec![
479            (0, 1, 2),
480            (0, 2, 1),
481            (2, 3, 1),
482            (0, 4, 1),
483            (0, 5, 1),
484            (0, 6, 1),
485        ];
486        let result = predict_ensemble(&desc, &elements, &bonds);
487        assert!(
488            result.pka_acidic.is_some(),
489            "Carboxylic acid should have pKa"
490        );
491    }
492}