sklears_model_selection/cv/
regression_cv.rs

1//! Regression-specific cross-validation iterators
2
3use super::RegressionCrossValidator;
4use scirs2_core::ndarray::Array1;
5use scirs2_core::random::rngs::StdRng;
6use scirs2_core::random::SeedableRng;
7use scirs2_core::SliceRandomExt;
8use sklears_core::types::Float;
9use std::collections::HashMap;
10
11/// Stratified K-Fold cross-validation for regression tasks
12///
13/// This cross-validator stratifies continuous target values by binning them into quantiles
14/// and then performs stratified sampling based on these bins. This ensures that each fold
15/// has a representative distribution of target values, which is particularly useful for
16/// regression problems with non-uniform target distributions.
17#[derive(Debug, Clone)]
18pub struct StratifiedRegressionKFold {
19    n_splits: usize,
20    n_bins: usize,
21    shuffle: bool,
22    random_state: Option<u64>,
23}
24
25impl StratifiedRegressionKFold {
26    /// Create a new StratifiedRegressionKFold cross-validator
27    pub fn new(n_splits: usize) -> Self {
28        assert!(n_splits >= 2, "n_splits must be at least 2");
29        Self {
30            n_splits,
31            n_bins: 10, // Default to 10 bins
32            shuffle: false,
33            random_state: None,
34        }
35    }
36
37    /// Set the number of bins for stratifying continuous targets
38    pub fn n_bins(mut self, n_bins: usize) -> Self {
39        assert!(n_bins >= 2, "n_bins must be at least 2");
40        self.n_bins = n_bins;
41        self
42    }
43
44    /// Set whether to shuffle the data before splitting
45    pub fn shuffle(mut self, shuffle: bool) -> Self {
46        self.shuffle = shuffle;
47        self
48    }
49
50    /// Set the random state for shuffling
51    pub fn random_state(mut self, seed: u64) -> Self {
52        self.random_state = Some(seed);
53        self
54    }
55
56    /// Convert continuous targets to discrete bins using quantile-based binning
57    fn create_bins(&self, y: &Array1<Float>) -> Array1<i32> {
58        let n_samples = y.len();
59        let mut y_sorted: Vec<(Float, usize)> =
60            y.iter().enumerate().map(|(i, &val)| (val, i)).collect();
61        y_sorted.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
62
63        let mut bins = Array1::<i32>::zeros(n_samples);
64        let bin_size = n_samples as f64 / self.n_bins as f64;
65
66        for (rank, &(_val, orig_idx)) in y_sorted.iter().enumerate() {
67            let bin = ((rank as f64 / bin_size).floor() as usize).min(self.n_bins - 1);
68            bins[orig_idx] = bin as i32;
69        }
70
71        bins
72    }
73
74    /// Calculate the size of each fold
75    fn calculate_fold_sizes(&self, n_samples: usize) -> Vec<usize> {
76        let min_fold_size = n_samples / self.n_splits;
77        let n_larger_folds = n_samples % self.n_splits;
78
79        let mut fold_sizes = vec![min_fold_size; self.n_splits];
80        for fold_size in fold_sizes.iter_mut().take(n_larger_folds) {
81            *fold_size += 1;
82        }
83
84        fold_sizes
85    }
86}
87
88impl RegressionCrossValidator for StratifiedRegressionKFold {
89    fn n_splits(&self) -> usize {
90        self.n_splits
91    }
92
93    fn split_regression(
94        &self,
95        n_samples: usize,
96        y: &Array1<Float>,
97    ) -> Vec<(Vec<usize>, Vec<usize>)> {
98        assert_eq!(
99            y.len(),
100            n_samples,
101            "y must have the same length as n_samples"
102        );
103        assert!(
104            self.n_splits <= n_samples,
105            "Cannot have number of splits {} greater than the number of samples {}",
106            self.n_splits,
107            n_samples
108        );
109
110        // Convert continuous targets to discrete bins
111        let y_binned = self.create_bins(y);
112
113        // Group indices by bin
114        let mut bin_indices: HashMap<i32, Vec<usize>> = HashMap::new();
115        for (idx, &bin) in y_binned.iter().enumerate() {
116            bin_indices.entry(bin).or_default().push(idx);
117        }
118
119        // Check we have enough samples in each bin
120        for indices in bin_indices.values() {
121            assert!(
122                indices.len() >= self.n_splits,
123                "The least populated bin has only {} members, which is less than n_splits={}",
124                indices.len(),
125                self.n_splits
126            );
127        }
128
129        // Shuffle within each bin if requested
130        if self.shuffle {
131            let mut rng = match self.random_state {
132                Some(seed) => StdRng::seed_from_u64(seed),
133                None => {
134                    use scirs2_core::random::thread_rng;
135                    StdRng::from_rng(&mut thread_rng())
136                }
137            };
138            for indices in bin_indices.values_mut() {
139                indices.shuffle(&mut rng);
140            }
141        }
142
143        // Create stratified folds
144        let mut splits = vec![(Vec::new(), Vec::new()); self.n_splits];
145
146        for (_bin, indices) in bin_indices {
147            let fold_sizes = self.calculate_fold_sizes(indices.len());
148            let mut current = 0;
149
150            for i in 0..self.n_splits {
151                let fold_size = fold_sizes[i];
152                let test_end = current + fold_size;
153
154                // Add to test set for this fold
155                splits[i].1.extend(&indices[current..test_end]);
156
157                // Add to train sets for other folds
158                for (j, split) in splits.iter_mut().enumerate().take(self.n_splits) {
159                    if i != j {
160                        split.0.extend(&indices[current..test_end]);
161                    }
162                }
163
164                current = test_end;
165            }
166        }
167
168        splits
169    }
170}