1use crate::dataset::Dataset;
8use crate::error::{Result, ScryLearnError};
9use crate::tree::cart::{DecisionTreeClassifier, DecisionTreeRegressor};
10use crate::weights::ClassWeight;
11use rayon::prelude::*;
12
13#[derive(Clone, Copy, Debug)]
15#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
16#[non_exhaustive]
17pub enum MaxFeatures {
18 Sqrt,
20 Log2,
22 All,
24 Fixed(usize),
26}
27
28impl MaxFeatures {
29 fn resolve(self, n_features: usize) -> usize {
30 match self {
31 Self::Sqrt => (n_features as f64).sqrt().ceil() as usize,
32 Self::Log2 => (n_features as f64).log2().ceil() as usize,
33 Self::All => n_features,
34 Self::Fixed(n) => n.min(n_features),
35 }
36 .max(1)
37 }
38}
39
40#[derive(Clone)]
49#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
50#[non_exhaustive]
51pub struct RandomForestClassifier {
52 n_estimators: usize,
53 max_depth: Option<usize>,
54 max_features: MaxFeatures,
55 min_samples_split: usize,
56 min_samples_leaf: usize,
57 bootstrap: bool,
58 seed: u64,
59 class_weight: ClassWeight,
60 trees: Vec<DecisionTreeClassifier>,
61 n_classes: usize,
62 n_features: usize,
63 feature_importances_: Vec<f64>,
64 oob_score_: Option<f64>,
65 #[cfg_attr(feature = "serde", serde(default))]
66 _schema_version: u32,
67}
68
69impl RandomForestClassifier {
70 pub fn new() -> Self {
72 Self {
73 n_estimators: 100,
74 max_depth: None,
75 max_features: MaxFeatures::Sqrt,
76 min_samples_split: 2,
77 min_samples_leaf: 1,
78 bootstrap: true,
79 seed: 42,
80 class_weight: ClassWeight::Uniform,
81 trees: Vec::new(),
82 n_classes: 0,
83 n_features: 0,
84 feature_importances_: Vec::new(),
85 oob_score_: None,
86 _schema_version: crate::version::SCHEMA_VERSION,
87 }
88 }
89
90 pub fn n_estimators(mut self, n: usize) -> Self {
92 self.n_estimators = n;
93 self
94 }
95
96 pub fn max_depth(mut self, d: usize) -> Self {
98 self.max_depth = Some(d);
99 self
100 }
101
102 pub fn max_features(mut self, mf: MaxFeatures) -> Self {
104 self.max_features = mf;
105 self
106 }
107
108 pub fn min_samples_split(mut self, n: usize) -> Self {
110 self.min_samples_split = n;
111 self
112 }
113
114 pub fn min_samples_leaf(mut self, n: usize) -> Self {
116 self.min_samples_leaf = n;
117 self
118 }
119
120 pub fn bootstrap(mut self, b: bool) -> Self {
122 self.bootstrap = b;
123 self
124 }
125
126 pub fn seed(mut self, s: u64) -> Self {
128 self.seed = s;
129 self
130 }
131
132 pub fn class_weight(mut self, cw: ClassWeight) -> Self {
134 self.class_weight = cw;
135 self
136 }
137
138 pub fn fit(&mut self, data: &Dataset) -> Result<()> {
145 data.validate_finite()?;
146 use std::sync::atomic::{AtomicU32, Ordering};
147
148 if data.n_samples() == 0 {
149 return Err(ScryLearnError::EmptyDataset);
150 }
151
152 self.n_features = data.n_features();
153 self.n_classes = data.n_classes();
154 let max_feats = self.max_features.resolve(self.n_features);
155 let do_bootstrap = self.bootstrap;
156 let n_samples = data.n_samples();
157 let n_classes = self.n_classes;
158 let feature_matrix = data.feature_matrix();
159 let n_features = data.n_features();
160
161 let global_sorted: Vec<Vec<usize>> = (0..n_features)
164 .map(|feat_idx| {
165 let col = &data.features[feat_idx];
166 let mut sorted: Vec<usize> = (0..n_samples).collect();
167 sorted.sort_unstable_by(|&a, &b| {
168 col[a]
169 .partial_cmp(&col[b])
170 .unwrap_or(std::cmp::Ordering::Equal)
171 });
172 sorted
173 })
174 .collect();
175 let global_sorted_ref = &global_sorted;
176
177 let oob_votes: Vec<AtomicU32> = (0..n_samples * n_classes)
180 .map(|_| AtomicU32::new(0))
181 .collect();
182 let oob_votes_ref = &oob_votes;
183
184 let mut trees: Vec<DecisionTreeClassifier> = (0..self.n_estimators)
187 .into_par_iter()
188 .map(|tree_idx| {
189 let mut rng = crate::rng::FastRng::new(self.seed.wrapping_add(tree_idx as u64));
190 let n = n_samples;
191
192 let indices: Vec<usize> = if do_bootstrap {
194 (0..n).map(|_| rng.usize(0..n)).collect()
195 } else {
196 (0..n).collect()
197 };
198
199 let mut tree = DecisionTreeClassifier::new()
200 .max_features(max_feats)
201 .min_samples_split(self.min_samples_split)
202 .min_samples_leaf(self.min_samples_leaf)
203 .class_weight(self.class_weight.clone());
204
205 if let Some(d) = self.max_depth {
206 tree = tree.max_depth(d);
207 }
208
209 tree.fit_on_indices_presorted(data, &indices, global_sorted_ref)
211 .ok();
212
213 if do_bootstrap {
216 if let Some(ref ft) = tree.flat_tree {
217 let n_words = n.div_ceil(64);
219 let mut in_bag = vec![0u64; n_words];
220 for &idx in &indices {
221 in_bag[idx / 64] |= 1u64 << (idx % 64);
222 }
223
224 for sample_idx in 0..n {
226 if in_bag[sample_idx / 64] & (1u64 << (sample_idx % 64)) != 0 {
227 continue;
228 }
229 let pred = ft.predict_sample(&feature_matrix[sample_idx]) as usize;
230 if pred < n_classes {
231 oob_votes_ref[sample_idx * n_classes + pred]
232 .fetch_add(1, Ordering::Relaxed);
233 }
234 }
235 }
236 }
237
238 tree
239 })
240 .collect();
241
242 self.feature_importances_ = vec![0.0; self.n_features];
244 for tree in &trees {
245 if let Ok(imp) = tree.feature_importances() {
246 for (i, &v) in imp.iter().enumerate() {
247 self.feature_importances_[i] += v;
248 }
249 }
250 }
251 let n_trees = trees.len() as f64;
252 for imp in &mut self.feature_importances_ {
253 *imp /= n_trees;
254 }
255
256 self.oob_score_ = if do_bootstrap {
258 let totals: Vec<u32> = oob_votes
260 .iter()
261 .map(|a| a.load(Ordering::Relaxed))
262 .collect();
263 Self::oob_accuracy_from_votes(&totals, n_samples, n_classes, &data.target)
264 } else {
265 None
266 };
267
268 for tree in &mut trees {
270 tree.sample_weights = None;
271 tree.feature_importances_ = Vec::new();
272 }
273
274 self.trees = trees;
275 Ok(())
276 }
277
278 fn oob_accuracy_from_votes(
280 oob_total: &[u32],
281 n_samples: usize,
282 n_classes: usize,
283 target: &[f64],
284 ) -> Option<f64> {
285 let mut correct = 0usize;
286 let mut total = 0usize;
287 for sample_idx in 0..n_samples {
288 let row = &oob_total[sample_idx * n_classes..(sample_idx + 1) * n_classes];
289 let vote_count: u32 = row.iter().sum();
290 if vote_count == 0 {
291 continue;
292 }
293 let predicted_class = row
294 .iter()
295 .enumerate()
296 .max_by_key(|&(_, &v)| v)
297 .map_or(0, |(idx, _)| idx);
298 let true_class = target[sample_idx] as usize;
299 if predicted_class == true_class {
300 correct += 1;
301 }
302 total += 1;
303 }
304
305 if total > 0 {
306 Some(correct as f64 / total as f64)
307 } else {
308 None
309 }
310 }
311
312 pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
317 crate::version::check_schema_version(self._schema_version)?;
318 if self.trees.is_empty() {
319 return Err(ScryLearnError::NotFitted);
320 }
321
322 let n_classes = self.n_classes;
323 let predictions: Vec<f64> = features
324 .par_iter()
325 .map(|sample| {
326 let mut votes = vec![0usize; n_classes];
327 for tree in &self.trees {
328 if let Some(ref ft) = tree.flat_tree {
329 let class = ft.predict_sample(sample) as usize;
330 if class < n_classes {
331 votes[class] += 1;
332 }
333 }
334 }
335 votes
336 .iter()
337 .enumerate()
338 .max_by_key(|&(_, &v)| v)
339 .map_or(0.0, |(idx, _)| idx as f64)
340 })
341 .collect();
342
343 Ok(predictions)
344 }
345
346 pub fn predict_proba(&self, features: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
350 if self.trees.is_empty() {
351 return Err(ScryLearnError::NotFitted);
352 }
353
354 let n_classes = self.n_classes;
355 let n_trees = self.trees.len() as f64;
356
357 let probas: Vec<Vec<f64>> = features
358 .par_iter()
359 .map(|sample| {
360 let mut proba = vec![0.0; n_classes];
361 for tree in &self.trees {
362 if let Some(ref ft) = tree.flat_tree {
363 let tree_proba = ft.predict_proba_sample(sample, n_classes);
364 for (j, p) in tree_proba.into_iter().enumerate() {
365 if j < n_classes {
366 proba[j] += p;
367 }
368 }
369 }
370 }
371 for p in &mut proba {
372 *p /= n_trees;
373 }
374 proba
375 })
376 .collect();
377
378 Ok(probas)
379 }
380
381 pub fn feature_importances(&self) -> Result<Vec<f64>> {
383 if self.trees.is_empty() {
384 return Err(ScryLearnError::NotFitted);
385 }
386 Ok(self.feature_importances_.clone())
387 }
388
389 pub fn oob_score(&self) -> Option<f64> {
391 self.oob_score_
392 }
393
394 pub fn n_trees(&self) -> usize {
396 self.trees.len()
397 }
398
399 pub fn trees(&self) -> &[DecisionTreeClassifier] {
401 &self.trees
402 }
403
404 pub fn n_classes(&self) -> usize {
406 self.n_classes
407 }
408
409 pub fn n_features(&self) -> usize {
411 self.n_features
412 }
413}
414
415impl Default for RandomForestClassifier {
416 fn default() -> Self {
417 Self::new()
418 }
419}
420
421#[derive(Clone)]
427#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
428#[non_exhaustive]
429pub struct RandomForestRegressor {
430 n_estimators: usize,
431 max_depth: Option<usize>,
432 max_features: MaxFeatures,
433 min_samples_split: usize,
434 min_samples_leaf: usize,
435 bootstrap: bool,
436 seed: u64,
437 trees: Vec<DecisionTreeRegressor>,
438 n_features: usize,
439 feature_importances_: Vec<f64>,
440 #[cfg_attr(feature = "serde", serde(default))]
441 _schema_version: u32,
442}
443
444impl RandomForestRegressor {
445 pub fn new() -> Self {
447 Self {
448 n_estimators: 100,
449 max_depth: None,
450 max_features: MaxFeatures::All,
451 min_samples_split: 2,
452 min_samples_leaf: 1,
453 bootstrap: true,
454 seed: 42,
455 trees: Vec::new(),
456 n_features: 0,
457 feature_importances_: Vec::new(),
458 _schema_version: crate::version::SCHEMA_VERSION,
459 }
460 }
461
462 pub fn n_estimators(mut self, n: usize) -> Self {
464 self.n_estimators = n;
465 self
466 }
467
468 pub fn max_depth(mut self, d: usize) -> Self {
470 self.max_depth = Some(d);
471 self
472 }
473
474 pub fn max_features(mut self, mf: MaxFeatures) -> Self {
476 self.max_features = mf;
477 self
478 }
479
480 pub fn seed(mut self, s: u64) -> Self {
482 self.seed = s;
483 self
484 }
485
486 pub fn fit(&mut self, data: &Dataset) -> Result<()> {
488 data.validate_finite()?;
489 if data.n_samples() == 0 {
490 return Err(ScryLearnError::EmptyDataset);
491 }
492
493 self.n_features = data.n_features();
494 let max_feats = self.max_features.resolve(self.n_features);
495
496 let mut trees: Vec<DecisionTreeRegressor> = (0..self.n_estimators)
497 .into_par_iter()
498 .map(|tree_idx| {
499 let mut rng = crate::rng::FastRng::new(self.seed.wrapping_add(tree_idx as u64));
500 let n = data.n_samples();
501
502 let indices: Vec<usize> = if self.bootstrap {
503 (0..n).map(|_| rng.usize(0..n)).collect()
504 } else {
505 (0..n).collect()
506 };
507
508 let mut tree = DecisionTreeRegressor::new()
509 .max_features(max_feats)
510 .min_samples_split(self.min_samples_split)
511 .min_samples_leaf(self.min_samples_leaf);
512
513 if let Some(d) = self.max_depth {
514 tree = tree.max_depth(d);
515 }
516
517 tree.fit_on_indices(data, &indices).ok();
519 tree
520 })
521 .collect();
522
523 self.feature_importances_ = vec![0.0; self.n_features];
524 for tree in &trees {
525 if let Ok(imp) = tree.feature_importances() {
526 for (i, &v) in imp.iter().enumerate() {
527 self.feature_importances_[i] += v;
528 }
529 }
530 }
531 let n_trees = trees.len() as f64;
532 for imp in &mut self.feature_importances_ {
533 *imp /= n_trees;
534 }
535
536 for tree in &mut trees {
538 tree.feature_importances_ = Vec::new();
539 }
540
541 self.trees = trees;
542 Ok(())
543 }
544
545 pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
549 crate::version::check_schema_version(self._schema_version)?;
550 if self.trees.is_empty() {
551 return Err(ScryLearnError::NotFitted);
552 }
553
554 let n_trees = self.trees.len() as f64;
555
556 let predictions: Vec<f64> = features
557 .par_iter()
558 .map(|sample| {
559 let mut sum = 0.0;
560 for tree in &self.trees {
561 if let Some(ref ft) = tree.flat_tree {
562 sum += ft.predict_sample(sample);
563 }
564 }
565 sum / n_trees
566 })
567 .collect();
568
569 Ok(predictions)
570 }
571
572 pub fn feature_importances(&self) -> Result<Vec<f64>> {
574 if self.trees.is_empty() {
575 return Err(ScryLearnError::NotFitted);
576 }
577 Ok(self.feature_importances_.clone())
578 }
579
580 pub fn trees(&self) -> &[DecisionTreeRegressor] {
582 &self.trees
583 }
584
585 pub fn n_features(&self) -> usize {
587 self.n_features
588 }
589}
590
591impl Default for RandomForestRegressor {
592 fn default() -> Self {
593 Self::new()
594 }
595}
596
597#[cfg(test)]
598mod tests {
599 use super::*;
600
601 fn make_classification_data() -> Dataset {
602 let n = 100;
604 let mut f1 = Vec::with_capacity(n);
605 let mut f2 = Vec::with_capacity(n);
606 let mut target = Vec::with_capacity(n);
607 let mut rng = crate::rng::FastRng::new(42);
608
609 for _ in 0..n / 2 {
610 f1.push(rng.f64() * 3.0);
611 f2.push(rng.f64() * 3.0);
612 target.push(0.0);
613 }
614 for _ in 0..n / 2 {
615 f1.push(rng.f64() * 3.0 + 5.0);
616 f2.push(rng.f64() * 3.0 + 5.0);
617 target.push(1.0);
618 }
619
620 Dataset::new(
621 vec![f1, f2],
622 target,
623 vec!["f1".into(), "f2".into()],
624 "class",
625 )
626 }
627
628 #[test]
629 fn test_random_forest_classification() {
630 let data = make_classification_data();
631 let mut rf = RandomForestClassifier::new()
632 .n_estimators(20)
633 .max_depth(5)
634 .seed(42);
635 rf.fit(&data).unwrap();
636
637 let matrix = data.feature_matrix();
638 let preds = rf.predict(&matrix).unwrap();
639 let acc = preds
640 .iter()
641 .zip(data.target.iter())
642 .filter(|(p, t)| (*p - *t).abs() < 1e-6)
643 .count() as f64
644 / data.n_samples() as f64;
645
646 assert!(
647 acc >= 0.90,
648 "expected ≥90% accuracy, got {:.1}%",
649 acc * 100.0
650 );
651 }
652
653 #[test]
654 fn test_feature_importances_valid() {
655 let data = make_classification_data();
656 let mut rf = RandomForestClassifier::new().n_estimators(10).seed(42);
657 rf.fit(&data).unwrap();
658
659 let imp = rf.feature_importances().unwrap();
660 assert_eq!(imp.len(), 2);
661 assert!(imp.iter().all(|&v| v >= 0.0));
662 }
663
664 #[test]
665 fn test_predict_proba() {
666 let data = make_classification_data();
667 let mut rf = RandomForestClassifier::new().n_estimators(10).seed(42);
668 rf.fit(&data).unwrap();
669
670 let sample = vec![1.0, 1.0]; let proba = rf.predict_proba(&[sample]).unwrap();
672 assert!(proba[0][0] > 0.5, "should predict class 0 with >50%");
673 }
674
675 #[test]
676 fn test_oob_score_with_bootstrap() {
677 let data = make_classification_data();
678 let mut rf = RandomForestClassifier::new()
679 .n_estimators(50)
680 .max_depth(5)
681 .bootstrap(true)
682 .seed(42);
683 rf.fit(&data).unwrap();
684
685 let oob = rf.oob_score();
686 assert!(
687 oob.is_some(),
688 "OOB score should be available with bootstrap=true"
689 );
690 let score = oob.unwrap();
691 assert!(score >= 0.80, "expected OOB score ≥ 0.80, got {:.3}", score);
692 assert!(score <= 1.0, "OOB score should be ≤ 1.0, got {:.3}", score);
693 }
694
695 #[test]
696 fn test_oob_score_without_bootstrap() {
697 let data = make_classification_data();
698 let mut rf = RandomForestClassifier::new()
699 .n_estimators(10)
700 .bootstrap(false)
701 .seed(42);
702 rf.fit(&data).unwrap();
703
704 assert!(
705 rf.oob_score().is_none(),
706 "OOB score should be None when bootstrap=false"
707 );
708 }
709}