sklears_kernel_approximation/
cross_validation.rs

1//! Cross-validation framework for kernel parameter selection
2//!
3//! This module provides comprehensive cross-validation methods specifically designed
4//! for kernel approximation methods and parameter selection.
5
6use crate::{Nystroem, ParameterLearner, ParameterSet, RBFSampler};
7use rayon::prelude::*;
8use scirs2_core::ndarray::{s, Array1, Array2};
9use scirs2_core::random::rngs::StdRng;
10use scirs2_core::random::seq::SliceRandom;
11use scirs2_core::random::Rng;
12use scirs2_core::random::{thread_rng, SeedableRng};
13use sklears_core::{
14    error::{Result, SklearsError},
15    traits::{Fit, Transform},
16};
17use std::collections::HashMap;
18
19/// Cross-validation strategies
20#[derive(Debug, Clone)]
21/// CVStrategy
22pub enum CVStrategy {
23    /// K-fold cross-validation
24    KFold {
25        /// Number of folds
26        n_folds: usize,
27        /// Shuffle data before folding
28        shuffle: bool,
29    },
30    /// Stratified K-fold (for classification tasks)
31    StratifiedKFold {
32        /// Number of folds
33        n_folds: usize,
34        /// Shuffle data before folding
35        shuffle: bool,
36    },
37    /// Leave-one-out cross-validation
38    LeaveOneOut,
39    /// Leave-P-out cross-validation
40    LeavePOut {
41        /// Number of points to leave out
42        p: usize,
43    },
44    /// Time series split
45    TimeSeriesSplit {
46        /// Number of splits
47        n_splits: usize,
48        /// Maximum training size
49        max_train_size: Option<usize>,
50    },
51    /// Monte Carlo cross-validation
52    MonteCarlo {
53        /// Number of random splits
54        n_splits: usize,
55        /// Test size fraction
56        test_size: f64,
57    },
58}
59
60/// Scoring metrics for cross-validation
61#[derive(Debug, Clone)]
62/// ScoringMetric
63pub enum ScoringMetric {
64    /// Kernel alignment score
65    KernelAlignment,
66    /// Mean squared error (for regression)
67    MeanSquaredError,
68    /// Mean absolute error (for regression)
69    MeanAbsoluteError,
70    /// R² score (for regression)
71    R2Score,
72    /// Accuracy (for classification)
73    Accuracy,
74    /// F1 score (for classification)
75    F1Score,
76    /// Log-likelihood
77    LogLikelihood,
78    /// Custom scoring function
79    Custom,
80}
81
82/// Configuration for cross-validation
83#[derive(Debug, Clone)]
84/// CrossValidationConfig
85pub struct CrossValidationConfig {
86    /// Cross-validation strategy
87    pub cv_strategy: CVStrategy,
88    /// Scoring metric
89    pub scoring_metric: ScoringMetric,
90    /// Random seed for reproducibility
91    pub random_seed: Option<u64>,
92    /// Number of parallel jobs
93    pub n_jobs: usize,
94    /// Return training scores as well
95    pub return_train_score: bool,
96    /// Verbose output
97    pub verbose: bool,
98    /// Fit parameters for kernel methods
99    pub fit_params: HashMap<String, f64>,
100}
101
102impl Default for CrossValidationConfig {
103    fn default() -> Self {
104        Self {
105            cv_strategy: CVStrategy::KFold {
106                n_folds: 5,
107                shuffle: true,
108            },
109            scoring_metric: ScoringMetric::KernelAlignment,
110            random_seed: None,
111            n_jobs: num_cpus::get(),
112            return_train_score: false,
113            verbose: false,
114            fit_params: HashMap::new(),
115        }
116    }
117}
118
119/// Results from cross-validation
120#[derive(Debug, Clone)]
121/// CrossValidationResult
122pub struct CrossValidationResult {
123    /// Test scores for each fold
124    pub test_scores: Vec<f64>,
125    /// Training scores for each fold (if requested)
126    pub train_scores: Option<Vec<f64>>,
127    /// Mean test score
128    pub mean_test_score: f64,
129    /// Standard deviation of test scores
130    pub std_test_score: f64,
131    /// Mean training score (if available)
132    pub mean_train_score: Option<f64>,
133    /// Standard deviation of training scores (if available)
134    pub std_train_score: Option<f64>,
135    /// Fit times for each fold
136    pub fit_times: Vec<f64>,
137    /// Score times for each fold
138    pub score_times: Vec<f64>,
139}
140
141/// Cross-validation splitter interface
142pub trait CVSplitter {
143    /// Generate train/test indices for all splits
144    fn split(&self, x: &Array2<f64>, y: Option<&Array1<f64>>) -> Vec<(Vec<usize>, Vec<usize>)>;
145}
146
147/// K-fold cross-validation splitter
148pub struct KFoldSplitter {
149    n_folds: usize,
150    shuffle: bool,
151    random_seed: Option<u64>,
152}
153
154impl KFoldSplitter {
155    pub fn new(n_folds: usize, shuffle: bool, random_seed: Option<u64>) -> Self {
156        Self {
157            n_folds,
158            shuffle,
159            random_seed,
160        }
161    }
162}
163
164impl CVSplitter for KFoldSplitter {
165    fn split(&self, x: &Array2<f64>, _y: Option<&Array1<f64>>) -> Vec<(Vec<usize>, Vec<usize>)> {
166        let n_samples = x.nrows();
167        let mut indices: Vec<usize> = (0..n_samples).collect();
168
169        if self.shuffle {
170            let mut rng = if let Some(seed) = self.random_seed {
171                StdRng::seed_from_u64(seed)
172            } else {
173                StdRng::from_seed(thread_rng().gen())
174            };
175
176            indices.shuffle(&mut rng);
177        }
178
179        let fold_size = n_samples / self.n_folds;
180        let mut splits = Vec::new();
181
182        for fold in 0..self.n_folds {
183            let start = fold * fold_size;
184            let end = if fold == self.n_folds - 1 {
185                n_samples
186            } else {
187                (fold + 1) * fold_size
188            };
189
190            let test_indices = indices[start..end].to_vec();
191            let train_indices = indices[..start]
192                .iter()
193                .chain(indices[end..].iter())
194                .cloned()
195                .collect();
196
197            splits.push((train_indices, test_indices));
198        }
199
200        splits
201    }
202}
203
204/// Time series cross-validation splitter
205pub struct TimeSeriesSplitter {
206    n_splits: usize,
207    max_train_size: Option<usize>,
208}
209
210impl TimeSeriesSplitter {
211    pub fn new(n_splits: usize, max_train_size: Option<usize>) -> Self {
212        Self {
213            n_splits,
214            max_train_size,
215        }
216    }
217}
218
219impl CVSplitter for TimeSeriesSplitter {
220    fn split(&self, x: &Array2<f64>, _y: Option<&Array1<f64>>) -> Vec<(Vec<usize>, Vec<usize>)> {
221        let n_samples = x.nrows();
222        let test_size = n_samples / (self.n_splits + 1);
223        let mut splits = Vec::new();
224
225        for split in 0..self.n_splits {
226            let test_start = (split + 1) * test_size;
227            let test_end = if split == self.n_splits - 1 {
228                n_samples
229            } else {
230                (split + 2) * test_size
231            };
232
233            let train_end = test_start;
234            let train_start = if let Some(max_size) = self.max_train_size {
235                train_end.saturating_sub(max_size)
236            } else {
237                0
238            };
239
240            let train_indices = (train_start..train_end).collect();
241            let test_indices = (test_start..test_end).collect();
242
243            splits.push((train_indices, test_indices));
244        }
245
246        splits
247    }
248}
249
250/// Monte Carlo cross-validation splitter
251pub struct MonteCarloCVSplitter {
252    n_splits: usize,
253    test_size: f64,
254    random_seed: Option<u64>,
255}
256
257impl MonteCarloCVSplitter {
258    pub fn new(n_splits: usize, test_size: f64, random_seed: Option<u64>) -> Self {
259        Self {
260            n_splits,
261            test_size,
262            random_seed,
263        }
264    }
265}
266
267impl CVSplitter for MonteCarloCVSplitter {
268    fn split(&self, x: &Array2<f64>, _y: Option<&Array1<f64>>) -> Vec<(Vec<usize>, Vec<usize>)> {
269        let n_samples = x.nrows();
270        let test_samples = (n_samples as f64 * self.test_size) as usize;
271        let mut rng = if let Some(seed) = self.random_seed {
272            StdRng::seed_from_u64(seed)
273        } else {
274            StdRng::from_seed(thread_rng().gen())
275        };
276
277        let mut splits = Vec::new();
278
279        for _ in 0..self.n_splits {
280            let mut indices: Vec<usize> = (0..n_samples).collect();
281
282            indices.shuffle(&mut rng);
283
284            let test_indices = indices[..test_samples].to_vec();
285            let train_indices = indices[test_samples..].to_vec();
286
287            splits.push((train_indices, test_indices));
288        }
289
290        splits
291    }
292}
293
294/// Main cross-validation framework
295pub struct CrossValidator {
296    config: CrossValidationConfig,
297}
298
299impl CrossValidator {
300    /// Create a new cross-validator
301    pub fn new(config: CrossValidationConfig) -> Self {
302        Self { config }
303    }
304
305    /// Perform cross-validation for RBF sampler
306    pub fn cross_validate_rbf(
307        &self,
308        x: &Array2<f64>,
309        y: Option<&Array1<f64>>,
310        parameters: &ParameterSet,
311    ) -> Result<CrossValidationResult> {
312        let splitter = self.create_splitter()?;
313        let splits = splitter.split(x, y);
314
315        if self.config.verbose {
316            println!("Performing cross-validation with {} splits", splits.len());
317        }
318
319        // Parallel evaluation of each fold
320        let fold_results: Result<Vec<_>> = splits
321            .par_iter()
322            .enumerate()
323            .map(|(fold_idx, (train_indices, test_indices))| {
324                let start_time = std::time::Instant::now();
325
326                // Extract training and test data
327                let x_train = self.extract_samples(x, train_indices);
328                let x_test = self.extract_samples(x, test_indices);
329                let y_train = y.map(|y_data| self.extract_targets(y_data, train_indices));
330                let y_test = y.map(|y_data| self.extract_targets(y_data, test_indices));
331
332                // Fit RBF sampler
333                let sampler = RBFSampler::new(parameters.n_components).gamma(parameters.gamma);
334                let fitted = sampler.fit(&x_train, &())?;
335                let fit_time = start_time.elapsed().as_secs_f64();
336
337                // Transform data
338                let x_train_transformed = fitted.transform(&x_train)?;
339                let x_test_transformed = fitted.transform(&x_test)?;
340
341                // Compute scores
342                let score_start = std::time::Instant::now();
343                let test_score = self.compute_score(
344                    &x_test,
345                    &x_test_transformed,
346                    y_test.as_ref(),
347                    parameters.gamma,
348                )?;
349
350                let train_score = if self.config.return_train_score {
351                    Some(self.compute_score(
352                        &x_train,
353                        &x_train_transformed,
354                        y_train.as_ref(),
355                        parameters.gamma,
356                    )?)
357                } else {
358                    None
359                };
360
361                let score_time = score_start.elapsed().as_secs_f64();
362
363                if self.config.verbose {
364                    println!(
365                        "Fold {}: test_score = {:.6}, fit_time = {:.3}s",
366                        fold_idx, test_score, fit_time
367                    );
368                }
369
370                Ok((test_score, train_score, fit_time, score_time))
371            })
372            .collect();
373
374        let fold_results = fold_results?;
375
376        self.aggregate_results(fold_results)
377    }
378
379    /// Perform cross-validation for Nyström method
380    pub fn cross_validate_nystroem(
381        &self,
382        x: &Array2<f64>,
383        y: Option<&Array1<f64>>,
384        parameters: &ParameterSet,
385    ) -> Result<CrossValidationResult> {
386        use crate::nystroem::Kernel;
387
388        let splitter = self.create_splitter()?;
389        let splits = splitter.split(x, y);
390
391        let fold_results: Result<Vec<_>> = splits
392            .par_iter()
393            .enumerate()
394            .map(|(fold_idx, (train_indices, test_indices))| {
395                let start_time = std::time::Instant::now();
396
397                // Extract data
398                let x_train = self.extract_samples(x, train_indices);
399                let x_test = self.extract_samples(x, test_indices);
400                let y_train = y.map(|y_data| self.extract_targets(y_data, train_indices));
401                let y_test = y.map(|y_data| self.extract_targets(y_data, test_indices));
402
403                // Fit Nyström
404                let kernel = Kernel::Rbf {
405                    gamma: parameters.gamma,
406                };
407                let nystroem = Nystroem::new(kernel, parameters.n_components);
408                let fitted = nystroem.fit(&x_train, &())?;
409                let fit_time = start_time.elapsed().as_secs_f64();
410
411                // Transform data
412                let x_train_transformed = fitted.transform(&x_train)?;
413                let x_test_transformed = fitted.transform(&x_test)?;
414
415                // Compute scores
416                let score_start = std::time::Instant::now();
417                let test_score = self.compute_score(
418                    &x_test,
419                    &x_test_transformed,
420                    y_test.as_ref(),
421                    parameters.gamma,
422                )?;
423
424                let train_score = if self.config.return_train_score {
425                    Some(self.compute_score(
426                        &x_train,
427                        &x_train_transformed,
428                        y_train.as_ref(),
429                        parameters.gamma,
430                    )?)
431                } else {
432                    None
433                };
434
435                let score_time = score_start.elapsed().as_secs_f64();
436
437                if self.config.verbose {
438                    println!("Fold {}: test_score = {:.6}", fold_idx, test_score);
439                }
440
441                Ok((test_score, train_score, fit_time, score_time))
442            })
443            .collect();
444
445        let fold_results = fold_results?;
446        self.aggregate_results(fold_results)
447    }
448
449    /// Cross-validate with parameter search
450    pub fn cross_validate_with_search(
451        &self,
452        x: &Array2<f64>,
453        y: Option<&Array1<f64>>,
454        parameter_learner: &ParameterLearner,
455    ) -> Result<(ParameterSet, CrossValidationResult)> {
456        // First, optimize parameters using the parameter learner
457        let optimization_result = parameter_learner.optimize_rbf_parameters(x, y)?;
458        let best_params = optimization_result.best_parameters;
459
460        if self.config.verbose {
461            println!(
462                "Best parameters found: gamma={:.6}, n_components={}",
463                best_params.gamma, best_params.n_components
464            );
465        }
466
467        // Then perform cross-validation with the best parameters
468        let cv_result = self.cross_validate_rbf(x, y, &best_params)?;
469
470        Ok((best_params, cv_result))
471    }
472
473    /// Grid search with cross-validation
474    pub fn grid_search_cv(
475        &self,
476        x: &Array2<f64>,
477        y: Option<&Array1<f64>>,
478        param_grid: &HashMap<String, Vec<f64>>,
479    ) -> Result<(
480        ParameterSet,
481        f64,
482        HashMap<ParameterSet, CrossValidationResult>,
483    )> {
484        let gamma_values = param_grid.get("gamma").ok_or_else(|| {
485            SklearsError::InvalidInput("gamma parameter missing from grid".to_string())
486        })?;
487
488        let n_components_values = param_grid
489            .get("n_components")
490            .ok_or_else(|| {
491                SklearsError::InvalidInput("n_components parameter missing from grid".to_string())
492            })?
493            .iter()
494            .map(|&x| x as usize)
495            .collect::<Vec<_>>();
496
497        let mut best_score = f64::NEG_INFINITY;
498        let mut best_params = ParameterSet {
499            gamma: gamma_values[0],
500            n_components: n_components_values[0],
501            degree: None,
502            coef0: None,
503        };
504        let mut all_results = HashMap::new();
505
506        if self.config.verbose {
507            println!(
508                "Grid search over {} parameter combinations",
509                gamma_values.len() * n_components_values.len()
510            );
511        }
512
513        for &gamma in gamma_values {
514            for &n_components in &n_components_values {
515                let params = ParameterSet {
516                    gamma,
517                    n_components,
518                    degree: None,
519                    coef0: None,
520                };
521
522                let cv_result = self.cross_validate_rbf(x, y, &params)?;
523                let mean_score = cv_result.mean_test_score;
524
525                all_results.insert(params.clone(), cv_result);
526
527                if mean_score > best_score {
528                    best_score = mean_score;
529                    best_params = params;
530                }
531
532                if self.config.verbose {
533                    println!(
534                        "gamma={:.6}, n_components={}: score={:.6} ± {:.6}",
535                        gamma,
536                        n_components,
537                        mean_score,
538                        all_results
539                            .get(&ParameterSet {
540                                gamma,
541                                n_components,
542                                degree: None,
543                                coef0: None
544                            })
545                            .unwrap()
546                            .std_test_score
547                    );
548                }
549            }
550        }
551
552        Ok((best_params, best_score, all_results))
553    }
554
555    fn create_splitter(&self) -> Result<Box<dyn CVSplitter + Send + Sync>> {
556        match &self.config.cv_strategy {
557            CVStrategy::KFold { n_folds, shuffle } => Ok(Box::new(KFoldSplitter::new(
558                *n_folds,
559                *shuffle,
560                self.config.random_seed,
561            ))),
562            CVStrategy::TimeSeriesSplit {
563                n_splits,
564                max_train_size,
565            } => Ok(Box::new(TimeSeriesSplitter::new(
566                *n_splits,
567                *max_train_size,
568            ))),
569            CVStrategy::MonteCarlo {
570                n_splits,
571                test_size,
572            } => Ok(Box::new(MonteCarloCVSplitter::new(
573                *n_splits,
574                *test_size,
575                self.config.random_seed,
576            ))),
577            _ => {
578                // Fallback to K-fold for unsupported strategies
579                Ok(Box::new(KFoldSplitter::new(
580                    5,
581                    true,
582                    self.config.random_seed,
583                )))
584            }
585        }
586    }
587
588    fn extract_samples(&self, x: &Array2<f64>, indices: &[usize]) -> Array2<f64> {
589        let n_features = x.ncols();
590        let mut result = Array2::zeros((indices.len(), n_features));
591
592        for (i, &idx) in indices.iter().enumerate() {
593            result.row_mut(i).assign(&x.row(idx));
594        }
595
596        result
597    }
598
599    fn extract_targets(&self, y: &Array1<f64>, indices: &[usize]) -> Array1<f64> {
600        let mut result = Array1::zeros(indices.len());
601
602        for (i, &idx) in indices.iter().enumerate() {
603            result[i] = y[idx];
604        }
605
606        result
607    }
608
609    fn compute_score(
610        &self,
611        x: &Array2<f64>,
612        x_transformed: &Array2<f64>,
613        y: Option<&Array1<f64>>,
614        gamma: f64,
615    ) -> Result<f64> {
616        match &self.config.scoring_metric {
617            ScoringMetric::KernelAlignment => {
618                self.compute_kernel_alignment(x, x_transformed, gamma)
619            }
620            ScoringMetric::MeanSquaredError => {
621                if let Some(y_data) = y {
622                    self.compute_mse(x_transformed, y_data)
623                } else {
624                    Err(SklearsError::InvalidInput(
625                        "Target values required for MSE".to_string(),
626                    ))
627                }
628            }
629            ScoringMetric::MeanAbsoluteError => {
630                if let Some(y_data) = y {
631                    self.compute_mae(x_transformed, y_data)
632                } else {
633                    Err(SklearsError::InvalidInput(
634                        "Target values required for MAE".to_string(),
635                    ))
636                }
637            }
638            ScoringMetric::R2Score => {
639                if let Some(y_data) = y {
640                    self.compute_r2_score(x_transformed, y_data)
641                } else {
642                    Err(SklearsError::InvalidInput(
643                        "Target values required for R²".to_string(),
644                    ))
645                }
646            }
647            _ => {
648                // Fallback to kernel alignment
649                self.compute_kernel_alignment(x, x_transformed, gamma)
650            }
651        }
652    }
653
654    fn compute_kernel_alignment(
655        &self,
656        x: &Array2<f64>,
657        x_transformed: &Array2<f64>,
658        gamma: f64,
659    ) -> Result<f64> {
660        let n_samples = x.nrows().min(50); // Limit for efficiency
661        let x_subset = x.slice(s![..n_samples, ..]);
662
663        // Compute exact kernel matrix
664        let mut k_exact = Array2::zeros((n_samples, n_samples));
665        for i in 0..n_samples {
666            for j in 0..n_samples {
667                let diff = &x_subset.row(i) - &x_subset.row(j);
668                let squared_norm = diff.dot(&diff);
669                k_exact[[i, j]] = (-gamma * squared_norm).exp();
670            }
671        }
672
673        // Compute approximate kernel matrix
674        let x_transformed_subset = x_transformed.slice(s![..n_samples, ..]);
675        let k_approx = x_transformed_subset.dot(&x_transformed_subset.t());
676
677        // Compute alignment
678        let k_exact_frobenius = k_exact.iter().map(|&x| x * x).sum::<f64>().sqrt();
679        let k_approx_frobenius = k_approx.iter().map(|&x| x * x).sum::<f64>().sqrt();
680        let k_product = (&k_exact * &k_approx).sum();
681
682        let alignment = k_product / (k_exact_frobenius * k_approx_frobenius);
683        Ok(alignment)
684    }
685
686    fn compute_mse(&self, _x_transformed: &Array2<f64>, y: &Array1<f64>) -> Result<f64> {
687        // Simple linear regression MSE
688        // In practice, you'd want to use a proper regressor
689        let y_mean = y.mean().unwrap_or(0.0);
690        let mse = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum::<f64>() / y.len() as f64;
691        Ok(-mse) // Negative because we want to maximize score
692    }
693
694    fn compute_mae(&self, _x_transformed: &Array2<f64>, y: &Array1<f64>) -> Result<f64> {
695        // Simple linear regression MAE
696        let y_mean = y.mean().unwrap_or(0.0);
697        let mae = y.iter().map(|&yi| (yi - y_mean).abs()).sum::<f64>() / y.len() as f64;
698        Ok(-mae) // Negative because we want to maximize score
699    }
700
701    fn compute_r2_score(&self, _x_transformed: &Array2<f64>, y: &Array1<f64>) -> Result<f64> {
702        // Simple R² score computation
703        let y_mean = y.mean().unwrap_or(0.0);
704        let ss_tot = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum::<f64>();
705        let ss_res = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum::<f64>(); // Simplified
706
707        let r2 = 1.0 - (ss_res / ss_tot);
708        Ok(r2)
709    }
710
711    fn aggregate_results(
712        &self,
713        fold_results: Vec<(f64, Option<f64>, f64, f64)>,
714    ) -> Result<CrossValidationResult> {
715        let test_scores: Vec<f64> = fold_results.iter().map(|(score, _, _, _)| *score).collect();
716        let train_scores: Option<Vec<f64>> = if self.config.return_train_score {
717            Some(
718                fold_results
719                    .iter()
720                    .filter_map(|(_, train_score, _, _)| *train_score)
721                    .collect(),
722            )
723        } else {
724            None
725        };
726        let fit_times: Vec<f64> = fold_results
727            .iter()
728            .map(|(_, _, fit_time, _)| *fit_time)
729            .collect();
730        let score_times: Vec<f64> = fold_results
731            .iter()
732            .map(|(_, _, _, score_time)| *score_time)
733            .collect();
734
735        let mean_test_score = test_scores.iter().sum::<f64>() / test_scores.len() as f64;
736        let variance_test = test_scores
737            .iter()
738            .map(|&score| (score - mean_test_score).powi(2))
739            .sum::<f64>()
740            / test_scores.len() as f64;
741        let std_test_score = variance_test.sqrt();
742
743        let (mean_train_score, std_train_score) = if let Some(ref train_scores) = train_scores {
744            let mean = train_scores.iter().sum::<f64>() / train_scores.len() as f64;
745            let variance = train_scores
746                .iter()
747                .map(|&score| (score - mean).powi(2))
748                .sum::<f64>()
749                / train_scores.len() as f64;
750            (Some(mean), Some(variance.sqrt()))
751        } else {
752            (None, None)
753        };
754
755        Ok(CrossValidationResult {
756            test_scores,
757            train_scores,
758            mean_test_score,
759            std_test_score,
760            mean_train_score,
761            std_train_score,
762            fit_times,
763            score_times,
764        })
765    }
766}
767
768#[allow(non_snake_case)]
769#[cfg(test)]
770mod tests {
771    use super::*;
772    use approx::assert_abs_diff_eq;
773
774    #[test]
775    fn test_kfold_splitter() {
776        let x = Array2::from_shape_vec((20, 3), (0..60).map(|i| i as f64).collect()).unwrap();
777
778        let splitter = KFoldSplitter::new(4, false, Some(42));
779        let splits = splitter.split(&x, None);
780
781        assert_eq!(splits.len(), 4);
782
783        // Check that all indices are covered exactly once
784        let mut all_test_indices: Vec<usize> = Vec::new();
785        for (_, test_indices) in &splits {
786            all_test_indices.extend(test_indices);
787        }
788        all_test_indices.sort();
789
790        let expected_indices: Vec<usize> = (0..20).collect();
791        assert_eq!(all_test_indices, expected_indices);
792
793        // Check fold sizes are approximately equal
794        for (_, test_indices) in &splits {
795            assert!(test_indices.len() >= 4);
796            assert!(test_indices.len() <= 6);
797        }
798    }
799
800    #[test]
801    fn test_time_series_splitter() {
802        let x = Array2::from_shape_vec((30, 2), (0..60).map(|i| i as f64).collect()).unwrap();
803
804        let splitter = TimeSeriesSplitter::new(3, Some(15));
805        let splits = splitter.split(&x, None);
806
807        assert_eq!(splits.len(), 3);
808
809        // Check that training sets are chronologically before test sets
810        for (train_indices, test_indices) in &splits {
811            if !train_indices.is_empty() && !test_indices.is_empty() {
812                let max_train = train_indices.iter().max().unwrap();
813                let min_test = test_indices.iter().min().unwrap();
814                assert!(max_train < min_test);
815            }
816        }
817    }
818
819    #[test]
820    fn test_monte_carlo_splitter() {
821        let x = Array2::from_shape_vec((50, 4), (0..200).map(|i| i as f64).collect()).unwrap();
822
823        let splitter = MonteCarloCVSplitter::new(5, 0.3, Some(123));
824        let splits = splitter.split(&x, None);
825
826        assert_eq!(splits.len(), 5);
827
828        // Check test size is approximately correct
829        for (train_indices, test_indices) in &splits {
830            let total_size = train_indices.len() + test_indices.len();
831            assert_eq!(total_size, 50);
832            assert!(test_indices.len() >= 14); // 30% of 50 = 15, allow some variance
833            assert!(test_indices.len() <= 16);
834        }
835    }
836
837    #[test]
838    fn test_cross_validator_rbf() {
839        let x =
840            Array2::from_shape_vec((40, 5), (0..200).map(|i| i as f64 * 0.01).collect()).unwrap();
841
842        let config = CrossValidationConfig {
843            cv_strategy: CVStrategy::KFold {
844                n_folds: 3,
845                shuffle: true,
846            },
847            scoring_metric: ScoringMetric::KernelAlignment,
848            return_train_score: true,
849            random_seed: Some(42),
850            ..Default::default()
851        };
852
853        let cv = CrossValidator::new(config);
854        let params = ParameterSet {
855            gamma: 0.5,
856            n_components: 20,
857            degree: None,
858            coef0: None,
859        };
860
861        let result = cv.cross_validate_rbf(&x, None, &params).unwrap();
862
863        assert_eq!(result.test_scores.len(), 3);
864        assert!(result.train_scores.is_some());
865        assert_eq!(result.train_scores.as_ref().unwrap().len(), 3);
866        assert!(result.mean_test_score > 0.0);
867        assert!(result.std_test_score >= 0.0);
868        assert!(result.mean_train_score.is_some());
869        assert!(result.std_train_score.is_some());
870        assert_eq!(result.fit_times.len(), 3);
871        assert_eq!(result.score_times.len(), 3);
872    }
873
874    #[test]
875    fn test_cross_validator_nystroem() {
876        let x =
877            Array2::from_shape_vec((30, 4), (0..120).map(|i| i as f64 * 0.02).collect()).unwrap();
878
879        let config = CrossValidationConfig {
880            cv_strategy: CVStrategy::KFold {
881                n_folds: 4,
882                shuffle: false,
883            },
884            scoring_metric: ScoringMetric::KernelAlignment,
885            ..Default::default()
886        };
887
888        let cv = CrossValidator::new(config);
889        let params = ParameterSet {
890            gamma: 1.0,
891            n_components: 15,
892            degree: None,
893            coef0: None,
894        };
895
896        let result = cv.cross_validate_nystroem(&x, None, &params).unwrap();
897
898        assert_eq!(result.test_scores.len(), 4);
899        assert!(result.mean_test_score > 0.0);
900        assert!(result.std_test_score >= 0.0);
901    }
902
903    #[test]
904    fn test_grid_search_cv() {
905        let x =
906            Array2::from_shape_vec((25, 3), (0..75).map(|i| i as f64 * 0.05).collect()).unwrap();
907
908        let config = CrossValidationConfig {
909            cv_strategy: CVStrategy::KFold {
910                n_folds: 3,
911                shuffle: true,
912            },
913            random_seed: Some(789),
914            verbose: false,
915            ..Default::default()
916        };
917
918        let cv = CrossValidator::new(config);
919
920        let mut param_grid = HashMap::new();
921        param_grid.insert("gamma".to_string(), vec![0.1, 1.0]);
922        param_grid.insert("n_components".to_string(), vec![10.0, 20.0]);
923
924        let (best_params, best_score, all_results) =
925            cv.grid_search_cv(&x, None, &param_grid).unwrap();
926
927        assert!(best_score > 0.0);
928        assert!(best_params.gamma == 0.1 || best_params.gamma == 1.0);
929        assert!(best_params.n_components == 10 || best_params.n_components == 20);
930        assert_eq!(all_results.len(), 4); // 2x2 grid
931
932        // Verify that best_score is actually the maximum
933        let max_score = all_results
934            .values()
935            .map(|result| result.mean_test_score)
936            .fold(f64::NEG_INFINITY, f64::max);
937        assert_abs_diff_eq!(best_score, max_score, epsilon = 1e-10);
938    }
939
940    #[test]
941    fn test_cross_validation_with_targets() {
942        let x = Array2::from_shape_vec((20, 3), (0..60).map(|i| i as f64 * 0.1).collect()).unwrap();
943        let y = Array1::from_shape_fn(20, |i| (i as f64 * 0.1).sin());
944
945        let config = CrossValidationConfig {
946            cv_strategy: CVStrategy::KFold {
947                n_folds: 4,
948                shuffle: true,
949            },
950            scoring_metric: ScoringMetric::MeanSquaredError,
951            random_seed: Some(456),
952            ..Default::default()
953        };
954
955        let cv = CrossValidator::new(config);
956        let params = ParameterSet {
957            gamma: 0.8,
958            n_components: 15,
959            degree: None,
960            coef0: None,
961        };
962
963        let result = cv.cross_validate_rbf(&x, Some(&y), &params).unwrap();
964
965        assert_eq!(result.test_scores.len(), 4);
966        // MSE scores are negative (we negate them to convert to maximization)
967        assert!(result.mean_test_score <= 0.0);
968    }
969
970    #[test]
971    fn test_cv_splitter_consistency() {
972        let x = Array2::from_shape_vec((15, 2), (0..30).map(|i| i as f64).collect()).unwrap();
973
974        // Test that the same splitter with same seed produces same results
975        let splitter1 = KFoldSplitter::new(3, true, Some(42));
976        let splitter2 = KFoldSplitter::new(3, true, Some(42));
977
978        let splits1 = splitter1.split(&x, None);
979        let splits2 = splitter2.split(&x, None);
980
981        assert_eq!(splits1.len(), splits2.len());
982        for (split1, split2) in splits1.iter().zip(splits2.iter()) {
983            assert_eq!(split1.0, split2.0); // train indices
984            assert_eq!(split1.1, split2.1); // test indices
985        }
986    }
987
988    #[test]
989    fn test_cross_validation_result_aggregation() {
990        let mut config = CrossValidationConfig::default();
991        config.return_train_score = true;
992        let cv = CrossValidator::new(config);
993
994        let fold_results = vec![
995            (0.8, Some(0.85), 0.1, 0.05),
996            (0.75, Some(0.8), 0.12, 0.04),
997            (0.82, Some(0.88), 0.11, 0.06),
998        ];
999
1000        let result = cv.aggregate_results(fold_results).unwrap();
1001
1002        assert_abs_diff_eq!(result.mean_test_score, 0.79, epsilon = 1e-10);
1003        assert!(result.std_test_score > 0.0);
1004        assert!(result.mean_train_score.is_some());
1005        assert_abs_diff_eq!(
1006            result.mean_train_score.unwrap(),
1007            0.8433333333333334,
1008            epsilon = 1e-10
1009        );
1010    }
1011}