Skip to main content

scry_learn/tree/
random_forest.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Random Forest — parallel ensemble of CART decision trees.
3//!
4//! Uses bootstrap sampling and random feature subsets for each tree,
5//! trained in parallel via rayon.
6
7use crate::dataset::Dataset;
8use crate::error::{Result, ScryLearnError};
9use crate::tree::cart::{DecisionTreeClassifier, DecisionTreeRegressor};
10use crate::weights::ClassWeight;
11use rayon::prelude::*;
12
13/// Strategy for selecting the number of features per split.
14#[derive(Clone, Copy, Debug)]
15#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
16#[non_exhaustive]
17pub enum MaxFeatures {
18    /// `√n_features` (default for classification).
19    Sqrt,
20    /// `log₂(n_features)`.
21    Log2,
22    /// Use all features (no bagging).
23    All,
24    /// A fixed count.
25    Fixed(usize),
26}
27
28impl MaxFeatures {
29    fn resolve(self, n_features: usize) -> usize {
30        match self {
31            Self::Sqrt => (n_features as f64).sqrt().ceil() as usize,
32            Self::Log2 => (n_features as f64).log2().ceil() as usize,
33            Self::All => n_features,
34            Self::Fixed(n) => n.min(n_features),
35        }
36        .max(1)
37    }
38}
39
40// ---------------------------------------------------------------------------
41// Random Forest Classifier
42// ---------------------------------------------------------------------------
43
44/// Random Forest for classification.
45///
46/// Trains an ensemble of decision trees in parallel, each on a bootstrap
47/// sample with a random subset of features. Predictions are by majority vote.
48#[derive(Clone)]
49#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
50#[non_exhaustive]
51pub struct RandomForestClassifier {
52    n_estimators: usize,
53    max_depth: Option<usize>,
54    max_features: MaxFeatures,
55    min_samples_split: usize,
56    min_samples_leaf: usize,
57    bootstrap: bool,
58    seed: u64,
59    class_weight: ClassWeight,
60    trees: Vec<DecisionTreeClassifier>,
61    n_classes: usize,
62    n_features: usize,
63    feature_importances_: Vec<f64>,
64    oob_score_: Option<f64>,
65    #[cfg_attr(feature = "serde", serde(default))]
66    _schema_version: u32,
67}
68
69impl RandomForestClassifier {
70    /// Create a new random forest with default parameters.
71    pub fn new() -> Self {
72        Self {
73            n_estimators: 100,
74            max_depth: None,
75            max_features: MaxFeatures::Sqrt,
76            min_samples_split: 2,
77            min_samples_leaf: 1,
78            bootstrap: true,
79            seed: 42,
80            class_weight: ClassWeight::Uniform,
81            trees: Vec::new(),
82            n_classes: 0,
83            n_features: 0,
84            feature_importances_: Vec::new(),
85            oob_score_: None,
86            _schema_version: crate::version::SCHEMA_VERSION,
87        }
88    }
89
90    /// Set number of trees.
91    pub fn n_estimators(mut self, n: usize) -> Self {
92        self.n_estimators = n;
93        self
94    }
95
96    /// Set maximum depth per tree.
97    pub fn max_depth(mut self, d: usize) -> Self {
98        self.max_depth = Some(d);
99        self
100    }
101
102    /// Set feature selection strategy.
103    pub fn max_features(mut self, mf: MaxFeatures) -> Self {
104        self.max_features = mf;
105        self
106    }
107
108    /// Set minimum samples to split.
109    pub fn min_samples_split(mut self, n: usize) -> Self {
110        self.min_samples_split = n;
111        self
112    }
113
114    /// Set minimum samples per leaf.
115    pub fn min_samples_leaf(mut self, n: usize) -> Self {
116        self.min_samples_leaf = n;
117        self
118    }
119
120    /// Enable/disable bootstrap sampling.
121    pub fn bootstrap(mut self, b: bool) -> Self {
122        self.bootstrap = b;
123        self
124    }
125
126    /// Set the random seed.
127    pub fn seed(mut self, s: u64) -> Self {
128        self.seed = s;
129        self
130    }
131
132    /// Set class weighting strategy for imbalanced datasets.
133    pub fn class_weight(mut self, cw: ClassWeight) -> Self {
134        self.class_weight = cw;
135        self
136    }
137
138    /// Train the random forest.
139    ///
140    /// OOB votes are accumulated into a shared atomic array during parallel build,
141    /// avoiding retention of per-tree vote arrays or bootstrap indices.
142    /// Dataset indices are pre-sorted once and shared across all trees to avoid
143    /// per-tree `sorted_by_feature` allocation (~6 MB savings with 16 threads).
144    pub fn fit(&mut self, data: &Dataset) -> Result<()> {
145        data.validate_finite()?;
146        use std::sync::atomic::{AtomicU32, Ordering};
147
148        if data.n_samples() == 0 {
149            return Err(ScryLearnError::EmptyDataset);
150        }
151
152        self.n_features = data.n_features();
153        self.n_classes = data.n_classes();
154        let max_feats = self.max_features.resolve(self.n_features);
155        let do_bootstrap = self.bootstrap;
156        let n_samples = data.n_samples();
157        let n_classes = self.n_classes;
158        let feature_matrix = data.feature_matrix();
159        let n_features = data.n_features();
160
161        // Pre-sort ALL dataset indices by each feature once (shared read-only).
162        // Each tree filters via membership bitset for its bootstrap sample.
163        let global_sorted: Vec<Vec<usize>> = (0..n_features)
164            .map(|feat_idx| {
165                let col = &data.features[feat_idx];
166                let mut sorted: Vec<usize> = (0..n_samples).collect();
167                sorted.sort_unstable_by(|&a, &b| {
168                    col[a]
169                        .partial_cmp(&col[b])
170                        .unwrap_or(std::cmp::Ordering::Equal)
171                });
172                sorted
173            })
174            .collect();
175        let global_sorted_ref = &global_sorted;
176
177        // Shared OOB accumulator: oob_votes[sample * n_classes + class].
178        // Atomic u32 so multiple threads can update without locking.
179        let oob_votes: Vec<AtomicU32> = (0..n_samples * n_classes)
180            .map(|_| AtomicU32::new(0))
181            .collect();
182        let oob_votes_ref = &oob_votes;
183
184        // Train trees in parallel. OOB votes are merged directly into the
185        // shared accumulator — no per-tree vote arrays are ever stored.
186        let mut trees: Vec<DecisionTreeClassifier> = (0..self.n_estimators)
187            .into_par_iter()
188            .map(|tree_idx| {
189                let mut rng = crate::rng::FastRng::new(self.seed.wrapping_add(tree_idx as u64));
190                let n = n_samples;
191
192                // Bootstrap sample.
193                let indices: Vec<usize> = if do_bootstrap {
194                    (0..n).map(|_| rng.usize(0..n)).collect()
195                } else {
196                    (0..n).collect()
197                };
198
199                let mut tree = DecisionTreeClassifier::new()
200                    .max_features(max_feats)
201                    .min_samples_split(self.min_samples_split)
202                    .min_samples_leaf(self.min_samples_leaf)
203                    .class_weight(self.class_weight.clone());
204
205                if let Some(d) = self.max_depth {
206                    tree = tree.max_depth(d);
207                }
208
209                // Train using shared pre-sorted indices — no per-tree sort allocation.
210                tree.fit_on_indices_presorted(data, &indices, global_sorted_ref)
211                    .ok();
212
213                // Compute OOB votes inline and merge into shared accumulator.
214                // Bootstrap indices and bitset are dropped at end of closure.
215                if do_bootstrap {
216                    if let Some(ref ft) = tree.flat_tree {
217                        // Build compact bitset of in-bag samples.
218                        let n_words = n.div_ceil(64);
219                        let mut in_bag = vec![0u64; n_words];
220                        for &idx in &indices {
221                            in_bag[idx / 64] |= 1u64 << (idx % 64);
222                        }
223
224                        // Vote for OOB samples, merging directly into shared accumulator.
225                        for sample_idx in 0..n {
226                            if in_bag[sample_idx / 64] & (1u64 << (sample_idx % 64)) != 0 {
227                                continue;
228                            }
229                            let pred = ft.predict_sample(&feature_matrix[sample_idx]) as usize;
230                            if pred < n_classes {
231                                oob_votes_ref[sample_idx * n_classes + pred]
232                                    .fetch_add(1, Ordering::Relaxed);
233                            }
234                        }
235                    }
236                }
237
238                tree
239            })
240            .collect();
241
242        // Aggregate feature importances.
243        self.feature_importances_ = vec![0.0; self.n_features];
244        for tree in &trees {
245            if let Ok(imp) = tree.feature_importances() {
246                for (i, &v) in imp.iter().enumerate() {
247                    self.feature_importances_[i] += v;
248                }
249            }
250        }
251        let n_trees = trees.len() as f64;
252        for imp in &mut self.feature_importances_ {
253            *imp /= n_trees;
254        }
255
256        // Compute OOB accuracy from accumulated atomic votes.
257        self.oob_score_ = if do_bootstrap {
258            // Convert atomics to plain u32 for scoring.
259            let totals: Vec<u32> = oob_votes
260                .iter()
261                .map(|a| a.load(Ordering::Relaxed))
262                .collect();
263            Self::oob_accuracy_from_votes(&totals, n_samples, n_classes, &data.target)
264        } else {
265            None
266        };
267
268        // Clear per-tree training-only data to save memory.
269        for tree in &mut trees {
270            tree.sample_weights = None;
271            tree.feature_importances_ = Vec::new();
272        }
273
274        self.trees = trees;
275        Ok(())
276    }
277
278    /// Compute OOB accuracy from flat vote accumulation array.
279    fn oob_accuracy_from_votes(
280        oob_total: &[u32],
281        n_samples: usize,
282        n_classes: usize,
283        target: &[f64],
284    ) -> Option<f64> {
285        let mut correct = 0usize;
286        let mut total = 0usize;
287        for sample_idx in 0..n_samples {
288            let row = &oob_total[sample_idx * n_classes..(sample_idx + 1) * n_classes];
289            let vote_count: u32 = row.iter().sum();
290            if vote_count == 0 {
291                continue;
292            }
293            let predicted_class = row
294                .iter()
295                .enumerate()
296                .max_by_key(|&(_, &v)| v)
297                .map_or(0, |(idx, _)| idx);
298            let true_class = target[sample_idx] as usize;
299            if predicted_class == true_class {
300                correct += 1;
301            }
302            total += 1;
303        }
304
305        if total > 0 {
306            Some(correct as f64 / total as f64)
307        } else {
308            None
309        }
310    }
311
312    /// Predict class labels by majority vote.
313    ///
314    /// Uses `FlatTree::predict_sample` for cache-optimal traversal.
315    /// Parallelized across samples via rayon.
316    pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
317        crate::version::check_schema_version(self._schema_version)?;
318        if self.trees.is_empty() {
319            return Err(ScryLearnError::NotFitted);
320        }
321
322        let n_classes = self.n_classes;
323        let predictions: Vec<f64> = features
324            .par_iter()
325            .map(|sample| {
326                let mut votes = vec![0usize; n_classes];
327                for tree in &self.trees {
328                    if let Some(ref ft) = tree.flat_tree {
329                        let class = ft.predict_sample(sample) as usize;
330                        if class < n_classes {
331                            votes[class] += 1;
332                        }
333                    }
334                }
335                votes
336                    .iter()
337                    .enumerate()
338                    .max_by_key(|&(_, &v)| v)
339                    .map_or(0.0, |(idx, _)| idx as f64)
340            })
341            .collect();
342
343        Ok(predictions)
344    }
345
346    /// Predict class probabilities (average across trees).
347    ///
348    /// Parallelized across samples via rayon.
349    pub fn predict_proba(&self, features: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
350        if self.trees.is_empty() {
351            return Err(ScryLearnError::NotFitted);
352        }
353
354        let n_classes = self.n_classes;
355        let n_trees = self.trees.len() as f64;
356
357        let probas: Vec<Vec<f64>> = features
358            .par_iter()
359            .map(|sample| {
360                let mut proba = vec![0.0; n_classes];
361                for tree in &self.trees {
362                    if let Some(ref ft) = tree.flat_tree {
363                        let tree_proba = ft.predict_proba_sample(sample, n_classes);
364                        for (j, p) in tree_proba.into_iter().enumerate() {
365                            if j < n_classes {
366                                proba[j] += p;
367                            }
368                        }
369                    }
370                }
371                for p in &mut proba {
372                    *p /= n_trees;
373                }
374                proba
375            })
376            .collect();
377
378        Ok(probas)
379    }
380
381    /// Feature importances averaged across all trees.
382    pub fn feature_importances(&self) -> Result<Vec<f64>> {
383        if self.trees.is_empty() {
384            return Err(ScryLearnError::NotFitted);
385        }
386        Ok(self.feature_importances_.clone())
387    }
388
389    /// Out-of-bag accuracy score (available after fit with bootstrap=true).
390    pub fn oob_score(&self) -> Option<f64> {
391        self.oob_score_
392    }
393
394    /// Number of trained trees.
395    pub fn n_trees(&self) -> usize {
396        self.trees.len()
397    }
398
399    /// Get individual trees (for visualization or inspection).
400    pub fn trees(&self) -> &[DecisionTreeClassifier] {
401        &self.trees
402    }
403
404    /// Number of classes the model was trained on.
405    pub fn n_classes(&self) -> usize {
406        self.n_classes
407    }
408
409    /// Number of features the model was trained on.
410    pub fn n_features(&self) -> usize {
411        self.n_features
412    }
413}
414
415impl Default for RandomForestClassifier {
416    fn default() -> Self {
417        Self::new()
418    }
419}
420
421// ---------------------------------------------------------------------------
422// Random Forest Regressor
423// ---------------------------------------------------------------------------
424
425/// Random Forest for regression (mean of tree predictions).
426#[derive(Clone)]
427#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
428#[non_exhaustive]
429pub struct RandomForestRegressor {
430    n_estimators: usize,
431    max_depth: Option<usize>,
432    max_features: MaxFeatures,
433    min_samples_split: usize,
434    min_samples_leaf: usize,
435    bootstrap: bool,
436    seed: u64,
437    trees: Vec<DecisionTreeRegressor>,
438    n_features: usize,
439    feature_importances_: Vec<f64>,
440    #[cfg_attr(feature = "serde", serde(default))]
441    _schema_version: u32,
442}
443
444impl RandomForestRegressor {
445    /// Create a new regressor forest.
446    pub fn new() -> Self {
447        Self {
448            n_estimators: 100,
449            max_depth: None,
450            max_features: MaxFeatures::All,
451            min_samples_split: 2,
452            min_samples_leaf: 1,
453            bootstrap: true,
454            seed: 42,
455            trees: Vec::new(),
456            n_features: 0,
457            feature_importances_: Vec::new(),
458            _schema_version: crate::version::SCHEMA_VERSION,
459        }
460    }
461
462    /// Set number of trees.
463    pub fn n_estimators(mut self, n: usize) -> Self {
464        self.n_estimators = n;
465        self
466    }
467
468    /// Set maximum depth.
469    pub fn max_depth(mut self, d: usize) -> Self {
470        self.max_depth = Some(d);
471        self
472    }
473
474    /// Set feature selection strategy.
475    pub fn max_features(mut self, mf: MaxFeatures) -> Self {
476        self.max_features = mf;
477        self
478    }
479
480    /// Set random seed.
481    pub fn seed(mut self, s: u64) -> Self {
482        self.seed = s;
483        self
484    }
485
486    /// Train the forest.
487    pub fn fit(&mut self, data: &Dataset) -> Result<()> {
488        data.validate_finite()?;
489        if data.n_samples() == 0 {
490            return Err(ScryLearnError::EmptyDataset);
491        }
492
493        self.n_features = data.n_features();
494        let max_feats = self.max_features.resolve(self.n_features);
495
496        let mut trees: Vec<DecisionTreeRegressor> = (0..self.n_estimators)
497            .into_par_iter()
498            .map(|tree_idx| {
499                let mut rng = crate::rng::FastRng::new(self.seed.wrapping_add(tree_idx as u64));
500                let n = data.n_samples();
501
502                let indices: Vec<usize> = if self.bootstrap {
503                    (0..n).map(|_| rng.usize(0..n)).collect()
504                } else {
505                    (0..n).collect()
506                };
507
508                let mut tree = DecisionTreeRegressor::new()
509                    .max_features(max_feats)
510                    .min_samples_split(self.min_samples_split)
511                    .min_samples_leaf(self.min_samples_leaf);
512
513                if let Some(d) = self.max_depth {
514                    tree = tree.max_depth(d);
515                }
516
517                // Train directly on indices — no data copy.
518                tree.fit_on_indices(data, &indices).ok();
519                tree
520            })
521            .collect();
522
523        self.feature_importances_ = vec![0.0; self.n_features];
524        for tree in &trees {
525            if let Ok(imp) = tree.feature_importances() {
526                for (i, &v) in imp.iter().enumerate() {
527                    self.feature_importances_[i] += v;
528                }
529            }
530        }
531        let n_trees = trees.len() as f64;
532        for imp in &mut self.feature_importances_ {
533            *imp /= n_trees;
534        }
535
536        // Clear per-tree training-only data to save memory.
537        for tree in &mut trees {
538            tree.feature_importances_ = Vec::new();
539        }
540
541        self.trees = trees;
542        Ok(())
543    }
544
545    /// Predict values (mean across trees).
546    ///
547    /// Parallelized across samples via rayon.
548    pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
549        crate::version::check_schema_version(self._schema_version)?;
550        if self.trees.is_empty() {
551            return Err(ScryLearnError::NotFitted);
552        }
553
554        let n_trees = self.trees.len() as f64;
555
556        let predictions: Vec<f64> = features
557            .par_iter()
558            .map(|sample| {
559                let mut sum = 0.0;
560                for tree in &self.trees {
561                    if let Some(ref ft) = tree.flat_tree {
562                        sum += ft.predict_sample(sample);
563                    }
564                }
565                sum / n_trees
566            })
567            .collect();
568
569        Ok(predictions)
570    }
571
572    /// Feature importances.
573    pub fn feature_importances(&self) -> Result<Vec<f64>> {
574        if self.trees.is_empty() {
575            return Err(ScryLearnError::NotFitted);
576        }
577        Ok(self.feature_importances_.clone())
578    }
579
580    /// Get individual trees (for inspection or ONNX export).
581    pub fn trees(&self) -> &[DecisionTreeRegressor] {
582        &self.trees
583    }
584
585    /// Number of features the model was trained on.
586    pub fn n_features(&self) -> usize {
587        self.n_features
588    }
589}
590
591impl Default for RandomForestRegressor {
592    fn default() -> Self {
593        Self::new()
594    }
595}
596
597#[cfg(test)]
598mod tests {
599    use super::*;
600
601    fn make_classification_data() -> Dataset {
602        // Two features, clear separation.
603        let n = 100;
604        let mut f1 = Vec::with_capacity(n);
605        let mut f2 = Vec::with_capacity(n);
606        let mut target = Vec::with_capacity(n);
607        let mut rng = crate::rng::FastRng::new(42);
608
609        for _ in 0..n / 2 {
610            f1.push(rng.f64() * 3.0);
611            f2.push(rng.f64() * 3.0);
612            target.push(0.0);
613        }
614        for _ in 0..n / 2 {
615            f1.push(rng.f64() * 3.0 + 5.0);
616            f2.push(rng.f64() * 3.0 + 5.0);
617            target.push(1.0);
618        }
619
620        Dataset::new(
621            vec![f1, f2],
622            target,
623            vec!["f1".into(), "f2".into()],
624            "class",
625        )
626    }
627
628    #[test]
629    fn test_random_forest_classification() {
630        let data = make_classification_data();
631        let mut rf = RandomForestClassifier::new()
632            .n_estimators(20)
633            .max_depth(5)
634            .seed(42);
635        rf.fit(&data).unwrap();
636
637        let matrix = data.feature_matrix();
638        let preds = rf.predict(&matrix).unwrap();
639        let acc = preds
640            .iter()
641            .zip(data.target.iter())
642            .filter(|(p, t)| (*p - *t).abs() < 1e-6)
643            .count() as f64
644            / data.n_samples() as f64;
645
646        assert!(
647            acc >= 0.90,
648            "expected ≥90% accuracy, got {:.1}%",
649            acc * 100.0
650        );
651    }
652
653    #[test]
654    fn test_feature_importances_valid() {
655        let data = make_classification_data();
656        let mut rf = RandomForestClassifier::new().n_estimators(10).seed(42);
657        rf.fit(&data).unwrap();
658
659        let imp = rf.feature_importances().unwrap();
660        assert_eq!(imp.len(), 2);
661        assert!(imp.iter().all(|&v| v >= 0.0));
662    }
663
664    #[test]
665    fn test_predict_proba() {
666        let data = make_classification_data();
667        let mut rf = RandomForestClassifier::new().n_estimators(10).seed(42);
668        rf.fit(&data).unwrap();
669
670        let sample = vec![1.0, 1.0]; // should be class 0
671        let proba = rf.predict_proba(&[sample]).unwrap();
672        assert!(proba[0][0] > 0.5, "should predict class 0 with >50%");
673    }
674
675    #[test]
676    fn test_oob_score_with_bootstrap() {
677        let data = make_classification_data();
678        let mut rf = RandomForestClassifier::new()
679            .n_estimators(50)
680            .max_depth(5)
681            .bootstrap(true)
682            .seed(42);
683        rf.fit(&data).unwrap();
684
685        let oob = rf.oob_score();
686        assert!(
687            oob.is_some(),
688            "OOB score should be available with bootstrap=true"
689        );
690        let score = oob.unwrap();
691        assert!(score >= 0.80, "expected OOB score ≥ 0.80, got {:.3}", score);
692        assert!(score <= 1.0, "OOB score should be ≤ 1.0, got {:.3}", score);
693    }
694
695    #[test]
696    fn test_oob_score_without_bootstrap() {
697        let data = make_classification_data();
698        let mut rf = RandomForestClassifier::new()
699            .n_estimators(10)
700            .bootstrap(false)
701            .seed(42);
702        rf.fit(&data).unwrap();
703
704        assert!(
705            rf.oob_score().is_none(),
706            "OOB score should be None when bootstrap=false"
707        );
708    }
709}