1use crate::dataset::Dataset;
39use crate::error::{Result, ScryLearnError};
40use crate::tree::binning::FeatureBinner;
41
42use rayon::prelude::*;
43
44const NUM_BINS: usize = 256;
50
51#[derive(Clone, Copy, Debug, Default)]
53#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
54struct HistBin {
55 grad_sum: f64,
56 hess_sum: f64,
57 count: u32,
58}
59
60type FeatureHistogram = [HistBin; NUM_BINS];
64
65fn build_histograms(
71 binned: &[Vec<u8>], gradients: &[f64],
73 hessians: &[f64],
74 sample_indices: &[usize],
75 n_features: usize,
76) -> Vec<FeatureHistogram> {
77 #[cfg(feature = "scry-gpu")]
79 {
80 if let Ok(gpu) = crate::accel::ScryGpuBackend::new() {
81 use crate::accel::ComputeBackend;
82 let accel_hists = gpu.build_histograms(
83 binned,
84 gradients,
85 hessians,
86 sample_indices,
87 n_features,
88 NUM_BINS,
89 );
90 return accel_hists
91 .into_iter()
92 .map(|feat_bins| {
93 let mut hist: FeatureHistogram = [HistBin::default(); NUM_BINS];
94 for (b, &(g, h, c)) in feat_bins.iter().enumerate() {
95 if b < NUM_BINS {
96 hist[b].grad_sum = g;
97 hist[b].hess_sum = h;
98 hist[b].count = c as u32;
99 }
100 }
101 hist
102 })
103 .collect();
104 }
105 }
106
107 (0..n_features)
109 .into_par_iter()
110 .map(|f| {
111 let col = &binned[f];
112 let mut hist: FeatureHistogram = [HistBin::default(); NUM_BINS];
113 for &idx in sample_indices {
114 let bin = col[idx] as usize;
115 hist[bin].grad_sum += gradients[idx];
116 hist[bin].hess_sum += hessians[idx];
117 hist[bin].count += 1;
118 }
119 hist
120 })
121 .collect()
122}
123
124fn subtract_histograms(
126 parent: &[FeatureHistogram],
127 left: &[FeatureHistogram],
128) -> Vec<FeatureHistogram> {
129 parent
130 .iter()
131 .zip(left.iter())
132 .map(|(p, l)| {
133 let mut right = [HistBin::default(); NUM_BINS];
134 for b in 0..NUM_BINS {
135 right[b].grad_sum = p[b].grad_sum - l[b].grad_sum;
136 right[b].hess_sum = p[b].hess_sum - l[b].hess_sum;
137 right[b].count = p[b].count.saturating_sub(l[b].count);
138 }
139 right
140 })
141 .collect()
142}
143
144#[derive(Clone, Debug)]
150#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
151enum HistNode {
152 Leaf { value: f64 },
154 Split {
156 feature: usize,
157 bin_threshold: u8,
158 left: usize, right: usize,
160 gain: f64,
161 },
162}
163
164#[derive(Clone, Debug)]
167#[non_exhaustive]
168pub enum HistNodeView {
169 Leaf {
171 value: f64,
173 },
174 Split {
176 feature: usize,
178 threshold: f64,
180 left: usize,
182 right: usize,
184 },
185}
186
187#[derive(Clone, Debug)]
189#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
190struct HistTree {
191 nodes: Vec<HistNode>,
192}
193
194impl HistTree {
195 fn predict_one(&self, sample_binned: &[u8]) -> f64 {
197 let mut node_idx = 0;
198 loop {
199 match &self.nodes[node_idx] {
200 HistNode::Leaf { value } => return *value,
201 HistNode::Split {
202 feature,
203 bin_threshold,
204 left,
205 right,
206 ..
207 } => {
208 if sample_binned[*feature] <= *bin_threshold {
209 node_idx = *left;
210 } else {
211 node_idx = *right;
212 }
213 }
214 }
215 }
216 }
217
218 fn predict_one_raw(&self, sample: &[f64], binner: &FeatureBinner) -> f64 {
220 let mut node_idx = 0;
221 loop {
222 match &self.nodes[node_idx] {
223 HistNode::Leaf { value } => return *value,
224 HistNode::Split {
225 feature,
226 bin_threshold,
227 left,
228 right,
229 ..
230 } => {
231 let val = sample[*feature];
232 let bin = if val.is_nan() {
233 0u8
234 } else {
235 let edges = &binner.bin_edges()[*feature];
236 let pos = match edges.binary_search_by(|edge| {
237 edge.partial_cmp(&val).unwrap_or(std::cmp::Ordering::Equal)
238 }) {
239 Ok(p) => p + 1,
240 Err(p) => p,
241 };
242 (pos + 1).min(255) as u8
243 };
244 if bin <= *bin_threshold {
245 node_idx = *left;
246 } else {
247 node_idx = *right;
248 }
249 }
250 }
251 }
252 }
253
254 fn feature_importances(&self, n_features: usize) -> Vec<f64> {
256 let mut imp = vec![0.0; n_features];
257 for node in &self.nodes {
258 if let HistNode::Split { feature, gain, .. } = node {
259 if *feature < n_features {
260 imp[*feature] += gain;
261 }
262 }
263 }
264 imp
265 }
266
267 fn to_node_views(&self, binner: &FeatureBinner) -> Vec<HistNodeView> {
270 let edges = binner.bin_edges();
271 self.nodes
272 .iter()
273 .map(|node| match node {
274 HistNode::Leaf { value } => HistNodeView::Leaf { value: *value },
275 HistNode::Split {
276 feature,
277 bin_threshold,
278 left,
279 right,
280 ..
281 } => {
282 let threshold = if *bin_threshold == 0 || *feature >= edges.len() {
287 f64::NEG_INFINITY
288 } else {
289 let feat_edges = &edges[*feature];
290 let idx = (*bin_threshold as usize).saturating_sub(1);
291 if idx < feat_edges.len() {
292 feat_edges[idx]
293 } else if !feat_edges.is_empty() {
294 feat_edges[feat_edges.len() - 1]
295 } else {
296 0.0
297 }
298 };
299 HistNodeView::Split {
300 feature: *feature,
301 threshold,
302 left: *left,
303 right: *right,
304 }
305 }
306 })
307 .collect()
308 }
309}
310
311struct LeafCandidate {
313 node_idx: usize,
315 sample_indices: Vec<usize>,
317 histograms: Vec<FeatureHistogram>,
319 grad_sum: f64,
321 hess_sum: f64,
323 depth: usize,
325}
326
327struct SplitResult {
329 feature: usize,
330 bin_threshold: u8,
331 gain: f64,
332 left_indices: Vec<usize>,
333 right_indices: Vec<usize>,
334 left_value: f64,
335 right_value: f64,
336 left_grad_sum: f64,
337 left_hess_sum: f64,
338 right_grad_sum: f64,
339 right_hess_sum: f64,
340}
341
342#[inline]
347fn leaf_value(grad_sum: f64, hess_sum: f64, l2_reg: f64) -> f64 {
348 let denom = hess_sum + l2_reg;
349 if denom.abs() < 1e-10 {
350 0.0
351 } else {
352 -grad_sum / denom
353 }
354}
355
356#[inline]
358fn split_gain(
359 grad_left: f64,
360 hess_left: f64,
361 grad_right: f64,
362 hess_right: f64,
363 l2_reg: f64,
364) -> f64 {
365 let left_term = grad_left * grad_left / (hess_left + l2_reg);
366 let right_term = grad_right * grad_right / (hess_right + l2_reg);
367 let parent_grad = grad_left + grad_right;
368 let parent_hess = hess_left + hess_right;
369 let parent_term = parent_grad * parent_grad / (parent_hess + l2_reg);
370 0.5 * (left_term + right_term - parent_term)
371}
372
373#[allow(clippy::too_many_arguments)]
375fn find_best_split(
376 histograms: &[FeatureHistogram],
377 binned: &[Vec<u8>],
378 sample_indices: &[usize],
379 grad_sum: f64,
380 hess_sum: f64,
381 min_samples_leaf: usize,
382 l2_reg: f64,
383 n_features: usize,
384) -> Option<SplitResult> {
385 let mut best_gain = 0.0; let mut best_feature = 0;
387 let mut best_threshold: u8 = 0;
388 let mut best_left_grad = 0.0;
389 let mut best_left_hess = 0.0;
390
391 for (f, hist) in histograms.iter().enumerate().take(n_features) {
392 let mut running_grad = 0.0;
393 let mut running_hess = 0.0;
394 let mut running_count: u32 = 0;
395 let total_count = sample_indices.len() as u32;
396
397 for bin in 0..255u8 {
399 let b = bin as usize;
400 running_grad += hist[b].grad_sum;
401 running_hess += hist[b].hess_sum;
402 running_count += hist[b].count;
403
404 let right_count = total_count.saturating_sub(running_count);
405 if (running_count as usize) < min_samples_leaf
406 || (right_count as usize) < min_samples_leaf
407 {
408 continue;
409 }
410
411 let right_grad = grad_sum - running_grad;
412 let right_hess = hess_sum - running_hess;
413
414 let gain = split_gain(running_grad, running_hess, right_grad, right_hess, l2_reg);
415
416 if gain > best_gain {
417 best_gain = gain;
418 best_feature = f;
419 best_threshold = bin;
420 best_left_grad = running_grad;
421 best_left_hess = running_hess;
422 }
423 }
424 }
425
426 if best_gain <= 0.0 {
427 return None;
428 }
429
430 let col = &binned[best_feature];
432 let mut left_indices = Vec::new();
433 let mut right_indices = Vec::new();
434 for &idx in sample_indices {
435 if col[idx] <= best_threshold {
436 left_indices.push(idx);
437 } else {
438 right_indices.push(idx);
439 }
440 }
441
442 let best_right_grad = grad_sum - best_left_grad;
443 let best_right_hess = hess_sum - best_left_hess;
444
445 Some(SplitResult {
449 feature: best_feature,
450 bin_threshold: best_threshold,
451 gain: best_gain,
452 left_indices,
453 right_indices,
454 left_value: leaf_value(best_left_grad, best_left_hess, l2_reg),
455 right_value: leaf_value(best_right_grad, best_right_hess, l2_reg),
456 left_grad_sum: best_left_grad,
457 left_hess_sum: best_left_hess,
458 right_grad_sum: best_right_grad,
459 right_hess_sum: best_right_hess,
460 })
461}
462
463#[allow(clippy::too_many_arguments)]
465fn build_tree_leaf_wise(
466 binned: &[Vec<u8>],
467 gradients: &[f64],
468 hessians: &[f64],
469 sample_indices: &[usize],
470 max_leaf_nodes: usize,
471 min_samples_leaf: usize,
472 max_depth: usize,
473 l2_reg: f64,
474 n_features: usize,
475) -> HistTree {
476 let mut nodes: Vec<HistNode> = Vec::new();
477
478 let total_grad: f64 = sample_indices.iter().map(|&i| gradients[i]).sum();
480 let total_hess: f64 = sample_indices.iter().map(|&i| hessians[i]).sum();
481
482 let root_value = leaf_value(total_grad, total_hess, l2_reg);
483 nodes.push(HistNode::Leaf { value: root_value });
484
485 let root_histograms = build_histograms(binned, gradients, hessians, sample_indices, n_features);
487
488 let mut candidates: Vec<LeafCandidate> = Vec::new();
490 candidates.push(LeafCandidate {
491 node_idx: 0,
492 sample_indices: sample_indices.to_vec(),
493 histograms: root_histograms,
494 grad_sum: total_grad,
495 hess_sum: total_hess,
496 depth: 0,
497 });
498
499 let mut n_leaves = 1usize;
500
501 while n_leaves < max_leaf_nodes && !candidates.is_empty() {
502 let mut best_cand_idx = 0;
504 let mut best_gain = f64::NEG_INFINITY;
505
506 for (c_idx, cand) in candidates.iter().enumerate() {
507 if cand.depth >= max_depth {
508 continue;
509 }
510 if cand.sample_indices.len() < 2 * min_samples_leaf {
511 continue;
512 }
513 if let Some(split) = find_best_split(
515 &cand.histograms,
516 binned,
517 &cand.sample_indices,
518 cand.grad_sum,
519 cand.hess_sum,
520 min_samples_leaf,
521 l2_reg,
522 n_features,
523 ) {
524 if split.gain > best_gain {
525 best_gain = split.gain;
526 best_cand_idx = c_idx;
527 }
528 }
529 }
530
531 if best_gain <= 0.0 {
532 break;
533 }
534
535 let cand = candidates.remove(best_cand_idx);
536
537 let split = find_best_split(
539 &cand.histograms,
540 binned,
541 &cand.sample_indices,
542 cand.grad_sum,
543 cand.hess_sum,
544 min_samples_leaf,
545 l2_reg,
546 n_features,
547 );
548
549 let Some(split) = split else {
550 continue;
551 };
552
553 let left_idx = nodes.len();
555 nodes.push(HistNode::Leaf {
556 value: split.left_value,
557 });
558 let right_idx = nodes.len();
559 nodes.push(HistNode::Leaf {
560 value: split.right_value,
561 });
562
563 nodes[cand.node_idx] = HistNode::Split {
565 feature: split.feature,
566 bin_threshold: split.bin_threshold,
567 left: left_idx,
568 right: right_idx,
569 gain: split.gain,
570 };
571
572 n_leaves += 1; let (small_indices, _large_indices, small_is_left) =
577 if split.left_indices.len() <= split.right_indices.len() {
578 (&split.left_indices, &split.right_indices, true)
579 } else {
580 (&split.right_indices, &split.left_indices, false)
581 };
582
583 let small_histograms =
584 build_histograms(binned, gradients, hessians, small_indices, n_features);
585 let large_histograms = subtract_histograms(&cand.histograms, &small_histograms);
586
587 let (left_hist, right_hist) = if small_is_left {
588 (small_histograms, large_histograms)
589 } else {
590 (large_histograms, small_histograms)
591 };
592
593 let new_depth = cand.depth + 1;
594
595 if split.left_indices.len() >= 2 * min_samples_leaf && new_depth < max_depth {
597 candidates.push(LeafCandidate {
598 node_idx: left_idx,
599 sample_indices: split.left_indices,
600 histograms: left_hist,
601 grad_sum: split.left_grad_sum,
602 hess_sum: split.left_hess_sum,
603 depth: new_depth,
604 });
605 }
606
607 if split.right_indices.len() >= 2 * min_samples_leaf && new_depth < max_depth {
608 candidates.push(LeafCandidate {
609 node_idx: right_idx,
610 sample_indices: split.right_indices,
611 histograms: right_hist,
612 grad_sum: split.right_grad_sum,
613 hess_sum: split.right_hess_sum,
614 depth: new_depth,
615 });
616 }
617 }
618
619 HistTree { nodes }
620}
621
622#[derive(Clone)]
652#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
653#[non_exhaustive]
654pub struct HistGradientBoostingRegressor {
655 n_estimators: usize,
656 learning_rate: f64,
657 max_leaf_nodes: usize,
658 min_samples_leaf: usize,
659 max_depth: usize,
660 max_bins: usize,
661 l2_regularization: f64,
662 seed: u64,
663 trees: Vec<HistTree>,
665 binner: FeatureBinner,
666 init_prediction: f64,
667 n_features: usize,
668 fitted: bool,
669 #[cfg_attr(feature = "serde", serde(default))]
670 _schema_version: u32,
671}
672
673impl HistGradientBoostingRegressor {
674 pub fn new() -> Self {
685 Self {
686 n_estimators: 100,
687 learning_rate: 0.1,
688 max_leaf_nodes: 31,
689 min_samples_leaf: 20,
690 max_depth: 8,
691 max_bins: NUM_BINS,
692 l2_regularization: 0.0,
693 seed: 42,
694 trees: Vec::new(),
695 binner: FeatureBinner::new(),
696 init_prediction: 0.0,
697 n_features: 0,
698 fitted: false,
699 _schema_version: crate::version::SCHEMA_VERSION,
700 }
701 }
702
703 pub fn n_estimators(mut self, n: usize) -> Self {
705 self.n_estimators = n;
706 self
707 }
708
709 pub fn learning_rate(mut self, lr: f64) -> Self {
711 self.learning_rate = lr;
712 self
713 }
714
715 pub fn max_leaf_nodes(mut self, n: usize) -> Self {
719 self.max_leaf_nodes = n;
720 self
721 }
722
723 pub fn min_samples_leaf(mut self, n: usize) -> Self {
725 self.min_samples_leaf = n;
726 self
727 }
728
729 pub fn max_depth(mut self, d: usize) -> Self {
731 self.max_depth = d;
732 self
733 }
734
735 pub fn max_bins(mut self, bins: usize) -> Self {
737 self.max_bins = bins.clamp(2, NUM_BINS);
738 self
739 }
740
741 pub fn l2_regularization(mut self, l2: f64) -> Self {
743 self.l2_regularization = l2;
744 self
745 }
746
747 pub fn seed(mut self, s: u64) -> Self {
749 self.seed = s;
750 self
751 }
752
753 pub fn fit(&mut self, data: &Dataset) -> Result<()> {
755 data.validate_no_inf()?;
756 let n = data.n_samples();
757 if n == 0 {
758 return Err(ScryLearnError::EmptyDataset);
759 }
760 if self.learning_rate <= 0.0 || self.learning_rate > 1.0 {
761 return Err(ScryLearnError::InvalidParameter(
762 "learning_rate must be in (0, 1]".into(),
763 ));
764 }
765
766 self.n_features = data.n_features();
767
768 self.binner = FeatureBinner::new().max_bins(self.max_bins);
770 let binned = self.binner.fit_transform(data)?;
771
772 let mean: f64 = data.target.iter().sum::<f64>() / n as f64;
774 self.init_prediction = mean;
775
776 let mut predictions = vec![mean; n];
777 let all_indices: Vec<usize> = (0..n).collect();
778
779 self.trees = Vec::with_capacity(self.n_estimators);
780
781 let effective_min_leaf = self.min_samples_leaf.min(n / 4).max(1);
783
784 for _ in 0..self.n_estimators {
785 let gradients: Vec<f64> = (0..n).map(|i| -(data.target[i] - predictions[i])).collect();
787 let hessians = vec![1.0; n]; let tree = build_tree_leaf_wise(
790 &binned,
791 &gradients,
792 &hessians,
793 &all_indices,
794 self.max_leaf_nodes,
795 effective_min_leaf,
796 self.max_depth,
797 self.l2_regularization,
798 self.n_features,
799 );
800
801 for &i in &all_indices {
803 let sample: Vec<u8> = binned.iter().map(|col| col[i]).collect();
804 predictions[i] += self.learning_rate * tree.predict_one(&sample);
805 }
806
807 self.trees.push(tree);
808 }
809
810 self.fitted = true;
811 Ok(())
812 }
813
814 pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
818 crate::version::check_schema_version(self._schema_version)?;
819 if !self.fitted {
820 return Err(ScryLearnError::NotFitted);
821 }
822 let n = features.len();
823 let mut preds = vec![self.init_prediction; n];
824
825 for tree in &self.trees {
826 for (i, sample) in features.iter().enumerate() {
827 preds[i] += self.learning_rate * tree.predict_one_raw(sample, &self.binner);
828 }
829 }
830
831 Ok(preds)
832 }
833
834 pub fn feature_importances(&self) -> Result<Vec<f64>> {
836 if !self.fitted {
837 return Err(ScryLearnError::NotFitted);
838 }
839 let m = self.n_features;
840 let mut imp = vec![0.0; m];
841 for tree in &self.trees {
842 let ti = tree.feature_importances(m);
843 for (i, &v) in ti.iter().enumerate() {
844 imp[i] += v;
845 }
846 }
847 let total: f64 = imp.iter().sum();
848 if total > 0.0 {
849 for v in &mut imp {
850 *v /= total;
851 }
852 }
853 Ok(imp)
854 }
855
856 pub fn n_trees(&self) -> usize {
858 self.trees.len()
859 }
860
861 pub fn n_features(&self) -> usize {
863 self.n_features
864 }
865
866 pub fn learning_rate_val(&self) -> f64 {
868 self.learning_rate
869 }
870
871 pub fn init_prediction_val(&self) -> f64 {
873 self.init_prediction
874 }
875
876 pub fn tree_node_views(&self) -> Vec<Vec<HistNodeView>> {
879 self.trees
880 .iter()
881 .map(|tree| tree.to_node_views(&self.binner))
882 .collect()
883 }
884}
885
886impl Default for HistGradientBoostingRegressor {
887 fn default() -> Self {
888 Self::new()
889 }
890}
891
892#[derive(Clone)]
925#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
926#[non_exhaustive]
927pub struct HistGradientBoostingClassifier {
928 n_estimators: usize,
929 learning_rate: f64,
930 max_leaf_nodes: usize,
931 min_samples_leaf: usize,
932 max_depth: usize,
933 max_bins: usize,
934 l2_regularization: f64,
935 seed: u64,
936 trees: Vec<Vec<HistTree>>,
938 binner: FeatureBinner,
939 init_predictions: Vec<f64>,
940 n_classes: usize,
941 n_features: usize,
942 fitted: bool,
943 #[cfg_attr(feature = "serde", serde(default))]
944 _schema_version: u32,
945}
946
947impl HistGradientBoostingClassifier {
948 pub fn new() -> Self {
959 Self {
960 n_estimators: 100,
961 learning_rate: 0.1,
962 max_leaf_nodes: 31,
963 min_samples_leaf: 20,
964 max_depth: 8,
965 max_bins: NUM_BINS,
966 l2_regularization: 0.0,
967 seed: 42,
968 trees: Vec::new(),
969 binner: FeatureBinner::new(),
970 init_predictions: Vec::new(),
971 n_classes: 0,
972 n_features: 0,
973 fitted: false,
974 _schema_version: crate::version::SCHEMA_VERSION,
975 }
976 }
977
978 pub fn n_estimators(mut self, n: usize) -> Self {
980 self.n_estimators = n;
981 self
982 }
983
984 pub fn learning_rate(mut self, lr: f64) -> Self {
986 self.learning_rate = lr;
987 self
988 }
989
990 pub fn max_leaf_nodes(mut self, n: usize) -> Self {
992 self.max_leaf_nodes = n;
993 self
994 }
995
996 pub fn min_samples_leaf(mut self, n: usize) -> Self {
998 self.min_samples_leaf = n;
999 self
1000 }
1001
1002 pub fn max_depth(mut self, d: usize) -> Self {
1004 self.max_depth = d;
1005 self
1006 }
1007
1008 pub fn max_bins(mut self, bins: usize) -> Self {
1010 self.max_bins = bins.clamp(2, NUM_BINS);
1011 self
1012 }
1013
1014 pub fn l2_regularization(mut self, l2: f64) -> Self {
1016 self.l2_regularization = l2;
1017 self
1018 }
1019
1020 pub fn seed(mut self, s: u64) -> Self {
1022 self.seed = s;
1023 self
1024 }
1025
1026 pub fn fit(&mut self, data: &Dataset) -> Result<()> {
1028 data.validate_no_inf()?;
1029 let n = data.n_samples();
1030 if n == 0 {
1031 return Err(ScryLearnError::EmptyDataset);
1032 }
1033 if self.learning_rate <= 0.0 || self.learning_rate > 1.0 {
1034 return Err(ScryLearnError::InvalidParameter(
1035 "learning_rate must be in (0, 1]".into(),
1036 ));
1037 }
1038
1039 self.n_features = data.n_features();
1040 self.n_classes = data.n_classes();
1041 let k = self.n_classes;
1042
1043 if k < 2 {
1044 return Err(ScryLearnError::InvalidParameter(
1045 "need at least 2 classes for classification".into(),
1046 ));
1047 }
1048
1049 self.binner = FeatureBinner::new().max_bins(self.max_bins);
1051 let binned = self.binner.fit_transform(data)?;
1052
1053 let all_indices: Vec<usize> = (0..n).collect();
1054
1055 let effective_min_leaf = self.min_samples_leaf.min(n / 4).max(1);
1057
1058 if k == 2 {
1059 self.fit_binary(data, n, &binned, &all_indices, effective_min_leaf)
1060 } else {
1061 self.fit_multiclass(data, n, k, &binned, &all_indices, effective_min_leaf)
1062 }
1063 }
1064
1065 #[allow(clippy::unnecessary_wraps)]
1067 fn fit_binary(
1068 &mut self,
1069 data: &Dataset,
1070 n: usize,
1071 binned: &[Vec<u8>],
1072 all_indices: &[usize],
1073 min_leaf: usize,
1074 ) -> Result<()> {
1075 let pos_count = data.target.iter().filter(|&&y| y > 0.5).count();
1077 let p = (pos_count as f64 / n as f64).clamp(1e-7, 1.0 - 1e-7);
1078 let f0 = (p / (1.0 - p)).ln();
1079 self.init_predictions = vec![f0];
1080
1081 let mut f_vals = vec![f0; n];
1082 let mut trees_seq = Vec::with_capacity(self.n_estimators);
1083
1084 for _ in 0..self.n_estimators {
1085 let probs: Vec<f64> = f_vals.iter().map(|&f| sigmoid(f)).collect();
1087 let gradients: Vec<f64> = (0..n).map(|i| probs[i] - data.target[i]).collect();
1088 let hessians: Vec<f64> = probs.iter().map(|&p| (p * (1.0 - p)).max(1e-10)).collect();
1089
1090 let tree = build_tree_leaf_wise(
1091 binned,
1092 &gradients,
1093 &hessians,
1094 all_indices,
1095 self.max_leaf_nodes,
1096 min_leaf,
1097 self.max_depth,
1098 self.l2_regularization,
1099 self.n_features,
1100 );
1101
1102 for &i in all_indices {
1104 let sample: Vec<u8> = binned.iter().map(|col| col[i]).collect();
1105 f_vals[i] += self.learning_rate * tree.predict_one(&sample);
1106 }
1107
1108 trees_seq.push(tree);
1109 }
1110
1111 self.trees = vec![trees_seq];
1112 self.fitted = true;
1113 Ok(())
1114 }
1115
1116 #[allow(clippy::unnecessary_wraps)]
1118 fn fit_multiclass(
1119 &mut self,
1120 data: &Dataset,
1121 n: usize,
1122 k: usize,
1123 binned: &[Vec<u8>],
1124 all_indices: &[usize],
1125 min_leaf: usize,
1126 ) -> Result<()> {
1127 let y_onehot: Vec<Vec<f64>> = (0..k)
1129 .map(|cls| {
1130 data.target
1131 .iter()
1132 .map(|&y| if (y as usize) == cls { 1.0 } else { 0.0 })
1133 .collect()
1134 })
1135 .collect();
1136
1137 let class_counts: Vec<usize> = (0..k)
1139 .map(|cls| data.target.iter().filter(|&&y| (y as usize) == cls).count())
1140 .collect();
1141 let init_preds: Vec<f64> = class_counts
1142 .iter()
1143 .map(|&c| (c as f64 / n as f64).clamp(1e-7, 1.0 - 1e-7).ln())
1144 .collect();
1145 self.init_predictions.clone_from(&init_preds);
1146
1147 let mut f_vals: Vec<Vec<f64>> = (0..k).map(|c| vec![init_preds[c]; n]).collect();
1149 let mut trees_all: Vec<Vec<HistTree>> = (0..k)
1150 .map(|_| Vec::with_capacity(self.n_estimators))
1151 .collect();
1152
1153 for _ in 0..self.n_estimators {
1154 let probs = softmax_matrix(&f_vals, n, k);
1156
1157 for cls in 0..k {
1158 let gradients: Vec<f64> =
1160 (0..n).map(|i| probs[cls][i] - y_onehot[cls][i]).collect();
1161 let hessians: Vec<f64> = (0..n)
1162 .map(|i| (probs[cls][i] * (1.0 - probs[cls][i])).max(1e-10))
1163 .collect();
1164
1165 let tree = build_tree_leaf_wise(
1166 binned,
1167 &gradients,
1168 &hessians,
1169 all_indices,
1170 self.max_leaf_nodes,
1171 min_leaf,
1172 self.max_depth,
1173 self.l2_regularization,
1174 self.n_features,
1175 );
1176
1177 for &i in all_indices {
1178 let sample: Vec<u8> = binned.iter().map(|col| col[i]).collect();
1179 f_vals[cls][i] += self.learning_rate * tree.predict_one(&sample);
1180 }
1181
1182 trees_all[cls].push(tree);
1183 }
1184 }
1185
1186 self.trees = trees_all;
1187 self.fitted = true;
1188 Ok(())
1189 }
1190
1191 pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
1193 crate::version::check_schema_version(self._schema_version)?;
1194 if !self.fitted {
1195 return Err(ScryLearnError::NotFitted);
1196 }
1197 let proba = self.predict_proba(features)?;
1198 Ok(proba
1199 .iter()
1200 .map(|row| {
1201 row.iter()
1202 .enumerate()
1203 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
1204 .map_or(0.0, |(idx, _)| idx as f64)
1205 })
1206 .collect())
1207 }
1208
1209 pub fn predict_proba(&self, features: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
1211 if !self.fitted {
1212 return Err(ScryLearnError::NotFitted);
1213 }
1214 let n = features.len();
1215 let k = self.n_classes;
1216
1217 if k == 2 {
1218 let mut f_vals = vec![self.init_predictions[0]; n];
1220 for tree in &self.trees[0] {
1221 for (i, sample) in features.iter().enumerate() {
1222 f_vals[i] += self.learning_rate * tree.predict_one_raw(sample, &self.binner);
1223 }
1224 }
1225 Ok(f_vals
1226 .iter()
1227 .map(|&f| {
1228 let p = sigmoid(f);
1229 vec![1.0 - p, p]
1230 })
1231 .collect())
1232 } else {
1233 let mut f_vals: Vec<Vec<f64>> =
1235 (0..k).map(|c| vec![self.init_predictions[c]; n]).collect();
1236
1237 for (cls_vals, cls_trees) in f_vals.iter_mut().zip(self.trees.iter()).take(k) {
1238 for tree in cls_trees {
1239 for (i, sample) in features.iter().enumerate() {
1240 cls_vals[i] +=
1241 self.learning_rate * tree.predict_one_raw(sample, &self.binner);
1242 }
1243 }
1244 }
1245
1246 let probs = softmax_matrix(&f_vals, n, k);
1247 Ok((0..n)
1249 .map(|i| (0..k).map(|c| probs[c][i]).collect())
1250 .collect())
1251 }
1252 }
1253
1254 pub fn feature_importances(&self) -> Result<Vec<f64>> {
1256 if !self.fitted {
1257 return Err(ScryLearnError::NotFitted);
1258 }
1259 let m = self.n_features;
1260 let mut imp = vec![0.0; m];
1261 for tree_seq in &self.trees {
1262 for tree in tree_seq {
1263 let ti = tree.feature_importances(m);
1264 for (i, &v) in ti.iter().enumerate() {
1265 imp[i] += v;
1266 }
1267 }
1268 }
1269 let total: f64 = imp.iter().sum();
1270 if total > 0.0 {
1271 for v in &mut imp {
1272 *v /= total;
1273 }
1274 }
1275 Ok(imp)
1276 }
1277
1278 pub fn n_trees(&self) -> usize {
1280 self.trees.iter().map(Vec::len).sum()
1281 }
1282
1283 pub fn n_classes(&self) -> usize {
1285 self.n_classes
1286 }
1287
1288 pub fn n_features(&self) -> usize {
1290 self.n_features
1291 }
1292
1293 pub fn learning_rate_val(&self) -> f64 {
1295 self.learning_rate
1296 }
1297
1298 pub fn init_predictions_val(&self) -> &[f64] {
1300 &self.init_predictions
1301 }
1302
1303 pub fn class_tree_node_views(&self) -> Vec<Vec<Vec<HistNodeView>>> {
1306 self.trees
1307 .iter()
1308 .map(|class_trees| {
1309 class_trees
1310 .iter()
1311 .map(|tree| tree.to_node_views(&self.binner))
1312 .collect()
1313 })
1314 .collect()
1315 }
1316}
1317
1318impl Default for HistGradientBoostingClassifier {
1319 fn default() -> Self {
1320 Self::new()
1321 }
1322}
1323
1324#[inline]
1330fn sigmoid(x: f64) -> f64 {
1331 1.0 / (1.0 + (-x).exp())
1332}
1333
1334fn softmax_matrix(f_vals: &[Vec<f64>], n: usize, k: usize) -> Vec<Vec<f64>> {
1336 let mut result: Vec<Vec<f64>> = vec![vec![0.0; n]; k];
1337
1338 for i in 0..n {
1339 let max_f = (0..k)
1340 .map(|c| f_vals[c][i])
1341 .fold(f64::NEG_INFINITY, f64::max);
1342 let exp_sum: f64 = (0..k).map(|c| (f_vals[c][i] - max_f).exp()).sum();
1343 for c in 0..k {
1344 result[c][i] = (f_vals[c][i] - max_f).exp() / exp_sum;
1345 }
1346 }
1347
1348 result
1349}
1350
1351#[cfg(test)]
1356mod tests {
1357 use super::*;
1358 use crate::metrics::{accuracy, r2_score};
1359
1360 fn simple_regression_data() -> Dataset {
1361 let x: Vec<f64> = (0..100).map(|i| i as f64 * 0.1).collect();
1363 let y: Vec<f64> = x.iter().map(|&v| 2.0 * v + 1.0).collect();
1364 Dataset::new(vec![x], y, vec!["x".into()], "y")
1365 }
1366
1367 fn simple_classification_data() -> Dataset {
1368 let n = 200;
1369 let mut f1 = Vec::with_capacity(n);
1370 let mut f2 = Vec::with_capacity(n);
1371 let mut target = Vec::with_capacity(n);
1372 let mut rng = crate::rng::FastRng::new(42);
1373
1374 for _ in 0..n / 2 {
1375 f1.push(rng.f64() * 2.0);
1376 f2.push(rng.f64() * 2.0);
1377 target.push(0.0);
1378 }
1379 for _ in 0..n / 2 {
1380 f1.push(5.0 + rng.f64() * 2.0);
1381 f2.push(5.0 + rng.f64() * 2.0);
1382 target.push(1.0);
1383 }
1384
1385 Dataset::new(
1386 vec![f1, f2],
1387 target,
1388 vec!["f1".into(), "f2".into()],
1389 "class",
1390 )
1391 }
1392
1393 #[test]
1394 fn test_hist_gbr_fit_predict() {
1395 let data = simple_regression_data();
1396 let mut model = HistGradientBoostingRegressor::new()
1397 .n_estimators(50)
1398 .learning_rate(0.1)
1399 .max_leaf_nodes(15)
1400 .min_samples_leaf(5);
1401 model.fit(&data).unwrap();
1402
1403 let test_x = vec![vec![3.0], vec![5.0], vec![7.0]];
1404 let preds = model.predict(&test_x).unwrap();
1405 assert_eq!(preds.len(), 3);
1406
1407 assert!((preds[0] - 7.0).abs() < 1.5, "got {}", preds[0]);
1409 assert!((preds[1] - 11.0).abs() < 1.5, "got {}", preds[1]);
1410 }
1411
1412 #[test]
1413 fn test_hist_gbr_r2() {
1414 let data = simple_regression_data();
1415 let mut model = HistGradientBoostingRegressor::new()
1416 .n_estimators(100)
1417 .learning_rate(0.1)
1418 .max_leaf_nodes(31)
1419 .min_samples_leaf(3);
1420 model.fit(&data).unwrap();
1421
1422 let features = data.feature_matrix();
1423 let preds = model.predict(&features).unwrap();
1424 let r2 = r2_score(&data.target, &preds);
1425 assert!(r2 > 0.95, "R² should be > 0.95, got {r2:.4}");
1426 }
1427
1428 #[test]
1429 fn test_hist_gbc_binary() {
1430 let data = simple_classification_data();
1431 let mut model = HistGradientBoostingClassifier::new()
1432 .n_estimators(50)
1433 .learning_rate(0.1)
1434 .max_leaf_nodes(15)
1435 .min_samples_leaf(5);
1436 model.fit(&data).unwrap();
1437
1438 let features = data.feature_matrix();
1439 let preds = model.predict(&features).unwrap();
1440 let acc = accuracy(&data.target, &preds);
1441 assert!(
1442 acc > 0.90,
1443 "accuracy should be > 90%, got {:.1}%",
1444 acc * 100.0
1445 );
1446 }
1447
1448 #[test]
1449 fn test_hist_gbc_multiclass() {
1450 let n_per_class = 50;
1451 let mut rng = crate::rng::FastRng::new(42);
1452 let mut f1 = Vec::new();
1453 let mut f2 = Vec::new();
1454 let mut target = Vec::new();
1455
1456 for cls in 0..3 {
1457 let offset = cls as f64 * 5.0;
1458 for _ in 0..n_per_class {
1459 f1.push(offset + rng.f64() * 2.0);
1460 f2.push(offset + rng.f64() * 2.0);
1461 target.push(cls as f64);
1462 }
1463 }
1464
1465 let data = Dataset::new(
1466 vec![f1, f2],
1467 target,
1468 vec!["f1".into(), "f2".into()],
1469 "class",
1470 );
1471
1472 let mut model = HistGradientBoostingClassifier::new()
1473 .n_estimators(50)
1474 .learning_rate(0.1)
1475 .max_leaf_nodes(15)
1476 .min_samples_leaf(3);
1477 model.fit(&data).unwrap();
1478
1479 let features = data.feature_matrix();
1480 let preds = model.predict(&features).unwrap();
1481 let acc = accuracy(&data.target, &preds);
1482 assert!(
1483 acc > 0.90,
1484 "multiclass accuracy > 90%, got {:.1}%",
1485 acc * 100.0
1486 );
1487 }
1488
1489 #[test]
1490 fn test_hist_gbc_predict_proba() {
1491 let data = simple_classification_data();
1492 let mut model = HistGradientBoostingClassifier::new()
1493 .n_estimators(30)
1494 .learning_rate(0.1)
1495 .min_samples_leaf(5);
1496 model.fit(&data).unwrap();
1497
1498 let features = data.feature_matrix();
1499 let proba = model.predict_proba(&features).unwrap();
1500 for row in &proba {
1501 let sum: f64 = row.iter().sum();
1502 assert!((sum - 1.0).abs() < 1e-6, "probabilities should sum to 1.0");
1503 for &p in row {
1504 assert!((0.0..=1.0).contains(&p), "probability out of range: {p}");
1505 }
1506 }
1507 }
1508
1509 #[test]
1510 fn test_hist_gbr_not_fitted() {
1511 let model = HistGradientBoostingRegressor::new();
1512 let result = model.predict(&[vec![1.0]]);
1513 assert!(result.is_err());
1514 }
1515
1516 #[test]
1517 fn test_hist_gbc_not_fitted() {
1518 let model = HistGradientBoostingClassifier::new();
1519 let result = model.predict(&[vec![1.0]]);
1520 assert!(result.is_err());
1521 }
1522
1523 #[test]
1524 fn test_hist_gbr_feature_importances() {
1525 let data = simple_regression_data();
1526 let mut model = HistGradientBoostingRegressor::new()
1527 .n_estimators(50)
1528 .min_samples_leaf(3);
1529 model.fit(&data).unwrap();
1530
1531 let imp = model.feature_importances().unwrap();
1532 assert_eq!(imp.len(), 1);
1533 let sum: f64 = imp.iter().sum();
1534 assert!((sum - 1.0).abs() < 1e-6 || sum == 0.0);
1535 }
1536
1537 #[test]
1538 fn test_hist_gbr_with_nan() {
1539 let x: Vec<f64> = (0..100)
1540 .map(|i| {
1541 if i % 10 == 0 {
1542 f64::NAN
1543 } else {
1544 i as f64 * 0.1
1545 }
1546 })
1547 .collect();
1548 let y: Vec<f64> = (0..100).map(|i| i as f64 * 0.2 + 1.0).collect();
1549 let data = Dataset::new(vec![x], y, vec!["x".into()], "y");
1550
1551 let mut model = HistGradientBoostingRegressor::new()
1552 .n_estimators(50)
1553 .min_samples_leaf(3);
1554 model.fit(&data).unwrap();
1555
1556 let preds = model.predict(&[vec![f64::NAN], vec![5.0]]).unwrap();
1558 assert_eq!(preds.len(), 2);
1559 assert!(
1560 !preds[0].is_nan(),
1561 "NaN input should produce a finite prediction"
1562 );
1563 assert!(!preds[1].is_nan());
1564 }
1565}