1use crate::dataset::Dataset;
15use crate::error::{Result, ScryLearnError};
16use crate::neural::callback::{CallbackAction, EpochMetrics, TrainingCallback, TrainingHistory};
17use crate::tree::cart::{presort_indices, DecisionTreeRegressor};
18use crate::weights::{compute_sample_weights, ClassWeight};
19
20#[derive(Clone, Debug, Default)]
36#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
37#[non_exhaustive]
38pub enum RegressionLoss {
39 #[default]
41 SquaredError,
42 AbsoluteError,
44 Huber {
47 alpha: f64,
49 },
50 Quantile {
53 alpha: f64,
55 },
56}
57
58impl RegressionLoss {
59 fn initial_prediction(&self, y: &[f64]) -> f64 {
61 match self {
62 Self::SquaredError => {
63 let sum: f64 = y.iter().sum();
64 sum / y.len() as f64
65 }
66 Self::AbsoluteError | Self::Huber { .. } => median(y),
67 Self::Quantile { alpha } => quantile(y, *alpha),
68 }
69 }
70
71 fn negative_gradient(&self, y: f64, f: f64, delta: f64) -> f64 {
75 match self {
76 Self::SquaredError => y - f,
77 Self::AbsoluteError => {
78 if y > f {
79 1.0
80 } else if y < f {
81 -1.0
82 } else {
83 0.0
84 }
85 }
86 Self::Huber { .. } => {
87 let r = y - f;
88 if r.abs() <= delta {
89 r
90 } else {
91 delta * r.signum()
92 }
93 }
94 Self::Quantile { alpha } => {
95 if y > f {
96 *alpha
97 } else if y < f {
98 -(1.0 - alpha)
99 } else {
100 0.0
101 }
102 }
103 }
104 }
105
106 fn update_terminal_value(
111 &self,
112 residuals: &[f64],
113 y_in_leaf: &[f64],
114 f_in_leaf: &[f64],
115 delta: f64,
116 ) -> f64 {
117 match self {
118 Self::SquaredError => {
119 if residuals.is_empty() {
121 0.0
122 } else {
123 residuals.iter().sum::<f64>() / residuals.len() as f64
124 }
125 }
126 Self::AbsoluteError => median(residuals),
127 Self::Huber { .. } => {
128 let med = median(residuals);
130 let correction: f64 = residuals
131 .iter()
132 .map(|&r| {
133 let diff = r - med;
134 diff.clamp(-delta, delta)
135 })
136 .sum::<f64>()
137 / residuals.len().max(1) as f64;
138 med + correction
139 }
140 Self::Quantile { alpha } => {
141 let diffs: Vec<f64> = y_in_leaf
143 .iter()
144 .zip(f_in_leaf.iter())
145 .map(|(&y, &f)| y - f)
146 .collect();
147 quantile(&diffs, *alpha)
148 }
149 }
150 }
151
152 fn needs_terminal_update(&self) -> bool {
154 !matches!(self, Self::SquaredError)
155 }
156}
157
158fn median(data: &[f64]) -> f64 {
160 if data.is_empty() {
161 return 0.0;
162 }
163 let mut sorted: Vec<f64> = data.to_vec();
164 sorted.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
165 let n = sorted.len();
166 if n % 2 == 1 {
167 sorted[n / 2]
168 } else {
169 f64::midpoint(sorted[n / 2 - 1], sorted[n / 2])
170 }
171}
172
173fn quantile(data: &[f64], alpha: f64) -> f64 {
175 if data.is_empty() {
176 return 0.0;
177 }
178 let mut sorted: Vec<f64> = data.to_vec();
179 sorted.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
180 let n = sorted.len();
181 if n == 1 {
182 return sorted[0];
183 }
184 let pos = alpha * (n - 1) as f64;
185 let lo = pos.floor() as usize;
186 let hi = pos.ceil() as usize;
187 if lo == hi {
188 sorted[lo]
189 } else {
190 let frac = pos - lo as f64;
191 sorted[lo] * (1.0 - frac) + sorted[hi] * frac
192 }
193}
194
195#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
224#[non_exhaustive]
225pub struct GradientBoostingRegressor {
226 n_estimators: usize,
227 learning_rate: f64,
228 max_depth: usize,
229 min_samples_split: usize,
230 min_samples_leaf: usize,
231 subsample: f64,
232 seed: u64,
233 loss: RegressionLoss,
234 validation_fraction: f64,
235 n_iter_no_change: Option<usize>,
236 tol: f64,
237 trees: Vec<DecisionTreeRegressor>,
239 init_prediction: f64,
240 n_features: usize,
241 fitted: bool,
242 n_estimators_used: usize,
243 history: Option<TrainingHistory>,
244 #[cfg_attr(feature = "serde", serde(skip))]
246 callbacks: Vec<Box<dyn TrainingCallback>>,
247 #[cfg_attr(feature = "serde", serde(default))]
248 _schema_version: u32,
249}
250
251impl Clone for GradientBoostingRegressor {
252 fn clone(&self) -> Self {
253 Self {
254 n_estimators: self.n_estimators,
255 learning_rate: self.learning_rate,
256 max_depth: self.max_depth,
257 min_samples_split: self.min_samples_split,
258 min_samples_leaf: self.min_samples_leaf,
259 subsample: self.subsample,
260 seed: self.seed,
261 loss: self.loss.clone(),
262 validation_fraction: self.validation_fraction,
263 n_iter_no_change: self.n_iter_no_change,
264 tol: self.tol,
265 trees: self.trees.clone(),
266 init_prediction: self.init_prediction,
267 n_features: self.n_features,
268 fitted: self.fitted,
269 n_estimators_used: self.n_estimators_used,
270 history: self.history.clone(),
271 callbacks: Vec::new(),
272 _schema_version: self._schema_version,
273 }
274 }
275}
276
277impl GradientBoostingRegressor {
278 pub fn new() -> Self {
280 Self {
281 n_estimators: 100,
282 learning_rate: 0.1,
283 max_depth: 3,
284 min_samples_split: 2,
285 min_samples_leaf: 1,
286 subsample: 1.0,
287 seed: 42,
288 loss: RegressionLoss::SquaredError,
289 validation_fraction: 0.1,
290 n_iter_no_change: None,
291 tol: crate::constants::DEFAULT_TOL,
292 trees: Vec::new(),
293 init_prediction: 0.0,
294 n_features: 0,
295 fitted: false,
296 n_estimators_used: 0,
297 history: None,
298 callbacks: Vec::new(),
299 _schema_version: crate::version::SCHEMA_VERSION,
300 }
301 }
302
303 pub fn n_estimators(mut self, n: usize) -> Self {
305 self.n_estimators = n;
306 self
307 }
308
309 pub fn learning_rate(mut self, lr: f64) -> Self {
311 self.learning_rate = lr;
312 self
313 }
314
315 pub fn max_depth(mut self, d: usize) -> Self {
317 self.max_depth = d;
318 self
319 }
320
321 pub fn min_samples_split(mut self, n: usize) -> Self {
323 self.min_samples_split = n;
324 self
325 }
326
327 pub fn min_samples_leaf(mut self, n: usize) -> Self {
329 self.min_samples_leaf = n;
330 self
331 }
332
333 pub fn subsample(mut self, s: f64) -> Self {
335 self.subsample = s;
336 self
337 }
338
339 pub fn seed(mut self, s: u64) -> Self {
341 self.seed = s;
342 self
343 }
344
345 pub fn n_iter_no_change(mut self, n: usize) -> Self {
348 self.n_iter_no_change = Some(n);
349 self
350 }
351
352 pub fn validation_fraction(mut self, frac: f64) -> Self {
355 self.validation_fraction = frac;
356 self
357 }
358
359 pub fn tol(mut self, t: f64) -> Self {
361 self.tol = t;
362 self
363 }
364
365 pub fn callback(mut self, cb: Box<dyn TrainingCallback>) -> Self {
367 self.callbacks.push(cb);
368 self
369 }
370
371 pub fn n_estimators_used(&self) -> usize {
374 self.n_estimators_used
375 }
376
377 pub fn loss(mut self, l: RegressionLoss) -> Self {
387 self.loss = l;
388 self
389 }
390
391 pub fn fit(&mut self, data: &Dataset) -> Result<()> {
393 data.validate_finite()?;
394 let n = data.n_samples();
395 if n == 0 {
396 return Err(ScryLearnError::EmptyDataset);
397 }
398 if self.learning_rate <= 0.0 || self.learning_rate > 1.0 {
399 return Err(ScryLearnError::InvalidParameter(
400 "learning_rate must be in (0, 1]".into(),
401 ));
402 }
403 if self.subsample <= 0.0 || self.subsample > 1.0 {
404 return Err(ScryLearnError::InvalidParameter(
405 "subsample must be in (0, 1]".into(),
406 ));
407 }
408
409 self.n_features = data.n_features();
410
411 let (train_data, val_data) = if self.n_iter_no_change.is_some() {
413 let (t, v) = crate::split::train_test_split(data, self.validation_fraction, self.seed);
414 (t, Some(v))
415 } else {
416 (data.clone(), None)
417 };
418 let n_train = train_data.n_samples();
419
420 let init = self.loss.initial_prediction(&train_data.target);
422 self.init_prediction = init;
423
424 let mut f_vals = vec![init; n_train];
426
427 let delta = match &self.loss {
429 RegressionLoss::Huber { alpha } => {
430 let abs_resid: Vec<f64> = train_data
431 .target
432 .iter()
433 .zip(f_vals.iter())
434 .map(|(&y, &f)| (y - f).abs())
435 .collect();
436 quantile(&abs_resid, *alpha)
437 }
438 _ => 0.0, };
440
441 let mut rng = crate::rng::FastRng::new(self.seed);
442 let all_indices: Vec<usize> = (0..n_train).collect();
443 self.trees = Vec::with_capacity(self.n_estimators);
444
445 let mut temp_data = Dataset::new(
447 train_data.features.clone(),
448 vec![0.0; n_train],
449 train_data.feature_names.clone(),
450 "residual",
451 );
452 let row_major = train_data.feature_matrix();
453
454 let global_sorted = presort_indices(&temp_data, &all_indices);
458
459 let mut best_val_loss = f64::INFINITY;
461 let mut no_improve_count = 0usize;
462 let patience = self.n_iter_no_change.unwrap_or(usize::MAX);
463
464 let mut history = TrainingHistory::new();
465 let mut callbacks = std::mem::take(&mut self.callbacks);
466
467 for round in 0..self.n_estimators {
468 let round_start = std::time::Instant::now();
469
470 for (i, fv) in f_vals.iter().enumerate().take(n_train) {
472 temp_data.target[i] = self
473 .loss
474 .negative_gradient(train_data.target[i], *fv, delta);
475 }
476
477 let indices = subsample_indices(n_train, self.subsample, &mut rng, &all_indices);
479
480 let mut tree = DecisionTreeRegressor::new()
482 .max_depth(self.max_depth)
483 .min_samples_split(self.min_samples_split)
484 .min_samples_leaf(self.min_samples_leaf);
485 tree.fit_on_indices_presorted(&temp_data, &indices, &global_sorted)?;
486
487 if self.loss.needs_terminal_update() {
490 if let Some(ref mut flat) = tree.flat_tree {
491 let leaf_ids = flat.apply(&row_major);
493 let n_nodes = flat.n_nodes();
494 let mut leaf_residuals: Vec<Vec<f64>> = vec![Vec::new(); n_nodes];
495 let mut leaf_y: Vec<Vec<f64>> = vec![Vec::new(); n_nodes];
496 let mut leaf_f: Vec<Vec<f64>> = vec![Vec::new(); n_nodes];
497 for (i, &lid) in leaf_ids.iter().enumerate() {
498 leaf_residuals[lid].push(temp_data.target[i]);
499 leaf_y[lid].push(train_data.target[i]);
500 leaf_f[lid].push(f_vals[i]);
501 }
502 for node_id in 0..n_nodes {
503 if !leaf_residuals[node_id].is_empty() {
504 let new_val = self.loss.update_terminal_value(
505 &leaf_residuals[node_id],
506 &leaf_y[node_id],
507 &leaf_f[node_id],
508 delta,
509 );
510 flat.set_leaf_prediction(node_id, new_val);
511 }
512 }
513 }
514 }
515
516 let tree_preds = tree.predict(&row_major)?;
518 for (f_val, &tp) in f_vals.iter_mut().zip(tree_preds.iter()) {
519 *f_val += self.learning_rate * tp;
520 }
521
522 self.trees.push(tree);
523
524 let train_mse: f64 = train_data
526 .target
527 .iter()
528 .zip(f_vals.iter())
529 .map(|(&y, &f)| (y - f).powi(2))
530 .sum::<f64>()
531 / n_train as f64;
532
533 let grad_norm: f64 = temp_data
535 .target
536 .iter()
537 .take(n_train)
538 .map(|&r| r * r)
539 .sum::<f64>()
540 .sqrt();
541
542 let elapsed = round_start.elapsed().as_millis() as u64;
543
544 let metrics = EpochMetrics {
545 epoch: round,
546 train_loss: train_mse,
547 val_loss: None, train_metric: None,
549 val_metric: None,
550 learning_rate: self.learning_rate,
551 grad_norm,
552 elapsed_ms: elapsed,
553 };
554
555 let mut cb_stop = false;
556 for cb in &mut callbacks {
557 if cb.on_epoch_end(&metrics) == CallbackAction::Stop {
558 cb_stop = true;
559 }
560 }
561
562 history.push(metrics);
563
564 if cb_stop {
565 self.n_estimators_used = round + 1;
566 self.fitted = true;
567 for cb in &mut callbacks {
568 cb.on_training_end();
569 }
570 self.callbacks = callbacks;
571 self.history = Some(history);
572 return Ok(());
573 }
574
575 if let Some(ref val) = val_data {
577 let val_features = val.feature_matrix();
578 let mut val_preds = vec![self.init_prediction; val_features.len()];
579 for t in &self.trees {
580 if let Ok(tp) = t.predict(&val_features) {
581 for (p, &v) in val_preds.iter_mut().zip(tp.iter()) {
582 *p += self.learning_rate * v;
583 }
584 }
585 }
586 let val_mse: f64 = val
587 .target
588 .iter()
589 .zip(val_preds.iter())
590 .map(|(&y, &p)| (y - p).powi(2))
591 .sum::<f64>()
592 / val.target.len() as f64;
593
594 if let Some(last) = history.epochs.last_mut() {
596 last.val_loss = Some(val_mse);
597 }
598
599 if val_mse + self.tol < best_val_loss {
600 best_val_loss = val_mse;
601 no_improve_count = 0;
602 } else {
603 no_improve_count += 1;
604 if no_improve_count >= patience {
605 self.n_estimators_used = round + 1;
606 self.fitted = true;
607 for cb in &mut callbacks {
608 cb.on_training_end();
609 }
610 self.callbacks = callbacks;
611 self.history = Some(history);
612 return Ok(());
613 }
614 }
615 }
616 }
617
618 self.n_estimators_used = self.trees.len();
619 self.fitted = true;
620 for cb in &mut callbacks {
621 cb.on_training_end();
622 }
623 self.callbacks = callbacks;
624 self.history = Some(history);
625 Ok(())
626 }
627
628 pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
632 crate::version::check_schema_version(self._schema_version)?;
633 if !self.fitted {
634 return Err(ScryLearnError::NotFitted);
635 }
636 let n = features.len();
637 let mut preds = vec![self.init_prediction; n];
638 for tree in &self.trees {
639 let tp = tree.predict(features)?;
640 for (p, &t) in preds.iter_mut().zip(tp.iter()) {
641 *p += self.learning_rate * t;
642 }
643 }
644 Ok(preds)
645 }
646
647 pub fn feature_importances(&self) -> Result<Vec<f64>> {
649 if !self.fitted {
650 return Err(ScryLearnError::NotFitted);
651 }
652 let m = self.n_features;
653 let mut importances = vec![0.0; m];
654 let n_trees = self.trees.len() as f64;
655 for tree in &self.trees {
656 if let Ok(imp) = tree.feature_importances() {
657 for (i, &v) in imp.iter().enumerate() {
658 if i < m {
659 importances[i] += v / n_trees;
660 }
661 }
662 }
663 }
664 let total: f64 = importances.iter().sum();
666 if total > 0.0 {
667 for v in &mut importances {
668 *v /= total;
669 }
670 }
671 Ok(importances)
672 }
673
674 pub fn n_trees(&self) -> usize {
676 self.trees.len()
677 }
678
679 pub fn early_stopped(&self) -> bool {
681 self.n_iter_no_change.is_some() && self.n_estimators_used < self.n_estimators
682 }
683
684 pub fn history(&self) -> Option<&TrainingHistory> {
686 self.history.as_ref()
687 }
688
689 pub fn trees(&self) -> &[DecisionTreeRegressor] {
691 &self.trees
692 }
693
694 pub fn n_features(&self) -> usize {
696 self.n_features
697 }
698
699 pub fn learning_rate_val(&self) -> f64 {
701 self.learning_rate
702 }
703
704 pub fn init_prediction_val(&self) -> f64 {
706 self.init_prediction
707 }
708}
709
710impl Default for GradientBoostingRegressor {
711 fn default() -> Self {
712 Self::new()
713 }
714}
715
716#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
752#[non_exhaustive]
753pub struct GradientBoostingClassifier {
754 n_estimators: usize,
755 learning_rate: f64,
756 max_depth: usize,
757 min_samples_split: usize,
758 min_samples_leaf: usize,
759 subsample: f64,
760 seed: u64,
761 class_weight: ClassWeight,
762 trees: Vec<Vec<DecisionTreeRegressor>>,
764 init_predictions: Vec<f64>,
765 n_classes: usize,
766 n_features: usize,
767 fitted: bool,
768 history: Option<TrainingHistory>,
769 #[cfg_attr(feature = "serde", serde(skip))]
771 callbacks: Vec<Box<dyn TrainingCallback>>,
772 #[cfg_attr(feature = "serde", serde(default))]
773 _schema_version: u32,
774}
775
776impl Clone for GradientBoostingClassifier {
777 fn clone(&self) -> Self {
778 Self {
779 n_estimators: self.n_estimators,
780 learning_rate: self.learning_rate,
781 max_depth: self.max_depth,
782 min_samples_split: self.min_samples_split,
783 min_samples_leaf: self.min_samples_leaf,
784 subsample: self.subsample,
785 seed: self.seed,
786 class_weight: self.class_weight.clone(),
787 trees: self.trees.clone(),
788 init_predictions: self.init_predictions.clone(),
789 n_classes: self.n_classes,
790 n_features: self.n_features,
791 fitted: self.fitted,
792 history: self.history.clone(),
793 callbacks: Vec::new(),
794 _schema_version: self._schema_version,
795 }
796 }
797}
798
799impl GradientBoostingClassifier {
800 pub fn new() -> Self {
802 Self {
803 n_estimators: 100,
804 learning_rate: 0.1,
805 max_depth: 3,
806 min_samples_split: 2,
807 min_samples_leaf: 1,
808 subsample: 1.0,
809 seed: 42,
810 class_weight: ClassWeight::Uniform,
811 trees: Vec::new(),
812 init_predictions: Vec::new(),
813 n_classes: 0,
814 n_features: 0,
815 fitted: false,
816 history: None,
817 callbacks: Vec::new(),
818 _schema_version: crate::version::SCHEMA_VERSION,
819 }
820 }
821
822 pub fn n_estimators(mut self, n: usize) -> Self {
824 self.n_estimators = n;
825 self
826 }
827
828 pub fn learning_rate(mut self, lr: f64) -> Self {
830 self.learning_rate = lr;
831 self
832 }
833
834 pub fn max_depth(mut self, d: usize) -> Self {
836 self.max_depth = d;
837 self
838 }
839
840 pub fn min_samples_split(mut self, n: usize) -> Self {
842 self.min_samples_split = n;
843 self
844 }
845
846 pub fn min_samples_leaf(mut self, n: usize) -> Self {
848 self.min_samples_leaf = n;
849 self
850 }
851
852 pub fn subsample(mut self, s: f64) -> Self {
854 self.subsample = s;
855 self
856 }
857
858 pub fn seed(mut self, s: u64) -> Self {
860 self.seed = s;
861 self
862 }
863
864 pub fn class_weight(mut self, cw: ClassWeight) -> Self {
866 self.class_weight = cw;
867 self
868 }
869
870 pub fn callback(mut self, cb: Box<dyn TrainingCallback>) -> Self {
872 self.callbacks.push(cb);
873 self
874 }
875
876 pub fn fit(&mut self, data: &Dataset) -> Result<()> {
878 data.validate_finite()?;
879 let n = data.n_samples();
880 if n == 0 {
881 return Err(ScryLearnError::EmptyDataset);
882 }
883 if self.learning_rate <= 0.0 || self.learning_rate > 1.0 {
884 return Err(ScryLearnError::InvalidParameter(
885 "learning_rate must be in (0, 1]".into(),
886 ));
887 }
888 if self.subsample <= 0.0 || self.subsample > 1.0 {
889 return Err(ScryLearnError::InvalidParameter(
890 "subsample must be in (0, 1]".into(),
891 ));
892 }
893
894 self.n_features = data.n_features();
895 self.n_classes = data.n_classes();
896 let k = self.n_classes;
897
898 if k < 2 {
899 return Err(ScryLearnError::InvalidParameter(
900 "need at least 2 classes for classification".into(),
901 ));
902 }
903
904 let mut rng = crate::rng::FastRng::new(self.seed);
905 let all_indices: Vec<usize> = (0..n).collect();
906 let row_major = data.feature_matrix();
907 let sample_weights = compute_sample_weights(&data.target, &self.class_weight);
908
909 if k == 2 {
910 self.fit_binary(data, n, &mut rng, &all_indices, &row_major, &sample_weights)?;
912 } else {
913 self.fit_multiclass(
915 data,
916 n,
917 k,
918 &mut rng,
919 &all_indices,
920 &row_major,
921 &sample_weights,
922 )?;
923 }
924 Ok(())
925 }
926
927 fn fit_binary(
929 &mut self,
930 data: &Dataset,
931 n: usize,
932 rng: &mut crate::rng::FastRng,
933 all_indices: &[usize],
934 row_major: &[Vec<f64>],
935 sample_weights: &[f64],
936 ) -> Result<()> {
937 let pos_count = data.target.iter().filter(|&&y| y > 0.5).count();
939 let p = (pos_count as f64) / (n as f64);
940 let p_clamped = p.clamp(
941 crate::constants::GBT_PROB_CLAMP,
942 1.0 - crate::constants::GBT_PROB_CLAMP,
943 );
944 let f0 = (p_clamped / (1.0 - p_clamped)).ln(); self.init_predictions = vec![f0];
946
947 let mut f_vals = vec![f0; n];
948 let mut trees_seq = Vec::with_capacity(self.n_estimators);
949 let mut history = TrainingHistory::new();
950 let mut callbacks = std::mem::take(&mut self.callbacks);
951
952 let mut temp_data = Dataset::new(
954 data.features.clone(),
955 vec![0.0; n],
956 data.feature_names.clone(),
957 "residual",
958 );
959
960 let global_sorted = presort_indices(&temp_data, all_indices);
962
963 for round in 0..self.n_estimators {
964 let round_start = std::time::Instant::now();
965
966 let probs: Vec<f64> = f_vals.iter().map(|&f| sigmoid(f)).collect();
968
969 for i in 0..n {
971 temp_data.target[i] = sample_weights[i] * (data.target[i] - probs[i]);
972 }
973
974 let indices = subsample_indices(n, self.subsample, rng, all_indices);
975
976 let mut tree = DecisionTreeRegressor::new()
977 .max_depth(self.max_depth)
978 .min_samples_split(self.min_samples_split)
979 .min_samples_leaf(self.min_samples_leaf);
980 tree.fit_on_indices_presorted(&temp_data, &indices, &global_sorted)?;
981
982 if let Some(ref mut flat) = tree.flat_tree {
987 let leaf_indices = flat.apply(row_major);
988 newton_correct_binary_leaves(
989 flat,
990 &leaf_indices,
991 &temp_data.target, &probs,
993 );
994 }
995
996 let tp = tree.predict(row_major)?;
997 for (f_val, &t) in f_vals.iter_mut().zip(tp.iter()) {
998 *f_val += self.learning_rate * t;
999 }
1000
1001 trees_seq.push(tree);
1002
1003 let probs_after: Vec<f64> = f_vals.iter().map(|&f| sigmoid(f)).collect();
1005 let train_loss: f64 = data
1006 .target
1007 .iter()
1008 .zip(probs_after.iter())
1009 .map(|(&y, &p)| {
1010 let p_c = p.clamp(
1011 crate::constants::NEAR_ZERO,
1012 1.0 - crate::constants::NEAR_ZERO,
1013 );
1014 -(y * p_c.ln() + (1.0 - y) * (1.0 - p_c).ln())
1015 })
1016 .sum::<f64>()
1017 / n as f64;
1018
1019 let grad_norm: f64 = temp_data
1021 .target
1022 .iter()
1023 .take(n)
1024 .map(|&r| r * r)
1025 .sum::<f64>()
1026 .sqrt();
1027
1028 let elapsed = round_start.elapsed().as_millis() as u64;
1029
1030 let metrics = EpochMetrics {
1031 epoch: round,
1032 train_loss,
1033 val_loss: None,
1034 train_metric: None,
1035 val_metric: None,
1036 learning_rate: self.learning_rate,
1037 grad_norm,
1038 elapsed_ms: elapsed,
1039 };
1040
1041 let mut cb_stop = false;
1042 for cb in &mut callbacks {
1043 if cb.on_epoch_end(&metrics) == CallbackAction::Stop {
1044 cb_stop = true;
1045 }
1046 }
1047
1048 history.push(metrics);
1049
1050 if cb_stop {
1051 break;
1052 }
1053 }
1054
1055 self.trees = vec![trees_seq];
1056 self.fitted = true;
1057 for cb in &mut callbacks {
1058 cb.on_training_end();
1059 }
1060 self.callbacks = callbacks;
1061 self.history = Some(history);
1062 Ok(())
1063 }
1064
1065 #[allow(clippy::too_many_arguments)]
1067 fn fit_multiclass(
1068 &mut self,
1069 data: &Dataset,
1070 n: usize,
1071 k: usize,
1072 rng: &mut crate::rng::FastRng,
1073 all_indices: &[usize],
1074 row_major: &[Vec<f64>],
1075 sample_weights: &[f64],
1076 ) -> Result<()> {
1077 let y_onehot: Vec<Vec<f64>> = (0..k)
1079 .map(|cls| {
1080 data.target
1081 .iter()
1082 .map(|&y| if (y as usize) == cls { 1.0 } else { 0.0 })
1083 .collect()
1084 })
1085 .collect();
1086
1087 let class_counts: Vec<usize> = (0..k)
1089 .map(|cls| data.target.iter().filter(|&&y| (y as usize) == cls).count())
1090 .collect();
1091 let init_preds: Vec<f64> = class_counts
1092 .iter()
1093 .map(|&c| {
1094 let p = (c as f64 / n as f64).clamp(
1095 crate::constants::GBT_PROB_CLAMP,
1096 1.0 - crate::constants::GBT_PROB_CLAMP,
1097 );
1098 p.ln()
1099 })
1100 .collect();
1101 self.init_predictions.clone_from(&init_preds);
1102
1103 let mut f_vals: Vec<Vec<f64>> = (0..k).map(|c| vec![init_preds[c]; n]).collect();
1105
1106 let mut trees_all: Vec<Vec<DecisionTreeRegressor>> = (0..k)
1107 .map(|_| Vec::with_capacity(self.n_estimators))
1108 .collect();
1109 let mut history = TrainingHistory::new();
1110 let mut callbacks = std::mem::take(&mut self.callbacks);
1111
1112 let mut temp_data = Dataset::new(
1114 data.features.clone(),
1115 vec![0.0; n],
1116 data.feature_names.clone(),
1117 "residual",
1118 );
1119
1120 let global_sorted = presort_indices(&temp_data, all_indices);
1122
1123 for round in 0..self.n_estimators {
1124 let round_start = std::time::Instant::now();
1125 let probs = softmax_matrix(&f_vals, n, k);
1127
1128 let indices = subsample_indices(n, self.subsample, rng, all_indices);
1129
1130 for cls in 0..k {
1132 for i in 0..n {
1134 temp_data.target[i] = sample_weights[i] * (y_onehot[cls][i] - probs[cls][i]);
1135 }
1136
1137 let mut tree = DecisionTreeRegressor::new()
1138 .max_depth(self.max_depth)
1139 .min_samples_split(self.min_samples_split)
1140 .min_samples_leaf(self.min_samples_leaf);
1141 tree.fit_on_indices_presorted(&temp_data, &indices, &global_sorted)?;
1142
1143 if let Some(ref mut flat) = tree.flat_tree {
1147 let leaf_indices = flat.apply(row_major);
1148 newton_correct_multiclass_leaves(
1149 flat,
1150 &leaf_indices,
1151 &temp_data.target, &probs[cls], k,
1154 );
1155 }
1156
1157 let tp = tree.predict(row_major)?;
1158 for (f_val, &t) in f_vals[cls].iter_mut().zip(tp.iter()) {
1159 *f_val += self.learning_rate * t;
1160 }
1161
1162 trees_all[cls].push(tree);
1163 }
1164
1165 let probs_after = softmax_matrix(&f_vals, n, k);
1167 let train_loss: f64 = (0..n)
1168 .map(|i| {
1169 let cls_i = data.target[i] as usize;
1170 let p = probs_after[cls_i][i].clamp(
1171 crate::constants::NEAR_ZERO,
1172 1.0 - crate::constants::NEAR_ZERO,
1173 );
1174 -p.ln()
1175 })
1176 .sum::<f64>()
1177 / n as f64;
1178
1179 let grad_norm: f64 = temp_data
1181 .target
1182 .iter()
1183 .take(n)
1184 .map(|&r| r * r)
1185 .sum::<f64>()
1186 .sqrt();
1187
1188 let elapsed = round_start.elapsed().as_millis() as u64;
1189
1190 let metrics = EpochMetrics {
1191 epoch: round,
1192 train_loss,
1193 val_loss: None,
1194 train_metric: None,
1195 val_metric: None,
1196 learning_rate: self.learning_rate,
1197 grad_norm,
1198 elapsed_ms: elapsed,
1199 };
1200
1201 let mut cb_stop = false;
1202 for cb in &mut callbacks {
1203 if cb.on_epoch_end(&metrics) == CallbackAction::Stop {
1204 cb_stop = true;
1205 }
1206 }
1207
1208 history.push(metrics);
1209
1210 if cb_stop {
1211 break;
1212 }
1213 }
1214
1215 self.trees = trees_all;
1216 self.fitted = true;
1217 for cb in &mut callbacks {
1218 cb.on_training_end();
1219 }
1220 self.callbacks = callbacks;
1221 self.history = Some(history);
1222 Ok(())
1223 }
1224
1225 pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
1227 crate::version::check_schema_version(self._schema_version)?;
1228 if !self.fitted {
1229 return Err(ScryLearnError::NotFitted);
1230 }
1231 let proba = self.predict_proba(features)?;
1232 Ok(proba
1233 .iter()
1234 .map(|row| {
1235 row.iter()
1236 .enumerate()
1237 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
1238 .map_or(0.0, |(idx, _)| idx as f64)
1239 })
1240 .collect())
1241 }
1242
1243 pub fn predict_proba(&self, features: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
1245 if !self.fitted {
1246 return Err(ScryLearnError::NotFitted);
1247 }
1248 let n = features.len();
1249 let k = self.n_classes;
1250
1251 if k == 2 {
1252 let mut f_vals = vec![self.init_predictions[0]; n];
1254 for tree in &self.trees[0] {
1255 let tp = tree.predict(features)?;
1256 for (f, &t) in f_vals.iter_mut().zip(tp.iter()) {
1257 *f += self.learning_rate * t;
1258 }
1259 }
1260 Ok(f_vals
1261 .iter()
1262 .map(|&f| {
1263 let p1 = sigmoid(f);
1264 vec![1.0 - p1, p1]
1265 })
1266 .collect())
1267 } else {
1268 let mut f_vals: Vec<Vec<f64>> =
1270 (0..k).map(|c| vec![self.init_predictions[c]; n]).collect();
1271 for (cls_fvals, cls_trees) in f_vals.iter_mut().zip(self.trees.iter()).take(k) {
1272 for tree in cls_trees {
1273 let tp = tree.predict(features)?;
1274 for (f, &t) in cls_fvals.iter_mut().zip(tp.iter()) {
1275 *f += self.learning_rate * t;
1276 }
1277 }
1278 }
1279 let mut result = Vec::with_capacity(n);
1281 #[allow(clippy::needless_range_loop)]
1282 for i in 0..n {
1283 let logits: Vec<f64> = (0..k).map(|c| f_vals[c][i]).collect();
1284 let max_l = logits.iter().copied().fold(f64::NEG_INFINITY, f64::max);
1285 let exps: Vec<f64> = logits.iter().map(|&l| (l - max_l).exp()).collect();
1286 let sum: f64 = exps.iter().sum();
1287 result.push(exps.iter().map(|&e| e / sum).collect());
1288 }
1289 Ok(result)
1290 }
1291 }
1292
1293 pub fn feature_importances(&self) -> Result<Vec<f64>> {
1295 if !self.fitted {
1296 return Err(ScryLearnError::NotFitted);
1297 }
1298 let m = self.n_features;
1299 let mut importances = vec![0.0; m];
1300 let mut total_trees = 0.0;
1301 for class_trees in &self.trees {
1302 for tree in class_trees {
1303 if let Ok(imp) = tree.feature_importances() {
1304 for (i, &v) in imp.iter().enumerate() {
1305 if i < m {
1306 importances[i] += v;
1307 }
1308 }
1309 }
1310 total_trees += 1.0;
1311 }
1312 }
1313 if total_trees > 0.0 {
1314 for v in &mut importances {
1315 *v /= total_trees;
1316 }
1317 }
1318 let total: f64 = importances.iter().sum();
1319 if total > 0.0 {
1320 for v in &mut importances {
1321 *v /= total;
1322 }
1323 }
1324 Ok(importances)
1325 }
1326
1327 pub fn n_classes(&self) -> usize {
1329 self.n_classes
1330 }
1331
1332 pub fn n_trees(&self) -> usize {
1334 self.trees.iter().map(Vec::len).sum()
1335 }
1336
1337 pub fn history(&self) -> Option<&TrainingHistory> {
1339 self.history.as_ref()
1340 }
1341
1342 pub fn class_trees(&self) -> &[Vec<DecisionTreeRegressor>] {
1345 &self.trees
1346 }
1347
1348 pub fn n_features(&self) -> usize {
1350 self.n_features
1351 }
1352
1353 pub fn learning_rate_val(&self) -> f64 {
1355 self.learning_rate
1356 }
1357
1358 pub fn init_predictions_val(&self) -> &[f64] {
1360 &self.init_predictions
1361 }
1362}
1363
1364impl Default for GradientBoostingClassifier {
1365 fn default() -> Self {
1366 Self::new()
1367 }
1368}
1369
1370fn newton_correct_binary_leaves(
1383 flat: &mut crate::tree::cart::FlatTree,
1384 leaf_indices: &[usize],
1385 residuals: &[f64],
1386 probs: &[f64],
1387) {
1388 use std::collections::HashMap;
1389
1390 let mut leaf_num: HashMap<usize, f64> = HashMap::new();
1392 let mut leaf_den: HashMap<usize, f64> = HashMap::new();
1393
1394 for (i, &leaf_idx) in leaf_indices.iter().enumerate() {
1395 let r = residuals[i];
1396 let p = probs[i];
1397 let hessian = p * (1.0 - p);
1398 *leaf_num.entry(leaf_idx).or_insert(0.0) += r;
1399 *leaf_den.entry(leaf_idx).or_insert(0.0) += hessian;
1400 }
1401
1402 for (&leaf_idx, &num) in &leaf_num {
1404 let den = leaf_den[&leaf_idx];
1405 if den.abs() > crate::constants::SINGULAR_THRESHOLD {
1407 flat.set_leaf_prediction(leaf_idx, num / den);
1408 }
1409 }
1410}
1411
1412fn newton_correct_multiclass_leaves(
1421 flat: &mut crate::tree::cart::FlatTree,
1422 leaf_indices: &[usize],
1423 residuals: &[f64],
1424 probs: &[f64],
1425 k: usize,
1426) {
1427 use std::collections::HashMap;
1428
1429 let factor = (k - 1) as f64 / k as f64;
1430
1431 let mut leaf_num: HashMap<usize, f64> = HashMap::new();
1432 let mut leaf_den: HashMap<usize, f64> = HashMap::new();
1433
1434 for (i, &leaf_idx) in leaf_indices.iter().enumerate() {
1435 let r = residuals[i];
1436 let p = probs[i];
1437 let hessian = (p * (1.0 - p)).max(crate::constants::SINGULAR_THRESHOLD);
1438 *leaf_num.entry(leaf_idx).or_insert(0.0) += r;
1439 *leaf_den.entry(leaf_idx).or_insert(0.0) += hessian;
1440 }
1441
1442 for (&leaf_idx, &num) in &leaf_num {
1443 let den = leaf_den[&leaf_idx];
1444 if den.abs() > crate::constants::SINGULAR_THRESHOLD {
1445 flat.set_leaf_prediction(leaf_idx, factor * num / den);
1446 }
1447 }
1448}
1449
1450#[inline]
1455fn sigmoid(x: f64) -> f64 {
1456 1.0 / (1.0 + (-x).exp())
1457}
1458
1459fn softmax_matrix(f_vals: &[Vec<f64>], n: usize, k: usize) -> Vec<Vec<f64>> {
1461 let mut probs = vec![vec![0.0; n]; k];
1462 for i in 0..n {
1463 let logits: Vec<f64> = (0..k).map(|c| f_vals[c][i]).collect();
1464 let max_l = logits.iter().copied().fold(f64::NEG_INFINITY, f64::max);
1465 let exps: Vec<f64> = logits.iter().map(|&l| (l - max_l).exp()).collect();
1466 let sum: f64 = exps.iter().sum();
1467 for c in 0..k {
1468 probs[c][i] = exps[c] / sum;
1469 }
1470 }
1471 probs
1472}
1473
1474fn subsample_indices(
1476 n: usize,
1477 subsample: f64,
1478 rng: &mut crate::rng::FastRng,
1479 all_indices: &[usize],
1480) -> Vec<usize> {
1481 if subsample >= 1.0 {
1482 return all_indices.to_vec();
1483 }
1484 let k = ((n as f64) * subsample).ceil() as usize;
1485 let mut idx = all_indices.to_vec();
1486 for i in 0..k.min(n) {
1487 let j = rng.usize(i..n);
1488 idx.swap(i, j);
1489 }
1490 idx.truncate(k);
1491 idx
1492}
1493
1494#[cfg(test)]
1499#[allow(clippy::float_cmp)]
1500mod tests {
1501 use super::*;
1502
1503 fn make_linear_data(n: usize) -> Dataset {
1504 let x: Vec<f64> = (0..n).map(|i| i as f64).collect();
1505 let y: Vec<f64> = x.iter().map(|&v| 2.0 * v + 1.0).collect(); Dataset::new(vec![x], y, vec!["x".into()], "y")
1507 }
1508
1509 fn make_binary_data() -> Dataset {
1510 let mut f1 = Vec::new();
1512 let mut f2 = Vec::new();
1513 let mut target = Vec::new();
1514 for i in 0..50 {
1515 let v = i as f64 / 50.0;
1516 f1.push(v);
1517 f2.push(v * 0.5);
1518 target.push(0.0);
1519 }
1520 for i in 0..50 {
1521 let v = 1.0 + i as f64 / 50.0;
1522 f1.push(v);
1523 f2.push(v * 0.5);
1524 target.push(1.0);
1525 }
1526 Dataset::new(vec![f1, f2], target, vec!["f1".into(), "f2".into()], "cls")
1527 }
1528
1529 fn make_multiclass_data() -> Dataset {
1530 let mut f1 = Vec::new();
1531 let mut f2 = Vec::new();
1532 let mut target = Vec::new();
1533 for i in 0..30 {
1534 f1.push(i as f64 / 30.0);
1535 f2.push(0.0);
1536 target.push(0.0);
1537 }
1538 for i in 0..30 {
1539 f1.push(2.0 + i as f64 / 30.0);
1540 f2.push(0.0);
1541 target.push(1.0);
1542 }
1543 for i in 0..30 {
1544 f1.push(4.0 + i as f64 / 30.0);
1545 f2.push(0.0);
1546 target.push(2.0);
1547 }
1548 Dataset::new(vec![f1, f2], target, vec!["f1".into(), "f2".into()], "cls")
1549 }
1550
1551 #[test]
1554 fn regressor_learns_linear() {
1555 let data = make_linear_data(100);
1556 let mut gbr = GradientBoostingRegressor::new()
1557 .n_estimators(100)
1558 .learning_rate(0.1)
1559 .max_depth(3);
1560 gbr.fit(&data).unwrap();
1561
1562 let preds = gbr.predict(&[vec![50.0], vec![75.0]]).unwrap();
1563 assert!((preds[0] - 101.0).abs() < 10.0, "pred={}", preds[0]);
1565 assert!((preds[1] - 151.0).abs() < 15.0, "pred={}", preds[1]);
1566 }
1567
1568 #[test]
1569 fn regressor_not_fitted_error() {
1570 let gbr = GradientBoostingRegressor::new();
1571 assert!(gbr.predict(&[vec![1.0]]).is_err());
1572 assert!(gbr.feature_importances().is_err());
1573 }
1574
1575 #[test]
1576 fn regressor_subsample() {
1577 let data = make_linear_data(100);
1578 let mut gbr = GradientBoostingRegressor::new()
1579 .n_estimators(50)
1580 .subsample(0.7)
1581 .learning_rate(0.1)
1582 .max_depth(3);
1583 gbr.fit(&data).unwrap();
1584 let preds = gbr.predict(&[vec![25.0]]).unwrap();
1585 assert!((preds[0] - 51.0).abs() < 15.0, "pred={}", preds[0]);
1587 }
1588
1589 #[test]
1590 fn regressor_feature_importances() {
1591 let data = make_linear_data(100);
1592 let mut gbr = GradientBoostingRegressor::new()
1593 .n_estimators(20)
1594 .max_depth(2);
1595 gbr.fit(&data).unwrap();
1596 let imp = gbr.feature_importances().unwrap();
1597 assert_eq!(imp.len(), 1);
1598 assert!(
1599 (imp[0] - 1.0).abs() < 1e-6,
1600 "single feature should have importance 1.0"
1601 );
1602 }
1603
1604 #[test]
1605 fn regressor_invalid_params() {
1606 let data = make_linear_data(10);
1607 let mut gbr = GradientBoostingRegressor::new().learning_rate(0.0);
1608 assert!(gbr.fit(&data).is_err());
1609
1610 let mut gbr = GradientBoostingRegressor::new().subsample(1.5);
1611 assert!(gbr.fit(&data).is_err());
1612 }
1613
1614 #[test]
1615 fn regressor_early_stopping() {
1616 let mut rng = crate::rng::FastRng::new(42);
1618 let n = 50;
1619 let x: Vec<f64> = (0..n).map(|_| rng.f64() * 10.0).collect();
1620 let y: Vec<f64> = x.iter().map(|&v| v.sin() + rng.f64() * 5.0).collect();
1622 let data = Dataset::new(vec![x], y, vec!["x".into()], "y");
1623
1624 let mut gbr = GradientBoostingRegressor::new()
1625 .n_estimators(1000)
1626 .learning_rate(0.5)
1627 .max_depth(5)
1628 .n_iter_no_change(5)
1629 .validation_fraction(0.3)
1630 .tol(0.0);
1631 gbr.fit(&data).unwrap();
1632
1633 assert!(
1636 gbr.n_trees() < 1000,
1637 "Expected early stopping, but used all {} estimators",
1638 gbr.n_trees()
1639 );
1640 assert!(gbr.early_stopped(), "early_stopped() should be true");
1641 assert!(gbr.n_estimators_used() < 1000);
1642 }
1643
1644 #[test]
1647 fn classifier_binary() {
1648 let data = make_binary_data();
1649 let mut gbc = GradientBoostingClassifier::new()
1650 .n_estimators(50)
1651 .learning_rate(0.1)
1652 .max_depth(2);
1653 gbc.fit(&data).unwrap();
1654
1655 let test = vec![vec![0.2, 0.1], vec![1.5, 0.75]];
1656 let preds = gbc.predict(&test).unwrap();
1657 assert_eq!(preds[0], 0.0, "low values -> class 0");
1658 assert_eq!(preds[1], 1.0, "high values -> class 1");
1659 }
1660
1661 #[test]
1662 fn classifier_binary_proba() {
1663 let data = make_binary_data();
1664 let mut gbc = GradientBoostingClassifier::new()
1665 .n_estimators(50)
1666 .learning_rate(0.1)
1667 .max_depth(2);
1668 gbc.fit(&data).unwrap();
1669
1670 let probas = gbc.predict_proba(&[vec![0.2, 0.1]]).unwrap();
1671 assert_eq!(probas[0].len(), 2);
1672 let sum: f64 = probas[0].iter().sum();
1673 assert!((sum - 1.0).abs() < 1e-6, "probabilities should sum to 1");
1674 assert!(probas[0][0] > probas[0][1], "class 0 should be more likely");
1675 }
1676
1677 #[test]
1678 fn classifier_multiclass() {
1679 let data = make_multiclass_data();
1680 let mut gbc = GradientBoostingClassifier::new()
1681 .n_estimators(100)
1682 .learning_rate(0.1)
1683 .max_depth(3);
1684 gbc.fit(&data).unwrap();
1685
1686 let test = vec![vec![0.5, 0.0], vec![2.5, 0.0], vec![4.5, 0.0]];
1687 let preds = gbc.predict(&test).unwrap();
1688 assert_eq!(preds[0], 0.0, "should be class 0");
1689 assert_eq!(preds[1], 1.0, "should be class 1");
1690 assert_eq!(preds[2], 2.0, "should be class 2");
1691 }
1692
1693 #[test]
1694 fn classifier_multiclass_proba() {
1695 let data = make_multiclass_data();
1696 let mut gbc = GradientBoostingClassifier::new()
1697 .n_estimators(50)
1698 .learning_rate(0.1)
1699 .max_depth(2);
1700 gbc.fit(&data).unwrap();
1701
1702 let probas = gbc.predict_proba(&[vec![0.5, 0.0]]).unwrap();
1703 assert_eq!(probas[0].len(), 3);
1704 let sum: f64 = probas[0].iter().sum();
1705 assert!((sum - 1.0).abs() < 1e-6, "probabilities should sum to 1");
1706 }
1707
1708 #[test]
1709 fn classifier_subsample() {
1710 let data = make_binary_data();
1711 let mut gbc = GradientBoostingClassifier::new()
1712 .n_estimators(50)
1713 .subsample(0.8)
1714 .learning_rate(0.1)
1715 .max_depth(2);
1716 gbc.fit(&data).unwrap();
1717
1718 let test = vec![vec![0.2, 0.1], vec![1.5, 0.75]];
1719 let preds = gbc.predict(&test).unwrap();
1720 assert_eq!(preds[0], 0.0);
1721 assert_eq!(preds[1], 1.0);
1722 }
1723
1724 #[test]
1725 fn classifier_feature_importances() {
1726 let data = make_binary_data();
1727 let mut gbc = GradientBoostingClassifier::new()
1728 .n_estimators(20)
1729 .max_depth(2);
1730 gbc.fit(&data).unwrap();
1731 let imp = gbc.feature_importances().unwrap();
1732 assert_eq!(imp.len(), 2);
1733 let sum: f64 = imp.iter().sum();
1734 assert!((sum - 1.0).abs() < 1e-4, "importances should sum to 1");
1735 }
1736
1737 #[test]
1738 fn classifier_not_fitted_error() {
1739 let gbc = GradientBoostingClassifier::new();
1740 assert!(gbc.predict(&[vec![1.0, 2.0]]).is_err());
1741 assert!(gbc.predict_proba(&[vec![1.0, 2.0]]).is_err());
1742 assert!(gbc.feature_importances().is_err());
1743 }
1744
1745 #[test]
1746 fn classifier_n_trees_binary() {
1747 let data = make_binary_data();
1748 let mut gbc = GradientBoostingClassifier::new()
1749 .n_estimators(25)
1750 .max_depth(2);
1751 gbc.fit(&data).unwrap();
1752 assert_eq!(gbc.n_trees(), 25, "binary: 1 class × 25 rounds");
1753 }
1754
1755 #[test]
1756 fn classifier_n_trees_multiclass() {
1757 let data = make_multiclass_data();
1758 let mut gbc = GradientBoostingClassifier::new()
1759 .n_estimators(10)
1760 .max_depth(2);
1761 gbc.fit(&data).unwrap();
1762 assert_eq!(gbc.n_trees(), 30, "multiclass: 3 classes × 10 rounds");
1763 }
1764
1765 #[test]
1768 fn regressor_loss_squared_error_default() {
1769 let data = make_linear_data(100);
1771 let mut gbr = GradientBoostingRegressor::new()
1772 .n_estimators(100)
1773 .loss(RegressionLoss::SquaredError)
1774 .learning_rate(0.1)
1775 .max_depth(3);
1776 gbr.fit(&data).unwrap();
1777 let preds = gbr.predict(&[vec![50.0]]).unwrap();
1778 assert!(
1779 (preds[0] - 101.0).abs() < 10.0,
1780 "SquaredError pred={}",
1781 preds[0]
1782 );
1783 }
1784
1785 #[test]
1786 fn regressor_loss_absolute_error() {
1787 let data = make_linear_data(100);
1788 let mut gbr = GradientBoostingRegressor::new()
1789 .n_estimators(200)
1790 .loss(RegressionLoss::AbsoluteError)
1791 .learning_rate(0.1)
1792 .max_depth(3);
1793 gbr.fit(&data).unwrap();
1794 let preds = gbr.predict(&[vec![50.0]]).unwrap();
1795 assert!(
1797 (preds[0] - 101.0).abs() < 20.0,
1798 "AbsoluteError pred={}",
1799 preds[0]
1800 );
1801 }
1802
1803 #[test]
1804 fn regressor_loss_huber() {
1805 let data = make_linear_data(100);
1806 let mut gbr = GradientBoostingRegressor::new()
1807 .n_estimators(200)
1808 .loss(RegressionLoss::Huber { alpha: 0.9 })
1809 .learning_rate(0.1)
1810 .max_depth(3);
1811 gbr.fit(&data).unwrap();
1812 let preds = gbr.predict(&[vec![50.0]]).unwrap();
1813 assert!((preds[0] - 101.0).abs() < 20.0, "Huber pred={}", preds[0]);
1814 }
1815
1816 #[test]
1817 fn regressor_loss_quantile_median() {
1818 let data = make_linear_data(100);
1819 let mut gbr = GradientBoostingRegressor::new()
1820 .n_estimators(200)
1821 .loss(RegressionLoss::Quantile { alpha: 0.5 })
1822 .learning_rate(0.1)
1823 .max_depth(3);
1824 gbr.fit(&data).unwrap();
1825 let preds = gbr.predict(&[vec![50.0]]).unwrap();
1826 assert!(
1827 (preds[0] - 101.0).abs() < 25.0,
1828 "Quantile(0.5) pred={}",
1829 preds[0]
1830 );
1831 }
1832
1833 #[test]
1834 fn test_median_helper() {
1835 assert!((median(&[1.0, 3.0, 5.0]) - 3.0).abs() < 1e-12);
1836 assert!((median(&[1.0, 3.0, 5.0, 7.0]) - 4.0).abs() < 1e-12);
1837 assert!((median(&[42.0]) - 42.0).abs() < 1e-12);
1838 assert!((median(&[]) - 0.0).abs() < 1e-12);
1839 }
1840
1841 #[test]
1842 fn test_quantile_helper() {
1843 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1844 assert!((quantile(&data, 0.5) - 3.0).abs() < 1e-12);
1845 assert!((quantile(&data, 0.0) - 1.0).abs() < 1e-12);
1846 assert!((quantile(&data, 1.0) - 5.0).abs() < 1e-12);
1847 assert!((quantile(&data, 0.25) - 2.0).abs() < 1e-12);
1848 }
1849
1850 #[test]
1853 fn regressor_history_populated() {
1854 let data = make_linear_data(50);
1855 let mut gbr = GradientBoostingRegressor::new()
1856 .n_estimators(10)
1857 .learning_rate(0.1)
1858 .max_depth(3);
1859 gbr.fit(&data).unwrap();
1860
1861 let history = gbr.history().expect("history should be populated");
1862 assert_eq!(history.len(), 10);
1863 assert!(history.epochs[0].train_loss > history.epochs[9].train_loss);
1865 assert!(history.epochs[0].grad_norm > 0.0);
1867 }
1868
1869 #[test]
1870 fn classifier_binary_history_populated() {
1871 let data = make_binary_data();
1872 let mut gbc = GradientBoostingClassifier::new()
1873 .n_estimators(10)
1874 .learning_rate(0.1)
1875 .max_depth(2);
1876 gbc.fit(&data).unwrap();
1877
1878 let history = gbc.history().expect("history should be populated");
1879 assert_eq!(history.len(), 10);
1880 assert!(history.epochs[0].train_loss > 0.0);
1881 }
1882
1883 #[test]
1884 fn classifier_multiclass_history_populated() {
1885 let data = make_multiclass_data();
1886 let mut gbc = GradientBoostingClassifier::new()
1887 .n_estimators(10)
1888 .learning_rate(0.1)
1889 .max_depth(2);
1890 gbc.fit(&data).unwrap();
1891
1892 let history = gbc.history().expect("history should be populated");
1893 assert_eq!(history.len(), 10);
1894 assert!(history.epochs[0].train_loss > 0.0);
1895 }
1896
1897 #[test]
1898 fn regressor_early_stopping_history() {
1899 let mut rng = crate::rng::FastRng::new(42);
1900 let n = 50;
1901 let x: Vec<f64> = (0..n).map(|_| rng.f64() * 10.0).collect();
1902 let y: Vec<f64> = x.iter().map(|&v| v.sin() + rng.f64() * 5.0).collect();
1903 let data = Dataset::new(vec![x], y, vec!["x".into()], "y");
1904
1905 let mut gbr = GradientBoostingRegressor::new()
1906 .n_estimators(1000)
1907 .learning_rate(0.5)
1908 .max_depth(5)
1909 .n_iter_no_change(5)
1910 .validation_fraction(0.3)
1911 .tol(0.0);
1912 gbr.fit(&data).unwrap();
1913
1914 let history = gbr.history().expect("history should be populated");
1915 assert_eq!(history.len(), gbr.n_estimators_used());
1917 assert!(history.epochs.last().unwrap().val_loss.is_some());
1919 }
1920}