1use super::descriptors::MolecularDescriptors;
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct EnsembleResult {
16 pub logp: f64,
18 pub logp_std: f64,
20 pub log_solubility: f64,
22 pub tpsa: f64,
24 pub pka_acidic: Option<f64>,
26 pub pka_basic: Option<f64>,
28 pub veber: VeberResult,
30 pub bbb_permeable: bool,
32 pub bbb_score: f64,
34 pub confidence: f64,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct VeberResult {
41 pub tpsa_ok: bool,
43 pub rotb_ok: bool,
45 pub passes: bool,
47}
48
49fn tpsa_contribution(z: u8, n_heavy_neighbors: usize, n_h_neighbors: usize) -> f64 {
52 match z {
53 7 => match (n_heavy_neighbors, n_h_neighbors) {
55 (1, 2) => 26.02, (1, 1) => 26.02, (2, 1) => 12.36, (2, 0) => 12.36, (3, 0) => 3.24, (1, 0) => 23.79, _ => 12.0,
62 },
63 8 => match (n_heavy_neighbors, n_h_neighbors) {
65 (0, 2) => 20.23, (1, 1) => 20.23, (1, 0) => 17.07, (2, 0) => 9.23, _ => 15.0,
70 },
71 16 => match (n_heavy_neighbors, n_h_neighbors) {
73 (1, 1) => 38.80, (2, 0) => 25.30, (1, 0) => 32.00, _ => 28.0,
77 },
78 15 => 34.14,
80 _ => 0.0,
81 }
82}
83
84pub fn compute_tpsa(elements: &[u8], bonds: &[(usize, usize, u8)]) -> f64 {
88 let n = elements.len();
89 let mut adj: Vec<Vec<usize>> = vec![vec![]; n];
90 for &(i, j, _) in bonds {
91 if i < n && j < n {
92 adj[i].push(j);
93 adj[j].push(i);
94 }
95 }
96
97 let mut tpsa = 0.0;
98 for i in 0..n {
99 if !matches!(elements[i], 7 | 8 | 15 | 16) {
100 continue;
101 }
102 let n_heavy = adj[i].iter().filter(|&&j| elements[j] != 1).count();
103 let n_h = adj[i].iter().filter(|&&j| elements[j] == 1).count();
104 tpsa += tpsa_contribution(elements[i], n_heavy, n_h);
105 }
106 tpsa
107}
108
109fn logp_model_1(desc: &MolecularDescriptors) -> f64 {
113 let base = 0.120 * desc.n_heavy_atoms as f64;
114 let h_corr = -0.230 * desc.n_hbd as f64;
115 let ring_corr = 0.150 * desc.n_rings as f64;
116 let arom_corr = 0.080 * desc.n_aromatic as f64;
117 let polar_corr = -0.310 * desc.n_hba as f64;
118 let sp3_corr = -0.180 * desc.fsp3;
119 let mw_term = 0.005 * (desc.molecular_weight - 100.0);
120 base + h_corr + ring_corr + arom_corr + polar_corr + sp3_corr + mw_term
121}
122
123fn logp_model_2(desc: &MolecularDescriptors, tpsa: f64) -> f64 {
125 let polarity_term = -0.015 * tpsa;
127 let size_term = 0.008 * desc.molecular_weight;
128 let hbond_term = -0.45 * (desc.n_hbd as f64 + 0.5 * desc.n_hba as f64);
129 let lipophilic = 0.22 * (desc.n_heavy_atoms as f64 - desc.n_hba as f64);
130
131 let nl_correction = if tpsa > 80.0 {
133 -0.003 * (tpsa - 80.0).powi(2) / 100.0
134 } else {
135 0.0
136 };
137
138 size_term + polarity_term + hbond_term + lipophilic + nl_correction - 1.5
139}
140
141fn logp_model_3(desc: &MolecularDescriptors) -> f64 {
143 let chi_approx = if desc.n_bonds > 0 {
144 (desc.n_bonds as f64).sqrt() / (desc.n_heavy_atoms as f64).max(1.0)
145 } else {
146 0.0
147 };
148
149 let wiener_term = if desc.wiener_index > 0.0 {
150 0.25 * desc.wiener_index.ln()
151 } else {
152 0.0
153 };
154
155 let polar_penalty = -0.35 * (desc.n_hbd + desc.n_hba) as f64;
156 let arom_bonus = 0.12 * desc.n_aromatic as f64;
157
158 chi_approx + wiener_term + polar_penalty + arom_bonus - 0.8
159}
160
161fn predict_pka_acidic(desc: &MolecularDescriptors, tpsa: f64) -> Option<f64> {
166 if desc.n_hbd == 0 {
167 return None;
168 }
169
170 let base_pka = if desc.n_hba >= 2 && desc.n_hbd >= 1 {
173 4.5
175 } else {
176 14.0
178 };
179
180 let ew_correction =
182 -0.3 * (desc.sum_electronegativity / desc.n_heavy_atoms.max(1) as f64 - 2.5);
183 let arom_correction = if desc.n_aromatic > 0 { -1.5 } else { 0.0 };
184 let tpsa_correction = -0.02 * (tpsa - 60.0);
185
186 Some((base_pka + ew_correction + arom_correction + tpsa_correction).clamp(0.0, 25.0))
187}
188
189fn predict_pka_basic(desc: &MolecularDescriptors) -> Option<f64> {
191 let has_nitrogen =
193 desc.n_hba > 0 && desc.sum_electronegativity / desc.n_heavy_atoms.max(1) as f64 > 2.6;
194 if !has_nitrogen {
195 return None;
196 }
197
198 let base_pka = if desc.n_aromatic > 0 {
200 5.2 } else {
202 10.6 };
204
205 let sp3_correction = 0.5 * desc.fsp3;
206 Some((base_pka + sp3_correction).clamp(0.0, 14.0))
207}
208
209fn predict_bbb(desc: &MolecularDescriptors, logp: f64, tpsa: f64) -> (bool, f64) {
215 let mut score = 1.0;
216
217 if desc.molecular_weight > 450.0 {
218 score -= 0.3 * ((desc.molecular_weight - 450.0) / 100.0).min(1.0);
219 }
220 if tpsa > 90.0 {
221 score -= 0.35 * ((tpsa - 90.0) / 50.0).min(1.0);
222 }
223 if logp < 1.0 {
224 score -= 0.2 * (1.0 - logp).min(1.0);
225 }
226 if logp > 5.0 {
227 score -= 0.2 * ((logp - 5.0) / 2.0).min(1.0);
228 }
229 if desc.n_hbd > 3 {
230 score -= 0.15 * (desc.n_hbd as f64 - 3.0).min(2.0) / 2.0;
231 }
232
233 let score = score.clamp(0.0, 1.0);
234 (score > 0.5, score)
235}
236
237pub fn predict_ensemble(
242 desc: &MolecularDescriptors,
243 elements: &[u8],
244 bonds: &[(usize, usize, u8)],
245) -> EnsembleResult {
246 let tpsa = compute_tpsa(elements, bonds);
247
248 let lp1 = logp_model_1(desc);
250 let lp2 = logp_model_2(desc, tpsa);
251 let lp3 = logp_model_3(desc);
252 let logp = (lp1 + lp2 + lp3) / 3.0;
253
254 let logp_std = {
256 let mean = logp;
257 let var = ((lp1 - mean).powi(2) + (lp2 - mean).powi(2) + (lp3 - mean).powi(2)) / 3.0;
258 var.sqrt()
259 };
260
261 let frac_aromatic = if desc.n_heavy_atoms > 0 {
263 desc.n_aromatic as f64 / desc.n_heavy_atoms as f64
264 } else {
265 0.0
266 };
267 let log_solubility = 0.16 - 0.63 * logp - 0.0062 * desc.molecular_weight
268 + 0.066 * desc.n_rotatable_bonds as f64
269 - 0.74 * frac_aromatic;
270
271 let pka_acidic = predict_pka_acidic(desc, tpsa);
273 let pka_basic = predict_pka_basic(desc);
274
275 let tpsa_ok = tpsa <= 140.0;
277 let rotb_ok = desc.n_rotatable_bonds <= 10;
278 let veber = VeberResult {
279 tpsa_ok,
280 rotb_ok,
281 passes: tpsa_ok && rotb_ok,
282 };
283
284 let (bbb_permeable, bbb_score) = predict_bbb(desc, logp, tpsa);
286
287 let model_agreement = 1.0 - (logp_std / 2.0).min(1.0);
289 let size_confidence = if desc.n_heavy_atoms >= 3 && desc.n_heavy_atoms <= 50 {
290 1.0
291 } else {
292 0.5
293 };
294 let confidence = (model_agreement * 0.7 + size_confidence * 0.3).clamp(0.0, 1.0);
295
296 EnsembleResult {
297 logp,
298 logp_std,
299 log_solubility,
300 tpsa,
301 pka_acidic,
302 pka_basic,
303 veber,
304 bbb_permeable,
305 bbb_score,
306 confidence,
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313 use crate::ml::descriptors::compute_descriptors;
314
315 #[test]
316 fn test_tpsa_water() {
317 let elements = [8u8, 1, 1];
319 let bonds = [(0usize, 1usize, 1u8), (0, 2, 1)];
320 let tpsa = compute_tpsa(&elements, &bonds);
321 assert!(tpsa > 15.0 && tpsa < 25.0, "Water TPSA: {tpsa}");
322 }
323
324 #[test]
325 fn test_tpsa_benzene() {
326 let elements = [6u8, 6, 6, 6, 6, 6, 1, 1, 1, 1, 1, 1];
328 let bonds: Vec<(usize, usize, u8)> = vec![
329 (0, 1, 2),
330 (1, 2, 1),
331 (2, 3, 2),
332 (3, 4, 1),
333 (4, 5, 2),
334 (5, 0, 1),
335 (0, 6, 1),
336 (1, 7, 1),
337 (2, 8, 1),
338 (3, 9, 1),
339 (4, 10, 1),
340 (5, 11, 1),
341 ];
342 let tpsa = compute_tpsa(&elements, &bonds);
343 assert!(
344 (tpsa - 0.0).abs() < 1e-6,
345 "Benzene TPSA should be 0: {tpsa}"
346 );
347 }
348
349 #[test]
350 fn test_ensemble_ethanol() {
351 let elements = [6u8, 6, 8, 1, 1, 1, 1, 1, 1];
352 let bonds: Vec<(usize, usize, u8)> = vec![
353 (0, 1, 1),
354 (1, 2, 1),
355 (0, 3, 1),
356 (0, 4, 1),
357 (0, 5, 1),
358 (1, 6, 1),
359 (1, 7, 1),
360 (2, 8, 1),
361 ];
362 let desc = compute_descriptors(&elements, &bonds, &[], &[]);
363 let result = predict_ensemble(&desc, &elements, &bonds);
364
365 assert!(result.tpsa > 15.0, "Ethanol has polar O-H: {}", result.tpsa);
366 assert!(result.logp < 2.0, "Ethanol is hydrophilic: {}", result.logp);
367 assert!(result.logp_std >= 0.0, "Uncertainty must be non-negative");
368 assert!(result.confidence > 0.0 && result.confidence <= 1.0);
369 assert!(result.veber.passes);
370 }
371
372 #[test]
373 fn test_ensemble_logp_consistency() {
374 let elements = [6u8, 6, 6, 1, 1, 1, 1, 1, 1, 1, 1];
376 let bonds: Vec<(usize, usize, u8)> = vec![
377 (0, 1, 1),
378 (1, 2, 1),
379 (0, 3, 1),
380 (0, 4, 1),
381 (0, 5, 1),
382 (1, 6, 1),
383 (1, 7, 1),
384 (2, 8, 1),
385 (2, 9, 1),
386 (2, 10, 1),
387 ];
388 let desc = compute_descriptors(&elements, &bonds, &[], &[]);
389 let result = predict_ensemble(&desc, &elements, &bonds);
390
391 assert!(
393 result.logp_std < 2.0,
394 "Models should broadly agree: std={}",
395 result.logp_std
396 );
397 }
398
399 #[test]
400 fn test_veber_large_molecule() {
401 let desc = MolecularDescriptors {
402 molecular_weight: 600.0,
403 n_heavy_atoms: 45,
404 n_hydrogens: 20,
405 n_bonds: 60,
406 n_rotatable_bonds: 15,
407 n_hbd: 5,
408 n_hba: 10,
409 fsp3: 0.3,
410 total_abs_charge: 3.0,
411 max_charge: 0.4,
412 min_charge: -0.4,
413 wiener_index: 3000.0,
414 n_rings: 4,
415 n_aromatic: 8,
416 balaban_j: 1.8,
417 sum_electronegativity: 120.0,
418 sum_polarizability: 65.0,
419 };
420 let elements = [6u8; 45];
421 let bonds: Vec<(usize, usize, u8)> = (0..44).map(|i| (i, i + 1, 1u8)).collect();
422 let result = predict_ensemble(&desc, &elements, &bonds);
423
424 assert!(
425 !result.veber.rotb_ok,
426 "Too many rotatable bonds: {}",
427 desc.n_rotatable_bonds
428 );
429 }
430
431 #[test]
432 fn test_bbb_small_lipophilic() {
433 let elements = [6u8, 6, 6, 1, 1, 1, 1, 1, 1, 1, 1];
435 let bonds: Vec<(usize, usize, u8)> = vec![
436 (0, 1, 1),
437 (1, 2, 1),
438 (0, 3, 1),
439 (0, 4, 1),
440 (0, 5, 1),
441 (1, 6, 1),
442 (1, 7, 1),
443 (2, 8, 1),
444 (2, 9, 1),
445 (2, 10, 1),
446 ];
447 let desc = compute_descriptors(&elements, &bonds, &[], &[]);
448 let result = predict_ensemble(&desc, &elements, &bonds);
449 assert!(
450 result.bbb_score > 0.0,
451 "Small lipophilic molecule should have positive BBB score"
452 );
453 }
454
455 #[test]
456 fn test_pka_with_acid() {
457 let desc = MolecularDescriptors {
459 molecular_weight: 60.0,
460 n_heavy_atoms: 3,
461 n_hydrogens: 4,
462 n_bonds: 6,
463 n_rotatable_bonds: 0,
464 n_hbd: 1,
465 n_hba: 2,
466 fsp3: 0.0,
467 total_abs_charge: 0.5,
468 max_charge: 0.2,
469 min_charge: -0.3,
470 wiener_index: 4.0,
471 n_rings: 0,
472 n_aromatic: 0,
473 balaban_j: 1.0,
474 sum_electronegativity: 8.0,
475 sum_polarizability: 3.0,
476 };
477 let elements = [6u8, 8, 8, 1, 1, 1, 1];
478 let bonds: Vec<(usize, usize, u8)> = vec![
479 (0, 1, 2),
480 (0, 2, 1),
481 (2, 3, 1),
482 (0, 4, 1),
483 (0, 5, 1),
484 (0, 6, 1),
485 ];
486 let result = predict_ensemble(&desc, &elements, &bonds);
487 assert!(
488 result.pka_acidic.is_some(),
489 "Carboxylic acid should have pKa"
490 );
491 }
492}