1use crate::{Nystroem, ParameterLearner, ParameterSet, RBFSampler};
7use rayon::prelude::*;
8use scirs2_core::ndarray::{s, Array1, Array2};
9use scirs2_core::random::rngs::StdRng;
10use scirs2_core::random::seq::SliceRandom;
11use scirs2_core::random::{thread_rng, SeedableRng};
12use sklears_core::{
13 error::{Result, SklearsError},
14 traits::{Fit, Transform},
15};
16use std::collections::HashMap;
17
18#[derive(Debug, Clone)]
20pub enum CVStrategy {
22 KFold {
24 n_folds: usize,
26 shuffle: bool,
28 },
29 StratifiedKFold {
31 n_folds: usize,
33 shuffle: bool,
35 },
36 LeaveOneOut,
38 LeavePOut {
40 p: usize,
42 },
43 TimeSeriesSplit {
45 n_splits: usize,
47 max_train_size: Option<usize>,
49 },
50 MonteCarlo {
52 n_splits: usize,
54 test_size: f64,
56 },
57}
58
59#[derive(Debug, Clone)]
61pub enum ScoringMetric {
63 KernelAlignment,
65 MeanSquaredError,
67 MeanAbsoluteError,
69 R2Score,
71 Accuracy,
73 F1Score,
75 LogLikelihood,
77 Custom,
79}
80
81#[derive(Debug, Clone)]
83pub struct CrossValidationConfig {
85 pub cv_strategy: CVStrategy,
87 pub scoring_metric: ScoringMetric,
89 pub random_seed: Option<u64>,
91 pub n_jobs: usize,
93 pub return_train_score: bool,
95 pub verbose: bool,
97 pub fit_params: HashMap<String, f64>,
99}
100
101impl Default for CrossValidationConfig {
102 fn default() -> Self {
103 Self {
104 cv_strategy: CVStrategy::KFold {
105 n_folds: 5,
106 shuffle: true,
107 },
108 scoring_metric: ScoringMetric::KernelAlignment,
109 random_seed: None,
110 n_jobs: num_cpus::get(),
111 return_train_score: false,
112 verbose: false,
113 fit_params: HashMap::new(),
114 }
115 }
116}
117
118#[derive(Debug, Clone)]
120pub struct CrossValidationResult {
122 pub test_scores: Vec<f64>,
124 pub train_scores: Option<Vec<f64>>,
126 pub mean_test_score: f64,
128 pub std_test_score: f64,
130 pub mean_train_score: Option<f64>,
132 pub std_train_score: Option<f64>,
134 pub fit_times: Vec<f64>,
136 pub score_times: Vec<f64>,
138}
139
140pub trait CVSplitter {
142 fn split(&self, x: &Array2<f64>, y: Option<&Array1<f64>>) -> Vec<(Vec<usize>, Vec<usize>)>;
144}
145
146pub struct KFoldSplitter {
148 n_folds: usize,
149 shuffle: bool,
150 random_seed: Option<u64>,
151}
152
153impl KFoldSplitter {
154 pub fn new(n_folds: usize, shuffle: bool, random_seed: Option<u64>) -> Self {
155 Self {
156 n_folds,
157 shuffle,
158 random_seed,
159 }
160 }
161}
162
163impl CVSplitter for KFoldSplitter {
164 fn split(&self, x: &Array2<f64>, _y: Option<&Array1<f64>>) -> Vec<(Vec<usize>, Vec<usize>)> {
165 let n_samples = x.nrows();
166 let mut indices: Vec<usize> = (0..n_samples).collect();
167
168 if self.shuffle {
169 let mut rng = if let Some(seed) = self.random_seed {
170 StdRng::seed_from_u64(seed)
171 } else {
172 StdRng::from_seed(thread_rng().random())
173 };
174
175 indices.shuffle(&mut rng);
176 }
177
178 let fold_size = n_samples / self.n_folds;
179 let mut splits = Vec::new();
180
181 for fold in 0..self.n_folds {
182 let start = fold * fold_size;
183 let end = if fold == self.n_folds - 1 {
184 n_samples
185 } else {
186 (fold + 1) * fold_size
187 };
188
189 let test_indices = indices[start..end].to_vec();
190 let train_indices = indices[..start]
191 .iter()
192 .chain(indices[end..].iter())
193 .cloned()
194 .collect();
195
196 splits.push((train_indices, test_indices));
197 }
198
199 splits
200 }
201}
202
203pub struct TimeSeriesSplitter {
205 n_splits: usize,
206 max_train_size: Option<usize>,
207}
208
209impl TimeSeriesSplitter {
210 pub fn new(n_splits: usize, max_train_size: Option<usize>) -> Self {
211 Self {
212 n_splits,
213 max_train_size,
214 }
215 }
216}
217
218impl CVSplitter for TimeSeriesSplitter {
219 fn split(&self, x: &Array2<f64>, _y: Option<&Array1<f64>>) -> Vec<(Vec<usize>, Vec<usize>)> {
220 let n_samples = x.nrows();
221 let test_size = n_samples / (self.n_splits + 1);
222 let mut splits = Vec::new();
223
224 for split in 0..self.n_splits {
225 let test_start = (split + 1) * test_size;
226 let test_end = if split == self.n_splits - 1 {
227 n_samples
228 } else {
229 (split + 2) * test_size
230 };
231
232 let train_end = test_start;
233 let train_start = if let Some(max_size) = self.max_train_size {
234 train_end.saturating_sub(max_size)
235 } else {
236 0
237 };
238
239 let train_indices = (train_start..train_end).collect();
240 let test_indices = (test_start..test_end).collect();
241
242 splits.push((train_indices, test_indices));
243 }
244
245 splits
246 }
247}
248
249pub struct MonteCarloCVSplitter {
251 n_splits: usize,
252 test_size: f64,
253 random_seed: Option<u64>,
254}
255
256impl MonteCarloCVSplitter {
257 pub fn new(n_splits: usize, test_size: f64, random_seed: Option<u64>) -> Self {
258 Self {
259 n_splits,
260 test_size,
261 random_seed,
262 }
263 }
264}
265
266impl CVSplitter for MonteCarloCVSplitter {
267 fn split(&self, x: &Array2<f64>, _y: Option<&Array1<f64>>) -> Vec<(Vec<usize>, Vec<usize>)> {
268 let n_samples = x.nrows();
269 let test_samples = (n_samples as f64 * self.test_size) as usize;
270 let mut rng = if let Some(seed) = self.random_seed {
271 StdRng::seed_from_u64(seed)
272 } else {
273 StdRng::from_seed(thread_rng().random())
274 };
275
276 let mut splits = Vec::new();
277
278 for _ in 0..self.n_splits {
279 let mut indices: Vec<usize> = (0..n_samples).collect();
280
281 indices.shuffle(&mut rng);
282
283 let test_indices = indices[..test_samples].to_vec();
284 let train_indices = indices[test_samples..].to_vec();
285
286 splits.push((train_indices, test_indices));
287 }
288
289 splits
290 }
291}
292
293pub struct CrossValidator {
295 config: CrossValidationConfig,
296}
297
298impl CrossValidator {
299 pub fn new(config: CrossValidationConfig) -> Self {
301 Self { config }
302 }
303
304 pub fn cross_validate_rbf(
306 &self,
307 x: &Array2<f64>,
308 y: Option<&Array1<f64>>,
309 parameters: &ParameterSet,
310 ) -> Result<CrossValidationResult> {
311 let splitter = self.create_splitter()?;
312 let splits = splitter.split(x, y);
313
314 if self.config.verbose {
315 println!("Performing cross-validation with {} splits", splits.len());
316 }
317
318 let fold_results: Result<Vec<_>> = splits
320 .par_iter()
321 .enumerate()
322 .map(|(fold_idx, (train_indices, test_indices))| {
323 let start_time = std::time::Instant::now();
324
325 let x_train = self.extract_samples(x, train_indices);
327 let x_test = self.extract_samples(x, test_indices);
328 let y_train = y.map(|y_data| self.extract_targets(y_data, train_indices));
329 let y_test = y.map(|y_data| self.extract_targets(y_data, test_indices));
330
331 let sampler = RBFSampler::new(parameters.n_components).gamma(parameters.gamma);
333 let fitted = sampler.fit(&x_train, &())?;
334 let fit_time = start_time.elapsed().as_secs_f64();
335
336 let x_train_transformed = fitted.transform(&x_train)?;
338 let x_test_transformed = fitted.transform(&x_test)?;
339
340 let score_start = std::time::Instant::now();
342 let test_score = self.compute_score(
343 &x_test,
344 &x_test_transformed,
345 y_test.as_ref(),
346 parameters.gamma,
347 )?;
348
349 let train_score = if self.config.return_train_score {
350 Some(self.compute_score(
351 &x_train,
352 &x_train_transformed,
353 y_train.as_ref(),
354 parameters.gamma,
355 )?)
356 } else {
357 None
358 };
359
360 let score_time = score_start.elapsed().as_secs_f64();
361
362 if self.config.verbose {
363 println!(
364 "Fold {}: test_score = {:.6}, fit_time = {:.3}s",
365 fold_idx, test_score, fit_time
366 );
367 }
368
369 Ok((test_score, train_score, fit_time, score_time))
370 })
371 .collect();
372
373 let fold_results = fold_results?;
374
375 self.aggregate_results(fold_results)
376 }
377
378 pub fn cross_validate_nystroem(
380 &self,
381 x: &Array2<f64>,
382 y: Option<&Array1<f64>>,
383 parameters: &ParameterSet,
384 ) -> Result<CrossValidationResult> {
385 use crate::nystroem::Kernel;
386
387 let splitter = self.create_splitter()?;
388 let splits = splitter.split(x, y);
389
390 let fold_results: Result<Vec<_>> = splits
391 .par_iter()
392 .enumerate()
393 .map(|(fold_idx, (train_indices, test_indices))| {
394 let start_time = std::time::Instant::now();
395
396 let x_train = self.extract_samples(x, train_indices);
398 let x_test = self.extract_samples(x, test_indices);
399 let y_train = y.map(|y_data| self.extract_targets(y_data, train_indices));
400 let y_test = y.map(|y_data| self.extract_targets(y_data, test_indices));
401
402 let kernel = Kernel::Rbf {
404 gamma: parameters.gamma,
405 };
406 let nystroem = Nystroem::new(kernel, parameters.n_components);
407 let fitted = nystroem.fit(&x_train, &())?;
408 let fit_time = start_time.elapsed().as_secs_f64();
409
410 let x_train_transformed = fitted.transform(&x_train)?;
412 let x_test_transformed = fitted.transform(&x_test)?;
413
414 let score_start = std::time::Instant::now();
416 let test_score = self.compute_score(
417 &x_test,
418 &x_test_transformed,
419 y_test.as_ref(),
420 parameters.gamma,
421 )?;
422
423 let train_score = if self.config.return_train_score {
424 Some(self.compute_score(
425 &x_train,
426 &x_train_transformed,
427 y_train.as_ref(),
428 parameters.gamma,
429 )?)
430 } else {
431 None
432 };
433
434 let score_time = score_start.elapsed().as_secs_f64();
435
436 if self.config.verbose {
437 println!("Fold {}: test_score = {:.6}", fold_idx, test_score);
438 }
439
440 Ok((test_score, train_score, fit_time, score_time))
441 })
442 .collect();
443
444 let fold_results = fold_results?;
445 self.aggregate_results(fold_results)
446 }
447
448 pub fn cross_validate_with_search(
450 &self,
451 x: &Array2<f64>,
452 y: Option<&Array1<f64>>,
453 parameter_learner: &ParameterLearner,
454 ) -> Result<(ParameterSet, CrossValidationResult)> {
455 let optimization_result = parameter_learner.optimize_rbf_parameters(x, y)?;
457 let best_params = optimization_result.best_parameters;
458
459 if self.config.verbose {
460 println!(
461 "Best parameters found: gamma={:.6}, n_components={}",
462 best_params.gamma, best_params.n_components
463 );
464 }
465
466 let cv_result = self.cross_validate_rbf(x, y, &best_params)?;
468
469 Ok((best_params, cv_result))
470 }
471
472 pub fn grid_search_cv(
474 &self,
475 x: &Array2<f64>,
476 y: Option<&Array1<f64>>,
477 param_grid: &HashMap<String, Vec<f64>>,
478 ) -> Result<(
479 ParameterSet,
480 f64,
481 HashMap<ParameterSet, CrossValidationResult>,
482 )> {
483 let gamma_values = param_grid.get("gamma").ok_or_else(|| {
484 SklearsError::InvalidInput("gamma parameter missing from grid".to_string())
485 })?;
486
487 let n_components_values = param_grid
488 .get("n_components")
489 .ok_or_else(|| {
490 SklearsError::InvalidInput("n_components parameter missing from grid".to_string())
491 })?
492 .iter()
493 .map(|&x| x as usize)
494 .collect::<Vec<_>>();
495
496 let mut best_score = f64::NEG_INFINITY;
497 let mut best_params = ParameterSet {
498 gamma: gamma_values[0],
499 n_components: n_components_values[0],
500 degree: None,
501 coef0: None,
502 };
503 let mut all_results = HashMap::new();
504
505 if self.config.verbose {
506 println!(
507 "Grid search over {} parameter combinations",
508 gamma_values.len() * n_components_values.len()
509 );
510 }
511
512 for &gamma in gamma_values {
513 for &n_components in &n_components_values {
514 let params = ParameterSet {
515 gamma,
516 n_components,
517 degree: None,
518 coef0: None,
519 };
520
521 let cv_result = self.cross_validate_rbf(x, y, ¶ms)?;
522 let mean_score = cv_result.mean_test_score;
523
524 all_results.insert(params.clone(), cv_result);
525
526 if mean_score > best_score {
527 best_score = mean_score;
528 best_params = params;
529 }
530
531 if self.config.verbose {
532 println!(
533 "gamma={:.6}, n_components={}: score={:.6} ± {:.6}",
534 gamma,
535 n_components,
536 mean_score,
537 all_results
538 .get(&ParameterSet {
539 gamma,
540 n_components,
541 degree: None,
542 coef0: None
543 })
544 .expect("operation should succeed")
545 .std_test_score
546 );
547 }
548 }
549 }
550
551 Ok((best_params, best_score, all_results))
552 }
553
554 fn create_splitter(&self) -> Result<Box<dyn CVSplitter + Send + Sync>> {
555 match &self.config.cv_strategy {
556 CVStrategy::KFold { n_folds, shuffle } => Ok(Box::new(KFoldSplitter::new(
557 *n_folds,
558 *shuffle,
559 self.config.random_seed,
560 ))),
561 CVStrategy::TimeSeriesSplit {
562 n_splits,
563 max_train_size,
564 } => Ok(Box::new(TimeSeriesSplitter::new(
565 *n_splits,
566 *max_train_size,
567 ))),
568 CVStrategy::MonteCarlo {
569 n_splits,
570 test_size,
571 } => Ok(Box::new(MonteCarloCVSplitter::new(
572 *n_splits,
573 *test_size,
574 self.config.random_seed,
575 ))),
576 _ => {
577 Ok(Box::new(KFoldSplitter::new(
579 5,
580 true,
581 self.config.random_seed,
582 )))
583 }
584 }
585 }
586
587 fn extract_samples(&self, x: &Array2<f64>, indices: &[usize]) -> Array2<f64> {
588 let n_features = x.ncols();
589 let mut result = Array2::zeros((indices.len(), n_features));
590
591 for (i, &idx) in indices.iter().enumerate() {
592 result.row_mut(i).assign(&x.row(idx));
593 }
594
595 result
596 }
597
598 fn extract_targets(&self, y: &Array1<f64>, indices: &[usize]) -> Array1<f64> {
599 let mut result = Array1::zeros(indices.len());
600
601 for (i, &idx) in indices.iter().enumerate() {
602 result[i] = y[idx];
603 }
604
605 result
606 }
607
608 fn compute_score(
609 &self,
610 x: &Array2<f64>,
611 x_transformed: &Array2<f64>,
612 y: Option<&Array1<f64>>,
613 gamma: f64,
614 ) -> Result<f64> {
615 match &self.config.scoring_metric {
616 ScoringMetric::KernelAlignment => {
617 self.compute_kernel_alignment(x, x_transformed, gamma)
618 }
619 ScoringMetric::MeanSquaredError => {
620 if let Some(y_data) = y {
621 self.compute_mse(x_transformed, y_data)
622 } else {
623 Err(SklearsError::InvalidInput(
624 "Target values required for MSE".to_string(),
625 ))
626 }
627 }
628 ScoringMetric::MeanAbsoluteError => {
629 if let Some(y_data) = y {
630 self.compute_mae(x_transformed, y_data)
631 } else {
632 Err(SklearsError::InvalidInput(
633 "Target values required for MAE".to_string(),
634 ))
635 }
636 }
637 ScoringMetric::R2Score => {
638 if let Some(y_data) = y {
639 self.compute_r2_score(x_transformed, y_data)
640 } else {
641 Err(SklearsError::InvalidInput(
642 "Target values required for R²".to_string(),
643 ))
644 }
645 }
646 _ => {
647 self.compute_kernel_alignment(x, x_transformed, gamma)
649 }
650 }
651 }
652
653 fn compute_kernel_alignment(
654 &self,
655 x: &Array2<f64>,
656 x_transformed: &Array2<f64>,
657 gamma: f64,
658 ) -> Result<f64> {
659 let n_samples = x.nrows().min(50); let x_subset = x.slice(s![..n_samples, ..]);
661
662 let mut k_exact = Array2::zeros((n_samples, n_samples));
664 for i in 0..n_samples {
665 for j in 0..n_samples {
666 let diff = &x_subset.row(i) - &x_subset.row(j);
667 let squared_norm = diff.dot(&diff);
668 k_exact[[i, j]] = (-gamma * squared_norm).exp();
669 }
670 }
671
672 let x_transformed_subset = x_transformed.slice(s![..n_samples, ..]);
674 let k_approx = x_transformed_subset.dot(&x_transformed_subset.t());
675
676 let k_exact_frobenius = k_exact.iter().map(|&x| x * x).sum::<f64>().sqrt();
678 let k_approx_frobenius = k_approx.iter().map(|&x| x * x).sum::<f64>().sqrt();
679 let k_product = (&k_exact * &k_approx).sum();
680
681 let alignment = k_product / (k_exact_frobenius * k_approx_frobenius);
682 Ok(alignment)
683 }
684
685 fn compute_mse(&self, _x_transformed: &Array2<f64>, y: &Array1<f64>) -> Result<f64> {
686 let y_mean = y.mean().unwrap_or(0.0);
689 let mse = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum::<f64>() / y.len() as f64;
690 Ok(-mse) }
692
693 fn compute_mae(&self, _x_transformed: &Array2<f64>, y: &Array1<f64>) -> Result<f64> {
694 let y_mean = y.mean().unwrap_or(0.0);
696 let mae = y.iter().map(|&yi| (yi - y_mean).abs()).sum::<f64>() / y.len() as f64;
697 Ok(-mae) }
699
700 fn compute_r2_score(&self, _x_transformed: &Array2<f64>, y: &Array1<f64>) -> Result<f64> {
701 let y_mean = y.mean().unwrap_or(0.0);
703 let ss_tot = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum::<f64>();
704 let ss_res = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum::<f64>(); let r2 = 1.0 - (ss_res / ss_tot);
707 Ok(r2)
708 }
709
710 fn aggregate_results(
711 &self,
712 fold_results: Vec<(f64, Option<f64>, f64, f64)>,
713 ) -> Result<CrossValidationResult> {
714 let test_scores: Vec<f64> = fold_results.iter().map(|(score, _, _, _)| *score).collect();
715 let train_scores: Option<Vec<f64>> = if self.config.return_train_score {
716 Some(
717 fold_results
718 .iter()
719 .filter_map(|(_, train_score, _, _)| *train_score)
720 .collect(),
721 )
722 } else {
723 None
724 };
725 let fit_times: Vec<f64> = fold_results
726 .iter()
727 .map(|(_, _, fit_time, _)| *fit_time)
728 .collect();
729 let score_times: Vec<f64> = fold_results
730 .iter()
731 .map(|(_, _, _, score_time)| *score_time)
732 .collect();
733
734 let mean_test_score = test_scores.iter().sum::<f64>() / test_scores.len() as f64;
735 let variance_test = test_scores
736 .iter()
737 .map(|&score| (score - mean_test_score).powi(2))
738 .sum::<f64>()
739 / test_scores.len() as f64;
740 let std_test_score = variance_test.sqrt();
741
742 let (mean_train_score, std_train_score) = if let Some(ref train_scores) = train_scores {
743 let mean = train_scores.iter().sum::<f64>() / train_scores.len() as f64;
744 let variance = train_scores
745 .iter()
746 .map(|&score| (score - mean).powi(2))
747 .sum::<f64>()
748 / train_scores.len() as f64;
749 (Some(mean), Some(variance.sqrt()))
750 } else {
751 (None, None)
752 };
753
754 Ok(CrossValidationResult {
755 test_scores,
756 train_scores,
757 mean_test_score,
758 std_test_score,
759 mean_train_score,
760 std_train_score,
761 fit_times,
762 score_times,
763 })
764 }
765}
766
767#[allow(non_snake_case)]
768#[cfg(test)]
769mod tests {
770 use super::*;
771 use approx::assert_abs_diff_eq;
772
773 #[test]
774 fn test_kfold_splitter() {
775 let x = Array2::from_shape_vec((20, 3), (0..60).map(|i| i as f64).collect())
776 .expect("operation should succeed");
777
778 let splitter = KFoldSplitter::new(4, false, Some(42));
779 let splits = splitter.split(&x, None);
780
781 assert_eq!(splits.len(), 4);
782
783 let mut all_test_indices: Vec<usize> = Vec::new();
785 for (_, test_indices) in &splits {
786 all_test_indices.extend(test_indices);
787 }
788 all_test_indices.sort();
789
790 let expected_indices: Vec<usize> = (0..20).collect();
791 assert_eq!(all_test_indices, expected_indices);
792
793 for (_, test_indices) in &splits {
795 assert!(test_indices.len() >= 4);
796 assert!(test_indices.len() <= 6);
797 }
798 }
799
800 #[test]
801 fn test_time_series_splitter() {
802 let x = Array2::from_shape_vec((30, 2), (0..60).map(|i| i as f64).collect())
803 .expect("operation should succeed");
804
805 let splitter = TimeSeriesSplitter::new(3, Some(15));
806 let splits = splitter.split(&x, None);
807
808 assert_eq!(splits.len(), 3);
809
810 for (train_indices, test_indices) in &splits {
812 if !train_indices.is_empty() && !test_indices.is_empty() {
813 let max_train = train_indices
814 .iter()
815 .max()
816 .expect("operation should succeed");
817 let min_test = test_indices.iter().min().expect("operation should succeed");
818 assert!(max_train < min_test);
819 }
820 }
821 }
822
823 #[test]
824 fn test_monte_carlo_splitter() {
825 let x = Array2::from_shape_vec((50, 4), (0..200).map(|i| i as f64).collect())
826 .expect("operation should succeed");
827
828 let splitter = MonteCarloCVSplitter::new(5, 0.3, Some(123));
829 let splits = splitter.split(&x, None);
830
831 assert_eq!(splits.len(), 5);
832
833 for (train_indices, test_indices) in &splits {
835 let total_size = train_indices.len() + test_indices.len();
836 assert_eq!(total_size, 50);
837 assert!(test_indices.len() >= 14); assert!(test_indices.len() <= 16);
839 }
840 }
841
842 #[test]
843 fn test_cross_validator_rbf() {
844 let x = Array2::from_shape_vec((40, 5), (0..200).map(|i| i as f64 * 0.01).collect())
845 .expect("operation should succeed");
846
847 let config = CrossValidationConfig {
848 cv_strategy: CVStrategy::KFold {
849 n_folds: 3,
850 shuffle: true,
851 },
852 scoring_metric: ScoringMetric::KernelAlignment,
853 return_train_score: true,
854 random_seed: Some(42),
855 ..Default::default()
856 };
857
858 let cv = CrossValidator::new(config);
859 let params = ParameterSet {
860 gamma: 0.5,
861 n_components: 20,
862 degree: None,
863 coef0: None,
864 };
865
866 let result = cv
867 .cross_validate_rbf(&x, None, ¶ms)
868 .expect("operation should succeed");
869
870 assert_eq!(result.test_scores.len(), 3);
871 assert!(result.train_scores.is_some());
872 assert_eq!(
873 result
874 .train_scores
875 .as_ref()
876 .expect("operation should succeed")
877 .len(),
878 3
879 );
880 assert!(result.mean_test_score > 0.0);
881 assert!(result.std_test_score >= 0.0);
882 assert!(result.mean_train_score.is_some());
883 assert!(result.std_train_score.is_some());
884 assert_eq!(result.fit_times.len(), 3);
885 assert_eq!(result.score_times.len(), 3);
886 }
887
888 #[test]
889 fn test_cross_validator_nystroem() {
890 let x = Array2::from_shape_vec((30, 4), (0..120).map(|i| i as f64 * 0.02).collect())
891 .expect("operation should succeed");
892
893 let config = CrossValidationConfig {
894 cv_strategy: CVStrategy::KFold {
895 n_folds: 4,
896 shuffle: false,
897 },
898 scoring_metric: ScoringMetric::KernelAlignment,
899 ..Default::default()
900 };
901
902 let cv = CrossValidator::new(config);
903 let params = ParameterSet {
904 gamma: 1.0,
905 n_components: 15,
906 degree: None,
907 coef0: None,
908 };
909
910 let result = cv
911 .cross_validate_nystroem(&x, None, ¶ms)
912 .expect("operation should succeed");
913
914 assert_eq!(result.test_scores.len(), 4);
915 assert!(result.mean_test_score > 0.0);
916 assert!(result.std_test_score >= 0.0);
917 }
918
919 #[test]
920 fn test_grid_search_cv() {
921 let x = Array2::from_shape_vec((25, 3), (0..75).map(|i| i as f64 * 0.05).collect())
922 .expect("operation should succeed");
923
924 let config = CrossValidationConfig {
925 cv_strategy: CVStrategy::KFold {
926 n_folds: 3,
927 shuffle: true,
928 },
929 random_seed: Some(789),
930 verbose: false,
931 ..Default::default()
932 };
933
934 let cv = CrossValidator::new(config);
935
936 let mut param_grid = HashMap::new();
937 param_grid.insert("gamma".to_string(), vec![0.1, 1.0]);
938 param_grid.insert("n_components".to_string(), vec![10.0, 20.0]);
939
940 let (best_params, best_score, all_results) = cv
941 .grid_search_cv(&x, None, ¶m_grid)
942 .expect("operation should succeed");
943
944 assert!(best_score > 0.0);
945 assert!(best_params.gamma == 0.1 || best_params.gamma == 1.0);
946 assert!(best_params.n_components == 10 || best_params.n_components == 20);
947 assert_eq!(all_results.len(), 4); let max_score = all_results
951 .values()
952 .map(|result| result.mean_test_score)
953 .fold(f64::NEG_INFINITY, f64::max);
954 assert_abs_diff_eq!(best_score, max_score, epsilon = 1e-10);
955 }
956
957 #[test]
958 fn test_cross_validation_with_targets() {
959 let x = Array2::from_shape_vec((20, 3), (0..60).map(|i| i as f64 * 0.1).collect())
960 .expect("operation should succeed");
961 let y = Array1::from_shape_fn(20, |i| (i as f64 * 0.1).sin());
962
963 let config = CrossValidationConfig {
964 cv_strategy: CVStrategy::KFold {
965 n_folds: 4,
966 shuffle: true,
967 },
968 scoring_metric: ScoringMetric::MeanSquaredError,
969 random_seed: Some(456),
970 ..Default::default()
971 };
972
973 let cv = CrossValidator::new(config);
974 let params = ParameterSet {
975 gamma: 0.8,
976 n_components: 15,
977 degree: None,
978 coef0: None,
979 };
980
981 let result = cv
982 .cross_validate_rbf(&x, Some(&y), ¶ms)
983 .expect("operation should succeed");
984
985 assert_eq!(result.test_scores.len(), 4);
986 assert!(result.mean_test_score <= 0.0);
988 }
989
990 #[test]
991 fn test_cv_splitter_consistency() {
992 let x = Array2::from_shape_vec((15, 2), (0..30).map(|i| i as f64).collect())
993 .expect("operation should succeed");
994
995 let splitter1 = KFoldSplitter::new(3, true, Some(42));
997 let splitter2 = KFoldSplitter::new(3, true, Some(42));
998
999 let splits1 = splitter1.split(&x, None);
1000 let splits2 = splitter2.split(&x, None);
1001
1002 assert_eq!(splits1.len(), splits2.len());
1003 for (split1, split2) in splits1.iter().zip(splits2.iter()) {
1004 assert_eq!(split1.0, split2.0); assert_eq!(split1.1, split2.1); }
1007 }
1008
1009 #[test]
1010 fn test_cross_validation_result_aggregation() {
1011 let mut config = CrossValidationConfig::default();
1012 config.return_train_score = true;
1013 let cv = CrossValidator::new(config);
1014
1015 let fold_results = vec![
1016 (0.8, Some(0.85), 0.1, 0.05),
1017 (0.75, Some(0.8), 0.12, 0.04),
1018 (0.82, Some(0.88), 0.11, 0.06),
1019 ];
1020
1021 let result = cv
1022 .aggregate_results(fold_results)
1023 .expect("operation should succeed");
1024
1025 assert_abs_diff_eq!(result.mean_test_score, 0.79, epsilon = 1e-10);
1026 assert!(result.std_test_score > 0.0);
1027 assert!(result.mean_train_score.is_some());
1028 assert_abs_diff_eq!(
1029 result.mean_train_score.expect("operation should succeed"),
1030 0.8433333333333334,
1031 epsilon = 1e-10
1032 );
1033 }
1034}