1use crate::{Nystroem, ParameterLearner, ParameterSet, RBFSampler};
7use rayon::prelude::*;
8use scirs2_core::ndarray::{s, Array1, Array2};
9use scirs2_core::random::rngs::StdRng;
10use scirs2_core::random::seq::SliceRandom;
11use scirs2_core::random::{thread_rng, Rng, SeedableRng};
12use sklears_core::{
13 error::{Result, SklearsError},
14 traits::{Fit, Transform},
15};
16use std::collections::HashMap;
17
18#[derive(Debug, Clone)]
20pub enum CVStrategy {
22 KFold {
24 n_folds: usize,
26 shuffle: bool,
28 },
29 StratifiedKFold {
31 n_folds: usize,
33 shuffle: bool,
35 },
36 LeaveOneOut,
38 LeavePOut {
40 p: usize,
42 },
43 TimeSeriesSplit {
45 n_splits: usize,
47 max_train_size: Option<usize>,
49 },
50 MonteCarlo {
52 n_splits: usize,
54 test_size: f64,
56 },
57}
58
59#[derive(Debug, Clone)]
61pub enum ScoringMetric {
63 KernelAlignment,
65 MeanSquaredError,
67 MeanAbsoluteError,
69 R2Score,
71 Accuracy,
73 F1Score,
75 LogLikelihood,
77 Custom,
79}
80
81#[derive(Debug, Clone)]
83pub struct CrossValidationConfig {
85 pub cv_strategy: CVStrategy,
87 pub scoring_metric: ScoringMetric,
89 pub random_seed: Option<u64>,
91 pub n_jobs: usize,
93 pub return_train_score: bool,
95 pub verbose: bool,
97 pub fit_params: HashMap<String, f64>,
99}
100
101impl Default for CrossValidationConfig {
102 fn default() -> Self {
103 Self {
104 cv_strategy: CVStrategy::KFold {
105 n_folds: 5,
106 shuffle: true,
107 },
108 scoring_metric: ScoringMetric::KernelAlignment,
109 random_seed: None,
110 n_jobs: num_cpus::get(),
111 return_train_score: false,
112 verbose: false,
113 fit_params: HashMap::new(),
114 }
115 }
116}
117
118#[derive(Debug, Clone)]
120pub struct CrossValidationResult {
122 pub test_scores: Vec<f64>,
124 pub train_scores: Option<Vec<f64>>,
126 pub mean_test_score: f64,
128 pub std_test_score: f64,
130 pub mean_train_score: Option<f64>,
132 pub std_train_score: Option<f64>,
134 pub fit_times: Vec<f64>,
136 pub score_times: Vec<f64>,
138}
139
140pub trait CVSplitter {
142 fn split(&self, x: &Array2<f64>, y: Option<&Array1<f64>>) -> Vec<(Vec<usize>, Vec<usize>)>;
144}
145
146pub struct KFoldSplitter {
148 n_folds: usize,
149 shuffle: bool,
150 random_seed: Option<u64>,
151}
152
153impl KFoldSplitter {
154 pub fn new(n_folds: usize, shuffle: bool, random_seed: Option<u64>) -> Self {
155 Self {
156 n_folds,
157 shuffle,
158 random_seed,
159 }
160 }
161}
162
163impl CVSplitter for KFoldSplitter {
164 fn split(&self, x: &Array2<f64>, _y: Option<&Array1<f64>>) -> Vec<(Vec<usize>, Vec<usize>)> {
165 let n_samples = x.nrows();
166 let mut indices: Vec<usize> = (0..n_samples).collect();
167
168 if self.shuffle {
169 let mut rng = if let Some(seed) = self.random_seed {
170 StdRng::seed_from_u64(seed)
171 } else {
172 StdRng::from_seed(thread_rng().gen())
173 };
174
175 indices.shuffle(&mut rng);
176 }
177
178 let fold_size = n_samples / self.n_folds;
179 let mut splits = Vec::new();
180
181 for fold in 0..self.n_folds {
182 let start = fold * fold_size;
183 let end = if fold == self.n_folds - 1 {
184 n_samples
185 } else {
186 (fold + 1) * fold_size
187 };
188
189 let test_indices = indices[start..end].to_vec();
190 let train_indices = indices[..start]
191 .iter()
192 .chain(indices[end..].iter())
193 .cloned()
194 .collect();
195
196 splits.push((train_indices, test_indices));
197 }
198
199 splits
200 }
201}
202
203pub struct TimeSeriesSplitter {
205 n_splits: usize,
206 max_train_size: Option<usize>,
207}
208
209impl TimeSeriesSplitter {
210 pub fn new(n_splits: usize, max_train_size: Option<usize>) -> Self {
211 Self {
212 n_splits,
213 max_train_size,
214 }
215 }
216}
217
218impl CVSplitter for TimeSeriesSplitter {
219 fn split(&self, x: &Array2<f64>, _y: Option<&Array1<f64>>) -> Vec<(Vec<usize>, Vec<usize>)> {
220 let n_samples = x.nrows();
221 let test_size = n_samples / (self.n_splits + 1);
222 let mut splits = Vec::new();
223
224 for split in 0..self.n_splits {
225 let test_start = (split + 1) * test_size;
226 let test_end = if split == self.n_splits - 1 {
227 n_samples
228 } else {
229 (split + 2) * test_size
230 };
231
232 let train_end = test_start;
233 let train_start = if let Some(max_size) = self.max_train_size {
234 train_end.saturating_sub(max_size)
235 } else {
236 0
237 };
238
239 let train_indices = (train_start..train_end).collect();
240 let test_indices = (test_start..test_end).collect();
241
242 splits.push((train_indices, test_indices));
243 }
244
245 splits
246 }
247}
248
249pub struct MonteCarloCVSplitter {
251 n_splits: usize,
252 test_size: f64,
253 random_seed: Option<u64>,
254}
255
256impl MonteCarloCVSplitter {
257 pub fn new(n_splits: usize, test_size: f64, random_seed: Option<u64>) -> Self {
258 Self {
259 n_splits,
260 test_size,
261 random_seed,
262 }
263 }
264}
265
266impl CVSplitter for MonteCarloCVSplitter {
267 fn split(&self, x: &Array2<f64>, _y: Option<&Array1<f64>>) -> Vec<(Vec<usize>, Vec<usize>)> {
268 let n_samples = x.nrows();
269 let test_samples = (n_samples as f64 * self.test_size) as usize;
270 let mut rng = if let Some(seed) = self.random_seed {
271 StdRng::seed_from_u64(seed)
272 } else {
273 StdRng::from_seed(thread_rng().gen())
274 };
275
276 let mut splits = Vec::new();
277
278 for _ in 0..self.n_splits {
279 let mut indices: Vec<usize> = (0..n_samples).collect();
280
281 indices.shuffle(&mut rng);
282
283 let test_indices = indices[..test_samples].to_vec();
284 let train_indices = indices[test_samples..].to_vec();
285
286 splits.push((train_indices, test_indices));
287 }
288
289 splits
290 }
291}
292
293pub struct CrossValidator {
295 config: CrossValidationConfig,
296}
297
298impl CrossValidator {
299 pub fn new(config: CrossValidationConfig) -> Self {
301 Self { config }
302 }
303
304 pub fn cross_validate_rbf(
306 &self,
307 x: &Array2<f64>,
308 y: Option<&Array1<f64>>,
309 parameters: &ParameterSet,
310 ) -> Result<CrossValidationResult> {
311 let splitter = self.create_splitter()?;
312 let splits = splitter.split(x, y);
313
314 if self.config.verbose {
315 println!("Performing cross-validation with {} splits", splits.len());
316 }
317
318 let fold_results: Result<Vec<_>> = splits
320 .par_iter()
321 .enumerate()
322 .map(|(fold_idx, (train_indices, test_indices))| {
323 let start_time = std::time::Instant::now();
324
325 let x_train = self.extract_samples(x, train_indices);
327 let x_test = self.extract_samples(x, test_indices);
328 let y_train = y.map(|y_data| self.extract_targets(y_data, train_indices));
329 let y_test = y.map(|y_data| self.extract_targets(y_data, test_indices));
330
331 let sampler = RBFSampler::new(parameters.n_components).gamma(parameters.gamma);
333 let fitted = sampler.fit(&x_train, &())?;
334 let fit_time = start_time.elapsed().as_secs_f64();
335
336 let x_train_transformed = fitted.transform(&x_train)?;
338 let x_test_transformed = fitted.transform(&x_test)?;
339
340 let score_start = std::time::Instant::now();
342 let test_score = self.compute_score(
343 &x_test,
344 &x_test_transformed,
345 y_test.as_ref(),
346 parameters.gamma,
347 )?;
348
349 let train_score = if self.config.return_train_score {
350 Some(self.compute_score(
351 &x_train,
352 &x_train_transformed,
353 y_train.as_ref(),
354 parameters.gamma,
355 )?)
356 } else {
357 None
358 };
359
360 let score_time = score_start.elapsed().as_secs_f64();
361
362 if self.config.verbose {
363 println!(
364 "Fold {}: test_score = {:.6}, fit_time = {:.3}s",
365 fold_idx, test_score, fit_time
366 );
367 }
368
369 Ok((test_score, train_score, fit_time, score_time))
370 })
371 .collect();
372
373 let fold_results = fold_results?;
374
375 self.aggregate_results(fold_results)
376 }
377
378 pub fn cross_validate_nystroem(
380 &self,
381 x: &Array2<f64>,
382 y: Option<&Array1<f64>>,
383 parameters: &ParameterSet,
384 ) -> Result<CrossValidationResult> {
385 use crate::nystroem::Kernel;
386
387 let splitter = self.create_splitter()?;
388 let splits = splitter.split(x, y);
389
390 let fold_results: Result<Vec<_>> = splits
391 .par_iter()
392 .enumerate()
393 .map(|(fold_idx, (train_indices, test_indices))| {
394 let start_time = std::time::Instant::now();
395
396 let x_train = self.extract_samples(x, train_indices);
398 let x_test = self.extract_samples(x, test_indices);
399 let y_train = y.map(|y_data| self.extract_targets(y_data, train_indices));
400 let y_test = y.map(|y_data| self.extract_targets(y_data, test_indices));
401
402 let kernel = Kernel::Rbf {
404 gamma: parameters.gamma,
405 };
406 let nystroem = Nystroem::new(kernel, parameters.n_components);
407 let fitted = nystroem.fit(&x_train, &())?;
408 let fit_time = start_time.elapsed().as_secs_f64();
409
410 let x_train_transformed = fitted.transform(&x_train)?;
412 let x_test_transformed = fitted.transform(&x_test)?;
413
414 let score_start = std::time::Instant::now();
416 let test_score = self.compute_score(
417 &x_test,
418 &x_test_transformed,
419 y_test.as_ref(),
420 parameters.gamma,
421 )?;
422
423 let train_score = if self.config.return_train_score {
424 Some(self.compute_score(
425 &x_train,
426 &x_train_transformed,
427 y_train.as_ref(),
428 parameters.gamma,
429 )?)
430 } else {
431 None
432 };
433
434 let score_time = score_start.elapsed().as_secs_f64();
435
436 if self.config.verbose {
437 println!("Fold {}: test_score = {:.6}", fold_idx, test_score);
438 }
439
440 Ok((test_score, train_score, fit_time, score_time))
441 })
442 .collect();
443
444 let fold_results = fold_results?;
445 self.aggregate_results(fold_results)
446 }
447
448 pub fn cross_validate_with_search(
450 &self,
451 x: &Array2<f64>,
452 y: Option<&Array1<f64>>,
453 parameter_learner: &ParameterLearner,
454 ) -> Result<(ParameterSet, CrossValidationResult)> {
455 let optimization_result = parameter_learner.optimize_rbf_parameters(x, y)?;
457 let best_params = optimization_result.best_parameters;
458
459 if self.config.verbose {
460 println!(
461 "Best parameters found: gamma={:.6}, n_components={}",
462 best_params.gamma, best_params.n_components
463 );
464 }
465
466 let cv_result = self.cross_validate_rbf(x, y, &best_params)?;
468
469 Ok((best_params, cv_result))
470 }
471
472 pub fn grid_search_cv(
474 &self,
475 x: &Array2<f64>,
476 y: Option<&Array1<f64>>,
477 param_grid: &HashMap<String, Vec<f64>>,
478 ) -> Result<(
479 ParameterSet,
480 f64,
481 HashMap<ParameterSet, CrossValidationResult>,
482 )> {
483 let gamma_values = param_grid.get("gamma").ok_or_else(|| {
484 SklearsError::InvalidInput("gamma parameter missing from grid".to_string())
485 })?;
486
487 let n_components_values = param_grid
488 .get("n_components")
489 .ok_or_else(|| {
490 SklearsError::InvalidInput("n_components parameter missing from grid".to_string())
491 })?
492 .iter()
493 .map(|&x| x as usize)
494 .collect::<Vec<_>>();
495
496 let mut best_score = f64::NEG_INFINITY;
497 let mut best_params = ParameterSet {
498 gamma: gamma_values[0],
499 n_components: n_components_values[0],
500 degree: None,
501 coef0: None,
502 };
503 let mut all_results = HashMap::new();
504
505 if self.config.verbose {
506 println!(
507 "Grid search over {} parameter combinations",
508 gamma_values.len() * n_components_values.len()
509 );
510 }
511
512 for &gamma in gamma_values {
513 for &n_components in &n_components_values {
514 let params = ParameterSet {
515 gamma,
516 n_components,
517 degree: None,
518 coef0: None,
519 };
520
521 let cv_result = self.cross_validate_rbf(x, y, ¶ms)?;
522 let mean_score = cv_result.mean_test_score;
523
524 all_results.insert(params.clone(), cv_result);
525
526 if mean_score > best_score {
527 best_score = mean_score;
528 best_params = params;
529 }
530
531 if self.config.verbose {
532 println!(
533 "gamma={:.6}, n_components={}: score={:.6} ± {:.6}",
534 gamma,
535 n_components,
536 mean_score,
537 all_results
538 .get(&ParameterSet {
539 gamma,
540 n_components,
541 degree: None,
542 coef0: None
543 })
544 .unwrap()
545 .std_test_score
546 );
547 }
548 }
549 }
550
551 Ok((best_params, best_score, all_results))
552 }
553
554 fn create_splitter(&self) -> Result<Box<dyn CVSplitter + Send + Sync>> {
555 match &self.config.cv_strategy {
556 CVStrategy::KFold { n_folds, shuffle } => Ok(Box::new(KFoldSplitter::new(
557 *n_folds,
558 *shuffle,
559 self.config.random_seed,
560 ))),
561 CVStrategy::TimeSeriesSplit {
562 n_splits,
563 max_train_size,
564 } => Ok(Box::new(TimeSeriesSplitter::new(
565 *n_splits,
566 *max_train_size,
567 ))),
568 CVStrategy::MonteCarlo {
569 n_splits,
570 test_size,
571 } => Ok(Box::new(MonteCarloCVSplitter::new(
572 *n_splits,
573 *test_size,
574 self.config.random_seed,
575 ))),
576 _ => {
577 Ok(Box::new(KFoldSplitter::new(
579 5,
580 true,
581 self.config.random_seed,
582 )))
583 }
584 }
585 }
586
587 fn extract_samples(&self, x: &Array2<f64>, indices: &[usize]) -> Array2<f64> {
588 let n_features = x.ncols();
589 let mut result = Array2::zeros((indices.len(), n_features));
590
591 for (i, &idx) in indices.iter().enumerate() {
592 result.row_mut(i).assign(&x.row(idx));
593 }
594
595 result
596 }
597
598 fn extract_targets(&self, y: &Array1<f64>, indices: &[usize]) -> Array1<f64> {
599 let mut result = Array1::zeros(indices.len());
600
601 for (i, &idx) in indices.iter().enumerate() {
602 result[i] = y[idx];
603 }
604
605 result
606 }
607
608 fn compute_score(
609 &self,
610 x: &Array2<f64>,
611 x_transformed: &Array2<f64>,
612 y: Option<&Array1<f64>>,
613 gamma: f64,
614 ) -> Result<f64> {
615 match &self.config.scoring_metric {
616 ScoringMetric::KernelAlignment => {
617 self.compute_kernel_alignment(x, x_transformed, gamma)
618 }
619 ScoringMetric::MeanSquaredError => {
620 if let Some(y_data) = y {
621 self.compute_mse(x_transformed, y_data)
622 } else {
623 Err(SklearsError::InvalidInput(
624 "Target values required for MSE".to_string(),
625 ))
626 }
627 }
628 ScoringMetric::MeanAbsoluteError => {
629 if let Some(y_data) = y {
630 self.compute_mae(x_transformed, y_data)
631 } else {
632 Err(SklearsError::InvalidInput(
633 "Target values required for MAE".to_string(),
634 ))
635 }
636 }
637 ScoringMetric::R2Score => {
638 if let Some(y_data) = y {
639 self.compute_r2_score(x_transformed, y_data)
640 } else {
641 Err(SklearsError::InvalidInput(
642 "Target values required for R²".to_string(),
643 ))
644 }
645 }
646 _ => {
647 self.compute_kernel_alignment(x, x_transformed, gamma)
649 }
650 }
651 }
652
653 fn compute_kernel_alignment(
654 &self,
655 x: &Array2<f64>,
656 x_transformed: &Array2<f64>,
657 gamma: f64,
658 ) -> Result<f64> {
659 let n_samples = x.nrows().min(50); let x_subset = x.slice(s![..n_samples, ..]);
661
662 let mut k_exact = Array2::zeros((n_samples, n_samples));
664 for i in 0..n_samples {
665 for j in 0..n_samples {
666 let diff = &x_subset.row(i) - &x_subset.row(j);
667 let squared_norm = diff.dot(&diff);
668 k_exact[[i, j]] = (-gamma * squared_norm).exp();
669 }
670 }
671
672 let x_transformed_subset = x_transformed.slice(s![..n_samples, ..]);
674 let k_approx = x_transformed_subset.dot(&x_transformed_subset.t());
675
676 let k_exact_frobenius = k_exact.iter().map(|&x| x * x).sum::<f64>().sqrt();
678 let k_approx_frobenius = k_approx.iter().map(|&x| x * x).sum::<f64>().sqrt();
679 let k_product = (&k_exact * &k_approx).sum();
680
681 let alignment = k_product / (k_exact_frobenius * k_approx_frobenius);
682 Ok(alignment)
683 }
684
685 fn compute_mse(&self, x_transformed: &Array2<f64>, y: &Array1<f64>) -> Result<f64> {
686 let y_mean = y.mean().unwrap_or(0.0);
689 let mse = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum::<f64>() / y.len() as f64;
690 Ok(-mse) }
692
693 fn compute_mae(&self, x_transformed: &Array2<f64>, y: &Array1<f64>) -> Result<f64> {
694 let y_mean = y.mean().unwrap_or(0.0);
696 let mae = y.iter().map(|&yi| (yi - y_mean).abs()).sum::<f64>() / y.len() as f64;
697 Ok(-mae) }
699
700 fn compute_r2_score(&self, x_transformed: &Array2<f64>, y: &Array1<f64>) -> Result<f64> {
701 let y_mean = y.mean().unwrap_or(0.0);
703 let ss_tot = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum::<f64>();
704 let ss_res = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum::<f64>(); let r2 = 1.0 - (ss_res / ss_tot);
707 Ok(r2)
708 }
709
710 fn aggregate_results(
711 &self,
712 fold_results: Vec<(f64, Option<f64>, f64, f64)>,
713 ) -> Result<CrossValidationResult> {
714 let test_scores: Vec<f64> = fold_results.iter().map(|(score, _, _, _)| *score).collect();
715 let train_scores: Option<Vec<f64>> = if self.config.return_train_score {
716 Some(
717 fold_results
718 .iter()
719 .filter_map(|(_, train_score, _, _)| *train_score)
720 .collect(),
721 )
722 } else {
723 None
724 };
725 let fit_times: Vec<f64> = fold_results
726 .iter()
727 .map(|(_, _, fit_time, _)| *fit_time)
728 .collect();
729 let score_times: Vec<f64> = fold_results
730 .iter()
731 .map(|(_, _, _, score_time)| *score_time)
732 .collect();
733
734 let mean_test_score = test_scores.iter().sum::<f64>() / test_scores.len() as f64;
735 let variance_test = test_scores
736 .iter()
737 .map(|&score| (score - mean_test_score).powi(2))
738 .sum::<f64>()
739 / test_scores.len() as f64;
740 let std_test_score = variance_test.sqrt();
741
742 let (mean_train_score, std_train_score) = if let Some(ref train_scores) = train_scores {
743 let mean = train_scores.iter().sum::<f64>() / train_scores.len() as f64;
744 let variance = train_scores
745 .iter()
746 .map(|&score| (score - mean).powi(2))
747 .sum::<f64>()
748 / train_scores.len() as f64;
749 (Some(mean), Some(variance.sqrt()))
750 } else {
751 (None, None)
752 };
753
754 Ok(CrossValidationResult {
755 test_scores,
756 train_scores,
757 mean_test_score,
758 std_test_score,
759 mean_train_score,
760 std_train_score,
761 fit_times,
762 score_times,
763 })
764 }
765}
766
767#[allow(non_snake_case)]
768#[cfg(test)]
769mod tests {
770 use super::*;
771 use approx::assert_abs_diff_eq;
772
773 #[test]
774 fn test_kfold_splitter() {
775 let x = Array2::from_shape_vec((20, 3), (0..60).map(|i| i as f64).collect()).unwrap();
776
777 let splitter = KFoldSplitter::new(4, false, Some(42));
778 let splits = splitter.split(&x, None);
779
780 assert_eq!(splits.len(), 4);
781
782 let mut all_test_indices: Vec<usize> = Vec::new();
784 for (_, test_indices) in &splits {
785 all_test_indices.extend(test_indices);
786 }
787 all_test_indices.sort();
788
789 let expected_indices: Vec<usize> = (0..20).collect();
790 assert_eq!(all_test_indices, expected_indices);
791
792 for (_, test_indices) in &splits {
794 assert!(test_indices.len() >= 4);
795 assert!(test_indices.len() <= 6);
796 }
797 }
798
799 #[test]
800 fn test_time_series_splitter() {
801 let x = Array2::from_shape_vec((30, 2), (0..60).map(|i| i as f64).collect()).unwrap();
802
803 let splitter = TimeSeriesSplitter::new(3, Some(15));
804 let splits = splitter.split(&x, None);
805
806 assert_eq!(splits.len(), 3);
807
808 for (train_indices, test_indices) in &splits {
810 if !train_indices.is_empty() && !test_indices.is_empty() {
811 let max_train = train_indices.iter().max().unwrap();
812 let min_test = test_indices.iter().min().unwrap();
813 assert!(max_train < min_test);
814 }
815 }
816 }
817
818 #[test]
819 fn test_monte_carlo_splitter() {
820 let x = Array2::from_shape_vec((50, 4), (0..200).map(|i| i as f64).collect()).unwrap();
821
822 let splitter = MonteCarloCVSplitter::new(5, 0.3, Some(123));
823 let splits = splitter.split(&x, None);
824
825 assert_eq!(splits.len(), 5);
826
827 for (train_indices, test_indices) in &splits {
829 let total_size = train_indices.len() + test_indices.len();
830 assert_eq!(total_size, 50);
831 assert!(test_indices.len() >= 14); assert!(test_indices.len() <= 16);
833 }
834 }
835
836 #[test]
837 fn test_cross_validator_rbf() {
838 let x =
839 Array2::from_shape_vec((40, 5), (0..200).map(|i| i as f64 * 0.01).collect()).unwrap();
840
841 let config = CrossValidationConfig {
842 cv_strategy: CVStrategy::KFold {
843 n_folds: 3,
844 shuffle: true,
845 },
846 scoring_metric: ScoringMetric::KernelAlignment,
847 return_train_score: true,
848 random_seed: Some(42),
849 ..Default::default()
850 };
851
852 let cv = CrossValidator::new(config);
853 let params = ParameterSet {
854 gamma: 0.5,
855 n_components: 20,
856 degree: None,
857 coef0: None,
858 };
859
860 let result = cv.cross_validate_rbf(&x, None, ¶ms).unwrap();
861
862 assert_eq!(result.test_scores.len(), 3);
863 assert!(result.train_scores.is_some());
864 assert_eq!(result.train_scores.as_ref().unwrap().len(), 3);
865 assert!(result.mean_test_score > 0.0);
866 assert!(result.std_test_score >= 0.0);
867 assert!(result.mean_train_score.is_some());
868 assert!(result.std_train_score.is_some());
869 assert_eq!(result.fit_times.len(), 3);
870 assert_eq!(result.score_times.len(), 3);
871 }
872
873 #[test]
874 fn test_cross_validator_nystroem() {
875 let x =
876 Array2::from_shape_vec((30, 4), (0..120).map(|i| i as f64 * 0.02).collect()).unwrap();
877
878 let config = CrossValidationConfig {
879 cv_strategy: CVStrategy::KFold {
880 n_folds: 4,
881 shuffle: false,
882 },
883 scoring_metric: ScoringMetric::KernelAlignment,
884 ..Default::default()
885 };
886
887 let cv = CrossValidator::new(config);
888 let params = ParameterSet {
889 gamma: 1.0,
890 n_components: 15,
891 degree: None,
892 coef0: None,
893 };
894
895 let result = cv.cross_validate_nystroem(&x, None, ¶ms).unwrap();
896
897 assert_eq!(result.test_scores.len(), 4);
898 assert!(result.mean_test_score > 0.0);
899 assert!(result.std_test_score >= 0.0);
900 }
901
902 #[test]
903 fn test_grid_search_cv() {
904 let x =
905 Array2::from_shape_vec((25, 3), (0..75).map(|i| i as f64 * 0.05).collect()).unwrap();
906
907 let config = CrossValidationConfig {
908 cv_strategy: CVStrategy::KFold {
909 n_folds: 3,
910 shuffle: true,
911 },
912 random_seed: Some(789),
913 verbose: false,
914 ..Default::default()
915 };
916
917 let cv = CrossValidator::new(config);
918
919 let mut param_grid = HashMap::new();
920 param_grid.insert("gamma".to_string(), vec![0.1, 1.0]);
921 param_grid.insert("n_components".to_string(), vec![10.0, 20.0]);
922
923 let (best_params, best_score, all_results) =
924 cv.grid_search_cv(&x, None, ¶m_grid).unwrap();
925
926 assert!(best_score > 0.0);
927 assert!(best_params.gamma == 0.1 || best_params.gamma == 1.0);
928 assert!(best_params.n_components == 10 || best_params.n_components == 20);
929 assert_eq!(all_results.len(), 4); let max_score = all_results
933 .values()
934 .map(|result| result.mean_test_score)
935 .fold(f64::NEG_INFINITY, f64::max);
936 assert_abs_diff_eq!(best_score, max_score, epsilon = 1e-10);
937 }
938
939 #[test]
940 fn test_cross_validation_with_targets() {
941 let x = Array2::from_shape_vec((20, 3), (0..60).map(|i| i as f64 * 0.1).collect()).unwrap();
942 let y = Array1::from_shape_fn(20, |i| (i as f64 * 0.1).sin());
943
944 let config = CrossValidationConfig {
945 cv_strategy: CVStrategy::KFold {
946 n_folds: 4,
947 shuffle: true,
948 },
949 scoring_metric: ScoringMetric::MeanSquaredError,
950 random_seed: Some(456),
951 ..Default::default()
952 };
953
954 let cv = CrossValidator::new(config);
955 let params = ParameterSet {
956 gamma: 0.8,
957 n_components: 15,
958 degree: None,
959 coef0: None,
960 };
961
962 let result = cv.cross_validate_rbf(&x, Some(&y), ¶ms).unwrap();
963
964 assert_eq!(result.test_scores.len(), 4);
965 assert!(result.mean_test_score <= 0.0);
967 }
968
969 #[test]
970 fn test_cv_splitter_consistency() {
971 let x = Array2::from_shape_vec((15, 2), (0..30).map(|i| i as f64).collect()).unwrap();
972
973 let splitter1 = KFoldSplitter::new(3, true, Some(42));
975 let splitter2 = KFoldSplitter::new(3, true, Some(42));
976
977 let splits1 = splitter1.split(&x, None);
978 let splits2 = splitter2.split(&x, None);
979
980 assert_eq!(splits1.len(), splits2.len());
981 for (split1, split2) in splits1.iter().zip(splits2.iter()) {
982 assert_eq!(split1.0, split2.0); assert_eq!(split1.1, split2.1); }
985 }
986
987 #[test]
988 fn test_cross_validation_result_aggregation() {
989 let mut config = CrossValidationConfig::default();
990 config.return_train_score = true;
991 let cv = CrossValidator::new(config);
992
993 let fold_results = vec![
994 (0.8, Some(0.85), 0.1, 0.05),
995 (0.75, Some(0.8), 0.12, 0.04),
996 (0.82, Some(0.88), 0.11, 0.06),
997 ];
998
999 let result = cv.aggregate_results(fold_results).unwrap();
1000
1001 assert_abs_diff_eq!(result.mean_test_score, 0.79, epsilon = 1e-10);
1002 assert!(result.std_test_score > 0.0);
1003 assert!(result.mean_train_score.is_some());
1004 assert_abs_diff_eq!(
1005 result.mean_train_score.unwrap(),
1006 0.8433333333333334,
1007 epsilon = 1e-10
1008 );
1009 }
1010}