1use scirs2_core::ndarray::{Array1, Array2, Axis};
7use scirs2_core::random::essentials::Normal as RandNormal;
8use scirs2_core::random::rngs::StdRng as RealStdRng;
9use scirs2_core::random::Rng;
10use scirs2_core::random::SeedableRng;
11use sklears_core::error::Result;
12use std::collections::HashMap;
13
14#[derive(Debug, Clone)]
16pub struct KernelApproximationValidator {
18 config: ValidationConfig,
19 theoretical_bounds: HashMap<String, TheoreticalBound>,
20}
21
22#[derive(Debug, Clone)]
24pub struct ValidationConfig {
26 pub confidence_level: f64,
28 pub max_approximation_error: f64,
30 pub convergence_tolerance: f64,
32 pub stability_tolerance: f64,
34 pub sample_sizes: Vec<usize>,
36 pub approximation_dimensions: Vec<usize>,
38 pub repetitions: usize,
40 pub random_state: Option<u64>,
42}
43
44impl Default for ValidationConfig {
45 fn default() -> Self {
46 Self {
47 confidence_level: 0.95,
48 max_approximation_error: 0.1,
49 convergence_tolerance: 1e-6,
50 stability_tolerance: 1e-4,
51 sample_sizes: vec![100, 500, 1000, 2000],
52 approximation_dimensions: vec![50, 100, 200, 500],
53 repetitions: 10,
54 random_state: Some(42),
55 }
56 }
57}
58
59#[derive(Debug, Clone)]
61pub struct TheoreticalBound {
63 pub method_name: String,
65 pub bound_type: BoundType,
67 pub bound_function: BoundFunction,
69 pub constants: HashMap<String, f64>,
71}
72
73#[derive(Debug, Clone)]
75pub enum BoundType {
77 Probabilistic { confidence: f64 },
79 Deterministic,
81 Expected,
83 Concentration { deviation_parameter: f64 },
85}
86
87#[derive(Debug, Clone)]
89pub enum BoundFunction {
91 RandomFourierFeatures,
93 Nystroem,
95 StructuredRandomFeatures,
97 Fastfood,
99 Custom { formula: String },
101}
102
103#[derive(Debug, Clone)]
105pub struct ValidationResult {
107 pub method_name: String,
109 pub empirical_errors: Vec<f64>,
111 pub theoretical_bounds: Vec<f64>,
113 pub bound_violations: usize,
115 pub bound_tightness: f64,
117 pub convergence_rate: Option<f64>,
119 pub stability_analysis: StabilityAnalysis,
121 pub sample_complexity: SampleComplexityAnalysis,
123 pub dimension_dependency: DimensionDependencyAnalysis,
125}
126
127#[derive(Debug, Clone)]
129pub struct StabilityAnalysis {
131 pub perturbation_sensitivity: f64,
133 pub numerical_stability: f64,
135 pub condition_numbers: Vec<f64>,
137 pub eigenvalue_stability: f64,
139}
140
141#[derive(Debug, Clone)]
143pub struct SampleComplexityAnalysis {
145 pub minimum_samples: usize,
147 pub convergence_rate: f64,
149 pub sample_efficiency: f64,
151 pub dimension_scaling: f64,
153}
154
155#[derive(Debug, Clone)]
157pub struct DimensionDependencyAnalysis {
159 pub approximation_quality_vs_dimension: Vec<(usize, f64)>,
161 pub computational_cost_vs_dimension: Vec<(usize, f64)>,
163 pub optimal_dimension: usize,
165 pub dimension_efficiency: f64,
167}
168
169#[derive(Debug, Clone)]
171pub struct CrossValidationResult {
173 pub method_name: String,
175 pub cv_scores: Vec<f64>,
177 pub mean_score: f64,
179 pub std_score: f64,
181 pub best_parameters: HashMap<String, f64>,
183 pub parameter_sensitivity: HashMap<String, f64>,
185}
186
187impl KernelApproximationValidator {
188 pub fn new(config: ValidationConfig) -> Self {
190 let mut validator = Self {
191 config,
192 theoretical_bounds: HashMap::new(),
193 };
194
195 validator.add_default_bounds();
197 validator
198 }
199
200 pub fn add_theoretical_bound(&mut self, bound: TheoreticalBound) {
202 self.theoretical_bounds
203 .insert(bound.method_name.clone(), bound);
204 }
205
206 fn add_default_bounds(&mut self) {
207 self.add_theoretical_bound(TheoreticalBound {
209 method_name: "RBF".to_string(),
210 bound_type: BoundType::Probabilistic { confidence: 0.95 },
211 bound_function: BoundFunction::RandomFourierFeatures,
212 constants: [
213 ("kernel_bound".to_string(), 1.0),
214 ("lipschitz_constant".to_string(), 1.0),
215 ]
216 .iter()
217 .cloned()
218 .collect(),
219 });
220
221 self.add_theoretical_bound(TheoreticalBound {
223 method_name: "Nystroem".to_string(),
224 bound_type: BoundType::Expected,
225 bound_function: BoundFunction::Nystroem,
226 constants: [
227 ("trace_bound".to_string(), 1.0),
228 ("effective_rank".to_string(), 100.0),
229 ]
230 .iter()
231 .cloned()
232 .collect(),
233 });
234
235 self.add_theoretical_bound(TheoreticalBound {
237 method_name: "Fastfood".to_string(),
238 bound_type: BoundType::Probabilistic { confidence: 0.95 },
239 bound_function: BoundFunction::Fastfood,
240 constants: [
241 ("dimension_factor".to_string(), 1.0),
242 ("log_factor".to_string(), 2.0),
243 ]
244 .iter()
245 .cloned()
246 .collect(),
247 });
248 }
249
250 pub fn validate_method<T: ValidatableKernelMethod>(
252 &self,
253 method: &T,
254 data: &Array2<f64>,
255 true_kernel: Option<&Array2<f64>>,
256 ) -> Result<ValidationResult> {
257 let method_name = method.method_name();
258 let mut empirical_errors = Vec::new();
259 let mut theoretical_bounds = Vec::new();
260 let mut condition_numbers = Vec::new();
261
262 for &n_components in &self.config.approximation_dimensions {
264 let mut dimension_errors = Vec::new();
265
266 for _ in 0..self.config.repetitions {
267 let fitted = method.fit_with_dimension(data, n_components)?;
269 let approximation = fitted.get_kernel_approximation(data)?;
270
271 let empirical_error = if let Some(true_k) = true_kernel {
273 self.compute_approximation_error(&approximation, true_k)?
274 } else {
275 let rbf_kernel = self.compute_rbf_kernel(data, 1.0)?;
277 self.compute_approximation_error(&approximation, &rbf_kernel)?
278 };
279
280 dimension_errors.push(empirical_error);
281
282 if let Some(cond_num) = fitted.compute_condition_number()? {
284 condition_numbers.push(cond_num);
285 }
286 }
287
288 let mean_error = dimension_errors.iter().sum::<f64>() / dimension_errors.len() as f64;
289 empirical_errors.push(mean_error);
290
291 if let Some(bound) = self.theoretical_bounds.get(&method_name) {
293 let theoretical_bound = self.compute_theoretical_bound(
294 bound,
295 data.nrows(),
296 data.ncols(),
297 n_components,
298 )?;
299 theoretical_bounds.push(theoretical_bound);
300 } else {
301 theoretical_bounds.push(f64::INFINITY);
302 }
303 }
304
305 let bound_violations = empirical_errors
307 .iter()
308 .zip(theoretical_bounds.iter())
309 .filter(|(&emp, &theo)| emp > theo)
310 .count();
311
312 let bound_tightness = empirical_errors
314 .iter()
315 .zip(theoretical_bounds.iter())
316 .filter(|(_, &theo)| theo.is_finite())
317 .map(|(&emp, &theo)| emp / theo)
318 .sum::<f64>()
319 / empirical_errors.len() as f64;
320
321 let convergence_rate = self.estimate_convergence_rate(&empirical_errors);
323
324 let stability_analysis = self.analyze_stability(method, data, &condition_numbers)?;
326
327 let sample_complexity = self.analyze_sample_complexity(method, data)?;
329
330 let dimension_dependency =
332 self.analyze_dimension_dependency(method, data, &empirical_errors)?;
333
334 Ok(ValidationResult {
335 method_name,
336 empirical_errors,
337 theoretical_bounds,
338 bound_violations,
339 bound_tightness,
340 convergence_rate,
341 stability_analysis,
342 sample_complexity,
343 dimension_dependency,
344 })
345 }
346
347 pub fn cross_validate<T: ValidatableKernelMethod>(
349 &self,
350 method: &T,
351 data: &Array2<f64>,
352 targets: Option<&Array1<f64>>,
353 parameter_grid: HashMap<String, Vec<f64>>,
354 ) -> Result<CrossValidationResult> {
355 let mut best_score = f64::NEG_INFINITY;
356 let mut best_parameters = HashMap::new();
357 let mut all_scores = Vec::new();
358 let mut parameter_sensitivity = HashMap::new();
359
360 let param_combinations = self.generate_parameter_combinations(¶meter_grid);
362
363 for params in param_combinations {
364 let cv_scores = self.k_fold_cross_validation(method, data, targets, ¶ms, 5)?;
365 let mean_score = cv_scores.iter().sum::<f64>() / cv_scores.len() as f64;
366
367 all_scores.push(mean_score);
368
369 if mean_score > best_score {
370 best_score = mean_score;
371 best_parameters = params.clone();
372 }
373 }
374
375 for (param_name, param_values) in ¶meter_grid {
377 let mut sensitivities = Vec::new();
378
379 for ¶m_value in param_values.iter() {
380 let mut single_param = best_parameters.clone();
381 single_param.insert(param_name.clone(), param_value);
382
383 let cv_scores =
384 self.k_fold_cross_validation(method, data, targets, &single_param, 3)?;
385 let mean_score = cv_scores.iter().sum::<f64>() / cv_scores.len() as f64;
386 sensitivities.push((best_score - mean_score).abs());
387 }
388
389 let sensitivity = sensitivities.iter().sum::<f64>() / sensitivities.len() as f64;
390 parameter_sensitivity.insert(param_name.clone(), sensitivity);
391 }
392
393 let mean_score = all_scores.iter().sum::<f64>() / all_scores.len() as f64;
394 let variance = all_scores
395 .iter()
396 .map(|&x| (x - mean_score).powi(2))
397 .sum::<f64>()
398 / all_scores.len() as f64;
399 let std_score = variance.sqrt();
400
401 Ok(CrossValidationResult {
402 method_name: method.method_name(),
403 cv_scores: all_scores,
404 mean_score,
405 std_score,
406 best_parameters,
407 parameter_sensitivity,
408 })
409 }
410
411 fn compute_approximation_error(
412 &self,
413 approx_kernel: &Array2<f64>,
414 true_kernel: &Array2<f64>,
415 ) -> Result<f64> {
416 let diff = approx_kernel - true_kernel;
418 let frobenius_error = diff.mapv(|x| x * x).sum().sqrt();
419
420 let true_norm = true_kernel.mapv(|x| x * x).sum().sqrt();
422 Ok(frobenius_error / true_norm.max(1e-8))
423 }
424
425 fn compute_rbf_kernel(&self, data: &Array2<f64>, gamma: f64) -> Result<Array2<f64>> {
426 let n_samples = data.nrows();
427 let mut kernel = Array2::zeros((n_samples, n_samples));
428
429 for i in 0..n_samples {
430 for j in i..n_samples {
431 let diff = &data.row(i) - &data.row(j);
432 let dist_sq = diff.mapv(|x| x * x).sum();
433 let similarity = (-gamma * dist_sq).exp();
434 kernel[[i, j]] = similarity;
435 kernel[[j, i]] = similarity;
436 }
437 }
438
439 Ok(kernel)
440 }
441
442 fn compute_theoretical_bound(
443 &self,
444 bound: &TheoreticalBound,
445 n_samples: usize,
446 n_features: usize,
447 n_components: usize,
448 ) -> Result<f64> {
449 let bound_value = match &bound.bound_function {
450 BoundFunction::RandomFourierFeatures => {
451 let kernel_bound = bound.constants.get("kernel_bound").unwrap_or(&1.0);
452 let lipschitz = bound.constants.get("lipschitz_constant").unwrap_or(&1.0);
453
454 let log_factor = (n_features as f64).ln();
456 kernel_bound * lipschitz * (log_factor / n_components as f64).sqrt()
457 }
458 BoundFunction::Nystroem => {
459 let trace_bound = bound.constants.get("trace_bound").unwrap_or(&1.0);
460 let effective_rank = bound.constants.get("effective_rank").unwrap_or(&100.0);
461
462 trace_bound * (effective_rank / n_components as f64).sqrt()
464 }
465 BoundFunction::StructuredRandomFeatures => {
466 let log_factor = (n_features as f64).ln();
467 (n_features as f64 * log_factor / n_components as f64).sqrt()
468 }
469 BoundFunction::Fastfood => {
470 let log_factor = bound.constants.get("log_factor").unwrap_or(&2.0);
471 let dim_factor = bound.constants.get("dimension_factor").unwrap_or(&1.0);
472
473 let log_d = (n_features as f64).ln();
474 dim_factor
475 * (n_features as f64 * log_d.powf(*log_factor) / n_components as f64).sqrt()
476 }
477 BoundFunction::Custom { formula: _ } => {
478 1.0 / (n_components as f64).sqrt()
480 }
481 };
482
483 let final_bound = match &bound.bound_type {
485 BoundType::Probabilistic { confidence } => {
486 let z_score = self.inverse_normal_cdf(*confidence);
488 bound_value * (1.0 + z_score / (n_samples as f64).sqrt())
489 }
490 BoundType::Deterministic => bound_value,
491 BoundType::Expected => bound_value * 0.8, BoundType::Concentration {
493 deviation_parameter,
494 } => bound_value * (1.0 + deviation_parameter / (n_samples as f64).sqrt()),
495 };
496
497 Ok(final_bound)
498 }
499
500 fn inverse_normal_cdf(&self, p: f64) -> f64 {
501 if p <= 0.5 {
503 -self.inverse_normal_cdf(1.0 - p)
504 } else {
505 let t = (-2.0 * (1.0 - p).ln()).sqrt();
506 let c0 = 2.515517;
507 let c1 = 0.802853;
508 let c2 = 0.010328;
509 let d1 = 1.432788;
510 let d2 = 0.189269;
511 let d3 = 0.001308;
512
513 t - (c0 + c1 * t + c2 * t * t) / (1.0 + d1 * t + d2 * t * t + d3 * t * t * t)
514 }
515 }
516
517 fn estimate_convergence_rate(&self, errors: &[f64]) -> Option<f64> {
518 if errors.len() < 3 {
519 return None;
520 }
521
522 let dimensions: Vec<f64> = self
524 .config
525 .approximation_dimensions
526 .iter()
527 .take(errors.len())
528 .map(|&x| (x as f64).ln())
529 .collect();
530
531 let log_errors: Vec<f64> = errors.iter().map(|&x| x.ln()).collect();
532
533 let n = dimensions.len() as f64;
535 let sum_x = dimensions.iter().sum::<f64>();
536 let sum_y = log_errors.iter().sum::<f64>();
537 let sum_xy = dimensions
538 .iter()
539 .zip(log_errors.iter())
540 .map(|(&x, &y)| x * y)
541 .sum::<f64>();
542 let sum_x2 = dimensions.iter().map(|&x| x * x).sum::<f64>();
543
544 let slope = (n * sum_xy - sum_x * sum_y) / (n * sum_x2 - sum_x * sum_x);
545 Some(-slope) }
547
548 fn analyze_stability<T: ValidatableKernelMethod>(
549 &self,
550 method: &T,
551 data: &Array2<f64>,
552 condition_numbers: &[f64],
553 ) -> Result<StabilityAnalysis> {
554 let mut rng = RealStdRng::seed_from_u64(self.config.random_state.unwrap_or(42));
555 let normal = RandNormal::new(0.0, self.config.stability_tolerance).unwrap();
556
557 let mut perturbation_errors = Vec::new();
559
560 for _ in 0..5 {
561 let mut perturbed_data = data.clone();
562 for elem in perturbed_data.iter_mut() {
563 *elem += rng.sample(normal);
564 }
565
566 let original_fitted = method.fit_with_dimension(data, 100)?;
567 let perturbed_fitted = method.fit_with_dimension(&perturbed_data, 100)?;
568
569 let original_approx = original_fitted.get_kernel_approximation(data)?;
570 let perturbed_approx = perturbed_fitted.get_kernel_approximation(data)?;
571
572 let error = self.compute_approximation_error(&perturbed_approx, &original_approx)?;
573 perturbation_errors.push(error);
574 }
575
576 let perturbation_sensitivity =
577 perturbation_errors.iter().sum::<f64>() / perturbation_errors.len() as f64;
578
579 let numerical_stability = if condition_numbers.is_empty() {
581 1.0
582 } else {
583 let mean_condition =
584 condition_numbers.iter().sum::<f64>() / condition_numbers.len() as f64;
585 1.0 / mean_condition.ln().max(1.0)
586 };
587
588 let eigenvalue_stability = 1.0 - perturbation_sensitivity;
590
591 Ok(StabilityAnalysis {
592 perturbation_sensitivity,
593 numerical_stability,
594 condition_numbers: condition_numbers.to_vec(),
595 eigenvalue_stability,
596 })
597 }
598
599 fn analyze_sample_complexity<T: ValidatableKernelMethod>(
600 &self,
601 method: &T,
602 data: &Array2<f64>,
603 ) -> Result<SampleComplexityAnalysis> {
604 let mut sample_errors = Vec::new();
605
606 for &n_samples in &self.config.sample_sizes {
608 if n_samples > data.nrows() {
609 continue;
610 }
611
612 let subset_data = data
613 .slice(scirs2_core::ndarray::s![..n_samples, ..])
614 .to_owned();
615 let fitted = method.fit_with_dimension(&subset_data, 100)?;
616 let approx = fitted.get_kernel_approximation(&subset_data)?;
617
618 let rbf_kernel = self.compute_rbf_kernel(&subset_data, 1.0)?;
619 let error = self.compute_approximation_error(&approx, &rbf_kernel)?;
620 sample_errors.push(error);
621 }
622
623 let target_error = self.config.max_approximation_error;
625 let minimum_samples = self
626 .config
627 .sample_sizes
628 .iter()
629 .zip(sample_errors.iter())
630 .find(|(_, &error)| error <= target_error)
631 .map(|(&samples, _)| samples)
632 .unwrap_or(*self.config.sample_sizes.last().unwrap());
633
634 let convergence_rate = if sample_errors.len() >= 2 {
636 let log_samples: Vec<f64> = self
637 .config
638 .sample_sizes
639 .iter()
640 .take(sample_errors.len())
641 .map(|&x| (x as f64).ln())
642 .collect();
643 let log_errors: Vec<f64> = sample_errors.iter().map(|&x| x.ln()).collect();
644
645 let n = log_samples.len() as f64;
647 let sum_x = log_samples.iter().sum::<f64>();
648 let sum_y = log_errors.iter().sum::<f64>();
649 let sum_xy = log_samples
650 .iter()
651 .zip(log_errors.iter())
652 .map(|(&x, &y)| x * y)
653 .sum::<f64>();
654 let sum_x2 = log_samples.iter().map(|&x| x * x).sum::<f64>();
655
656 -(n * sum_xy - sum_x * sum_y) / (n * sum_x2 - sum_x * sum_x)
657 } else {
658 0.5 };
660
661 let sample_efficiency = 1.0 / minimum_samples as f64;
662 let dimension_scaling = data.ncols() as f64 / minimum_samples as f64;
663
664 Ok(SampleComplexityAnalysis {
665 minimum_samples,
666 convergence_rate,
667 sample_efficiency,
668 dimension_scaling,
669 })
670 }
671
672 fn analyze_dimension_dependency<T: ValidatableKernelMethod>(
673 &self,
674 _method: &T,
675 data: &Array2<f64>,
676 errors: &[f64],
677 ) -> Result<DimensionDependencyAnalysis> {
678 let approximation_quality_vs_dimension: Vec<(usize, f64)> = self
679 .config
680 .approximation_dimensions
681 .iter()
682 .take(errors.len())
683 .zip(errors.iter())
684 .map(|(&dim, &error)| (dim, 1.0 - error)) .collect();
686
687 let computational_cost_vs_dimension: Vec<(usize, f64)> = self
689 .config
690 .approximation_dimensions
691 .iter()
692 .map(|&dim| (dim, dim as f64 * data.nrows() as f64))
693 .collect();
694
695 let optimal_dimension = approximation_quality_vs_dimension
697 .iter()
698 .zip(computational_cost_vs_dimension.iter())
699 .map(|((dim, quality), (_, cost))| (*dim, quality / cost))
700 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
701 .map(|(dim, _)| dim)
702 .unwrap_or(100);
703
704 let dimension_efficiency = approximation_quality_vs_dimension
705 .iter()
706 .map(|(_, quality)| quality)
707 .sum::<f64>()
708 / approximation_quality_vs_dimension.len() as f64;
709
710 Ok(DimensionDependencyAnalysis {
711 approximation_quality_vs_dimension,
712 computational_cost_vs_dimension,
713 optimal_dimension,
714 dimension_efficiency,
715 })
716 }
717
718 fn generate_parameter_combinations(
719 &self,
720 parameter_grid: &HashMap<String, Vec<f64>>,
721 ) -> Vec<HashMap<String, f64>> {
722 let mut combinations = vec![HashMap::new()];
723
724 for (param_name, param_values) in parameter_grid {
725 let mut new_combinations = Vec::new();
726
727 for combination in &combinations {
728 for ¶m_value in param_values {
729 let mut new_combination = combination.clone();
730 new_combination.insert(param_name.clone(), param_value);
731 new_combinations.push(new_combination);
732 }
733 }
734
735 combinations = new_combinations;
736 }
737
738 combinations
739 }
740
741 fn k_fold_cross_validation<T: ValidatableKernelMethod>(
742 &self,
743 method: &T,
744 data: &Array2<f64>,
745 _targets: Option<&Array1<f64>>,
746 parameters: &HashMap<String, f64>,
747 k: usize,
748 ) -> Result<Vec<f64>> {
749 let n_samples = data.nrows();
750 let fold_size = n_samples / k;
751 let mut scores = Vec::new();
752
753 for fold in 0..k {
754 let start_idx = fold * fold_size;
755 let end_idx = if fold == k - 1 {
756 n_samples
757 } else {
758 (fold + 1) * fold_size
759 };
760
761 let train_indices: Vec<usize> = (0..n_samples)
763 .filter(|&i| i < start_idx || i >= end_idx)
764 .collect();
765 let val_indices: Vec<usize> = (start_idx..end_idx).collect();
766
767 let train_data = data.select(Axis(0), &train_indices);
768 let val_data = data.select(Axis(0), &val_indices);
769
770 let fitted = method.fit_with_parameters(&train_data, parameters)?;
772 let train_approx = fitted.get_kernel_approximation(&train_data)?;
773 let val_approx = fitted.get_kernel_approximation(&val_data)?;
774
775 let train_kernel = self.compute_rbf_kernel(&train_data, 1.0)?;
777 let val_kernel = self.compute_rbf_kernel(&val_data, 1.0)?;
778
779 let train_error = self.compute_approximation_error(&train_approx, &train_kernel)?;
780 let val_error = self.compute_approximation_error(&val_approx, &val_kernel)?;
781
782 let score = -(train_error + val_error) / 2.0;
784 scores.push(score);
785 }
786
787 Ok(scores)
788 }
789}
790
791pub trait ValidatableKernelMethod {
793 fn method_name(&self) -> String;
795
796 fn fit_with_dimension(
798 &self,
799 data: &Array2<f64>,
800 n_components: usize,
801 ) -> Result<Box<dyn ValidatedFittedMethod>>;
802
803 fn fit_with_parameters(
805 &self,
806 data: &Array2<f64>,
807 parameters: &HashMap<String, f64>,
808 ) -> Result<Box<dyn ValidatedFittedMethod>>;
809}
810
811pub trait ValidatedFittedMethod {
813 fn get_kernel_approximation(&self, data: &Array2<f64>) -> Result<Array2<f64>>;
815
816 fn compute_condition_number(&self) -> Result<Option<f64>>;
818
819 fn approximation_dimension(&self) -> usize;
821}
822
823#[allow(non_snake_case)]
824#[cfg(test)]
825mod tests {
826 use super::*;
827 struct MockValidatableRBF {
829 gamma: f64,
830 }
831
832 impl ValidatableKernelMethod for MockValidatableRBF {
833 fn method_name(&self) -> String {
834 "MockRBF".to_string()
835 }
836
837 fn fit_with_dimension(
838 &self,
839 _data: &Array2<f64>,
840 n_components: usize,
841 ) -> Result<Box<dyn ValidatedFittedMethod>> {
842 Ok(Box::new(MockValidatedFitted { n_components }))
843 }
844
845 fn fit_with_parameters(
846 &self,
847 _data: &Array2<f64>,
848 parameters: &HashMap<String, f64>,
849 ) -> Result<Box<dyn ValidatedFittedMethod>> {
850 let n_components = parameters.get("n_components").copied().unwrap_or(100.0) as usize;
851 Ok(Box::new(MockValidatedFitted { n_components }))
852 }
853 }
854
855 struct MockValidatedFitted {
856 n_components: usize,
857 }
858
859 impl ValidatedFittedMethod for MockValidatedFitted {
860 fn get_kernel_approximation(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
861 let n_samples = data.nrows();
862 let mut kernel = Array2::zeros((n_samples, n_samples));
863
864 for i in 0..n_samples {
866 kernel[[i, i]] = 1.0;
867 for j in i + 1..n_samples {
868 let similarity = 0.5; kernel[[i, j]] = similarity;
870 kernel[[j, i]] = similarity;
871 }
872 }
873
874 Ok(kernel)
875 }
876
877 fn compute_condition_number(&self) -> Result<Option<f64>> {
878 Ok(Some(10.0))
880 }
881
882 fn approximation_dimension(&self) -> usize {
883 self.n_components
884 }
885 }
886
887 #[test]
888 fn test_validator_creation() {
889 let config = ValidationConfig::default();
890 let validator = KernelApproximationValidator::new(config);
891
892 assert!(!validator.theoretical_bounds.is_empty());
893 assert!(validator.theoretical_bounds.contains_key("RBF"));
894 }
895
896 #[test]
897 fn test_method_validation() {
898 let config = ValidationConfig {
899 approximation_dimensions: vec![10, 20],
900 repetitions: 2,
901 ..Default::default()
902 };
903 let validator = KernelApproximationValidator::new(config);
904
905 let data = Array2::from_shape_fn((50, 5), |(i, j)| (i + j) as f64 * 0.1);
906 let method = MockValidatableRBF { gamma: 1.0 };
907
908 let result = validator.validate_method(&method, &data, None).unwrap();
909
910 assert_eq!(result.method_name, "MockRBF");
911 assert_eq!(result.empirical_errors.len(), 2);
912 assert_eq!(result.theoretical_bounds.len(), 2);
913 if let Some(rate) = result.convergence_rate {
916 assert!(rate.is_finite());
917 }
918 }
919
920 #[test]
921 fn test_cross_validation() {
922 let config = ValidationConfig::default();
923 let validator = KernelApproximationValidator::new(config);
924
925 let data = Array2::from_shape_fn((30, 4), |(i, j)| (i + j) as f64 * 0.1);
926 let method = MockValidatableRBF { gamma: 1.0 };
927
928 let mut parameter_grid = HashMap::new();
929 parameter_grid.insert("gamma".to_string(), vec![0.5, 1.0, 2.0]);
930 parameter_grid.insert("n_components".to_string(), vec![10.0, 20.0]);
931
932 let result = validator
933 .cross_validate(&method, &data, None, parameter_grid)
934 .unwrap();
935
936 assert_eq!(result.method_name, "MockRBF");
937 assert!(!result.cv_scores.is_empty());
938 assert!(!result.best_parameters.is_empty());
939 }
940
941 #[test]
942 fn test_theoretical_bounds() {
943 let config = ValidationConfig::default();
944 let validator = KernelApproximationValidator::new(config);
945
946 let bound = validator.theoretical_bounds.get("RBF").unwrap();
947 let theoretical_bound = validator
948 .compute_theoretical_bound(bound, 100, 10, 50)
949 .unwrap();
950
951 assert!(theoretical_bound > 0.0);
952 assert!(theoretical_bound.is_finite());
953 }
954}