sklears_preprocessing/
cross_validation.rs

1//! Cross-Validation Utilities for Preprocessing
2//!
3//! Provides cross-validation support for preprocessing parameter tuning,
4//! including grid search and random search for optimal preprocessing parameters.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::random::essentials::Uniform;
8use scirs2_core::random::{seeded_rng, Distribution};
9use sklears_core::prelude::SklearsError;
10use std::collections::HashMap;
11
12/// K-Fold cross-validation splitter
13#[derive(Debug, Clone)]
14pub struct KFold {
15    /// Number of folds
16    pub n_splits: usize,
17    /// Whether to shuffle data before splitting
18    pub shuffle: bool,
19    /// Random seed for shuffling
20    pub random_state: Option<u64>,
21}
22
23impl KFold {
24    /// Create a new K-Fold splitter
25    pub fn new(n_splits: usize, shuffle: bool, random_state: Option<u64>) -> Self {
26        Self {
27            n_splits,
28            shuffle,
29            random_state,
30        }
31    }
32
33    /// Generate train/test splits
34    pub fn split(&self, n_samples: usize) -> Result<Vec<(Vec<usize>, Vec<usize>)>, SklearsError> {
35        if n_samples < self.n_splits {
36            return Err(SklearsError::InvalidInput(format!(
37                "Cannot split {} samples into {} folds",
38                n_samples, self.n_splits
39            )));
40        }
41
42        let mut indices: Vec<usize> = (0..n_samples).collect();
43
44        if self.shuffle {
45            use std::time::{SystemTime, UNIX_EPOCH};
46
47            let seed = self.random_state.unwrap_or_else(|| {
48                SystemTime::now()
49                    .duration_since(UNIX_EPOCH)
50                    .unwrap()
51                    .as_secs()
52            });
53
54            let mut rng = seeded_rng(seed);
55
56            // Fisher-Yates shuffle
57            for i in (1..indices.len()).rev() {
58                let uniform = Uniform::new(0, i + 1).unwrap();
59                let j = uniform.sample(&mut rng);
60                indices.swap(i, j);
61            }
62        }
63
64        let fold_size = n_samples / self.n_splits;
65        let mut splits = Vec::new();
66
67        for fold_idx in 0..self.n_splits {
68            let test_start = fold_idx * fold_size;
69            let test_end = if fold_idx == self.n_splits - 1 {
70                n_samples
71            } else {
72                (fold_idx + 1) * fold_size
73            };
74
75            let test_indices: Vec<usize> = indices[test_start..test_end].to_vec();
76            let train_indices: Vec<usize> = indices[..test_start]
77                .iter()
78                .chain(&indices[test_end..])
79                .copied()
80                .collect();
81
82            splits.push((train_indices, test_indices));
83        }
84
85        Ok(splits)
86    }
87}
88
89/// Stratified K-Fold cross-validation splitter
90#[derive(Debug, Clone)]
91pub struct StratifiedKFold {
92    /// Number of folds
93    pub n_splits: usize,
94    /// Whether to shuffle data before splitting
95    pub shuffle: bool,
96    /// Random seed for shuffling
97    pub random_state: Option<u64>,
98}
99
100impl StratifiedKFold {
101    /// Create a new Stratified K-Fold splitter
102    pub fn new(n_splits: usize, shuffle: bool, random_state: Option<u64>) -> Self {
103        Self {
104            n_splits,
105            shuffle,
106            random_state,
107        }
108    }
109
110    /// Generate stratified train/test splits
111    pub fn split(&self, y: &Array1<i32>) -> Result<Vec<(Vec<usize>, Vec<usize>)>, SklearsError> {
112        let n_samples = y.len();
113
114        if n_samples < self.n_splits {
115            return Err(SklearsError::InvalidInput(format!(
116                "Cannot split {} samples into {} folds",
117                n_samples, self.n_splits
118            )));
119        }
120
121        // Group indices by class
122        let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
123        for (idx, &label) in y.iter().enumerate() {
124            class_indices.entry(label).or_default().push(idx);
125        }
126
127        // Shuffle within each class
128        if self.shuffle {
129            use std::time::{SystemTime, UNIX_EPOCH};
130
131            let seed = self.random_state.unwrap_or_else(|| {
132                SystemTime::now()
133                    .duration_since(UNIX_EPOCH)
134                    .unwrap()
135                    .as_secs()
136            });
137
138            let mut rng = seeded_rng(seed);
139
140            for indices in class_indices.values_mut() {
141                for i in (1..indices.len()).rev() {
142                    let uniform = Uniform::new(0, i + 1).unwrap();
143                    let j = uniform.sample(&mut rng);
144                    indices.swap(i, j);
145                }
146            }
147        }
148
149        // Create splits maintaining class distribution
150        let mut splits: Vec<(Vec<usize>, Vec<usize>)> = vec![];
151
152        for fold_idx in 0..self.n_splits {
153            let mut train_indices = Vec::new();
154            let mut test_indices = Vec::new();
155
156            for indices in class_indices.values() {
157                let fold_size = indices.len() / self.n_splits;
158                let test_start = fold_idx * fold_size;
159                let test_end = if fold_idx == self.n_splits - 1 {
160                    indices.len()
161                } else {
162                    (fold_idx + 1) * fold_size
163                };
164
165                test_indices.extend(&indices[test_start..test_end]);
166                train_indices.extend(&indices[..test_start]);
167                train_indices.extend(&indices[test_end..]);
168            }
169
170            splits.push((train_indices, test_indices));
171        }
172
173        Ok(splits)
174    }
175}
176
177/// Cross-validation score result
178#[derive(Debug, Clone)]
179pub struct CVScore {
180    /// Mean score across folds
181    pub mean: f64,
182    /// Standard deviation of scores
183    pub std: f64,
184    /// Individual fold scores
185    pub scores: Vec<f64>,
186}
187
188/// Grid search parameter specification
189#[derive(Debug, Clone)]
190pub struct ParameterGrid {
191    parameters: HashMap<String, Vec<f64>>,
192}
193
194impl ParameterGrid {
195    /// Create a new parameter grid
196    pub fn new() -> Self {
197        Self {
198            parameters: HashMap::new(),
199        }
200    }
201
202    /// Add a parameter with possible values
203    pub fn add_parameter(mut self, name: String, values: Vec<f64>) -> Self {
204        self.parameters.insert(name, values);
205        self
206    }
207
208    /// Generate all parameter combinations
209    pub fn combinations(&self) -> Vec<HashMap<String, f64>> {
210        if self.parameters.is_empty() {
211            return vec![HashMap::new()];
212        }
213
214        let mut result = vec![HashMap::new()];
215
216        for (param_name, param_values) in &self.parameters {
217            let mut new_result = Vec::new();
218
219            for combination in &result {
220                for &value in param_values {
221                    let mut new_combination = combination.clone();
222                    new_combination.insert(param_name.clone(), value);
223                    new_result.push(new_combination);
224                }
225            }
226
227            result = new_result;
228        }
229
230        result
231    }
232
233    /// Get total number of combinations
234    pub fn n_combinations(&self) -> usize {
235        if self.parameters.is_empty() {
236            return 0;
237        }
238
239        self.parameters.values().map(|v| v.len()).product()
240    }
241}
242
243impl Default for ParameterGrid {
244    fn default() -> Self {
245        Self::new()
246    }
247}
248
249/// Random search parameter specification
250#[derive(Debug, Clone)]
251pub struct ParameterDistribution {
252    parameters: HashMap<String, (f64, f64)>, // (min, max) for uniform distribution
253}
254
255impl ParameterDistribution {
256    /// Create a new parameter distribution
257    pub fn new() -> Self {
258        Self {
259            parameters: HashMap::new(),
260        }
261    }
262
263    /// Add a parameter with range
264    pub fn add_parameter(mut self, name: String, min: f64, max: f64) -> Self {
265        self.parameters.insert(name, (min, max));
266        self
267    }
268
269    /// Sample random parameters
270    pub fn sample(&self, n_iter: usize, random_state: Option<u64>) -> Vec<HashMap<String, f64>> {
271        use std::time::{SystemTime, UNIX_EPOCH};
272
273        let seed = random_state.unwrap_or_else(|| {
274            SystemTime::now()
275                .duration_since(UNIX_EPOCH)
276                .unwrap()
277                .as_secs()
278        });
279
280        let mut rng = seeded_rng(seed);
281
282        (0..n_iter)
283            .map(|_| {
284                self.parameters
285                    .iter()
286                    .map(|(name, &(min, max))| {
287                        let uniform = Uniform::new_inclusive(min, max).unwrap();
288                        (name.clone(), uniform.sample(&mut rng))
289                    })
290                    .collect()
291            })
292            .collect()
293    }
294}
295
296impl Default for ParameterDistribution {
297    fn default() -> Self {
298        Self::new()
299    }
300}
301
302/// Evaluation metric for preprocessing quality
303pub trait PreprocessingMetric {
304    /// Evaluate preprocessing quality
305    fn evaluate(&self, x_original: &Array2<f64>, x_transformed: &Array2<f64>) -> f64;
306}
307
308/// Variance preservation metric
309pub struct VariancePreservationMetric;
310
311impl PreprocessingMetric for VariancePreservationMetric {
312    fn evaluate(&self, x_original: &Array2<f64>, x_transformed: &Array2<f64>) -> f64 {
313        let mut total_variance_ratio = 0.0;
314
315        for j in 0..x_original.ncols() {
316            let original_col = x_original.column(j);
317            let transformed_col = x_transformed.column(j);
318
319            let original_var = Self::compute_variance(original_col);
320            let transformed_var = Self::compute_variance(transformed_col);
321
322            if original_var > 1e-10 {
323                total_variance_ratio += transformed_var / original_var;
324            }
325        }
326
327        total_variance_ratio / x_original.ncols() as f64
328    }
329}
330
331impl VariancePreservationMetric {
332    fn compute_variance<'a, I>(values: I) -> f64
333    where
334        I: IntoIterator<Item = &'a f64>,
335    {
336        let vals: Vec<f64> = values
337            .into_iter()
338            .copied()
339            .filter(|v| !v.is_nan())
340            .collect();
341
342        if vals.is_empty() {
343            return 0.0;
344        }
345
346        let mean = vals.iter().sum::<f64>() / vals.len() as f64;
347        vals.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / vals.len() as f64
348    }
349}
350
351/// Information preservation metric (measures mutual information preservation)
352pub struct InformationPreservationMetric;
353
354impl PreprocessingMetric for InformationPreservationMetric {
355    fn evaluate(&self, x_original: &Array2<f64>, x_transformed: &Array2<f64>) -> f64 {
356        // Simplified: Use correlation as proxy for information preservation
357        let mut total_correlation = 0.0;
358        let mut count = 0;
359
360        for j in 0..x_original.ncols().min(x_transformed.ncols()) {
361            let corr = Self::compute_correlation(x_original, x_transformed, j);
362            if !corr.is_nan() {
363                total_correlation += corr.abs();
364                count += 1;
365            }
366        }
367
368        if count > 0 {
369            total_correlation / count as f64
370        } else {
371            0.0
372        }
373    }
374}
375
376impl InformationPreservationMetric {
377    fn compute_correlation(x1: &Array2<f64>, x2: &Array2<f64>, col_idx: usize) -> f64 {
378        let col1 = x1.column(col_idx);
379        let col2 = x2.column(col_idx);
380
381        let pairs: Vec<(f64, f64)> = col1
382            .iter()
383            .zip(col2.iter())
384            .filter(|(a, b)| !a.is_nan() && !b.is_nan())
385            .map(|(&a, &b)| (a, b))
386            .collect();
387
388        if pairs.len() < 2 {
389            return 0.0;
390        }
391
392        let mean1 = pairs.iter().map(|(a, _)| a).sum::<f64>() / pairs.len() as f64;
393        let mean2 = pairs.iter().map(|(_, b)| b).sum::<f64>() / pairs.len() as f64;
394
395        let mut cov = 0.0;
396        let mut var1 = 0.0;
397        let mut var2 = 0.0;
398
399        for (a, b) in &pairs {
400            let d1 = a - mean1;
401            let d2 = b - mean2;
402            cov += d1 * d2;
403            var1 += d1 * d1;
404            var2 += d2 * d2;
405        }
406
407        if var1 < 1e-10 || var2 < 1e-10 {
408            return 0.0;
409        }
410
411        cov / (var1 * var2).sqrt()
412    }
413}
414
415#[cfg(test)]
416mod tests {
417    use super::*;
418    use scirs2_core::random::essentials::Normal;
419    use scirs2_core::random::{seeded_rng, Distribution};
420
421    fn generate_test_data(nrows: usize, ncols: usize, seed: u64) -> Array2<f64> {
422        let mut rng = seeded_rng(seed);
423        let normal = Normal::new(0.0, 1.0).unwrap();
424
425        let data: Vec<f64> = (0..nrows * ncols)
426            .map(|_| normal.sample(&mut rng))
427            .collect();
428
429        Array2::from_shape_vec((nrows, ncols), data).unwrap()
430    }
431
432    #[test]
433    fn test_kfold_split() {
434        let kfold = KFold::new(5, false, Some(42));
435        let splits = kfold.split(100).unwrap();
436
437        assert_eq!(splits.len(), 5);
438
439        for (train, test) in &splits {
440            assert!(train.len() > 0);
441            assert!(test.len() > 0);
442            assert_eq!(train.len() + test.len(), 100);
443        }
444    }
445
446    #[test]
447    fn test_kfold_shuffle() {
448        let kfold1 = KFold::new(3, true, Some(42));
449        let splits1 = kfold1.split(30).unwrap();
450
451        let kfold2 = KFold::new(3, false, None);
452        let splits2 = kfold2.split(30).unwrap();
453
454        // Shuffled and non-shuffled should be different
455        let different = splits1[0].0 != splits2[0].0;
456        assert!(different);
457    }
458
459    #[test]
460    fn test_stratified_kfold() {
461        let y = Array1::from_vec(vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2]);
462
463        let stratified = StratifiedKFold::new(3, false, Some(42));
464        let splits = stratified.split(&y).unwrap();
465
466        assert_eq!(splits.len(), 3);
467
468        // Check that each split maintains class distribution
469        for (train_indices, test_indices) in &splits {
470            let _train_classes: Vec<i32> = train_indices.iter().map(|&i| y[i]).collect();
471            let test_classes: Vec<i32> = test_indices.iter().map(|&i| y[i]).collect();
472
473            // Count classes in test set
474            let test_0 = test_classes.iter().filter(|&&c| c == 0).count();
475            let test_1 = test_classes.iter().filter(|&&c| c == 1).count();
476            let test_2 = test_classes.iter().filter(|&&c| c == 2).count();
477
478            // Each class should appear roughly equally
479            assert!(test_0 > 0);
480            assert!(test_1 > 0);
481            assert!(test_2 > 0);
482        }
483    }
484
485    #[test]
486    fn test_parameter_grid() {
487        let grid = ParameterGrid::new()
488            .add_parameter("alpha".to_string(), vec![0.1, 1.0, 10.0])
489            .add_parameter("beta".to_string(), vec![0.5, 1.5]);
490
491        let combinations = grid.combinations();
492
493        assert_eq!(combinations.len(), 6); // 3 * 2 = 6
494        assert_eq!(grid.n_combinations(), 6);
495
496        // Check that all combinations are present
497        let has_alpha_0_1 = combinations.iter().any(|c| c.get("alpha") == Some(&0.1));
498        assert!(has_alpha_0_1);
499    }
500
501    #[test]
502    fn test_parameter_distribution() {
503        let dist = ParameterDistribution::new()
504            .add_parameter("alpha".to_string(), 0.0, 1.0)
505            .add_parameter("beta".to_string(), 0.0, 10.0);
506
507        let samples = dist.sample(10, Some(42));
508
509        assert_eq!(samples.len(), 10);
510
511        for sample in &samples {
512            let alpha = sample.get("alpha").unwrap();
513            let beta = sample.get("beta").unwrap();
514
515            assert!(*alpha >= 0.0 && *alpha <= 1.0);
516            assert!(*beta >= 0.0 && *beta <= 10.0);
517        }
518    }
519
520    #[test]
521    fn test_variance_preservation_metric() {
522        let x_original = generate_test_data(100, 5, 42);
523        let x_transformed = x_original.clone();
524
525        let metric = VariancePreservationMetric;
526        let score = metric.evaluate(&x_original, &x_transformed);
527
528        // Same data should have score close to 1.0
529        assert!((score - 1.0).abs() < 0.1);
530    }
531
532    #[test]
533    fn test_information_preservation_metric() {
534        let x_original = generate_test_data(100, 5, 123);
535        let x_transformed = x_original.clone();
536
537        let metric = InformationPreservationMetric;
538        let score = metric.evaluate(&x_original, &x_transformed);
539
540        // Same data should have high correlation
541        assert!(score > 0.9);
542    }
543
544    #[test]
545    fn test_kfold_edge_case_small_dataset() {
546        let kfold = KFold::new(5, false, Some(42));
547        let result = kfold.split(3);
548
549        assert!(result.is_err());
550    }
551
552    #[test]
553    fn test_empty_parameter_grid() {
554        let grid = ParameterGrid::new();
555        let combinations = grid.combinations();
556
557        assert_eq!(combinations.len(), 1);
558        assert!(combinations[0].is_empty());
559        assert_eq!(grid.n_combinations(), 0);
560    }
561}