Skip to main content

sci_form/ml/
advanced_models.rs

1//! Advanced ML models: Random Forest, Gradient Boosting, and cross-validation.
2//!
3//! Pure-Rust implementations for molecular property prediction beyond
4//! simple linear models. Includes:
5//! - Random Forest (bagged decision trees)
6//! - Gradient Boosted Trees (GBM / GBRT)
7//! - K-fold cross-validation
8//! - Model recalibration via isotonic regression
9
10use serde::{Deserialize, Serialize};
11
12// ─── Decision Tree ───────────────────────────────────────────────────────────
13
14/// A single decision tree node (binary split).
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub enum TreeNode {
17    Leaf {
18        value: f64,
19    },
20    Split {
21        feature: usize,
22        threshold: f64,
23        left: Box<TreeNode>,
24        right: Box<TreeNode>,
25    },
26}
27
28impl TreeNode {
29    /// Predict a single sample.
30    pub fn predict(&self, features: &[f64]) -> f64 {
31        match self {
32            TreeNode::Leaf { value } => *value,
33            TreeNode::Split {
34                feature,
35                threshold,
36                left,
37                right,
38            } => {
39                if features.get(*feature).copied().unwrap_or(0.0) <= *threshold {
40                    left.predict(features)
41                } else {
42                    right.predict(features)
43                }
44            }
45        }
46    }
47}
48
49/// Configuration for tree building.
50#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
51pub struct TreeConfig {
52    /// Maximum depth of the tree.
53    pub max_depth: usize,
54    /// Minimum samples to split a node.
55    pub min_samples_split: usize,
56    /// Minimum samples in a leaf.
57    pub min_samples_leaf: usize,
58    /// Number of features to consider at each split (0 = sqrt(n_features)).
59    pub max_features: usize,
60}
61
62impl Default for TreeConfig {
63    fn default() -> Self {
64        Self {
65            max_depth: 10,
66            min_samples_split: 5,
67            min_samples_leaf: 2,
68            max_features: 0,
69        }
70    }
71}
72
73/// Build a decision tree from data.
74pub fn build_tree(
75    features: &[Vec<f64>],
76    targets: &[f64],
77    config: &TreeConfig,
78    rng_seed: u64,
79) -> TreeNode {
80    let indices: Vec<usize> = (0..targets.len()).collect();
81    let n_features = features.first().map_or(0, |f| f.len());
82    let max_feat = if config.max_features == 0 {
83        (n_features as f64).sqrt().ceil() as usize
84    } else {
85        config.max_features.min(n_features)
86    }
87    .max(1);
88    build_tree_recursive(features, targets, &indices, config, max_feat, 0, rng_seed)
89}
90
91fn build_tree_recursive(
92    features: &[Vec<f64>],
93    targets: &[f64],
94    indices: &[usize],
95    config: &TreeConfig,
96    max_features: usize,
97    depth: usize,
98    seed: u64,
99) -> TreeNode {
100    let n = indices.len();
101
102    // Leaf conditions
103    if n < config.min_samples_split || depth >= config.max_depth || n < 2 * config.min_samples_leaf
104    {
105        let mean = indices.iter().map(|&i| targets[i]).sum::<f64>() / n.max(1) as f64;
106        return TreeNode::Leaf { value: mean };
107    }
108
109    let n_features = features.first().map_or(0, |f| f.len());
110    if n_features == 0 {
111        let mean = indices.iter().map(|&i| targets[i]).sum::<f64>() / n as f64;
112        return TreeNode::Leaf { value: mean };
113    }
114
115    // Select random subset of features
116    let feature_subset = select_features(n_features, max_features, seed);
117
118    // Find best split
119    let mut best_score = f64::INFINITY;
120    let mut best_feature = 0;
121    let mut best_threshold = 0.0;
122    let mut best_left = vec![];
123    let mut best_right = vec![];
124
125    for &feat in &feature_subset {
126        // Get unique sorted values for this feature
127        let mut vals: Vec<f64> = indices.iter().map(|&i| features[i][feat]).collect();
128        vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
129        vals.dedup();
130
131        if vals.len() < 2 {
132            continue;
133        }
134
135        // Try midpoints as thresholds
136        for w in vals.windows(2) {
137            let threshold = (w[0] + w[1]) / 2.0;
138            let (left, right): (Vec<usize>, Vec<usize>) = indices
139                .iter()
140                .partition(|&&i| features[i][feat] <= threshold);
141
142            if left.len() < config.min_samples_leaf || right.len() < config.min_samples_leaf {
143                continue;
144            }
145
146            let score = mse_score(&left, targets) + mse_score(&right, targets);
147            if score < best_score {
148                best_score = score;
149                best_feature = feat;
150                best_threshold = threshold;
151                best_left = left;
152                best_right = right;
153            }
154        }
155    }
156
157    if best_left.is_empty() || best_right.is_empty() {
158        let mean = indices.iter().map(|&i| targets[i]).sum::<f64>() / n as f64;
159        return TreeNode::Leaf { value: mean };
160    }
161
162    let left_node = build_tree_recursive(
163        features,
164        targets,
165        &best_left,
166        config,
167        max_features,
168        depth + 1,
169        seed.wrapping_mul(6364136223846793005).wrapping_add(1),
170    );
171    let right_node = build_tree_recursive(
172        features,
173        targets,
174        &best_right,
175        config,
176        max_features,
177        depth + 1,
178        seed.wrapping_mul(6364136223846793005).wrapping_add(3),
179    );
180
181    TreeNode::Split {
182        feature: best_feature,
183        threshold: best_threshold,
184        left: Box::new(left_node),
185        right: Box::new(right_node),
186    }
187}
188
189fn mse_score(indices: &[usize], targets: &[f64]) -> f64 {
190    let n = indices.len() as f64;
191    if n < 1.0 {
192        return 0.0;
193    }
194    let mean = indices.iter().map(|&i| targets[i]).sum::<f64>() / n;
195    indices
196        .iter()
197        .map(|&i| {
198            let d = targets[i] - mean;
199            d * d
200        })
201        .sum::<f64>()
202}
203
204fn select_features(n_features: usize, max_features: usize, seed: u64) -> Vec<usize> {
205    if max_features >= n_features {
206        return (0..n_features).collect();
207    }
208
209    // Simple LCG-based selection without replacement
210    let mut selected = Vec::with_capacity(max_features);
211    let mut available: Vec<usize> = (0..n_features).collect();
212    let mut s = seed;
213    for _ in 0..max_features {
214        if available.is_empty() {
215            break;
216        }
217        s = s
218            .wrapping_mul(6364136223846793005)
219            .wrapping_add(1442695040888963407);
220        let idx = (s >> 33) as usize % available.len();
221        selected.push(available.swap_remove(idx));
222    }
223    selected
224}
225
226// ─── Random Forest ───────────────────────────────────────────────────────────
227
228/// Random Forest model: ensemble of bagged decision trees.
229#[derive(Debug, Clone, Serialize, Deserialize)]
230pub struct RandomForest {
231    pub trees: Vec<TreeNode>,
232    pub n_trees: usize,
233    pub oob_score: Option<f64>,
234}
235
236/// Configuration for Random Forest.
237#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
238pub struct RandomForestConfig {
239    pub n_trees: usize,
240    pub tree_config: TreeConfig,
241    /// Fraction of samples for each bootstrap (0.0–1.0). Default: 1.0
242    pub sample_fraction: f64,
243    pub seed: u64,
244}
245
246impl Default for RandomForestConfig {
247    fn default() -> Self {
248        Self {
249            n_trees: 100,
250            tree_config: TreeConfig::default(),
251            sample_fraction: 1.0,
252            seed: 42,
253        }
254    }
255}
256
257/// Train a Random Forest regressor.
258pub fn train_random_forest(
259    features: &[Vec<f64>],
260    targets: &[f64],
261    config: &RandomForestConfig,
262) -> RandomForest {
263    let n = targets.len();
264    let n_sample = ((n as f64 * config.sample_fraction).ceil() as usize).max(1);
265    let mut trees = Vec::with_capacity(config.n_trees);
266
267    for t in 0..config.n_trees {
268        let seed = config.seed.wrapping_add(t as u64 * 1000003);
269        // Bootstrap sample
270        let bootstrap = bootstrap_indices(n, n_sample, seed);
271        let boot_features: Vec<Vec<f64>> = bootstrap.iter().map(|&i| features[i].clone()).collect();
272        let boot_targets: Vec<f64> = bootstrap.iter().map(|&i| targets[i]).collect();
273        let tree = build_tree(&boot_features, &boot_targets, &config.tree_config, seed);
274        trees.push(tree);
275    }
276
277    RandomForest {
278        n_trees: config.n_trees,
279        trees,
280        oob_score: None,
281    }
282}
283
284impl RandomForest {
285    /// Predict a single sample (mean of all trees).
286    pub fn predict(&self, features: &[f64]) -> f64 {
287        let sum: f64 = self.trees.iter().map(|t| t.predict(features)).sum();
288        sum / self.trees.len().max(1) as f64
289    }
290
291    /// Predict and return individual tree predictions (for uncertainty).
292    pub fn predict_with_variance(&self, features: &[f64]) -> (f64, f64) {
293        let preds: Vec<f64> = self.trees.iter().map(|t| t.predict(features)).collect();
294        let mean = preds.iter().sum::<f64>() / preds.len().max(1) as f64;
295        let var =
296            preds.iter().map(|p| (p - mean) * (p - mean)).sum::<f64>() / preds.len().max(1) as f64;
297        (mean, var)
298    }
299}
300
301fn bootstrap_indices(n: usize, n_sample: usize, seed: u64) -> Vec<usize> {
302    let mut indices = Vec::with_capacity(n_sample);
303    let mut s = seed;
304    for _ in 0..n_sample {
305        s = s
306            .wrapping_mul(6364136223846793005)
307            .wrapping_add(1442695040888963407);
308        indices.push((s >> 33) as usize % n);
309    }
310    indices
311}
312
313// ─── Gradient Boosting ───────────────────────────────────────────────────────
314
315/// Gradient Boosted Trees regressor.
316#[derive(Debug, Clone, Serialize, Deserialize)]
317pub struct GradientBoosting {
318    pub trees: Vec<TreeNode>,
319    pub learning_rate: f64,
320    pub initial_value: f64,
321    pub n_estimators: usize,
322}
323
324/// Configuration for Gradient Boosting.
325#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
326pub struct GradientBoostingConfig {
327    pub n_estimators: usize,
328    pub learning_rate: f64,
329    pub tree_config: TreeConfig,
330    pub subsample: f64,
331    pub seed: u64,
332}
333
334impl Default for GradientBoostingConfig {
335    fn default() -> Self {
336        Self {
337            n_estimators: 100,
338            learning_rate: 0.1,
339            tree_config: TreeConfig {
340                max_depth: 5,
341                min_samples_split: 5,
342                min_samples_leaf: 2,
343                max_features: 0,
344            },
345            subsample: 0.8,
346            seed: 42,
347        }
348    }
349}
350
351/// Train a Gradient Boosting regressor (L2 loss / least squares).
352pub fn train_gradient_boosting(
353    features: &[Vec<f64>],
354    targets: &[f64],
355    config: &GradientBoostingConfig,
356) -> GradientBoosting {
357    let n = targets.len();
358    let initial_value = targets.iter().sum::<f64>() / n.max(1) as f64;
359    let mut predictions = vec![initial_value; n];
360    let mut trees = Vec::with_capacity(config.n_estimators);
361
362    for t in 0..config.n_estimators {
363        // Compute negative gradient (residuals for L2 loss)
364        let residuals: Vec<f64> = (0..n).map(|i| targets[i] - predictions[i]).collect();
365
366        // Subsample
367        let seed = config.seed.wrapping_add(t as u64 * 999983);
368        let n_sub = ((n as f64 * config.subsample).ceil() as usize).max(1);
369        let sub_idx = bootstrap_indices(n, n_sub, seed);
370
371        let sub_features: Vec<Vec<f64>> = sub_idx.iter().map(|&i| features[i].clone()).collect();
372        let sub_residuals: Vec<f64> = sub_idx.iter().map(|&i| residuals[i]).collect();
373
374        let tree = build_tree(&sub_features, &sub_residuals, &config.tree_config, seed);
375
376        // Update predictions
377        for i in 0..n {
378            predictions[i] += config.learning_rate * tree.predict(&features[i]);
379        }
380
381        trees.push(tree);
382    }
383
384    GradientBoosting {
385        trees,
386        learning_rate: config.learning_rate,
387        initial_value,
388        n_estimators: config.n_estimators,
389    }
390}
391
392impl GradientBoosting {
393    /// Predict a single sample.
394    pub fn predict(&self, features: &[f64]) -> f64 {
395        self.initial_value
396            + self.learning_rate * self.trees.iter().map(|t| t.predict(features)).sum::<f64>()
397    }
398}
399
400// ─── Cross-Validation ────────────────────────────────────────────────────────
401
402/// Result of k-fold cross-validation.
403#[derive(Debug, Clone, Serialize, Deserialize)]
404pub struct CrossValidationResult {
405    /// Mean absolute error per fold.
406    pub fold_mae: Vec<f64>,
407    /// Root mean squared error per fold.
408    pub fold_rmse: Vec<f64>,
409    /// R² per fold.
410    pub fold_r2: Vec<f64>,
411    /// Overall MAE (mean across folds).
412    pub mean_mae: f64,
413    /// Overall RMSE (mean across folds).
414    pub mean_rmse: f64,
415    /// Overall R² (mean across folds).
416    pub mean_r2: f64,
417    /// Number of folds.
418    pub k: usize,
419}
420
421/// Model type for cross-validation.
422pub enum ModelType {
423    RandomForest(RandomForestConfig),
424    GradientBoosting(GradientBoostingConfig),
425}
426
427/// Perform k-fold cross-validation.
428pub fn cross_validate(
429    features: &[Vec<f64>],
430    targets: &[f64],
431    model_type: &ModelType,
432    k: usize,
433    seed: u64,
434) -> CrossValidationResult {
435    let n = targets.len();
436    let k = k.max(2).min(n);
437
438    // Create fold assignments (shuffled)
439    let mut indices: Vec<usize> = (0..n).collect();
440    // Fisher-Yates shuffle
441    let mut s = seed;
442    for i in (1..n).rev() {
443        s = s
444            .wrapping_mul(6364136223846793005)
445            .wrapping_add(1442695040888963407);
446        let j = (s >> 33) as usize % (i + 1);
447        indices.swap(i, j);
448    }
449
450    let fold_size = n / k;
451    let mut fold_mae = Vec::with_capacity(k);
452    let mut fold_rmse = Vec::with_capacity(k);
453    let mut fold_r2 = Vec::with_capacity(k);
454
455    for fold in 0..k {
456        let test_start = fold * fold_size;
457        let test_end = if fold == k - 1 {
458            n
459        } else {
460            (fold + 1) * fold_size
461        };
462
463        let test_indices: Vec<usize> = indices[test_start..test_end].to_vec();
464        let train_indices: Vec<usize> = indices[..test_start]
465            .iter()
466            .chain(indices[test_end..].iter())
467            .copied()
468            .collect();
469
470        let train_features: Vec<Vec<f64>> =
471            train_indices.iter().map(|&i| features[i].clone()).collect();
472        let train_targets: Vec<f64> = train_indices.iter().map(|&i| targets[i]).collect();
473
474        // Train model
475        let predictions: Vec<f64> = match model_type {
476            ModelType::RandomForest(config) => {
477                let model = train_random_forest(&train_features, &train_targets, config);
478                test_indices
479                    .iter()
480                    .map(|&i| model.predict(&features[i]))
481                    .collect()
482            }
483            ModelType::GradientBoosting(config) => {
484                let model = train_gradient_boosting(&train_features, &train_targets, config);
485                test_indices
486                    .iter()
487                    .map(|&i| model.predict(&features[i]))
488                    .collect()
489            }
490        };
491
492        let test_targets: Vec<f64> = test_indices.iter().map(|&i| targets[i]).collect();
493        let (mae, rmse, r2) = compute_metrics(&predictions, &test_targets);
494        fold_mae.push(mae);
495        fold_rmse.push(rmse);
496        fold_r2.push(r2);
497    }
498
499    let mean_mae = fold_mae.iter().sum::<f64>() / k as f64;
500    let mean_rmse = fold_rmse.iter().sum::<f64>() / k as f64;
501    let mean_r2 = fold_r2.iter().sum::<f64>() / k as f64;
502
503    CrossValidationResult {
504        fold_mae,
505        fold_rmse,
506        fold_r2,
507        mean_mae,
508        mean_rmse,
509        mean_r2,
510        k,
511    }
512}
513
514fn compute_metrics(predictions: &[f64], targets: &[f64]) -> (f64, f64, f64) {
515    let n = predictions.len() as f64;
516    if n < 1.0 {
517        return (0.0, 0.0, 0.0);
518    }
519
520    let mut sum_ae = 0.0;
521    let mut sum_se = 0.0;
522    let mean_t = targets.iter().sum::<f64>() / n;
523    let mut ss_tot = 0.0;
524
525    for i in 0..predictions.len() {
526        let err = predictions[i] - targets[i];
527        sum_ae += err.abs();
528        sum_se += err * err;
529        ss_tot += (targets[i] - mean_t) * (targets[i] - mean_t);
530    }
531
532    let mae = sum_ae / n;
533    let rmse = (sum_se / n).sqrt();
534    let r2 = if ss_tot > 1e-12 {
535        1.0 - sum_se / ss_tot
536    } else {
537        0.0
538    };
539
540    (mae, rmse, r2)
541}
542
543// ─── Isotonic Regression Recalibration ───────────────────────────────────────
544
545/// Isotonic regression for model recalibration.
546/// Fits a monotone non-decreasing function to prediction-target pairs.
547#[derive(Debug, Clone, Serialize, Deserialize)]
548pub struct IsotonicCalibrator {
549    pub knots_x: Vec<f64>,
550    pub knots_y: Vec<f64>,
551}
552
553impl IsotonicCalibrator {
554    /// Fit isotonic regression from (prediction, target) pairs.
555    pub fn fit(predictions: &[f64], targets: &[f64]) -> Self {
556        let n = predictions.len().min(targets.len());
557        if n == 0 {
558            return Self {
559                knots_x: vec![],
560                knots_y: vec![],
561            };
562        }
563
564        // Sort by prediction
565        let mut pairs: Vec<(f64, f64)> = predictions[..n]
566            .iter()
567            .zip(targets[..n].iter())
568            .map(|(&p, &t)| (p, t))
569            .collect();
570        pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
571
572        // Pool adjacent violators (PAVA)
573        let mut blocks: Vec<(f64, f64, usize)> = pairs.iter().map(|&(x, y)| (x, y, 1)).collect(); // (x, y_mean, count)
574
575        let mut i = 0;
576        while i < blocks.len() - 1 {
577            if blocks[i].1 > blocks[i + 1].1 {
578                // Merge blocks
579                let n1 = blocks[i].2 as f64;
580                let n2 = blocks[i + 1].2 as f64;
581                let merged_y = (n1 * blocks[i].1 + n2 * blocks[i + 1].1) / (n1 + n2);
582                blocks[i].1 = merged_y;
583                blocks[i].2 += blocks[i + 1].2;
584                blocks.remove(i + 1);
585                if i > 0 {
586                    i = i.saturating_sub(1);
587                }
588            } else {
589                i += 1;
590            }
591        }
592
593        let knots_x: Vec<f64> = blocks.iter().map(|b| b.0).collect();
594        let knots_y: Vec<f64> = blocks.iter().map(|b| b.1).collect();
595
596        Self { knots_x, knots_y }
597    }
598
599    /// Calibrate a prediction using the fitted isotonic function.
600    pub fn calibrate(&self, prediction: f64) -> f64 {
601        if self.knots_x.is_empty() {
602            return prediction;
603        }
604        if prediction <= self.knots_x[0] {
605            return self.knots_y[0];
606        }
607        if prediction >= *self.knots_x.last().unwrap() {
608            return *self.knots_y.last().unwrap();
609        }
610
611        // Linear interpolation between knots
612        for i in 0..self.knots_x.len() - 1 {
613            if prediction >= self.knots_x[i] && prediction <= self.knots_x[i + 1] {
614                let t = (prediction - self.knots_x[i]) / (self.knots_x[i + 1] - self.knots_x[i]);
615                return self.knots_y[i] + t * (self.knots_y[i + 1] - self.knots_y[i]);
616            }
617        }
618
619        prediction
620    }
621}
622
623#[cfg(test)]
624mod tests {
625    use super::*;
626
627    #[test]
628    fn test_decision_tree() {
629        let features = vec![
630            vec![1.0],
631            vec![2.0],
632            vec![3.0],
633            vec![4.0],
634            vec![5.0],
635            vec![6.0],
636            vec![7.0],
637            vec![8.0],
638        ];
639        let targets = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
640
641        let config = TreeConfig {
642            max_depth: 5,
643            min_samples_split: 2,
644            min_samples_leaf: 1,
645            max_features: 0,
646        };
647        let tree = build_tree(&features, &targets, &config, 42);
648
649        // Should predict close to the training values
650        let pred = tree.predict(&[4.0]);
651        assert!((pred - 4.0).abs() < 2.0);
652    }
653
654    #[test]
655    fn test_random_forest() {
656        let features: Vec<Vec<f64>> = (0..20).map(|i| vec![i as f64]).collect();
657        let targets: Vec<f64> = (0..20).map(|i| (i as f64) * 2.0 + 1.0).collect();
658
659        let config = RandomForestConfig {
660            n_trees: 10,
661            seed: 42,
662            ..Default::default()
663        };
664        let model = train_random_forest(&features, &targets, &config);
665        assert_eq!(model.trees.len(), 10);
666
667        let (pred, var) = model.predict_with_variance(&[10.0]);
668        assert!(pred > 0.0);
669        assert!(var >= 0.0);
670    }
671
672    #[test]
673    fn test_gradient_boosting() {
674        let features: Vec<Vec<f64>> = (0..20).map(|i| vec![i as f64]).collect();
675        let targets: Vec<f64> = (0..20).map(|i| (i as f64) * 2.0 + 1.0).collect();
676
677        let config = GradientBoostingConfig {
678            n_estimators: 20,
679            learning_rate: 0.3,
680            seed: 42,
681            ..Default::default()
682        };
683        let model = train_gradient_boosting(&features, &targets, &config);
684        let pred = model.predict(&[10.0]);
685        assert!(pred > 0.0);
686    }
687
688    #[test]
689    fn test_cross_validation() {
690        let features: Vec<Vec<f64>> = (0..30).map(|i| vec![i as f64]).collect();
691        let targets: Vec<f64> = (0..30).map(|i| (i as f64) * 2.0).collect();
692
693        let config = RandomForestConfig {
694            n_trees: 5,
695            seed: 42,
696            ..Default::default()
697        };
698        let cv = cross_validate(&features, &targets, &ModelType::RandomForest(config), 5, 42);
699        assert_eq!(cv.k, 5);
700        assert_eq!(cv.fold_mae.len(), 5);
701    }
702
703    #[test]
704    fn test_isotonic_calibrator() {
705        let preds = vec![1.0, 2.0, 3.0, 4.0, 5.0];
706        let targets = vec![1.1, 2.5, 2.8, 4.2, 5.1];
707
708        let cal = IsotonicCalibrator::fit(&preds, &targets);
709        let calibrated = cal.calibrate(3.0);
710        // Should return a value near 2.8 (the fitted value at x=3)
711        assert!((calibrated - 2.8).abs() < 1.0);
712    }
713}