1use crate::backend::BackendType;
4#[cfg(any(feature = "cuda", feature = "gpu"))]
5use crate::backend::GpuMode;
6use crate::booster::GBDTConfig;
7use crate::dataset::{
8 split_holdout, BinnedDataset, ColumnPermutation, FeatureInfo, FeatureType, QuantileBinner,
9};
10use crate::loss::{sigmoid, softmax, MultiClassLogLoss};
11use crate::tree::{InteractionConstraints, Tree, TreeGrower};
12use crate::tuner::ModelFormat;
13use crate::{Result, TreeBoostError};
14use rand::seq::SliceRandom;
15use rand::SeedableRng;
16use rayon::prelude::*;
17use rkyv::{Archive, Deserialize, Serialize};
18use std::path::Path;
19
20#[cfg(feature = "cuda")]
21use crate::backend::cuda::FullCudaTreeBuilder;
22
23#[cfg(feature = "gpu")]
24use crate::backend::wgpu::FullGpuTreeBuilder;
25
26#[derive(Debug, Clone, Archive, Serialize, Deserialize, serde::Serialize, serde::Deserialize)]
28pub struct GBDTModel {
29 config: GBDTConfig,
31 base_prediction: f32,
33 base_predictions_multiclass: Vec<f32>,
36 trees: Vec<Tree>,
49 num_classes: usize,
51 conformal_q: Option<f32>,
53 feature_info: Vec<FeatureInfo>,
55 column_permutation: Option<ColumnPermutation>,
57}
58
59#[inline]
65pub(crate) fn should_early_stop(
66 rounds_without_improvement: usize,
67 current_count: usize,
68 early_stopping_rounds: usize,
69 min_early_stopping: usize,
70) -> bool {
71 rounds_without_improvement >= early_stopping_rounds && current_count >= min_early_stopping
72}
73
74#[inline]
76pub(crate) fn early_stop_keep_count(best_count: usize, min_early_stopping: usize) -> usize {
77 best_count.max(min_early_stopping)
78}
79
80impl GBDTModel {
81 pub fn train(
102 features: &[f32],
103 num_features: usize,
104 targets: &[f32],
105 config: GBDTConfig,
106 feature_names: Option<Vec<String>>,
107 ) -> Result<Self> {
108 let num_rows = if num_features > 0 {
109 features.len() / num_features
110 } else {
111 0
112 };
113
114 if num_rows == 0 || num_features == 0 {
115 return Err(TreeBoostError::Config("Empty dataset".to_string()));
116 }
117
118 if features.len() != num_rows * num_features {
119 return Err(TreeBoostError::Config(format!(
120 "Feature array length {} doesn't match num_rows * num_features ({} * {} = {})",
121 features.len(),
122 num_rows,
123 num_features,
124 num_rows * num_features
125 )));
126 }
127
128 if targets.len() != num_rows {
129 return Err(TreeBoostError::Config(format!(
130 "Target length {} doesn't match num_rows {}",
131 targets.len(),
132 num_rows
133 )));
134 }
135
136 let binner = QuantileBinner::new(config.num_bins);
138
139 let binned_results: Vec<(Vec<u8>, FeatureInfo)> = (0..num_features)
141 .into_par_iter()
142 .map(|f| {
143 let column: Vec<f64> = (0..num_rows)
145 .map(|r| features[r * num_features + f] as f64)
146 .collect();
147
148 let boundaries = binner.compute_boundaries(&column);
150 let binned = binner.bin_column(&column, &boundaries);
151
152 let name = feature_names
154 .as_ref()
155 .and_then(|names| names.get(f).cloned())
156 .unwrap_or_else(|| format!("feature_{}", f));
157
158 let info = FeatureInfo {
159 name,
160 feature_type: FeatureType::Numeric,
161 num_bins: (boundaries.len() + 1).min(255) as u8,
162 bin_boundaries: boundaries,
163 };
164
165 (binned, info)
166 })
167 .collect();
168
169 let mut binned_data = Vec::with_capacity(num_rows * num_features);
171 let mut feature_info = Vec::with_capacity(num_features);
172
173 for (binned_col, info) in binned_results {
174 binned_data.extend(binned_col);
175 feature_info.push(info);
176 }
177
178 let dataset = BinnedDataset::new(num_rows, binned_data, targets.to_vec(), feature_info);
180
181 Self::train_binned(&dataset, config)
182 }
183
184 pub fn train_with_output(
213 features: &[f32],
214 num_features: usize,
215 targets: &[f32],
216 config: GBDTConfig,
217 feature_names: Option<Vec<String>>,
218 output_dir: impl AsRef<Path>,
219 formats: &[ModelFormat],
220 ) -> Result<Self> {
221 let model = Self::train(
223 features,
224 num_features,
225 targets,
226 config.clone(),
227 feature_names,
228 )?;
229
230 model.save_to_directory(output_dir, &config, formats)?;
232
233 Ok(model)
234 }
235
236 pub fn save_to_directory(
250 &self,
251 output_dir: impl AsRef<Path>,
252 config: &GBDTConfig,
253 formats: &[ModelFormat],
254 ) -> Result<()> {
255 use std::fs;
256 use std::io::Write;
257
258 if formats.is_empty() {
260 return Err(TreeBoostError::Config(
261 "formats must not be empty - specify at least one model format".to_string(),
262 ));
263 }
264
265 let dir = output_dir.as_ref();
266
267 fs::create_dir_all(dir)?;
269
270 let config_path = dir.join("config.json");
272 let config_json = serde_json::to_string_pretty(config).map_err(|e| {
273 TreeBoostError::Serialization(format!("Failed to serialize config: {}", e))
274 })?;
275 let mut file = fs::File::create(&config_path)?;
276 file.write_all(config_json.as_bytes())?;
277
278 for format in formats {
280 let model_path = dir.join(format!("model.{}", format.extension()));
281 match format {
282 ModelFormat::Rkyv => {
283 crate::serialize::save_model(self, &model_path)?;
284 }
285 ModelFormat::Bincode => {
286 crate::serialize::save_model_bincode(self, &model_path)?;
287 }
288 }
289 }
290
291 Ok(())
292 }
293
294 pub fn train_with_eras(
319 features: &[f32],
320 num_features: usize,
321 targets: &[f32],
322 era_indices: &[u16],
323 config: GBDTConfig,
324 feature_names: Option<Vec<String>>,
325 ) -> Result<Self> {
326 let num_rows = if num_features > 0 {
327 features.len() / num_features
328 } else {
329 0
330 };
331
332 if num_rows == 0 || num_features == 0 {
333 return Err(TreeBoostError::Config("Empty dataset".to_string()));
334 }
335
336 if features.len() != num_rows * num_features {
337 return Err(TreeBoostError::Config(format!(
338 "Feature array length {} doesn't match num_rows * num_features ({} * {} = {})",
339 features.len(),
340 num_rows,
341 num_features,
342 num_rows * num_features
343 )));
344 }
345
346 if targets.len() != num_rows {
347 return Err(TreeBoostError::Config(format!(
348 "Target length {} doesn't match num_rows {}",
349 targets.len(),
350 num_rows
351 )));
352 }
353
354 if era_indices.len() != num_rows {
355 return Err(TreeBoostError::Config(format!(
356 "era_indices length {} doesn't match num_rows {}",
357 era_indices.len(),
358 num_rows
359 )));
360 }
361
362 if !config.era_splitting {
363 return Err(TreeBoostError::Config(
364 "era_splitting must be enabled in config when using train_with_eras".to_string(),
365 ));
366 }
367
368 let binner = QuantileBinner::new(config.num_bins);
370
371 let binned_results: Vec<(Vec<u8>, FeatureInfo)> = (0..num_features)
373 .into_par_iter()
374 .map(|f| {
375 let column: Vec<f64> = (0..num_rows)
377 .map(|r| features[r * num_features + f] as f64)
378 .collect();
379
380 let boundaries = binner.compute_boundaries(&column);
382 let binned = binner.bin_column(&column, &boundaries);
383
384 let name = feature_names
386 .as_ref()
387 .and_then(|names| names.get(f).cloned())
388 .unwrap_or_else(|| format!("feature_{}", f));
389
390 let info = FeatureInfo {
391 name,
392 feature_type: FeatureType::Numeric,
393 num_bins: (boundaries.len() + 1).min(255) as u8,
394 bin_boundaries: boundaries,
395 };
396
397 (binned, info)
398 })
399 .collect();
400
401 let mut binned_data = Vec::with_capacity(num_rows * num_features);
403 let mut feature_info = Vec::with_capacity(num_features);
404
405 for (binned_col, info) in binned_results {
406 binned_data.extend(binned_col);
407 feature_info.push(info);
408 }
409
410 let mut dataset = BinnedDataset::new(num_rows, binned_data, targets.to_vec(), feature_info);
412 dataset.set_era_indices(era_indices.to_vec());
413
414 Self::train_binned(&dataset, config)
415 }
416
417 pub fn train_binned(dataset: &BinnedDataset, config: GBDTConfig) -> Result<Self> {
424 if let Some(num_classes) = config.loss_type.num_classes() {
426 return Self::train_binned_multiclass(dataset, config, num_classes);
427 }
428
429 config.validate().map_err(TreeBoostError::Config)?;
430
431 let loss_fn = config.loss_type.create();
432 let targets = dataset.targets();
433
434 let split = split_holdout(
436 dataset.num_rows(),
437 config.validation_ratio,
438 config.calibration_ratio,
439 config.seed,
440 );
441 let (train_indices, validation_indices, calibration_indices) =
442 (split.train, split.validation, split.calibration);
443
444 let train_targets: Vec<f32> = train_indices.iter().map(|&i| targets[i]).collect();
446 let base_prediction = loss_fn.initial_prediction(&train_targets);
447
448 let mut predictions = vec![base_prediction; dataset.num_rows()];
450
451 let mut gradients = vec![0.0f32; dataset.num_rows()];
453 let mut hessians = vec![0.0f32; dataset.num_rows()];
454
455 let interaction_constraints = if config.interaction_groups.is_empty() {
457 InteractionConstraints::new()
458 } else {
459 InteractionConstraints::from_groups(
460 config.interaction_groups.clone(),
461 dataset.num_features(),
462 )
463 };
464
465 let tree_grower = TreeGrower::new()
467 .with_max_depth(config.max_depth)
468 .with_max_leaves(config.max_leaves)
469 .with_lambda(config.lambda)
470 .with_min_samples_leaf(config.min_samples_leaf)
471 .with_min_hessian_leaf(config.min_hessian_leaf)
472 .with_entropy_weight(config.entropy_weight)
473 .with_min_gain(config.min_gain)
474 .with_learning_rate(config.learning_rate)
475 .with_colsample(config.colsample)
476 .with_monotonic_constraints(config.monotonic_constraints.clone())
477 .with_interaction_constraints(interaction_constraints)
478 .with_backend(config.backend_type)
479 .with_gpu_subgroups(config.use_gpu_subgroups)
480 .with_era_splitting(config.era_splitting);
481
482 let mut trees = Vec::with_capacity(config.num_rounds);
483 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
484
485 let early_stopping_enabled =
487 config.early_stopping_rounds > 0 && !validation_indices.is_empty();
488 let mut best_val_loss = f32::MAX;
489 let mut rounds_without_improvement = 0;
490 let mut best_num_trees = 0;
491
492 let mut sample_indices: Vec<usize> = Vec::with_capacity(train_indices.len());
494 let mut shuffle_buffer: Vec<usize> = if config.subsample < 1.0 && !config.goss_enabled {
495 train_indices.clone() } else {
497 Vec::new()
498 };
499 let mut goss_indexed: Vec<(usize, f32)> = if config.goss_enabled {
500 Vec::with_capacity(train_indices.len())
501 } else {
502 Vec::new()
503 };
504
505 let use_fused = !config.goss_enabled && config.subsample >= 1.0;
507
508 #[cfg(feature = "cuda")]
511 let mut cuda_builder: Option<FullCudaTreeBuilder> =
512 if use_fused && matches!(config.backend_type, BackendType::Cuda | BackendType::Auto) {
513 use crate::backend::cuda::CudaDevice;
514 CudaDevice::new().and_then(|d| {
515 let resolved = config.gpu_mode.resolve(BackendType::Cuda);
517 if matches!(resolved, GpuMode::Full) {
518 Some(FullCudaTreeBuilder::new(std::sync::Arc::new(d)))
519 } else {
520 None
521 }
522 })
523 } else {
524 None
525 };
526
527 #[cfg(feature = "gpu")]
528 let mut wgpu_builder: Option<FullGpuTreeBuilder> = if use_fused
529 && matches!(config.backend_type, BackendType::Wgpu | BackendType::Auto)
530 && {
531 #[cfg(feature = "cuda")]
532 {
533 cuda_builder.is_none() }
535 #[cfg(not(feature = "cuda"))]
536 {
537 true
538 }
539 } {
540 use crate::backend::wgpu::GpuDevice;
541 GpuDevice::new().and_then(|d| {
542 let resolved = config.gpu_mode.resolve(BackendType::Wgpu);
544 if matches!(resolved, GpuMode::Full) {
545 Some(FullGpuTreeBuilder::new(std::sync::Arc::new(d)))
546 } else {
547 None
548 }
549 })
550 } else {
551 None
552 };
553
554 for _round in 0..config.num_rounds {
555 #[allow(unused_mut, unused_assignments)]
557 let mut tree: Option<Tree> = None;
558
559 #[cfg(feature = "cuda")]
561 if tree.is_none() {
562 if let Some(ref mut builder) = cuda_builder {
563 for &idx in &train_indices {
565 let (g, h) = loss_fn.gradient_hessian(targets[idx], predictions[idx]);
566 gradients[idx] = g;
567 hessians[idx] = h;
568 }
569 tree = Some(builder.build_tree(
570 dataset,
571 &gradients,
572 &hessians,
573 &train_indices,
574 config.max_depth,
575 config.max_leaves,
576 config.lambda,
577 config.min_samples_leaf,
578 config.min_hessian_leaf,
579 config.min_gain,
580 config.learning_rate,
581 ));
582 }
583 }
584
585 #[cfg(feature = "gpu")]
586 if tree.is_none() {
587 if let Some(ref mut builder) = wgpu_builder {
588 for &idx in &train_indices {
590 let (g, h) = loss_fn.gradient_hessian(targets[idx], predictions[idx]);
591 gradients[idx] = g;
592 hessians[idx] = h;
593 }
594 tree = Some(builder.build_tree(
595 dataset,
596 &gradients,
597 &hessians,
598 &train_indices,
599 config.max_depth,
600 config.max_leaves,
601 config.lambda,
602 config.min_samples_leaf,
603 config.min_hessian_leaf,
604 config.min_gain,
605 config.learning_rate,
606 ));
607 }
608 }
609
610 let tree = tree.unwrap_or_else(|| {
612 if use_fused {
613 tree_grower.grow_fused(
616 dataset,
617 &train_indices,
618 targets,
619 &predictions,
620 loss_fn.as_ref(),
621 &mut gradients,
622 &mut hessians,
623 )
624 } else {
625 if config.parallel_gradient {
630 train_indices.par_iter().for_each(|&idx| {
631 let (g, h) = loss_fn.gradient_hessian(targets[idx], predictions[idx]);
632 unsafe {
634 let grad_ptr = gradients.as_ptr() as *mut f32;
635 let hess_ptr = hessians.as_ptr() as *mut f32;
636 *grad_ptr.add(idx) = g;
637 *hess_ptr.add(idx) = h;
638 }
639 });
640 } else {
641 for &idx in &train_indices {
642 let (g, h) = loss_fn.gradient_hessian(targets[idx], predictions[idx]);
643 gradients[idx] = g;
644 hessians[idx] = h;
645 }
646 }
647
648 let tree_indices: &[usize] = if config.goss_enabled {
650 sample_indices.clear();
652 Self::goss_sample_into(
653 &train_indices,
654 &mut gradients,
655 &mut hessians,
656 config.goss_top_rate,
657 config.goss_other_rate,
658 &mut rng,
659 &mut goss_indexed,
660 &mut sample_indices,
661 );
662 &sample_indices
663 } else if config.subsample < 1.0 {
664 sample_indices.clear();
666 let n_samples =
667 ((train_indices.len() as f32) * config.subsample).ceil() as usize;
668 shuffle_buffer.shuffle(&mut rng);
669 sample_indices.extend_from_slice(&shuffle_buffer[..n_samples]);
670 &sample_indices
671 } else {
672 &train_indices
673 };
674
675 tree_grower.grow_with_indices(dataset, &gradients, &hessians, tree_indices)
677 }
678 });
679
680 tree.predict_batch_add(dataset, &mut predictions);
683
684 trees.push(tree);
685
686 if early_stopping_enabled {
688 let val_loss: f32 = if validation_indices.len() >= 10000 {
691 validation_indices
692 .par_iter()
693 .map(|&idx| loss_fn.loss(targets[idx], predictions[idx]))
694 .sum::<f32>()
695 } else {
696 validation_indices
697 .iter()
698 .map(|&idx| loss_fn.loss(targets[idx], predictions[idx]))
699 .sum::<f32>()
700 } / validation_indices.len() as f32;
701
702 if val_loss < best_val_loss {
703 best_val_loss = val_loss;
704 best_num_trees = trees.len();
705 rounds_without_improvement = 0;
706 } else {
707 rounds_without_improvement += 1;
708 if should_early_stop(
709 rounds_without_improvement,
710 trees.len(),
711 config.early_stopping_rounds,
712 config.min_early_stopping_trees,
713 ) {
714 trees.truncate(early_stop_keep_count(
715 best_num_trees,
716 config.min_early_stopping_trees,
717 ));
718 break;
719 }
720 }
721 }
722 }
723
724 if early_stopping_enabled && best_num_trees > 0 && best_num_trees < trees.len() {
726 trees.truncate(early_stop_keep_count(
727 best_num_trees,
728 config.min_early_stopping_trees,
729 ));
730 }
731
732 let column_permutation = if config.column_reordering && !trees.is_empty() {
734 let importances = Self::compute_importances_from_trees(&trees, dataset.num_features());
735 Some(ColumnPermutation::from_importances(&importances))
736 } else {
737 None
738 };
739
740 let conformal_q = if !calibration_indices.is_empty() {
742 let calib_residuals: Vec<f32> = if calibration_indices.len() >= 10000 {
743 calibration_indices
744 .par_iter()
745 .map(|&idx| (targets[idx] - predictions[idx]).abs())
746 .collect()
747 } else {
748 calibration_indices
749 .iter()
750 .map(|&idx| (targets[idx] - predictions[idx]).abs())
751 .collect()
752 };
753
754 Some(Self::compute_quantile(
755 &calib_residuals,
756 config.conformal_quantile,
757 ))
758 } else {
759 None
760 };
761
762 Ok(Self {
763 config,
764 base_prediction,
765 base_predictions_multiclass: Vec::new(),
766 trees,
767 num_classes: 0,
768 conformal_q,
769 feature_info: dataset.all_feature_info().to_vec(),
770 column_permutation,
771 })
772 }
773
774 pub fn train_binned_with_validation(
791 train_dataset: &BinnedDataset,
792 val_dataset: &BinnedDataset,
793 val_targets: &[f32],
794 config: GBDTConfig,
795 ) -> Result<Self> {
796 config.validate().map_err(TreeBoostError::Config)?;
797
798 let loss_fn = config.loss_type.create();
799 let targets = train_dataset.targets();
800
801 let train_indices: Vec<usize> = (0..train_dataset.num_rows()).collect();
803
804 let base_prediction = loss_fn.initial_prediction(targets);
806
807 let mut predictions = vec![base_prediction; train_dataset.num_rows()];
809 let mut val_predictions = vec![base_prediction; val_dataset.num_rows()];
810
811 let mut gradients = vec![0.0f32; train_dataset.num_rows()];
813 let mut hessians = vec![0.0f32; train_dataset.num_rows()];
814
815 let interaction_constraints = if config.interaction_groups.is_empty() {
817 InteractionConstraints::new()
818 } else {
819 InteractionConstraints::from_groups(
820 config.interaction_groups.clone(),
821 train_dataset.num_features(),
822 )
823 };
824
825 let tree_grower = TreeGrower::new()
827 .with_max_depth(config.max_depth)
828 .with_max_leaves(config.max_leaves)
829 .with_lambda(config.lambda)
830 .with_min_samples_leaf(config.min_samples_leaf)
831 .with_min_hessian_leaf(config.min_hessian_leaf)
832 .with_entropy_weight(config.entropy_weight)
833 .with_min_gain(config.min_gain)
834 .with_learning_rate(config.learning_rate)
835 .with_colsample(config.colsample)
836 .with_monotonic_constraints(config.monotonic_constraints.clone())
837 .with_interaction_constraints(interaction_constraints)
838 .with_backend(config.backend_type)
839 .with_gpu_subgroups(config.use_gpu_subgroups)
840 .with_era_splitting(config.era_splitting);
841
842 let mut trees = Vec::with_capacity(config.num_rounds);
843 let mut rng = rand::rngs::StdRng::seed_from_u64(config.seed);
844
845 let early_stopping_enabled = config.early_stopping_rounds > 0;
847 let mut best_val_loss = f32::MAX;
848 let mut rounds_without_improvement = 0;
849 let mut best_num_trees = 0;
850
851 let mut sample_indices: Vec<usize> = Vec::with_capacity(train_indices.len());
853 let mut shuffle_buffer: Vec<usize> = if config.subsample < 1.0 && !config.goss_enabled {
854 train_indices.clone()
855 } else {
856 Vec::new()
857 };
858 let mut goss_indexed: Vec<(usize, f32)> = if config.goss_enabled {
859 Vec::with_capacity(train_indices.len())
860 } else {
861 Vec::new()
862 };
863
864 let use_fused = !config.goss_enabled && config.subsample >= 1.0;
865
866 for _round in 0..config.num_rounds {
867 let tree = if use_fused {
869 tree_grower.grow_fused(
870 train_dataset,
871 &train_indices,
872 targets,
873 &predictions,
874 loss_fn.as_ref(),
875 &mut gradients,
876 &mut hessians,
877 )
878 } else {
879 for &idx in &train_indices {
881 let (g, h) = loss_fn.gradient_hessian(targets[idx], predictions[idx]);
882 gradients[idx] = g;
883 hessians[idx] = h;
884 }
885
886 let tree_indices: &[usize] = if config.goss_enabled {
888 sample_indices.clear();
889 Self::goss_sample_into(
890 &train_indices,
891 &mut gradients,
892 &mut hessians,
893 config.goss_top_rate,
894 config.goss_other_rate,
895 &mut rng,
896 &mut goss_indexed,
897 &mut sample_indices,
898 );
899 &sample_indices
900 } else if config.subsample < 1.0 {
901 sample_indices.clear();
902 let n_samples =
903 ((train_indices.len() as f32) * config.subsample).ceil() as usize;
904 shuffle_buffer.shuffle(&mut rng);
905 sample_indices.extend_from_slice(&shuffle_buffer[..n_samples]);
906 &sample_indices
907 } else {
908 &train_indices
909 };
910
911 tree_grower.grow_with_indices(train_dataset, &gradients, &hessians, tree_indices)
912 };
913
914 tree.predict_batch_add(train_dataset, &mut predictions);
916
917 for (i, pred) in val_predictions.iter_mut().enumerate() {
919 *pred += tree.predict_row(val_dataset, i);
920 }
921
922 trees.push(tree);
923
924 if early_stopping_enabled {
927 let val_loss: f32 = val_targets
928 .iter()
929 .zip(val_predictions.iter())
930 .map(|(&target, &pred)| loss_fn.loss(target, pred))
931 .sum::<f32>()
932 / val_targets.len() as f32;
933
934 if val_loss < best_val_loss {
935 best_val_loss = val_loss;
936 best_num_trees = trees.len();
937 rounds_without_improvement = 0;
938 } else {
939 rounds_without_improvement += 1;
940 if should_early_stop(
941 rounds_without_improvement,
942 trees.len(),
943 config.early_stopping_rounds,
944 config.min_early_stopping_trees,
945 ) {
946 trees.truncate(early_stop_keep_count(
947 best_num_trees,
948 config.min_early_stopping_trees,
949 ));
950 break;
951 }
952 }
953 }
954 }
955
956 if early_stopping_enabled && best_num_trees > 0 && best_num_trees < trees.len() {
958 trees.truncate(early_stop_keep_count(
959 best_num_trees,
960 config.min_early_stopping_trees,
961 ));
962 }
963
964 let column_permutation = if config.column_reordering && !trees.is_empty() {
966 let importances =
967 Self::compute_importances_from_trees(&trees, train_dataset.num_features());
968 Some(ColumnPermutation::from_importances(&importances))
969 } else {
970 None
971 };
972
973 let conformal_q = if !val_targets.is_empty() {
975 let residuals: Vec<f32> = val_targets
976 .iter()
977 .zip(val_predictions.iter())
978 .map(|(&target, &pred)| (target - pred).abs())
979 .collect();
980 Some(Self::compute_quantile(
981 &residuals,
982 config.conformal_quantile,
983 ))
984 } else {
985 None
986 };
987
988 Ok(Self {
989 config,
990 base_prediction,
991 base_predictions_multiclass: Vec::new(),
992 trees,
993 num_classes: 0,
994 conformal_q,
995 feature_info: train_dataset.all_feature_info().to_vec(),
996 column_permutation,
997 })
998 }
999
1000 fn train_binned_multiclass(
1005 dataset: &BinnedDataset,
1006 config: GBDTConfig,
1007 num_classes: usize,
1008 ) -> Result<Self> {
1009 config.validate().map_err(TreeBoostError::Config)?;
1010
1011 let targets = dataset.targets();
1012 let multiclass_loss = MultiClassLogLoss::new(num_classes);
1013
1014 let split = split_holdout(
1016 dataset.num_rows(),
1017 config.validation_ratio,
1018 config.calibration_ratio,
1019 config.seed,
1020 );
1021 let (train_indices, validation_indices, _calibration_indices) =
1022 (split.train, split.validation, split.calibration);
1023
1024 let train_targets: Vec<f32> = train_indices.iter().map(|&i| targets[i]).collect();
1026 let base_predictions = multiclass_loss.initial_predictions(&train_targets);
1027
1028 let num_rows = dataset.num_rows();
1030 let mut predictions: Vec<f32> = Vec::with_capacity(num_rows * num_classes);
1031 for _ in 0..num_rows {
1032 predictions.extend_from_slice(&base_predictions);
1033 }
1034
1035 let mut gradients = vec![0.0f32; num_rows];
1037 let mut hessians = vec![0.0f32; num_rows];
1038
1039 let interaction_constraints = if config.interaction_groups.is_empty() {
1041 InteractionConstraints::new()
1042 } else {
1043 InteractionConstraints::from_groups(
1044 config.interaction_groups.clone(),
1045 dataset.num_features(),
1046 )
1047 };
1048
1049 let tree_grower = TreeGrower::new()
1051 .with_max_depth(config.max_depth)
1052 .with_max_leaves(config.max_leaves)
1053 .with_lambda(config.lambda)
1054 .with_min_samples_leaf(config.min_samples_leaf)
1055 .with_min_hessian_leaf(config.min_hessian_leaf)
1056 .with_entropy_weight(config.entropy_weight)
1057 .with_min_gain(config.min_gain)
1058 .with_learning_rate(config.learning_rate)
1059 .with_colsample(config.colsample)
1060 .with_monotonic_constraints(config.monotonic_constraints.clone())
1061 .with_interaction_constraints(interaction_constraints)
1062 .with_backend(config.backend_type)
1063 .with_gpu_subgroups(config.use_gpu_subgroups)
1064 .with_era_splitting(config.era_splitting);
1065
1066 let mut trees = Vec::with_capacity(config.num_rounds * num_classes);
1068
1069 let early_stopping_enabled =
1071 config.early_stopping_rounds > 0 && !validation_indices.is_empty();
1072 let mut best_val_loss = f32::MAX;
1073 let mut rounds_without_improvement = 0;
1074 let mut best_num_rounds = 0;
1075
1076 for round in 0..config.num_rounds {
1077 for class_idx in 0..num_classes {
1079 multiclass_loss.compute_gradients_batch(
1081 class_idx,
1082 targets,
1083 &predictions,
1084 &train_indices,
1085 &mut gradients,
1086 &mut hessians,
1087 );
1088
1089 let tree =
1091 tree_grower.grow_with_indices(dataset, &gradients, &hessians, &train_indices);
1092
1093 for idx in 0..num_rows {
1095 let delta = tree.predict(|f| dataset.get_bin(idx, f));
1096 predictions[idx * num_classes + class_idx] += delta;
1097 }
1098
1099 trees.push(tree);
1100 }
1101
1102 if early_stopping_enabled {
1104 let mut val_loss = 0.0f32;
1106 for &idx in &validation_indices {
1107 let target_class = targets[idx] as usize;
1108 let row_preds = &predictions[idx * num_classes..(idx + 1) * num_classes];
1109 let probs = softmax(row_preds);
1110 val_loss -= probs[target_class].max(1e-15).ln();
1112 }
1113 val_loss /= validation_indices.len() as f32;
1114
1115 if val_loss < best_val_loss {
1116 best_val_loss = val_loss;
1117 best_num_rounds = round + 1;
1118 rounds_without_improvement = 0;
1119 } else {
1120 rounds_without_improvement += 1;
1121 if should_early_stop(
1123 rounds_without_improvement,
1124 trees.len(),
1125 config.early_stopping_rounds,
1126 config.min_early_stopping_trees,
1127 ) {
1128 let keep_rounds = early_stop_keep_count(
1129 best_num_rounds,
1130 config.min_early_stopping_trees / num_classes.max(1),
1131 );
1132 trees.truncate(keep_rounds * num_classes);
1133 break;
1134 }
1135 }
1136 }
1137 }
1138
1139 if early_stopping_enabled
1141 && best_num_rounds > 0
1142 && best_num_rounds * num_classes < trees.len()
1143 {
1144 let keep_rounds = early_stop_keep_count(
1145 best_num_rounds,
1146 config.min_early_stopping_trees / num_classes.max(1),
1147 );
1148 trees.truncate(keep_rounds * num_classes);
1149 }
1150
1151 let column_permutation = if config.column_reordering && !trees.is_empty() {
1153 let importances = Self::compute_importances_from_trees(&trees, dataset.num_features());
1154 Some(ColumnPermutation::from_importances(&importances))
1155 } else {
1156 None
1157 };
1158
1159 Ok(Self {
1160 config,
1161 base_prediction: 0.0, base_predictions_multiclass: base_predictions,
1163 trees,
1164 num_classes,
1165 conformal_q: None, feature_info: dataset.all_feature_info().to_vec(),
1167 column_permutation,
1168 })
1169 }
1170
1171 fn compute_importances_from_trees(trees: &[Tree], num_features: usize) -> Vec<f32> {
1173 let mut importances = vec![0.0f32; num_features];
1174
1175 for tree in trees {
1176 for (_, node) in tree.internal_nodes() {
1177 if let Some((feature_idx, _, _, _, _)) = node.split_info() {
1178 importances[feature_idx] += node.sum_hessians;
1179 }
1180 }
1181 }
1182
1183 let total: f32 = importances.iter().sum();
1185 if total > 0.0 {
1186 for imp in &mut importances {
1187 *imp /= total;
1188 }
1189 }
1190
1191 importances
1192 }
1193
1194 #[allow(clippy::too_many_arguments)]
1206 fn goss_sample_into(
1207 train_indices: &[usize],
1208 gradients: &mut [f32],
1209 hessians: &mut [f32],
1210 top_rate: f32,
1211 other_rate: f32,
1212 rng: &mut rand::rngs::StdRng,
1213 indexed_buffer: &mut Vec<(usize, f32)>,
1214 result: &mut Vec<usize>,
1215 ) {
1216 let n = train_indices.len();
1217 if n == 0 {
1218 return;
1219 }
1220
1221 let n_top = ((n as f32) * top_rate).ceil() as usize;
1223 let n_top = n_top.min(n);
1224 let n_other = ((n as f32) * other_rate).ceil() as usize;
1226
1227 indexed_buffer.clear();
1229 indexed_buffer.extend(train_indices.iter().map(|&idx| (idx, gradients[idx].abs())));
1230
1231 if n_top < n {
1233 indexed_buffer.select_nth_unstable_by(n_top, |a, b| {
1234 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
1235 });
1236 }
1237
1238 result.extend(indexed_buffer[..n_top].iter().map(|(idx, _)| *idx));
1240
1241 let rest_slice = &mut indexed_buffer[n_top..];
1243 rest_slice.shuffle(rng);
1244 let n_rest = rest_slice.len().min(n_other);
1245
1246 let weight = (1.0 - top_rate) / other_rate;
1248
1249 for &(idx, _) in &rest_slice[..n_rest] {
1251 gradients[idx] *= weight;
1252 hessians[idx] *= weight;
1253 result.push(idx);
1254 }
1255 }
1256
1257 fn compute_quantile(values: &[f32], q: f32) -> f32 {
1259 if values.is_empty() {
1260 return 0.0;
1261 }
1262
1263 let mut sorted: Vec<f32> = values.iter().copied().filter(|v| !v.is_nan()).collect();
1265
1266 if sorted.is_empty() {
1267 return 0.0;
1268 }
1269
1270 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1271
1272 let idx = ((sorted.len() as f32) * q).ceil() as usize;
1273 let idx = idx.min(sorted.len() - 1);
1274 sorted[idx]
1275 }
1276
1277 pub fn predict_row(&self, dataset: &BinnedDataset, row_idx: usize) -> f32 {
1279 let mut pred = self.base_prediction;
1280 for tree in &self.trees {
1281 pred += tree.predict_row(dataset, row_idx);
1282 }
1283 pred
1284 }
1285
1286 pub fn predict(&self, dataset: &BinnedDataset) -> Vec<f32> {
1293 if self.config.parallel_prediction {
1294 self.predict_parallel(dataset)
1295 } else {
1296 self.predict_sequential(dataset)
1297 }
1298 }
1299
1300 pub fn predict_sequential(&self, dataset: &BinnedDataset) -> Vec<f32> {
1305 let num_rows = dataset.num_rows();
1306
1307 let mut predictions = vec![self.base_prediction; num_rows];
1309
1310 for tree in &self.trees {
1312 tree.predict_batch_add(dataset, &mut predictions);
1313 }
1314
1315 predictions
1316 }
1317
1318 pub fn predict_parallel(&self, dataset: &BinnedDataset) -> Vec<f32> {
1323 let num_rows = dataset.num_rows();
1324
1325 if num_rows < 1000 || self.trees.is_empty() {
1327 return self.predict_sequential(dataset);
1328 }
1329
1330 let mut predictions = vec![self.base_prediction; num_rows];
1332
1333 let num_threads = rayon::current_num_threads();
1335 let chunk_size = (num_rows / (num_threads * 4)).max(256);
1336
1337 predictions
1339 .par_chunks_mut(chunk_size)
1340 .enumerate()
1341 .for_each(|(chunk_idx, chunk)| {
1342 let start_row = chunk_idx * chunk_size;
1343
1344 for tree in &self.trees {
1346 for (i, pred) in chunk.iter_mut().enumerate() {
1347 let row_idx = start_row + i;
1348 *pred += tree.predict(|f| dataset.get_bin(row_idx, f));
1349 }
1350 }
1351 });
1352
1353 predictions
1354 }
1355
1356 #[doc(hidden)]
1358 pub fn predict_row_wise(&self, dataset: &BinnedDataset) -> Vec<f32> {
1359 let num_rows = dataset.num_rows();
1360 let num_features = dataset.num_features();
1361
1362 let mut predictions = Vec::with_capacity(num_rows);
1363 let mut row_bins = vec![0u8; num_features];
1364
1365 for row_idx in 0..num_rows {
1366 for (f, bin) in row_bins.iter_mut().enumerate() {
1368 *bin = dataset.get_bin(row_idx, f);
1369 }
1370
1371 let mut pred = self.base_prediction;
1373 for tree in &self.trees {
1374 pred += tree.predict(|f| row_bins[f]);
1375 }
1376 predictions.push(pred);
1377 }
1378
1379 predictions
1380 }
1381
1382 pub fn predict_with_intervals(
1386 &self,
1387 dataset: &BinnedDataset,
1388 ) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
1389 let predictions = self.predict(dataset);
1390
1391 let q = self.conformal_q.unwrap_or(0.0);
1392 let lower: Vec<f32> = predictions.iter().map(|&p| p - q).collect();
1393 let upper: Vec<f32> = predictions.iter().map(|&p| p + q).collect();
1394
1395 (predictions, lower, upper)
1396 }
1397
1398 pub fn predict_proba(&self, dataset: &BinnedDataset) -> Vec<f32> {
1410 let raw = self.predict(dataset);
1411 raw.iter().map(|&r| sigmoid(r)).collect()
1412 }
1413
1414 pub fn predict_class(&self, dataset: &BinnedDataset, threshold: f32) -> Vec<u32> {
1426 let proba = self.predict_proba(dataset);
1427 proba
1428 .iter()
1429 .map(|&p| if p >= threshold { 1 } else { 0 })
1430 .collect()
1431 }
1432
1433 pub fn is_multiclass(&self) -> bool {
1439 self.num_classes > 0
1440 }
1441
1442 pub fn get_num_classes(&self) -> usize {
1444 self.num_classes
1445 }
1446
1447 pub fn predict_proba_multiclass(&self, dataset: &BinnedDataset) -> Vec<Vec<f32>> {
1455 if self.num_classes == 0 {
1456 return self
1458 .predict_proba(dataset)
1459 .into_iter()
1460 .map(|p| vec![1.0 - p, p])
1461 .collect();
1462 }
1463
1464 let num_rows = dataset.num_rows();
1465 let num_classes = self.num_classes;
1466 let num_rounds = self.trees.len() / num_classes;
1467
1468 let mut raw_preds: Vec<f32> = Vec::with_capacity(num_rows * num_classes);
1470 for _ in 0..num_rows {
1471 raw_preds.extend_from_slice(&self.base_predictions_multiclass);
1472 }
1473
1474 for round in 0..num_rounds {
1477 for class_idx in 0..num_classes {
1478 let tree_idx = round * num_classes + class_idx;
1479 let tree = &self.trees[tree_idx];
1480
1481 for row_idx in 0..num_rows {
1482 let delta = tree.predict(|f| dataset.get_bin(row_idx, f));
1483 raw_preds[row_idx * num_classes + class_idx] += delta;
1484 }
1485 }
1486 }
1487
1488 let mut result = Vec::with_capacity(num_rows);
1490 for row_idx in 0..num_rows {
1491 let row_preds = &raw_preds[row_idx * num_classes..(row_idx + 1) * num_classes];
1492 result.push(softmax(row_preds));
1493 }
1494
1495 result
1496 }
1497
1498 pub fn predict_class_multiclass(&self, dataset: &BinnedDataset) -> Vec<u32> {
1506 let proba = self.predict_proba_multiclass(dataset);
1507 proba
1508 .iter()
1509 .map(|p| {
1510 p.iter()
1511 .enumerate()
1512 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
1513 .map(|(idx, _)| idx as u32)
1514 .unwrap_or(0)
1515 })
1516 .collect()
1517 }
1518
1519 pub fn predict_raw_multiclass(&self, dataset: &BinnedDataset) -> Vec<Vec<f32>> {
1524 if self.num_classes == 0 {
1525 return self.predict(dataset).into_iter().map(|p| vec![p]).collect();
1527 }
1528
1529 let num_rows = dataset.num_rows();
1530 let num_classes = self.num_classes;
1531 let num_rounds = self.trees.len() / num_classes;
1532
1533 let mut raw_preds: Vec<f32> = Vec::with_capacity(num_rows * num_classes);
1535 for _ in 0..num_rows {
1536 raw_preds.extend_from_slice(&self.base_predictions_multiclass);
1537 }
1538
1539 for round in 0..num_rounds {
1541 for class_idx in 0..num_classes {
1542 let tree_idx = round * num_classes + class_idx;
1543 let tree = &self.trees[tree_idx];
1544
1545 for row_idx in 0..num_rows {
1546 let delta = tree.predict(|f| dataset.get_bin(row_idx, f));
1547 raw_preds[row_idx * num_classes + class_idx] += delta;
1548 }
1549 }
1550 }
1551
1552 let mut result = Vec::with_capacity(num_rows);
1554 for row_idx in 0..num_rows {
1555 let row_preds = &raw_preds[row_idx * num_classes..(row_idx + 1) * num_classes];
1556 result.push(row_preds.to_vec());
1557 }
1558
1559 result
1560 }
1561
1562 pub fn predict_raw(&self, features: &[f64]) -> Vec<f32> {
1579 let num_features = self.num_features();
1580 if num_features == 0 {
1581 return vec![];
1582 }
1583
1584 let num_rows = features.len() / num_features;
1585 debug_assert_eq!(features.len(), num_rows * num_features);
1586
1587 if self.config.parallel_prediction && num_rows >= 1000 {
1588 self.predict_raw_parallel(features, num_features)
1589 } else {
1590 self.predict_raw_sequential(features, num_features)
1591 }
1592 }
1593
1594 fn predict_raw_sequential(&self, features: &[f64], num_features: usize) -> Vec<f32> {
1596 let num_rows = features.len() / num_features;
1597
1598 let mut predictions = vec![self.base_prediction; num_rows];
1600
1601 for tree in &self.trees {
1603 tree.predict_batch_add_raw(features, num_features, &mut predictions);
1604 }
1605
1606 predictions
1607 }
1608
1609 fn predict_raw_parallel(&self, features: &[f64], num_features: usize) -> Vec<f32> {
1611 let num_rows = features.len() / num_features;
1612
1613 if num_rows < 1000 || self.trees.is_empty() {
1615 return self.predict_raw_sequential(features, num_features);
1616 }
1617
1618 let mut predictions = vec![self.base_prediction; num_rows];
1620
1621 let num_threads = rayon::current_num_threads();
1623 let chunk_size = (num_rows / (num_threads * 4)).max(256);
1624
1625 predictions
1627 .par_chunks_mut(chunk_size)
1628 .enumerate()
1629 .for_each(|(chunk_idx, chunk)| {
1630 let start_row = chunk_idx * chunk_size;
1631 let chunk_features_start = start_row * num_features;
1632
1633 for tree in &self.trees {
1635 for (i, pred) in chunk.iter_mut().enumerate() {
1636 let row_offset = chunk_features_start + i * num_features;
1637 *pred += tree.predict_raw(|f| features[row_offset + f]);
1638 }
1639 }
1640 });
1641
1642 predictions
1643 }
1644
1645 pub fn predict_raw_with_intervals(&self, features: &[f64]) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
1649 let predictions = self.predict_raw(features);
1650
1651 let q = self.conformal_q.unwrap_or(0.0);
1652 let lower: Vec<f32> = predictions.iter().map(|&p| p - q).collect();
1653 let upper: Vec<f32> = predictions.iter().map(|&p| p + q).collect();
1654
1655 (predictions, lower, upper)
1656 }
1657
1658 pub fn predict_proba_raw(&self, features: &[f64]) -> Vec<f32> {
1663 let raw = self.predict_raw(features);
1664 raw.iter().map(|&r| sigmoid(r)).collect()
1665 }
1666
1667 pub fn predict_class_raw(&self, features: &[f64], threshold: f32) -> Vec<u32> {
1672 let proba = self.predict_proba_raw(features);
1673 proba
1674 .iter()
1675 .map(|&p| if p >= threshold { 1 } else { 0 })
1676 .collect()
1677 }
1678
1679 pub fn predict_proba_multiclass_raw(&self, features: &[f64]) -> Vec<Vec<f32>> {
1695 if self.num_classes == 0 {
1696 return self
1698 .predict_proba_raw(features)
1699 .into_iter()
1700 .map(|p| vec![1.0 - p, p])
1701 .collect();
1702 }
1703
1704 let num_features = self.num_features();
1705 if num_features == 0 {
1706 return vec![];
1707 }
1708
1709 let num_rows = features.len() / num_features;
1710 let num_classes = self.num_classes;
1711 let num_rounds = self.trees.len() / num_classes;
1712
1713 let mut raw_preds: Vec<f32> = Vec::with_capacity(num_rows * num_classes);
1715 for _ in 0..num_rows {
1716 raw_preds.extend_from_slice(&self.base_predictions_multiclass);
1717 }
1718
1719 for round in 0..num_rounds {
1722 for class_idx in 0..num_classes {
1723 let tree_idx = round * num_classes + class_idx;
1724 let tree = &self.trees[tree_idx];
1725
1726 for row_idx in 0..num_rows {
1727 let row_offset = row_idx * num_features;
1728 let delta = tree.predict_raw(|f| features[row_offset + f]);
1729 raw_preds[row_idx * num_classes + class_idx] += delta;
1730 }
1731 }
1732 }
1733
1734 let mut result = Vec::with_capacity(num_rows);
1736 for row_idx in 0..num_rows {
1737 let row_preds = &raw_preds[row_idx * num_classes..(row_idx + 1) * num_classes];
1738 result.push(softmax(row_preds));
1739 }
1740
1741 result
1742 }
1743
1744 pub fn predict_class_multiclass_raw(&self, features: &[f64]) -> Vec<u32> {
1755 let proba = self.predict_proba_multiclass_raw(features);
1756 proba
1757 .iter()
1758 .map(|p| {
1759 p.iter()
1760 .enumerate()
1761 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
1762 .map(|(idx, _)| idx as u32)
1763 .unwrap_or(0)
1764 })
1765 .collect()
1766 }
1767
1768 pub fn predict_raw_multiclass_raw(&self, features: &[f64]) -> Vec<Vec<f32>> {
1778 if self.num_classes == 0 {
1779 return self
1781 .predict_raw(features)
1782 .into_iter()
1783 .map(|p| vec![p])
1784 .collect();
1785 }
1786
1787 let num_features = self.num_features();
1788 if num_features == 0 {
1789 return vec![];
1790 }
1791
1792 let num_rows = features.len() / num_features;
1793 let num_classes = self.num_classes;
1794 let num_rounds = self.trees.len() / num_classes;
1795
1796 let mut raw_preds: Vec<f32> = Vec::with_capacity(num_rows * num_classes);
1798 for _ in 0..num_rows {
1799 raw_preds.extend_from_slice(&self.base_predictions_multiclass);
1800 }
1801
1802 for round in 0..num_rounds {
1804 for class_idx in 0..num_classes {
1805 let tree_idx = round * num_classes + class_idx;
1806 let tree = &self.trees[tree_idx];
1807
1808 for row_idx in 0..num_rows {
1809 let row_offset = row_idx * num_features;
1810 let delta = tree.predict_raw(|f| features[row_offset + f]);
1811 raw_preds[row_idx * num_classes + class_idx] += delta;
1812 }
1813 }
1814 }
1815
1816 let mut result = Vec::with_capacity(num_rows);
1818 for row_idx in 0..num_rows {
1819 let row_preds = &raw_preds[row_idx * num_classes..(row_idx + 1) * num_classes];
1820 result.push(row_preds.to_vec());
1821 }
1822
1823 result
1824 }
1825
1826 pub fn num_trees(&self) -> usize {
1828 self.trees.len()
1829 }
1830
1831 pub fn config(&self) -> &GBDTConfig {
1833 &self.config
1834 }
1835
1836 pub fn base_prediction(&self) -> f32 {
1838 self.base_prediction
1839 }
1840
1841 pub fn conformal_quantile(&self) -> Option<f32> {
1843 self.conformal_q
1844 }
1845
1846 pub fn trees(&self) -> &[Tree] {
1848 &self.trees
1849 }
1850
1851 pub fn feature_info(&self) -> &[FeatureInfo] {
1853 &self.feature_info
1854 }
1855
1856 pub fn num_features(&self) -> usize {
1858 self.feature_info.len()
1859 }
1860
1861 pub fn column_permutation(&self) -> Option<&ColumnPermutation> {
1863 self.column_permutation.as_ref()
1864 }
1865
1866 pub fn feature_importance(&self) -> Vec<f32> {
1868 let mut importances = vec![0.0f32; self.num_features()];
1869
1870 for tree in &self.trees {
1871 for (_, node) in tree.internal_nodes() {
1872 let (feature_idx, _, _, _, _) = node.split_info().unwrap();
1874 importances[feature_idx] += node.sum_hessians;
1876 }
1877 }
1878
1879 let total: f32 = importances.iter().sum();
1881 if total > 0.0 {
1882 for imp in &mut importances {
1883 *imp /= total;
1884 }
1885 }
1886
1887 importances
1888 }
1889
1890 pub fn optimize_dataset_layout(
1897 &self,
1898 dataset: &BinnedDataset,
1899 ) -> (BinnedDataset, crate::dataset::ColumnPermutation) {
1900 let importances = self.feature_importance();
1901 let permutation = crate::dataset::ColumnPermutation::from_importances(&importances);
1902 let optimized = crate::dataset::reorder_dataset(dataset, &permutation);
1903 (optimized, permutation)
1904 }
1905
1906 pub fn create_packed_dataset(&self, dataset: &BinnedDataset) -> crate::dataset::PackedDataset {
1911 crate::dataset::PackedDataset::from_binned(dataset)
1912 }
1913
1914 pub fn num_rounds(&self) -> usize {
1923 if self.num_classes == 0 {
1924 self.trees.len()
1925 } else {
1926 self.trees.len() / self.num_classes.max(1)
1927 }
1928 }
1929
1930 pub fn append_trees(&mut self, new_trees: Vec<Tree>) {
1955 self.trees.extend(new_trees);
1956 }
1957
1958 pub fn append_tree(&mut self, tree: Tree) {
1962 self.trees.push(tree);
1963 }
1964
1965 pub fn compute_residuals(&self, dataset: &BinnedDataset, targets: &[f32]) -> Vec<f32> {
1977 let predictions = self.predict(dataset);
1978 predictions
1979 .iter()
1980 .zip(targets)
1981 .map(|(p, t)| t - p)
1982 .collect()
1983 }
1984
1985 pub fn compute_residuals_raw(&self, features: &[f64], targets: &[f32]) -> Vec<f32> {
1994 let predictions = self.predict_raw(features);
1995 predictions
1996 .iter()
1997 .zip(targets)
1998 .map(|(p, t)| t - p)
1999 .collect()
2000 }
2001
2002 pub fn is_compatible_for_update(&self, num_features: usize) -> bool {
2006 self.num_features() == num_features
2007 }
2008
2009 pub fn trees_mut(&mut self) -> &mut Vec<Tree> {
2013 &mut self.trees
2014 }
2015
2016 pub fn truncate_to_rounds(&mut self, num_rounds: usize) {
2026 let trees_per_round = if self.num_classes == 0 {
2027 1
2028 } else {
2029 self.num_classes
2030 };
2031 let target_trees = num_rounds * trees_per_round;
2032 if target_trees < self.trees.len() {
2033 self.trees.truncate(target_trees);
2034 }
2035 }
2036}
2037
2038use crate::tuner::{ParamValue, TunableModel};
2043use std::collections::HashMap;
2044
2045impl TunableModel for GBDTModel {
2046 type Config = GBDTConfig;
2047
2048 fn train(dataset: &BinnedDataset, config: &Self::Config) -> Result<Self> {
2049 Self::train_binned(dataset, config.clone())
2050 }
2051
2052 fn train_with_validation(
2053 train_data: &BinnedDataset,
2054 val_data: &BinnedDataset,
2055 val_targets: &[f32],
2056 config: &Self::Config,
2057 ) -> Result<Self> {
2058 Self::train_binned_with_validation(train_data, val_data, val_targets, config.clone())
2059 }
2060
2061 fn predict(&self, dataset: &BinnedDataset) -> Vec<f32> {
2062 GBDTModel::predict(self, dataset)
2063 }
2064
2065 fn num_trees(&self) -> usize {
2066 self.trees.len()
2067 }
2068
2069 fn apply_params(config: &mut Self::Config, params: &HashMap<String, ParamValue>) {
2070 for (name, value) in params {
2071 match (name.as_str(), value) {
2072 ("max_depth", ParamValue::Numeric(v)) => config.max_depth = *v as usize,
2073 ("learning_rate", ParamValue::Numeric(v)) => config.learning_rate = *v,
2074 ("subsample", ParamValue::Numeric(v)) => config.subsample = *v,
2075 ("colsample", ParamValue::Numeric(v)) => config.colsample = *v,
2076 ("lambda", ParamValue::Numeric(v)) => config.lambda = *v,
2077 ("entropy_weight", ParamValue::Numeric(v)) => config.entropy_weight = *v,
2078 ("min_samples_leaf", ParamValue::Numeric(v)) => {
2079 config.min_samples_leaf = *v as usize
2080 }
2081 ("min_hessian_leaf", ParamValue::Numeric(v)) => config.min_hessian_leaf = *v,
2082 ("min_gain", ParamValue::Numeric(v)) => config.min_gain = *v,
2083 ("num_rounds", ParamValue::Numeric(v)) => config.num_rounds = *v as usize,
2084 ("goss_top_rate", ParamValue::Numeric(v)) => config.goss_top_rate = *v,
2085 ("goss_other_rate", ParamValue::Numeric(v)) => config.goss_other_rate = *v,
2086 _ => {} }
2088 }
2089 }
2090
2091 fn valid_params() -> &'static [&'static str] {
2092 &[
2093 "max_depth",
2094 "learning_rate",
2095 "subsample",
2096 "colsample",
2097 "lambda",
2098 "entropy_weight",
2099 "min_samples_leaf",
2100 "min_hessian_leaf",
2101 "min_gain",
2102 "num_rounds",
2103 "goss_top_rate",
2104 "goss_other_rate",
2105 ]
2106 }
2107
2108 fn default_config() -> Self::Config {
2109 GBDTConfig::default()
2110 }
2111
2112 fn is_gpu_config(config: &Self::Config) -> bool {
2113 matches!(
2114 config.backend_type,
2115 BackendType::Wgpu | BackendType::Cuda | BackendType::Auto
2116 )
2117 }
2118
2119 fn get_learning_rate(config: &Self::Config) -> f32 {
2120 config.learning_rate
2121 }
2122
2123 fn configure_validation(
2124 config: &mut Self::Config,
2125 validation_ratio: f32,
2126 early_stopping_rounds: usize,
2127 ) {
2128 config.validation_ratio = validation_ratio;
2129 config.early_stopping_rounds = early_stopping_rounds;
2130 }
2131
2132 fn set_num_rounds(config: &mut Self::Config, num_rounds: usize) {
2133 config.num_rounds = num_rounds;
2134 }
2135
2136 fn save_rkyv(&self, path: &std::path::Path) -> Result<()> {
2137 crate::serialize::save_model(self, path)
2138 }
2139
2140 fn save_bincode(&self, path: &std::path::Path) -> Result<()> {
2141 crate::serialize::save_model_bincode(self, path)
2142 }
2143
2144 fn supports_conformal() -> bool {
2145 true
2146 }
2147
2148 fn conformal_quantile(&self) -> Option<f32> {
2149 self.conformal_q
2150 }
2151
2152 fn configure_conformal(config: &mut Self::Config, calibration_ratio: f32, quantile: f32) {
2153 config.calibration_ratio = calibration_ratio;
2154 config.conformal_quantile = quantile;
2155 }
2156}
2157
2158#[cfg(test)]
2159mod tests {
2160 use super::*;
2161 use crate::dataset::{FeatureInfo, FeatureType};
2162
2163 fn create_regression_dataset(num_rows: usize, noise: f32) -> BinnedDataset {
2164 let num_features = 3;
2165
2166 let mut features = Vec::with_capacity(num_rows * num_features);
2168 for f in 0..num_features {
2169 for r in 0..num_rows {
2170 features.push(((r * (f + 1) * 17) % 256) as u8);
2171 }
2172 }
2173
2174 let targets: Vec<f32> = (0..num_rows)
2176 .map(|i| {
2177 let f0 = features[i] as f32 / 255.0;
2178 let f1 = features[num_rows + i] as f32 / 255.0;
2179 f0 * 10.0 + f1 * 5.0 + noise * (i as f32 % 10.0 - 5.0) / 5.0
2180 })
2181 .collect();
2182
2183 let feature_info = (0..num_features)
2184 .map(|i| FeatureInfo {
2185 name: format!("feature_{}", i),
2186 feature_type: FeatureType::Numeric,
2187 num_bins: 255,
2188 bin_boundaries: vec![],
2189 })
2190 .collect();
2191
2192 BinnedDataset::new(num_rows, features, targets, feature_info)
2193 }
2194
2195 #[test]
2196 fn test_train_basic() {
2197 let dataset = create_regression_dataset(500, 0.1);
2198
2199 let config = GBDTConfig::new()
2200 .with_num_rounds(10)
2201 .with_max_depth(3)
2202 .with_learning_rate(0.1);
2203
2204 let model = GBDTModel::train_binned(&dataset, config).unwrap();
2205
2206 assert_eq!(model.num_trees(), 10);
2207
2208 let predictions = model.predict(&dataset);
2210 assert_eq!(predictions.len(), 500);
2211 }
2212
2213 #[test]
2214 fn test_train_with_pseudo_huber() {
2215 let dataset = create_regression_dataset(500, 1.0);
2216
2217 let config = GBDTConfig::new()
2218 .with_num_rounds(10)
2219 .with_pseudo_huber_loss(1.0);
2220
2221 let model = GBDTModel::train_binned(&dataset, config).unwrap();
2222 assert_eq!(model.num_trees(), 10);
2223 }
2224
2225 #[test]
2226 fn test_train_with_conformal() {
2227 let dataset = create_regression_dataset(500, 0.5);
2228
2229 let config = GBDTConfig::new()
2230 .with_num_rounds(10)
2231 .with_conformal(0.2, 0.9);
2232
2233 let model = GBDTModel::train_binned(&dataset, config).unwrap();
2234
2235 assert!(model.conformal_quantile().is_some());
2236 assert!(model.conformal_quantile().unwrap() > 0.0);
2237
2238 let (preds, lower, upper) = model.predict_with_intervals(&dataset);
2240 assert_eq!(preds.len(), dataset.num_rows());
2241 assert_eq!(lower.len(), dataset.num_rows());
2242 assert_eq!(upper.len(), dataset.num_rows());
2243
2244 for i in 0..preds.len() {
2246 assert!((preds[i] - lower[i] - (upper[i] - preds[i])).abs() < 1e-6);
2247 }
2248 }
2249
2250 #[test]
2251 fn test_train_with_early_stopping() {
2252 let dataset = create_regression_dataset(1000, 0.1);
2253
2254 let config = GBDTConfig::new()
2255 .with_num_rounds(100) .with_max_depth(4)
2257 .with_early_stopping(5, 0.2); let model = GBDTModel::train_binned(&dataset, config).unwrap();
2260
2261 assert!(model.num_trees() < 100);
2264 assert!(model.num_trees() > 0);
2265 }
2266
2267 #[test]
2268 fn test_train_with_subsampling() {
2269 let dataset = create_regression_dataset(1000, 0.1);
2270
2271 let config = GBDTConfig::new()
2272 .with_num_rounds(10)
2273 .with_max_depth(4)
2274 .with_subsample(0.8) .with_colsample(0.8); let model = GBDTModel::train_binned(&dataset, config).unwrap();
2278
2279 assert_eq!(model.num_trees(), 10);
2280
2281 let predictions = model.predict(&dataset);
2283 assert_eq!(predictions.len(), 1000);
2284 }
2285
2286 #[test]
2287 fn test_train_with_goss() {
2288 let dataset = create_regression_dataset(1000, 0.1);
2289
2290 let config = GBDTConfig::new()
2292 .with_num_rounds(10)
2293 .with_max_depth(4)
2294 .with_goss(true);
2295
2296 let model = GBDTModel::train_binned(&dataset, config).unwrap();
2297
2298 assert_eq!(model.num_trees(), 10);
2299
2300 let predictions = model.predict(&dataset);
2302 assert_eq!(predictions.len(), 1000);
2303 }
2304
2305 #[test]
2306 fn test_train_with_goss_custom_rates() {
2307 let dataset = create_regression_dataset(1000, 0.1);
2308
2309 let config = GBDTConfig::new()
2311 .with_num_rounds(10)
2312 .with_max_depth(4)
2313 .with_goss_rates(0.3, 0.15); let model = GBDTModel::train_binned(&dataset, config).unwrap();
2316
2317 assert_eq!(model.num_trees(), 10);
2318
2319 let predictions = model.predict(&dataset);
2320 assert_eq!(predictions.len(), 1000);
2321 }
2322
2323 #[test]
2324 fn test_auto_column_reordering() {
2325 let dataset = create_regression_dataset(500, 0.1);
2326
2327 let config = GBDTConfig::new().with_num_rounds(10).with_max_depth(4);
2329
2330 let model = GBDTModel::train_binned(&dataset, config).unwrap();
2331
2332 assert!(model.column_permutation().is_some());
2334 let permutation = model.column_permutation().unwrap();
2335 assert_eq!(permutation.new_to_original().len(), 3); }
2337
2338 #[test]
2339 fn test_column_reordering_disabled() {
2340 let dataset = create_regression_dataset(500, 0.1);
2341
2342 let config = GBDTConfig::new()
2344 .with_num_rounds(10)
2345 .with_max_depth(4)
2346 .with_column_reordering(false);
2347
2348 let model = GBDTModel::train_binned(&dataset, config).unwrap();
2349
2350 assert!(model.column_permutation().is_none());
2352 }
2353
2354 #[test]
2355 fn test_feature_importance() {
2356 let dataset = create_regression_dataset(500, 0.1);
2357
2358 let config = GBDTConfig::new().with_num_rounds(20).with_max_depth(4);
2359
2360 let model = GBDTModel::train_binned(&dataset, config).unwrap();
2361 let importances = model.feature_importance();
2362
2363 assert_eq!(importances.len(), 3);
2364
2365 let total: f32 = importances.iter().sum();
2367 assert!((total - 1.0).abs() < 0.01);
2368 }
2369
2370 #[test]
2371 fn test_train_with_monotonic_constraints() {
2372 use crate::tree::MonotonicConstraint;
2373
2374 let dataset = create_regression_dataset(500, 0.1);
2375
2376 let config = GBDTConfig::new()
2378 .with_num_rounds(10)
2379 .with_max_depth(4)
2380 .with_monotonic_constraints(vec![
2381 MonotonicConstraint::Increasing,
2382 MonotonicConstraint::None,
2383 MonotonicConstraint::None,
2384 ]);
2385
2386 let model = GBDTModel::train_binned(&dataset, config).unwrap();
2387
2388 assert!(model.num_trees() > 0);
2390
2391 let predictions = model.predict(&dataset);
2393 assert_eq!(predictions.len(), 500);
2394 }
2395
2396 #[test]
2397 fn test_train_with_interaction_constraints() {
2398 let dataset = create_regression_dataset(500, 0.1);
2399
2400 let config = GBDTConfig::new()
2402 .with_num_rounds(10)
2403 .with_max_depth(4)
2404 .with_interaction_groups(vec![vec![0, 1]]);
2405
2406 let model = GBDTModel::train_binned(&dataset, config).unwrap();
2407
2408 assert!(model.num_trees() > 0);
2410
2411 let predictions = model.predict(&dataset);
2413 assert_eq!(predictions.len(), 500);
2414 }
2415
2416 #[test]
2417 fn test_train_with_era_splitting() {
2418 let num_rows = 600;
2419 let num_eras = 3;
2420
2421 let mut dataset = create_regression_dataset(num_rows, 0.1);
2423
2424 let era_indices: Vec<u16> = (0..num_rows).map(|i| (i % num_eras) as u16).collect();
2426 dataset.set_era_indices(era_indices);
2427
2428 let config = GBDTConfig::new()
2430 .with_num_rounds(10)
2431 .with_max_depth(3)
2432 .with_learning_rate(0.1)
2433 .with_era_splitting(true);
2434
2435 let model = GBDTModel::train_binned(&dataset, config).unwrap();
2436
2437 assert!(model.num_trees() > 0);
2439
2440 let predictions = model.predict(&dataset);
2442 assert_eq!(predictions.len(), num_rows);
2443 }
2444
2445 #[test]
2446 fn test_train_with_eras_high_level_api() {
2447 let num_rows = 600;
2448 let num_features = 5;
2449 let num_eras = 3;
2450
2451 use rand::SeedableRng;
2453 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
2454 let features: Vec<f32> = (0..num_rows * num_features)
2455 .map(|_| rand::Rng::gen_range(&mut rng, 0.0..1.0))
2456 .collect();
2457
2458 let targets: Vec<f32> = (0..num_rows)
2460 .map(|i| {
2461 let f0 = features[i * num_features];
2462 let f1 = features[i * num_features + 1];
2463 f0 * 2.0 + f1 * 3.0 + rand::Rng::gen_range(&mut rng, -0.1..0.1)
2464 })
2465 .collect();
2466
2467 let era_indices: Vec<u16> = (0..num_rows).map(|i| (i % num_eras) as u16).collect();
2469
2470 let config = GBDTConfig::new()
2472 .with_num_rounds(10)
2473 .with_max_depth(3)
2474 .with_learning_rate(0.1)
2475 .with_era_splitting(true);
2476
2477 let model = GBDTModel::train_with_eras(
2478 &features,
2479 num_features,
2480 &targets,
2481 &era_indices,
2482 config,
2483 None,
2484 )
2485 .unwrap();
2486
2487 assert!(model.num_trees() > 0);
2489 assert_eq!(model.num_features(), num_features);
2490
2491 let features_f64: Vec<f64> = features.iter().map(|&v| v as f64).collect();
2493 let predictions = model.predict_raw(&features_f64);
2494 assert_eq!(predictions.len(), num_rows);
2495 }
2496
2497 fn create_multiclass_dataset(num_rows: usize, num_classes: usize) -> BinnedDataset {
2499 let num_features = 4;
2500
2501 let mut features = Vec::with_capacity(num_rows * num_features);
2503 for f in 0..num_features {
2504 for r in 0..num_rows {
2505 features.push(((r * (f + 1) * 17 + r % num_classes * 50) % 256) as u8);
2506 }
2507 }
2508
2509 let targets: Vec<f32> = (0..num_rows).map(|i| (i % num_classes) as f32).collect();
2511
2512 let feature_info = (0..num_features)
2513 .map(|i| FeatureInfo {
2514 name: format!("feature_{}", i),
2515 feature_type: FeatureType::Numeric,
2516 num_bins: 255,
2517 bin_boundaries: vec![],
2518 })
2519 .collect();
2520
2521 BinnedDataset::new(num_rows, features, targets, feature_info)
2522 }
2523
2524 #[test]
2525 fn test_multiclass_training() {
2526 let num_classes = 3;
2527 let dataset = create_multiclass_dataset(300, num_classes);
2528
2529 let config = GBDTConfig::new()
2530 .with_num_rounds(10)
2531 .with_max_depth(3)
2532 .with_learning_rate(0.1)
2533 .with_multiclass_logloss(num_classes);
2534
2535 let model = GBDTModel::train_binned(&dataset, config).unwrap();
2536
2537 assert_eq!(model.num_trees(), 10 * num_classes);
2539 assert!(model.is_multiclass());
2540 assert_eq!(model.get_num_classes(), num_classes);
2541 }
2542
2543 #[test]
2544 fn test_multiclass_prediction() {
2545 let num_classes = 3;
2546 let dataset = create_multiclass_dataset(150, num_classes);
2547
2548 let config = GBDTConfig::new()
2549 .with_num_rounds(20)
2550 .with_max_depth(4)
2551 .with_learning_rate(0.1)
2552 .with_multiclass_logloss(num_classes);
2553
2554 let model = GBDTModel::train_binned(&dataset, config).unwrap();
2555
2556 let proba = model.predict_proba_multiclass(&dataset);
2558 assert_eq!(proba.len(), 150);
2559
2560 for row_proba in &proba {
2562 assert_eq!(row_proba.len(), num_classes);
2563 let sum: f32 = row_proba.iter().sum();
2564 assert!((sum - 1.0).abs() < 1e-5, "Probabilities should sum to 1");
2565
2566 for &p in row_proba {
2568 assert!(p >= 0.0 && p <= 1.0, "Probability should be in [0, 1]");
2569 }
2570 }
2571
2572 let classes = model.predict_class_multiclass(&dataset);
2574 assert_eq!(classes.len(), 150);
2575
2576 for &c in &classes {
2578 assert!(
2579 (c as usize) < num_classes,
2580 "Predicted class should be < num_classes"
2581 );
2582 }
2583
2584 let targets = dataset.targets();
2586 let correct: usize = classes
2587 .iter()
2588 .zip(targets.iter())
2589 .filter(|(&pred, &target)| pred == target as u32)
2590 .count();
2591 let accuracy = correct as f32 / 150.0;
2592
2593 assert!(
2595 accuracy > 0.4,
2596 "Multi-class accuracy {} should be better than random",
2597 accuracy
2598 );
2599 }
2600
2601 #[test]
2602 fn test_save_to_directory() {
2603 use tempfile::tempdir;
2604
2605 let dataset = create_regression_dataset(100, 0.1);
2606
2607 let config = GBDTConfig::new().with_num_rounds(5).with_max_depth(3);
2608
2609 let model = GBDTModel::train_binned(&dataset, config.clone()).unwrap();
2610
2611 let dir = tempdir().unwrap();
2613 let output_path = dir.path().join("model_output");
2614
2615 model
2617 .save_to_directory(&output_path, &config, &[ModelFormat::Rkyv])
2618 .unwrap();
2619
2620 assert!(output_path.join("config.json").exists());
2622 assert!(output_path.join("model.rkyv").exists());
2623
2624 let loaded = crate::serialize::load_model(output_path.join("model.rkyv")).unwrap();
2626 assert_eq!(loaded.num_trees(), model.num_trees());
2627
2628 let config_content = std::fs::read_to_string(output_path.join("config.json")).unwrap();
2630 let parsed: serde_json::Value = serde_json::from_str(&config_content).unwrap();
2631 assert!(parsed.get("num_rounds").is_some());
2632 assert_eq!(parsed["num_rounds"], 5);
2633 }
2634
2635 #[test]
2636 fn test_train_with_output() {
2637 use tempfile::tempdir;
2638
2639 let features = vec![
2641 1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
2642 ];
2643 let targets = vec![1.0f32, 2.0, 3.0, 4.0];
2644 let num_features = 3;
2645
2646 let config = GBDTConfig::new().with_num_rounds(5).with_max_depth(2);
2647
2648 let dir = tempdir().unwrap();
2650 let output_path = dir.path().join("train_output");
2651
2652 let model = GBDTModel::train_with_output(
2654 &features,
2655 num_features,
2656 &targets,
2657 config,
2658 None,
2659 &output_path,
2660 &[ModelFormat::Rkyv, ModelFormat::Bincode],
2661 )
2662 .unwrap();
2663
2664 assert_eq!(model.num_trees(), 5);
2666
2667 assert!(output_path.join("config.json").exists());
2669 assert!(output_path.join("model.rkyv").exists());
2670 assert!(output_path.join("model.bin").exists());
2671
2672 let loaded_rkyv = crate::serialize::load_model(output_path.join("model.rkyv")).unwrap();
2674 let loaded_bincode =
2675 crate::serialize::load_model_bincode(output_path.join("model.bin")).unwrap();
2676 assert_eq!(loaded_rkyv.num_trees(), 5);
2677 assert_eq!(loaded_bincode.num_trees(), 5);
2678 }
2679
2680 #[test]
2681 fn test_save_to_directory_empty_formats_error() {
2682 use tempfile::tempdir;
2683
2684 let dataset = create_regression_dataset(100, 0.1);
2685
2686 let config = GBDTConfig::new().with_num_rounds(5).with_max_depth(3);
2687
2688 let model = GBDTModel::train_binned(&dataset, config.clone()).unwrap();
2689
2690 let dir = tempdir().unwrap();
2692 let output_path = dir.path().join("model_output");
2693
2694 let result = model.save_to_directory(&output_path, &config, &[]);
2696 assert!(result.is_err());
2697
2698 let err_msg = result.unwrap_err().to_string();
2699 assert!(
2700 err_msg.contains("formats must not be empty"),
2701 "Error message: {}",
2702 err_msg
2703 );
2704 }
2705
2706 #[test]
2707 fn test_save_to_directory_creates_parent_dirs() {
2708 use tempfile::tempdir;
2709
2710 let dataset = create_regression_dataset(100, 0.1);
2711
2712 let config = GBDTConfig::new().with_num_rounds(5).with_max_depth(3);
2713
2714 let model = GBDTModel::train_binned(&dataset, config.clone()).unwrap();
2715
2716 let dir = tempdir().unwrap();
2718 let output_path = dir
2719 .path()
2720 .join("deeply")
2721 .join("nested")
2722 .join("path")
2723 .join("model");
2724
2725 model
2727 .save_to_directory(&output_path, &config, &[ModelFormat::Rkyv])
2728 .unwrap();
2729
2730 assert!(output_path.join("config.json").exists());
2732 assert!(output_path.join("model.rkyv").exists());
2733 }
2734
2735 #[test]
2736 fn test_save_to_directory_config_json_completeness() {
2737 use tempfile::tempdir;
2738
2739 let dataset = create_regression_dataset(100, 0.1);
2740
2741 let config = GBDTConfig::new()
2743 .with_num_rounds(42)
2744 .with_max_depth(7)
2745 .with_learning_rate(0.05)
2746 .with_subsample(0.8)
2747 .with_lambda(2.0)
2748 .with_entropy_weight(0.1);
2749
2750 let model = GBDTModel::train_binned(&dataset, config.clone()).unwrap();
2751
2752 let dir = tempdir().unwrap();
2754 let output_path = dir.path().join("model_output");
2755 model
2756 .save_to_directory(&output_path, &config, &[ModelFormat::Rkyv])
2757 .unwrap();
2758
2759 let config_content = std::fs::read_to_string(output_path.join("config.json")).unwrap();
2761 let parsed: serde_json::Value = serde_json::from_str(&config_content).unwrap();
2762
2763 assert_eq!(parsed["num_rounds"], 42);
2765 assert_eq!(parsed["max_depth"], 7);
2766 assert!((parsed["learning_rate"].as_f64().unwrap() - 0.05).abs() < 0.001);
2767 assert!((parsed["subsample"].as_f64().unwrap() - 0.8).abs() < 0.001);
2768 assert!((parsed["lambda"].as_f64().unwrap() - 2.0).abs() < 0.001);
2769 assert!((parsed["entropy_weight"].as_f64().unwrap() - 0.1).abs() < 0.001);
2770 }
2771
2772 #[test]
2777 fn test_tree_residual_appending() {
2778 let dataset = create_regression_dataset(100, 0.1);
2779
2780 let config = GBDTConfig::new()
2782 .with_num_rounds(5)
2783 .with_max_depth(3)
2784 .with_learning_rate(0.1);
2785
2786 let mut model = GBDTModel::train_binned(&dataset, config.clone()).unwrap();
2787 assert_eq!(model.num_trees(), 5);
2788
2789 let initial_preds = model.predict(&dataset);
2790 let initial_mse: f32 = initial_preds
2791 .iter()
2792 .zip(dataset.targets())
2793 .map(|(p, t)| (p - t).powi(2))
2794 .sum::<f32>()
2795 / 100.0;
2796
2797 let residuals = model.compute_residuals(&dataset, dataset.targets());
2799 assert_eq!(residuals.len(), 100);
2800
2801 let second_model = GBDTModel::train_binned(&dataset, config).unwrap();
2807 let trees_to_append: Vec<_> = second_model.trees().to_vec();
2808
2809 model.append_trees(trees_to_append);
2810 assert_eq!(model.num_trees(), 10);
2811
2812 let new_preds = model.predict(&dataset);
2814 assert_eq!(new_preds.len(), 100);
2815
2816 let new_mse: f32 = new_preds
2818 .iter()
2819 .zip(dataset.targets())
2820 .map(|(p, t)| (p - t).powi(2))
2821 .sum::<f32>()
2822 / 100.0;
2823
2824 assert!(initial_mse.is_finite());
2826 assert!(new_mse.is_finite());
2827 }
2828
2829 #[test]
2830 fn test_tree_ensemble_growth() {
2831 let dataset = create_regression_dataset(100, 0.1);
2832
2833 let config = GBDTConfig::new().with_num_rounds(5).with_max_depth(3);
2835
2836 let mut model = GBDTModel::train_binned(&dataset, config.clone()).unwrap();
2837 assert_eq!(model.num_trees(), 5);
2838 assert_eq!(model.num_rounds(), 5);
2839
2840 let second_model = GBDTModel::train_binned(&dataset, config).unwrap();
2842 model.append_trees(second_model.trees().to_vec());
2843
2844 assert_eq!(model.num_trees(), 10);
2845 assert_eq!(model.num_rounds(), 10);
2846
2847 let predictions = model.predict(&dataset);
2849 assert_eq!(predictions.len(), 100);
2850 assert!(predictions.iter().all(|p| p.is_finite()));
2851 }
2852
2853 #[test]
2854 fn test_append_single_tree() {
2855 let dataset = create_regression_dataset(100, 0.1);
2856
2857 let config = GBDTConfig::new().with_num_rounds(1).with_max_depth(3);
2858
2859 let mut model = GBDTModel::train_binned(&dataset, config.clone()).unwrap();
2860 assert_eq!(model.num_trees(), 1);
2861
2862 let second_model = GBDTModel::train_binned(&dataset, config).unwrap();
2864 model.append_tree(second_model.trees()[0].clone());
2865
2866 assert_eq!(model.num_trees(), 2);
2867 }
2868
2869 #[test]
2870 fn test_compute_residuals_correctness() {
2871 let dataset = create_regression_dataset(50, 0.1);
2872
2873 let config = GBDTConfig::new().with_num_rounds(5).with_max_depth(3);
2874
2875 let model = GBDTModel::train_binned(&dataset, config).unwrap();
2876
2877 let predictions = model.predict(&dataset);
2878 let residuals = model.compute_residuals(&dataset, dataset.targets());
2879
2880 for (i, (r, (p, t))) in residuals
2882 .iter()
2883 .zip(predictions.iter().zip(dataset.targets()))
2884 .enumerate()
2885 {
2886 let expected = t - p;
2887 assert!(
2888 (r - expected).abs() < 1e-5,
2889 "Residual {} mismatch: got {}, expected {}",
2890 i,
2891 r,
2892 expected
2893 );
2894 }
2895 }
2896
2897 #[test]
2898 fn test_truncate_to_rounds() {
2899 let dataset = create_regression_dataset(100, 0.1);
2900
2901 let config = GBDTConfig::new().with_num_rounds(10).with_max_depth(3);
2902
2903 let mut model = GBDTModel::train_binned(&dataset, config).unwrap();
2904 assert_eq!(model.num_trees(), 10);
2905
2906 model.truncate_to_rounds(5);
2908 assert_eq!(model.num_trees(), 5);
2909 assert_eq!(model.num_rounds(), 5);
2910
2911 model.truncate_to_rounds(20);
2913 assert_eq!(model.num_trees(), 5);
2914 }
2915
2916 #[test]
2917 fn test_is_compatible_for_update() {
2918 let dataset = create_regression_dataset(100, 0.1);
2919
2920 let config = GBDTConfig::new().with_num_rounds(3);
2921 let model = GBDTModel::train_binned(&dataset, config).unwrap();
2922
2923 assert!(model.is_compatible_for_update(3));
2925
2926 assert!(!model.is_compatible_for_update(5));
2928 assert!(!model.is_compatible_for_update(2));
2929 }
2930}