1use crate::{Nystroem, ParameterLearner, ParameterSet, RBFSampler};
7use rayon::prelude::*;
8use scirs2_core::ndarray::{s, Array1, Array2};
9use scirs2_core::random::rngs::StdRng;
10use scirs2_core::random::seq::SliceRandom;
11use scirs2_core::random::Rng;
12use scirs2_core::random::{thread_rng, SeedableRng};
13use sklears_core::{
14 error::{Result, SklearsError},
15 traits::{Fit, Transform},
16};
17use std::collections::HashMap;
18
19#[derive(Debug, Clone)]
21pub enum CVStrategy {
23 KFold {
25 n_folds: usize,
27 shuffle: bool,
29 },
30 StratifiedKFold {
32 n_folds: usize,
34 shuffle: bool,
36 },
37 LeaveOneOut,
39 LeavePOut {
41 p: usize,
43 },
44 TimeSeriesSplit {
46 n_splits: usize,
48 max_train_size: Option<usize>,
50 },
51 MonteCarlo {
53 n_splits: usize,
55 test_size: f64,
57 },
58}
59
60#[derive(Debug, Clone)]
62pub enum ScoringMetric {
64 KernelAlignment,
66 MeanSquaredError,
68 MeanAbsoluteError,
70 R2Score,
72 Accuracy,
74 F1Score,
76 LogLikelihood,
78 Custom,
80}
81
82#[derive(Debug, Clone)]
84pub struct CrossValidationConfig {
86 pub cv_strategy: CVStrategy,
88 pub scoring_metric: ScoringMetric,
90 pub random_seed: Option<u64>,
92 pub n_jobs: usize,
94 pub return_train_score: bool,
96 pub verbose: bool,
98 pub fit_params: HashMap<String, f64>,
100}
101
102impl Default for CrossValidationConfig {
103 fn default() -> Self {
104 Self {
105 cv_strategy: CVStrategy::KFold {
106 n_folds: 5,
107 shuffle: true,
108 },
109 scoring_metric: ScoringMetric::KernelAlignment,
110 random_seed: None,
111 n_jobs: num_cpus::get(),
112 return_train_score: false,
113 verbose: false,
114 fit_params: HashMap::new(),
115 }
116 }
117}
118
119#[derive(Debug, Clone)]
121pub struct CrossValidationResult {
123 pub test_scores: Vec<f64>,
125 pub train_scores: Option<Vec<f64>>,
127 pub mean_test_score: f64,
129 pub std_test_score: f64,
131 pub mean_train_score: Option<f64>,
133 pub std_train_score: Option<f64>,
135 pub fit_times: Vec<f64>,
137 pub score_times: Vec<f64>,
139}
140
141pub trait CVSplitter {
143 fn split(&self, x: &Array2<f64>, y: Option<&Array1<f64>>) -> Vec<(Vec<usize>, Vec<usize>)>;
145}
146
147pub struct KFoldSplitter {
149 n_folds: usize,
150 shuffle: bool,
151 random_seed: Option<u64>,
152}
153
154impl KFoldSplitter {
155 pub fn new(n_folds: usize, shuffle: bool, random_seed: Option<u64>) -> Self {
156 Self {
157 n_folds,
158 shuffle,
159 random_seed,
160 }
161 }
162}
163
164impl CVSplitter for KFoldSplitter {
165 fn split(&self, x: &Array2<f64>, _y: Option<&Array1<f64>>) -> Vec<(Vec<usize>, Vec<usize>)> {
166 let n_samples = x.nrows();
167 let mut indices: Vec<usize> = (0..n_samples).collect();
168
169 if self.shuffle {
170 let mut rng = if let Some(seed) = self.random_seed {
171 StdRng::seed_from_u64(seed)
172 } else {
173 StdRng::from_seed(thread_rng().gen())
174 };
175
176 indices.shuffle(&mut rng);
177 }
178
179 let fold_size = n_samples / self.n_folds;
180 let mut splits = Vec::new();
181
182 for fold in 0..self.n_folds {
183 let start = fold * fold_size;
184 let end = if fold == self.n_folds - 1 {
185 n_samples
186 } else {
187 (fold + 1) * fold_size
188 };
189
190 let test_indices = indices[start..end].to_vec();
191 let train_indices = indices[..start]
192 .iter()
193 .chain(indices[end..].iter())
194 .cloned()
195 .collect();
196
197 splits.push((train_indices, test_indices));
198 }
199
200 splits
201 }
202}
203
204pub struct TimeSeriesSplitter {
206 n_splits: usize,
207 max_train_size: Option<usize>,
208}
209
210impl TimeSeriesSplitter {
211 pub fn new(n_splits: usize, max_train_size: Option<usize>) -> Self {
212 Self {
213 n_splits,
214 max_train_size,
215 }
216 }
217}
218
219impl CVSplitter for TimeSeriesSplitter {
220 fn split(&self, x: &Array2<f64>, _y: Option<&Array1<f64>>) -> Vec<(Vec<usize>, Vec<usize>)> {
221 let n_samples = x.nrows();
222 let test_size = n_samples / (self.n_splits + 1);
223 let mut splits = Vec::new();
224
225 for split in 0..self.n_splits {
226 let test_start = (split + 1) * test_size;
227 let test_end = if split == self.n_splits - 1 {
228 n_samples
229 } else {
230 (split + 2) * test_size
231 };
232
233 let train_end = test_start;
234 let train_start = if let Some(max_size) = self.max_train_size {
235 train_end.saturating_sub(max_size)
236 } else {
237 0
238 };
239
240 let train_indices = (train_start..train_end).collect();
241 let test_indices = (test_start..test_end).collect();
242
243 splits.push((train_indices, test_indices));
244 }
245
246 splits
247 }
248}
249
250pub struct MonteCarloCVSplitter {
252 n_splits: usize,
253 test_size: f64,
254 random_seed: Option<u64>,
255}
256
257impl MonteCarloCVSplitter {
258 pub fn new(n_splits: usize, test_size: f64, random_seed: Option<u64>) -> Self {
259 Self {
260 n_splits,
261 test_size,
262 random_seed,
263 }
264 }
265}
266
267impl CVSplitter for MonteCarloCVSplitter {
268 fn split(&self, x: &Array2<f64>, _y: Option<&Array1<f64>>) -> Vec<(Vec<usize>, Vec<usize>)> {
269 let n_samples = x.nrows();
270 let test_samples = (n_samples as f64 * self.test_size) as usize;
271 let mut rng = if let Some(seed) = self.random_seed {
272 StdRng::seed_from_u64(seed)
273 } else {
274 StdRng::from_seed(thread_rng().gen())
275 };
276
277 let mut splits = Vec::new();
278
279 for _ in 0..self.n_splits {
280 let mut indices: Vec<usize> = (0..n_samples).collect();
281
282 indices.shuffle(&mut rng);
283
284 let test_indices = indices[..test_samples].to_vec();
285 let train_indices = indices[test_samples..].to_vec();
286
287 splits.push((train_indices, test_indices));
288 }
289
290 splits
291 }
292}
293
294pub struct CrossValidator {
296 config: CrossValidationConfig,
297}
298
299impl CrossValidator {
300 pub fn new(config: CrossValidationConfig) -> Self {
302 Self { config }
303 }
304
305 pub fn cross_validate_rbf(
307 &self,
308 x: &Array2<f64>,
309 y: Option<&Array1<f64>>,
310 parameters: &ParameterSet,
311 ) -> Result<CrossValidationResult> {
312 let splitter = self.create_splitter()?;
313 let splits = splitter.split(x, y);
314
315 if self.config.verbose {
316 println!("Performing cross-validation with {} splits", splits.len());
317 }
318
319 let fold_results: Result<Vec<_>> = splits
321 .par_iter()
322 .enumerate()
323 .map(|(fold_idx, (train_indices, test_indices))| {
324 let start_time = std::time::Instant::now();
325
326 let x_train = self.extract_samples(x, train_indices);
328 let x_test = self.extract_samples(x, test_indices);
329 let y_train = y.map(|y_data| self.extract_targets(y_data, train_indices));
330 let y_test = y.map(|y_data| self.extract_targets(y_data, test_indices));
331
332 let sampler = RBFSampler::new(parameters.n_components).gamma(parameters.gamma);
334 let fitted = sampler.fit(&x_train, &())?;
335 let fit_time = start_time.elapsed().as_secs_f64();
336
337 let x_train_transformed = fitted.transform(&x_train)?;
339 let x_test_transformed = fitted.transform(&x_test)?;
340
341 let score_start = std::time::Instant::now();
343 let test_score = self.compute_score(
344 &x_test,
345 &x_test_transformed,
346 y_test.as_ref(),
347 parameters.gamma,
348 )?;
349
350 let train_score = if self.config.return_train_score {
351 Some(self.compute_score(
352 &x_train,
353 &x_train_transformed,
354 y_train.as_ref(),
355 parameters.gamma,
356 )?)
357 } else {
358 None
359 };
360
361 let score_time = score_start.elapsed().as_secs_f64();
362
363 if self.config.verbose {
364 println!(
365 "Fold {}: test_score = {:.6}, fit_time = {:.3}s",
366 fold_idx, test_score, fit_time
367 );
368 }
369
370 Ok((test_score, train_score, fit_time, score_time))
371 })
372 .collect();
373
374 let fold_results = fold_results?;
375
376 self.aggregate_results(fold_results)
377 }
378
379 pub fn cross_validate_nystroem(
381 &self,
382 x: &Array2<f64>,
383 y: Option<&Array1<f64>>,
384 parameters: &ParameterSet,
385 ) -> Result<CrossValidationResult> {
386 use crate::nystroem::Kernel;
387
388 let splitter = self.create_splitter()?;
389 let splits = splitter.split(x, y);
390
391 let fold_results: Result<Vec<_>> = splits
392 .par_iter()
393 .enumerate()
394 .map(|(fold_idx, (train_indices, test_indices))| {
395 let start_time = std::time::Instant::now();
396
397 let x_train = self.extract_samples(x, train_indices);
399 let x_test = self.extract_samples(x, test_indices);
400 let y_train = y.map(|y_data| self.extract_targets(y_data, train_indices));
401 let y_test = y.map(|y_data| self.extract_targets(y_data, test_indices));
402
403 let kernel = Kernel::Rbf {
405 gamma: parameters.gamma,
406 };
407 let nystroem = Nystroem::new(kernel, parameters.n_components);
408 let fitted = nystroem.fit(&x_train, &())?;
409 let fit_time = start_time.elapsed().as_secs_f64();
410
411 let x_train_transformed = fitted.transform(&x_train)?;
413 let x_test_transformed = fitted.transform(&x_test)?;
414
415 let score_start = std::time::Instant::now();
417 let test_score = self.compute_score(
418 &x_test,
419 &x_test_transformed,
420 y_test.as_ref(),
421 parameters.gamma,
422 )?;
423
424 let train_score = if self.config.return_train_score {
425 Some(self.compute_score(
426 &x_train,
427 &x_train_transformed,
428 y_train.as_ref(),
429 parameters.gamma,
430 )?)
431 } else {
432 None
433 };
434
435 let score_time = score_start.elapsed().as_secs_f64();
436
437 if self.config.verbose {
438 println!("Fold {}: test_score = {:.6}", fold_idx, test_score);
439 }
440
441 Ok((test_score, train_score, fit_time, score_time))
442 })
443 .collect();
444
445 let fold_results = fold_results?;
446 self.aggregate_results(fold_results)
447 }
448
449 pub fn cross_validate_with_search(
451 &self,
452 x: &Array2<f64>,
453 y: Option<&Array1<f64>>,
454 parameter_learner: &ParameterLearner,
455 ) -> Result<(ParameterSet, CrossValidationResult)> {
456 let optimization_result = parameter_learner.optimize_rbf_parameters(x, y)?;
458 let best_params = optimization_result.best_parameters;
459
460 if self.config.verbose {
461 println!(
462 "Best parameters found: gamma={:.6}, n_components={}",
463 best_params.gamma, best_params.n_components
464 );
465 }
466
467 let cv_result = self.cross_validate_rbf(x, y, &best_params)?;
469
470 Ok((best_params, cv_result))
471 }
472
473 pub fn grid_search_cv(
475 &self,
476 x: &Array2<f64>,
477 y: Option<&Array1<f64>>,
478 param_grid: &HashMap<String, Vec<f64>>,
479 ) -> Result<(
480 ParameterSet,
481 f64,
482 HashMap<ParameterSet, CrossValidationResult>,
483 )> {
484 let gamma_values = param_grid.get("gamma").ok_or_else(|| {
485 SklearsError::InvalidInput("gamma parameter missing from grid".to_string())
486 })?;
487
488 let n_components_values = param_grid
489 .get("n_components")
490 .ok_or_else(|| {
491 SklearsError::InvalidInput("n_components parameter missing from grid".to_string())
492 })?
493 .iter()
494 .map(|&x| x as usize)
495 .collect::<Vec<_>>();
496
497 let mut best_score = f64::NEG_INFINITY;
498 let mut best_params = ParameterSet {
499 gamma: gamma_values[0],
500 n_components: n_components_values[0],
501 degree: None,
502 coef0: None,
503 };
504 let mut all_results = HashMap::new();
505
506 if self.config.verbose {
507 println!(
508 "Grid search over {} parameter combinations",
509 gamma_values.len() * n_components_values.len()
510 );
511 }
512
513 for &gamma in gamma_values {
514 for &n_components in &n_components_values {
515 let params = ParameterSet {
516 gamma,
517 n_components,
518 degree: None,
519 coef0: None,
520 };
521
522 let cv_result = self.cross_validate_rbf(x, y, ¶ms)?;
523 let mean_score = cv_result.mean_test_score;
524
525 all_results.insert(params.clone(), cv_result);
526
527 if mean_score > best_score {
528 best_score = mean_score;
529 best_params = params;
530 }
531
532 if self.config.verbose {
533 println!(
534 "gamma={:.6}, n_components={}: score={:.6} ± {:.6}",
535 gamma,
536 n_components,
537 mean_score,
538 all_results
539 .get(&ParameterSet {
540 gamma,
541 n_components,
542 degree: None,
543 coef0: None
544 })
545 .unwrap()
546 .std_test_score
547 );
548 }
549 }
550 }
551
552 Ok((best_params, best_score, all_results))
553 }
554
555 fn create_splitter(&self) -> Result<Box<dyn CVSplitter + Send + Sync>> {
556 match &self.config.cv_strategy {
557 CVStrategy::KFold { n_folds, shuffle } => Ok(Box::new(KFoldSplitter::new(
558 *n_folds,
559 *shuffle,
560 self.config.random_seed,
561 ))),
562 CVStrategy::TimeSeriesSplit {
563 n_splits,
564 max_train_size,
565 } => Ok(Box::new(TimeSeriesSplitter::new(
566 *n_splits,
567 *max_train_size,
568 ))),
569 CVStrategy::MonteCarlo {
570 n_splits,
571 test_size,
572 } => Ok(Box::new(MonteCarloCVSplitter::new(
573 *n_splits,
574 *test_size,
575 self.config.random_seed,
576 ))),
577 _ => {
578 Ok(Box::new(KFoldSplitter::new(
580 5,
581 true,
582 self.config.random_seed,
583 )))
584 }
585 }
586 }
587
588 fn extract_samples(&self, x: &Array2<f64>, indices: &[usize]) -> Array2<f64> {
589 let n_features = x.ncols();
590 let mut result = Array2::zeros((indices.len(), n_features));
591
592 for (i, &idx) in indices.iter().enumerate() {
593 result.row_mut(i).assign(&x.row(idx));
594 }
595
596 result
597 }
598
599 fn extract_targets(&self, y: &Array1<f64>, indices: &[usize]) -> Array1<f64> {
600 let mut result = Array1::zeros(indices.len());
601
602 for (i, &idx) in indices.iter().enumerate() {
603 result[i] = y[idx];
604 }
605
606 result
607 }
608
609 fn compute_score(
610 &self,
611 x: &Array2<f64>,
612 x_transformed: &Array2<f64>,
613 y: Option<&Array1<f64>>,
614 gamma: f64,
615 ) -> Result<f64> {
616 match &self.config.scoring_metric {
617 ScoringMetric::KernelAlignment => {
618 self.compute_kernel_alignment(x, x_transformed, gamma)
619 }
620 ScoringMetric::MeanSquaredError => {
621 if let Some(y_data) = y {
622 self.compute_mse(x_transformed, y_data)
623 } else {
624 Err(SklearsError::InvalidInput(
625 "Target values required for MSE".to_string(),
626 ))
627 }
628 }
629 ScoringMetric::MeanAbsoluteError => {
630 if let Some(y_data) = y {
631 self.compute_mae(x_transformed, y_data)
632 } else {
633 Err(SklearsError::InvalidInput(
634 "Target values required for MAE".to_string(),
635 ))
636 }
637 }
638 ScoringMetric::R2Score => {
639 if let Some(y_data) = y {
640 self.compute_r2_score(x_transformed, y_data)
641 } else {
642 Err(SklearsError::InvalidInput(
643 "Target values required for R²".to_string(),
644 ))
645 }
646 }
647 _ => {
648 self.compute_kernel_alignment(x, x_transformed, gamma)
650 }
651 }
652 }
653
654 fn compute_kernel_alignment(
655 &self,
656 x: &Array2<f64>,
657 x_transformed: &Array2<f64>,
658 gamma: f64,
659 ) -> Result<f64> {
660 let n_samples = x.nrows().min(50); let x_subset = x.slice(s![..n_samples, ..]);
662
663 let mut k_exact = Array2::zeros((n_samples, n_samples));
665 for i in 0..n_samples {
666 for j in 0..n_samples {
667 let diff = &x_subset.row(i) - &x_subset.row(j);
668 let squared_norm = diff.dot(&diff);
669 k_exact[[i, j]] = (-gamma * squared_norm).exp();
670 }
671 }
672
673 let x_transformed_subset = x_transformed.slice(s![..n_samples, ..]);
675 let k_approx = x_transformed_subset.dot(&x_transformed_subset.t());
676
677 let k_exact_frobenius = k_exact.iter().map(|&x| x * x).sum::<f64>().sqrt();
679 let k_approx_frobenius = k_approx.iter().map(|&x| x * x).sum::<f64>().sqrt();
680 let k_product = (&k_exact * &k_approx).sum();
681
682 let alignment = k_product / (k_exact_frobenius * k_approx_frobenius);
683 Ok(alignment)
684 }
685
686 fn compute_mse(&self, _x_transformed: &Array2<f64>, y: &Array1<f64>) -> Result<f64> {
687 let y_mean = y.mean().unwrap_or(0.0);
690 let mse = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum::<f64>() / y.len() as f64;
691 Ok(-mse) }
693
694 fn compute_mae(&self, _x_transformed: &Array2<f64>, y: &Array1<f64>) -> Result<f64> {
695 let y_mean = y.mean().unwrap_or(0.0);
697 let mae = y.iter().map(|&yi| (yi - y_mean).abs()).sum::<f64>() / y.len() as f64;
698 Ok(-mae) }
700
701 fn compute_r2_score(&self, _x_transformed: &Array2<f64>, y: &Array1<f64>) -> Result<f64> {
702 let y_mean = y.mean().unwrap_or(0.0);
704 let ss_tot = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum::<f64>();
705 let ss_res = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum::<f64>(); let r2 = 1.0 - (ss_res / ss_tot);
708 Ok(r2)
709 }
710
711 fn aggregate_results(
712 &self,
713 fold_results: Vec<(f64, Option<f64>, f64, f64)>,
714 ) -> Result<CrossValidationResult> {
715 let test_scores: Vec<f64> = fold_results.iter().map(|(score, _, _, _)| *score).collect();
716 let train_scores: Option<Vec<f64>> = if self.config.return_train_score {
717 Some(
718 fold_results
719 .iter()
720 .filter_map(|(_, train_score, _, _)| *train_score)
721 .collect(),
722 )
723 } else {
724 None
725 };
726 let fit_times: Vec<f64> = fold_results
727 .iter()
728 .map(|(_, _, fit_time, _)| *fit_time)
729 .collect();
730 let score_times: Vec<f64> = fold_results
731 .iter()
732 .map(|(_, _, _, score_time)| *score_time)
733 .collect();
734
735 let mean_test_score = test_scores.iter().sum::<f64>() / test_scores.len() as f64;
736 let variance_test = test_scores
737 .iter()
738 .map(|&score| (score - mean_test_score).powi(2))
739 .sum::<f64>()
740 / test_scores.len() as f64;
741 let std_test_score = variance_test.sqrt();
742
743 let (mean_train_score, std_train_score) = if let Some(ref train_scores) = train_scores {
744 let mean = train_scores.iter().sum::<f64>() / train_scores.len() as f64;
745 let variance = train_scores
746 .iter()
747 .map(|&score| (score - mean).powi(2))
748 .sum::<f64>()
749 / train_scores.len() as f64;
750 (Some(mean), Some(variance.sqrt()))
751 } else {
752 (None, None)
753 };
754
755 Ok(CrossValidationResult {
756 test_scores,
757 train_scores,
758 mean_test_score,
759 std_test_score,
760 mean_train_score,
761 std_train_score,
762 fit_times,
763 score_times,
764 })
765 }
766}
767
768#[allow(non_snake_case)]
769#[cfg(test)]
770mod tests {
771 use super::*;
772 use approx::assert_abs_diff_eq;
773
774 #[test]
775 fn test_kfold_splitter() {
776 let x = Array2::from_shape_vec((20, 3), (0..60).map(|i| i as f64).collect()).unwrap();
777
778 let splitter = KFoldSplitter::new(4, false, Some(42));
779 let splits = splitter.split(&x, None);
780
781 assert_eq!(splits.len(), 4);
782
783 let mut all_test_indices: Vec<usize> = Vec::new();
785 for (_, test_indices) in &splits {
786 all_test_indices.extend(test_indices);
787 }
788 all_test_indices.sort();
789
790 let expected_indices: Vec<usize> = (0..20).collect();
791 assert_eq!(all_test_indices, expected_indices);
792
793 for (_, test_indices) in &splits {
795 assert!(test_indices.len() >= 4);
796 assert!(test_indices.len() <= 6);
797 }
798 }
799
800 #[test]
801 fn test_time_series_splitter() {
802 let x = Array2::from_shape_vec((30, 2), (0..60).map(|i| i as f64).collect()).unwrap();
803
804 let splitter = TimeSeriesSplitter::new(3, Some(15));
805 let splits = splitter.split(&x, None);
806
807 assert_eq!(splits.len(), 3);
808
809 for (train_indices, test_indices) in &splits {
811 if !train_indices.is_empty() && !test_indices.is_empty() {
812 let max_train = train_indices.iter().max().unwrap();
813 let min_test = test_indices.iter().min().unwrap();
814 assert!(max_train < min_test);
815 }
816 }
817 }
818
819 #[test]
820 fn test_monte_carlo_splitter() {
821 let x = Array2::from_shape_vec((50, 4), (0..200).map(|i| i as f64).collect()).unwrap();
822
823 let splitter = MonteCarloCVSplitter::new(5, 0.3, Some(123));
824 let splits = splitter.split(&x, None);
825
826 assert_eq!(splits.len(), 5);
827
828 for (train_indices, test_indices) in &splits {
830 let total_size = train_indices.len() + test_indices.len();
831 assert_eq!(total_size, 50);
832 assert!(test_indices.len() >= 14); assert!(test_indices.len() <= 16);
834 }
835 }
836
837 #[test]
838 fn test_cross_validator_rbf() {
839 let x =
840 Array2::from_shape_vec((40, 5), (0..200).map(|i| i as f64 * 0.01).collect()).unwrap();
841
842 let config = CrossValidationConfig {
843 cv_strategy: CVStrategy::KFold {
844 n_folds: 3,
845 shuffle: true,
846 },
847 scoring_metric: ScoringMetric::KernelAlignment,
848 return_train_score: true,
849 random_seed: Some(42),
850 ..Default::default()
851 };
852
853 let cv = CrossValidator::new(config);
854 let params = ParameterSet {
855 gamma: 0.5,
856 n_components: 20,
857 degree: None,
858 coef0: None,
859 };
860
861 let result = cv.cross_validate_rbf(&x, None, ¶ms).unwrap();
862
863 assert_eq!(result.test_scores.len(), 3);
864 assert!(result.train_scores.is_some());
865 assert_eq!(result.train_scores.as_ref().unwrap().len(), 3);
866 assert!(result.mean_test_score > 0.0);
867 assert!(result.std_test_score >= 0.0);
868 assert!(result.mean_train_score.is_some());
869 assert!(result.std_train_score.is_some());
870 assert_eq!(result.fit_times.len(), 3);
871 assert_eq!(result.score_times.len(), 3);
872 }
873
874 #[test]
875 fn test_cross_validator_nystroem() {
876 let x =
877 Array2::from_shape_vec((30, 4), (0..120).map(|i| i as f64 * 0.02).collect()).unwrap();
878
879 let config = CrossValidationConfig {
880 cv_strategy: CVStrategy::KFold {
881 n_folds: 4,
882 shuffle: false,
883 },
884 scoring_metric: ScoringMetric::KernelAlignment,
885 ..Default::default()
886 };
887
888 let cv = CrossValidator::new(config);
889 let params = ParameterSet {
890 gamma: 1.0,
891 n_components: 15,
892 degree: None,
893 coef0: None,
894 };
895
896 let result = cv.cross_validate_nystroem(&x, None, ¶ms).unwrap();
897
898 assert_eq!(result.test_scores.len(), 4);
899 assert!(result.mean_test_score > 0.0);
900 assert!(result.std_test_score >= 0.0);
901 }
902
903 #[test]
904 fn test_grid_search_cv() {
905 let x =
906 Array2::from_shape_vec((25, 3), (0..75).map(|i| i as f64 * 0.05).collect()).unwrap();
907
908 let config = CrossValidationConfig {
909 cv_strategy: CVStrategy::KFold {
910 n_folds: 3,
911 shuffle: true,
912 },
913 random_seed: Some(789),
914 verbose: false,
915 ..Default::default()
916 };
917
918 let cv = CrossValidator::new(config);
919
920 let mut param_grid = HashMap::new();
921 param_grid.insert("gamma".to_string(), vec![0.1, 1.0]);
922 param_grid.insert("n_components".to_string(), vec![10.0, 20.0]);
923
924 let (best_params, best_score, all_results) =
925 cv.grid_search_cv(&x, None, ¶m_grid).unwrap();
926
927 assert!(best_score > 0.0);
928 assert!(best_params.gamma == 0.1 || best_params.gamma == 1.0);
929 assert!(best_params.n_components == 10 || best_params.n_components == 20);
930 assert_eq!(all_results.len(), 4); let max_score = all_results
934 .values()
935 .map(|result| result.mean_test_score)
936 .fold(f64::NEG_INFINITY, f64::max);
937 assert_abs_diff_eq!(best_score, max_score, epsilon = 1e-10);
938 }
939
940 #[test]
941 fn test_cross_validation_with_targets() {
942 let x = Array2::from_shape_vec((20, 3), (0..60).map(|i| i as f64 * 0.1).collect()).unwrap();
943 let y = Array1::from_shape_fn(20, |i| (i as f64 * 0.1).sin());
944
945 let config = CrossValidationConfig {
946 cv_strategy: CVStrategy::KFold {
947 n_folds: 4,
948 shuffle: true,
949 },
950 scoring_metric: ScoringMetric::MeanSquaredError,
951 random_seed: Some(456),
952 ..Default::default()
953 };
954
955 let cv = CrossValidator::new(config);
956 let params = ParameterSet {
957 gamma: 0.8,
958 n_components: 15,
959 degree: None,
960 coef0: None,
961 };
962
963 let result = cv.cross_validate_rbf(&x, Some(&y), ¶ms).unwrap();
964
965 assert_eq!(result.test_scores.len(), 4);
966 assert!(result.mean_test_score <= 0.0);
968 }
969
970 #[test]
971 fn test_cv_splitter_consistency() {
972 let x = Array2::from_shape_vec((15, 2), (0..30).map(|i| i as f64).collect()).unwrap();
973
974 let splitter1 = KFoldSplitter::new(3, true, Some(42));
976 let splitter2 = KFoldSplitter::new(3, true, Some(42));
977
978 let splits1 = splitter1.split(&x, None);
979 let splits2 = splitter2.split(&x, None);
980
981 assert_eq!(splits1.len(), splits2.len());
982 for (split1, split2) in splits1.iter().zip(splits2.iter()) {
983 assert_eq!(split1.0, split2.0); assert_eq!(split1.1, split2.1); }
986 }
987
988 #[test]
989 fn test_cross_validation_result_aggregation() {
990 let mut config = CrossValidationConfig::default();
991 config.return_train_score = true;
992 let cv = CrossValidator::new(config);
993
994 let fold_results = vec![
995 (0.8, Some(0.85), 0.1, 0.05),
996 (0.75, Some(0.8), 0.12, 0.04),
997 (0.82, Some(0.88), 0.11, 0.06),
998 ];
999
1000 let result = cv.aggregate_results(fold_results).unwrap();
1001
1002 assert_abs_diff_eq!(result.mean_test_score, 0.79, epsilon = 1e-10);
1003 assert!(result.std_test_score > 0.0);
1004 assert!(result.mean_train_score.is_some());
1005 assert_abs_diff_eq!(
1006 result.mean_train_score.unwrap(),
1007 0.8433333333333334,
1008 epsilon = 1e-10
1009 );
1010 }
1011}