Skip to main content

scry_learn/tree/
gradient_boosting.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Gradient Boosted Trees — sequential ensemble for classification and regression.
3//!
4//! Each boosting round fits a shallow regression tree to the negative gradient
5//! (pseudo-residuals) of the loss function. Prediction is the sum of all trees
6//! scaled by `learning_rate` plus the initial prediction.
7//!
8//! Uses **Newton-Raphson leaf correction** for classification (second-order
9//! gradient step), matching sklearn's `GradientBoostingClassifier` behavior.
10//!
11//! Internally reuses [`DecisionTreeRegressor`] with [`FlatTree`] for
12//! cache-optimal prediction of each weak learner.
13
14use crate::dataset::Dataset;
15use crate::error::{Result, ScryLearnError};
16use crate::neural::callback::{CallbackAction, EpochMetrics, TrainingCallback, TrainingHistory};
17use crate::tree::cart::{presort_indices, DecisionTreeRegressor};
18use crate::weights::{compute_sample_weights, ClassWeight};
19
20// ═══════════════════════════════════════════════════════════════════════════
21// Regression Loss Functions
22// ═══════════════════════════════════════════════════════════════════════════
23
24/// Loss function for gradient boosting regression.
25///
26/// Controls how pseudo-residuals are computed and how leaf values are
27/// determined. Different losses provide different robustness properties.
28///
29/// - `SquaredError` (default): standard MSE loss, optimal for Gaussian noise.
30/// - `AbsoluteError`: L1 loss (MAE), more robust to outliers.
31/// - `Huber { alpha }`: hybrid of squared and absolute error; `alpha` is
32///   the quantile at which the transition occurs (default 0.9).
33/// - `Quantile { alpha }`: predicts the `alpha`-quantile of the conditional
34///   distribution (default 0.5 = median).
35#[derive(Clone, Debug, Default)]
36#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
37#[non_exhaustive]
38pub enum RegressionLoss {
39    /// Least-squares loss (MSE). Default.
40    #[default]
41    SquaredError,
42    /// Least-absolute-deviation loss (MAE).
43    AbsoluteError,
44    /// Huber loss — squared error for small residuals, absolute for large.
45    /// `alpha` is the quantile threshold (typically 0.9).
46    Huber {
47        /// Quantile threshold for switching between squared and absolute.
48        alpha: f64,
49    },
50    /// Quantile loss — predicts the `alpha`-quantile.
51    /// `alpha` = 0.5 gives the median.
52    Quantile {
53        /// Target quantile in (0, 1).
54        alpha: f64,
55    },
56}
57
58impl RegressionLoss {
59    /// Compute the initial (constant) prediction F₀.
60    fn initial_prediction(&self, y: &[f64]) -> f64 {
61        match self {
62            Self::SquaredError => {
63                let sum: f64 = y.iter().sum();
64                sum / y.len() as f64
65            }
66            Self::AbsoluteError | Self::Huber { .. } => median(y),
67            Self::Quantile { alpha } => quantile(y, *alpha),
68        }
69    }
70
71    /// Compute negative gradient (pseudo-residuals) for sample `i`.
72    ///
73    /// `y` is the true target, `f` is the current prediction.
74    fn negative_gradient(&self, y: f64, f: f64, delta: f64) -> f64 {
75        match self {
76            Self::SquaredError => y - f,
77            Self::AbsoluteError => {
78                if y > f {
79                    1.0
80                } else if y < f {
81                    -1.0
82                } else {
83                    0.0
84                }
85            }
86            Self::Huber { .. } => {
87                let r = y - f;
88                if r.abs() <= delta {
89                    r
90                } else {
91                    delta * r.signum()
92                }
93            }
94            Self::Quantile { alpha } => {
95                if y > f {
96                    *alpha
97                } else if y < f {
98                    -(1.0 - alpha)
99                } else {
100                    0.0
101                }
102            }
103        }
104    }
105
106    /// Compute optimal leaf value for terminal regions.
107    ///
108    /// For SquaredError, mean of residuals is already correct (tree default).
109    /// For other losses, we override leaf predictions.
110    fn update_terminal_value(
111        &self,
112        residuals: &[f64],
113        y_in_leaf: &[f64],
114        f_in_leaf: &[f64],
115        delta: f64,
116    ) -> f64 {
117        match self {
118            Self::SquaredError => {
119                // Tree already computes mean — no override needed.
120                if residuals.is_empty() {
121                    0.0
122                } else {
123                    residuals.iter().sum::<f64>() / residuals.len() as f64
124                }
125            }
126            Self::AbsoluteError => median(residuals),
127            Self::Huber { .. } => {
128                // Median of residuals + mean of clipped tails.
129                let med = median(residuals);
130                let correction: f64 = residuals
131                    .iter()
132                    .map(|&r| {
133                        let diff = r - med;
134                        diff.clamp(-delta, delta)
135                    })
136                    .sum::<f64>()
137                    / residuals.len().max(1) as f64;
138                med + correction
139            }
140            Self::Quantile { alpha } => {
141                // Compute residuals from current predictions.
142                let diffs: Vec<f64> = y_in_leaf
143                    .iter()
144                    .zip(f_in_leaf.iter())
145                    .map(|(&y, &f)| y - f)
146                    .collect();
147                quantile(&diffs, *alpha)
148            }
149        }
150    }
151
152    /// Whether this loss needs terminal region updates (overriding tree leaf values).
153    fn needs_terminal_update(&self) -> bool {
154        !matches!(self, Self::SquaredError)
155    }
156}
157
158/// Compute the median of a slice.
159fn median(data: &[f64]) -> f64 {
160    if data.is_empty() {
161        return 0.0;
162    }
163    let mut sorted: Vec<f64> = data.to_vec();
164    sorted.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
165    let n = sorted.len();
166    if n % 2 == 1 {
167        sorted[n / 2]
168    } else {
169        f64::midpoint(sorted[n / 2 - 1], sorted[n / 2])
170    }
171}
172
173/// Compute the `alpha`-quantile of a slice (linear interpolation).
174fn quantile(data: &[f64], alpha: f64) -> f64 {
175    if data.is_empty() {
176        return 0.0;
177    }
178    let mut sorted: Vec<f64> = data.to_vec();
179    sorted.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
180    let n = sorted.len();
181    if n == 1 {
182        return sorted[0];
183    }
184    let pos = alpha * (n - 1) as f64;
185    let lo = pos.floor() as usize;
186    let hi = pos.ceil() as usize;
187    if lo == hi {
188        sorted[lo]
189    } else {
190        let frac = pos - lo as f64;
191        sorted[lo] * (1.0 - frac) + sorted[hi] * frac
192    }
193}
194
195// ═══════════════════════════════════════════════════════════════════════════
196// Gradient Boosting Regressor
197// ═══════════════════════════════════════════════════════════════════════════
198
199/// Gradient Boosting for regression.
200///
201/// Builds an additive ensemble of shallow decision trees, each fitting the
202/// negative gradient (pseudo-residuals) of the loss function. Supports
203/// stochastic subsampling and multiple loss functions.
204///
205/// # Example
206/// ```
207/// use scry_learn::dataset::Dataset;
208/// use scry_learn::tree::GradientBoostingRegressor;
209///
210/// let features = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]];
211/// let target = vec![2.0, 4.0, 6.0, 8.0, 10.0];
212/// let data = Dataset::new(features, target, vec!["x".into()], "y");
213///
214/// let mut gbr = GradientBoostingRegressor::new()
215///     .n_estimators(50)
216///     .learning_rate(0.1)
217///     .max_depth(3);
218/// gbr.fit(&data).unwrap();
219///
220/// let preds = gbr.predict(&[vec![3.0]]).unwrap();
221/// assert!((preds[0] - 6.0).abs() < 1.0);
222/// ```
223#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
224#[non_exhaustive]
225pub struct GradientBoostingRegressor {
226    n_estimators: usize,
227    learning_rate: f64,
228    max_depth: usize,
229    min_samples_split: usize,
230    min_samples_leaf: usize,
231    subsample: f64,
232    seed: u64,
233    loss: RegressionLoss,
234    validation_fraction: f64,
235    n_iter_no_change: Option<usize>,
236    tol: f64,
237    // Fitted state
238    trees: Vec<DecisionTreeRegressor>,
239    init_prediction: f64,
240    n_features: usize,
241    fitted: bool,
242    n_estimators_used: usize,
243    history: Option<TrainingHistory>,
244    /// User-supplied training callbacks (not cloned or serialized).
245    #[cfg_attr(feature = "serde", serde(skip))]
246    callbacks: Vec<Box<dyn TrainingCallback>>,
247    #[cfg_attr(feature = "serde", serde(default))]
248    _schema_version: u32,
249}
250
251impl Clone for GradientBoostingRegressor {
252    fn clone(&self) -> Self {
253        Self {
254            n_estimators: self.n_estimators,
255            learning_rate: self.learning_rate,
256            max_depth: self.max_depth,
257            min_samples_split: self.min_samples_split,
258            min_samples_leaf: self.min_samples_leaf,
259            subsample: self.subsample,
260            seed: self.seed,
261            loss: self.loss.clone(),
262            validation_fraction: self.validation_fraction,
263            n_iter_no_change: self.n_iter_no_change,
264            tol: self.tol,
265            trees: self.trees.clone(),
266            init_prediction: self.init_prediction,
267            n_features: self.n_features,
268            fitted: self.fitted,
269            n_estimators_used: self.n_estimators_used,
270            history: self.history.clone(),
271            callbacks: Vec::new(),
272            _schema_version: self._schema_version,
273        }
274    }
275}
276
277impl GradientBoostingRegressor {
278    /// Create a new regressor with default parameters.
279    pub fn new() -> Self {
280        Self {
281            n_estimators: 100,
282            learning_rate: 0.1,
283            max_depth: 3,
284            min_samples_split: 2,
285            min_samples_leaf: 1,
286            subsample: 1.0,
287            seed: 42,
288            loss: RegressionLoss::SquaredError,
289            validation_fraction: 0.1,
290            n_iter_no_change: None,
291            tol: crate::constants::DEFAULT_TOL,
292            trees: Vec::new(),
293            init_prediction: 0.0,
294            n_features: 0,
295            fitted: false,
296            n_estimators_used: 0,
297            history: None,
298            callbacks: Vec::new(),
299            _schema_version: crate::version::SCHEMA_VERSION,
300        }
301    }
302
303    /// Set number of boosting rounds.
304    pub fn n_estimators(mut self, n: usize) -> Self {
305        self.n_estimators = n;
306        self
307    }
308
309    /// Set learning rate (shrinkage). Lower values need more estimators.
310    pub fn learning_rate(mut self, lr: f64) -> Self {
311        self.learning_rate = lr;
312        self
313    }
314
315    /// Set maximum depth per tree (default: 3, shallow stumps).
316    pub fn max_depth(mut self, d: usize) -> Self {
317        self.max_depth = d;
318        self
319    }
320
321    /// Set minimum samples required to split an internal node.
322    pub fn min_samples_split(mut self, n: usize) -> Self {
323        self.min_samples_split = n;
324        self
325    }
326
327    /// Set minimum samples required in a leaf node.
328    pub fn min_samples_leaf(mut self, n: usize) -> Self {
329        self.min_samples_leaf = n;
330        self
331    }
332
333    /// Set subsample fraction (0.0, 1.0] for stochastic GBT.
334    pub fn subsample(mut self, s: f64) -> Self {
335        self.subsample = s;
336        self
337    }
338
339    /// Set random seed.
340    pub fn seed(mut self, s: u64) -> Self {
341        self.seed = s;
342        self
343    }
344
345    /// Enable early stopping. Training stops when validation loss does not
346    /// improve for `n` consecutive rounds.
347    pub fn n_iter_no_change(mut self, n: usize) -> Self {
348        self.n_iter_no_change = Some(n);
349        self
350    }
351
352    /// Set fraction of training data to use as validation for early stopping
353    /// (default: 0.1).
354    pub fn validation_fraction(mut self, frac: f64) -> Self {
355        self.validation_fraction = frac;
356        self
357    }
358
359    /// Set tolerance for early stopping (default: 1e-4).
360    pub fn tol(mut self, t: f64) -> Self {
361        self.tol = t;
362        self
363    }
364
365    /// Add a training callback (invoked after each boosting round).
366    pub fn callback(mut self, cb: Box<dyn TrainingCallback>) -> Self {
367        self.callbacks.push(cb);
368        self
369    }
370
371    /// Number of estimators actually used (may be less than `n_estimators`
372    /// if early stopping triggered).
373    pub fn n_estimators_used(&self) -> usize {
374        self.n_estimators_used
375    }
376
377    /// Set the regression loss function.
378    ///
379    /// # Example
380    /// ```
381    /// use scry_learn::tree::{GradientBoostingRegressor, RegressionLoss};
382    ///
383    /// let gbr = GradientBoostingRegressor::new()
384    ///     .loss(RegressionLoss::Huber { alpha: 0.9 });
385    /// ```
386    pub fn loss(mut self, l: RegressionLoss) -> Self {
387        self.loss = l;
388        self
389    }
390
391    /// Train the gradient boosting ensemble.
392    pub fn fit(&mut self, data: &Dataset) -> Result<()> {
393        data.validate_finite()?;
394        let n = data.n_samples();
395        if n == 0 {
396            return Err(ScryLearnError::EmptyDataset);
397        }
398        if self.learning_rate <= 0.0 || self.learning_rate > 1.0 {
399            return Err(ScryLearnError::InvalidParameter(
400                "learning_rate must be in (0, 1]".into(),
401            ));
402        }
403        if self.subsample <= 0.0 || self.subsample > 1.0 {
404            return Err(ScryLearnError::InvalidParameter(
405                "subsample must be in (0, 1]".into(),
406            ));
407        }
408
409        self.n_features = data.n_features();
410
411        // ── Early stopping: split into train / validation ──
412        let (train_data, val_data) = if self.n_iter_no_change.is_some() {
413            let (t, v) = crate::split::train_test_split(data, self.validation_fraction, self.seed);
414            (t, Some(v))
415        } else {
416            (data.clone(), None)
417        };
418        let n_train = train_data.n_samples();
419
420        // F₀ = loss-specific initial prediction
421        let init = self.loss.initial_prediction(&train_data.target);
422        self.init_prediction = init;
423
424        // Current predictions for each training sample.
425        let mut f_vals = vec![init; n_train];
426
427        // Compute Huber delta (quantile of |y - F₀|) — used throughout training.
428        let delta = match &self.loss {
429            RegressionLoss::Huber { alpha } => {
430                let abs_resid: Vec<f64> = train_data
431                    .target
432                    .iter()
433                    .zip(f_vals.iter())
434                    .map(|(&y, &f)| (y - f).abs())
435                    .collect();
436                quantile(&abs_resid, *alpha)
437            }
438            _ => 0.0, // unused for other losses
439        };
440
441        let mut rng = crate::rng::FastRng::new(self.seed);
442        let all_indices: Vec<usize> = (0..n_train).collect();
443        self.trees = Vec::with_capacity(self.n_estimators);
444
445        // Reusable dataset — share features, replace target each round.
446        let mut temp_data = Dataset::new(
447            train_data.features.clone(),
448            vec![0.0; n_train],
449            train_data.feature_names.clone(),
450            "residual",
451        );
452        let row_major = train_data.feature_matrix();
453
454        // Pre-sort indices once — reused across all boosting rounds.
455        // Feature values don't change between rounds (only targets/residuals do),
456        // so the sorted order is valid throughout training.
457        let global_sorted = presort_indices(&temp_data, &all_indices);
458
459        // Early stopping state.
460        let mut best_val_loss = f64::INFINITY;
461        let mut no_improve_count = 0usize;
462        let patience = self.n_iter_no_change.unwrap_or(usize::MAX);
463
464        let mut history = TrainingHistory::new();
465        let mut callbacks = std::mem::take(&mut self.callbacks);
466
467        for round in 0..self.n_estimators {
468            let round_start = std::time::Instant::now();
469
470            // Compute negative gradient (pseudo-residuals).
471            for (i, fv) in f_vals.iter().enumerate().take(n_train) {
472                temp_data.target[i] = self
473                    .loss
474                    .negative_gradient(train_data.target[i], *fv, delta);
475            }
476
477            // Subsample indices.
478            let indices = subsample_indices(n_train, self.subsample, &mut rng, &all_indices);
479
480            // Fit a shallow regression tree to the pseudo-residuals.
481            let mut tree = DecisionTreeRegressor::new()
482                .max_depth(self.max_depth)
483                .min_samples_split(self.min_samples_split)
484                .min_samples_leaf(self.min_samples_leaf);
485            tree.fit_on_indices_presorted(&temp_data, &indices, &global_sorted)?;
486
487            // For non-squared-error losses, override leaf values with
488            // the loss-specific optimal terminal region update.
489            if self.loss.needs_terminal_update() {
490                if let Some(ref mut flat) = tree.flat_tree {
491                    // Compute leaf assignments for training samples.
492                    let leaf_ids = flat.apply(&row_major);
493                    let n_nodes = flat.n_nodes();
494                    let mut leaf_residuals: Vec<Vec<f64>> = vec![Vec::new(); n_nodes];
495                    let mut leaf_y: Vec<Vec<f64>> = vec![Vec::new(); n_nodes];
496                    let mut leaf_f: Vec<Vec<f64>> = vec![Vec::new(); n_nodes];
497                    for (i, &lid) in leaf_ids.iter().enumerate() {
498                        leaf_residuals[lid].push(temp_data.target[i]);
499                        leaf_y[lid].push(train_data.target[i]);
500                        leaf_f[lid].push(f_vals[i]);
501                    }
502                    for node_id in 0..n_nodes {
503                        if !leaf_residuals[node_id].is_empty() {
504                            let new_val = self.loss.update_terminal_value(
505                                &leaf_residuals[node_id],
506                                &leaf_y[node_id],
507                                &leaf_f[node_id],
508                                delta,
509                            );
510                            flat.set_leaf_prediction(node_id, new_val);
511                        }
512                    }
513                }
514            }
515
516            // Update predictions: F(x_i) += η × tree.predict(x_i)
517            let tree_preds = tree.predict(&row_major)?;
518            for (f_val, &tp) in f_vals.iter_mut().zip(tree_preds.iter()) {
519                *f_val += self.learning_rate * tp;
520            }
521
522            self.trees.push(tree);
523
524            // Compute training loss (MSE on training set).
525            let train_mse: f64 = train_data
526                .target
527                .iter()
528                .zip(f_vals.iter())
529                .map(|(&y, &f)| (y - f).powi(2))
530                .sum::<f64>()
531                / n_train as f64;
532
533            // Gradient norm: L2 norm of pseudo-residuals (approximation for trees).
534            let grad_norm: f64 = temp_data
535                .target
536                .iter()
537                .take(n_train)
538                .map(|&r| r * r)
539                .sum::<f64>()
540                .sqrt();
541
542            let elapsed = round_start.elapsed().as_millis() as u64;
543
544            let metrics = EpochMetrics {
545                epoch: round,
546                train_loss: train_mse,
547                val_loss: None, // updated below if early stopping
548                train_metric: None,
549                val_metric: None,
550                learning_rate: self.learning_rate,
551                grad_norm,
552                elapsed_ms: elapsed,
553            };
554
555            let mut cb_stop = false;
556            for cb in &mut callbacks {
557                if cb.on_epoch_end(&metrics) == CallbackAction::Stop {
558                    cb_stop = true;
559                }
560            }
561
562            history.push(metrics);
563
564            if cb_stop {
565                self.n_estimators_used = round + 1;
566                self.fitted = true;
567                for cb in &mut callbacks {
568                    cb.on_training_end();
569                }
570                self.callbacks = callbacks;
571                self.history = Some(history);
572                return Ok(());
573            }
574
575            // ── Check early stopping ──
576            if let Some(ref val) = val_data {
577                let val_features = val.feature_matrix();
578                let mut val_preds = vec![self.init_prediction; val_features.len()];
579                for t in &self.trees {
580                    if let Ok(tp) = t.predict(&val_features) {
581                        for (p, &v) in val_preds.iter_mut().zip(tp.iter()) {
582                            *p += self.learning_rate * v;
583                        }
584                    }
585                }
586                let val_mse: f64 = val
587                    .target
588                    .iter()
589                    .zip(val_preds.iter())
590                    .map(|(&y, &p)| (y - p).powi(2))
591                    .sum::<f64>()
592                    / val.target.len() as f64;
593
594                // Record val_loss in history.
595                if let Some(last) = history.epochs.last_mut() {
596                    last.val_loss = Some(val_mse);
597                }
598
599                if val_mse + self.tol < best_val_loss {
600                    best_val_loss = val_mse;
601                    no_improve_count = 0;
602                } else {
603                    no_improve_count += 1;
604                    if no_improve_count >= patience {
605                        self.n_estimators_used = round + 1;
606                        self.fitted = true;
607                        for cb in &mut callbacks {
608                            cb.on_training_end();
609                        }
610                        self.callbacks = callbacks;
611                        self.history = Some(history);
612                        return Ok(());
613                    }
614                }
615            }
616        }
617
618        self.n_estimators_used = self.trees.len();
619        self.fitted = true;
620        for cb in &mut callbacks {
621            cb.on_training_end();
622        }
623        self.callbacks = callbacks;
624        self.history = Some(history);
625        Ok(())
626    }
627
628    /// Predict values for new samples.
629    ///
630    /// `features` is row-major: `features[sample_idx][feature_idx]`.
631    pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
632        crate::version::check_schema_version(self._schema_version)?;
633        if !self.fitted {
634            return Err(ScryLearnError::NotFitted);
635        }
636        let n = features.len();
637        let mut preds = vec![self.init_prediction; n];
638        for tree in &self.trees {
639            let tp = tree.predict(features)?;
640            for (p, &t) in preds.iter_mut().zip(tp.iter()) {
641                *p += self.learning_rate * t;
642            }
643        }
644        Ok(preds)
645    }
646
647    /// Feature importances averaged across all trees.
648    pub fn feature_importances(&self) -> Result<Vec<f64>> {
649        if !self.fitted {
650            return Err(ScryLearnError::NotFitted);
651        }
652        let m = self.n_features;
653        let mut importances = vec![0.0; m];
654        let n_trees = self.trees.len() as f64;
655        for tree in &self.trees {
656            if let Ok(imp) = tree.feature_importances() {
657                for (i, &v) in imp.iter().enumerate() {
658                    if i < m {
659                        importances[i] += v / n_trees;
660                    }
661                }
662            }
663        }
664        // Normalize.
665        let total: f64 = importances.iter().sum();
666        if total > 0.0 {
667            for v in &mut importances {
668                *v /= total;
669            }
670        }
671        Ok(importances)
672    }
673
674    /// Number of estimators (trees) in the ensemble.
675    pub fn n_trees(&self) -> usize {
676        self.trees.len()
677    }
678
679    /// Whether early stopping was triggered.
680    pub fn early_stopped(&self) -> bool {
681        self.n_iter_no_change.is_some() && self.n_estimators_used < self.n_estimators
682    }
683
684    /// Return training history (populated after `fit()`).
685    pub fn history(&self) -> Option<&TrainingHistory> {
686        self.history.as_ref()
687    }
688
689    /// Get individual trees (for inspection or ONNX export).
690    pub fn trees(&self) -> &[DecisionTreeRegressor] {
691        &self.trees
692    }
693
694    /// Number of features the model was trained on.
695    pub fn n_features(&self) -> usize {
696        self.n_features
697    }
698
699    /// Learning rate value.
700    pub fn learning_rate_val(&self) -> f64 {
701        self.learning_rate
702    }
703
704    /// Initial (base) prediction value.
705    pub fn init_prediction_val(&self) -> f64 {
706        self.init_prediction
707    }
708}
709
710impl Default for GradientBoostingRegressor {
711    fn default() -> Self {
712        Self::new()
713    }
714}
715
716// ═══════════════════════════════════════════════════════════════════════════
717// Gradient Boosting Classifier
718// ═══════════════════════════════════════════════════════════════════════════
719
720/// Gradient Boosting for classification (binary + multiclass).
721///
722/// - Binary: fits a single sequence of trees to log-loss pseudo-residuals.
723/// - Multiclass (K > 2): fits K sequences of trees (one-vs-all softmax).
724///
725/// Uses **Newton-Raphson leaf correction** (second-order gradient step) for
726/// optimal leaf values, matching sklearn's `GradientBoostingClassifier`.
727///
728/// # Example
729/// ```
730/// use scry_learn::dataset::Dataset;
731/// use scry_learn::tree::GradientBoostingClassifier;
732///
733/// // Simple linearly separable data.
734/// let features = vec![
735///     vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
736///     vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
737/// ];
738/// let target = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
739/// let data = Dataset::new(features, target, vec!["x1".into(), "x2".into()], "class");
740///
741/// let mut gbc = GradientBoostingClassifier::new()
742///     .n_estimators(50)
743///     .learning_rate(0.1)
744///     .max_depth(2);
745/// gbc.fit(&data).unwrap();
746///
747/// let preds = gbc.predict(&[vec![1.5, 0.15], vec![5.5, 0.55]]).unwrap();
748/// assert_eq!(preds[0], 0.0);
749/// assert_eq!(preds[1], 1.0);
750/// ```
751#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
752#[non_exhaustive]
753pub struct GradientBoostingClassifier {
754    n_estimators: usize,
755    learning_rate: f64,
756    max_depth: usize,
757    min_samples_split: usize,
758    min_samples_leaf: usize,
759    subsample: f64,
760    seed: u64,
761    class_weight: ClassWeight,
762    // Fitted state — trees[class_idx][estimator_idx]
763    trees: Vec<Vec<DecisionTreeRegressor>>,
764    init_predictions: Vec<f64>,
765    n_classes: usize,
766    n_features: usize,
767    fitted: bool,
768    history: Option<TrainingHistory>,
769    /// User-supplied training callbacks (not cloned or serialized).
770    #[cfg_attr(feature = "serde", serde(skip))]
771    callbacks: Vec<Box<dyn TrainingCallback>>,
772    #[cfg_attr(feature = "serde", serde(default))]
773    _schema_version: u32,
774}
775
776impl Clone for GradientBoostingClassifier {
777    fn clone(&self) -> Self {
778        Self {
779            n_estimators: self.n_estimators,
780            learning_rate: self.learning_rate,
781            max_depth: self.max_depth,
782            min_samples_split: self.min_samples_split,
783            min_samples_leaf: self.min_samples_leaf,
784            subsample: self.subsample,
785            seed: self.seed,
786            class_weight: self.class_weight.clone(),
787            trees: self.trees.clone(),
788            init_predictions: self.init_predictions.clone(),
789            n_classes: self.n_classes,
790            n_features: self.n_features,
791            fitted: self.fitted,
792            history: self.history.clone(),
793            callbacks: Vec::new(),
794            _schema_version: self._schema_version,
795        }
796    }
797}
798
799impl GradientBoostingClassifier {
800    /// Create a new classifier with default parameters.
801    pub fn new() -> Self {
802        Self {
803            n_estimators: 100,
804            learning_rate: 0.1,
805            max_depth: 3,
806            min_samples_split: 2,
807            min_samples_leaf: 1,
808            subsample: 1.0,
809            seed: 42,
810            class_weight: ClassWeight::Uniform,
811            trees: Vec::new(),
812            init_predictions: Vec::new(),
813            n_classes: 0,
814            n_features: 0,
815            fitted: false,
816            history: None,
817            callbacks: Vec::new(),
818            _schema_version: crate::version::SCHEMA_VERSION,
819        }
820    }
821
822    /// Set number of boosting rounds.
823    pub fn n_estimators(mut self, n: usize) -> Self {
824        self.n_estimators = n;
825        self
826    }
827
828    /// Set learning rate (shrinkage).
829    pub fn learning_rate(mut self, lr: f64) -> Self {
830        self.learning_rate = lr;
831        self
832    }
833
834    /// Set maximum depth per tree.
835    pub fn max_depth(mut self, d: usize) -> Self {
836        self.max_depth = d;
837        self
838    }
839
840    /// Set minimum samples required to split.
841    pub fn min_samples_split(mut self, n: usize) -> Self {
842        self.min_samples_split = n;
843        self
844    }
845
846    /// Set minimum samples required in a leaf.
847    pub fn min_samples_leaf(mut self, n: usize) -> Self {
848        self.min_samples_leaf = n;
849        self
850    }
851
852    /// Set subsample fraction for stochastic GBT.
853    pub fn subsample(mut self, s: f64) -> Self {
854        self.subsample = s;
855        self
856    }
857
858    /// Set random seed.
859    pub fn seed(mut self, s: u64) -> Self {
860        self.seed = s;
861        self
862    }
863
864    /// Set class weighting strategy for imbalanced datasets.
865    pub fn class_weight(mut self, cw: ClassWeight) -> Self {
866        self.class_weight = cw;
867        self
868    }
869
870    /// Add a training callback (invoked after each boosting round).
871    pub fn callback(mut self, cb: Box<dyn TrainingCallback>) -> Self {
872        self.callbacks.push(cb);
873        self
874    }
875
876    /// Train the gradient boosting classifier.
877    pub fn fit(&mut self, data: &Dataset) -> Result<()> {
878        data.validate_finite()?;
879        let n = data.n_samples();
880        if n == 0 {
881            return Err(ScryLearnError::EmptyDataset);
882        }
883        if self.learning_rate <= 0.0 || self.learning_rate > 1.0 {
884            return Err(ScryLearnError::InvalidParameter(
885                "learning_rate must be in (0, 1]".into(),
886            ));
887        }
888        if self.subsample <= 0.0 || self.subsample > 1.0 {
889            return Err(ScryLearnError::InvalidParameter(
890                "subsample must be in (0, 1]".into(),
891            ));
892        }
893
894        self.n_features = data.n_features();
895        self.n_classes = data.n_classes();
896        let k = self.n_classes;
897
898        if k < 2 {
899            return Err(ScryLearnError::InvalidParameter(
900                "need at least 2 classes for classification".into(),
901            ));
902        }
903
904        let mut rng = crate::rng::FastRng::new(self.seed);
905        let all_indices: Vec<usize> = (0..n).collect();
906        let row_major = data.feature_matrix();
907        let sample_weights = compute_sample_weights(&data.target, &self.class_weight);
908
909        if k == 2 {
910            // ── Binary classification via log-loss ──
911            self.fit_binary(data, n, &mut rng, &all_indices, &row_major, &sample_weights)?;
912        } else {
913            // ── Multiclass via softmax (K one-vs-all) ──
914            self.fit_multiclass(
915                data,
916                n,
917                k,
918                &mut rng,
919                &all_indices,
920                &row_major,
921                &sample_weights,
922            )?;
923        }
924        Ok(())
925    }
926
927    /// Binary classification: single sequence of trees with Newton leaf correction.
928    fn fit_binary(
929        &mut self,
930        data: &Dataset,
931        n: usize,
932        rng: &mut crate::rng::FastRng,
933        all_indices: &[usize],
934        row_major: &[Vec<f64>],
935        sample_weights: &[f64],
936    ) -> Result<()> {
937        // Class prior: p = count(y=1) / n
938        let pos_count = data.target.iter().filter(|&&y| y > 0.5).count();
939        let p = (pos_count as f64) / (n as f64);
940        let p_clamped = p.clamp(
941            crate::constants::GBT_PROB_CLAMP,
942            1.0 - crate::constants::GBT_PROB_CLAMP,
943        );
944        let f0 = (p_clamped / (1.0 - p_clamped)).ln(); // log-odds
945        self.init_predictions = vec![f0];
946
947        let mut f_vals = vec![f0; n];
948        let mut trees_seq = Vec::with_capacity(self.n_estimators);
949        let mut history = TrainingHistory::new();
950        let mut callbacks = std::mem::take(&mut self.callbacks);
951
952        // Reusable dataset — share features, replace target each round.
953        let mut temp_data = Dataset::new(
954            data.features.clone(),
955            vec![0.0; n],
956            data.feature_names.clone(),
957            "residual",
958        );
959
960        // Pre-sort indices once — reused across all boosting rounds.
961        let global_sorted = presort_indices(&temp_data, all_indices);
962
963        for round in 0..self.n_estimators {
964            let round_start = std::time::Instant::now();
965
966            // Compute probabilities and pseudo-residuals.
967            let probs: Vec<f64> = f_vals.iter().map(|&f| sigmoid(f)).collect();
968
969            // pseudo-residuals: r_i = weight_i * (y_i - sigmoid(F(x_i)))
970            for i in 0..n {
971                temp_data.target[i] = sample_weights[i] * (data.target[i] - probs[i]);
972            }
973
974            let indices = subsample_indices(n, self.subsample, rng, all_indices);
975
976            let mut tree = DecisionTreeRegressor::new()
977                .max_depth(self.max_depth)
978                .min_samples_split(self.min_samples_split)
979                .min_samples_leaf(self.min_samples_leaf);
980            tree.fit_on_indices_presorted(&temp_data, &indices, &global_sorted)?;
981
982            // ── Newton-Raphson leaf correction ──
983            // For each leaf, replace the mean residual with:
984            //   leaf_value = Σ(residual_i) / Σ(p_i × (1 - p_i))
985            // This is the optimal Newton step for log-loss.
986            if let Some(ref mut flat) = tree.flat_tree {
987                let leaf_indices = flat.apply(row_major);
988                newton_correct_binary_leaves(
989                    flat,
990                    &leaf_indices,
991                    &temp_data.target, // residuals
992                    &probs,
993                );
994            }
995
996            let tp = tree.predict(row_major)?;
997            for (f_val, &t) in f_vals.iter_mut().zip(tp.iter()) {
998                *f_val += self.learning_rate * t;
999            }
1000
1001            trees_seq.push(tree);
1002
1003            // Binary cross-entropy loss: -mean(y*log(p) + (1-y)*log(1-p))
1004            let probs_after: Vec<f64> = f_vals.iter().map(|&f| sigmoid(f)).collect();
1005            let train_loss: f64 = data
1006                .target
1007                .iter()
1008                .zip(probs_after.iter())
1009                .map(|(&y, &p)| {
1010                    let p_c = p.clamp(
1011                        crate::constants::NEAR_ZERO,
1012                        1.0 - crate::constants::NEAR_ZERO,
1013                    );
1014                    -(y * p_c.ln() + (1.0 - y) * (1.0 - p_c).ln())
1015                })
1016                .sum::<f64>()
1017                / n as f64;
1018
1019            // Gradient norm: L2 norm of residuals.
1020            let grad_norm: f64 = temp_data
1021                .target
1022                .iter()
1023                .take(n)
1024                .map(|&r| r * r)
1025                .sum::<f64>()
1026                .sqrt();
1027
1028            let elapsed = round_start.elapsed().as_millis() as u64;
1029
1030            let metrics = EpochMetrics {
1031                epoch: round,
1032                train_loss,
1033                val_loss: None,
1034                train_metric: None,
1035                val_metric: None,
1036                learning_rate: self.learning_rate,
1037                grad_norm,
1038                elapsed_ms: elapsed,
1039            };
1040
1041            let mut cb_stop = false;
1042            for cb in &mut callbacks {
1043                if cb.on_epoch_end(&metrics) == CallbackAction::Stop {
1044                    cb_stop = true;
1045                }
1046            }
1047
1048            history.push(metrics);
1049
1050            if cb_stop {
1051                break;
1052            }
1053        }
1054
1055        self.trees = vec![trees_seq];
1056        self.fitted = true;
1057        for cb in &mut callbacks {
1058            cb.on_training_end();
1059        }
1060        self.callbacks = callbacks;
1061        self.history = Some(history);
1062        Ok(())
1063    }
1064
1065    /// Multiclass classification: K parallel tree sequences (softmax) with Newton correction.
1066    #[allow(clippy::too_many_arguments)]
1067    fn fit_multiclass(
1068        &mut self,
1069        data: &Dataset,
1070        n: usize,
1071        k: usize,
1072        rng: &mut crate::rng::FastRng,
1073        all_indices: &[usize],
1074        row_major: &[Vec<f64>],
1075        sample_weights: &[f64],
1076    ) -> Result<()> {
1077        // Build one-hot targets: y_k[i] = 1 if target[i] == k, else 0.
1078        let y_onehot: Vec<Vec<f64>> = (0..k)
1079            .map(|cls| {
1080                data.target
1081                    .iter()
1082                    .map(|&y| if (y as usize) == cls { 1.0 } else { 0.0 })
1083                    .collect()
1084            })
1085            .collect();
1086
1087        // Initial predictions: log of class priors.
1088        let class_counts: Vec<usize> = (0..k)
1089            .map(|cls| data.target.iter().filter(|&&y| (y as usize) == cls).count())
1090            .collect();
1091        let init_preds: Vec<f64> = class_counts
1092            .iter()
1093            .map(|&c| {
1094                let p = (c as f64 / n as f64).clamp(
1095                    crate::constants::GBT_PROB_CLAMP,
1096                    1.0 - crate::constants::GBT_PROB_CLAMP,
1097                );
1098                p.ln()
1099            })
1100            .collect();
1101        self.init_predictions.clone_from(&init_preds);
1102
1103        // f_vals[class][sample]
1104        let mut f_vals: Vec<Vec<f64>> = (0..k).map(|c| vec![init_preds[c]; n]).collect();
1105
1106        let mut trees_all: Vec<Vec<DecisionTreeRegressor>> = (0..k)
1107            .map(|_| Vec::with_capacity(self.n_estimators))
1108            .collect();
1109        let mut history = TrainingHistory::new();
1110        let mut callbacks = std::mem::take(&mut self.callbacks);
1111
1112        // Reusable dataset — share features, replace target each round.
1113        let mut temp_data = Dataset::new(
1114            data.features.clone(),
1115            vec![0.0; n],
1116            data.feature_names.clone(),
1117            "residual",
1118        );
1119
1120        // Pre-sort indices once — reused across all boosting rounds.
1121        let global_sorted = presort_indices(&temp_data, all_indices);
1122
1123        for round in 0..self.n_estimators {
1124            let round_start = std::time::Instant::now();
1125            // Compute softmax probabilities.
1126            let probs = softmax_matrix(&f_vals, n, k);
1127
1128            let indices = subsample_indices(n, self.subsample, rng, all_indices);
1129
1130            // Fit one tree per class.
1131            for cls in 0..k {
1132                // pseudo-residuals: r_i = weight_i * (y_k[i] - p_k[i])
1133                for i in 0..n {
1134                    temp_data.target[i] = sample_weights[i] * (y_onehot[cls][i] - probs[cls][i]);
1135                }
1136
1137                let mut tree = DecisionTreeRegressor::new()
1138                    .max_depth(self.max_depth)
1139                    .min_samples_split(self.min_samples_split)
1140                    .min_samples_leaf(self.min_samples_leaf);
1141                tree.fit_on_indices_presorted(&temp_data, &indices, &global_sorted)?;
1142
1143                // ── Newton-Raphson leaf correction for multiclass ──
1144                // For each leaf:
1145                //   leaf_value = (K-1)/K × Σ(residual_i) / Σ(p_i × (1 - p_i))
1146                if let Some(ref mut flat) = tree.flat_tree {
1147                    let leaf_indices = flat.apply(row_major);
1148                    newton_correct_multiclass_leaves(
1149                        flat,
1150                        &leaf_indices,
1151                        &temp_data.target, // residuals
1152                        &probs[cls],       // softmax probabilities for this class
1153                        k,
1154                    );
1155                }
1156
1157                let tp = tree.predict(row_major)?;
1158                for (f_val, &t) in f_vals[cls].iter_mut().zip(tp.iter()) {
1159                    *f_val += self.learning_rate * t;
1160                }
1161
1162                trees_all[cls].push(tree);
1163            }
1164
1165            // Cross-entropy loss: -mean(sum_k y_k * log(p_k))
1166            let probs_after = softmax_matrix(&f_vals, n, k);
1167            let train_loss: f64 = (0..n)
1168                .map(|i| {
1169                    let cls_i = data.target[i] as usize;
1170                    let p = probs_after[cls_i][i].clamp(
1171                        crate::constants::NEAR_ZERO,
1172                        1.0 - crate::constants::NEAR_ZERO,
1173                    );
1174                    -p.ln()
1175                })
1176                .sum::<f64>()
1177                / n as f64;
1178
1179            // Gradient norm: L2 norm of last class's residuals (representative).
1180            let grad_norm: f64 = temp_data
1181                .target
1182                .iter()
1183                .take(n)
1184                .map(|&r| r * r)
1185                .sum::<f64>()
1186                .sqrt();
1187
1188            let elapsed = round_start.elapsed().as_millis() as u64;
1189
1190            let metrics = EpochMetrics {
1191                epoch: round,
1192                train_loss,
1193                val_loss: None,
1194                train_metric: None,
1195                val_metric: None,
1196                learning_rate: self.learning_rate,
1197                grad_norm,
1198                elapsed_ms: elapsed,
1199            };
1200
1201            let mut cb_stop = false;
1202            for cb in &mut callbacks {
1203                if cb.on_epoch_end(&metrics) == CallbackAction::Stop {
1204                    cb_stop = true;
1205                }
1206            }
1207
1208            history.push(metrics);
1209
1210            if cb_stop {
1211                break;
1212            }
1213        }
1214
1215        self.trees = trees_all;
1216        self.fitted = true;
1217        for cb in &mut callbacks {
1218            cb.on_training_end();
1219        }
1220        self.callbacks = callbacks;
1221        self.history = Some(history);
1222        Ok(())
1223    }
1224
1225    /// Predict class labels for new samples.
1226    pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
1227        crate::version::check_schema_version(self._schema_version)?;
1228        if !self.fitted {
1229            return Err(ScryLearnError::NotFitted);
1230        }
1231        let proba = self.predict_proba(features)?;
1232        Ok(proba
1233            .iter()
1234            .map(|row| {
1235                row.iter()
1236                    .enumerate()
1237                    .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
1238                    .map_or(0.0, |(idx, _)| idx as f64)
1239            })
1240            .collect())
1241    }
1242
1243    /// Predict class probabilities for new samples.
1244    pub fn predict_proba(&self, features: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
1245        if !self.fitted {
1246            return Err(ScryLearnError::NotFitted);
1247        }
1248        let n = features.len();
1249        let k = self.n_classes;
1250
1251        if k == 2 {
1252            // Binary: single tree sequence.
1253            let mut f_vals = vec![self.init_predictions[0]; n];
1254            for tree in &self.trees[0] {
1255                let tp = tree.predict(features)?;
1256                for (f, &t) in f_vals.iter_mut().zip(tp.iter()) {
1257                    *f += self.learning_rate * t;
1258                }
1259            }
1260            Ok(f_vals
1261                .iter()
1262                .map(|&f| {
1263                    let p1 = sigmoid(f);
1264                    vec![1.0 - p1, p1]
1265                })
1266                .collect())
1267        } else {
1268            // Multiclass: K tree sequences → softmax.
1269            let mut f_vals: Vec<Vec<f64>> =
1270                (0..k).map(|c| vec![self.init_predictions[c]; n]).collect();
1271            for (cls_fvals, cls_trees) in f_vals.iter_mut().zip(self.trees.iter()).take(k) {
1272                for tree in cls_trees {
1273                    let tp = tree.predict(features)?;
1274                    for (f, &t) in cls_fvals.iter_mut().zip(tp.iter()) {
1275                        *f += self.learning_rate * t;
1276                    }
1277                }
1278            }
1279            // Softmax across classes for each sample.
1280            let mut result = Vec::with_capacity(n);
1281            #[allow(clippy::needless_range_loop)]
1282            for i in 0..n {
1283                let logits: Vec<f64> = (0..k).map(|c| f_vals[c][i]).collect();
1284                let max_l = logits.iter().copied().fold(f64::NEG_INFINITY, f64::max);
1285                let exps: Vec<f64> = logits.iter().map(|&l| (l - max_l).exp()).collect();
1286                let sum: f64 = exps.iter().sum();
1287                result.push(exps.iter().map(|&e| e / sum).collect());
1288            }
1289            Ok(result)
1290        }
1291    }
1292
1293    /// Feature importances averaged across all trees.
1294    pub fn feature_importances(&self) -> Result<Vec<f64>> {
1295        if !self.fitted {
1296            return Err(ScryLearnError::NotFitted);
1297        }
1298        let m = self.n_features;
1299        let mut importances = vec![0.0; m];
1300        let mut total_trees = 0.0;
1301        for class_trees in &self.trees {
1302            for tree in class_trees {
1303                if let Ok(imp) = tree.feature_importances() {
1304                    for (i, &v) in imp.iter().enumerate() {
1305                        if i < m {
1306                            importances[i] += v;
1307                        }
1308                    }
1309                }
1310                total_trees += 1.0;
1311            }
1312        }
1313        if total_trees > 0.0 {
1314            for v in &mut importances {
1315                *v /= total_trees;
1316            }
1317        }
1318        let total: f64 = importances.iter().sum();
1319        if total > 0.0 {
1320            for v in &mut importances {
1321                *v /= total;
1322            }
1323        }
1324        Ok(importances)
1325    }
1326
1327    /// Number of classes.
1328    pub fn n_classes(&self) -> usize {
1329        self.n_classes
1330    }
1331
1332    /// Total number of trees across all class sequences.
1333    pub fn n_trees(&self) -> usize {
1334        self.trees.iter().map(Vec::len).sum()
1335    }
1336
1337    /// Return training history (populated after `fit()`).
1338    pub fn history(&self) -> Option<&TrainingHistory> {
1339        self.history.as_ref()
1340    }
1341
1342    /// Get tree sequences per class (for inspection or ONNX export).
1343    /// `class_trees()[class_idx][estimator_idx]` is the tree for that class/round.
1344    pub fn class_trees(&self) -> &[Vec<DecisionTreeRegressor>] {
1345        &self.trees
1346    }
1347
1348    /// Number of features the model was trained on.
1349    pub fn n_features(&self) -> usize {
1350        self.n_features
1351    }
1352
1353    /// Learning rate value.
1354    pub fn learning_rate_val(&self) -> f64 {
1355        self.learning_rate
1356    }
1357
1358    /// Initial predictions per class.
1359    pub fn init_predictions_val(&self) -> &[f64] {
1360        &self.init_predictions
1361    }
1362}
1363
1364impl Default for GradientBoostingClassifier {
1365    fn default() -> Self {
1366        Self::new()
1367    }
1368}
1369
1370// ═══════════════════════════════════════════════════════════════════════════
1371// Newton-Raphson leaf correction
1372// ═══════════════════════════════════════════════════════════════════════════
1373
1374/// Newton-Raphson correction for binary log-loss.
1375///
1376/// For each leaf, replace the mean residual with:
1377///   leaf_value = Σ(residual_i) / Σ(p_i × (1 - p_i))
1378///
1379/// where p_i = sigmoid(F(x_i)) and residual_i = y_i - p_i.
1380///
1381/// This is the optimal second-order correction step from Friedman (2001).
1382fn newton_correct_binary_leaves(
1383    flat: &mut crate::tree::cart::FlatTree,
1384    leaf_indices: &[usize],
1385    residuals: &[f64],
1386    probs: &[f64],
1387) {
1388    use std::collections::HashMap;
1389
1390    // Accumulate numerator (Σresid) and denominator (Σp*(1-p)) per leaf.
1391    let mut leaf_num: HashMap<usize, f64> = HashMap::new();
1392    let mut leaf_den: HashMap<usize, f64> = HashMap::new();
1393
1394    for (i, &leaf_idx) in leaf_indices.iter().enumerate() {
1395        let r = residuals[i];
1396        let p = probs[i];
1397        let hessian = p * (1.0 - p);
1398        *leaf_num.entry(leaf_idx).or_insert(0.0) += r;
1399        *leaf_den.entry(leaf_idx).or_insert(0.0) += hessian;
1400    }
1401
1402    // Overwrite leaf predictions with Newton-corrected values.
1403    for (&leaf_idx, &num) in &leaf_num {
1404        let den = leaf_den[&leaf_idx];
1405        // Avoid division by zero; fall back to gradient mean.
1406        if den.abs() > crate::constants::SINGULAR_THRESHOLD {
1407            flat.set_leaf_prediction(leaf_idx, num / den);
1408        }
1409    }
1410}
1411
1412/// Newton-Raphson correction for multiclass softmax.
1413///
1414/// For each leaf, replace the mean residual with:
1415///   leaf_value = (K-1)/K × Σ(residual_i) / Σ(p_i × (1 - p_i))
1416///
1417/// where p_i is the softmax probability for the current class.
1418/// Uses the exact diagonal Hessian p(1-p), matching sklearn, XGBoost,
1419/// and LightGBM (not the Friedman 2001 |r|(1-|r|) approximation).
1420fn newton_correct_multiclass_leaves(
1421    flat: &mut crate::tree::cart::FlatTree,
1422    leaf_indices: &[usize],
1423    residuals: &[f64],
1424    probs: &[f64],
1425    k: usize,
1426) {
1427    use std::collections::HashMap;
1428
1429    let factor = (k - 1) as f64 / k as f64;
1430
1431    let mut leaf_num: HashMap<usize, f64> = HashMap::new();
1432    let mut leaf_den: HashMap<usize, f64> = HashMap::new();
1433
1434    for (i, &leaf_idx) in leaf_indices.iter().enumerate() {
1435        let r = residuals[i];
1436        let p = probs[i];
1437        let hessian = (p * (1.0 - p)).max(crate::constants::SINGULAR_THRESHOLD);
1438        *leaf_num.entry(leaf_idx).or_insert(0.0) += r;
1439        *leaf_den.entry(leaf_idx).or_insert(0.0) += hessian;
1440    }
1441
1442    for (&leaf_idx, &num) in &leaf_num {
1443        let den = leaf_den[&leaf_idx];
1444        if den.abs() > crate::constants::SINGULAR_THRESHOLD {
1445            flat.set_leaf_prediction(leaf_idx, factor * num / den);
1446        }
1447    }
1448}
1449
1450// ═══════════════════════════════════════════════════════════════════════════
1451// Helper functions
1452// ═══════════════════════════════════════════════════════════════════════════
1453
1454#[inline]
1455fn sigmoid(x: f64) -> f64 {
1456    1.0 / (1.0 + (-x).exp())
1457}
1458
1459/// Compute softmax probabilities: probs[class][sample].
1460fn softmax_matrix(f_vals: &[Vec<f64>], n: usize, k: usize) -> Vec<Vec<f64>> {
1461    let mut probs = vec![vec![0.0; n]; k];
1462    for i in 0..n {
1463        let logits: Vec<f64> = (0..k).map(|c| f_vals[c][i]).collect();
1464        let max_l = logits.iter().copied().fold(f64::NEG_INFINITY, f64::max);
1465        let exps: Vec<f64> = logits.iter().map(|&l| (l - max_l).exp()).collect();
1466        let sum: f64 = exps.iter().sum();
1467        for c in 0..k {
1468            probs[c][i] = exps[c] / sum;
1469        }
1470    }
1471    probs
1472}
1473
1474/// Subsample indices using Fisher-Yates partial shuffle.
1475fn subsample_indices(
1476    n: usize,
1477    subsample: f64,
1478    rng: &mut crate::rng::FastRng,
1479    all_indices: &[usize],
1480) -> Vec<usize> {
1481    if subsample >= 1.0 {
1482        return all_indices.to_vec();
1483    }
1484    let k = ((n as f64) * subsample).ceil() as usize;
1485    let mut idx = all_indices.to_vec();
1486    for i in 0..k.min(n) {
1487        let j = rng.usize(i..n);
1488        idx.swap(i, j);
1489    }
1490    idx.truncate(k);
1491    idx
1492}
1493
1494// ═══════════════════════════════════════════════════════════════════════════
1495// Unit tests
1496// ═══════════════════════════════════════════════════════════════════════════
1497
1498#[cfg(test)]
1499#[allow(clippy::float_cmp)]
1500mod tests {
1501    use super::*;
1502
1503    fn make_linear_data(n: usize) -> Dataset {
1504        let x: Vec<f64> = (0..n).map(|i| i as f64).collect();
1505        let y: Vec<f64> = x.iter().map(|&v| 2.0 * v + 1.0).collect(); // y = 2x + 1
1506        Dataset::new(vec![x], y, vec!["x".into()], "y")
1507    }
1508
1509    fn make_binary_data() -> Dataset {
1510        // Two linearly separable clusters.
1511        let mut f1 = Vec::new();
1512        let mut f2 = Vec::new();
1513        let mut target = Vec::new();
1514        for i in 0..50 {
1515            let v = i as f64 / 50.0;
1516            f1.push(v);
1517            f2.push(v * 0.5);
1518            target.push(0.0);
1519        }
1520        for i in 0..50 {
1521            let v = 1.0 + i as f64 / 50.0;
1522            f1.push(v);
1523            f2.push(v * 0.5);
1524            target.push(1.0);
1525        }
1526        Dataset::new(vec![f1, f2], target, vec!["f1".into(), "f2".into()], "cls")
1527    }
1528
1529    fn make_multiclass_data() -> Dataset {
1530        let mut f1 = Vec::new();
1531        let mut f2 = Vec::new();
1532        let mut target = Vec::new();
1533        for i in 0..30 {
1534            f1.push(i as f64 / 30.0);
1535            f2.push(0.0);
1536            target.push(0.0);
1537        }
1538        for i in 0..30 {
1539            f1.push(2.0 + i as f64 / 30.0);
1540            f2.push(0.0);
1541            target.push(1.0);
1542        }
1543        for i in 0..30 {
1544            f1.push(4.0 + i as f64 / 30.0);
1545            f2.push(0.0);
1546            target.push(2.0);
1547        }
1548        Dataset::new(vec![f1, f2], target, vec!["f1".into(), "f2".into()], "cls")
1549    }
1550
1551    // ─── Regressor tests ───
1552
1553    #[test]
1554    fn regressor_learns_linear() {
1555        let data = make_linear_data(100);
1556        let mut gbr = GradientBoostingRegressor::new()
1557            .n_estimators(100)
1558            .learning_rate(0.1)
1559            .max_depth(3);
1560        gbr.fit(&data).unwrap();
1561
1562        let preds = gbr.predict(&[vec![50.0], vec![75.0]]).unwrap();
1563        // y = 2x + 1 → 101, 151
1564        assert!((preds[0] - 101.0).abs() < 10.0, "pred={}", preds[0]);
1565        assert!((preds[1] - 151.0).abs() < 15.0, "pred={}", preds[1]);
1566    }
1567
1568    #[test]
1569    fn regressor_not_fitted_error() {
1570        let gbr = GradientBoostingRegressor::new();
1571        assert!(gbr.predict(&[vec![1.0]]).is_err());
1572        assert!(gbr.feature_importances().is_err());
1573    }
1574
1575    #[test]
1576    fn regressor_subsample() {
1577        let data = make_linear_data(100);
1578        let mut gbr = GradientBoostingRegressor::new()
1579            .n_estimators(50)
1580            .subsample(0.7)
1581            .learning_rate(0.1)
1582            .max_depth(3);
1583        gbr.fit(&data).unwrap();
1584        let preds = gbr.predict(&[vec![25.0]]).unwrap();
1585        // Should still learn something reasonable.
1586        assert!((preds[0] - 51.0).abs() < 15.0, "pred={}", preds[0]);
1587    }
1588
1589    #[test]
1590    fn regressor_feature_importances() {
1591        let data = make_linear_data(100);
1592        let mut gbr = GradientBoostingRegressor::new()
1593            .n_estimators(20)
1594            .max_depth(2);
1595        gbr.fit(&data).unwrap();
1596        let imp = gbr.feature_importances().unwrap();
1597        assert_eq!(imp.len(), 1);
1598        assert!(
1599            (imp[0] - 1.0).abs() < 1e-6,
1600            "single feature should have importance 1.0"
1601        );
1602    }
1603
1604    #[test]
1605    fn regressor_invalid_params() {
1606        let data = make_linear_data(10);
1607        let mut gbr = GradientBoostingRegressor::new().learning_rate(0.0);
1608        assert!(gbr.fit(&data).is_err());
1609
1610        let mut gbr = GradientBoostingRegressor::new().subsample(1.5);
1611        assert!(gbr.fit(&data).is_err());
1612    }
1613
1614    #[test]
1615    fn regressor_early_stopping() {
1616        // Use noisy data where overfitting will occur with aggressive settings.
1617        let mut rng = crate::rng::FastRng::new(42);
1618        let n = 50;
1619        let x: Vec<f64> = (0..n).map(|_| rng.f64() * 10.0).collect();
1620        // y = sin(x) + heavy noise — tree will overfit noise.
1621        let y: Vec<f64> = x.iter().map(|&v| v.sin() + rng.f64() * 5.0).collect();
1622        let data = Dataset::new(vec![x], y, vec!["x".into()], "y");
1623
1624        let mut gbr = GradientBoostingRegressor::new()
1625            .n_estimators(1000)
1626            .learning_rate(0.5)
1627            .max_depth(5)
1628            .n_iter_no_change(5)
1629            .validation_fraction(0.3)
1630            .tol(0.0);
1631        gbr.fit(&data).unwrap();
1632
1633        // With 1000 max estimators, heavy noise, and patience of 5,
1634        // early stopping should kick in well before 1000.
1635        assert!(
1636            gbr.n_trees() < 1000,
1637            "Expected early stopping, but used all {} estimators",
1638            gbr.n_trees()
1639        );
1640        assert!(gbr.early_stopped(), "early_stopped() should be true");
1641        assert!(gbr.n_estimators_used() < 1000);
1642    }
1643
1644    // ─── Classifier tests ───
1645
1646    #[test]
1647    fn classifier_binary() {
1648        let data = make_binary_data();
1649        let mut gbc = GradientBoostingClassifier::new()
1650            .n_estimators(50)
1651            .learning_rate(0.1)
1652            .max_depth(2);
1653        gbc.fit(&data).unwrap();
1654
1655        let test = vec![vec![0.2, 0.1], vec![1.5, 0.75]];
1656        let preds = gbc.predict(&test).unwrap();
1657        assert_eq!(preds[0], 0.0, "low values -> class 0");
1658        assert_eq!(preds[1], 1.0, "high values -> class 1");
1659    }
1660
1661    #[test]
1662    fn classifier_binary_proba() {
1663        let data = make_binary_data();
1664        let mut gbc = GradientBoostingClassifier::new()
1665            .n_estimators(50)
1666            .learning_rate(0.1)
1667            .max_depth(2);
1668        gbc.fit(&data).unwrap();
1669
1670        let probas = gbc.predict_proba(&[vec![0.2, 0.1]]).unwrap();
1671        assert_eq!(probas[0].len(), 2);
1672        let sum: f64 = probas[0].iter().sum();
1673        assert!((sum - 1.0).abs() < 1e-6, "probabilities should sum to 1");
1674        assert!(probas[0][0] > probas[0][1], "class 0 should be more likely");
1675    }
1676
1677    #[test]
1678    fn classifier_multiclass() {
1679        let data = make_multiclass_data();
1680        let mut gbc = GradientBoostingClassifier::new()
1681            .n_estimators(100)
1682            .learning_rate(0.1)
1683            .max_depth(3);
1684        gbc.fit(&data).unwrap();
1685
1686        let test = vec![vec![0.5, 0.0], vec![2.5, 0.0], vec![4.5, 0.0]];
1687        let preds = gbc.predict(&test).unwrap();
1688        assert_eq!(preds[0], 0.0, "should be class 0");
1689        assert_eq!(preds[1], 1.0, "should be class 1");
1690        assert_eq!(preds[2], 2.0, "should be class 2");
1691    }
1692
1693    #[test]
1694    fn classifier_multiclass_proba() {
1695        let data = make_multiclass_data();
1696        let mut gbc = GradientBoostingClassifier::new()
1697            .n_estimators(50)
1698            .learning_rate(0.1)
1699            .max_depth(2);
1700        gbc.fit(&data).unwrap();
1701
1702        let probas = gbc.predict_proba(&[vec![0.5, 0.0]]).unwrap();
1703        assert_eq!(probas[0].len(), 3);
1704        let sum: f64 = probas[0].iter().sum();
1705        assert!((sum - 1.0).abs() < 1e-6, "probabilities should sum to 1");
1706    }
1707
1708    #[test]
1709    fn classifier_subsample() {
1710        let data = make_binary_data();
1711        let mut gbc = GradientBoostingClassifier::new()
1712            .n_estimators(50)
1713            .subsample(0.8)
1714            .learning_rate(0.1)
1715            .max_depth(2);
1716        gbc.fit(&data).unwrap();
1717
1718        let test = vec![vec![0.2, 0.1], vec![1.5, 0.75]];
1719        let preds = gbc.predict(&test).unwrap();
1720        assert_eq!(preds[0], 0.0);
1721        assert_eq!(preds[1], 1.0);
1722    }
1723
1724    #[test]
1725    fn classifier_feature_importances() {
1726        let data = make_binary_data();
1727        let mut gbc = GradientBoostingClassifier::new()
1728            .n_estimators(20)
1729            .max_depth(2);
1730        gbc.fit(&data).unwrap();
1731        let imp = gbc.feature_importances().unwrap();
1732        assert_eq!(imp.len(), 2);
1733        let sum: f64 = imp.iter().sum();
1734        assert!((sum - 1.0).abs() < 1e-4, "importances should sum to 1");
1735    }
1736
1737    #[test]
1738    fn classifier_not_fitted_error() {
1739        let gbc = GradientBoostingClassifier::new();
1740        assert!(gbc.predict(&[vec![1.0, 2.0]]).is_err());
1741        assert!(gbc.predict_proba(&[vec![1.0, 2.0]]).is_err());
1742        assert!(gbc.feature_importances().is_err());
1743    }
1744
1745    #[test]
1746    fn classifier_n_trees_binary() {
1747        let data = make_binary_data();
1748        let mut gbc = GradientBoostingClassifier::new()
1749            .n_estimators(25)
1750            .max_depth(2);
1751        gbc.fit(&data).unwrap();
1752        assert_eq!(gbc.n_trees(), 25, "binary: 1 class × 25 rounds");
1753    }
1754
1755    #[test]
1756    fn classifier_n_trees_multiclass() {
1757        let data = make_multiclass_data();
1758        let mut gbc = GradientBoostingClassifier::new()
1759            .n_estimators(10)
1760            .max_depth(2);
1761        gbc.fit(&data).unwrap();
1762        assert_eq!(gbc.n_trees(), 30, "multiclass: 3 classes × 10 rounds");
1763    }
1764
1765    // ─── Loss function tests ───
1766
1767    #[test]
1768    fn regressor_loss_squared_error_default() {
1769        // Verify default behaviour is unchanged: SquaredError.
1770        let data = make_linear_data(100);
1771        let mut gbr = GradientBoostingRegressor::new()
1772            .n_estimators(100)
1773            .loss(RegressionLoss::SquaredError)
1774            .learning_rate(0.1)
1775            .max_depth(3);
1776        gbr.fit(&data).unwrap();
1777        let preds = gbr.predict(&[vec![50.0]]).unwrap();
1778        assert!(
1779            (preds[0] - 101.0).abs() < 10.0,
1780            "SquaredError pred={}",
1781            preds[0]
1782        );
1783    }
1784
1785    #[test]
1786    fn regressor_loss_absolute_error() {
1787        let data = make_linear_data(100);
1788        let mut gbr = GradientBoostingRegressor::new()
1789            .n_estimators(200)
1790            .loss(RegressionLoss::AbsoluteError)
1791            .learning_rate(0.1)
1792            .max_depth(3);
1793        gbr.fit(&data).unwrap();
1794        let preds = gbr.predict(&[vec![50.0]]).unwrap();
1795        // y = 2x + 1 → 101
1796        assert!(
1797            (preds[0] - 101.0).abs() < 20.0,
1798            "AbsoluteError pred={}",
1799            preds[0]
1800        );
1801    }
1802
1803    #[test]
1804    fn regressor_loss_huber() {
1805        let data = make_linear_data(100);
1806        let mut gbr = GradientBoostingRegressor::new()
1807            .n_estimators(200)
1808            .loss(RegressionLoss::Huber { alpha: 0.9 })
1809            .learning_rate(0.1)
1810            .max_depth(3);
1811        gbr.fit(&data).unwrap();
1812        let preds = gbr.predict(&[vec![50.0]]).unwrap();
1813        assert!((preds[0] - 101.0).abs() < 20.0, "Huber pred={}", preds[0]);
1814    }
1815
1816    #[test]
1817    fn regressor_loss_quantile_median() {
1818        let data = make_linear_data(100);
1819        let mut gbr = GradientBoostingRegressor::new()
1820            .n_estimators(200)
1821            .loss(RegressionLoss::Quantile { alpha: 0.5 })
1822            .learning_rate(0.1)
1823            .max_depth(3);
1824        gbr.fit(&data).unwrap();
1825        let preds = gbr.predict(&[vec![50.0]]).unwrap();
1826        assert!(
1827            (preds[0] - 101.0).abs() < 25.0,
1828            "Quantile(0.5) pred={}",
1829            preds[0]
1830        );
1831    }
1832
1833    #[test]
1834    fn test_median_helper() {
1835        assert!((median(&[1.0, 3.0, 5.0]) - 3.0).abs() < 1e-12);
1836        assert!((median(&[1.0, 3.0, 5.0, 7.0]) - 4.0).abs() < 1e-12);
1837        assert!((median(&[42.0]) - 42.0).abs() < 1e-12);
1838        assert!((median(&[]) - 0.0).abs() < 1e-12);
1839    }
1840
1841    #[test]
1842    fn test_quantile_helper() {
1843        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1844        assert!((quantile(&data, 0.5) - 3.0).abs() < 1e-12);
1845        assert!((quantile(&data, 0.0) - 1.0).abs() < 1e-12);
1846        assert!((quantile(&data, 1.0) - 5.0).abs() < 1e-12);
1847        assert!((quantile(&data, 0.25) - 2.0).abs() < 1e-12);
1848    }
1849
1850    // ─── Training history tests ───
1851
1852    #[test]
1853    fn regressor_history_populated() {
1854        let data = make_linear_data(50);
1855        let mut gbr = GradientBoostingRegressor::new()
1856            .n_estimators(10)
1857            .learning_rate(0.1)
1858            .max_depth(3);
1859        gbr.fit(&data).unwrap();
1860
1861        let history = gbr.history().expect("history should be populated");
1862        assert_eq!(history.len(), 10);
1863        // Loss should decrease over rounds.
1864        assert!(history.epochs[0].train_loss > history.epochs[9].train_loss);
1865        // Grad norms should be positive.
1866        assert!(history.epochs[0].grad_norm > 0.0);
1867    }
1868
1869    #[test]
1870    fn classifier_binary_history_populated() {
1871        let data = make_binary_data();
1872        let mut gbc = GradientBoostingClassifier::new()
1873            .n_estimators(10)
1874            .learning_rate(0.1)
1875            .max_depth(2);
1876        gbc.fit(&data).unwrap();
1877
1878        let history = gbc.history().expect("history should be populated");
1879        assert_eq!(history.len(), 10);
1880        assert!(history.epochs[0].train_loss > 0.0);
1881    }
1882
1883    #[test]
1884    fn classifier_multiclass_history_populated() {
1885        let data = make_multiclass_data();
1886        let mut gbc = GradientBoostingClassifier::new()
1887            .n_estimators(10)
1888            .learning_rate(0.1)
1889            .max_depth(2);
1890        gbc.fit(&data).unwrap();
1891
1892        let history = gbc.history().expect("history should be populated");
1893        assert_eq!(history.len(), 10);
1894        assert!(history.epochs[0].train_loss > 0.0);
1895    }
1896
1897    #[test]
1898    fn regressor_early_stopping_history() {
1899        let mut rng = crate::rng::FastRng::new(42);
1900        let n = 50;
1901        let x: Vec<f64> = (0..n).map(|_| rng.f64() * 10.0).collect();
1902        let y: Vec<f64> = x.iter().map(|&v| v.sin() + rng.f64() * 5.0).collect();
1903        let data = Dataset::new(vec![x], y, vec!["x".into()], "y");
1904
1905        let mut gbr = GradientBoostingRegressor::new()
1906            .n_estimators(1000)
1907            .learning_rate(0.5)
1908            .max_depth(5)
1909            .n_iter_no_change(5)
1910            .validation_fraction(0.3)
1911            .tol(0.0);
1912        gbr.fit(&data).unwrap();
1913
1914        let history = gbr.history().expect("history should be populated");
1915        // History length should match n_estimators_used.
1916        assert_eq!(history.len(), gbr.n_estimators_used());
1917        // Some epochs should have val_loss.
1918        assert!(history.epochs.last().unwrap().val_loss.is_some());
1919    }
1920}