Skip to main content

scry_learn/tree/
histogram_gbt.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Histogram-based Gradient Boosted Trees — O(n) split finding.
3//!
4//! This module implements the core innovation behind XGBoost, LightGBM, and
5//! CatBoost: instead of sorting features at each split, the data is pre-binned
6//! into 256 `u8` bins and gradients are accumulated into fixed-size histograms.
7//! Split finding becomes O(256) per feature per leaf, regardless of dataset size.
8//!
9//! ## Key Optimizations
10//!
11//! - **Histogram subtraction trick**: parent − left = right (halves histogram
12//!   construction cost).
13//! - **Leaf-wise (best-first) growth**: grows the leaf with the highest gain,
14//!   matching LightGBM's strategy for deeper, more accurate trees.
15//! - **SIMD-friendly layout**: histograms are contiguous `[HistBin; 256]` arrays,
16//!   enabling auto-vectorization.
17//! - **Rayon parallelism** for histogram construction across features.
18//!
19//! # Example
20//! ```
21//! use scry_learn::dataset::Dataset;
22//! use scry_learn::tree::HistGradientBoostingRegressor;
23//!
24//! let features = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]];
25//! let target = vec![2.0, 4.0, 6.0, 8.0, 10.0];
26//! let data = Dataset::new(features, target, vec!["x".into()], "y");
27//!
28//! let mut model = HistGradientBoostingRegressor::new()
29//!     .n_estimators(100)
30//!     .learning_rate(0.1)
31//!     .max_leaf_nodes(31);
32//! model.fit(&data).unwrap();
33//!
34//! let preds = model.predict(&[vec![3.0]]).unwrap();
35//! assert!((preds[0] - 6.0).abs() < 1.0);
36//! ```
37
38use crate::dataset::Dataset;
39use crate::error::{Result, ScryLearnError};
40use crate::tree::binning::FeatureBinner;
41
42use rayon::prelude::*;
43
44// ═══════════════════════════════════════════════════════════════════════════
45// Histogram data structures
46// ═══════════════════════════════════════════════════════════════════════════
47
48/// Number of histogram bins.
49const NUM_BINS: usize = 256;
50
51/// A single histogram bin accumulating gradient statistics.
52#[derive(Clone, Copy, Debug, Default)]
53#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
54struct HistBin {
55    grad_sum: f64,
56    hess_sum: f64,
57    count: u32,
58}
59
60/// Histogram for one feature: 256 bins of gradient/hessian sums.
61///
62/// Contiguous layout for SIMD-friendly access during split search.
63type FeatureHistogram = [HistBin; NUM_BINS];
64
65/// Build histograms for all features from the binned data.
66///
67/// When the `gpu` feature is enabled and a GPU backend is available, this
68/// delegates to [`ComputeBackend::build_histograms`] for acceleration.
69/// Otherwise it falls back to the Rayon-parallel CPU path.
70fn build_histograms(
71    binned: &[Vec<u8>], // [feature][sample]
72    gradients: &[f64],
73    hessians: &[f64],
74    sample_indices: &[usize],
75    n_features: usize,
76) -> Vec<FeatureHistogram> {
77    // Try GPU-accelerated path when the feature is enabled.
78    #[cfg(feature = "scry-gpu")]
79    {
80        if let Ok(gpu) = crate::accel::ScryGpuBackend::new() {
81            use crate::accel::ComputeBackend;
82            let accel_hists = gpu.build_histograms(
83                binned,
84                gradients,
85                hessians,
86                sample_indices,
87                n_features,
88                NUM_BINS,
89            );
90            return accel_hists
91                .into_iter()
92                .map(|feat_bins| {
93                    let mut hist: FeatureHistogram = [HistBin::default(); NUM_BINS];
94                    for (b, &(g, h, c)) in feat_bins.iter().enumerate() {
95                        if b < NUM_BINS {
96                            hist[b].grad_sum = g;
97                            hist[b].hess_sum = h;
98                            hist[b].count = c as u32;
99                        }
100                    }
101                    hist
102                })
103                .collect();
104        }
105    }
106
107    // CPU fallback: parallel across features via Rayon.
108    (0..n_features)
109        .into_par_iter()
110        .map(|f| {
111            let col = &binned[f];
112            let mut hist: FeatureHistogram = [HistBin::default(); NUM_BINS];
113            for &idx in sample_indices {
114                let bin = col[idx] as usize;
115                hist[bin].grad_sum += gradients[idx];
116                hist[bin].hess_sum += hessians[idx];
117                hist[bin].count += 1;
118            }
119            hist
120        })
121        .collect()
122}
123
124/// Histogram subtraction: parent − left = right.
125fn subtract_histograms(
126    parent: &[FeatureHistogram],
127    left: &[FeatureHistogram],
128) -> Vec<FeatureHistogram> {
129    parent
130        .iter()
131        .zip(left.iter())
132        .map(|(p, l)| {
133            let mut right = [HistBin::default(); NUM_BINS];
134            for b in 0..NUM_BINS {
135                right[b].grad_sum = p[b].grad_sum - l[b].grad_sum;
136                right[b].hess_sum = p[b].hess_sum - l[b].hess_sum;
137                right[b].count = p[b].count.saturating_sub(l[b].count);
138            }
139            right
140        })
141        .collect()
142}
143
144// ═══════════════════════════════════════════════════════════════════════════
145// Internal tree representation
146// ═══════════════════════════════════════════════════════════════════════════
147
148/// A node in the histogram-based tree.
149#[derive(Clone, Debug)]
150#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
151enum HistNode {
152    /// Leaf node with a prediction value.
153    Leaf { value: f64 },
154    /// Internal split node.
155    Split {
156        feature: usize,
157        bin_threshold: u8,
158        left: usize, // index into HistTree::nodes
159        right: usize,
160        gain: f64,
161    },
162}
163
164/// Public view of a HistNode for ONNX export, with bin thresholds converted
165/// to raw feature value thresholds.
166#[derive(Clone, Debug)]
167#[non_exhaustive]
168pub enum HistNodeView {
169    /// Leaf node.
170    Leaf {
171        /// Prediction value.
172        value: f64,
173    },
174    /// Internal split node.
175    Split {
176        /// Feature index.
177        feature: usize,
178        /// Raw feature threshold (≤ goes left).
179        threshold: f64,
180        /// Left child index.
181        left: usize,
182        /// Right child index.
183        right: usize,
184    },
185}
186
187/// A tree built from histogram-based splits.
188#[derive(Clone, Debug)]
189#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
190struct HistTree {
191    nodes: Vec<HistNode>,
192}
193
194impl HistTree {
195    /// Predict for a single sample.
196    fn predict_one(&self, sample_binned: &[u8]) -> f64 {
197        let mut node_idx = 0;
198        loop {
199            match &self.nodes[node_idx] {
200                HistNode::Leaf { value } => return *value,
201                HistNode::Split {
202                    feature,
203                    bin_threshold,
204                    left,
205                    right,
206                    ..
207                } => {
208                    if sample_binned[*feature] <= *bin_threshold {
209                        node_idx = *left;
210                    } else {
211                        node_idx = *right;
212                    }
213                }
214            }
215        }
216    }
217
218    /// Predict for a single raw (unbinned) sample using bin edges.
219    fn predict_one_raw(&self, sample: &[f64], binner: &FeatureBinner) -> f64 {
220        let mut node_idx = 0;
221        loop {
222            match &self.nodes[node_idx] {
223                HistNode::Leaf { value } => return *value,
224                HistNode::Split {
225                    feature,
226                    bin_threshold,
227                    left,
228                    right,
229                    ..
230                } => {
231                    let val = sample[*feature];
232                    let bin = if val.is_nan() {
233                        0u8
234                    } else {
235                        let edges = &binner.bin_edges()[*feature];
236                        let pos = match edges.binary_search_by(|edge| {
237                            edge.partial_cmp(&val).unwrap_or(std::cmp::Ordering::Equal)
238                        }) {
239                            Ok(p) => p + 1,
240                            Err(p) => p,
241                        };
242                        (pos + 1).min(255) as u8
243                    };
244                    if bin <= *bin_threshold {
245                        node_idx = *left;
246                    } else {
247                        node_idx = *right;
248                    }
249                }
250            }
251        }
252    }
253
254    /// Collect feature importance (total gain) from this tree.
255    fn feature_importances(&self, n_features: usize) -> Vec<f64> {
256        let mut imp = vec![0.0; n_features];
257        for node in &self.nodes {
258            if let HistNode::Split { feature, gain, .. } = node {
259                if *feature < n_features {
260                    imp[*feature] += gain;
261                }
262            }
263        }
264        imp
265    }
266
267    /// Convert to public HistNodeView representation, translating bin
268    /// thresholds to raw feature value thresholds using the binner.
269    fn to_node_views(&self, binner: &FeatureBinner) -> Vec<HistNodeView> {
270        let edges = binner.bin_edges();
271        self.nodes
272            .iter()
273            .map(|node| match node {
274                HistNode::Leaf { value } => HistNodeView::Leaf { value: *value },
275                HistNode::Split {
276                    feature,
277                    bin_threshold,
278                    left,
279                    right,
280                    ..
281                } => {
282                    // Convert bin threshold to raw value threshold.
283                    // bin k corresponds to values in [edges[k-2], edges[k-1]).
284                    // bin <= bin_threshold means val < edges[bin_threshold - 1].
285                    // For ONNX BRANCH_LEQ (val <= T), use edges[bin_threshold - 1].
286                    let threshold = if *bin_threshold == 0 || *feature >= edges.len() {
287                        f64::NEG_INFINITY
288                    } else {
289                        let feat_edges = &edges[*feature];
290                        let idx = (*bin_threshold as usize).saturating_sub(1);
291                        if idx < feat_edges.len() {
292                            feat_edges[idx]
293                        } else if !feat_edges.is_empty() {
294                            feat_edges[feat_edges.len() - 1]
295                        } else {
296                            0.0
297                        }
298                    };
299                    HistNodeView::Split {
300                        feature: *feature,
301                        threshold,
302                        left: *left,
303                        right: *right,
304                    }
305                }
306            })
307            .collect()
308    }
309}
310
311/// Candidate leaf for best-first (leaf-wise) growth.
312struct LeafCandidate {
313    /// Index into the tree's nodes Vec (this is a Leaf node).
314    node_idx: usize,
315    /// Sample indices falling into this leaf.
316    sample_indices: Vec<usize>,
317    /// Pre-computed histograms for this leaf.
318    histograms: Vec<FeatureHistogram>,
319    /// Total gradient sum in this leaf.
320    grad_sum: f64,
321    /// Total hessian sum in this leaf.
322    hess_sum: f64,
323    /// Depth of this leaf.
324    depth: usize,
325}
326
327/// Result of scanning one leaf for the best split.
328struct SplitResult {
329    feature: usize,
330    bin_threshold: u8,
331    gain: f64,
332    left_indices: Vec<usize>,
333    right_indices: Vec<usize>,
334    left_value: f64,
335    right_value: f64,
336    left_grad_sum: f64,
337    left_hess_sum: f64,
338    right_grad_sum: f64,
339    right_hess_sum: f64,
340}
341
342/// L2-regularized leaf value: −G / (H + λ).
343///
344/// Guards against near-zero denominator which would produce extreme
345/// leaf values that destabilize the boosting ensemble.
346#[inline]
347fn leaf_value(grad_sum: f64, hess_sum: f64, l2_reg: f64) -> f64 {
348    let denom = hess_sum + l2_reg;
349    if denom.abs() < 1e-10 {
350        0.0
351    } else {
352        -grad_sum / denom
353    }
354}
355
356/// Split gain: G_L²/(H_L+λ) + G_R²/(H_R+λ) − G²/(H+λ).
357#[inline]
358fn split_gain(
359    grad_left: f64,
360    hess_left: f64,
361    grad_right: f64,
362    hess_right: f64,
363    l2_reg: f64,
364) -> f64 {
365    let left_term = grad_left * grad_left / (hess_left + l2_reg);
366    let right_term = grad_right * grad_right / (hess_right + l2_reg);
367    let parent_grad = grad_left + grad_right;
368    let parent_hess = hess_left + hess_right;
369    let parent_term = parent_grad * parent_grad / (parent_hess + l2_reg);
370    0.5 * (left_term + right_term - parent_term)
371}
372
373/// Find the best split across all features for a given leaf.
374#[allow(clippy::too_many_arguments)]
375fn find_best_split(
376    histograms: &[FeatureHistogram],
377    binned: &[Vec<u8>],
378    sample_indices: &[usize],
379    grad_sum: f64,
380    hess_sum: f64,
381    min_samples_leaf: usize,
382    l2_reg: f64,
383    n_features: usize,
384) -> Option<SplitResult> {
385    let mut best_gain = 0.0; // only accept positive gains
386    let mut best_feature = 0;
387    let mut best_threshold: u8 = 0;
388    let mut best_left_grad = 0.0;
389    let mut best_left_hess = 0.0;
390
391    for (f, hist) in histograms.iter().enumerate().take(n_features) {
392        let mut running_grad = 0.0;
393        let mut running_hess = 0.0;
394        let mut running_count: u32 = 0;
395        let total_count = sample_indices.len() as u32;
396
397        // Scan bins left-to-right (including bin 0 for NaN).
398        for bin in 0..255u8 {
399            let b = bin as usize;
400            running_grad += hist[b].grad_sum;
401            running_hess += hist[b].hess_sum;
402            running_count += hist[b].count;
403
404            let right_count = total_count.saturating_sub(running_count);
405            if (running_count as usize) < min_samples_leaf
406                || (right_count as usize) < min_samples_leaf
407            {
408                continue;
409            }
410
411            let right_grad = grad_sum - running_grad;
412            let right_hess = hess_sum - running_hess;
413
414            let gain = split_gain(running_grad, running_hess, right_grad, right_hess, l2_reg);
415
416            if gain > best_gain {
417                best_gain = gain;
418                best_feature = f;
419                best_threshold = bin;
420                best_left_grad = running_grad;
421                best_left_hess = running_hess;
422            }
423        }
424    }
425
426    if best_gain <= 0.0 {
427        return None;
428    }
429
430    // Split sample indices.
431    let col = &binned[best_feature];
432    let mut left_indices = Vec::new();
433    let mut right_indices = Vec::new();
434    for &idx in sample_indices {
435        if col[idx] <= best_threshold {
436            left_indices.push(idx);
437        } else {
438            right_indices.push(idx);
439        }
440    }
441
442    let best_right_grad = grad_sum - best_left_grad;
443    let best_right_hess = hess_sum - best_left_hess;
444
445    // Build left histogram to enable histogram subtraction trick.
446    // (The caller will rebuild with the actual gradients.)
447
448    Some(SplitResult {
449        feature: best_feature,
450        bin_threshold: best_threshold,
451        gain: best_gain,
452        left_indices,
453        right_indices,
454        left_value: leaf_value(best_left_grad, best_left_hess, l2_reg),
455        right_value: leaf_value(best_right_grad, best_right_hess, l2_reg),
456        left_grad_sum: best_left_grad,
457        left_hess_sum: best_left_hess,
458        right_grad_sum: best_right_grad,
459        right_hess_sum: best_right_hess,
460    })
461}
462
463/// Build a single tree using leaf-wise (best-first) growth.
464#[allow(clippy::too_many_arguments)]
465fn build_tree_leaf_wise(
466    binned: &[Vec<u8>],
467    gradients: &[f64],
468    hessians: &[f64],
469    sample_indices: &[usize],
470    max_leaf_nodes: usize,
471    min_samples_leaf: usize,
472    max_depth: usize,
473    l2_reg: f64,
474    n_features: usize,
475) -> HistTree {
476    let mut nodes: Vec<HistNode> = Vec::new();
477
478    // Compute initial sums.
479    let total_grad: f64 = sample_indices.iter().map(|&i| gradients[i]).sum();
480    let total_hess: f64 = sample_indices.iter().map(|&i| hessians[i]).sum();
481
482    let root_value = leaf_value(total_grad, total_hess, l2_reg);
483    nodes.push(HistNode::Leaf { value: root_value });
484
485    // Build root histograms.
486    let root_histograms = build_histograms(binned, gradients, hessians, sample_indices, n_features);
487
488    // Priority queue of splittable leaves (sorted by best gain).
489    let mut candidates: Vec<LeafCandidate> = Vec::new();
490    candidates.push(LeafCandidate {
491        node_idx: 0,
492        sample_indices: sample_indices.to_vec(),
493        histograms: root_histograms,
494        grad_sum: total_grad,
495        hess_sum: total_hess,
496        depth: 0,
497    });
498
499    let mut n_leaves = 1usize;
500
501    while n_leaves < max_leaf_nodes && !candidates.is_empty() {
502        // Find the candidate with the best split gain.
503        let mut best_cand_idx = 0;
504        let mut best_gain = f64::NEG_INFINITY;
505
506        for (c_idx, cand) in candidates.iter().enumerate() {
507            if cand.depth >= max_depth {
508                continue;
509            }
510            if cand.sample_indices.len() < 2 * min_samples_leaf {
511                continue;
512            }
513            // Find best split for this candidate.
514            if let Some(split) = find_best_split(
515                &cand.histograms,
516                binned,
517                &cand.sample_indices,
518                cand.grad_sum,
519                cand.hess_sum,
520                min_samples_leaf,
521                l2_reg,
522                n_features,
523            ) {
524                if split.gain > best_gain {
525                    best_gain = split.gain;
526                    best_cand_idx = c_idx;
527                }
528            }
529        }
530
531        if best_gain <= 0.0 {
532            break;
533        }
534
535        let cand = candidates.remove(best_cand_idx);
536
537        // Re-compute best split for the chosen candidate (we need the full result).
538        let split = find_best_split(
539            &cand.histograms,
540            binned,
541            &cand.sample_indices,
542            cand.grad_sum,
543            cand.hess_sum,
544            min_samples_leaf,
545            l2_reg,
546            n_features,
547        );
548
549        let Some(split) = split else {
550            continue;
551        };
552
553        // Create child leaf nodes.
554        let left_idx = nodes.len();
555        nodes.push(HistNode::Leaf {
556            value: split.left_value,
557        });
558        let right_idx = nodes.len();
559        nodes.push(HistNode::Leaf {
560            value: split.right_value,
561        });
562
563        // Convert parent leaf → split node.
564        nodes[cand.node_idx] = HistNode::Split {
565            feature: split.feature,
566            bin_threshold: split.bin_threshold,
567            left: left_idx,
568            right: right_idx,
569            gain: split.gain,
570        };
571
572        n_leaves += 1; // one leaf became two (net +1)
573
574        // Build histograms for children using subtraction trick:
575        // smaller child gets full histogram build, larger gets subtraction.
576        let (small_indices, _large_indices, small_is_left) =
577            if split.left_indices.len() <= split.right_indices.len() {
578                (&split.left_indices, &split.right_indices, true)
579            } else {
580                (&split.right_indices, &split.left_indices, false)
581            };
582
583        let small_histograms =
584            build_histograms(binned, gradients, hessians, small_indices, n_features);
585        let large_histograms = subtract_histograms(&cand.histograms, &small_histograms);
586
587        let (left_hist, right_hist) = if small_is_left {
588            (small_histograms, large_histograms)
589        } else {
590            (large_histograms, small_histograms)
591        };
592
593        let new_depth = cand.depth + 1;
594
595        // Add children as new candidates.
596        if split.left_indices.len() >= 2 * min_samples_leaf && new_depth < max_depth {
597            candidates.push(LeafCandidate {
598                node_idx: left_idx,
599                sample_indices: split.left_indices,
600                histograms: left_hist,
601                grad_sum: split.left_grad_sum,
602                hess_sum: split.left_hess_sum,
603                depth: new_depth,
604            });
605        }
606
607        if split.right_indices.len() >= 2 * min_samples_leaf && new_depth < max_depth {
608            candidates.push(LeafCandidate {
609                node_idx: right_idx,
610                sample_indices: split.right_indices,
611                histograms: right_hist,
612                grad_sum: split.right_grad_sum,
613                hess_sum: split.right_hess_sum,
614                depth: new_depth,
615            });
616        }
617    }
618
619    HistTree { nodes }
620}
621
622// ═══════════════════════════════════════════════════════════════════════════
623// Histogram Gradient Boosting Regressor
624// ═══════════════════════════════════════════════════════════════════════════
625
626/// Histogram-based Gradient Boosting for regression.
627///
628/// Uses pre-binned features and O(256) histogram scans for split finding,
629/// delivering 5-10× speedup over standard GBT on large datasets. This is
630/// the same algorithmic approach as LightGBM/XGBoost/CatBoost, implemented
631/// in pure Rust with no external BLAS dependency.
632///
633/// # Example
634/// ```
635/// use scry_learn::dataset::Dataset;
636/// use scry_learn::tree::HistGradientBoostingRegressor;
637///
638/// let features = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]];
639/// let target = vec![2.0, 4.0, 6.0, 8.0, 10.0];
640/// let data = Dataset::new(features, target, vec!["x".into()], "y");
641///
642/// let mut model = HistGradientBoostingRegressor::new()
643///     .n_estimators(100)
644///     .learning_rate(0.1)
645///     .max_leaf_nodes(31);
646/// model.fit(&data).unwrap();
647///
648/// let preds = model.predict(&[vec![3.0]]).unwrap();
649/// assert!((preds[0] - 6.0).abs() < 1.0);
650/// ```
651#[derive(Clone)]
652#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
653#[non_exhaustive]
654pub struct HistGradientBoostingRegressor {
655    n_estimators: usize,
656    learning_rate: f64,
657    max_leaf_nodes: usize,
658    min_samples_leaf: usize,
659    max_depth: usize,
660    max_bins: usize,
661    l2_regularization: f64,
662    seed: u64,
663    // Fitted state
664    trees: Vec<HistTree>,
665    binner: FeatureBinner,
666    init_prediction: f64,
667    n_features: usize,
668    fitted: bool,
669    #[cfg_attr(feature = "serde", serde(default))]
670    _schema_version: u32,
671}
672
673impl HistGradientBoostingRegressor {
674    /// Create a new regressor with default parameters.
675    ///
676    /// # Example
677    /// ```
678    /// use scry_learn::tree::HistGradientBoostingRegressor;
679    ///
680    /// let model = HistGradientBoostingRegressor::new()
681    ///     .n_estimators(200)
682    ///     .learning_rate(0.05);
683    /// ```
684    pub fn new() -> Self {
685        Self {
686            n_estimators: 100,
687            learning_rate: 0.1,
688            max_leaf_nodes: 31,
689            min_samples_leaf: 20,
690            max_depth: 8,
691            max_bins: NUM_BINS,
692            l2_regularization: 0.0,
693            seed: 42,
694            trees: Vec::new(),
695            binner: FeatureBinner::new(),
696            init_prediction: 0.0,
697            n_features: 0,
698            fitted: false,
699            _schema_version: crate::version::SCHEMA_VERSION,
700        }
701    }
702
703    /// Set number of boosting rounds (default: 100).
704    pub fn n_estimators(mut self, n: usize) -> Self {
705        self.n_estimators = n;
706        self
707    }
708
709    /// Set learning rate / shrinkage (default: 0.1).
710    pub fn learning_rate(mut self, lr: f64) -> Self {
711        self.learning_rate = lr;
712        self
713    }
714
715    /// Set maximum number of leaf nodes per tree (default: 31).
716    ///
717    /// This controls tree complexity. LightGBM default is 31.
718    pub fn max_leaf_nodes(mut self, n: usize) -> Self {
719        self.max_leaf_nodes = n;
720        self
721    }
722
723    /// Set minimum samples required in a leaf (default: 20).
724    pub fn min_samples_leaf(mut self, n: usize) -> Self {
725        self.min_samples_leaf = n;
726        self
727    }
728
729    /// Set maximum tree depth (default: 8). Acts as a secondary depth limit.
730    pub fn max_depth(mut self, d: usize) -> Self {
731        self.max_depth = d;
732        self
733    }
734
735    /// Set maximum number of bins (2..=256, default: 256).
736    pub fn max_bins(mut self, bins: usize) -> Self {
737        self.max_bins = bins.clamp(2, NUM_BINS);
738        self
739    }
740
741    /// Set L2 regularization (default: 0.0).
742    pub fn l2_regularization(mut self, l2: f64) -> Self {
743        self.l2_regularization = l2;
744        self
745    }
746
747    /// Set random seed (default: 42).
748    pub fn seed(mut self, s: u64) -> Self {
749        self.seed = s;
750        self
751    }
752
753    /// Train the histogram-based gradient boosting regressor.
754    pub fn fit(&mut self, data: &Dataset) -> Result<()> {
755        data.validate_no_inf()?;
756        let n = data.n_samples();
757        if n == 0 {
758            return Err(ScryLearnError::EmptyDataset);
759        }
760        if self.learning_rate <= 0.0 || self.learning_rate > 1.0 {
761            return Err(ScryLearnError::InvalidParameter(
762                "learning_rate must be in (0, 1]".into(),
763            ));
764        }
765
766        self.n_features = data.n_features();
767
768        // Bin features.
769        self.binner = FeatureBinner::new().max_bins(self.max_bins);
770        let binned = self.binner.fit_transform(data)?;
771
772        // Initial prediction: mean of targets.
773        let mean: f64 = data.target.iter().sum::<f64>() / n as f64;
774        self.init_prediction = mean;
775
776        let mut predictions = vec![mean; n];
777        let all_indices: Vec<usize> = (0..n).collect();
778
779        self.trees = Vec::with_capacity(self.n_estimators);
780
781        // Adjust min_samples_leaf for small datasets.
782        let effective_min_leaf = self.min_samples_leaf.min(n / 4).max(1);
783
784        for _ in 0..self.n_estimators {
785            // Compute gradients (negative residuals) and hessians.
786            let gradients: Vec<f64> = (0..n).map(|i| -(data.target[i] - predictions[i])).collect();
787            let hessians = vec![1.0; n]; // squared error: hessian = 1
788
789            let tree = build_tree_leaf_wise(
790                &binned,
791                &gradients,
792                &hessians,
793                &all_indices,
794                self.max_leaf_nodes,
795                effective_min_leaf,
796                self.max_depth,
797                self.l2_regularization,
798                self.n_features,
799            );
800
801            // Update predictions.
802            for &i in &all_indices {
803                let sample: Vec<u8> = binned.iter().map(|col| col[i]).collect();
804                predictions[i] += self.learning_rate * tree.predict_one(&sample);
805            }
806
807            self.trees.push(tree);
808        }
809
810        self.fitted = true;
811        Ok(())
812    }
813
814    /// Predict values for new samples.
815    ///
816    /// `features` is row-major: `features[sample_idx][feature_idx]`.
817    pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
818        crate::version::check_schema_version(self._schema_version)?;
819        if !self.fitted {
820            return Err(ScryLearnError::NotFitted);
821        }
822        let n = features.len();
823        let mut preds = vec![self.init_prediction; n];
824
825        for tree in &self.trees {
826            for (i, sample) in features.iter().enumerate() {
827                preds[i] += self.learning_rate * tree.predict_one_raw(sample, &self.binner);
828            }
829        }
830
831        Ok(preds)
832    }
833
834    /// Feature importances (total gain, normalized).
835    pub fn feature_importances(&self) -> Result<Vec<f64>> {
836        if !self.fitted {
837            return Err(ScryLearnError::NotFitted);
838        }
839        let m = self.n_features;
840        let mut imp = vec![0.0; m];
841        for tree in &self.trees {
842            let ti = tree.feature_importances(m);
843            for (i, &v) in ti.iter().enumerate() {
844                imp[i] += v;
845            }
846        }
847        let total: f64 = imp.iter().sum();
848        if total > 0.0 {
849            for v in &mut imp {
850                *v /= total;
851            }
852        }
853        Ok(imp)
854    }
855
856    /// Number of trees in the ensemble.
857    pub fn n_trees(&self) -> usize {
858        self.trees.len()
859    }
860
861    /// Number of features the model was trained on.
862    pub fn n_features(&self) -> usize {
863        self.n_features
864    }
865
866    /// Learning rate value.
867    pub fn learning_rate_val(&self) -> f64 {
868        self.learning_rate
869    }
870
871    /// Initial (base) prediction value.
872    pub fn init_prediction_val(&self) -> f64 {
873        self.init_prediction
874    }
875
876    /// Convert internal HistTree nodes to public HistNodeView arrays for ONNX export.
877    /// Bin thresholds are converted to raw feature thresholds using the binner.
878    pub fn tree_node_views(&self) -> Vec<Vec<HistNodeView>> {
879        self.trees
880            .iter()
881            .map(|tree| tree.to_node_views(&self.binner))
882            .collect()
883    }
884}
885
886impl Default for HistGradientBoostingRegressor {
887    fn default() -> Self {
888        Self::new()
889    }
890}
891
892// ═══════════════════════════════════════════════════════════════════════════
893// Histogram Gradient Boosting Classifier
894// ═══════════════════════════════════════════════════════════════════════════
895
896/// Histogram-based Gradient Boosting for classification (binary + multiclass).
897///
898/// Uses the same O(256) histogram approach as the regressor, with log-loss
899/// for binary classification and softmax for multiclass. Leaf-wise tree growth
900/// with Newton-Raphson leaf correction.
901///
902/// # Example
903/// ```
904/// use scry_learn::dataset::Dataset;
905/// use scry_learn::tree::HistGradientBoostingClassifier;
906///
907/// let features = vec![
908///     vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
909///     vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
910/// ];
911/// let target = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
912/// let data = Dataset::new(features, target, vec!["x1".into(), "x2".into()], "class");
913///
914/// let mut model = HistGradientBoostingClassifier::new()
915///     .n_estimators(50)
916///     .learning_rate(0.1)
917///     .max_leaf_nodes(31);
918/// model.fit(&data).unwrap();
919///
920/// let preds = model.predict(&[vec![1.5, 0.15], vec![5.5, 0.55]]).unwrap();
921/// assert_eq!(preds[0], 0.0);
922/// assert_eq!(preds[1], 1.0);
923/// ```
924#[derive(Clone)]
925#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
926#[non_exhaustive]
927pub struct HistGradientBoostingClassifier {
928    n_estimators: usize,
929    learning_rate: f64,
930    max_leaf_nodes: usize,
931    min_samples_leaf: usize,
932    max_depth: usize,
933    max_bins: usize,
934    l2_regularization: f64,
935    seed: u64,
936    // Fitted state — trees[class_idx][estimator_idx]
937    trees: Vec<Vec<HistTree>>,
938    binner: FeatureBinner,
939    init_predictions: Vec<f64>,
940    n_classes: usize,
941    n_features: usize,
942    fitted: bool,
943    #[cfg_attr(feature = "serde", serde(default))]
944    _schema_version: u32,
945}
946
947impl HistGradientBoostingClassifier {
948    /// Create a new classifier with default parameters.
949    ///
950    /// # Example
951    /// ```
952    /// use scry_learn::tree::HistGradientBoostingClassifier;
953    ///
954    /// let model = HistGradientBoostingClassifier::new()
955    ///     .n_estimators(200)
956    ///     .learning_rate(0.05);
957    /// ```
958    pub fn new() -> Self {
959        Self {
960            n_estimators: 100,
961            learning_rate: 0.1,
962            max_leaf_nodes: 31,
963            min_samples_leaf: 20,
964            max_depth: 8,
965            max_bins: NUM_BINS,
966            l2_regularization: 0.0,
967            seed: 42,
968            trees: Vec::new(),
969            binner: FeatureBinner::new(),
970            init_predictions: Vec::new(),
971            n_classes: 0,
972            n_features: 0,
973            fitted: false,
974            _schema_version: crate::version::SCHEMA_VERSION,
975        }
976    }
977
978    /// Set number of boosting rounds (default: 100).
979    pub fn n_estimators(mut self, n: usize) -> Self {
980        self.n_estimators = n;
981        self
982    }
983
984    /// Set learning rate / shrinkage (default: 0.1).
985    pub fn learning_rate(mut self, lr: f64) -> Self {
986        self.learning_rate = lr;
987        self
988    }
989
990    /// Set maximum leaf nodes per tree (default: 31).
991    pub fn max_leaf_nodes(mut self, n: usize) -> Self {
992        self.max_leaf_nodes = n;
993        self
994    }
995
996    /// Set minimum samples per leaf (default: 20).
997    pub fn min_samples_leaf(mut self, n: usize) -> Self {
998        self.min_samples_leaf = n;
999        self
1000    }
1001
1002    /// Set maximum tree depth (default: 8).
1003    pub fn max_depth(mut self, d: usize) -> Self {
1004        self.max_depth = d;
1005        self
1006    }
1007
1008    /// Set maximum bins (2..=256, default: 256).
1009    pub fn max_bins(mut self, bins: usize) -> Self {
1010        self.max_bins = bins.clamp(2, NUM_BINS);
1011        self
1012    }
1013
1014    /// Set L2 regularization (default: 0.0).
1015    pub fn l2_regularization(mut self, l2: f64) -> Self {
1016        self.l2_regularization = l2;
1017        self
1018    }
1019
1020    /// Set random seed (default: 42).
1021    pub fn seed(mut self, s: u64) -> Self {
1022        self.seed = s;
1023        self
1024    }
1025
1026    /// Train the histogram-based gradient boosting classifier.
1027    pub fn fit(&mut self, data: &Dataset) -> Result<()> {
1028        data.validate_no_inf()?;
1029        let n = data.n_samples();
1030        if n == 0 {
1031            return Err(ScryLearnError::EmptyDataset);
1032        }
1033        if self.learning_rate <= 0.0 || self.learning_rate > 1.0 {
1034            return Err(ScryLearnError::InvalidParameter(
1035                "learning_rate must be in (0, 1]".into(),
1036            ));
1037        }
1038
1039        self.n_features = data.n_features();
1040        self.n_classes = data.n_classes();
1041        let k = self.n_classes;
1042
1043        if k < 2 {
1044            return Err(ScryLearnError::InvalidParameter(
1045                "need at least 2 classes for classification".into(),
1046            ));
1047        }
1048
1049        // Bin features.
1050        self.binner = FeatureBinner::new().max_bins(self.max_bins);
1051        let binned = self.binner.fit_transform(data)?;
1052
1053        let all_indices: Vec<usize> = (0..n).collect();
1054
1055        // Adjust min_samples_leaf for small datasets.
1056        let effective_min_leaf = self.min_samples_leaf.min(n / 4).max(1);
1057
1058        if k == 2 {
1059            self.fit_binary(data, n, &binned, &all_indices, effective_min_leaf)
1060        } else {
1061            self.fit_multiclass(data, n, k, &binned, &all_indices, effective_min_leaf)
1062        }
1063    }
1064
1065    /// Binary classification via log-loss.
1066    #[allow(clippy::unnecessary_wraps)]
1067    fn fit_binary(
1068        &mut self,
1069        data: &Dataset,
1070        n: usize,
1071        binned: &[Vec<u8>],
1072        all_indices: &[usize],
1073        min_leaf: usize,
1074    ) -> Result<()> {
1075        // Initial prediction: log-odds of positive class.
1076        let pos_count = data.target.iter().filter(|&&y| y > 0.5).count();
1077        let p = (pos_count as f64 / n as f64).clamp(1e-7, 1.0 - 1e-7);
1078        let f0 = (p / (1.0 - p)).ln();
1079        self.init_predictions = vec![f0];
1080
1081        let mut f_vals = vec![f0; n];
1082        let mut trees_seq = Vec::with_capacity(self.n_estimators);
1083
1084        for _ in 0..self.n_estimators {
1085            // Compute gradients and hessians for log-loss.
1086            let probs: Vec<f64> = f_vals.iter().map(|&f| sigmoid(f)).collect();
1087            let gradients: Vec<f64> = (0..n).map(|i| probs[i] - data.target[i]).collect();
1088            let hessians: Vec<f64> = probs.iter().map(|&p| (p * (1.0 - p)).max(1e-10)).collect();
1089
1090            let tree = build_tree_leaf_wise(
1091                binned,
1092                &gradients,
1093                &hessians,
1094                all_indices,
1095                self.max_leaf_nodes,
1096                min_leaf,
1097                self.max_depth,
1098                self.l2_regularization,
1099                self.n_features,
1100            );
1101
1102            // Update predictions.
1103            for &i in all_indices {
1104                let sample: Vec<u8> = binned.iter().map(|col| col[i]).collect();
1105                f_vals[i] += self.learning_rate * tree.predict_one(&sample);
1106            }
1107
1108            trees_seq.push(tree);
1109        }
1110
1111        self.trees = vec![trees_seq];
1112        self.fitted = true;
1113        Ok(())
1114    }
1115
1116    /// Multiclass via softmax (K tree sequences).
1117    #[allow(clippy::unnecessary_wraps)]
1118    fn fit_multiclass(
1119        &mut self,
1120        data: &Dataset,
1121        n: usize,
1122        k: usize,
1123        binned: &[Vec<u8>],
1124        all_indices: &[usize],
1125        min_leaf: usize,
1126    ) -> Result<()> {
1127        // One-hot targets.
1128        let y_onehot: Vec<Vec<f64>> = (0..k)
1129            .map(|cls| {
1130                data.target
1131                    .iter()
1132                    .map(|&y| if (y as usize) == cls { 1.0 } else { 0.0 })
1133                    .collect()
1134            })
1135            .collect();
1136
1137        // Initial predictions: log of class priors.
1138        let class_counts: Vec<usize> = (0..k)
1139            .map(|cls| data.target.iter().filter(|&&y| (y as usize) == cls).count())
1140            .collect();
1141        let init_preds: Vec<f64> = class_counts
1142            .iter()
1143            .map(|&c| (c as f64 / n as f64).clamp(1e-7, 1.0 - 1e-7).ln())
1144            .collect();
1145        self.init_predictions.clone_from(&init_preds);
1146
1147        // f_vals[class][sample]
1148        let mut f_vals: Vec<Vec<f64>> = (0..k).map(|c| vec![init_preds[c]; n]).collect();
1149        let mut trees_all: Vec<Vec<HistTree>> = (0..k)
1150            .map(|_| Vec::with_capacity(self.n_estimators))
1151            .collect();
1152
1153        for _ in 0..self.n_estimators {
1154            // Softmax probabilities.
1155            let probs = softmax_matrix(&f_vals, n, k);
1156
1157            for cls in 0..k {
1158                // Gradients: p_k - y_k; Hessians: p_k * (1 - p_k).
1159                let gradients: Vec<f64> =
1160                    (0..n).map(|i| probs[cls][i] - y_onehot[cls][i]).collect();
1161                let hessians: Vec<f64> = (0..n)
1162                    .map(|i| (probs[cls][i] * (1.0 - probs[cls][i])).max(1e-10))
1163                    .collect();
1164
1165                let tree = build_tree_leaf_wise(
1166                    binned,
1167                    &gradients,
1168                    &hessians,
1169                    all_indices,
1170                    self.max_leaf_nodes,
1171                    min_leaf,
1172                    self.max_depth,
1173                    self.l2_regularization,
1174                    self.n_features,
1175                );
1176
1177                for &i in all_indices {
1178                    let sample: Vec<u8> = binned.iter().map(|col| col[i]).collect();
1179                    f_vals[cls][i] += self.learning_rate * tree.predict_one(&sample);
1180                }
1181
1182                trees_all[cls].push(tree);
1183            }
1184        }
1185
1186        self.trees = trees_all;
1187        self.fitted = true;
1188        Ok(())
1189    }
1190
1191    /// Predict class labels for new samples.
1192    pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
1193        crate::version::check_schema_version(self._schema_version)?;
1194        if !self.fitted {
1195            return Err(ScryLearnError::NotFitted);
1196        }
1197        let proba = self.predict_proba(features)?;
1198        Ok(proba
1199            .iter()
1200            .map(|row| {
1201                row.iter()
1202                    .enumerate()
1203                    .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
1204                    .map_or(0.0, |(idx, _)| idx as f64)
1205            })
1206            .collect())
1207    }
1208
1209    /// Predict class probabilities for new samples.
1210    pub fn predict_proba(&self, features: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
1211        if !self.fitted {
1212            return Err(ScryLearnError::NotFitted);
1213        }
1214        let n = features.len();
1215        let k = self.n_classes;
1216
1217        if k == 2 {
1218            // Binary: single tree sequence, use sigmoid.
1219            let mut f_vals = vec![self.init_predictions[0]; n];
1220            for tree in &self.trees[0] {
1221                for (i, sample) in features.iter().enumerate() {
1222                    f_vals[i] += self.learning_rate * tree.predict_one_raw(sample, &self.binner);
1223                }
1224            }
1225            Ok(f_vals
1226                .iter()
1227                .map(|&f| {
1228                    let p = sigmoid(f);
1229                    vec![1.0 - p, p]
1230                })
1231                .collect())
1232        } else {
1233            // Multiclass: K tree sequences, use softmax.
1234            let mut f_vals: Vec<Vec<f64>> =
1235                (0..k).map(|c| vec![self.init_predictions[c]; n]).collect();
1236
1237            for (cls_vals, cls_trees) in f_vals.iter_mut().zip(self.trees.iter()).take(k) {
1238                for tree in cls_trees {
1239                    for (i, sample) in features.iter().enumerate() {
1240                        cls_vals[i] +=
1241                            self.learning_rate * tree.predict_one_raw(sample, &self.binner);
1242                    }
1243                }
1244            }
1245
1246            let probs = softmax_matrix(&f_vals, n, k);
1247            // Transpose from [class][sample] to [sample][class].
1248            Ok((0..n)
1249                .map(|i| (0..k).map(|c| probs[c][i]).collect())
1250                .collect())
1251        }
1252    }
1253
1254    /// Feature importances (total gain, normalized).
1255    pub fn feature_importances(&self) -> Result<Vec<f64>> {
1256        if !self.fitted {
1257            return Err(ScryLearnError::NotFitted);
1258        }
1259        let m = self.n_features;
1260        let mut imp = vec![0.0; m];
1261        for tree_seq in &self.trees {
1262            for tree in tree_seq {
1263                let ti = tree.feature_importances(m);
1264                for (i, &v) in ti.iter().enumerate() {
1265                    imp[i] += v;
1266                }
1267            }
1268        }
1269        let total: f64 = imp.iter().sum();
1270        if total > 0.0 {
1271            for v in &mut imp {
1272                *v /= total;
1273            }
1274        }
1275        Ok(imp)
1276    }
1277
1278    /// Number of trees in the ensemble.
1279    pub fn n_trees(&self) -> usize {
1280        self.trees.iter().map(Vec::len).sum()
1281    }
1282
1283    /// Number of classes.
1284    pub fn n_classes(&self) -> usize {
1285        self.n_classes
1286    }
1287
1288    /// Number of features the model was trained on.
1289    pub fn n_features(&self) -> usize {
1290        self.n_features
1291    }
1292
1293    /// Learning rate value.
1294    pub fn learning_rate_val(&self) -> f64 {
1295        self.learning_rate
1296    }
1297
1298    /// Initial predictions per class.
1299    pub fn init_predictions_val(&self) -> &[f64] {
1300        &self.init_predictions
1301    }
1302
1303    /// Convert internal HistTree nodes to public HistNodeView arrays for ONNX export.
1304    /// Returns `class_tree_views[class_idx][tree_idx]` = Vec of HistNodeView.
1305    pub fn class_tree_node_views(&self) -> Vec<Vec<Vec<HistNodeView>>> {
1306        self.trees
1307            .iter()
1308            .map(|class_trees| {
1309                class_trees
1310                    .iter()
1311                    .map(|tree| tree.to_node_views(&self.binner))
1312                    .collect()
1313            })
1314            .collect()
1315    }
1316}
1317
1318impl Default for HistGradientBoostingClassifier {
1319    fn default() -> Self {
1320        Self::new()
1321    }
1322}
1323
1324// ═══════════════════════════════════════════════════════════════════════════
1325// Utility functions
1326// ═══════════════════════════════════════════════════════════════════════════
1327
1328/// Sigmoid function.
1329#[inline]
1330fn sigmoid(x: f64) -> f64 {
1331    1.0 / (1.0 + (-x).exp())
1332}
1333
1334/// Softmax over class×sample matrix. Input/output: `[class][sample]`.
1335fn softmax_matrix(f_vals: &[Vec<f64>], n: usize, k: usize) -> Vec<Vec<f64>> {
1336    let mut result: Vec<Vec<f64>> = vec![vec![0.0; n]; k];
1337
1338    for i in 0..n {
1339        let max_f = (0..k)
1340            .map(|c| f_vals[c][i])
1341            .fold(f64::NEG_INFINITY, f64::max);
1342        let exp_sum: f64 = (0..k).map(|c| (f_vals[c][i] - max_f).exp()).sum();
1343        for c in 0..k {
1344            result[c][i] = (f_vals[c][i] - max_f).exp() / exp_sum;
1345        }
1346    }
1347
1348    result
1349}
1350
1351// ═══════════════════════════════════════════════════════════════════════════
1352// Tests
1353// ═══════════════════════════════════════════════════════════════════════════
1354
1355#[cfg(test)]
1356mod tests {
1357    use super::*;
1358    use crate::metrics::{accuracy, r2_score};
1359
1360    fn simple_regression_data() -> Dataset {
1361        // y = 2x + 1
1362        let x: Vec<f64> = (0..100).map(|i| i as f64 * 0.1).collect();
1363        let y: Vec<f64> = x.iter().map(|&v| 2.0 * v + 1.0).collect();
1364        Dataset::new(vec![x], y, vec!["x".into()], "y")
1365    }
1366
1367    fn simple_classification_data() -> Dataset {
1368        let n = 200;
1369        let mut f1 = Vec::with_capacity(n);
1370        let mut f2 = Vec::with_capacity(n);
1371        let mut target = Vec::with_capacity(n);
1372        let mut rng = crate::rng::FastRng::new(42);
1373
1374        for _ in 0..n / 2 {
1375            f1.push(rng.f64() * 2.0);
1376            f2.push(rng.f64() * 2.0);
1377            target.push(0.0);
1378        }
1379        for _ in 0..n / 2 {
1380            f1.push(5.0 + rng.f64() * 2.0);
1381            f2.push(5.0 + rng.f64() * 2.0);
1382            target.push(1.0);
1383        }
1384
1385        Dataset::new(
1386            vec![f1, f2],
1387            target,
1388            vec!["f1".into(), "f2".into()],
1389            "class",
1390        )
1391    }
1392
1393    #[test]
1394    fn test_hist_gbr_fit_predict() {
1395        let data = simple_regression_data();
1396        let mut model = HistGradientBoostingRegressor::new()
1397            .n_estimators(50)
1398            .learning_rate(0.1)
1399            .max_leaf_nodes(15)
1400            .min_samples_leaf(5);
1401        model.fit(&data).unwrap();
1402
1403        let test_x = vec![vec![3.0], vec![5.0], vec![7.0]];
1404        let preds = model.predict(&test_x).unwrap();
1405        assert_eq!(preds.len(), 3);
1406
1407        // Should approximate y = 2x + 1.
1408        assert!((preds[0] - 7.0).abs() < 1.5, "got {}", preds[0]);
1409        assert!((preds[1] - 11.0).abs() < 1.5, "got {}", preds[1]);
1410    }
1411
1412    #[test]
1413    fn test_hist_gbr_r2() {
1414        let data = simple_regression_data();
1415        let mut model = HistGradientBoostingRegressor::new()
1416            .n_estimators(100)
1417            .learning_rate(0.1)
1418            .max_leaf_nodes(31)
1419            .min_samples_leaf(3);
1420        model.fit(&data).unwrap();
1421
1422        let features = data.feature_matrix();
1423        let preds = model.predict(&features).unwrap();
1424        let r2 = r2_score(&data.target, &preds);
1425        assert!(r2 > 0.95, "R² should be > 0.95, got {r2:.4}");
1426    }
1427
1428    #[test]
1429    fn test_hist_gbc_binary() {
1430        let data = simple_classification_data();
1431        let mut model = HistGradientBoostingClassifier::new()
1432            .n_estimators(50)
1433            .learning_rate(0.1)
1434            .max_leaf_nodes(15)
1435            .min_samples_leaf(5);
1436        model.fit(&data).unwrap();
1437
1438        let features = data.feature_matrix();
1439        let preds = model.predict(&features).unwrap();
1440        let acc = accuracy(&data.target, &preds);
1441        assert!(
1442            acc > 0.90,
1443            "accuracy should be > 90%, got {:.1}%",
1444            acc * 100.0
1445        );
1446    }
1447
1448    #[test]
1449    fn test_hist_gbc_multiclass() {
1450        let n_per_class = 50;
1451        let mut rng = crate::rng::FastRng::new(42);
1452        let mut f1 = Vec::new();
1453        let mut f2 = Vec::new();
1454        let mut target = Vec::new();
1455
1456        for cls in 0..3 {
1457            let offset = cls as f64 * 5.0;
1458            for _ in 0..n_per_class {
1459                f1.push(offset + rng.f64() * 2.0);
1460                f2.push(offset + rng.f64() * 2.0);
1461                target.push(cls as f64);
1462            }
1463        }
1464
1465        let data = Dataset::new(
1466            vec![f1, f2],
1467            target,
1468            vec!["f1".into(), "f2".into()],
1469            "class",
1470        );
1471
1472        let mut model = HistGradientBoostingClassifier::new()
1473            .n_estimators(50)
1474            .learning_rate(0.1)
1475            .max_leaf_nodes(15)
1476            .min_samples_leaf(3);
1477        model.fit(&data).unwrap();
1478
1479        let features = data.feature_matrix();
1480        let preds = model.predict(&features).unwrap();
1481        let acc = accuracy(&data.target, &preds);
1482        assert!(
1483            acc > 0.90,
1484            "multiclass accuracy > 90%, got {:.1}%",
1485            acc * 100.0
1486        );
1487    }
1488
1489    #[test]
1490    fn test_hist_gbc_predict_proba() {
1491        let data = simple_classification_data();
1492        let mut model = HistGradientBoostingClassifier::new()
1493            .n_estimators(30)
1494            .learning_rate(0.1)
1495            .min_samples_leaf(5);
1496        model.fit(&data).unwrap();
1497
1498        let features = data.feature_matrix();
1499        let proba = model.predict_proba(&features).unwrap();
1500        for row in &proba {
1501            let sum: f64 = row.iter().sum();
1502            assert!((sum - 1.0).abs() < 1e-6, "probabilities should sum to 1.0");
1503            for &p in row {
1504                assert!((0.0..=1.0).contains(&p), "probability out of range: {p}");
1505            }
1506        }
1507    }
1508
1509    #[test]
1510    fn test_hist_gbr_not_fitted() {
1511        let model = HistGradientBoostingRegressor::new();
1512        let result = model.predict(&[vec![1.0]]);
1513        assert!(result.is_err());
1514    }
1515
1516    #[test]
1517    fn test_hist_gbc_not_fitted() {
1518        let model = HistGradientBoostingClassifier::new();
1519        let result = model.predict(&[vec![1.0]]);
1520        assert!(result.is_err());
1521    }
1522
1523    #[test]
1524    fn test_hist_gbr_feature_importances() {
1525        let data = simple_regression_data();
1526        let mut model = HistGradientBoostingRegressor::new()
1527            .n_estimators(50)
1528            .min_samples_leaf(3);
1529        model.fit(&data).unwrap();
1530
1531        let imp = model.feature_importances().unwrap();
1532        assert_eq!(imp.len(), 1);
1533        let sum: f64 = imp.iter().sum();
1534        assert!((sum - 1.0).abs() < 1e-6 || sum == 0.0);
1535    }
1536
1537    #[test]
1538    fn test_hist_gbr_with_nan() {
1539        let x: Vec<f64> = (0..100)
1540            .map(|i| {
1541                if i % 10 == 0 {
1542                    f64::NAN
1543                } else {
1544                    i as f64 * 0.1
1545                }
1546            })
1547            .collect();
1548        let y: Vec<f64> = (0..100).map(|i| i as f64 * 0.2 + 1.0).collect();
1549        let data = Dataset::new(vec![x], y, vec!["x".into()], "y");
1550
1551        let mut model = HistGradientBoostingRegressor::new()
1552            .n_estimators(50)
1553            .min_samples_leaf(3);
1554        model.fit(&data).unwrap();
1555
1556        // Predict with NaN — should not panic.
1557        let preds = model.predict(&[vec![f64::NAN], vec![5.0]]).unwrap();
1558        assert_eq!(preds.len(), 2);
1559        assert!(
1560            !preds[0].is_nan(),
1561            "NaN input should produce a finite prediction"
1562        );
1563        assert!(!preds[1].is_nan());
1564    }
1565}