1use scirs2_core::ndarray::{Array1, Array2, Axis};
7use scirs2_core::random::essentials::Normal as RandNormal;
8use scirs2_core::random::rngs::StdRng as RealStdRng;
9use scirs2_core::random::SeedableRng;
10use scirs2_core::RngExt;
11use sklears_core::error::Result;
12use std::collections::HashMap;
13
14#[derive(Debug, Clone)]
16pub struct KernelApproximationValidator {
18 config: ValidationConfig,
19 theoretical_bounds: HashMap<String, TheoreticalBound>,
20}
21
22#[derive(Debug, Clone)]
24pub struct ValidationConfig {
26 pub confidence_level: f64,
28 pub max_approximation_error: f64,
30 pub convergence_tolerance: f64,
32 pub stability_tolerance: f64,
34 pub sample_sizes: Vec<usize>,
36 pub approximation_dimensions: Vec<usize>,
38 pub repetitions: usize,
40 pub random_state: Option<u64>,
42}
43
44impl Default for ValidationConfig {
45 fn default() -> Self {
46 Self {
47 confidence_level: 0.95,
48 max_approximation_error: 0.1,
49 convergence_tolerance: 1e-6,
50 stability_tolerance: 1e-4,
51 sample_sizes: vec![100, 500, 1000, 2000],
52 approximation_dimensions: vec![50, 100, 200, 500],
53 repetitions: 10,
54 random_state: Some(42),
55 }
56 }
57}
58
59#[derive(Debug, Clone)]
61pub struct TheoreticalBound {
63 pub method_name: String,
65 pub bound_type: BoundType,
67 pub bound_function: BoundFunction,
69 pub constants: HashMap<String, f64>,
71}
72
73#[derive(Debug, Clone)]
75pub enum BoundType {
77 Probabilistic { confidence: f64 },
79 Deterministic,
81 Expected,
83 Concentration { deviation_parameter: f64 },
85}
86
87#[derive(Debug, Clone)]
89pub enum BoundFunction {
91 RandomFourierFeatures,
93 Nystroem,
95 StructuredRandomFeatures,
97 Fastfood,
99 Custom { formula: String },
101}
102
103#[derive(Debug, Clone)]
105pub struct ValidationResult {
107 pub method_name: String,
109 pub empirical_errors: Vec<f64>,
111 pub theoretical_bounds: Vec<f64>,
113 pub bound_violations: usize,
115 pub bound_tightness: f64,
117 pub convergence_rate: Option<f64>,
119 pub stability_analysis: StabilityAnalysis,
121 pub sample_complexity: SampleComplexityAnalysis,
123 pub dimension_dependency: DimensionDependencyAnalysis,
125}
126
127#[derive(Debug, Clone)]
129pub struct StabilityAnalysis {
131 pub perturbation_sensitivity: f64,
133 pub numerical_stability: f64,
135 pub condition_numbers: Vec<f64>,
137 pub eigenvalue_stability: f64,
139}
140
141#[derive(Debug, Clone)]
143pub struct SampleComplexityAnalysis {
145 pub minimum_samples: usize,
147 pub convergence_rate: f64,
149 pub sample_efficiency: f64,
151 pub dimension_scaling: f64,
153}
154
155#[derive(Debug, Clone)]
157pub struct DimensionDependencyAnalysis {
159 pub approximation_quality_vs_dimension: Vec<(usize, f64)>,
161 pub computational_cost_vs_dimension: Vec<(usize, f64)>,
163 pub optimal_dimension: usize,
165 pub dimension_efficiency: f64,
167}
168
169#[derive(Debug, Clone)]
171pub struct CrossValidationResult {
173 pub method_name: String,
175 pub cv_scores: Vec<f64>,
177 pub mean_score: f64,
179 pub std_score: f64,
181 pub best_parameters: HashMap<String, f64>,
183 pub parameter_sensitivity: HashMap<String, f64>,
185}
186
187impl KernelApproximationValidator {
188 pub fn new(config: ValidationConfig) -> Self {
190 let mut validator = Self {
191 config,
192 theoretical_bounds: HashMap::new(),
193 };
194
195 validator.add_default_bounds();
197 validator
198 }
199
200 pub fn add_theoretical_bound(&mut self, bound: TheoreticalBound) {
202 self.theoretical_bounds
203 .insert(bound.method_name.clone(), bound);
204 }
205
206 fn add_default_bounds(&mut self) {
207 self.add_theoretical_bound(TheoreticalBound {
209 method_name: "RBF".to_string(),
210 bound_type: BoundType::Probabilistic { confidence: 0.95 },
211 bound_function: BoundFunction::RandomFourierFeatures,
212 constants: [
213 ("kernel_bound".to_string(), 1.0),
214 ("lipschitz_constant".to_string(), 1.0),
215 ]
216 .iter()
217 .cloned()
218 .collect(),
219 });
220
221 self.add_theoretical_bound(TheoreticalBound {
223 method_name: "Nystroem".to_string(),
224 bound_type: BoundType::Expected,
225 bound_function: BoundFunction::Nystroem,
226 constants: [
227 ("trace_bound".to_string(), 1.0),
228 ("effective_rank".to_string(), 100.0),
229 ]
230 .iter()
231 .cloned()
232 .collect(),
233 });
234
235 self.add_theoretical_bound(TheoreticalBound {
237 method_name: "Fastfood".to_string(),
238 bound_type: BoundType::Probabilistic { confidence: 0.95 },
239 bound_function: BoundFunction::Fastfood,
240 constants: [
241 ("dimension_factor".to_string(), 1.0),
242 ("log_factor".to_string(), 2.0),
243 ]
244 .iter()
245 .cloned()
246 .collect(),
247 });
248 }
249
250 pub fn validate_method<T: ValidatableKernelMethod>(
252 &self,
253 method: &T,
254 data: &Array2<f64>,
255 true_kernel: Option<&Array2<f64>>,
256 ) -> Result<ValidationResult> {
257 let method_name = method.method_name();
258 let mut empirical_errors = Vec::new();
259 let mut theoretical_bounds = Vec::new();
260 let mut condition_numbers = Vec::new();
261
262 for &n_components in &self.config.approximation_dimensions {
264 let mut dimension_errors = Vec::new();
265
266 for _ in 0..self.config.repetitions {
267 let fitted = method.fit_with_dimension(data, n_components)?;
269 let approximation = fitted.get_kernel_approximation(data)?;
270
271 let empirical_error = if let Some(true_k) = true_kernel {
273 self.compute_approximation_error(&approximation, true_k)?
274 } else {
275 let rbf_kernel = self.compute_rbf_kernel(data, 1.0)?;
277 self.compute_approximation_error(&approximation, &rbf_kernel)?
278 };
279
280 dimension_errors.push(empirical_error);
281
282 if let Some(cond_num) = fitted.compute_condition_number()? {
284 condition_numbers.push(cond_num);
285 }
286 }
287
288 let mean_error = dimension_errors.iter().sum::<f64>() / dimension_errors.len() as f64;
289 empirical_errors.push(mean_error);
290
291 if let Some(bound) = self.theoretical_bounds.get(&method_name) {
293 let theoretical_bound = self.compute_theoretical_bound(
294 bound,
295 data.nrows(),
296 data.ncols(),
297 n_components,
298 )?;
299 theoretical_bounds.push(theoretical_bound);
300 } else {
301 theoretical_bounds.push(f64::INFINITY);
302 }
303 }
304
305 let bound_violations = empirical_errors
307 .iter()
308 .zip(theoretical_bounds.iter())
309 .filter(|(&emp, &theo)| emp > theo)
310 .count();
311
312 let bound_tightness = empirical_errors
314 .iter()
315 .zip(theoretical_bounds.iter())
316 .filter(|(_, &theo)| theo.is_finite())
317 .map(|(&emp, &theo)| emp / theo)
318 .sum::<f64>()
319 / empirical_errors.len() as f64;
320
321 let convergence_rate = self.estimate_convergence_rate(&empirical_errors);
323
324 let stability_analysis = self.analyze_stability(method, data, &condition_numbers)?;
326
327 let sample_complexity = self.analyze_sample_complexity(method, data)?;
329
330 let dimension_dependency =
332 self.analyze_dimension_dependency(method, data, &empirical_errors)?;
333
334 Ok(ValidationResult {
335 method_name,
336 empirical_errors,
337 theoretical_bounds,
338 bound_violations,
339 bound_tightness,
340 convergence_rate,
341 stability_analysis,
342 sample_complexity,
343 dimension_dependency,
344 })
345 }
346
347 pub fn cross_validate<T: ValidatableKernelMethod>(
349 &self,
350 method: &T,
351 data: &Array2<f64>,
352 targets: Option<&Array1<f64>>,
353 parameter_grid: HashMap<String, Vec<f64>>,
354 ) -> Result<CrossValidationResult> {
355 let mut best_score = f64::NEG_INFINITY;
356 let mut best_parameters = HashMap::new();
357 let mut all_scores = Vec::new();
358 let mut parameter_sensitivity = HashMap::new();
359
360 let param_combinations = self.generate_parameter_combinations(¶meter_grid);
362
363 for params in param_combinations {
364 let cv_scores = self.k_fold_cross_validation(method, data, targets, ¶ms, 5)?;
365 let mean_score = cv_scores.iter().sum::<f64>() / cv_scores.len() as f64;
366
367 all_scores.push(mean_score);
368
369 if mean_score > best_score {
370 best_score = mean_score;
371 best_parameters = params.clone();
372 }
373 }
374
375 for (param_name, param_values) in ¶meter_grid {
377 let mut sensitivities = Vec::new();
378
379 for ¶m_value in param_values.iter() {
380 let mut single_param = best_parameters.clone();
381 single_param.insert(param_name.clone(), param_value);
382
383 let cv_scores =
384 self.k_fold_cross_validation(method, data, targets, &single_param, 3)?;
385 let mean_score = cv_scores.iter().sum::<f64>() / cv_scores.len() as f64;
386 sensitivities.push((best_score - mean_score).abs());
387 }
388
389 let sensitivity = sensitivities.iter().sum::<f64>() / sensitivities.len() as f64;
390 parameter_sensitivity.insert(param_name.clone(), sensitivity);
391 }
392
393 let mean_score = all_scores.iter().sum::<f64>() / all_scores.len() as f64;
394 let variance = all_scores
395 .iter()
396 .map(|&x| (x - mean_score).powi(2))
397 .sum::<f64>()
398 / all_scores.len() as f64;
399 let std_score = variance.sqrt();
400
401 Ok(CrossValidationResult {
402 method_name: method.method_name(),
403 cv_scores: all_scores,
404 mean_score,
405 std_score,
406 best_parameters,
407 parameter_sensitivity,
408 })
409 }
410
411 fn compute_approximation_error(
412 &self,
413 approx_kernel: &Array2<f64>,
414 true_kernel: &Array2<f64>,
415 ) -> Result<f64> {
416 let diff = approx_kernel - true_kernel;
418 let frobenius_error = diff.mapv(|x| x * x).sum().sqrt();
419
420 let true_norm = true_kernel.mapv(|x| x * x).sum().sqrt();
422 Ok(frobenius_error / true_norm.max(1e-8))
423 }
424
425 fn compute_rbf_kernel(&self, data: &Array2<f64>, gamma: f64) -> Result<Array2<f64>> {
426 let n_samples = data.nrows();
427 let mut kernel = Array2::zeros((n_samples, n_samples));
428
429 for i in 0..n_samples {
430 for j in i..n_samples {
431 let diff = &data.row(i) - &data.row(j);
432 let dist_sq = diff.mapv(|x| x * x).sum();
433 let similarity = (-gamma * dist_sq).exp();
434 kernel[[i, j]] = similarity;
435 kernel[[j, i]] = similarity;
436 }
437 }
438
439 Ok(kernel)
440 }
441
442 fn compute_theoretical_bound(
443 &self,
444 bound: &TheoreticalBound,
445 n_samples: usize,
446 n_features: usize,
447 n_components: usize,
448 ) -> Result<f64> {
449 let bound_value = match &bound.bound_function {
450 BoundFunction::RandomFourierFeatures => {
451 let kernel_bound = bound.constants.get("kernel_bound").unwrap_or(&1.0);
452 let lipschitz = bound.constants.get("lipschitz_constant").unwrap_or(&1.0);
453
454 let log_factor = (n_features as f64).ln();
456 kernel_bound * lipschitz * (log_factor / n_components as f64).sqrt()
457 }
458 BoundFunction::Nystroem => {
459 let trace_bound = bound.constants.get("trace_bound").unwrap_or(&1.0);
460 let effective_rank = bound.constants.get("effective_rank").unwrap_or(&100.0);
461
462 trace_bound * (effective_rank / n_components as f64).sqrt()
464 }
465 BoundFunction::StructuredRandomFeatures => {
466 let log_factor = (n_features as f64).ln();
467 (n_features as f64 * log_factor / n_components as f64).sqrt()
468 }
469 BoundFunction::Fastfood => {
470 let log_factor = bound.constants.get("log_factor").unwrap_or(&2.0);
471 let dim_factor = bound.constants.get("dimension_factor").unwrap_or(&1.0);
472
473 let log_d = (n_features as f64).ln();
474 dim_factor
475 * (n_features as f64 * log_d.powf(*log_factor) / n_components as f64).sqrt()
476 }
477 BoundFunction::Custom { formula: _ } => {
478 1.0 / (n_components as f64).sqrt()
480 }
481 };
482
483 let final_bound = match &bound.bound_type {
485 BoundType::Probabilistic { confidence } => {
486 let z_score = self.inverse_normal_cdf(*confidence);
488 bound_value * (1.0 + z_score / (n_samples as f64).sqrt())
489 }
490 BoundType::Deterministic => bound_value,
491 BoundType::Expected => bound_value * 0.8, BoundType::Concentration {
493 deviation_parameter,
494 } => bound_value * (1.0 + deviation_parameter / (n_samples as f64).sqrt()),
495 };
496
497 Ok(final_bound)
498 }
499
500 fn inverse_normal_cdf(&self, p: f64) -> f64 {
501 if p <= 0.5 {
503 -self.inverse_normal_cdf(1.0 - p)
504 } else {
505 let t = (-2.0 * (1.0 - p).ln()).sqrt();
506 let c0 = 2.515517;
507 let c1 = 0.802853;
508 let c2 = 0.010328;
509 let d1 = 1.432788;
510 let d2 = 0.189269;
511 let d3 = 0.001308;
512
513 t - (c0 + c1 * t + c2 * t * t) / (1.0 + d1 * t + d2 * t * t + d3 * t * t * t)
514 }
515 }
516
517 fn estimate_convergence_rate(&self, errors: &[f64]) -> Option<f64> {
518 if errors.len() < 3 {
519 return None;
520 }
521
522 let dimensions: Vec<f64> = self
524 .config
525 .approximation_dimensions
526 .iter()
527 .take(errors.len())
528 .map(|&x| (x as f64).ln())
529 .collect();
530
531 let log_errors: Vec<f64> = errors.iter().map(|&x| x.ln()).collect();
532
533 let n = dimensions.len() as f64;
535 let sum_x = dimensions.iter().sum::<f64>();
536 let sum_y = log_errors.iter().sum::<f64>();
537 let sum_xy = dimensions
538 .iter()
539 .zip(log_errors.iter())
540 .map(|(&x, &y)| x * y)
541 .sum::<f64>();
542 let sum_x2 = dimensions.iter().map(|&x| x * x).sum::<f64>();
543
544 let slope = (n * sum_xy - sum_x * sum_y) / (n * sum_x2 - sum_x * sum_x);
545 Some(-slope) }
547
548 fn analyze_stability<T: ValidatableKernelMethod>(
549 &self,
550 method: &T,
551 data: &Array2<f64>,
552 condition_numbers: &[f64],
553 ) -> Result<StabilityAnalysis> {
554 let mut rng = RealStdRng::seed_from_u64(self.config.random_state.unwrap_or(42));
555 let normal = RandNormal::new(0.0, self.config.stability_tolerance)
556 .expect("operation should succeed");
557
558 let mut perturbation_errors = Vec::new();
560
561 for _ in 0..5 {
562 let mut perturbed_data = data.clone();
563 for elem in perturbed_data.iter_mut() {
564 *elem += rng.sample(normal);
565 }
566
567 let original_fitted = method.fit_with_dimension(data, 100)?;
568 let perturbed_fitted = method.fit_with_dimension(&perturbed_data, 100)?;
569
570 let original_approx = original_fitted.get_kernel_approximation(data)?;
571 let perturbed_approx = perturbed_fitted.get_kernel_approximation(data)?;
572
573 let error = self.compute_approximation_error(&perturbed_approx, &original_approx)?;
574 perturbation_errors.push(error);
575 }
576
577 let perturbation_sensitivity =
578 perturbation_errors.iter().sum::<f64>() / perturbation_errors.len() as f64;
579
580 let numerical_stability = if condition_numbers.is_empty() {
582 1.0
583 } else {
584 let mean_condition =
585 condition_numbers.iter().sum::<f64>() / condition_numbers.len() as f64;
586 1.0 / mean_condition.ln().max(1.0)
587 };
588
589 let eigenvalue_stability = 1.0 - perturbation_sensitivity;
591
592 Ok(StabilityAnalysis {
593 perturbation_sensitivity,
594 numerical_stability,
595 condition_numbers: condition_numbers.to_vec(),
596 eigenvalue_stability,
597 })
598 }
599
600 fn analyze_sample_complexity<T: ValidatableKernelMethod>(
601 &self,
602 method: &T,
603 data: &Array2<f64>,
604 ) -> Result<SampleComplexityAnalysis> {
605 let mut sample_errors = Vec::new();
606
607 for &n_samples in &self.config.sample_sizes {
609 if n_samples > data.nrows() {
610 continue;
611 }
612
613 let subset_data = data
614 .slice(scirs2_core::ndarray::s![..n_samples, ..])
615 .to_owned();
616 let fitted = method.fit_with_dimension(&subset_data, 100)?;
617 let approx = fitted.get_kernel_approximation(&subset_data)?;
618
619 let rbf_kernel = self.compute_rbf_kernel(&subset_data, 1.0)?;
620 let error = self.compute_approximation_error(&approx, &rbf_kernel)?;
621 sample_errors.push(error);
622 }
623
624 let target_error = self.config.max_approximation_error;
626 let minimum_samples = self
627 .config
628 .sample_sizes
629 .iter()
630 .zip(sample_errors.iter())
631 .find(|(_, &error)| error <= target_error)
632 .map(|(&samples, _)| samples)
633 .unwrap_or(
634 *self
635 .config
636 .sample_sizes
637 .last()
638 .expect("operation should succeed"),
639 );
640
641 let convergence_rate = if sample_errors.len() >= 2 {
643 let log_samples: Vec<f64> = self
644 .config
645 .sample_sizes
646 .iter()
647 .take(sample_errors.len())
648 .map(|&x| (x as f64).ln())
649 .collect();
650 let log_errors: Vec<f64> = sample_errors.iter().map(|&x| x.ln()).collect();
651
652 let n = log_samples.len() as f64;
654 let sum_x = log_samples.iter().sum::<f64>();
655 let sum_y = log_errors.iter().sum::<f64>();
656 let sum_xy = log_samples
657 .iter()
658 .zip(log_errors.iter())
659 .map(|(&x, &y)| x * y)
660 .sum::<f64>();
661 let sum_x2 = log_samples.iter().map(|&x| x * x).sum::<f64>();
662
663 -(n * sum_xy - sum_x * sum_y) / (n * sum_x2 - sum_x * sum_x)
664 } else {
665 0.5 };
667
668 let sample_efficiency = 1.0 / minimum_samples as f64;
669 let dimension_scaling = data.ncols() as f64 / minimum_samples as f64;
670
671 Ok(SampleComplexityAnalysis {
672 minimum_samples,
673 convergence_rate,
674 sample_efficiency,
675 dimension_scaling,
676 })
677 }
678
679 fn analyze_dimension_dependency<T: ValidatableKernelMethod>(
680 &self,
681 _method: &T,
682 data: &Array2<f64>,
683 errors: &[f64],
684 ) -> Result<DimensionDependencyAnalysis> {
685 let approximation_quality_vs_dimension: Vec<(usize, f64)> = self
686 .config
687 .approximation_dimensions
688 .iter()
689 .take(errors.len())
690 .zip(errors.iter())
691 .map(|(&dim, &error)| (dim, 1.0 - error)) .collect();
693
694 let computational_cost_vs_dimension: Vec<(usize, f64)> = self
696 .config
697 .approximation_dimensions
698 .iter()
699 .map(|&dim| (dim, dim as f64 * data.nrows() as f64))
700 .collect();
701
702 let optimal_dimension = approximation_quality_vs_dimension
704 .iter()
705 .zip(computational_cost_vs_dimension.iter())
706 .map(|((dim, quality), (_, cost))| (*dim, quality / cost))
707 .max_by(|a, b| a.1.partial_cmp(&b.1).expect("operation should succeed"))
708 .map(|(dim, _)| dim)
709 .unwrap_or(100);
710
711 let dimension_efficiency = approximation_quality_vs_dimension
712 .iter()
713 .map(|(_, quality)| quality)
714 .sum::<f64>()
715 / approximation_quality_vs_dimension.len() as f64;
716
717 Ok(DimensionDependencyAnalysis {
718 approximation_quality_vs_dimension,
719 computational_cost_vs_dimension,
720 optimal_dimension,
721 dimension_efficiency,
722 })
723 }
724
725 fn generate_parameter_combinations(
726 &self,
727 parameter_grid: &HashMap<String, Vec<f64>>,
728 ) -> Vec<HashMap<String, f64>> {
729 let mut combinations = vec![HashMap::new()];
730
731 for (param_name, param_values) in parameter_grid {
732 let mut new_combinations = Vec::new();
733
734 for combination in &combinations {
735 for ¶m_value in param_values {
736 let mut new_combination = combination.clone();
737 new_combination.insert(param_name.clone(), param_value);
738 new_combinations.push(new_combination);
739 }
740 }
741
742 combinations = new_combinations;
743 }
744
745 combinations
746 }
747
748 fn k_fold_cross_validation<T: ValidatableKernelMethod>(
749 &self,
750 method: &T,
751 data: &Array2<f64>,
752 _targets: Option<&Array1<f64>>,
753 parameters: &HashMap<String, f64>,
754 k: usize,
755 ) -> Result<Vec<f64>> {
756 let n_samples = data.nrows();
757 let fold_size = n_samples / k;
758 let mut scores = Vec::new();
759
760 for fold in 0..k {
761 let start_idx = fold * fold_size;
762 let end_idx = if fold == k - 1 {
763 n_samples
764 } else {
765 (fold + 1) * fold_size
766 };
767
768 let train_indices: Vec<usize> = (0..n_samples)
770 .filter(|&i| i < start_idx || i >= end_idx)
771 .collect();
772 let val_indices: Vec<usize> = (start_idx..end_idx).collect();
773
774 let train_data = data.select(Axis(0), &train_indices);
775 let val_data = data.select(Axis(0), &val_indices);
776
777 let fitted = method.fit_with_parameters(&train_data, parameters)?;
779 let train_approx = fitted.get_kernel_approximation(&train_data)?;
780 let val_approx = fitted.get_kernel_approximation(&val_data)?;
781
782 let train_kernel = self.compute_rbf_kernel(&train_data, 1.0)?;
784 let val_kernel = self.compute_rbf_kernel(&val_data, 1.0)?;
785
786 let train_error = self.compute_approximation_error(&train_approx, &train_kernel)?;
787 let val_error = self.compute_approximation_error(&val_approx, &val_kernel)?;
788
789 let score = -(train_error + val_error) / 2.0;
791 scores.push(score);
792 }
793
794 Ok(scores)
795 }
796}
797
798pub trait ValidatableKernelMethod {
800 fn method_name(&self) -> String;
802
803 fn fit_with_dimension(
805 &self,
806 data: &Array2<f64>,
807 n_components: usize,
808 ) -> Result<Box<dyn ValidatedFittedMethod>>;
809
810 fn fit_with_parameters(
812 &self,
813 data: &Array2<f64>,
814 parameters: &HashMap<String, f64>,
815 ) -> Result<Box<dyn ValidatedFittedMethod>>;
816}
817
818pub trait ValidatedFittedMethod {
820 fn get_kernel_approximation(&self, data: &Array2<f64>) -> Result<Array2<f64>>;
822
823 fn compute_condition_number(&self) -> Result<Option<f64>>;
825
826 fn approximation_dimension(&self) -> usize;
828}
829
830#[allow(non_snake_case)]
831#[cfg(test)]
832mod tests {
833 use super::*;
834 struct MockValidatableRBF {
836 gamma: f64,
837 }
838
839 impl ValidatableKernelMethod for MockValidatableRBF {
840 fn method_name(&self) -> String {
841 "MockRBF".to_string()
842 }
843
844 fn fit_with_dimension(
845 &self,
846 _data: &Array2<f64>,
847 n_components: usize,
848 ) -> Result<Box<dyn ValidatedFittedMethod>> {
849 Ok(Box::new(MockValidatedFitted { n_components }))
850 }
851
852 fn fit_with_parameters(
853 &self,
854 _data: &Array2<f64>,
855 parameters: &HashMap<String, f64>,
856 ) -> Result<Box<dyn ValidatedFittedMethod>> {
857 let n_components = parameters.get("n_components").copied().unwrap_or(100.0) as usize;
858 Ok(Box::new(MockValidatedFitted { n_components }))
859 }
860 }
861
862 struct MockValidatedFitted {
863 n_components: usize,
864 }
865
866 impl ValidatedFittedMethod for MockValidatedFitted {
867 fn get_kernel_approximation(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
868 let n_samples = data.nrows();
869 let mut kernel = Array2::zeros((n_samples, n_samples));
870
871 for i in 0..n_samples {
873 kernel[[i, i]] = 1.0;
874 for j in i + 1..n_samples {
875 let similarity = 0.5; kernel[[i, j]] = similarity;
877 kernel[[j, i]] = similarity;
878 }
879 }
880
881 Ok(kernel)
882 }
883
884 fn compute_condition_number(&self) -> Result<Option<f64>> {
885 Ok(Some(10.0))
887 }
888
889 fn approximation_dimension(&self) -> usize {
890 self.n_components
891 }
892 }
893
894 #[test]
895 fn test_validator_creation() {
896 let config = ValidationConfig::default();
897 let validator = KernelApproximationValidator::new(config);
898
899 assert!(!validator.theoretical_bounds.is_empty());
900 assert!(validator.theoretical_bounds.contains_key("RBF"));
901 }
902
903 #[test]
904 fn test_method_validation() {
905 let config = ValidationConfig {
906 approximation_dimensions: vec![10, 20],
907 repetitions: 2,
908 ..Default::default()
909 };
910 let validator = KernelApproximationValidator::new(config);
911
912 let data = Array2::from_shape_fn((50, 5), |(i, j)| (i + j) as f64 * 0.1);
913 let method = MockValidatableRBF { gamma: 1.0 };
914
915 let result = validator
916 .validate_method(&method, &data, None)
917 .expect("operation should succeed");
918
919 assert_eq!(result.method_name, "MockRBF");
920 assert_eq!(result.empirical_errors.len(), 2);
921 assert_eq!(result.theoretical_bounds.len(), 2);
922 if let Some(rate) = result.convergence_rate {
925 assert!(rate.is_finite());
926 }
927 }
928
929 #[test]
930 fn test_cross_validation() {
931 let config = ValidationConfig::default();
932 let validator = KernelApproximationValidator::new(config);
933
934 let data = Array2::from_shape_fn((30, 4), |(i, j)| (i + j) as f64 * 0.1);
935 let method = MockValidatableRBF { gamma: 1.0 };
936
937 let mut parameter_grid = HashMap::new();
938 parameter_grid.insert("gamma".to_string(), vec![0.5, 1.0, 2.0]);
939 parameter_grid.insert("n_components".to_string(), vec![10.0, 20.0]);
940
941 let result = validator
942 .cross_validate(&method, &data, None, parameter_grid)
943 .expect("operation should succeed");
944
945 assert_eq!(result.method_name, "MockRBF");
946 assert!(!result.cv_scores.is_empty());
947 assert!(!result.best_parameters.is_empty());
948 }
949
950 #[test]
951 fn test_theoretical_bounds() {
952 let config = ValidationConfig::default();
953 let validator = KernelApproximationValidator::new(config);
954
955 let bound = validator
956 .theoretical_bounds
957 .get("RBF")
958 .expect("operation should succeed");
959 let theoretical_bound = validator
960 .compute_theoretical_bound(bound, 100, 10, 50)
961 .expect("operation should succeed");
962
963 assert!(theoretical_bound > 0.0);
964 assert!(theoretical_bound.is_finite());
965 }
966}