sklears_kernel_approximation/
cross_validation.rs

1//! Cross-validation framework for kernel parameter selection
2//!
3//! This module provides comprehensive cross-validation methods specifically designed
4//! for kernel approximation methods and parameter selection.
5
6use crate::{Nystroem, ParameterLearner, ParameterSet, RBFSampler};
7use rayon::prelude::*;
8use scirs2_core::ndarray::{s, Array1, Array2};
9use scirs2_core::random::rngs::StdRng;
10use scirs2_core::random::seq::SliceRandom;
11use scirs2_core::random::{thread_rng, Rng, SeedableRng};
12use sklears_core::{
13    error::{Result, SklearsError},
14    traits::{Fit, Transform},
15};
16use std::collections::HashMap;
17
18/// Cross-validation strategies
19#[derive(Debug, Clone)]
20/// CVStrategy
21pub enum CVStrategy {
22    /// K-fold cross-validation
23    KFold {
24        /// Number of folds
25        n_folds: usize,
26        /// Shuffle data before folding
27        shuffle: bool,
28    },
29    /// Stratified K-fold (for classification tasks)
30    StratifiedKFold {
31        /// Number of folds
32        n_folds: usize,
33        /// Shuffle data before folding
34        shuffle: bool,
35    },
36    /// Leave-one-out cross-validation
37    LeaveOneOut,
38    /// Leave-P-out cross-validation
39    LeavePOut {
40        /// Number of points to leave out
41        p: usize,
42    },
43    /// Time series split
44    TimeSeriesSplit {
45        /// Number of splits
46        n_splits: usize,
47        /// Maximum training size
48        max_train_size: Option<usize>,
49    },
50    /// Monte Carlo cross-validation
51    MonteCarlo {
52        /// Number of random splits
53        n_splits: usize,
54        /// Test size fraction
55        test_size: f64,
56    },
57}
58
59/// Scoring metrics for cross-validation
60#[derive(Debug, Clone)]
61/// ScoringMetric
62pub enum ScoringMetric {
63    /// Kernel alignment score
64    KernelAlignment,
65    /// Mean squared error (for regression)
66    MeanSquaredError,
67    /// Mean absolute error (for regression)
68    MeanAbsoluteError,
69    /// R² score (for regression)
70    R2Score,
71    /// Accuracy (for classification)
72    Accuracy,
73    /// F1 score (for classification)
74    F1Score,
75    /// Log-likelihood
76    LogLikelihood,
77    /// Custom scoring function
78    Custom,
79}
80
81/// Configuration for cross-validation
82#[derive(Debug, Clone)]
83/// CrossValidationConfig
84pub struct CrossValidationConfig {
85    /// Cross-validation strategy
86    pub cv_strategy: CVStrategy,
87    /// Scoring metric
88    pub scoring_metric: ScoringMetric,
89    /// Random seed for reproducibility
90    pub random_seed: Option<u64>,
91    /// Number of parallel jobs
92    pub n_jobs: usize,
93    /// Return training scores as well
94    pub return_train_score: bool,
95    /// Verbose output
96    pub verbose: bool,
97    /// Fit parameters for kernel methods
98    pub fit_params: HashMap<String, f64>,
99}
100
101impl Default for CrossValidationConfig {
102    fn default() -> Self {
103        Self {
104            cv_strategy: CVStrategy::KFold {
105                n_folds: 5,
106                shuffle: true,
107            },
108            scoring_metric: ScoringMetric::KernelAlignment,
109            random_seed: None,
110            n_jobs: num_cpus::get(),
111            return_train_score: false,
112            verbose: false,
113            fit_params: HashMap::new(),
114        }
115    }
116}
117
118/// Results from cross-validation
119#[derive(Debug, Clone)]
120/// CrossValidationResult
121pub struct CrossValidationResult {
122    /// Test scores for each fold
123    pub test_scores: Vec<f64>,
124    /// Training scores for each fold (if requested)
125    pub train_scores: Option<Vec<f64>>,
126    /// Mean test score
127    pub mean_test_score: f64,
128    /// Standard deviation of test scores
129    pub std_test_score: f64,
130    /// Mean training score (if available)
131    pub mean_train_score: Option<f64>,
132    /// Standard deviation of training scores (if available)
133    pub std_train_score: Option<f64>,
134    /// Fit times for each fold
135    pub fit_times: Vec<f64>,
136    /// Score times for each fold
137    pub score_times: Vec<f64>,
138}
139
140/// Cross-validation splitter interface
141pub trait CVSplitter {
142    /// Generate train/test indices for all splits
143    fn split(&self, x: &Array2<f64>, y: Option<&Array1<f64>>) -> Vec<(Vec<usize>, Vec<usize>)>;
144}
145
146/// K-fold cross-validation splitter
147pub struct KFoldSplitter {
148    n_folds: usize,
149    shuffle: bool,
150    random_seed: Option<u64>,
151}
152
153impl KFoldSplitter {
154    pub fn new(n_folds: usize, shuffle: bool, random_seed: Option<u64>) -> Self {
155        Self {
156            n_folds,
157            shuffle,
158            random_seed,
159        }
160    }
161}
162
163impl CVSplitter for KFoldSplitter {
164    fn split(&self, x: &Array2<f64>, _y: Option<&Array1<f64>>) -> Vec<(Vec<usize>, Vec<usize>)> {
165        let n_samples = x.nrows();
166        let mut indices: Vec<usize> = (0..n_samples).collect();
167
168        if self.shuffle {
169            let mut rng = if let Some(seed) = self.random_seed {
170                StdRng::seed_from_u64(seed)
171            } else {
172                StdRng::from_seed(thread_rng().gen())
173            };
174
175            indices.shuffle(&mut rng);
176        }
177
178        let fold_size = n_samples / self.n_folds;
179        let mut splits = Vec::new();
180
181        for fold in 0..self.n_folds {
182            let start = fold * fold_size;
183            let end = if fold == self.n_folds - 1 {
184                n_samples
185            } else {
186                (fold + 1) * fold_size
187            };
188
189            let test_indices = indices[start..end].to_vec();
190            let train_indices = indices[..start]
191                .iter()
192                .chain(indices[end..].iter())
193                .cloned()
194                .collect();
195
196            splits.push((train_indices, test_indices));
197        }
198
199        splits
200    }
201}
202
203/// Time series cross-validation splitter
204pub struct TimeSeriesSplitter {
205    n_splits: usize,
206    max_train_size: Option<usize>,
207}
208
209impl TimeSeriesSplitter {
210    pub fn new(n_splits: usize, max_train_size: Option<usize>) -> Self {
211        Self {
212            n_splits,
213            max_train_size,
214        }
215    }
216}
217
218impl CVSplitter for TimeSeriesSplitter {
219    fn split(&self, x: &Array2<f64>, _y: Option<&Array1<f64>>) -> Vec<(Vec<usize>, Vec<usize>)> {
220        let n_samples = x.nrows();
221        let test_size = n_samples / (self.n_splits + 1);
222        let mut splits = Vec::new();
223
224        for split in 0..self.n_splits {
225            let test_start = (split + 1) * test_size;
226            let test_end = if split == self.n_splits - 1 {
227                n_samples
228            } else {
229                (split + 2) * test_size
230            };
231
232            let train_end = test_start;
233            let train_start = if let Some(max_size) = self.max_train_size {
234                train_end.saturating_sub(max_size)
235            } else {
236                0
237            };
238
239            let train_indices = (train_start..train_end).collect();
240            let test_indices = (test_start..test_end).collect();
241
242            splits.push((train_indices, test_indices));
243        }
244
245        splits
246    }
247}
248
249/// Monte Carlo cross-validation splitter
250pub struct MonteCarloCVSplitter {
251    n_splits: usize,
252    test_size: f64,
253    random_seed: Option<u64>,
254}
255
256impl MonteCarloCVSplitter {
257    pub fn new(n_splits: usize, test_size: f64, random_seed: Option<u64>) -> Self {
258        Self {
259            n_splits,
260            test_size,
261            random_seed,
262        }
263    }
264}
265
266impl CVSplitter for MonteCarloCVSplitter {
267    fn split(&self, x: &Array2<f64>, _y: Option<&Array1<f64>>) -> Vec<(Vec<usize>, Vec<usize>)> {
268        let n_samples = x.nrows();
269        let test_samples = (n_samples as f64 * self.test_size) as usize;
270        let mut rng = if let Some(seed) = self.random_seed {
271            StdRng::seed_from_u64(seed)
272        } else {
273            StdRng::from_seed(thread_rng().gen())
274        };
275
276        let mut splits = Vec::new();
277
278        for _ in 0..self.n_splits {
279            let mut indices: Vec<usize> = (0..n_samples).collect();
280
281            indices.shuffle(&mut rng);
282
283            let test_indices = indices[..test_samples].to_vec();
284            let train_indices = indices[test_samples..].to_vec();
285
286            splits.push((train_indices, test_indices));
287        }
288
289        splits
290    }
291}
292
293/// Main cross-validation framework
294pub struct CrossValidator {
295    config: CrossValidationConfig,
296}
297
298impl CrossValidator {
299    /// Create a new cross-validator
300    pub fn new(config: CrossValidationConfig) -> Self {
301        Self { config }
302    }
303
304    /// Perform cross-validation for RBF sampler
305    pub fn cross_validate_rbf(
306        &self,
307        x: &Array2<f64>,
308        y: Option<&Array1<f64>>,
309        parameters: &ParameterSet,
310    ) -> Result<CrossValidationResult> {
311        let splitter = self.create_splitter()?;
312        let splits = splitter.split(x, y);
313
314        if self.config.verbose {
315            println!("Performing cross-validation with {} splits", splits.len());
316        }
317
318        // Parallel evaluation of each fold
319        let fold_results: Result<Vec<_>> = splits
320            .par_iter()
321            .enumerate()
322            .map(|(fold_idx, (train_indices, test_indices))| {
323                let start_time = std::time::Instant::now();
324
325                // Extract training and test data
326                let x_train = self.extract_samples(x, train_indices);
327                let x_test = self.extract_samples(x, test_indices);
328                let y_train = y.map(|y_data| self.extract_targets(y_data, train_indices));
329                let y_test = y.map(|y_data| self.extract_targets(y_data, test_indices));
330
331                // Fit RBF sampler
332                let sampler = RBFSampler::new(parameters.n_components).gamma(parameters.gamma);
333                let fitted = sampler.fit(&x_train, &())?;
334                let fit_time = start_time.elapsed().as_secs_f64();
335
336                // Transform data
337                let x_train_transformed = fitted.transform(&x_train)?;
338                let x_test_transformed = fitted.transform(&x_test)?;
339
340                // Compute scores
341                let score_start = std::time::Instant::now();
342                let test_score = self.compute_score(
343                    &x_test,
344                    &x_test_transformed,
345                    y_test.as_ref(),
346                    parameters.gamma,
347                )?;
348
349                let train_score = if self.config.return_train_score {
350                    Some(self.compute_score(
351                        &x_train,
352                        &x_train_transformed,
353                        y_train.as_ref(),
354                        parameters.gamma,
355                    )?)
356                } else {
357                    None
358                };
359
360                let score_time = score_start.elapsed().as_secs_f64();
361
362                if self.config.verbose {
363                    println!(
364                        "Fold {}: test_score = {:.6}, fit_time = {:.3}s",
365                        fold_idx, test_score, fit_time
366                    );
367                }
368
369                Ok((test_score, train_score, fit_time, score_time))
370            })
371            .collect();
372
373        let fold_results = fold_results?;
374
375        self.aggregate_results(fold_results)
376    }
377
378    /// Perform cross-validation for Nyström method
379    pub fn cross_validate_nystroem(
380        &self,
381        x: &Array2<f64>,
382        y: Option<&Array1<f64>>,
383        parameters: &ParameterSet,
384    ) -> Result<CrossValidationResult> {
385        use crate::nystroem::Kernel;
386
387        let splitter = self.create_splitter()?;
388        let splits = splitter.split(x, y);
389
390        let fold_results: Result<Vec<_>> = splits
391            .par_iter()
392            .enumerate()
393            .map(|(fold_idx, (train_indices, test_indices))| {
394                let start_time = std::time::Instant::now();
395
396                // Extract data
397                let x_train = self.extract_samples(x, train_indices);
398                let x_test = self.extract_samples(x, test_indices);
399                let y_train = y.map(|y_data| self.extract_targets(y_data, train_indices));
400                let y_test = y.map(|y_data| self.extract_targets(y_data, test_indices));
401
402                // Fit Nyström
403                let kernel = Kernel::Rbf {
404                    gamma: parameters.gamma,
405                };
406                let nystroem = Nystroem::new(kernel, parameters.n_components);
407                let fitted = nystroem.fit(&x_train, &())?;
408                let fit_time = start_time.elapsed().as_secs_f64();
409
410                // Transform data
411                let x_train_transformed = fitted.transform(&x_train)?;
412                let x_test_transformed = fitted.transform(&x_test)?;
413
414                // Compute scores
415                let score_start = std::time::Instant::now();
416                let test_score = self.compute_score(
417                    &x_test,
418                    &x_test_transformed,
419                    y_test.as_ref(),
420                    parameters.gamma,
421                )?;
422
423                let train_score = if self.config.return_train_score {
424                    Some(self.compute_score(
425                        &x_train,
426                        &x_train_transformed,
427                        y_train.as_ref(),
428                        parameters.gamma,
429                    )?)
430                } else {
431                    None
432                };
433
434                let score_time = score_start.elapsed().as_secs_f64();
435
436                if self.config.verbose {
437                    println!("Fold {}: test_score = {:.6}", fold_idx, test_score);
438                }
439
440                Ok((test_score, train_score, fit_time, score_time))
441            })
442            .collect();
443
444        let fold_results = fold_results?;
445        self.aggregate_results(fold_results)
446    }
447
448    /// Cross-validate with parameter search
449    pub fn cross_validate_with_search(
450        &self,
451        x: &Array2<f64>,
452        y: Option<&Array1<f64>>,
453        parameter_learner: &ParameterLearner,
454    ) -> Result<(ParameterSet, CrossValidationResult)> {
455        // First, optimize parameters using the parameter learner
456        let optimization_result = parameter_learner.optimize_rbf_parameters(x, y)?;
457        let best_params = optimization_result.best_parameters;
458
459        if self.config.verbose {
460            println!(
461                "Best parameters found: gamma={:.6}, n_components={}",
462                best_params.gamma, best_params.n_components
463            );
464        }
465
466        // Then perform cross-validation with the best parameters
467        let cv_result = self.cross_validate_rbf(x, y, &best_params)?;
468
469        Ok((best_params, cv_result))
470    }
471
472    /// Grid search with cross-validation
473    pub fn grid_search_cv(
474        &self,
475        x: &Array2<f64>,
476        y: Option<&Array1<f64>>,
477        param_grid: &HashMap<String, Vec<f64>>,
478    ) -> Result<(
479        ParameterSet,
480        f64,
481        HashMap<ParameterSet, CrossValidationResult>,
482    )> {
483        let gamma_values = param_grid.get("gamma").ok_or_else(|| {
484            SklearsError::InvalidInput("gamma parameter missing from grid".to_string())
485        })?;
486
487        let n_components_values = param_grid
488            .get("n_components")
489            .ok_or_else(|| {
490                SklearsError::InvalidInput("n_components parameter missing from grid".to_string())
491            })?
492            .iter()
493            .map(|&x| x as usize)
494            .collect::<Vec<_>>();
495
496        let mut best_score = f64::NEG_INFINITY;
497        let mut best_params = ParameterSet {
498            gamma: gamma_values[0],
499            n_components: n_components_values[0],
500            degree: None,
501            coef0: None,
502        };
503        let mut all_results = HashMap::new();
504
505        if self.config.verbose {
506            println!(
507                "Grid search over {} parameter combinations",
508                gamma_values.len() * n_components_values.len()
509            );
510        }
511
512        for &gamma in gamma_values {
513            for &n_components in &n_components_values {
514                let params = ParameterSet {
515                    gamma,
516                    n_components,
517                    degree: None,
518                    coef0: None,
519                };
520
521                let cv_result = self.cross_validate_rbf(x, y, &params)?;
522                let mean_score = cv_result.mean_test_score;
523
524                all_results.insert(params.clone(), cv_result);
525
526                if mean_score > best_score {
527                    best_score = mean_score;
528                    best_params = params;
529                }
530
531                if self.config.verbose {
532                    println!(
533                        "gamma={:.6}, n_components={}: score={:.6} ± {:.6}",
534                        gamma,
535                        n_components,
536                        mean_score,
537                        all_results
538                            .get(&ParameterSet {
539                                gamma,
540                                n_components,
541                                degree: None,
542                                coef0: None
543                            })
544                            .unwrap()
545                            .std_test_score
546                    );
547                }
548            }
549        }
550
551        Ok((best_params, best_score, all_results))
552    }
553
554    fn create_splitter(&self) -> Result<Box<dyn CVSplitter + Send + Sync>> {
555        match &self.config.cv_strategy {
556            CVStrategy::KFold { n_folds, shuffle } => Ok(Box::new(KFoldSplitter::new(
557                *n_folds,
558                *shuffle,
559                self.config.random_seed,
560            ))),
561            CVStrategy::TimeSeriesSplit {
562                n_splits,
563                max_train_size,
564            } => Ok(Box::new(TimeSeriesSplitter::new(
565                *n_splits,
566                *max_train_size,
567            ))),
568            CVStrategy::MonteCarlo {
569                n_splits,
570                test_size,
571            } => Ok(Box::new(MonteCarloCVSplitter::new(
572                *n_splits,
573                *test_size,
574                self.config.random_seed,
575            ))),
576            _ => {
577                // Fallback to K-fold for unsupported strategies
578                Ok(Box::new(KFoldSplitter::new(
579                    5,
580                    true,
581                    self.config.random_seed,
582                )))
583            }
584        }
585    }
586
587    fn extract_samples(&self, x: &Array2<f64>, indices: &[usize]) -> Array2<f64> {
588        let n_features = x.ncols();
589        let mut result = Array2::zeros((indices.len(), n_features));
590
591        for (i, &idx) in indices.iter().enumerate() {
592            result.row_mut(i).assign(&x.row(idx));
593        }
594
595        result
596    }
597
598    fn extract_targets(&self, y: &Array1<f64>, indices: &[usize]) -> Array1<f64> {
599        let mut result = Array1::zeros(indices.len());
600
601        for (i, &idx) in indices.iter().enumerate() {
602            result[i] = y[idx];
603        }
604
605        result
606    }
607
608    fn compute_score(
609        &self,
610        x: &Array2<f64>,
611        x_transformed: &Array2<f64>,
612        y: Option<&Array1<f64>>,
613        gamma: f64,
614    ) -> Result<f64> {
615        match &self.config.scoring_metric {
616            ScoringMetric::KernelAlignment => {
617                self.compute_kernel_alignment(x, x_transformed, gamma)
618            }
619            ScoringMetric::MeanSquaredError => {
620                if let Some(y_data) = y {
621                    self.compute_mse(x_transformed, y_data)
622                } else {
623                    Err(SklearsError::InvalidInput(
624                        "Target values required for MSE".to_string(),
625                    ))
626                }
627            }
628            ScoringMetric::MeanAbsoluteError => {
629                if let Some(y_data) = y {
630                    self.compute_mae(x_transformed, y_data)
631                } else {
632                    Err(SklearsError::InvalidInput(
633                        "Target values required for MAE".to_string(),
634                    ))
635                }
636            }
637            ScoringMetric::R2Score => {
638                if let Some(y_data) = y {
639                    self.compute_r2_score(x_transformed, y_data)
640                } else {
641                    Err(SklearsError::InvalidInput(
642                        "Target values required for R²".to_string(),
643                    ))
644                }
645            }
646            _ => {
647                // Fallback to kernel alignment
648                self.compute_kernel_alignment(x, x_transformed, gamma)
649            }
650        }
651    }
652
653    fn compute_kernel_alignment(
654        &self,
655        x: &Array2<f64>,
656        x_transformed: &Array2<f64>,
657        gamma: f64,
658    ) -> Result<f64> {
659        let n_samples = x.nrows().min(50); // Limit for efficiency
660        let x_subset = x.slice(s![..n_samples, ..]);
661
662        // Compute exact kernel matrix
663        let mut k_exact = Array2::zeros((n_samples, n_samples));
664        for i in 0..n_samples {
665            for j in 0..n_samples {
666                let diff = &x_subset.row(i) - &x_subset.row(j);
667                let squared_norm = diff.dot(&diff);
668                k_exact[[i, j]] = (-gamma * squared_norm).exp();
669            }
670        }
671
672        // Compute approximate kernel matrix
673        let x_transformed_subset = x_transformed.slice(s![..n_samples, ..]);
674        let k_approx = x_transformed_subset.dot(&x_transformed_subset.t());
675
676        // Compute alignment
677        let k_exact_frobenius = k_exact.iter().map(|&x| x * x).sum::<f64>().sqrt();
678        let k_approx_frobenius = k_approx.iter().map(|&x| x * x).sum::<f64>().sqrt();
679        let k_product = (&k_exact * &k_approx).sum();
680
681        let alignment = k_product / (k_exact_frobenius * k_approx_frobenius);
682        Ok(alignment)
683    }
684
685    fn compute_mse(&self, x_transformed: &Array2<f64>, y: &Array1<f64>) -> Result<f64> {
686        // Simple linear regression MSE
687        // In practice, you'd want to use a proper regressor
688        let y_mean = y.mean().unwrap_or(0.0);
689        let mse = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum::<f64>() / y.len() as f64;
690        Ok(-mse) // Negative because we want to maximize score
691    }
692
693    fn compute_mae(&self, x_transformed: &Array2<f64>, y: &Array1<f64>) -> Result<f64> {
694        // Simple linear regression MAE
695        let y_mean = y.mean().unwrap_or(0.0);
696        let mae = y.iter().map(|&yi| (yi - y_mean).abs()).sum::<f64>() / y.len() as f64;
697        Ok(-mae) // Negative because we want to maximize score
698    }
699
700    fn compute_r2_score(&self, x_transformed: &Array2<f64>, y: &Array1<f64>) -> Result<f64> {
701        // Simple R² score computation
702        let y_mean = y.mean().unwrap_or(0.0);
703        let ss_tot = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum::<f64>();
704        let ss_res = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum::<f64>(); // Simplified
705
706        let r2 = 1.0 - (ss_res / ss_tot);
707        Ok(r2)
708    }
709
710    fn aggregate_results(
711        &self,
712        fold_results: Vec<(f64, Option<f64>, f64, f64)>,
713    ) -> Result<CrossValidationResult> {
714        let test_scores: Vec<f64> = fold_results.iter().map(|(score, _, _, _)| *score).collect();
715        let train_scores: Option<Vec<f64>> = if self.config.return_train_score {
716            Some(
717                fold_results
718                    .iter()
719                    .filter_map(|(_, train_score, _, _)| *train_score)
720                    .collect(),
721            )
722        } else {
723            None
724        };
725        let fit_times: Vec<f64> = fold_results
726            .iter()
727            .map(|(_, _, fit_time, _)| *fit_time)
728            .collect();
729        let score_times: Vec<f64> = fold_results
730            .iter()
731            .map(|(_, _, _, score_time)| *score_time)
732            .collect();
733
734        let mean_test_score = test_scores.iter().sum::<f64>() / test_scores.len() as f64;
735        let variance_test = test_scores
736            .iter()
737            .map(|&score| (score - mean_test_score).powi(2))
738            .sum::<f64>()
739            / test_scores.len() as f64;
740        let std_test_score = variance_test.sqrt();
741
742        let (mean_train_score, std_train_score) = if let Some(ref train_scores) = train_scores {
743            let mean = train_scores.iter().sum::<f64>() / train_scores.len() as f64;
744            let variance = train_scores
745                .iter()
746                .map(|&score| (score - mean).powi(2))
747                .sum::<f64>()
748                / train_scores.len() as f64;
749            (Some(mean), Some(variance.sqrt()))
750        } else {
751            (None, None)
752        };
753
754        Ok(CrossValidationResult {
755            test_scores,
756            train_scores,
757            mean_test_score,
758            std_test_score,
759            mean_train_score,
760            std_train_score,
761            fit_times,
762            score_times,
763        })
764    }
765}
766
767#[allow(non_snake_case)]
768#[cfg(test)]
769mod tests {
770    use super::*;
771    use approx::assert_abs_diff_eq;
772
773    #[test]
774    fn test_kfold_splitter() {
775        let x = Array2::from_shape_vec((20, 3), (0..60).map(|i| i as f64).collect()).unwrap();
776
777        let splitter = KFoldSplitter::new(4, false, Some(42));
778        let splits = splitter.split(&x, None);
779
780        assert_eq!(splits.len(), 4);
781
782        // Check that all indices are covered exactly once
783        let mut all_test_indices: Vec<usize> = Vec::new();
784        for (_, test_indices) in &splits {
785            all_test_indices.extend(test_indices);
786        }
787        all_test_indices.sort();
788
789        let expected_indices: Vec<usize> = (0..20).collect();
790        assert_eq!(all_test_indices, expected_indices);
791
792        // Check fold sizes are approximately equal
793        for (_, test_indices) in &splits {
794            assert!(test_indices.len() >= 4);
795            assert!(test_indices.len() <= 6);
796        }
797    }
798
799    #[test]
800    fn test_time_series_splitter() {
801        let x = Array2::from_shape_vec((30, 2), (0..60).map(|i| i as f64).collect()).unwrap();
802
803        let splitter = TimeSeriesSplitter::new(3, Some(15));
804        let splits = splitter.split(&x, None);
805
806        assert_eq!(splits.len(), 3);
807
808        // Check that training sets are chronologically before test sets
809        for (train_indices, test_indices) in &splits {
810            if !train_indices.is_empty() && !test_indices.is_empty() {
811                let max_train = train_indices.iter().max().unwrap();
812                let min_test = test_indices.iter().min().unwrap();
813                assert!(max_train < min_test);
814            }
815        }
816    }
817
818    #[test]
819    fn test_monte_carlo_splitter() {
820        let x = Array2::from_shape_vec((50, 4), (0..200).map(|i| i as f64).collect()).unwrap();
821
822        let splitter = MonteCarloCVSplitter::new(5, 0.3, Some(123));
823        let splits = splitter.split(&x, None);
824
825        assert_eq!(splits.len(), 5);
826
827        // Check test size is approximately correct
828        for (train_indices, test_indices) in &splits {
829            let total_size = train_indices.len() + test_indices.len();
830            assert_eq!(total_size, 50);
831            assert!(test_indices.len() >= 14); // 30% of 50 = 15, allow some variance
832            assert!(test_indices.len() <= 16);
833        }
834    }
835
836    #[test]
837    fn test_cross_validator_rbf() {
838        let x =
839            Array2::from_shape_vec((40, 5), (0..200).map(|i| i as f64 * 0.01).collect()).unwrap();
840
841        let config = CrossValidationConfig {
842            cv_strategy: CVStrategy::KFold {
843                n_folds: 3,
844                shuffle: true,
845            },
846            scoring_metric: ScoringMetric::KernelAlignment,
847            return_train_score: true,
848            random_seed: Some(42),
849            ..Default::default()
850        };
851
852        let cv = CrossValidator::new(config);
853        let params = ParameterSet {
854            gamma: 0.5,
855            n_components: 20,
856            degree: None,
857            coef0: None,
858        };
859
860        let result = cv.cross_validate_rbf(&x, None, &params).unwrap();
861
862        assert_eq!(result.test_scores.len(), 3);
863        assert!(result.train_scores.is_some());
864        assert_eq!(result.train_scores.as_ref().unwrap().len(), 3);
865        assert!(result.mean_test_score > 0.0);
866        assert!(result.std_test_score >= 0.0);
867        assert!(result.mean_train_score.is_some());
868        assert!(result.std_train_score.is_some());
869        assert_eq!(result.fit_times.len(), 3);
870        assert_eq!(result.score_times.len(), 3);
871    }
872
873    #[test]
874    fn test_cross_validator_nystroem() {
875        let x =
876            Array2::from_shape_vec((30, 4), (0..120).map(|i| i as f64 * 0.02).collect()).unwrap();
877
878        let config = CrossValidationConfig {
879            cv_strategy: CVStrategy::KFold {
880                n_folds: 4,
881                shuffle: false,
882            },
883            scoring_metric: ScoringMetric::KernelAlignment,
884            ..Default::default()
885        };
886
887        let cv = CrossValidator::new(config);
888        let params = ParameterSet {
889            gamma: 1.0,
890            n_components: 15,
891            degree: None,
892            coef0: None,
893        };
894
895        let result = cv.cross_validate_nystroem(&x, None, &params).unwrap();
896
897        assert_eq!(result.test_scores.len(), 4);
898        assert!(result.mean_test_score > 0.0);
899        assert!(result.std_test_score >= 0.0);
900    }
901
902    #[test]
903    fn test_grid_search_cv() {
904        let x =
905            Array2::from_shape_vec((25, 3), (0..75).map(|i| i as f64 * 0.05).collect()).unwrap();
906
907        let config = CrossValidationConfig {
908            cv_strategy: CVStrategy::KFold {
909                n_folds: 3,
910                shuffle: true,
911            },
912            random_seed: Some(789),
913            verbose: false,
914            ..Default::default()
915        };
916
917        let cv = CrossValidator::new(config);
918
919        let mut param_grid = HashMap::new();
920        param_grid.insert("gamma".to_string(), vec![0.1, 1.0]);
921        param_grid.insert("n_components".to_string(), vec![10.0, 20.0]);
922
923        let (best_params, best_score, all_results) =
924            cv.grid_search_cv(&x, None, &param_grid).unwrap();
925
926        assert!(best_score > 0.0);
927        assert!(best_params.gamma == 0.1 || best_params.gamma == 1.0);
928        assert!(best_params.n_components == 10 || best_params.n_components == 20);
929        assert_eq!(all_results.len(), 4); // 2x2 grid
930
931        // Verify that best_score is actually the maximum
932        let max_score = all_results
933            .values()
934            .map(|result| result.mean_test_score)
935            .fold(f64::NEG_INFINITY, f64::max);
936        assert_abs_diff_eq!(best_score, max_score, epsilon = 1e-10);
937    }
938
939    #[test]
940    fn test_cross_validation_with_targets() {
941        let x = Array2::from_shape_vec((20, 3), (0..60).map(|i| i as f64 * 0.1).collect()).unwrap();
942        let y = Array1::from_shape_fn(20, |i| (i as f64 * 0.1).sin());
943
944        let config = CrossValidationConfig {
945            cv_strategy: CVStrategy::KFold {
946                n_folds: 4,
947                shuffle: true,
948            },
949            scoring_metric: ScoringMetric::MeanSquaredError,
950            random_seed: Some(456),
951            ..Default::default()
952        };
953
954        let cv = CrossValidator::new(config);
955        let params = ParameterSet {
956            gamma: 0.8,
957            n_components: 15,
958            degree: None,
959            coef0: None,
960        };
961
962        let result = cv.cross_validate_rbf(&x, Some(&y), &params).unwrap();
963
964        assert_eq!(result.test_scores.len(), 4);
965        // MSE scores are negative (we negate them to convert to maximization)
966        assert!(result.mean_test_score <= 0.0);
967    }
968
969    #[test]
970    fn test_cv_splitter_consistency() {
971        let x = Array2::from_shape_vec((15, 2), (0..30).map(|i| i as f64).collect()).unwrap();
972
973        // Test that the same splitter with same seed produces same results
974        let splitter1 = KFoldSplitter::new(3, true, Some(42));
975        let splitter2 = KFoldSplitter::new(3, true, Some(42));
976
977        let splits1 = splitter1.split(&x, None);
978        let splits2 = splitter2.split(&x, None);
979
980        assert_eq!(splits1.len(), splits2.len());
981        for (split1, split2) in splits1.iter().zip(splits2.iter()) {
982            assert_eq!(split1.0, split2.0); // train indices
983            assert_eq!(split1.1, split2.1); // test indices
984        }
985    }
986
987    #[test]
988    fn test_cross_validation_result_aggregation() {
989        let mut config = CrossValidationConfig::default();
990        config.return_train_score = true;
991        let cv = CrossValidator::new(config);
992
993        let fold_results = vec![
994            (0.8, Some(0.85), 0.1, 0.05),
995            (0.75, Some(0.8), 0.12, 0.04),
996            (0.82, Some(0.88), 0.11, 0.06),
997        ];
998
999        let result = cv.aggregate_results(fold_results).unwrap();
1000
1001        assert_abs_diff_eq!(result.mean_test_score, 0.79, epsilon = 1e-10);
1002        assert!(result.std_test_score > 0.0);
1003        assert!(result.mean_train_score.is_some());
1004        assert_abs_diff_eq!(
1005            result.mean_train_score.unwrap(),
1006            0.8433333333333334,
1007            epsilon = 1e-10
1008        );
1009    }
1010}