sklears_model_selection/cv/
regression_cv.rs1use super::RegressionCrossValidator;
4use scirs2_core::ndarray::Array1;
5use scirs2_core::random::rngs::StdRng;
6use scirs2_core::random::SeedableRng;
7use scirs2_core::SliceRandomExt;
8use sklears_core::types::Float;
9use std::collections::HashMap;
10
11#[derive(Debug, Clone)]
18pub struct StratifiedRegressionKFold {
19 n_splits: usize,
20 n_bins: usize,
21 shuffle: bool,
22 random_state: Option<u64>,
23}
24
25impl StratifiedRegressionKFold {
26 pub fn new(n_splits: usize) -> Self {
28 assert!(n_splits >= 2, "n_splits must be at least 2");
29 Self {
30 n_splits,
31 n_bins: 10, shuffle: false,
33 random_state: None,
34 }
35 }
36
37 pub fn n_bins(mut self, n_bins: usize) -> Self {
39 assert!(n_bins >= 2, "n_bins must be at least 2");
40 self.n_bins = n_bins;
41 self
42 }
43
44 pub fn shuffle(mut self, shuffle: bool) -> Self {
46 self.shuffle = shuffle;
47 self
48 }
49
50 pub fn random_state(mut self, seed: u64) -> Self {
52 self.random_state = Some(seed);
53 self
54 }
55
56 fn create_bins(&self, y: &Array1<Float>) -> Array1<i32> {
58 let n_samples = y.len();
59 let mut y_sorted: Vec<(Float, usize)> =
60 y.iter().enumerate().map(|(i, &val)| (val, i)).collect();
61 y_sorted.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
62
63 let mut bins = Array1::<i32>::zeros(n_samples);
64 let bin_size = n_samples as f64 / self.n_bins as f64;
65
66 for (rank, &(_val, orig_idx)) in y_sorted.iter().enumerate() {
67 let bin = ((rank as f64 / bin_size).floor() as usize).min(self.n_bins - 1);
68 bins[orig_idx] = bin as i32;
69 }
70
71 bins
72 }
73
74 fn calculate_fold_sizes(&self, n_samples: usize) -> Vec<usize> {
76 let min_fold_size = n_samples / self.n_splits;
77 let n_larger_folds = n_samples % self.n_splits;
78
79 let mut fold_sizes = vec![min_fold_size; self.n_splits];
80 for fold_size in fold_sizes.iter_mut().take(n_larger_folds) {
81 *fold_size += 1;
82 }
83
84 fold_sizes
85 }
86}
87
88impl RegressionCrossValidator for StratifiedRegressionKFold {
89 fn n_splits(&self) -> usize {
90 self.n_splits
91 }
92
93 fn split_regression(
94 &self,
95 n_samples: usize,
96 y: &Array1<Float>,
97 ) -> Vec<(Vec<usize>, Vec<usize>)> {
98 assert_eq!(
99 y.len(),
100 n_samples,
101 "y must have the same length as n_samples"
102 );
103 assert!(
104 self.n_splits <= n_samples,
105 "Cannot have number of splits {} greater than the number of samples {}",
106 self.n_splits,
107 n_samples
108 );
109
110 let y_binned = self.create_bins(y);
112
113 let mut bin_indices: HashMap<i32, Vec<usize>> = HashMap::new();
115 for (idx, &bin) in y_binned.iter().enumerate() {
116 bin_indices.entry(bin).or_default().push(idx);
117 }
118
119 for indices in bin_indices.values() {
121 assert!(
122 indices.len() >= self.n_splits,
123 "The least populated bin has only {} members, which is less than n_splits={}",
124 indices.len(),
125 self.n_splits
126 );
127 }
128
129 if self.shuffle {
131 let mut rng = match self.random_state {
132 Some(seed) => StdRng::seed_from_u64(seed),
133 None => {
134 use scirs2_core::random::thread_rng;
135 StdRng::from_rng(&mut thread_rng())
136 }
137 };
138 for indices in bin_indices.values_mut() {
139 indices.shuffle(&mut rng);
140 }
141 }
142
143 let mut splits = vec![(Vec::new(), Vec::new()); self.n_splits];
145
146 for (_bin, indices) in bin_indices {
147 let fold_sizes = self.calculate_fold_sizes(indices.len());
148 let mut current = 0;
149
150 for i in 0..self.n_splits {
151 let fold_size = fold_sizes[i];
152 let test_end = current + fold_size;
153
154 splits[i].1.extend(&indices[current..test_end]);
156
157 for (j, split) in splits.iter_mut().enumerate().take(self.n_splits) {
159 if i != j {
160 split.0.extend(&indices[current..test_end]);
161 }
162 }
163
164 current = test_end;
165 }
166 }
167
168 splits
169 }
170}