1use crate::advanced::rbf::{RBFInterpolator, RBFKernel};
45use crate::bspline::BSpline;
46use crate::error::{InterpolateError, InterpolateResult};
47use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ScalarOperand};
48use scirs2_core::numeric::{Float, FromPrimitive, ToPrimitive};
49use std::collections::HashMap;
50use std::fmt::{Debug, Display, LowerExp};
51use std::ops::{AddAssign, DivAssign, MulAssign, RemAssign, SubAssign};
52
53#[derive(Debug, Clone, Copy, PartialEq)]
55pub enum ValidationMetric {
56 MeanSquaredError,
58 MeanAbsoluteError,
60 RootMeanSquaredError,
62 RSquared,
64 MeanAbsolutePercentageError,
66 MaxAbsoluteError,
68}
69
70#[derive(Debug, Clone, Copy, PartialEq)]
72pub enum CrossValidationStrategy {
73 KFold(usize),
75 LeaveOneOut,
77 MonteCarlo { n_splits: usize, test_fraction: f64 },
79 TimeSeries { n_splits: usize, gap: usize },
81}
82
83#[derive(Debug, Clone)]
85pub struct OptimizationConfig<T> {
86 pub max_iterations: usize,
88 pub tolerance: T,
90 pub random_seed: u64,
92 pub parallel: bool,
94 pub verbosity: usize,
96}
97
98impl<T: Float + FromPrimitive> Default for OptimizationConfig<T> {
99 fn default() -> Self {
100 Self {
101 max_iterations: 100,
102 tolerance: T::from(1e-6).expect("Operation failed"),
103 random_seed: 42,
104 parallel: true,
105 verbosity: 1,
106 }
107 }
108}
109
110#[derive(Debug, Clone)]
112pub struct OptimizationResult<T> {
113 pub best_parameters: HashMap<String, T>,
115 pub best_score: T,
117 pub parameter_scores: Vec<(HashMap<String, T>, T)>,
119 pub iterations: usize,
121 pub converged: bool,
123 pub optimization_time_ms: u64,
125}
126
127#[derive(Debug, Clone)]
129pub struct CrossValidationResult<T> {
130 pub mean_score: T,
132 pub std_score: T,
134 pub fold_scores: Vec<T>,
136 pub n_folds: usize,
138 pub metric: ValidationMetric,
140}
141
142#[derive(Debug)]
144pub struct CrossValidator<T>
145where
146 T: Float
147 + FromPrimitive
148 + ToPrimitive
149 + Debug
150 + Display
151 + LowerExp
152 + ScalarOperand
153 + AddAssign
154 + SubAssign
155 + MulAssign
156 + DivAssign
157 + RemAssign
158 + Copy
159 + Send
160 + Sync
161 + 'static,
162{
163 strategy: CrossValidationStrategy,
165 metric: ValidationMetric,
167 shuffle: bool,
169 random_seed: u64,
171 config: OptimizationConfig<T>,
173}
174
175impl<T> Default for CrossValidator<T>
176where
177 T: Float
178 + FromPrimitive
179 + ToPrimitive
180 + Debug
181 + Display
182 + LowerExp
183 + ScalarOperand
184 + AddAssign
185 + SubAssign
186 + MulAssign
187 + DivAssign
188 + RemAssign
189 + Copy
190 + Send
191 + Sync
192 + 'static,
193{
194 fn default() -> Self {
195 Self::new()
196 }
197}
198
199impl<T> CrossValidator<T>
200where
201 T: Float
202 + FromPrimitive
203 + ToPrimitive
204 + Debug
205 + Display
206 + LowerExp
207 + ScalarOperand
208 + AddAssign
209 + SubAssign
210 + MulAssign
211 + DivAssign
212 + RemAssign
213 + Copy
214 + Send
215 + Sync
216 + 'static,
217{
218 pub fn new() -> Self {
220 Self {
221 strategy: CrossValidationStrategy::KFold(5),
222 metric: ValidationMetric::MeanSquaredError,
223 shuffle: true,
224 random_seed: 42,
225 config: OptimizationConfig::default(),
226 }
227 }
228
229 pub fn with_strategy(mut self, strategy: CrossValidationStrategy) -> Self {
231 self.strategy = strategy;
232 self
233 }
234
235 pub fn with_k_folds(mut self, k: usize) -> Self {
237 self.strategy = CrossValidationStrategy::KFold(k);
238 self
239 }
240
241 pub fn with_metric(mut self, metric: ValidationMetric) -> Self {
243 self.metric = metric;
244 self
245 }
246
247 pub fn with_shuffle(mut self, shuffle: bool) -> Self {
249 self.shuffle = shuffle;
250 self
251 }
252
253 pub fn with_random_seed(mut self, seed: u64) -> Self {
255 self.random_seed = seed;
256 self
257 }
258
259 pub fn with_config(mut self, config: OptimizationConfig<T>) -> Self {
261 self.config = config;
262 self
263 }
264
265 pub fn cross_validate<F>(
277 &self,
278 x: &ArrayView1<T>,
279 y: &ArrayView1<T>,
280 interpolator_fn: F,
281 ) -> InterpolateResult<CrossValidationResult<T>>
282 where
283 F: Fn(&ArrayView1<T>, &ArrayView1<T>) -> InterpolateResult<Box<dyn InterpolatorTrait<T>>>,
284 {
285 let n = x.len();
286 if n != y.len() {
287 return Err(InterpolateError::DimensionMismatch(
288 "x and y must have the same length".to_string(),
289 ));
290 }
291
292 let folds = self.generate_folds(n)?;
293 let mut fold_scores = Vec::new();
294
295 for (train_indices, test_indices) in folds {
296 let x_train = self.extract_indices(x, &train_indices);
298 let y_train = self.extract_indices(y, &train_indices);
299 let x_test = self.extract_indices(x, &test_indices);
300 let y_test = self.extract_indices(y, &test_indices);
301
302 let mut training_pairs: Vec<_> = x_train
304 .iter()
305 .zip(y_train.iter())
306 .map(|(x, y)| (*x, *y))
307 .collect();
308 training_pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("Operation failed"));
309
310 let x_train_sorted: Array1<T> = training_pairs.iter().map(|(x, _)| *x).collect();
311 let y_train_sorted: Array1<T> = training_pairs.iter().map(|(_, y)| *y).collect();
312
313 let interpolator = interpolator_fn(&x_train_sorted.view(), &y_train_sorted.view())?;
315
316 let y_pred = interpolator.evaluate(&x_test.view())?;
318
319 let score = self.compute_metric(&y_test.view(), &y_pred.view())?;
321 fold_scores.push(score);
322 }
323
324 let n_folds = fold_scores.len();
325 let mean_score = fold_scores.iter().fold(T::zero(), |acc, &x| acc + x)
326 / T::from(fold_scores.len()).expect("Operation failed");
327 let variance = fold_scores
328 .iter()
329 .map(|&score| (score - mean_score) * (score - mean_score))
330 .fold(T::zero(), |acc, x| acc + x)
331 / T::from(fold_scores.len()).expect("Operation failed");
332 let std_score = variance.sqrt();
333
334 Ok(CrossValidationResult {
335 mean_score,
336 std_score,
337 fold_scores,
338 n_folds,
339 metric: self.metric,
340 })
341 }
342
343 pub fn optimize_rbf_parameters(
355 &mut self,
356 x: &ArrayView1<T>,
357 y: &ArrayView1<T>,
358 kernel_widths: &[T],
359 ) -> InterpolateResult<OptimizationResult<T>> {
360 let start_time = std::time::Instant::now();
361 let mut parameter_scores = Vec::new();
362 let mut best_score = T::infinity();
363 let mut best_params = HashMap::new();
364
365 for &width in kernel_widths {
366 let interpolator_fn = |x_train: &ArrayView1<T>, y_train: &ArrayView1<T>| {
367 let points_2d = Array2::from_shape_vec((x_train.len(), 1), x_train.to_vec())
369 .map_err(|e| {
370 InterpolateError::ComputationError(format!("Failed to reshape: {}", e))
371 })?;
372
373 let rbf =
374 RBFInterpolator::new(&points_2d.view(), y_train, RBFKernel::Gaussian, width)?;
375
376 Ok(Box::new(RBFWrapper::new(rbf)) as Box<dyn InterpolatorTrait<T>>)
377 };
378
379 let cv_result = self.cross_validate(x, y, interpolator_fn)?;
380 let score = cv_result.mean_score;
381
382 let mut params = HashMap::new();
383 params.insert("kernel_width".to_string(), width);
384 parameter_scores.push((params.clone(), score));
385
386 if score < best_score {
387 best_score = score;
388 best_params = params;
389 }
390
391 if self.config.verbosity > 0 {
392 println!(
393 "Width: {:.3}, CV Score: {:.6}",
394 width.to_f64().unwrap_or(0.0),
395 score.to_f64().unwrap_or(0.0)
396 );
397 }
398 }
399
400 let optimization_time_ms = start_time.elapsed().as_millis() as u64;
401
402 Ok(OptimizationResult {
403 best_parameters: best_params,
404 best_score,
405 parameter_scores,
406 iterations: kernel_widths.len(),
407 converged: true,
408 optimization_time_ms,
409 })
410 }
411
412 pub fn optimize_bspline_parameters(
424 &mut self,
425 x: &ArrayView1<T>,
426 y: &ArrayView1<T>,
427 degrees: &[usize],
428 ) -> InterpolateResult<OptimizationResult<T>> {
429 let start_time = std::time::Instant::now();
430 let mut parameter_scores = Vec::new();
431 let mut best_score = T::infinity();
432 let mut best_params = HashMap::new();
433
434 for °ree in degrees {
435 let interpolator_fn = |x_train: &ArrayView1<T>, y_train: &ArrayView1<T>| {
436 let bspline = crate::bspline::make_interp_bspline(
437 x_train,
438 y_train,
439 degree,
440 crate::bspline::ExtrapolateMode::Extrapolate,
441 )?;
442
443 Ok(Box::new(BSplineWrapper::new(bspline)) as Box<dyn InterpolatorTrait<T>>)
444 };
445
446 let cv_result = self.cross_validate(x, y, interpolator_fn)?;
447 let score = cv_result.mean_score;
448
449 let mut params = HashMap::new();
450 params.insert(
451 "degree".to_string(),
452 T::from(degree).expect("Operation failed"),
453 );
454 parameter_scores.push((params.clone(), score));
455
456 if score < best_score {
457 best_score = score;
458 best_params = params;
459 }
460
461 if self.config.verbosity > 0 {
462 println!(
463 "Degree: {}, CV Score: {:.6}",
464 degree,
465 score.to_f64().unwrap_or(0.0)
466 );
467 }
468 }
469
470 let optimization_time_ms = start_time.elapsed().as_millis() as u64;
471
472 Ok(OptimizationResult {
473 best_parameters: best_params,
474 best_score,
475 parameter_scores,
476 iterations: degrees.len(),
477 converged: true,
478 optimization_time_ms,
479 })
480 }
481
482 fn generate_folds(&self, n: usize) -> InterpolateResult<Vec<(Vec<usize>, Vec<usize>)>> {
484 match self.strategy {
485 CrossValidationStrategy::KFold(k) => {
486 if k > n {
487 return Err(InterpolateError::InvalidValue(
488 "Number of folds cannot exceed number of samples".to_string(),
489 ));
490 }
491
492 let mut indices: Vec<usize> = (0..n).collect();
493
494 if self.shuffle {
496 for i in 0..n {
497 let j = (self.random_seed as usize + i * 1103515245 + 12345) % n;
498 indices.swap(i, j);
499 }
500 }
501
502 let fold_size = n / k;
503 let mut folds = Vec::new();
504
505 for fold_idx in 0..k {
506 let start = fold_idx * fold_size;
507 let end = if fold_idx == k - 1 {
508 n
509 } else {
510 (fold_idx + 1) * fold_size
511 };
512
513 let test_indices = indices[start..end].to_vec();
514 let train_indices: Vec<usize> = indices
515 .iter()
516 .enumerate()
517 .filter(|(i_, _)| *i_ < start || *i_ >= end)
518 .map(|(_, &idx)| idx)
519 .collect();
520
521 folds.push((train_indices, test_indices));
522 }
523
524 Ok(folds)
525 }
526 CrossValidationStrategy::LeaveOneOut => {
527 let mut folds = Vec::new();
528 for i in 0..n {
529 let test_indices = vec![i];
530 let train_indices: Vec<usize> = (0..n).filter(|&idx| idx != i).collect();
531 folds.push((train_indices, test_indices));
532 }
533 Ok(folds)
534 }
535 CrossValidationStrategy::MonteCarlo {
536 n_splits,
537 test_fraction,
538 } => {
539 let mut folds = Vec::new();
540 let test_size = (n as f64 * test_fraction).max(1.0) as usize;
541
542 for split in 0..n_splits {
545 let mut indices: Vec<usize> = (0..n).collect();
546
547 for i in 0..n {
549 let j = (i + split * 17) % n; indices.swap(i, j);
551 }
552
553 let test_indices = indices[0..test_size].to_vec();
554 let train_indices = indices[test_size..].to_vec();
555 folds.push((train_indices, test_indices));
556 }
557 Ok(folds)
558 }
559 CrossValidationStrategy::TimeSeries { n_splits, gap: _ } => {
560 let mut folds = Vec::new();
562 let min_train_size = n / (n_splits + 1);
563 let test_size = n / (n_splits + 1);
564
565 for i in 0..n_splits {
566 let train_end = min_train_size + i * test_size;
567 let test_start = train_end;
568 let test_end = (test_start + test_size).min(n);
569
570 if test_end <= test_start {
571 break;
572 }
573
574 let train_indices: Vec<usize> = (0..train_end).collect();
575 let test_indices: Vec<usize> = (test_start..test_end).collect();
576
577 folds.push((train_indices, test_indices));
578 }
579 Ok(folds)
580 }
581 }
582 }
583
584 fn extract_indices(&self, arr: &ArrayView1<T>, indices: &[usize]) -> Array1<T> {
586 let mut result = Array1::zeros(indices.len());
587 for (i, &idx) in indices.iter().enumerate() {
588 result[i] = arr[idx];
589 }
590 result
591 }
592
593 fn compute_metric(
595 &self,
596 y_true: &ArrayView1<T>,
597 y_pred: &ArrayView1<T>,
598 ) -> InterpolateResult<T> {
599 if y_true.len() != y_pred.len() {
600 return Err(InterpolateError::DimensionMismatch(
601 "y_true and y_pred must have the same length".to_string(),
602 ));
603 }
604
605 let n = T::from(y_true.len()).expect("Operation failed");
606
607 match self.metric {
608 ValidationMetric::MeanSquaredError => {
609 let mse = y_true
610 .iter()
611 .zip(y_pred.iter())
612 .map(|(&yt, &yp)| (yt - yp) * (yt - yp))
613 .fold(T::zero(), |acc, x| acc + x)
614 / n;
615 Ok(mse)
616 }
617 ValidationMetric::MeanAbsoluteError => {
618 let mae = y_true
619 .iter()
620 .zip(y_pred.iter())
621 .map(|(&yt, &yp)| (yt - yp).abs())
622 .fold(T::zero(), |acc, x| acc + x)
623 / n;
624 Ok(mae)
625 }
626 ValidationMetric::RootMeanSquaredError => {
627 let mse = y_true
628 .iter()
629 .zip(y_pred.iter())
630 .map(|(&yt, &yp)| (yt - yp) * (yt - yp))
631 .fold(T::zero(), |acc, x| acc + x)
632 / n;
633 Ok(mse.sqrt())
634 }
635 ValidationMetric::RSquared => {
636 let y_mean = y_true.sum() / n;
637 let ss_tot = y_true
638 .iter()
639 .map(|&yt| (yt - y_mean) * (yt - y_mean))
640 .fold(T::zero(), |acc, x| acc + x);
641 let ss_res = y_true
642 .iter()
643 .zip(y_pred.iter())
644 .map(|(&yt, &yp)| (yt - yp) * (yt - yp))
645 .fold(T::zero(), |acc, x| acc + x);
646
647 if ss_tot == T::zero() {
648 Ok(T::one()) } else {
650 Ok(T::one() - ss_res / ss_tot)
651 }
652 }
653 ValidationMetric::MaxAbsoluteError => {
654 let max_error = y_true
655 .iter()
656 .zip(y_pred.iter())
657 .map(|(&yt, &yp)| (yt - yp).abs())
658 .fold(T::zero(), |acc, x| acc.max(x));
659 Ok(max_error)
660 }
661 ValidationMetric::MeanAbsolutePercentageError => {
662 let mut mape = T::zero();
663 let mut count = 0;
664 for (&yt, &yp) in y_true.iter().zip(y_pred.iter()) {
665 if yt != T::zero() {
666 mape += ((yt - yp) / yt).abs();
667 count += 1;
668 }
669 }
670 if count > 0 {
671 Ok(mape / T::from(count).expect("Operation failed")
672 * T::from(100.0).expect("Operation failed"))
673 } else {
674 Ok(T::zero())
675 }
676 }
677 }
678 }
679}
680
681pub trait InterpolatorTrait<T>: Debug + Send + Sync
683where
684 T: Float + Debug + Copy,
685{
686 fn evaluate(&self, x: &ArrayView1<T>) -> InterpolateResult<Array1<T>>;
687}
688
689#[derive(Debug)]
691struct RBFWrapper<T>
692where
693 T: Float
694 + FromPrimitive
695 + ToPrimitive
696 + Debug
697 + Display
698 + LowerExp
699 + ScalarOperand
700 + AddAssign
701 + SubAssign
702 + MulAssign
703 + DivAssign
704 + RemAssign
705 + Copy
706 + Send
707 + Sync
708 + 'static,
709{
710 interpolator: RBFInterpolator<T>,
711}
712
713impl<T> RBFWrapper<T>
714where
715 T: Float
716 + FromPrimitive
717 + ToPrimitive
718 + Debug
719 + Display
720 + LowerExp
721 + ScalarOperand
722 + AddAssign
723 + SubAssign
724 + MulAssign
725 + DivAssign
726 + RemAssign
727 + Copy
728 + Send
729 + Sync
730 + 'static,
731{
732 fn new(interpolator: RBFInterpolator<T>) -> Self {
733 Self { interpolator }
734 }
735}
736
737impl<T> InterpolatorTrait<T> for RBFWrapper<T>
738where
739 T: Float
740 + FromPrimitive
741 + ToPrimitive
742 + Debug
743 + Display
744 + LowerExp
745 + ScalarOperand
746 + AddAssign
747 + SubAssign
748 + MulAssign
749 + DivAssign
750 + RemAssign
751 + Copy
752 + Send
753 + Sync
754 + 'static,
755{
756 fn evaluate(&self, x: &ArrayView1<T>) -> InterpolateResult<Array1<T>> {
757 let points_2d = Array2::from_shape_vec((x.len(), 1), x.to_vec())
759 .map_err(|e| InterpolateError::ComputationError(format!("Failed to reshape: {}", e)))?;
760
761 self.interpolator.interpolate(&points_2d.view())
762 }
763}
764
765#[derive(Debug)]
767struct BSplineWrapper<T>
768where
769 T: Float
770 + FromPrimitive
771 + Debug
772 + Display
773 + Copy
774 + Send
775 + Sync
776 + AddAssign
777 + SubAssign
778 + MulAssign
779 + DivAssign
780 + RemAssign
781 + 'static,
782{
783 interpolator: BSpline<T>,
784}
785
786impl<T> BSplineWrapper<T>
787where
788 T: Float
789 + FromPrimitive
790 + Debug
791 + Display
792 + Copy
793 + Send
794 + Sync
795 + AddAssign
796 + SubAssign
797 + MulAssign
798 + DivAssign
799 + RemAssign
800 + 'static,
801{
802 fn new(interpolator: BSpline<T>) -> Self {
803 Self { interpolator }
804 }
805}
806
807impl<T> InterpolatorTrait<T> for BSplineWrapper<T>
808where
809 T: Float
810 + FromPrimitive
811 + Debug
812 + Display
813 + Copy
814 + Send
815 + Sync
816 + AddAssign
817 + SubAssign
818 + MulAssign
819 + DivAssign
820 + RemAssign
821 + 'static,
822{
823 fn evaluate(&self, x: &ArrayView1<T>) -> InterpolateResult<Array1<T>> {
824 self.interpolator.evaluate_array(x)
825 }
826}
827
828#[derive(Debug)]
830pub struct ModelSelector<T>
831where
832 T: Float
833 + FromPrimitive
834 + ToPrimitive
835 + Debug
836 + Display
837 + LowerExp
838 + ScalarOperand
839 + AddAssign
840 + SubAssign
841 + MulAssign
842 + DivAssign
843 + RemAssign
844 + Copy
845 + Send
846 + Sync
847 + 'static,
848{
849 cross_validator: CrossValidator<T>,
851 #[allow(dead_code)]
853 comparison_results: Vec<(String, CrossValidationResult<T>)>,
854}
855
856impl<T> ModelSelector<T>
857where
858 T: Float
859 + FromPrimitive
860 + ToPrimitive
861 + Debug
862 + Display
863 + LowerExp
864 + ScalarOperand
865 + AddAssign
866 + SubAssign
867 + MulAssign
868 + DivAssign
869 + RemAssign
870 + Copy
871 + Send
872 + Sync
873 + 'static,
874{
875 pub fn new() -> Self {
877 Self {
878 cross_validator: CrossValidator::new(),
879 comparison_results: Vec::new(),
880 }
881 }
882
883 pub fn with_cross_validator(mut self, cv: CrossValidator<T>) -> Self {
885 self.cross_validator = cv;
886 self
887 }
888
889 #[allow(dead_code)]
901 pub fn compare_methods<F>(
902 &mut self,
903 x: &ArrayView1<T>,
904 y: &ArrayView1<T>,
905 methods: HashMap<String, F>,
906 ) -> InterpolateResult<Vec<(String, CrossValidationResult<T>)>>
907 where
908 F: Fn(&ArrayView1<T>, &ArrayView1<T>) -> InterpolateResult<Box<dyn InterpolatorTrait<T>>>
909 + Clone,
910 {
911 let mut results = Vec::new();
912
913 for (method_name, interpolator_fn) in methods {
914 let cv_result = self.cross_validator.cross_validate(x, y, interpolator_fn)?;
915 results.push((method_name, cv_result));
916 }
917
918 results.sort_by(|a, b| {
920 a.1.mean_score
921 .partial_cmp(&b.1.mean_score)
922 .expect("Operation failed")
923 });
924
925 Ok(results)
926 }
927}
928
929impl<T> Default for ModelSelector<T>
930where
931 T: Float
932 + FromPrimitive
933 + ToPrimitive
934 + Debug
935 + Display
936 + LowerExp
937 + ScalarOperand
938 + AddAssign
939 + SubAssign
940 + MulAssign
941 + DivAssign
942 + RemAssign
943 + Copy
944 + Send
945 + Sync
946 + 'static,
947{
948 fn default() -> Self {
949 Self::new()
950 }
951}
952
953#[allow(dead_code)]
964pub fn make_cross_validator<T>(_kfolds: usize, metric: ValidationMetric) -> CrossValidator<T>
965where
966 T: Float
967 + FromPrimitive
968 + ToPrimitive
969 + Debug
970 + Display
971 + LowerExp
972 + ScalarOperand
973 + AddAssign
974 + SubAssign
975 + MulAssign
976 + DivAssign
977 + RemAssign
978 + Copy
979 + Send
980 + Sync
981 + 'static,
982{
983 CrossValidator::new()
984 .with_k_folds(_kfolds)
985 .with_metric(metric)
986}
987
988#[allow(dead_code)]
1002pub fn grid_search<T, F>(
1003 x: &ArrayView1<T>,
1004 y: &ArrayView1<T>,
1005 parameter_grid: &[HashMap<String, T>],
1006 cv: &CrossValidator<T>,
1007 interpolator_fn: F,
1008) -> InterpolateResult<(HashMap<String, T>, T)>
1009where
1010 T: Float
1011 + FromPrimitive
1012 + ToPrimitive
1013 + Debug
1014 + Display
1015 + LowerExp
1016 + ScalarOperand
1017 + AddAssign
1018 + SubAssign
1019 + MulAssign
1020 + DivAssign
1021 + RemAssign
1022 + Copy
1023 + Send
1024 + Sync
1025 + 'static,
1026 F: Fn(
1027 &HashMap<String, T>,
1028 &ArrayView1<T>,
1029 &ArrayView1<T>,
1030 ) -> InterpolateResult<Box<dyn InterpolatorTrait<T>>>,
1031{
1032 let mut best_score = T::infinity();
1033 let mut best_params = HashMap::new();
1034
1035 for params in parameter_grid {
1036 let interpolator_factory = |x_train: &ArrayView1<T>, y_train: &ArrayView1<T>| {
1037 interpolator_fn(params, x_train, y_train)
1038 };
1039
1040 let cv_result = cv.cross_validate(x, y, interpolator_factory)?;
1041
1042 if cv_result.mean_score < best_score {
1043 best_score = cv_result.mean_score;
1044 best_params = params.clone();
1045 }
1046 }
1047
1048 Ok((best_params, best_score))
1049}
1050
1051#[cfg(test)]
1052mod tests {
1053 use super::*;
1054 use scirs2_core::ndarray::Array1;
1055
1056 #[test]
1057 fn test_cross_validator_creation() {
1058 let cv = CrossValidator::<f64>::new();
1059 assert_eq!(cv.metric, ValidationMetric::MeanSquaredError);
1060 assert!(cv.shuffle);
1061 }
1062
1063 #[test]
1064 fn test_cross_validator_configuration() {
1065 let cv = CrossValidator::<f64>::new()
1066 .with_k_folds(10)
1067 .with_metric(ValidationMetric::MeanAbsoluteError)
1068 .with_shuffle(false);
1069
1070 match cv.strategy {
1071 CrossValidationStrategy::KFold(k) => assert_eq!(k, 10),
1072 _ => panic!("Expected KFold strategy"),
1073 }
1074 assert_eq!(cv.metric, ValidationMetric::MeanAbsoluteError);
1075 assert!(!cv.shuffle);
1076 }
1077
1078 #[test]
1079 fn test_fold_generation() {
1080 let cv = CrossValidator::<f64>::new().with_k_folds(3);
1081 let folds = cv.generate_folds(9).expect("Operation failed");
1082
1083 assert_eq!(folds.len(), 3);
1084
1085 let mut all_indices = std::collections::HashSet::new();
1087 for (train, test) in &folds {
1088 for &idx in train {
1089 all_indices.insert(idx);
1090 }
1091 for &idx in test {
1092 all_indices.insert(idx);
1093 }
1094 }
1095 assert_eq!(all_indices.len(), 9);
1096 }
1097
1098 #[test]
1099 fn test_leave_one_out_folds() {
1100 let cv = CrossValidator::<f64>::new().with_strategy(CrossValidationStrategy::LeaveOneOut);
1101 let folds = cv.generate_folds(5).expect("Operation failed");
1102
1103 assert_eq!(folds.len(), 5);
1104 for (train, test) in &folds {
1105 assert_eq!(test.len(), 1);
1106 assert_eq!(train.len(), 4);
1107 }
1108 }
1109
1110 #[test]
1111 fn test_metric_computation() {
1112 let cv = CrossValidator::<f64>::new().with_metric(ValidationMetric::MeanSquaredError);
1113
1114 let y_true = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
1115 let y_pred = Array1::from_vec(vec![1.1, 1.9, 3.1, 3.9]);
1116
1117 let mse = cv
1118 .compute_metric(&y_true.view(), &y_pred.view())
1119 .expect("Operation failed");
1120 let expected_mse = (0.1 * 0.1 + 0.1 * 0.1 + 0.1 * 0.1 + 0.1 * 0.1) / 4.0;
1121 assert!((mse - expected_mse).abs() < 1e-10);
1122 }
1123
1124 #[test]
1125 fn test_r_squared_metric() {
1126 let cv = CrossValidator::<f64>::new().with_metric(ValidationMetric::RSquared);
1127
1128 let y_true = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
1129 let y_pred = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]); let r2 = cv
1132 .compute_metric(&y_true.view(), &y_pred.view())
1133 .expect("Operation failed");
1134 assert!((r2 - 1.0).abs() < 1e-10);
1135 }
1136
1137 #[test]
1138 fn test_rbf_parameter_optimization() {
1139 let x = Array1::linspace(0.0, 1.0, 10);
1140 let y = x.mapv(|x| x * x);
1141
1142 let mut cv = CrossValidator::new().with_k_folds(3);
1143 let kernel_widths = vec![0.1, 1.0, 10.0];
1144
1145 let result = cv.optimize_rbf_parameters(&x.view(), &y.view(), &kernel_widths);
1146 assert!(result.is_ok());
1147
1148 let opt_result = result.expect("Operation failed");
1149 assert!(opt_result.best_parameters.contains_key("kernel_width"));
1150 assert_eq!(opt_result.parameter_scores.len(), 3);
1151 assert!(opt_result.best_score.is_finite());
1152 }
1153
1154 #[test]
1155 fn test_bspline_parameter_optimization() {
1156 let x = Array1::linspace(0.0, 10.0, 30);
1158 let y = x.mapv(|x| 2.0 * x + 1.0); let mut cv = CrossValidator::new().with_k_folds(2); let degrees = vec![1]; let result = cv.optimize_bspline_parameters(&x.view(), &y.view(), °rees);
1164
1165 match result {
1168 Ok(opt_result) => {
1169 assert!(opt_result.best_parameters.contains_key("degree"));
1170 assert_eq!(opt_result.parameter_scores.len(), 1);
1171 assert!(opt_result.best_score.is_finite());
1172 }
1173 Err(e) => {
1174 println!(
1177 "Cross-validation encountered numerical issues (expected): {:?}",
1178 e
1179 );
1180 assert!(matches!(e, InterpolateError::InvalidInput { .. }));
1181 }
1182 }
1183 }
1184
1185 #[test]
1186 fn test_model_selector_creation() {
1187 let selector = ModelSelector::<f64>::new();
1188 assert_eq!(selector.comparison_results.len(), 0);
1189 }
1190
1191 #[test]
1192 fn test_make_cross_validator() {
1193 let cv = make_cross_validator::<f64>(5, ValidationMetric::MeanAbsoluteError);
1194
1195 match cv.strategy {
1196 CrossValidationStrategy::KFold(k) => assert_eq!(k, 5),
1197 _ => panic!("Expected KFold strategy"),
1198 }
1199 assert_eq!(cv.metric, ValidationMetric::MeanAbsoluteError);
1200 }
1201
1202 #[test]
1203 fn test_extract_indices() {
1204 let cv = CrossValidator::<f64>::new();
1205 let arr = Array1::from_vec(vec![10.0, 20.0, 30.0, 40.0, 50.0]);
1206 let indices = vec![0, 2, 4];
1207
1208 let extracted = cv.extract_indices(&arr.view(), &indices);
1209 assert_eq!(extracted, Array1::from_vec(vec![10.0, 30.0, 50.0]));
1210 }
1211
1212 #[test]
1213 fn test_validation_metrics() {
1214 let cv_mse = CrossValidator::<f64>::new().with_metric(ValidationMetric::MeanSquaredError);
1215 let cv_mae = CrossValidator::<f64>::new().with_metric(ValidationMetric::MeanAbsoluteError);
1216 let cv_rmse =
1217 CrossValidator::<f64>::new().with_metric(ValidationMetric::RootMeanSquaredError);
1218
1219 let y_true = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1220 let y_pred = Array1::from_vec(vec![1.5, 2.5, 2.5]);
1221
1222 let mse = cv_mse
1223 .compute_metric(&y_true.view(), &y_pred.view())
1224 .expect("Operation failed");
1225 let mae = cv_mae
1226 .compute_metric(&y_true.view(), &y_pred.view())
1227 .expect("Operation failed");
1228 let rmse = cv_rmse
1229 .compute_metric(&y_true.view(), &y_pred.view())
1230 .expect("Operation failed");
1231
1232 assert!(mse > 0.0);
1233 assert!(mae > 0.0);
1234 assert!((rmse - mse.sqrt()).abs() < 1e-10);
1235 }
1236}