treeboost/booster/
gbdt.rs

1//! GBDT model and training
2
3use crate::backend::BackendType;
4#[cfg(any(feature = "cuda", feature = "gpu"))]
5use crate::backend::GpuMode;
6use crate::booster::GBDTConfig;
7use crate::dataset::{
8    split_holdout, BinnedDataset, ColumnPermutation, FeatureInfo, FeatureType, QuantileBinner,
9};
10use crate::loss::{sigmoid, softmax, MultiClassLogLoss};
11use crate::tree::{InteractionConstraints, Tree, TreeGrower};
12use crate::tuner::ModelFormat;
13use crate::{Result, TreeBoostError};
14use rand::seq::SliceRandom;
15use rand::SeedableRng;
16use rayon::prelude::*;
17use rkyv::{Archive, Deserialize, Serialize};
18use std::path::Path;
19
20#[cfg(feature = "cuda")]
21use crate::backend::cuda::FullCudaTreeBuilder;
22
23#[cfg(feature = "gpu")]
24use crate::backend::wgpu::FullGpuTreeBuilder;
25
26/// Trained GBDT model
27#[derive(Debug, Clone, Archive, Serialize, Deserialize, serde::Serialize, serde::Deserialize)]
28pub struct GBDTModel {
29    /// Training configuration
30    config: GBDTConfig,
31    /// Base prediction (initial value) - for regression and binary classification
32    base_prediction: f32,
33    /// Base predictions per class (for multi-class classification)
34    /// Empty for regression/binary classification
35    base_predictions_multiclass: Vec<f32>,
36    /// Ensemble of trees
37    ///
38    /// ## Storage Order
39    ///
40    /// **Regression/Binary**: One tree per round, stored sequentially.
41    /// - `trees[round]` = tree for round `round`
42    ///
43    /// **Multi-class (K classes)**: K trees per round (one per class), stored round-major.
44    /// - `trees[round * K + class_idx]` = tree for round `round`, class `class_idx`
45    /// - Example with 3 classes, 2 rounds: `[r0_c0, r0_c1, r0_c2, r1_c0, r1_c1, r1_c2]`
46    ///
47    /// Total trees = `num_rounds` (regression/binary) or `num_rounds * K` (multi-class)
48    trees: Vec<Tree>,
49    /// Number of classes (for multi-class classification, 0 otherwise)
50    num_classes: usize,
51    /// Conformal quantile for prediction intervals (if calibrated)
52    conformal_q: Option<f32>,
53    /// Feature info from training (bin boundaries for consistent prediction)
54    feature_info: Vec<FeatureInfo>,
55    /// Column permutation for cache-optimized prediction (if enabled)
56    column_permutation: Option<ColumnPermutation>,
57}
58
59// =============================================================================
60// Early Stopping Helpers
61// =============================================================================
62
63/// Check if early stopping should trigger
64#[inline]
65pub(crate) fn should_early_stop(
66    rounds_without_improvement: usize,
67    current_count: usize,
68    early_stopping_rounds: usize,
69    min_early_stopping: usize,
70) -> bool {
71    rounds_without_improvement >= early_stopping_rounds && current_count >= min_early_stopping
72}
73
74/// Calculate how many trees/rounds to keep after early stopping
75#[inline]
76pub(crate) fn early_stop_keep_count(best_count: usize, min_early_stopping: usize) -> usize {
77    best_count.max(min_early_stopping)
78}
79
80impl GBDTModel {
81    /// Train a GBDT model from raw feature data (high-level API)
82    ///
83    /// This is the primary training API that handles binning automatically.
84    /// Features are discretized using T-Digest quantile binning with parallelization.
85    ///
86    /// # Arguments
87    /// * `features` - Row-major feature matrix: `features[row * num_features + feature]`
88    ///   Shape: `(num_rows, num_features)` flattened to 1D
89    /// * `num_features` - Number of features (columns)
90    /// * `targets` - Target values, one per row
91    /// * `config` - Training configuration
92    /// * `feature_names` - Optional feature names (defaults to "feature_0", "feature_1", ...)
93    ///
94    /// # Example
95    /// ```ignore
96    /// let features = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2 rows × 3 features
97    /// let targets = vec![0.5, 1.5];
98    /// let config = GBDTConfig::new().with_num_rounds(100);
99    /// let model = GBDTModel::train(&features, 3, &targets, config, None)?;
100    /// ```
101    pub fn train(
102        features: &[f32],
103        num_features: usize,
104        targets: &[f32],
105        config: GBDTConfig,
106        feature_names: Option<Vec<String>>,
107    ) -> Result<Self> {
108        let num_rows = if num_features > 0 {
109            features.len() / num_features
110        } else {
111            0
112        };
113
114        if num_rows == 0 || num_features == 0 {
115            return Err(TreeBoostError::Config("Empty dataset".to_string()));
116        }
117
118        if features.len() != num_rows * num_features {
119            return Err(TreeBoostError::Config(format!(
120                "Feature array length {} doesn't match num_rows * num_features ({} * {} = {})",
121                features.len(),
122                num_rows,
123                num_features,
124                num_rows * num_features
125            )));
126        }
127
128        if targets.len() != num_rows {
129            return Err(TreeBoostError::Config(format!(
130                "Target length {} doesn't match num_rows {}",
131                targets.len(),
132                num_rows
133            )));
134        }
135
136        // Create binner
137        let binner = QuantileBinner::new(config.num_bins);
138
139        // Parallel binning: process each feature column in parallel
140        let binned_results: Vec<(Vec<u8>, FeatureInfo)> = (0..num_features)
141            .into_par_iter()
142            .map(|f| {
143                // Extract column (row-major to column values)
144                let column: Vec<f64> = (0..num_rows)
145                    .map(|r| features[r * num_features + f] as f64)
146                    .collect();
147
148                // Compute boundaries and bin
149                let boundaries = binner.compute_boundaries(&column);
150                let binned = binner.bin_column(&column, &boundaries);
151
152                // Create feature info
153                let name = feature_names
154                    .as_ref()
155                    .and_then(|names| names.get(f).cloned())
156                    .unwrap_or_else(|| format!("feature_{}", f));
157
158                let info = FeatureInfo {
159                    name,
160                    feature_type: FeatureType::Numeric,
161                    num_bins: (boundaries.len() + 1).min(255) as u8,
162                    bin_boundaries: boundaries,
163                };
164
165                (binned, info)
166            })
167            .collect();
168
169        // Combine results into column-major storage
170        let mut binned_data = Vec::with_capacity(num_rows * num_features);
171        let mut feature_info = Vec::with_capacity(num_features);
172
173        for (binned_col, info) in binned_results {
174            binned_data.extend(binned_col);
175            feature_info.push(info);
176        }
177
178        // Create BinnedDataset and train
179        let dataset = BinnedDataset::new(num_rows, binned_data, targets.to_vec(), feature_info);
180
181        Self::train_binned(&dataset, config)
182    }
183
184    /// Train a GBDT model and save to output directory
185    ///
186    /// This is a convenience method that trains a model and automatically saves:
187    /// - The trained model in the specified format(s)
188    /// - `config.json` with the training configuration for reproducibility
189    ///
190    /// # Arguments
191    /// * `features` - Row-major feature matrix: `features[row * num_features + feature]`
192    /// * `num_features` - Number of features (columns)
193    /// * `targets` - Target values, one per row
194    /// * `config` - Training configuration
195    /// * `feature_names` - Optional feature names
196    /// * `output_dir` - Directory to save the model and config
197    /// * `formats` - Model formats to save (e.g., `[ModelFormat::Rkyv]`)
198    ///
199    /// # Example
200    /// ```ignore
201    /// let features = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2 rows × 3 features
202    /// let targets = vec![0.5, 1.5];
203    /// let config = GBDTConfig::new().with_num_rounds(100);
204    ///
205    /// let model = GBDTModel::train_with_output(
206    ///     &features, 3, &targets, config, None,
207    ///     "output/my_model",
208    ///     &[ModelFormat::Rkyv],
209    /// )?;
210    /// // Creates: output/my_model/model.rkyv and output/my_model/config.json
211    /// ```
212    pub fn train_with_output(
213        features: &[f32],
214        num_features: usize,
215        targets: &[f32],
216        config: GBDTConfig,
217        feature_names: Option<Vec<String>>,
218        output_dir: impl AsRef<Path>,
219        formats: &[ModelFormat],
220    ) -> Result<Self> {
221        // Train the model
222        let model = Self::train(
223            features,
224            num_features,
225            targets,
226            config.clone(),
227            feature_names,
228        )?;
229
230        // Save to output directory
231        model.save_to_directory(output_dir, &config, formats)?;
232
233        Ok(model)
234    }
235
236    /// Save a trained model to a directory
237    ///
238    /// Creates the directory if it doesn't exist and saves:
239    /// - The model in each specified format
240    /// - `config.json` with the training configuration
241    ///
242    /// # Arguments
243    /// * `output_dir` - Directory to save the model
244    /// * `config` - Training configuration (for config.json)
245    /// * `formats` - Model formats to save (must not be empty)
246    ///
247    /// # Errors
248    /// Returns an error if `formats` is empty or if I/O operations fail.
249    pub fn save_to_directory(
250        &self,
251        output_dir: impl AsRef<Path>,
252        config: &GBDTConfig,
253        formats: &[ModelFormat],
254    ) -> Result<()> {
255        use std::fs;
256        use std::io::Write;
257
258        // Validate formats is not empty
259        if formats.is_empty() {
260            return Err(TreeBoostError::Config(
261                "formats must not be empty - specify at least one model format".to_string(),
262            ));
263        }
264
265        let dir = output_dir.as_ref();
266
267        // Create directory if it doesn't exist
268        fs::create_dir_all(dir)?;
269
270        // Save config.json for reproducibility
271        let config_path = dir.join("config.json");
272        let config_json = serde_json::to_string_pretty(config).map_err(|e| {
273            TreeBoostError::Serialization(format!("Failed to serialize config: {}", e))
274        })?;
275        let mut file = fs::File::create(&config_path)?;
276        file.write_all(config_json.as_bytes())?;
277
278        // Save model in each format
279        for format in formats {
280            let model_path = dir.join(format!("model.{}", format.extension()));
281            match format {
282                ModelFormat::Rkyv => {
283                    crate::serialize::save_model(self, &model_path)?;
284                }
285                ModelFormat::Bincode => {
286                    crate::serialize::save_model_bincode(self, &model_path)?;
287                }
288            }
289        }
290
291        Ok(())
292    }
293
294    /// Train a GBDT model with Directional Era Splitting (DES)
295    ///
296    /// Era splitting filters out spurious correlations by requiring all eras
297    /// to agree on split direction. This is useful for time-series or financial
298    /// data where patterns may not generalize across time periods.
299    ///
300    /// # Arguments
301    /// * `features` - Row-major feature matrix: `features[row * num_features + feature]`
302    /// * `num_features` - Number of features (columns)
303    /// * `targets` - Target values, one per row
304    /// * `era_indices` - Era index (0-based) for each row
305    /// * `config` - Training configuration (era_splitting must be enabled)
306    /// * `feature_names` - Optional feature names
307    ///
308    /// # Example
309    /// ```ignore
310    /// let features = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2 rows × 3 features
311    /// let targets = vec![0.5, 1.5];
312    /// let era_indices = vec![0, 1]; // Row 0 in era 0, row 1 in era 1
313    /// let config = GBDTConfig::new()
314    ///     .with_num_rounds(100)
315    ///     .with_era_splitting(true);
316    /// let model = GBDTModel::train_with_eras(&features, 3, &targets, &era_indices, config, None)?;
317    /// ```
318    pub fn train_with_eras(
319        features: &[f32],
320        num_features: usize,
321        targets: &[f32],
322        era_indices: &[u16],
323        config: GBDTConfig,
324        feature_names: Option<Vec<String>>,
325    ) -> Result<Self> {
326        let num_rows = if num_features > 0 {
327            features.len() / num_features
328        } else {
329            0
330        };
331
332        if num_rows == 0 || num_features == 0 {
333            return Err(TreeBoostError::Config("Empty dataset".to_string()));
334        }
335
336        if features.len() != num_rows * num_features {
337            return Err(TreeBoostError::Config(format!(
338                "Feature array length {} doesn't match num_rows * num_features ({} * {} = {})",
339                features.len(),
340                num_rows,
341                num_features,
342                num_rows * num_features
343            )));
344        }
345
346        if targets.len() != num_rows {
347            return Err(TreeBoostError::Config(format!(
348                "Target length {} doesn't match num_rows {}",
349                targets.len(),
350                num_rows
351            )));
352        }
353
354        if era_indices.len() != num_rows {
355            return Err(TreeBoostError::Config(format!(
356                "era_indices length {} doesn't match num_rows {}",
357                era_indices.len(),
358                num_rows
359            )));
360        }
361
362        if !config.era_splitting {
363            return Err(TreeBoostError::Config(
364                "era_splitting must be enabled in config when using train_with_eras".to_string(),
365            ));
366        }
367
368        // Create binner
369        let binner = QuantileBinner::new(config.num_bins);
370
371        // Parallel binning: process each feature column in parallel
372        let binned_results: Vec<(Vec<u8>, FeatureInfo)> = (0..num_features)
373            .into_par_iter()
374            .map(|f| {
375                // Extract column (row-major to column values)
376                let column: Vec<f64> = (0..num_rows)
377                    .map(|r| features[r * num_features + f] as f64)
378                    .collect();
379
380                // Compute boundaries and bin
381                let boundaries = binner.compute_boundaries(&column);
382                let binned = binner.bin_column(&column, &boundaries);
383
384                // Create feature info
385                let name = feature_names
386                    .as_ref()
387                    .and_then(|names| names.get(f).cloned())
388                    .unwrap_or_else(|| format!("feature_{}", f));
389
390                let info = FeatureInfo {
391                    name,
392                    feature_type: FeatureType::Numeric,
393                    num_bins: (boundaries.len() + 1).min(255) as u8,
394                    bin_boundaries: boundaries,
395                };
396
397                (binned, info)
398            })
399            .collect();
400
401        // Combine results into column-major storage
402        let mut binned_data = Vec::with_capacity(num_rows * num_features);
403        let mut feature_info = Vec::with_capacity(num_features);
404
405        for (binned_col, info) in binned_results {
406            binned_data.extend(binned_col);
407            feature_info.push(info);
408        }
409
410        // Create BinnedDataset with era indices
411        let mut dataset = BinnedDataset::new(num_rows, binned_data, targets.to_vec(), feature_info);
412        dataset.set_era_indices(era_indices.to_vec());
413
414        Self::train_binned(&dataset, config)
415    }
416
417    /// Train a GBDT model from pre-binned data (low-level API)
418    ///
419    /// Use this when you have already binned your data (e.g., for repeated training
420    /// with different hyperparameters on the same binned dataset).
421    ///
422    /// For most use cases, prefer `train()` which handles binning automatically.
423    pub fn train_binned(dataset: &BinnedDataset, config: GBDTConfig) -> Result<Self> {
424        // Dispatch to multi-class training if using multi-class loss
425        if let Some(num_classes) = config.loss_type.num_classes() {
426            return Self::train_binned_multiclass(dataset, config, num_classes);
427        }
428
429        config.validate().map_err(TreeBoostError::Config)?;
430
431        let loss_fn = config.loss_type.create();
432        let targets = dataset.targets();
433
434        // Split data for validation (early stopping) and conformal calibration
435        let split = split_holdout(
436            dataset.num_rows(),
437            config.validation_ratio,
438            config.calibration_ratio,
439            config.seed,
440        );
441        let (train_indices, validation_indices, calibration_indices) =
442            (split.train, split.validation, split.calibration);
443
444        // Compute base prediction from training data only
445        let train_targets: Vec<f32> = train_indices.iter().map(|&i| targets[i]).collect();
446        let base_prediction = loss_fn.initial_prediction(&train_targets);
447
448        // Initialize predictions for all rows
449        let mut predictions = vec![base_prediction; dataset.num_rows()];
450
451        // Gradient and hessian buffers
452        let mut gradients = vec![0.0f32; dataset.num_rows()];
453        let mut hessians = vec![0.0f32; dataset.num_rows()];
454
455        // Build interaction constraints from groups
456        let interaction_constraints = if config.interaction_groups.is_empty() {
457            InteractionConstraints::new()
458        } else {
459            InteractionConstraints::from_groups(
460                config.interaction_groups.clone(),
461                dataset.num_features(),
462            )
463        };
464
465        // Create tree grower
466        let tree_grower = TreeGrower::new()
467            .with_max_depth(config.max_depth)
468            .with_max_leaves(config.max_leaves)
469            .with_lambda(config.lambda)
470            .with_min_samples_leaf(config.min_samples_leaf)
471            .with_min_hessian_leaf(config.min_hessian_leaf)
472            .with_entropy_weight(config.entropy_weight)
473            .with_min_gain(config.min_gain)
474            .with_learning_rate(config.learning_rate)
475            .with_colsample(config.colsample)
476            .with_monotonic_constraints(config.monotonic_constraints.clone())
477            .with_interaction_constraints(interaction_constraints)
478            .with_backend(config.backend_type)
479            .with_gpu_subgroups(config.use_gpu_subgroups)
480            .with_era_splitting(config.era_splitting);
481
482        let mut trees = Vec::with_capacity(config.num_rounds);
483        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
484
485        // Early stopping state
486        let early_stopping_enabled =
487            config.early_stopping_rounds > 0 && !validation_indices.is_empty();
488        let mut best_val_loss = f32::MAX;
489        let mut rounds_without_improvement = 0;
490        let mut best_num_trees = 0;
491
492        // Pre-allocate reusable buffers for subsampling (avoid per-round allocation)
493        let mut sample_indices: Vec<usize> = Vec::with_capacity(train_indices.len());
494        let mut shuffle_buffer: Vec<usize> = if config.subsample < 1.0 && !config.goss_enabled {
495            train_indices.clone() // Pre-allocate for random subsampling
496        } else {
497            Vec::new()
498        };
499        let mut goss_indexed: Vec<(usize, f32)> = if config.goss_enabled {
500            Vec::with_capacity(train_indices.len())
501        } else {
502            Vec::new()
503        };
504
505        // Determine if we can use fused gradient+histogram (no subsampling)
506        let use_fused = !config.goss_enabled && config.subsample >= 1.0;
507
508        // Create Full GPU builders if applicable
509        // For BackendType::Auto, we try CUDA first, then WGPU
510        #[cfg(feature = "cuda")]
511        let mut cuda_builder: Option<FullCudaTreeBuilder> =
512            if use_fused && matches!(config.backend_type, BackendType::Cuda | BackendType::Auto) {
513                use crate::backend::cuda::CudaDevice;
514                CudaDevice::new().and_then(|d| {
515                    // Resolve gpu_mode knowing we have CUDA available
516                    let resolved = config.gpu_mode.resolve(BackendType::Cuda);
517                    if matches!(resolved, GpuMode::Full) {
518                        Some(FullCudaTreeBuilder::new(std::sync::Arc::new(d)))
519                    } else {
520                        None
521                    }
522                })
523            } else {
524                None
525            };
526
527        #[cfg(feature = "gpu")]
528        let mut wgpu_builder: Option<FullGpuTreeBuilder> = if use_fused
529            && matches!(config.backend_type, BackendType::Wgpu | BackendType::Auto)
530            && {
531                #[cfg(feature = "cuda")]
532                {
533                    cuda_builder.is_none() // Only use WGPU if CUDA not available/chosen
534                }
535                #[cfg(not(feature = "cuda"))]
536                {
537                    true
538                }
539            } {
540            use crate::backend::wgpu::GpuDevice;
541            GpuDevice::new().and_then(|d| {
542                // Resolve gpu_mode knowing we have WGPU
543                let resolved = config.gpu_mode.resolve(BackendType::Wgpu);
544                if matches!(resolved, GpuMode::Full) {
545                    Some(FullGpuTreeBuilder::new(std::sync::Arc::new(d)))
546                } else {
547                    None
548                }
549            })
550        } else {
551            None
552        };
553
554        for _round in 0..config.num_rounds {
555            // Grow tree - either fused, Full GPU, or separate gradient+histogram paths
556            #[allow(unused_mut, unused_assignments)]
557            let mut tree: Option<Tree> = None;
558
559            // Try Full GPU builders first (level-wise growth, all-GPU pipeline)
560            #[cfg(feature = "cuda")]
561            if tree.is_none() {
562                if let Some(ref mut builder) = cuda_builder {
563                    // Compute gradients for this round
564                    for &idx in &train_indices {
565                        let (g, h) = loss_fn.gradient_hessian(targets[idx], predictions[idx]);
566                        gradients[idx] = g;
567                        hessians[idx] = h;
568                    }
569                    tree = Some(builder.build_tree(
570                        dataset,
571                        &gradients,
572                        &hessians,
573                        &train_indices,
574                        config.max_depth,
575                        config.max_leaves,
576                        config.lambda,
577                        config.min_samples_leaf,
578                        config.min_hessian_leaf,
579                        config.min_gain,
580                        config.learning_rate,
581                    ));
582                }
583            }
584
585            #[cfg(feature = "gpu")]
586            if tree.is_none() {
587                if let Some(ref mut builder) = wgpu_builder {
588                    // Compute gradients for this round
589                    for &idx in &train_indices {
590                        let (g, h) = loss_fn.gradient_hessian(targets[idx], predictions[idx]);
591                        gradients[idx] = g;
592                        hessians[idx] = h;
593                    }
594                    tree = Some(builder.build_tree(
595                        dataset,
596                        &gradients,
597                        &hessians,
598                        &train_indices,
599                        config.max_depth,
600                        config.max_leaves,
601                        config.lambda,
602                        config.min_samples_leaf,
603                        config.min_hessian_leaf,
604                        config.min_gain,
605                        config.learning_rate,
606                    ));
607                }
608            }
609
610            // Fall back to TreeGrower (Hybrid mode or CPU)
611            let tree = tree.unwrap_or_else(|| {
612                if use_fused {
613                    // FUSED PATH: Compute gradients AND build root histogram in single pass
614                    // This eliminates cache pollution for ~40-80% speedup on large datasets
615                    tree_grower.grow_fused(
616                        dataset,
617                        &train_indices,
618                        targets,
619                        &predictions,
620                        loss_fn.as_ref(),
621                        &mut gradients,
622                        &mut hessians,
623                    )
624                } else {
625                    // SEPARATE PATH: Compute gradients first, then build histogram
626                    // Required for GOSS (needs all gradients for sampling) and random subsampling
627
628                    // Compute gradients and hessians
629                    if config.parallel_gradient {
630                        train_indices.par_iter().for_each(|&idx| {
631                            let (g, h) = loss_fn.gradient_hessian(targets[idx], predictions[idx]);
632                            // SAFETY: Each idx is unique within train_indices, so no data races
633                            unsafe {
634                                let grad_ptr = gradients.as_ptr() as *mut f32;
635                                let hess_ptr = hessians.as_ptr() as *mut f32;
636                                *grad_ptr.add(idx) = g;
637                                *hess_ptr.add(idx) = h;
638                            }
639                        });
640                    } else {
641                        for &idx in &train_indices {
642                            let (g, h) = loss_fn.gradient_hessian(targets[idx], predictions[idx]);
643                            gradients[idx] = g;
644                            hessians[idx] = h;
645                        }
646                    }
647
648                    // GOSS or random subsampling
649                    let tree_indices: &[usize] = if config.goss_enabled {
650                        // GOSS: Gradient-based One-Side Sampling
651                        sample_indices.clear();
652                        Self::goss_sample_into(
653                            &train_indices,
654                            &mut gradients,
655                            &mut hessians,
656                            config.goss_top_rate,
657                            config.goss_other_rate,
658                            &mut rng,
659                            &mut goss_indexed,
660                            &mut sample_indices,
661                        );
662                        &sample_indices
663                    } else if config.subsample < 1.0 {
664                        // Random subsampling (Stochastic Gradient Boosting)
665                        sample_indices.clear();
666                        let n_samples =
667                            ((train_indices.len() as f32) * config.subsample).ceil() as usize;
668                        shuffle_buffer.shuffle(&mut rng);
669                        sample_indices.extend_from_slice(&shuffle_buffer[..n_samples]);
670                        &sample_indices
671                    } else {
672                        &train_indices
673                    };
674
675                    // Grow tree using the selected training indices
676                    tree_grower.grow_with_indices(dataset, &gradients, &hessians, tree_indices)
677                }
678            });
679
680            // Update predictions using tree-wise batch prediction
681            // This is more cache-friendly than row-wise and avoids intermediate allocation
682            tree.predict_batch_add(dataset, &mut predictions);
683
684            trees.push(tree);
685
686            // Check for early stopping on validation set
687            if early_stopping_enabled {
688                // Compute validation loss using the ACTUAL loss function
689                // This ensures correct early stopping for classification (log loss) and regression (MSE/Huber)
690                let val_loss: f32 = if validation_indices.len() >= 10000 {
691                    validation_indices
692                        .par_iter()
693                        .map(|&idx| loss_fn.loss(targets[idx], predictions[idx]))
694                        .sum::<f32>()
695                } else {
696                    validation_indices
697                        .iter()
698                        .map(|&idx| loss_fn.loss(targets[idx], predictions[idx]))
699                        .sum::<f32>()
700                } / validation_indices.len() as f32;
701
702                if val_loss < best_val_loss {
703                    best_val_loss = val_loss;
704                    best_num_trees = trees.len();
705                    rounds_without_improvement = 0;
706                } else {
707                    rounds_without_improvement += 1;
708                    if should_early_stop(
709                        rounds_without_improvement,
710                        trees.len(),
711                        config.early_stopping_rounds,
712                        config.min_early_stopping_trees,
713                    ) {
714                        trees.truncate(early_stop_keep_count(
715                            best_num_trees,
716                            config.min_early_stopping_trees,
717                        ));
718                        break;
719                    }
720                }
721            }
722        }
723
724        // If early stopping was used but we finished all rounds, still check if we should truncate
725        if early_stopping_enabled && best_num_trees > 0 && best_num_trees < trees.len() {
726            trees.truncate(early_stop_keep_count(
727                best_num_trees,
728                config.min_early_stopping_trees,
729            ));
730        }
731
732        // Auto-apply column reordering by feature importance if enabled
733        let column_permutation = if config.column_reordering && !trees.is_empty() {
734            let importances = Self::compute_importances_from_trees(&trees, dataset.num_features());
735            Some(ColumnPermutation::from_importances(&importances))
736        } else {
737            None
738        };
739
740        // Compute conformal quantile if calibration set exists
741        let conformal_q = if !calibration_indices.is_empty() {
742            let calib_residuals: Vec<f32> = if calibration_indices.len() >= 10000 {
743                calibration_indices
744                    .par_iter()
745                    .map(|&idx| (targets[idx] - predictions[idx]).abs())
746                    .collect()
747            } else {
748                calibration_indices
749                    .iter()
750                    .map(|&idx| (targets[idx] - predictions[idx]).abs())
751                    .collect()
752            };
753
754            Some(Self::compute_quantile(
755                &calib_residuals,
756                config.conformal_quantile,
757            ))
758        } else {
759            None
760        };
761
762        Ok(Self {
763            config,
764            base_prediction,
765            base_predictions_multiclass: Vec::new(),
766            trees,
767            num_classes: 0,
768            conformal_q,
769            feature_info: dataset.all_feature_info().to_vec(),
770            column_permutation,
771        })
772    }
773
774    /// Train a GBDT model with an external validation set for early stopping
775    ///
776    /// Use this when you have a separate validation set that was properly prepared
777    /// (e.g., encoded separately to avoid target leakage). The external validation
778    /// set is used for early stopping decisions while training uses the full train set.
779    ///
780    /// # Note on Implementation
781    /// This method shares ~90% of logic with `train_binned()`. The duplication is
782    /// intentional due to the complexity of the training loop (CUDA/WGPU backends,
783    /// fused gradient paths, GOSS sampling). Extracting shared logic would require
784    /// significant refactoring and testing to maintain correctness.
785    ///
786    /// Key differences from `train_binned()`:
787    /// - Uses external validation set instead of internal split
788    /// - Maintains separate validation predictions array
789    /// - No calibration set for conformal prediction
790    pub fn train_binned_with_validation(
791        train_dataset: &BinnedDataset,
792        val_dataset: &BinnedDataset,
793        val_targets: &[f32],
794        config: GBDTConfig,
795    ) -> Result<Self> {
796        config.validate().map_err(TreeBoostError::Config)?;
797
798        let loss_fn = config.loss_type.create();
799        let targets = train_dataset.targets();
800
801        // Use ALL training data (no internal split)
802        let train_indices: Vec<usize> = (0..train_dataset.num_rows()).collect();
803
804        // Compute base prediction from training data
805        let base_prediction = loss_fn.initial_prediction(targets);
806
807        // Predictions for training and validation
808        let mut predictions = vec![base_prediction; train_dataset.num_rows()];
809        let mut val_predictions = vec![base_prediction; val_dataset.num_rows()];
810
811        // Gradient and hessian buffers
812        let mut gradients = vec![0.0f32; train_dataset.num_rows()];
813        let mut hessians = vec![0.0f32; train_dataset.num_rows()];
814
815        // Build interaction constraints
816        let interaction_constraints = if config.interaction_groups.is_empty() {
817            InteractionConstraints::new()
818        } else {
819            InteractionConstraints::from_groups(
820                config.interaction_groups.clone(),
821                train_dataset.num_features(),
822            )
823        };
824
825        // Create tree grower
826        let tree_grower = TreeGrower::new()
827            .with_max_depth(config.max_depth)
828            .with_max_leaves(config.max_leaves)
829            .with_lambda(config.lambda)
830            .with_min_samples_leaf(config.min_samples_leaf)
831            .with_min_hessian_leaf(config.min_hessian_leaf)
832            .with_entropy_weight(config.entropy_weight)
833            .with_min_gain(config.min_gain)
834            .with_learning_rate(config.learning_rate)
835            .with_colsample(config.colsample)
836            .with_monotonic_constraints(config.monotonic_constraints.clone())
837            .with_interaction_constraints(interaction_constraints)
838            .with_backend(config.backend_type)
839            .with_gpu_subgroups(config.use_gpu_subgroups)
840            .with_era_splitting(config.era_splitting);
841
842        let mut trees = Vec::with_capacity(config.num_rounds);
843        let mut rng = rand::rngs::StdRng::seed_from_u64(config.seed);
844
845        // Early stopping with external validation
846        let early_stopping_enabled = config.early_stopping_rounds > 0;
847        let mut best_val_loss = f32::MAX;
848        let mut rounds_without_improvement = 0;
849        let mut best_num_trees = 0;
850
851        // Pre-allocate buffers for subsampling
852        let mut sample_indices: Vec<usize> = Vec::with_capacity(train_indices.len());
853        let mut shuffle_buffer: Vec<usize> = if config.subsample < 1.0 && !config.goss_enabled {
854            train_indices.clone()
855        } else {
856            Vec::new()
857        };
858        let mut goss_indexed: Vec<(usize, f32)> = if config.goss_enabled {
859            Vec::with_capacity(train_indices.len())
860        } else {
861            Vec::new()
862        };
863
864        let use_fused = !config.goss_enabled && config.subsample >= 1.0;
865
866        for _round in 0..config.num_rounds {
867            // Grow tree using same logic as train_binned
868            let tree = if use_fused {
869                tree_grower.grow_fused(
870                    train_dataset,
871                    &train_indices,
872                    targets,
873                    &predictions,
874                    loss_fn.as_ref(),
875                    &mut gradients,
876                    &mut hessians,
877                )
878            } else {
879                // Compute gradients
880                for &idx in &train_indices {
881                    let (g, h) = loss_fn.gradient_hessian(targets[idx], predictions[idx]);
882                    gradients[idx] = g;
883                    hessians[idx] = h;
884                }
885
886                // Handle subsampling
887                let tree_indices: &[usize] = if config.goss_enabled {
888                    sample_indices.clear();
889                    Self::goss_sample_into(
890                        &train_indices,
891                        &mut gradients,
892                        &mut hessians,
893                        config.goss_top_rate,
894                        config.goss_other_rate,
895                        &mut rng,
896                        &mut goss_indexed,
897                        &mut sample_indices,
898                    );
899                    &sample_indices
900                } else if config.subsample < 1.0 {
901                    sample_indices.clear();
902                    let n_samples =
903                        ((train_indices.len() as f32) * config.subsample).ceil() as usize;
904                    shuffle_buffer.shuffle(&mut rng);
905                    sample_indices.extend_from_slice(&shuffle_buffer[..n_samples]);
906                    &sample_indices
907                } else {
908                    &train_indices
909                };
910
911                tree_grower.grow_with_indices(train_dataset, &gradients, &hessians, tree_indices)
912            };
913
914            // Update training predictions
915            tree.predict_batch_add(train_dataset, &mut predictions);
916
917            // Update validation predictions using external val_dataset
918            for (i, pred) in val_predictions.iter_mut().enumerate() {
919                *pred += tree.predict_row(val_dataset, i);
920            }
921
922            trees.push(tree);
923
924            // Early stopping check on EXTERNAL validation set
925            // Use the actual loss function, not MSE
926            if early_stopping_enabled {
927                let val_loss: f32 = val_targets
928                    .iter()
929                    .zip(val_predictions.iter())
930                    .map(|(&target, &pred)| loss_fn.loss(target, pred))
931                    .sum::<f32>()
932                    / val_targets.len() as f32;
933
934                if val_loss < best_val_loss {
935                    best_val_loss = val_loss;
936                    best_num_trees = trees.len();
937                    rounds_without_improvement = 0;
938                } else {
939                    rounds_without_improvement += 1;
940                    if should_early_stop(
941                        rounds_without_improvement,
942                        trees.len(),
943                        config.early_stopping_rounds,
944                        config.min_early_stopping_trees,
945                    ) {
946                        trees.truncate(early_stop_keep_count(
947                            best_num_trees,
948                            config.min_early_stopping_trees,
949                        ));
950                        break;
951                    }
952                }
953            }
954        }
955
956        // Truncate if we finished all rounds but best was earlier
957        if early_stopping_enabled && best_num_trees > 0 && best_num_trees < trees.len() {
958            trees.truncate(early_stop_keep_count(
959                best_num_trees,
960                config.min_early_stopping_trees,
961            ));
962        }
963
964        // Column reordering
965        let column_permutation = if config.column_reordering && !trees.is_empty() {
966            let importances =
967                Self::compute_importances_from_trees(&trees, train_dataset.num_features());
968            Some(ColumnPermutation::from_importances(&importances))
969        } else {
970            None
971        };
972
973        // Compute conformal quantile from validation set residuals
974        let conformal_q = if !val_targets.is_empty() {
975            let residuals: Vec<f32> = val_targets
976                .iter()
977                .zip(val_predictions.iter())
978                .map(|(&target, &pred)| (target - pred).abs())
979                .collect();
980            Some(Self::compute_quantile(
981                &residuals,
982                config.conformal_quantile,
983            ))
984        } else {
985            None
986        };
987
988        Ok(Self {
989            config,
990            base_prediction,
991            base_predictions_multiclass: Vec::new(),
992            trees,
993            num_classes: 0,
994            conformal_q,
995            feature_info: train_dataset.all_feature_info().to_vec(),
996            column_permutation,
997        })
998    }
999
1000    /// Train a multi-class classification model from pre-binned data
1001    ///
1002    /// This trains K trees per round (one per class) and combines predictions
1003    /// via softmax for final class probabilities.
1004    fn train_binned_multiclass(
1005        dataset: &BinnedDataset,
1006        config: GBDTConfig,
1007        num_classes: usize,
1008    ) -> Result<Self> {
1009        config.validate().map_err(TreeBoostError::Config)?;
1010
1011        let targets = dataset.targets();
1012        let multiclass_loss = MultiClassLogLoss::new(num_classes);
1013
1014        // Split data for validation and calibration
1015        let split = split_holdout(
1016            dataset.num_rows(),
1017            config.validation_ratio,
1018            config.calibration_ratio,
1019            config.seed,
1020        );
1021        let (train_indices, validation_indices, _calibration_indices) =
1022            (split.train, split.validation, split.calibration);
1023
1024        // Compute initial predictions per class from training data
1025        let train_targets: Vec<f32> = train_indices.iter().map(|&i| targets[i]).collect();
1026        let base_predictions = multiclass_loss.initial_predictions(&train_targets);
1027
1028        // Initialize predictions for all rows: predictions[row * num_classes + class]
1029        let num_rows = dataset.num_rows();
1030        let mut predictions: Vec<f32> = Vec::with_capacity(num_rows * num_classes);
1031        for _ in 0..num_rows {
1032            predictions.extend_from_slice(&base_predictions);
1033        }
1034
1035        // Gradient and hessian buffers (per sample, used for one class at a time)
1036        let mut gradients = vec![0.0f32; num_rows];
1037        let mut hessians = vec![0.0f32; num_rows];
1038
1039        // Build interaction constraints
1040        let interaction_constraints = if config.interaction_groups.is_empty() {
1041            InteractionConstraints::new()
1042        } else {
1043            InteractionConstraints::from_groups(
1044                config.interaction_groups.clone(),
1045                dataset.num_features(),
1046            )
1047        };
1048
1049        // Create tree grower
1050        let tree_grower = TreeGrower::new()
1051            .with_max_depth(config.max_depth)
1052            .with_max_leaves(config.max_leaves)
1053            .with_lambda(config.lambda)
1054            .with_min_samples_leaf(config.min_samples_leaf)
1055            .with_min_hessian_leaf(config.min_hessian_leaf)
1056            .with_entropy_weight(config.entropy_weight)
1057            .with_min_gain(config.min_gain)
1058            .with_learning_rate(config.learning_rate)
1059            .with_colsample(config.colsample)
1060            .with_monotonic_constraints(config.monotonic_constraints.clone())
1061            .with_interaction_constraints(interaction_constraints)
1062            .with_backend(config.backend_type)
1063            .with_gpu_subgroups(config.use_gpu_subgroups)
1064            .with_era_splitting(config.era_splitting);
1065
1066        // Trees stored as: [round0_class0, round0_class1, ..., round0_classK, round1_class0, ...]
1067        let mut trees = Vec::with_capacity(config.num_rounds * num_classes);
1068
1069        // Early stopping state
1070        let early_stopping_enabled =
1071            config.early_stopping_rounds > 0 && !validation_indices.is_empty();
1072        let mut best_val_loss = f32::MAX;
1073        let mut rounds_without_improvement = 0;
1074        let mut best_num_rounds = 0;
1075
1076        for round in 0..config.num_rounds {
1077            // Train K trees for this round (one per class)
1078            for class_idx in 0..num_classes {
1079                // Compute gradients and hessians for this class using batch method
1080                multiclass_loss.compute_gradients_batch(
1081                    class_idx,
1082                    targets,
1083                    &predictions,
1084                    &train_indices,
1085                    &mut gradients,
1086                    &mut hessians,
1087                );
1088
1089                // Grow tree for this class
1090                let tree =
1091                    tree_grower.grow_with_indices(dataset, &gradients, &hessians, &train_indices);
1092
1093                // Update predictions for this class
1094                for idx in 0..num_rows {
1095                    let delta = tree.predict(|f| dataset.get_bin(idx, f));
1096                    predictions[idx * num_classes + class_idx] += delta;
1097                }
1098
1099                trees.push(tree);
1100            }
1101
1102            // Early stopping check on validation set
1103            if early_stopping_enabled {
1104                // Compute multi-class log loss on validation set
1105                let mut val_loss = 0.0f32;
1106                for &idx in &validation_indices {
1107                    let target_class = targets[idx] as usize;
1108                    let row_preds = &predictions[idx * num_classes..(idx + 1) * num_classes];
1109                    let probs = softmax(row_preds);
1110                    // Negative log likelihood for true class
1111                    val_loss -= probs[target_class].max(1e-15).ln();
1112                }
1113                val_loss /= validation_indices.len() as f32;
1114
1115                if val_loss < best_val_loss {
1116                    best_val_loss = val_loss;
1117                    best_num_rounds = round + 1;
1118                    rounds_without_improvement = 0;
1119                } else {
1120                    rounds_without_improvement += 1;
1121                    // Use actual tree count (not round count) for consistency with binary/regression
1122                    if should_early_stop(
1123                        rounds_without_improvement,
1124                        trees.len(),
1125                        config.early_stopping_rounds,
1126                        config.min_early_stopping_trees,
1127                    ) {
1128                        let keep_rounds = early_stop_keep_count(
1129                            best_num_rounds,
1130                            config.min_early_stopping_trees / num_classes.max(1),
1131                        );
1132                        trees.truncate(keep_rounds * num_classes);
1133                        break;
1134                    }
1135                }
1136            }
1137        }
1138
1139        // Truncate if early stopping finished all rounds but best was earlier
1140        if early_stopping_enabled
1141            && best_num_rounds > 0
1142            && best_num_rounds * num_classes < trees.len()
1143        {
1144            let keep_rounds = early_stop_keep_count(
1145                best_num_rounds,
1146                config.min_early_stopping_trees / num_classes.max(1),
1147            );
1148            trees.truncate(keep_rounds * num_classes);
1149        }
1150
1151        // Compute column permutation if enabled
1152        let column_permutation = if config.column_reordering && !trees.is_empty() {
1153            let importances = Self::compute_importances_from_trees(&trees, dataset.num_features());
1154            Some(ColumnPermutation::from_importances(&importances))
1155        } else {
1156            None
1157        };
1158
1159        Ok(Self {
1160            config,
1161            base_prediction: 0.0, // Not used for multi-class
1162            base_predictions_multiclass: base_predictions,
1163            trees,
1164            num_classes,
1165            conformal_q: None, // Conformal not supported for multi-class yet
1166            feature_info: dataset.all_feature_info().to_vec(),
1167            column_permutation,
1168        })
1169    }
1170
1171    /// Compute feature importances from a collection of trees (internal helper)
1172    fn compute_importances_from_trees(trees: &[Tree], num_features: usize) -> Vec<f32> {
1173        let mut importances = vec![0.0f32; num_features];
1174
1175        for tree in trees {
1176            for (_, node) in tree.internal_nodes() {
1177                if let Some((feature_idx, _, _, _, _)) = node.split_info() {
1178                    importances[feature_idx] += node.sum_hessians;
1179                }
1180            }
1181        }
1182
1183        // Normalize
1184        let total: f32 = importances.iter().sum();
1185        if total > 0.0 {
1186            for imp in &mut importances {
1187                *imp /= total;
1188            }
1189        }
1190
1191        importances
1192    }
1193
1194    /// GOSS (Gradient-based One-Side Sampling) with buffer reuse
1195    ///
1196    /// Selects samples based on gradient magnitude:
1197    /// 1. Keep all top `top_rate` samples with largest |gradient|
1198    /// 2. Randomly sample `other_rate` from the remaining samples
1199    /// 3. Apply weight correction (1 - top_rate) / other_rate to sampled small-gradient samples
1200    ///
1201    /// Weight correction is applied in-place to gradients and hessians.
1202    /// Uses partial sorting (select_nth_unstable) for O(n) instead of O(n log n).
1203    ///
1204    /// This version reuses pre-allocated buffers to avoid per-round allocation.
1205    #[allow(clippy::too_many_arguments)]
1206    fn goss_sample_into(
1207        train_indices: &[usize],
1208        gradients: &mut [f32],
1209        hessians: &mut [f32],
1210        top_rate: f32,
1211        other_rate: f32,
1212        rng: &mut rand::rngs::StdRng,
1213        indexed_buffer: &mut Vec<(usize, f32)>,
1214        result: &mut Vec<usize>,
1215    ) {
1216        let n = train_indices.len();
1217        if n == 0 {
1218            return;
1219        }
1220
1221        // Number of top-gradient samples to keep
1222        let n_top = ((n as f32) * top_rate).ceil() as usize;
1223        let n_top = n_top.min(n);
1224        // Number of other samples to randomly select
1225        let n_other = ((n as f32) * other_rate).ceil() as usize;
1226
1227        // Reuse indexed buffer - clear and repopulate
1228        indexed_buffer.clear();
1229        indexed_buffer.extend(train_indices.iter().map(|&idx| (idx, gradients[idx].abs())));
1230
1231        // Partition around the n_top-th largest element (descending order)
1232        if n_top < n {
1233            indexed_buffer.select_nth_unstable_by(n_top, |a, b| {
1234                b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
1235            });
1236        }
1237
1238        // Add top n_top samples (large gradients) - no weight modification needed
1239        result.extend(indexed_buffer[..n_top].iter().map(|(idx, _)| *idx));
1240
1241        // For small gradients: shuffle the rest portion in-place and take n_other
1242        let rest_slice = &mut indexed_buffer[n_top..];
1243        rest_slice.shuffle(rng);
1244        let n_rest = rest_slice.len().min(n_other);
1245
1246        // Weight correction factor for small-gradient samples
1247        let weight = (1.0 - top_rate) / other_rate;
1248
1249        // Apply weight correction and add to result
1250        for &(idx, _) in &rest_slice[..n_rest] {
1251            gradients[idx] *= weight;
1252            hessians[idx] *= weight;
1253            result.push(idx);
1254        }
1255    }
1256
1257    /// Compute quantile of a sorted slice
1258    fn compute_quantile(values: &[f32], q: f32) -> f32 {
1259        if values.is_empty() {
1260            return 0.0;
1261        }
1262
1263        // Filter out NaN values before sorting
1264        let mut sorted: Vec<f32> = values.iter().copied().filter(|v| !v.is_nan()).collect();
1265
1266        if sorted.is_empty() {
1267            return 0.0;
1268        }
1269
1270        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1271
1272        let idx = ((sorted.len() as f32) * q).ceil() as usize;
1273        let idx = idx.min(sorted.len() - 1);
1274        sorted[idx]
1275    }
1276
1277    /// Predict for a single row
1278    pub fn predict_row(&self, dataset: &BinnedDataset, row_idx: usize) -> f32 {
1279        let mut pred = self.base_prediction;
1280        for tree in &self.trees {
1281            pred += tree.predict_row(dataset, row_idx);
1282        }
1283        pred
1284    }
1285
1286    /// Predict for all rows using tree-wise batch prediction
1287    ///
1288    /// This approach traverses one tree for ALL rows before moving to the next tree,
1289    /// which is more cache-friendly than row-wise traversal.
1290    ///
1291    /// Routes to parallel or sequential based on config.parallel_prediction
1292    pub fn predict(&self, dataset: &BinnedDataset) -> Vec<f32> {
1293        if self.config.parallel_prediction {
1294            self.predict_parallel(dataset)
1295        } else {
1296            self.predict_sequential(dataset)
1297        }
1298    }
1299
1300    /// Single-threaded tree-wise batch prediction
1301    ///
1302    /// Traverses each tree for all rows before moving to the next tree.
1303    /// More cache-friendly than row-wise traversal.
1304    pub fn predict_sequential(&self, dataset: &BinnedDataset) -> Vec<f32> {
1305        let num_rows = dataset.num_rows();
1306
1307        // Initialize predictions with base value
1308        let mut predictions = vec![self.base_prediction; num_rows];
1309
1310        // Tree-wise: traverse each tree for all rows
1311        for tree in &self.trees {
1312            tree.predict_batch_add(dataset, &mut predictions);
1313        }
1314
1315        predictions
1316    }
1317
1318    /// Parallel tree-wise batch prediction
1319    ///
1320    /// Splits rows into chunks and processes each chunk in parallel.
1321    /// Each chunk uses tree-wise traversal internally.
1322    pub fn predict_parallel(&self, dataset: &BinnedDataset) -> Vec<f32> {
1323        let num_rows = dataset.num_rows();
1324
1325        // For small datasets, use sequential
1326        if num_rows < 1000 || self.trees.is_empty() {
1327            return self.predict_sequential(dataset);
1328        }
1329
1330        // Initialize predictions with base value
1331        let mut predictions = vec![self.base_prediction; num_rows];
1332
1333        // Determine chunk size for parallelism (target ~4 chunks per thread)
1334        let num_threads = rayon::current_num_threads();
1335        let chunk_size = (num_rows / (num_threads * 4)).max(256);
1336
1337        // Process chunks in parallel, each chunk does tree-wise traversal
1338        predictions
1339            .par_chunks_mut(chunk_size)
1340            .enumerate()
1341            .for_each(|(chunk_idx, chunk)| {
1342                let start_row = chunk_idx * chunk_size;
1343
1344                // For each tree, process this chunk of rows
1345                for tree in &self.trees {
1346                    for (i, pred) in chunk.iter_mut().enumerate() {
1347                        let row_idx = start_row + i;
1348                        *pred += tree.predict(|f| dataset.get_bin(row_idx, f));
1349                    }
1350                }
1351            });
1352
1353        predictions
1354    }
1355
1356    /// Legacy row-wise prediction (kept for comparison/testing)
1357    #[doc(hidden)]
1358    pub fn predict_row_wise(&self, dataset: &BinnedDataset) -> Vec<f32> {
1359        let num_rows = dataset.num_rows();
1360        let num_features = dataset.num_features();
1361
1362        let mut predictions = Vec::with_capacity(num_rows);
1363        let mut row_bins = vec![0u8; num_features];
1364
1365        for row_idx in 0..num_rows {
1366            // Cache all bins for this row
1367            for (f, bin) in row_bins.iter_mut().enumerate() {
1368                *bin = dataset.get_bin(row_idx, f);
1369            }
1370
1371            // Traverse all trees with cached bins
1372            let mut pred = self.base_prediction;
1373            for tree in &self.trees {
1374                pred += tree.predict(|f| row_bins[f]);
1375            }
1376            predictions.push(pred);
1377        }
1378
1379        predictions
1380    }
1381
1382    /// Predict with conformal intervals
1383    ///
1384    /// Returns (predictions, lower_bounds, upper_bounds)
1385    pub fn predict_with_intervals(
1386        &self,
1387        dataset: &BinnedDataset,
1388    ) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
1389        let predictions = self.predict(dataset);
1390
1391        let q = self.conformal_q.unwrap_or(0.0);
1392        let lower: Vec<f32> = predictions.iter().map(|&p| p - q).collect();
1393        let upper: Vec<f32> = predictions.iter().map(|&p| p + q).collect();
1394
1395        (predictions, lower, upper)
1396    }
1397
1398    // ============================================================================
1399    // Classification prediction methods
1400    // ============================================================================
1401
1402    /// Predict class probabilities for binary classification
1403    ///
1404    /// Applies sigmoid to raw predictions to get probabilities in [0, 1].
1405    /// Only meaningful when trained with `with_binary_logloss()`.
1406    ///
1407    /// # Returns
1408    /// Vector of probabilities (probability of class 1)
1409    pub fn predict_proba(&self, dataset: &BinnedDataset) -> Vec<f32> {
1410        let raw = self.predict(dataset);
1411        raw.iter().map(|&r| sigmoid(r)).collect()
1412    }
1413
1414    /// Predict class labels for binary classification
1415    ///
1416    /// Applies sigmoid to raw predictions and thresholds at 0.5 (or custom threshold).
1417    /// Only meaningful when trained with `with_binary_logloss()`.
1418    ///
1419    /// # Arguments
1420    /// * `dataset` - The binned dataset to predict on
1421    /// * `threshold` - Classification threshold (default 0.5)
1422    ///
1423    /// # Returns
1424    /// Vector of class labels (0 or 1)
1425    pub fn predict_class(&self, dataset: &BinnedDataset, threshold: f32) -> Vec<u32> {
1426        let proba = self.predict_proba(dataset);
1427        proba
1428            .iter()
1429            .map(|&p| if p >= threshold { 1 } else { 0 })
1430            .collect()
1431    }
1432
1433    // ============================================================================
1434    // Multi-class classification prediction methods
1435    // ============================================================================
1436
1437    /// Check if this is a multi-class model
1438    pub fn is_multiclass(&self) -> bool {
1439        self.num_classes > 0
1440    }
1441
1442    /// Get number of classes (0 for regression/binary)
1443    pub fn get_num_classes(&self) -> usize {
1444        self.num_classes
1445    }
1446
1447    /// Predict class probabilities for multi-class classification
1448    ///
1449    /// Applies softmax to raw predictions to get probabilities for each class.
1450    /// Only meaningful when trained with `with_multiclass_logloss()`.
1451    ///
1452    /// # Returns
1453    /// Vector of probability vectors: result[sample][class]
1454    pub fn predict_proba_multiclass(&self, dataset: &BinnedDataset) -> Vec<Vec<f32>> {
1455        if self.num_classes == 0 {
1456            // Not a multi-class model, fall back to binary
1457            return self
1458                .predict_proba(dataset)
1459                .into_iter()
1460                .map(|p| vec![1.0 - p, p])
1461                .collect();
1462        }
1463
1464        let num_rows = dataset.num_rows();
1465        let num_classes = self.num_classes;
1466        let num_rounds = self.trees.len() / num_classes;
1467
1468        // Initialize raw predictions with base values
1469        let mut raw_preds: Vec<f32> = Vec::with_capacity(num_rows * num_classes);
1470        for _ in 0..num_rows {
1471            raw_preds.extend_from_slice(&self.base_predictions_multiclass);
1472        }
1473
1474        // Add tree predictions
1475        // Trees are stored as: [round0_class0, round0_class1, ..., round0_classK, round1_class0, ...]
1476        for round in 0..num_rounds {
1477            for class_idx in 0..num_classes {
1478                let tree_idx = round * num_classes + class_idx;
1479                let tree = &self.trees[tree_idx];
1480
1481                for row_idx in 0..num_rows {
1482                    let delta = tree.predict(|f| dataset.get_bin(row_idx, f));
1483                    raw_preds[row_idx * num_classes + class_idx] += delta;
1484                }
1485            }
1486        }
1487
1488        // Apply softmax to each row
1489        let mut result = Vec::with_capacity(num_rows);
1490        for row_idx in 0..num_rows {
1491            let row_preds = &raw_preds[row_idx * num_classes..(row_idx + 1) * num_classes];
1492            result.push(softmax(row_preds));
1493        }
1494
1495        result
1496    }
1497
1498    /// Predict class labels for multi-class classification
1499    ///
1500    /// Returns the class with highest probability (argmax of softmax).
1501    /// Only meaningful when trained with `with_multiclass_logloss()`.
1502    ///
1503    /// # Returns
1504    /// Vector of class labels (0, 1, 2, ..., K-1)
1505    pub fn predict_class_multiclass(&self, dataset: &BinnedDataset) -> Vec<u32> {
1506        let proba = self.predict_proba_multiclass(dataset);
1507        proba
1508            .iter()
1509            .map(|p| {
1510                p.iter()
1511                    .enumerate()
1512                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
1513                    .map(|(idx, _)| idx as u32)
1514                    .unwrap_or(0)
1515            })
1516            .collect()
1517    }
1518
1519    /// Predict raw scores for multi-class classification (before softmax)
1520    ///
1521    /// Returns raw predictions for each class (not probabilities).
1522    /// Shape: result[sample][class]
1523    pub fn predict_raw_multiclass(&self, dataset: &BinnedDataset) -> Vec<Vec<f32>> {
1524        if self.num_classes == 0 {
1525            // Not a multi-class model
1526            return self.predict(dataset).into_iter().map(|p| vec![p]).collect();
1527        }
1528
1529        let num_rows = dataset.num_rows();
1530        let num_classes = self.num_classes;
1531        let num_rounds = self.trees.len() / num_classes;
1532
1533        // Initialize raw predictions with base values
1534        let mut raw_preds: Vec<f32> = Vec::with_capacity(num_rows * num_classes);
1535        for _ in 0..num_rows {
1536            raw_preds.extend_from_slice(&self.base_predictions_multiclass);
1537        }
1538
1539        // Add tree predictions
1540        for round in 0..num_rounds {
1541            for class_idx in 0..num_classes {
1542                let tree_idx = round * num_classes + class_idx;
1543                let tree = &self.trees[tree_idx];
1544
1545                for row_idx in 0..num_rows {
1546                    let delta = tree.predict(|f| dataset.get_bin(row_idx, f));
1547                    raw_preds[row_idx * num_classes + class_idx] += delta;
1548                }
1549            }
1550        }
1551
1552        // Convert to Vec<Vec<f32>>
1553        let mut result = Vec::with_capacity(num_rows);
1554        for row_idx in 0..num_rows {
1555            let row_preds = &raw_preds[row_idx * num_classes..(row_idx + 1) * num_classes];
1556            result.push(row_preds.to_vec());
1557        }
1558
1559        result
1560    }
1561
1562    // ============================================================================
1563    // Raw prediction methods (no binning required)
1564    // ============================================================================
1565
1566    /// Predict using raw feature values (no binning needed)
1567    ///
1568    /// This is the primary prediction method for external use (e.g., Python bindings).
1569    /// Uses the split_value stored in tree nodes to compare directly against raw values,
1570    /// avoiding the overhead of binning on every prediction call.
1571    ///
1572    /// # Arguments
1573    /// * `features` - Row-major feature matrix: features[row * num_features + feature]
1574    ///   Shape: (num_rows, num_features)
1575    ///
1576    /// # Returns
1577    /// Vector of predictions for each row
1578    pub fn predict_raw(&self, features: &[f64]) -> Vec<f32> {
1579        let num_features = self.num_features();
1580        if num_features == 0 {
1581            return vec![];
1582        }
1583
1584        let num_rows = features.len() / num_features;
1585        debug_assert_eq!(features.len(), num_rows * num_features);
1586
1587        if self.config.parallel_prediction && num_rows >= 1000 {
1588            self.predict_raw_parallel(features, num_features)
1589        } else {
1590            self.predict_raw_sequential(features, num_features)
1591        }
1592    }
1593
1594    /// Single-threaded raw prediction using tree-wise traversal
1595    fn predict_raw_sequential(&self, features: &[f64], num_features: usize) -> Vec<f32> {
1596        let num_rows = features.len() / num_features;
1597
1598        // Initialize predictions with base value
1599        let mut predictions = vec![self.base_prediction; num_rows];
1600
1601        // Tree-wise: traverse each tree for all rows
1602        for tree in &self.trees {
1603            tree.predict_batch_add_raw(features, num_features, &mut predictions);
1604        }
1605
1606        predictions
1607    }
1608
1609    /// Parallel raw prediction using tree-wise traversal
1610    fn predict_raw_parallel(&self, features: &[f64], num_features: usize) -> Vec<f32> {
1611        let num_rows = features.len() / num_features;
1612
1613        // For small datasets, use sequential
1614        if num_rows < 1000 || self.trees.is_empty() {
1615            return self.predict_raw_sequential(features, num_features);
1616        }
1617
1618        // Initialize predictions with base value
1619        let mut predictions = vec![self.base_prediction; num_rows];
1620
1621        // Determine chunk size for parallelism
1622        let num_threads = rayon::current_num_threads();
1623        let chunk_size = (num_rows / (num_threads * 4)).max(256);
1624
1625        // Process chunks in parallel
1626        predictions
1627            .par_chunks_mut(chunk_size)
1628            .enumerate()
1629            .for_each(|(chunk_idx, chunk)| {
1630                let start_row = chunk_idx * chunk_size;
1631                let chunk_features_start = start_row * num_features;
1632
1633                // Each thread processes its chunk through all trees
1634                for tree in &self.trees {
1635                    for (i, pred) in chunk.iter_mut().enumerate() {
1636                        let row_offset = chunk_features_start + i * num_features;
1637                        *pred += tree.predict_raw(|f| features[row_offset + f]);
1638                    }
1639                }
1640            });
1641
1642        predictions
1643    }
1644
1645    /// Predict raw with conformal intervals
1646    ///
1647    /// Returns (predictions, lower_bounds, upper_bounds)
1648    pub fn predict_raw_with_intervals(&self, features: &[f64]) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
1649        let predictions = self.predict_raw(features);
1650
1651        let q = self.conformal_q.unwrap_or(0.0);
1652        let lower: Vec<f32> = predictions.iter().map(|&p| p - q).collect();
1653        let upper: Vec<f32> = predictions.iter().map(|&p| p + q).collect();
1654
1655        (predictions, lower, upper)
1656    }
1657
1658    /// Predict class probabilities from raw features (for binary classification)
1659    ///
1660    /// Applies sigmoid to raw predictions to get probabilities in [0, 1].
1661    /// Only meaningful when trained with `with_binary_logloss()`.
1662    pub fn predict_proba_raw(&self, features: &[f64]) -> Vec<f32> {
1663        let raw = self.predict_raw(features);
1664        raw.iter().map(|&r| sigmoid(r)).collect()
1665    }
1666
1667    /// Predict class labels from raw features (for binary classification)
1668    ///
1669    /// Applies sigmoid to raw predictions and thresholds.
1670    /// Only meaningful when trained with `with_binary_logloss()`.
1671    pub fn predict_class_raw(&self, features: &[f64], threshold: f32) -> Vec<u32> {
1672        let proba = self.predict_proba_raw(features);
1673        proba
1674            .iter()
1675            .map(|&p| if p >= threshold { 1 } else { 0 })
1676            .collect()
1677    }
1678
1679    // ============================================================================
1680    // Multi-class raw prediction methods (from raw features, no binning needed)
1681    // ============================================================================
1682
1683    /// Predict class probabilities from raw features (for multi-class classification)
1684    ///
1685    /// Uses the split_value stored in tree nodes to compare directly against raw values.
1686    /// Applies softmax to raw predictions to get probabilities for each class.
1687    /// Only meaningful when trained with `with_multiclass_logloss()`.
1688    ///
1689    /// # Arguments
1690    /// * `features` - Row-major feature matrix: features[row * num_features + feature]
1691    ///
1692    /// # Returns
1693    /// Vector of probability vectors: result[sample][class]
1694    pub fn predict_proba_multiclass_raw(&self, features: &[f64]) -> Vec<Vec<f32>> {
1695        if self.num_classes == 0 {
1696            // Not a multi-class model, fall back to binary
1697            return self
1698                .predict_proba_raw(features)
1699                .into_iter()
1700                .map(|p| vec![1.0 - p, p])
1701                .collect();
1702        }
1703
1704        let num_features = self.num_features();
1705        if num_features == 0 {
1706            return vec![];
1707        }
1708
1709        let num_rows = features.len() / num_features;
1710        let num_classes = self.num_classes;
1711        let num_rounds = self.trees.len() / num_classes;
1712
1713        // Initialize raw predictions with base values
1714        let mut raw_preds: Vec<f32> = Vec::with_capacity(num_rows * num_classes);
1715        for _ in 0..num_rows {
1716            raw_preds.extend_from_slice(&self.base_predictions_multiclass);
1717        }
1718
1719        // Add tree predictions
1720        // Trees are stored as: [round0_class0, round0_class1, ..., round0_classK, round1_class0, ...]
1721        for round in 0..num_rounds {
1722            for class_idx in 0..num_classes {
1723                let tree_idx = round * num_classes + class_idx;
1724                let tree = &self.trees[tree_idx];
1725
1726                for row_idx in 0..num_rows {
1727                    let row_offset = row_idx * num_features;
1728                    let delta = tree.predict_raw(|f| features[row_offset + f]);
1729                    raw_preds[row_idx * num_classes + class_idx] += delta;
1730                }
1731            }
1732        }
1733
1734        // Apply softmax to each row
1735        let mut result = Vec::with_capacity(num_rows);
1736        for row_idx in 0..num_rows {
1737            let row_preds = &raw_preds[row_idx * num_classes..(row_idx + 1) * num_classes];
1738            result.push(softmax(row_preds));
1739        }
1740
1741        result
1742    }
1743
1744    /// Predict class labels from raw features (for multi-class classification)
1745    ///
1746    /// Returns the class with highest probability (argmax of softmax).
1747    /// Only meaningful when trained with `with_multiclass_logloss()`.
1748    ///
1749    /// # Arguments
1750    /// * `features` - Row-major feature matrix: features[row * num_features + feature]
1751    ///
1752    /// # Returns
1753    /// Vector of class labels (0, 1, 2, ..., K-1)
1754    pub fn predict_class_multiclass_raw(&self, features: &[f64]) -> Vec<u32> {
1755        let proba = self.predict_proba_multiclass_raw(features);
1756        proba
1757            .iter()
1758            .map(|p| {
1759                p.iter()
1760                    .enumerate()
1761                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
1762                    .map(|(idx, _)| idx as u32)
1763                    .unwrap_or(0)
1764            })
1765            .collect()
1766    }
1767
1768    /// Predict raw scores from raw features (for multi-class, before softmax)
1769    ///
1770    /// Returns raw predictions for each class (not probabilities).
1771    ///
1772    /// # Arguments
1773    /// * `features` - Row-major feature matrix: features[row * num_features + feature]
1774    ///
1775    /// # Returns
1776    /// Vector of raw score vectors: result[sample][class]
1777    pub fn predict_raw_multiclass_raw(&self, features: &[f64]) -> Vec<Vec<f32>> {
1778        if self.num_classes == 0 {
1779            // Not a multi-class model
1780            return self
1781                .predict_raw(features)
1782                .into_iter()
1783                .map(|p| vec![p])
1784                .collect();
1785        }
1786
1787        let num_features = self.num_features();
1788        if num_features == 0 {
1789            return vec![];
1790        }
1791
1792        let num_rows = features.len() / num_features;
1793        let num_classes = self.num_classes;
1794        let num_rounds = self.trees.len() / num_classes;
1795
1796        // Initialize raw predictions with base values
1797        let mut raw_preds: Vec<f32> = Vec::with_capacity(num_rows * num_classes);
1798        for _ in 0..num_rows {
1799            raw_preds.extend_from_slice(&self.base_predictions_multiclass);
1800        }
1801
1802        // Add tree predictions
1803        for round in 0..num_rounds {
1804            for class_idx in 0..num_classes {
1805                let tree_idx = round * num_classes + class_idx;
1806                let tree = &self.trees[tree_idx];
1807
1808                for row_idx in 0..num_rows {
1809                    let row_offset = row_idx * num_features;
1810                    let delta = tree.predict_raw(|f| features[row_offset + f]);
1811                    raw_preds[row_idx * num_classes + class_idx] += delta;
1812                }
1813            }
1814        }
1815
1816        // Convert to Vec<Vec<f32>>
1817        let mut result = Vec::with_capacity(num_rows);
1818        for row_idx in 0..num_rows {
1819            let row_preds = &raw_preds[row_idx * num_classes..(row_idx + 1) * num_classes];
1820            result.push(row_preds.to_vec());
1821        }
1822
1823        result
1824    }
1825
1826    /// Get number of trees
1827    pub fn num_trees(&self) -> usize {
1828        self.trees.len()
1829    }
1830
1831    /// Get configuration
1832    pub fn config(&self) -> &GBDTConfig {
1833        &self.config
1834    }
1835
1836    /// Get base prediction
1837    pub fn base_prediction(&self) -> f32 {
1838        self.base_prediction
1839    }
1840
1841    /// Get conformal quantile (if calibrated)
1842    pub fn conformal_quantile(&self) -> Option<f32> {
1843        self.conformal_q
1844    }
1845
1846    /// Get trees
1847    pub fn trees(&self) -> &[Tree] {
1848        &self.trees
1849    }
1850
1851    /// Get feature info (for consistent binning during prediction)
1852    pub fn feature_info(&self) -> &[FeatureInfo] {
1853        &self.feature_info
1854    }
1855
1856    /// Get number of features
1857    pub fn num_features(&self) -> usize {
1858        self.feature_info.len()
1859    }
1860
1861    /// Get column permutation (if optimized layout was applied)
1862    pub fn column_permutation(&self) -> Option<&ColumnPermutation> {
1863        self.column_permutation.as_ref()
1864    }
1865
1866    /// Compute feature importance (gain-based)
1867    pub fn feature_importance(&self) -> Vec<f32> {
1868        let mut importances = vec![0.0f32; self.num_features()];
1869
1870        for tree in &self.trees {
1871            for (_, node) in tree.internal_nodes() {
1872                // Safe to unwrap: internal_nodes() filters to only internal nodes
1873                let (feature_idx, _, _, _, _) = node.split_info().unwrap();
1874                // Use hessian as importance weight (proxy for sample weight)
1875                importances[feature_idx] += node.sum_hessians;
1876            }
1877        }
1878
1879        // Normalize
1880        let total: f32 = importances.iter().sum();
1881        if total > 0.0 {
1882            for imp in &mut importances {
1883                *imp /= total;
1884            }
1885        }
1886
1887        importances
1888    }
1889
1890    /// Create a cache-optimized dataset by reordering columns based on feature importance
1891    ///
1892    /// More frequently used features are placed at the beginning of the dataset
1893    /// for better CPU cache locality during tree traversal.
1894    ///
1895    /// Returns the reordered dataset and the permutation mapping (new_idx -> original_idx)
1896    pub fn optimize_dataset_layout(
1897        &self,
1898        dataset: &BinnedDataset,
1899    ) -> (BinnedDataset, crate::dataset::ColumnPermutation) {
1900        let importances = self.feature_importance();
1901        let permutation = crate::dataset::ColumnPermutation::from_importances(&importances);
1902        let optimized = crate::dataset::reorder_dataset(dataset, &permutation);
1903        (optimized, permutation)
1904    }
1905
1906    /// Create a memory-optimized packed dataset from a BinnedDataset
1907    ///
1908    /// Uses 4-bit packing for features with ≤16 unique bins,
1909    /// providing up to 50% memory savings for low-cardinality features.
1910    pub fn create_packed_dataset(&self, dataset: &BinnedDataset) -> crate::dataset::PackedDataset {
1911        crate::dataset::PackedDataset::from_binned(dataset)
1912    }
1913
1914    // =========================================================================
1915    // Incremental Learning Support
1916    // =========================================================================
1917
1918    /// Get number of completed rounds
1919    ///
1920    /// For regression/binary: num_rounds == num_trees
1921    /// For multi-class: num_rounds = num_trees / num_classes
1922    pub fn num_rounds(&self) -> usize {
1923        if self.num_classes == 0 {
1924            self.trees.len()
1925        } else {
1926            self.trees.len() / self.num_classes.max(1)
1927        }
1928    }
1929
1930    /// Append new trees to the ensemble (incremental learning)
1931    ///
1932    /// # Arguments
1933    /// * `new_trees` - Trees trained on residuals from current ensemble
1934    ///
1935    /// # Multi-class Note
1936    /// For multi-class classification, trees must be provided in round-major order:
1937    /// `[round_n_class_0, round_n_class_1, ..., round_n_class_k, round_n+1_class_0, ...]`
1938    ///
1939    /// # Example
1940    /// ```ignore
1941    /// // Get current predictions
1942    /// let preds = model.predict(&dataset);
1943    ///
1944    /// // Compute residuals (new_targets - preds)
1945    /// let residuals: Vec<f32> = targets.iter().zip(&preds)
1946    ///     .map(|(t, p)| t - p).collect();
1947    ///
1948    /// // Compute gradients/hessians for new trees
1949    /// // ... train new trees on residuals ...
1950    ///
1951    /// // Append to model
1952    /// model.append_trees(new_trees);
1953    /// ```
1954    pub fn append_trees(&mut self, new_trees: Vec<Tree>) {
1955        self.trees.extend(new_trees);
1956    }
1957
1958    /// Append a single tree to the ensemble
1959    ///
1960    /// Convenience method for adding one tree at a time.
1961    pub fn append_tree(&mut self, tree: Tree) {
1962        self.trees.push(tree);
1963    }
1964
1965    /// Compute residuals from current ensemble predictions
1966    ///
1967    /// Residuals are `target - prediction`, which become the training targets
1968    /// for new trees in incremental learning.
1969    ///
1970    /// # Arguments
1971    /// * `dataset` - Binned dataset for prediction
1972    /// * `targets` - Original target values
1973    ///
1974    /// # Returns
1975    /// Residuals (target - prediction) for each sample
1976    pub fn compute_residuals(&self, dataset: &BinnedDataset, targets: &[f32]) -> Vec<f32> {
1977        let predictions = self.predict(dataset);
1978        predictions
1979            .iter()
1980            .zip(targets)
1981            .map(|(p, t)| t - p)
1982            .collect()
1983    }
1984
1985    /// Compute residuals from raw feature data
1986    ///
1987    /// # Arguments
1988    /// * `features` - Row-major feature matrix (f64 for raw prediction API)
1989    /// * `targets` - Original target values
1990    ///
1991    /// # Returns
1992    /// Residuals (target - prediction) for each sample
1993    pub fn compute_residuals_raw(&self, features: &[f64], targets: &[f32]) -> Vec<f32> {
1994        let predictions = self.predict_raw(features);
1995        predictions
1996            .iter()
1997            .zip(targets)
1998            .map(|(p, t)| t - p)
1999            .collect()
2000    }
2001
2002    /// Check if incremental update is compatible
2003    ///
2004    /// Verifies that new data has the same number of features as the trained model.
2005    pub fn is_compatible_for_update(&self, num_features: usize) -> bool {
2006        self.num_features() == num_features
2007    }
2008
2009    /// Get mutable reference to trees (for advanced manipulation)
2010    ///
2011    /// Use with caution - modifying trees directly may break model invariants.
2012    pub fn trees_mut(&mut self) -> &mut Vec<Tree> {
2013        &mut self.trees
2014    }
2015
2016    /// Truncate trees to a specific number of rounds
2017    ///
2018    /// Useful for early stopping or reducing model complexity.
2019    ///
2020    /// # Arguments
2021    /// * `num_rounds` - Number of rounds to keep
2022    ///
2023    /// # Multi-class Note
2024    /// Truncates to `num_rounds * num_classes` trees for multi-class models.
2025    pub fn truncate_to_rounds(&mut self, num_rounds: usize) {
2026        let trees_per_round = if self.num_classes == 0 {
2027            1
2028        } else {
2029            self.num_classes
2030        };
2031        let target_trees = num_rounds * trees_per_round;
2032        if target_trees < self.trees.len() {
2033            self.trees.truncate(target_trees);
2034        }
2035    }
2036}
2037
2038// =============================================================================
2039// TunableModel Implementation
2040// =============================================================================
2041
2042use crate::tuner::{ParamValue, TunableModel};
2043use std::collections::HashMap;
2044
2045impl TunableModel for GBDTModel {
2046    type Config = GBDTConfig;
2047
2048    fn train(dataset: &BinnedDataset, config: &Self::Config) -> Result<Self> {
2049        Self::train_binned(dataset, config.clone())
2050    }
2051
2052    fn train_with_validation(
2053        train_data: &BinnedDataset,
2054        val_data: &BinnedDataset,
2055        val_targets: &[f32],
2056        config: &Self::Config,
2057    ) -> Result<Self> {
2058        Self::train_binned_with_validation(train_data, val_data, val_targets, config.clone())
2059    }
2060
2061    fn predict(&self, dataset: &BinnedDataset) -> Vec<f32> {
2062        GBDTModel::predict(self, dataset)
2063    }
2064
2065    fn num_trees(&self) -> usize {
2066        self.trees.len()
2067    }
2068
2069    fn apply_params(config: &mut Self::Config, params: &HashMap<String, ParamValue>) {
2070        for (name, value) in params {
2071            match (name.as_str(), value) {
2072                ("max_depth", ParamValue::Numeric(v)) => config.max_depth = *v as usize,
2073                ("learning_rate", ParamValue::Numeric(v)) => config.learning_rate = *v,
2074                ("subsample", ParamValue::Numeric(v)) => config.subsample = *v,
2075                ("colsample", ParamValue::Numeric(v)) => config.colsample = *v,
2076                ("lambda", ParamValue::Numeric(v)) => config.lambda = *v,
2077                ("entropy_weight", ParamValue::Numeric(v)) => config.entropy_weight = *v,
2078                ("min_samples_leaf", ParamValue::Numeric(v)) => {
2079                    config.min_samples_leaf = *v as usize
2080                }
2081                ("min_hessian_leaf", ParamValue::Numeric(v)) => config.min_hessian_leaf = *v,
2082                ("min_gain", ParamValue::Numeric(v)) => config.min_gain = *v,
2083                ("num_rounds", ParamValue::Numeric(v)) => config.num_rounds = *v as usize,
2084                ("goss_top_rate", ParamValue::Numeric(v)) => config.goss_top_rate = *v,
2085                ("goss_other_rate", ParamValue::Numeric(v)) => config.goss_other_rate = *v,
2086                _ => {} // Unknown params are ignored
2087            }
2088        }
2089    }
2090
2091    fn valid_params() -> &'static [&'static str] {
2092        &[
2093            "max_depth",
2094            "learning_rate",
2095            "subsample",
2096            "colsample",
2097            "lambda",
2098            "entropy_weight",
2099            "min_samples_leaf",
2100            "min_hessian_leaf",
2101            "min_gain",
2102            "num_rounds",
2103            "goss_top_rate",
2104            "goss_other_rate",
2105        ]
2106    }
2107
2108    fn default_config() -> Self::Config {
2109        GBDTConfig::default()
2110    }
2111
2112    fn is_gpu_config(config: &Self::Config) -> bool {
2113        matches!(
2114            config.backend_type,
2115            BackendType::Wgpu | BackendType::Cuda | BackendType::Auto
2116        )
2117    }
2118
2119    fn get_learning_rate(config: &Self::Config) -> f32 {
2120        config.learning_rate
2121    }
2122
2123    fn configure_validation(
2124        config: &mut Self::Config,
2125        validation_ratio: f32,
2126        early_stopping_rounds: usize,
2127    ) {
2128        config.validation_ratio = validation_ratio;
2129        config.early_stopping_rounds = early_stopping_rounds;
2130    }
2131
2132    fn set_num_rounds(config: &mut Self::Config, num_rounds: usize) {
2133        config.num_rounds = num_rounds;
2134    }
2135
2136    fn save_rkyv(&self, path: &std::path::Path) -> Result<()> {
2137        crate::serialize::save_model(self, path)
2138    }
2139
2140    fn save_bincode(&self, path: &std::path::Path) -> Result<()> {
2141        crate::serialize::save_model_bincode(self, path)
2142    }
2143
2144    fn supports_conformal() -> bool {
2145        true
2146    }
2147
2148    fn conformal_quantile(&self) -> Option<f32> {
2149        self.conformal_q
2150    }
2151
2152    fn configure_conformal(config: &mut Self::Config, calibration_ratio: f32, quantile: f32) {
2153        config.calibration_ratio = calibration_ratio;
2154        config.conformal_quantile = quantile;
2155    }
2156}
2157
2158#[cfg(test)]
2159mod tests {
2160    use super::*;
2161    use crate::dataset::{FeatureInfo, FeatureType};
2162
2163    fn create_regression_dataset(num_rows: usize, noise: f32) -> BinnedDataset {
2164        let num_features = 3;
2165
2166        // Generate features
2167        let mut features = Vec::with_capacity(num_rows * num_features);
2168        for f in 0..num_features {
2169            for r in 0..num_rows {
2170                features.push(((r * (f + 1) * 17) % 256) as u8);
2171            }
2172        }
2173
2174        // Generate targets with some pattern
2175        let targets: Vec<f32> = (0..num_rows)
2176            .map(|i| {
2177                let f0 = features[i] as f32 / 255.0;
2178                let f1 = features[num_rows + i] as f32 / 255.0;
2179                f0 * 10.0 + f1 * 5.0 + noise * (i as f32 % 10.0 - 5.0) / 5.0
2180            })
2181            .collect();
2182
2183        let feature_info = (0..num_features)
2184            .map(|i| FeatureInfo {
2185                name: format!("feature_{}", i),
2186                feature_type: FeatureType::Numeric,
2187                num_bins: 255,
2188                bin_boundaries: vec![],
2189            })
2190            .collect();
2191
2192        BinnedDataset::new(num_rows, features, targets, feature_info)
2193    }
2194
2195    #[test]
2196    fn test_train_basic() {
2197        let dataset = create_regression_dataset(500, 0.1);
2198
2199        let config = GBDTConfig::new()
2200            .with_num_rounds(10)
2201            .with_max_depth(3)
2202            .with_learning_rate(0.1);
2203
2204        let model = GBDTModel::train_binned(&dataset, config).unwrap();
2205
2206        assert_eq!(model.num_trees(), 10);
2207
2208        // Test prediction
2209        let predictions = model.predict(&dataset);
2210        assert_eq!(predictions.len(), 500);
2211    }
2212
2213    #[test]
2214    fn test_train_with_pseudo_huber() {
2215        let dataset = create_regression_dataset(500, 1.0);
2216
2217        let config = GBDTConfig::new()
2218            .with_num_rounds(10)
2219            .with_pseudo_huber_loss(1.0);
2220
2221        let model = GBDTModel::train_binned(&dataset, config).unwrap();
2222        assert_eq!(model.num_trees(), 10);
2223    }
2224
2225    #[test]
2226    fn test_train_with_conformal() {
2227        let dataset = create_regression_dataset(500, 0.5);
2228
2229        let config = GBDTConfig::new()
2230            .with_num_rounds(10)
2231            .with_conformal(0.2, 0.9);
2232
2233        let model = GBDTModel::train_binned(&dataset, config).unwrap();
2234
2235        assert!(model.conformal_quantile().is_some());
2236        assert!(model.conformal_quantile().unwrap() > 0.0);
2237
2238        // Test interval prediction
2239        let (preds, lower, upper) = model.predict_with_intervals(&dataset);
2240        assert_eq!(preds.len(), dataset.num_rows());
2241        assert_eq!(lower.len(), dataset.num_rows());
2242        assert_eq!(upper.len(), dataset.num_rows());
2243
2244        // Intervals should be symmetric
2245        for i in 0..preds.len() {
2246            assert!((preds[i] - lower[i] - (upper[i] - preds[i])).abs() < 1e-6);
2247        }
2248    }
2249
2250    #[test]
2251    fn test_train_with_early_stopping() {
2252        let dataset = create_regression_dataset(1000, 0.1);
2253
2254        let config = GBDTConfig::new()
2255            .with_num_rounds(100) // Max rounds
2256            .with_max_depth(4)
2257            .with_early_stopping(5, 0.2); // Stop after 5 rounds without improvement, 20% validation
2258
2259        let model = GBDTModel::train_binned(&dataset, config).unwrap();
2260
2261        // Should have stopped early (fewer than 100 trees)
2262        // With deterministic data, early stopping should trigger
2263        assert!(model.num_trees() < 100);
2264        assert!(model.num_trees() > 0);
2265    }
2266
2267    #[test]
2268    fn test_train_with_subsampling() {
2269        let dataset = create_regression_dataset(1000, 0.1);
2270
2271        let config = GBDTConfig::new()
2272            .with_num_rounds(10)
2273            .with_max_depth(4)
2274            .with_subsample(0.8) // 80% row subsampling
2275            .with_colsample(0.8); // 80% column subsampling
2276
2277        let model = GBDTModel::train_binned(&dataset, config).unwrap();
2278
2279        assert_eq!(model.num_trees(), 10);
2280
2281        // Predictions should still be reasonable
2282        let predictions = model.predict(&dataset);
2283        assert_eq!(predictions.len(), 1000);
2284    }
2285
2286    #[test]
2287    fn test_train_with_goss() {
2288        let dataset = create_regression_dataset(1000, 0.1);
2289
2290        // GOSS enabled with default rates (top 20%, sample 10% of rest)
2291        let config = GBDTConfig::new()
2292            .with_num_rounds(10)
2293            .with_max_depth(4)
2294            .with_goss(true);
2295
2296        let model = GBDTModel::train_binned(&dataset, config).unwrap();
2297
2298        assert_eq!(model.num_trees(), 10);
2299
2300        // Predictions should still be reasonable
2301        let predictions = model.predict(&dataset);
2302        assert_eq!(predictions.len(), 1000);
2303    }
2304
2305    #[test]
2306    fn test_train_with_goss_custom_rates() {
2307        let dataset = create_regression_dataset(1000, 0.1);
2308
2309        // Custom GOSS rates
2310        let config = GBDTConfig::new()
2311            .with_num_rounds(10)
2312            .with_max_depth(4)
2313            .with_goss_rates(0.3, 0.15); // top 30%, sample 15% of rest
2314
2315        let model = GBDTModel::train_binned(&dataset, config).unwrap();
2316
2317        assert_eq!(model.num_trees(), 10);
2318
2319        let predictions = model.predict(&dataset);
2320        assert_eq!(predictions.len(), 1000);
2321    }
2322
2323    #[test]
2324    fn test_auto_column_reordering() {
2325        let dataset = create_regression_dataset(500, 0.1);
2326
2327        // With column reordering enabled (default)
2328        let config = GBDTConfig::new().with_num_rounds(10).with_max_depth(4);
2329
2330        let model = GBDTModel::train_binned(&dataset, config).unwrap();
2331
2332        // Should have computed column permutation
2333        assert!(model.column_permutation().is_some());
2334        let permutation = model.column_permutation().unwrap();
2335        assert_eq!(permutation.new_to_original().len(), 3); // 3 features
2336    }
2337
2338    #[test]
2339    fn test_column_reordering_disabled() {
2340        let dataset = create_regression_dataset(500, 0.1);
2341
2342        // With column reordering disabled
2343        let config = GBDTConfig::new()
2344            .with_num_rounds(10)
2345            .with_max_depth(4)
2346            .with_column_reordering(false);
2347
2348        let model = GBDTModel::train_binned(&dataset, config).unwrap();
2349
2350        // Should not have computed column permutation
2351        assert!(model.column_permutation().is_none());
2352    }
2353
2354    #[test]
2355    fn test_feature_importance() {
2356        let dataset = create_regression_dataset(500, 0.1);
2357
2358        let config = GBDTConfig::new().with_num_rounds(20).with_max_depth(4);
2359
2360        let model = GBDTModel::train_binned(&dataset, config).unwrap();
2361        let importances = model.feature_importance();
2362
2363        assert_eq!(importances.len(), 3);
2364
2365        // Importances should sum to ~1 (normalized)
2366        let total: f32 = importances.iter().sum();
2367        assert!((total - 1.0).abs() < 0.01);
2368    }
2369
2370    #[test]
2371    fn test_train_with_monotonic_constraints() {
2372        use crate::tree::MonotonicConstraint;
2373
2374        let dataset = create_regression_dataset(500, 0.1);
2375
2376        // Set monotonic increasing constraint on feature 0
2377        let config = GBDTConfig::new()
2378            .with_num_rounds(10)
2379            .with_max_depth(4)
2380            .with_monotonic_constraints(vec![
2381                MonotonicConstraint::Increasing,
2382                MonotonicConstraint::None,
2383                MonotonicConstraint::None,
2384            ]);
2385
2386        let model = GBDTModel::train_binned(&dataset, config).unwrap();
2387
2388        // Model should train successfully with constraints
2389        assert!(model.num_trees() > 0);
2390
2391        // Predictions should still work
2392        let predictions = model.predict(&dataset);
2393        assert_eq!(predictions.len(), 500);
2394    }
2395
2396    #[test]
2397    fn test_train_with_interaction_constraints() {
2398        let dataset = create_regression_dataset(500, 0.1);
2399
2400        // Features 0, 1 can interact; feature 2 is unconstrained
2401        let config = GBDTConfig::new()
2402            .with_num_rounds(10)
2403            .with_max_depth(4)
2404            .with_interaction_groups(vec![vec![0, 1]]);
2405
2406        let model = GBDTModel::train_binned(&dataset, config).unwrap();
2407
2408        // Model should train successfully with constraints
2409        assert!(model.num_trees() > 0);
2410
2411        // Predictions should still work
2412        let predictions = model.predict(&dataset);
2413        assert_eq!(predictions.len(), 500);
2414    }
2415
2416    #[test]
2417    fn test_train_with_era_splitting() {
2418        let num_rows = 600;
2419        let num_eras = 3;
2420
2421        // Create dataset with era indices
2422        let mut dataset = create_regression_dataset(num_rows, 0.1);
2423
2424        // Assign era indices (0, 1, 2) in round-robin fashion
2425        let era_indices: Vec<u16> = (0..num_rows).map(|i| (i % num_eras) as u16).collect();
2426        dataset.set_era_indices(era_indices);
2427
2428        // Train with era splitting enabled
2429        let config = GBDTConfig::new()
2430            .with_num_rounds(10)
2431            .with_max_depth(3)
2432            .with_learning_rate(0.1)
2433            .with_era_splitting(true);
2434
2435        let model = GBDTModel::train_binned(&dataset, config).unwrap();
2436
2437        // Model should train successfully with era splitting
2438        assert!(model.num_trees() > 0);
2439
2440        // Predictions should still work
2441        let predictions = model.predict(&dataset);
2442        assert_eq!(predictions.len(), num_rows);
2443    }
2444
2445    #[test]
2446    fn test_train_with_eras_high_level_api() {
2447        let num_rows = 600;
2448        let num_features = 5;
2449        let num_eras = 3;
2450
2451        // Create random features (row-major)
2452        use rand::SeedableRng;
2453        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
2454        let features: Vec<f32> = (0..num_rows * num_features)
2455            .map(|_| rand::Rng::gen_range(&mut rng, 0.0..1.0))
2456            .collect();
2457
2458        // Create targets based on first two features
2459        let targets: Vec<f32> = (0..num_rows)
2460            .map(|i| {
2461                let f0 = features[i * num_features];
2462                let f1 = features[i * num_features + 1];
2463                f0 * 2.0 + f1 * 3.0 + rand::Rng::gen_range(&mut rng, -0.1..0.1)
2464            })
2465            .collect();
2466
2467        // Era indices in round-robin fashion
2468        let era_indices: Vec<u16> = (0..num_rows).map(|i| (i % num_eras) as u16).collect();
2469
2470        // Train with era splitting via high-level API
2471        let config = GBDTConfig::new()
2472            .with_num_rounds(10)
2473            .with_max_depth(3)
2474            .with_learning_rate(0.1)
2475            .with_era_splitting(true);
2476
2477        let model = GBDTModel::train_with_eras(
2478            &features,
2479            num_features,
2480            &targets,
2481            &era_indices,
2482            config,
2483            None,
2484        )
2485        .unwrap();
2486
2487        // Model should train successfully
2488        assert!(model.num_trees() > 0);
2489        assert_eq!(model.num_features(), num_features);
2490
2491        // Predictions should work (convert to f64 for predict_raw)
2492        let features_f64: Vec<f64> = features.iter().map(|&v| v as f64).collect();
2493        let predictions = model.predict_raw(&features_f64);
2494        assert_eq!(predictions.len(), num_rows);
2495    }
2496
2497    // Helper function to create a multi-class dataset
2498    fn create_multiclass_dataset(num_rows: usize, num_classes: usize) -> BinnedDataset {
2499        let num_features = 4;
2500
2501        // Generate features with some class-specific patterns
2502        let mut features = Vec::with_capacity(num_rows * num_features);
2503        for f in 0..num_features {
2504            for r in 0..num_rows {
2505                features.push(((r * (f + 1) * 17 + r % num_classes * 50) % 256) as u8);
2506            }
2507        }
2508
2509        // Generate class labels (0, 1, ..., num_classes-1)
2510        let targets: Vec<f32> = (0..num_rows).map(|i| (i % num_classes) as f32).collect();
2511
2512        let feature_info = (0..num_features)
2513            .map(|i| FeatureInfo {
2514                name: format!("feature_{}", i),
2515                feature_type: FeatureType::Numeric,
2516                num_bins: 255,
2517                bin_boundaries: vec![],
2518            })
2519            .collect();
2520
2521        BinnedDataset::new(num_rows, features, targets, feature_info)
2522    }
2523
2524    #[test]
2525    fn test_multiclass_training() {
2526        let num_classes = 3;
2527        let dataset = create_multiclass_dataset(300, num_classes);
2528
2529        let config = GBDTConfig::new()
2530            .with_num_rounds(10)
2531            .with_max_depth(3)
2532            .with_learning_rate(0.1)
2533            .with_multiclass_logloss(num_classes);
2534
2535        let model = GBDTModel::train_binned(&dataset, config).unwrap();
2536
2537        // Should have K trees per round = 10 * 3 = 30 trees
2538        assert_eq!(model.num_trees(), 10 * num_classes);
2539        assert!(model.is_multiclass());
2540        assert_eq!(model.get_num_classes(), num_classes);
2541    }
2542
2543    #[test]
2544    fn test_multiclass_prediction() {
2545        let num_classes = 3;
2546        let dataset = create_multiclass_dataset(150, num_classes);
2547
2548        let config = GBDTConfig::new()
2549            .with_num_rounds(20)
2550            .with_max_depth(4)
2551            .with_learning_rate(0.1)
2552            .with_multiclass_logloss(num_classes);
2553
2554        let model = GBDTModel::train_binned(&dataset, config).unwrap();
2555
2556        // Test probability predictions
2557        let proba = model.predict_proba_multiclass(&dataset);
2558        assert_eq!(proba.len(), 150);
2559
2560        // Each row should have num_classes probabilities that sum to 1
2561        for row_proba in &proba {
2562            assert_eq!(row_proba.len(), num_classes);
2563            let sum: f32 = row_proba.iter().sum();
2564            assert!((sum - 1.0).abs() < 1e-5, "Probabilities should sum to 1");
2565
2566            // All probabilities should be in [0, 1]
2567            for &p in row_proba {
2568                assert!(p >= 0.0 && p <= 1.0, "Probability should be in [0, 1]");
2569            }
2570        }
2571
2572        // Test class predictions
2573        let classes = model.predict_class_multiclass(&dataset);
2574        assert_eq!(classes.len(), 150);
2575
2576        // All predicted classes should be valid
2577        for &c in &classes {
2578            assert!(
2579                (c as usize) < num_classes,
2580                "Predicted class should be < num_classes"
2581            );
2582        }
2583
2584        // Check that predictions are better than random (at least some correct)
2585        let targets = dataset.targets();
2586        let correct: usize = classes
2587            .iter()
2588            .zip(targets.iter())
2589            .filter(|(&pred, &target)| pred == target as u32)
2590            .count();
2591        let accuracy = correct as f32 / 150.0;
2592
2593        // With balanced classes and learned patterns, should be better than random (33%)
2594        assert!(
2595            accuracy > 0.4,
2596            "Multi-class accuracy {} should be better than random",
2597            accuracy
2598        );
2599    }
2600
2601    #[test]
2602    fn test_save_to_directory() {
2603        use tempfile::tempdir;
2604
2605        let dataset = create_regression_dataset(100, 0.1);
2606
2607        let config = GBDTConfig::new().with_num_rounds(5).with_max_depth(3);
2608
2609        let model = GBDTModel::train_binned(&dataset, config.clone()).unwrap();
2610
2611        // Create temp directory
2612        let dir = tempdir().unwrap();
2613        let output_path = dir.path().join("model_output");
2614
2615        // Save model and config
2616        model
2617            .save_to_directory(&output_path, &config, &[ModelFormat::Rkyv])
2618            .unwrap();
2619
2620        // Verify files exist
2621        assert!(output_path.join("config.json").exists());
2622        assert!(output_path.join("model.rkyv").exists());
2623
2624        // Load model and verify it works
2625        let loaded = crate::serialize::load_model(output_path.join("model.rkyv")).unwrap();
2626        assert_eq!(loaded.num_trees(), model.num_trees());
2627
2628        // Verify config.json is valid JSON
2629        let config_content = std::fs::read_to_string(output_path.join("config.json")).unwrap();
2630        let parsed: serde_json::Value = serde_json::from_str(&config_content).unwrap();
2631        assert!(parsed.get("num_rounds").is_some());
2632        assert_eq!(parsed["num_rounds"], 5);
2633    }
2634
2635    #[test]
2636    fn test_train_with_output() {
2637        use tempfile::tempdir;
2638
2639        // Create simple test data
2640        let features = vec![
2641            1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
2642        ];
2643        let targets = vec![1.0f32, 2.0, 3.0, 4.0];
2644        let num_features = 3;
2645
2646        let config = GBDTConfig::new().with_num_rounds(5).with_max_depth(2);
2647
2648        // Create temp directory
2649        let dir = tempdir().unwrap();
2650        let output_path = dir.path().join("train_output");
2651
2652        // Train and save in one go
2653        let model = GBDTModel::train_with_output(
2654            &features,
2655            num_features,
2656            &targets,
2657            config,
2658            None,
2659            &output_path,
2660            &[ModelFormat::Rkyv, ModelFormat::Bincode],
2661        )
2662        .unwrap();
2663
2664        // Verify model was trained
2665        assert_eq!(model.num_trees(), 5);
2666
2667        // Verify files exist
2668        assert!(output_path.join("config.json").exists());
2669        assert!(output_path.join("model.rkyv").exists());
2670        assert!(output_path.join("model.bin").exists());
2671
2672        // Load both formats and verify
2673        let loaded_rkyv = crate::serialize::load_model(output_path.join("model.rkyv")).unwrap();
2674        let loaded_bincode =
2675            crate::serialize::load_model_bincode(output_path.join("model.bin")).unwrap();
2676        assert_eq!(loaded_rkyv.num_trees(), 5);
2677        assert_eq!(loaded_bincode.num_trees(), 5);
2678    }
2679
2680    #[test]
2681    fn test_save_to_directory_empty_formats_error() {
2682        use tempfile::tempdir;
2683
2684        let dataset = create_regression_dataset(100, 0.1);
2685
2686        let config = GBDTConfig::new().with_num_rounds(5).with_max_depth(3);
2687
2688        let model = GBDTModel::train_binned(&dataset, config.clone()).unwrap();
2689
2690        // Create temp directory
2691        let dir = tempdir().unwrap();
2692        let output_path = dir.path().join("model_output");
2693
2694        // Try to save with empty formats - should fail
2695        let result = model.save_to_directory(&output_path, &config, &[]);
2696        assert!(result.is_err());
2697
2698        let err_msg = result.unwrap_err().to_string();
2699        assert!(
2700            err_msg.contains("formats must not be empty"),
2701            "Error message: {}",
2702            err_msg
2703        );
2704    }
2705
2706    #[test]
2707    fn test_save_to_directory_creates_parent_dirs() {
2708        use tempfile::tempdir;
2709
2710        let dataset = create_regression_dataset(100, 0.1);
2711
2712        let config = GBDTConfig::new().with_num_rounds(5).with_max_depth(3);
2713
2714        let model = GBDTModel::train_binned(&dataset, config.clone()).unwrap();
2715
2716        // Create temp directory with nested non-existent path
2717        let dir = tempdir().unwrap();
2718        let output_path = dir
2719            .path()
2720            .join("deeply")
2721            .join("nested")
2722            .join("path")
2723            .join("model");
2724
2725        // Should create all parent directories
2726        model
2727            .save_to_directory(&output_path, &config, &[ModelFormat::Rkyv])
2728            .unwrap();
2729
2730        // Verify files exist
2731        assert!(output_path.join("config.json").exists());
2732        assert!(output_path.join("model.rkyv").exists());
2733    }
2734
2735    #[test]
2736    fn test_save_to_directory_config_json_completeness() {
2737        use tempfile::tempdir;
2738
2739        let dataset = create_regression_dataset(100, 0.1);
2740
2741        // Create a config with various non-default settings
2742        let config = GBDTConfig::new()
2743            .with_num_rounds(42)
2744            .with_max_depth(7)
2745            .with_learning_rate(0.05)
2746            .with_subsample(0.8)
2747            .with_lambda(2.0)
2748            .with_entropy_weight(0.1);
2749
2750        let model = GBDTModel::train_binned(&dataset, config.clone()).unwrap();
2751
2752        // Save
2753        let dir = tempdir().unwrap();
2754        let output_path = dir.path().join("model_output");
2755        model
2756            .save_to_directory(&output_path, &config, &[ModelFormat::Rkyv])
2757            .unwrap();
2758
2759        // Load and verify config.json has all fields
2760        let config_content = std::fs::read_to_string(output_path.join("config.json")).unwrap();
2761        let parsed: serde_json::Value = serde_json::from_str(&config_content).unwrap();
2762
2763        // Verify key fields are present and have correct values
2764        assert_eq!(parsed["num_rounds"], 42);
2765        assert_eq!(parsed["max_depth"], 7);
2766        assert!((parsed["learning_rate"].as_f64().unwrap() - 0.05).abs() < 0.001);
2767        assert!((parsed["subsample"].as_f64().unwrap() - 0.8).abs() < 0.001);
2768        assert!((parsed["lambda"].as_f64().unwrap() - 2.0).abs() < 0.001);
2769        assert!((parsed["entropy_weight"].as_f64().unwrap() - 0.1).abs() < 0.001);
2770    }
2771
2772    // =========================================================================
2773    // Incremental Learning Tests
2774    // =========================================================================
2775
2776    #[test]
2777    fn test_tree_residual_appending() {
2778        let dataset = create_regression_dataset(100, 0.1);
2779
2780        // Train initial model with 5 trees
2781        let config = GBDTConfig::new()
2782            .with_num_rounds(5)
2783            .with_max_depth(3)
2784            .with_learning_rate(0.1);
2785
2786        let mut model = GBDTModel::train_binned(&dataset, config.clone()).unwrap();
2787        assert_eq!(model.num_trees(), 5);
2788
2789        let initial_preds = model.predict(&dataset);
2790        let initial_mse: f32 = initial_preds
2791            .iter()
2792            .zip(dataset.targets())
2793            .map(|(p, t)| (p - t).powi(2))
2794            .sum::<f32>()
2795            / 100.0;
2796
2797        // Compute residuals
2798        let residuals = model.compute_residuals(&dataset, dataset.targets());
2799        assert_eq!(residuals.len(), 100);
2800
2801        // Train additional trees on residuals
2802        // (In practice, you'd use the full training pipeline, but for testing
2803        // we just verify the append functionality works)
2804
2805        // For now, just verify we can append trees
2806        let second_model = GBDTModel::train_binned(&dataset, config).unwrap();
2807        let trees_to_append: Vec<_> = second_model.trees().to_vec();
2808
2809        model.append_trees(trees_to_append);
2810        assert_eq!(model.num_trees(), 10);
2811
2812        // Predictions should use all 10 trees
2813        let new_preds = model.predict(&dataset);
2814        assert_eq!(new_preds.len(), 100);
2815
2816        // MSE may change (could be better or worse depending on the data)
2817        let new_mse: f32 = new_preds
2818            .iter()
2819            .zip(dataset.targets())
2820            .map(|(p, t)| (p - t).powi(2))
2821            .sum::<f32>()
2822            / 100.0;
2823
2824        // Both MSEs should be finite and reasonable
2825        assert!(initial_mse.is_finite());
2826        assert!(new_mse.is_finite());
2827    }
2828
2829    #[test]
2830    fn test_tree_ensemble_growth() {
2831        let dataset = create_regression_dataset(100, 0.1);
2832
2833        // Train with 5 trees
2834        let config = GBDTConfig::new().with_num_rounds(5).with_max_depth(3);
2835
2836        let mut model = GBDTModel::train_binned(&dataset, config.clone()).unwrap();
2837        assert_eq!(model.num_trees(), 5);
2838        assert_eq!(model.num_rounds(), 5);
2839
2840        // Train another batch and append
2841        let second_model = GBDTModel::train_binned(&dataset, config).unwrap();
2842        model.append_trees(second_model.trees().to_vec());
2843
2844        assert_eq!(model.num_trees(), 10);
2845        assert_eq!(model.num_rounds(), 10);
2846
2847        // Predictions should use all trees
2848        let predictions = model.predict(&dataset);
2849        assert_eq!(predictions.len(), 100);
2850        assert!(predictions.iter().all(|p| p.is_finite()));
2851    }
2852
2853    #[test]
2854    fn test_append_single_tree() {
2855        let dataset = create_regression_dataset(100, 0.1);
2856
2857        let config = GBDTConfig::new().with_num_rounds(1).with_max_depth(3);
2858
2859        let mut model = GBDTModel::train_binned(&dataset, config.clone()).unwrap();
2860        assert_eq!(model.num_trees(), 1);
2861
2862        // Append single tree
2863        let second_model = GBDTModel::train_binned(&dataset, config).unwrap();
2864        model.append_tree(second_model.trees()[0].clone());
2865
2866        assert_eq!(model.num_trees(), 2);
2867    }
2868
2869    #[test]
2870    fn test_compute_residuals_correctness() {
2871        let dataset = create_regression_dataset(50, 0.1);
2872
2873        let config = GBDTConfig::new().with_num_rounds(5).with_max_depth(3);
2874
2875        let model = GBDTModel::train_binned(&dataset, config).unwrap();
2876
2877        let predictions = model.predict(&dataset);
2878        let residuals = model.compute_residuals(&dataset, dataset.targets());
2879
2880        // Verify residuals = targets - predictions
2881        for (i, (r, (p, t))) in residuals
2882            .iter()
2883            .zip(predictions.iter().zip(dataset.targets()))
2884            .enumerate()
2885        {
2886            let expected = t - p;
2887            assert!(
2888                (r - expected).abs() < 1e-5,
2889                "Residual {} mismatch: got {}, expected {}",
2890                i,
2891                r,
2892                expected
2893            );
2894        }
2895    }
2896
2897    #[test]
2898    fn test_truncate_to_rounds() {
2899        let dataset = create_regression_dataset(100, 0.1);
2900
2901        let config = GBDTConfig::new().with_num_rounds(10).with_max_depth(3);
2902
2903        let mut model = GBDTModel::train_binned(&dataset, config).unwrap();
2904        assert_eq!(model.num_trees(), 10);
2905
2906        // Truncate to 5 rounds
2907        model.truncate_to_rounds(5);
2908        assert_eq!(model.num_trees(), 5);
2909        assert_eq!(model.num_rounds(), 5);
2910
2911        // Truncating to more rounds than exist should be no-op
2912        model.truncate_to_rounds(20);
2913        assert_eq!(model.num_trees(), 5);
2914    }
2915
2916    #[test]
2917    fn test_is_compatible_for_update() {
2918        let dataset = create_regression_dataset(100, 0.1);
2919
2920        let config = GBDTConfig::new().with_num_rounds(3);
2921        let model = GBDTModel::train_binned(&dataset, config).unwrap();
2922
2923        // Should be compatible with same number of features
2924        assert!(model.is_compatible_for_update(3));
2925
2926        // Should not be compatible with different number
2927        assert!(!model.is_compatible_for_update(5));
2928        assert!(!model.is_compatible_for_update(2));
2929    }
2930}