Skip to main content

sklears_kernel_approximation/
validation.rs

1//! Advanced Validation Framework for Kernel Approximation Methods
2//!
3//! This module provides comprehensive validation tools including theoretical error bound
4//! validation, convergence analysis, and approximation quality assessment.
5
6use scirs2_core::ndarray::{Array1, Array2, Axis};
7use scirs2_core::random::essentials::Normal as RandNormal;
8use scirs2_core::random::rngs::StdRng as RealStdRng;
9use scirs2_core::random::SeedableRng;
10use scirs2_core::RngExt;
11use sklears_core::error::Result;
12use std::collections::HashMap;
13
14/// Comprehensive validation framework for kernel approximation methods
15#[derive(Debug, Clone)]
16/// KernelApproximationValidator
17pub struct KernelApproximationValidator {
18    config: ValidationConfig,
19    theoretical_bounds: HashMap<String, TheoreticalBound>,
20}
21
22/// Configuration for validation
23#[derive(Debug, Clone)]
24/// ValidationConfig
25pub struct ValidationConfig {
26    /// confidence_level
27    pub confidence_level: f64,
28    /// max_approximation_error
29    pub max_approximation_error: f64,
30    /// convergence_tolerance
31    pub convergence_tolerance: f64,
32    /// stability_tolerance
33    pub stability_tolerance: f64,
34    /// sample_sizes
35    pub sample_sizes: Vec<usize>,
36    /// approximation_dimensions
37    pub approximation_dimensions: Vec<usize>,
38    /// repetitions
39    pub repetitions: usize,
40    /// random_state
41    pub random_state: Option<u64>,
42}
43
44impl Default for ValidationConfig {
45    fn default() -> Self {
46        Self {
47            confidence_level: 0.95,
48            max_approximation_error: 0.1,
49            convergence_tolerance: 1e-6,
50            stability_tolerance: 1e-4,
51            sample_sizes: vec![100, 500, 1000, 2000],
52            approximation_dimensions: vec![50, 100, 200, 500],
53            repetitions: 10,
54            random_state: Some(42),
55        }
56    }
57}
58
59/// Theoretical error bounds for different approximation methods
60#[derive(Debug, Clone)]
61/// TheoreticalBound
62pub struct TheoreticalBound {
63    /// method_name
64    pub method_name: String,
65    /// bound_type
66    pub bound_type: BoundType,
67    /// bound_function
68    pub bound_function: BoundFunction,
69    /// constants
70    pub constants: HashMap<String, f64>,
71}
72
73/// Types of theoretical bounds
74#[derive(Debug, Clone)]
75/// BoundType
76pub enum BoundType {
77    /// Probabilistic bound with confidence level
78    Probabilistic { confidence: f64 },
79    /// Deterministic worst-case bound
80    Deterministic,
81    /// Expected error bound
82    Expected,
83    /// Concentration inequality bound
84    Concentration { deviation_parameter: f64 },
85}
86
87/// Functions for computing theoretical bounds
88#[derive(Debug, Clone)]
89/// BoundFunction
90pub enum BoundFunction {
91    /// RFF approximation error: O(sqrt(log(d)/m))
92    RandomFourierFeatures,
93    /// Nyström approximation error: depends on eigenvalue decay
94    Nystroem,
95    /// Structured random features: O(sqrt(d*log(d)/m))
96    StructuredRandomFeatures,
97    /// Fastfood approximation: O(sqrt(d*log^2(d)/m))
98    Fastfood,
99    /// Custom bound function
100    Custom { formula: String },
101}
102
103/// Result of validation analysis
104#[derive(Debug, Clone)]
105/// ValidationResult
106pub struct ValidationResult {
107    /// method_name
108    pub method_name: String,
109    /// empirical_errors
110    pub empirical_errors: Vec<f64>,
111    /// theoretical_bounds
112    pub theoretical_bounds: Vec<f64>,
113    /// bound_violations
114    pub bound_violations: usize,
115    /// bound_tightness
116    pub bound_tightness: f64,
117    /// convergence_rate
118    pub convergence_rate: Option<f64>,
119    /// stability_analysis
120    pub stability_analysis: StabilityAnalysis,
121    /// sample_complexity
122    pub sample_complexity: SampleComplexityAnalysis,
123    /// dimension_dependency
124    pub dimension_dependency: DimensionDependencyAnalysis,
125}
126
127/// Stability analysis results
128#[derive(Debug, Clone)]
129/// StabilityAnalysis
130pub struct StabilityAnalysis {
131    /// perturbation_sensitivity
132    pub perturbation_sensitivity: f64,
133    /// numerical_stability
134    pub numerical_stability: f64,
135    /// condition_numbers
136    pub condition_numbers: Vec<f64>,
137    /// eigenvalue_stability
138    pub eigenvalue_stability: f64,
139}
140
141/// Sample complexity analysis
142#[derive(Debug, Clone)]
143/// SampleComplexityAnalysis
144pub struct SampleComplexityAnalysis {
145    /// minimum_samples
146    pub minimum_samples: usize,
147    /// convergence_rate
148    pub convergence_rate: f64,
149    /// sample_efficiency
150    pub sample_efficiency: f64,
151    /// dimension_scaling
152    pub dimension_scaling: f64,
153}
154
155/// Dimension dependency analysis
156#[derive(Debug, Clone)]
157/// DimensionDependencyAnalysis
158pub struct DimensionDependencyAnalysis {
159    /// approximation_quality_vs_dimension
160    pub approximation_quality_vs_dimension: Vec<(usize, f64)>,
161    /// computational_cost_vs_dimension
162    pub computational_cost_vs_dimension: Vec<(usize, f64)>,
163    /// optimal_dimension
164    pub optimal_dimension: usize,
165    /// dimension_efficiency
166    pub dimension_efficiency: f64,
167}
168
169/// Cross-validation result for kernel approximation
170#[derive(Debug, Clone)]
171/// CrossValidationResult
172pub struct CrossValidationResult {
173    /// method_name
174    pub method_name: String,
175    /// cv_scores
176    pub cv_scores: Vec<f64>,
177    /// mean_score
178    pub mean_score: f64,
179    /// std_score
180    pub std_score: f64,
181    /// best_parameters
182    pub best_parameters: HashMap<String, f64>,
183    /// parameter_sensitivity
184    pub parameter_sensitivity: HashMap<String, f64>,
185}
186
187impl KernelApproximationValidator {
188    /// Create a new validator with configuration
189    pub fn new(config: ValidationConfig) -> Self {
190        let mut validator = Self {
191            config,
192            theoretical_bounds: HashMap::new(),
193        };
194
195        // Add default theoretical bounds
196        validator.add_default_bounds();
197        validator
198    }
199
200    /// Add theoretical bounds for a specific method
201    pub fn add_theoretical_bound(&mut self, bound: TheoreticalBound) {
202        self.theoretical_bounds
203            .insert(bound.method_name.clone(), bound);
204    }
205
206    fn add_default_bounds(&mut self) {
207        // RFF bounds
208        self.add_theoretical_bound(TheoreticalBound {
209            method_name: "RBF".to_string(),
210            bound_type: BoundType::Probabilistic { confidence: 0.95 },
211            bound_function: BoundFunction::RandomFourierFeatures,
212            constants: [
213                ("kernel_bound".to_string(), 1.0),
214                ("lipschitz_constant".to_string(), 1.0),
215            ]
216            .iter()
217            .cloned()
218            .collect(),
219        });
220
221        // Nyström bounds
222        self.add_theoretical_bound(TheoreticalBound {
223            method_name: "Nystroem".to_string(),
224            bound_type: BoundType::Expected,
225            bound_function: BoundFunction::Nystroem,
226            constants: [
227                ("trace_bound".to_string(), 1.0),
228                ("effective_rank".to_string(), 100.0),
229            ]
230            .iter()
231            .cloned()
232            .collect(),
233        });
234
235        // Fastfood bounds
236        self.add_theoretical_bound(TheoreticalBound {
237            method_name: "Fastfood".to_string(),
238            bound_type: BoundType::Probabilistic { confidence: 0.95 },
239            bound_function: BoundFunction::Fastfood,
240            constants: [
241                ("dimension_factor".to_string(), 1.0),
242                ("log_factor".to_string(), 2.0),
243            ]
244            .iter()
245            .cloned()
246            .collect(),
247        });
248    }
249
250    /// Validate a kernel approximation method
251    pub fn validate_method<T: ValidatableKernelMethod>(
252        &self,
253        method: &T,
254        data: &Array2<f64>,
255        true_kernel: Option<&Array2<f64>>,
256    ) -> Result<ValidationResult> {
257        let method_name = method.method_name();
258        let mut empirical_errors = Vec::new();
259        let mut theoretical_bounds = Vec::new();
260        let mut condition_numbers = Vec::new();
261
262        // Test different approximation dimensions
263        for &n_components in &self.config.approximation_dimensions {
264            let mut dimension_errors = Vec::new();
265
266            for _ in 0..self.config.repetitions {
267                // Fit and evaluate the method
268                let fitted = method.fit_with_dimension(data, n_components)?;
269                let approximation = fitted.get_kernel_approximation(data)?;
270
271                // Compute empirical error
272                let empirical_error = if let Some(true_k) = true_kernel {
273                    self.compute_approximation_error(&approximation, true_k)?
274                } else {
275                    // Use RBF kernel as reference
276                    let rbf_kernel = self.compute_rbf_kernel(data, 1.0)?;
277                    self.compute_approximation_error(&approximation, &rbf_kernel)?
278                };
279
280                dimension_errors.push(empirical_error);
281
282                // Compute condition number for stability analysis
283                if let Some(cond_num) = fitted.compute_condition_number()? {
284                    condition_numbers.push(cond_num);
285                }
286            }
287
288            let mean_error = dimension_errors.iter().sum::<f64>() / dimension_errors.len() as f64;
289            empirical_errors.push(mean_error);
290
291            // Compute theoretical bound
292            if let Some(bound) = self.theoretical_bounds.get(&method_name) {
293                let theoretical_bound = self.compute_theoretical_bound(
294                    bound,
295                    data.nrows(),
296                    data.ncols(),
297                    n_components,
298                )?;
299                theoretical_bounds.push(theoretical_bound);
300            } else {
301                theoretical_bounds.push(f64::INFINITY);
302            }
303        }
304
305        // Count bound violations
306        let bound_violations = empirical_errors
307            .iter()
308            .zip(theoretical_bounds.iter())
309            .filter(|(&emp, &theo)| emp > theo)
310            .count();
311
312        // Compute bound tightness (average ratio of empirical to theoretical)
313        let bound_tightness = empirical_errors
314            .iter()
315            .zip(theoretical_bounds.iter())
316            .filter(|(_, &theo)| theo.is_finite())
317            .map(|(&emp, &theo)| emp / theo)
318            .sum::<f64>()
319            / empirical_errors.len() as f64;
320
321        // Analyze convergence rate
322        let convergence_rate = self.estimate_convergence_rate(&empirical_errors);
323
324        // Perform stability analysis
325        let stability_analysis = self.analyze_stability(method, data, &condition_numbers)?;
326
327        // Analyze sample complexity
328        let sample_complexity = self.analyze_sample_complexity(method, data)?;
329
330        // Analyze dimension dependency
331        let dimension_dependency =
332            self.analyze_dimension_dependency(method, data, &empirical_errors)?;
333
334        Ok(ValidationResult {
335            method_name,
336            empirical_errors,
337            theoretical_bounds,
338            bound_violations,
339            bound_tightness,
340            convergence_rate,
341            stability_analysis,
342            sample_complexity,
343            dimension_dependency,
344        })
345    }
346
347    /// Perform cross-validation for parameter selection
348    pub fn cross_validate<T: ValidatableKernelMethod>(
349        &self,
350        method: &T,
351        data: &Array2<f64>,
352        targets: Option<&Array1<f64>>,
353        parameter_grid: HashMap<String, Vec<f64>>,
354    ) -> Result<CrossValidationResult> {
355        let mut best_score = f64::NEG_INFINITY;
356        let mut best_parameters = HashMap::new();
357        let mut all_scores = Vec::new();
358        let mut parameter_sensitivity = HashMap::new();
359
360        // Generate parameter combinations
361        let param_combinations = self.generate_parameter_combinations(&parameter_grid);
362
363        for params in param_combinations {
364            let cv_scores = self.k_fold_cross_validation(method, data, targets, &params, 5)?;
365            let mean_score = cv_scores.iter().sum::<f64>() / cv_scores.len() as f64;
366
367            all_scores.push(mean_score);
368
369            if mean_score > best_score {
370                best_score = mean_score;
371                best_parameters = params.clone();
372            }
373        }
374
375        // Analyze parameter sensitivity
376        for (param_name, param_values) in &parameter_grid {
377            let mut sensitivities = Vec::new();
378
379            for &param_value in param_values.iter() {
380                let mut single_param = best_parameters.clone();
381                single_param.insert(param_name.clone(), param_value);
382
383                let cv_scores =
384                    self.k_fold_cross_validation(method, data, targets, &single_param, 3)?;
385                let mean_score = cv_scores.iter().sum::<f64>() / cv_scores.len() as f64;
386                sensitivities.push((best_score - mean_score).abs());
387            }
388
389            let sensitivity = sensitivities.iter().sum::<f64>() / sensitivities.len() as f64;
390            parameter_sensitivity.insert(param_name.clone(), sensitivity);
391        }
392
393        let mean_score = all_scores.iter().sum::<f64>() / all_scores.len() as f64;
394        let variance = all_scores
395            .iter()
396            .map(|&x| (x - mean_score).powi(2))
397            .sum::<f64>()
398            / all_scores.len() as f64;
399        let std_score = variance.sqrt();
400
401        Ok(CrossValidationResult {
402            method_name: method.method_name(),
403            cv_scores: all_scores,
404            mean_score,
405            std_score,
406            best_parameters,
407            parameter_sensitivity,
408        })
409    }
410
411    fn compute_approximation_error(
412        &self,
413        approx_kernel: &Array2<f64>,
414        true_kernel: &Array2<f64>,
415    ) -> Result<f64> {
416        // Compute Frobenius norm error
417        let diff = approx_kernel - true_kernel;
418        let frobenius_error = diff.mapv(|x| x * x).sum().sqrt();
419
420        // Normalize by true kernel norm
421        let true_norm = true_kernel.mapv(|x| x * x).sum().sqrt();
422        Ok(frobenius_error / true_norm.max(1e-8))
423    }
424
425    fn compute_rbf_kernel(&self, data: &Array2<f64>, gamma: f64) -> Result<Array2<f64>> {
426        let n_samples = data.nrows();
427        let mut kernel = Array2::zeros((n_samples, n_samples));
428
429        for i in 0..n_samples {
430            for j in i..n_samples {
431                let diff = &data.row(i) - &data.row(j);
432                let dist_sq = diff.mapv(|x| x * x).sum();
433                let similarity = (-gamma * dist_sq).exp();
434                kernel[[i, j]] = similarity;
435                kernel[[j, i]] = similarity;
436            }
437        }
438
439        Ok(kernel)
440    }
441
442    fn compute_theoretical_bound(
443        &self,
444        bound: &TheoreticalBound,
445        n_samples: usize,
446        n_features: usize,
447        n_components: usize,
448    ) -> Result<f64> {
449        let bound_value = match &bound.bound_function {
450            BoundFunction::RandomFourierFeatures => {
451                let kernel_bound = bound.constants.get("kernel_bound").unwrap_or(&1.0);
452                let lipschitz = bound.constants.get("lipschitz_constant").unwrap_or(&1.0);
453
454                // O(sqrt(log(d)/m)) bound for RFF
455                let log_factor = (n_features as f64).ln();
456                kernel_bound * lipschitz * (log_factor / n_components as f64).sqrt()
457            }
458            BoundFunction::Nystroem => {
459                let trace_bound = bound.constants.get("trace_bound").unwrap_or(&1.0);
460                let effective_rank = bound.constants.get("effective_rank").unwrap_or(&100.0);
461
462                // Approximation depends on eigenvalue decay
463                trace_bound * (effective_rank / n_components as f64).sqrt()
464            }
465            BoundFunction::StructuredRandomFeatures => {
466                let log_factor = (n_features as f64).ln();
467                (n_features as f64 * log_factor / n_components as f64).sqrt()
468            }
469            BoundFunction::Fastfood => {
470                let log_factor = bound.constants.get("log_factor").unwrap_or(&2.0);
471                let dim_factor = bound.constants.get("dimension_factor").unwrap_or(&1.0);
472
473                let log_d = (n_features as f64).ln();
474                dim_factor
475                    * (n_features as f64 * log_d.powf(*log_factor) / n_components as f64).sqrt()
476            }
477            BoundFunction::Custom { formula: _ } => {
478                // Placeholder for custom formulas
479                1.0 / (n_components as f64).sqrt()
480            }
481        };
482
483        // Apply bound type modifications
484        let final_bound = match &bound.bound_type {
485            BoundType::Probabilistic { confidence } => {
486                // Add confidence-dependent factor
487                let z_score = self.inverse_normal_cdf(*confidence);
488                bound_value * (1.0 + z_score / (n_samples as f64).sqrt())
489            }
490            BoundType::Deterministic => bound_value,
491            BoundType::Expected => bound_value * 0.8, // Expected is typically tighter
492            BoundType::Concentration {
493                deviation_parameter,
494            } => bound_value * (1.0 + deviation_parameter / (n_samples as f64).sqrt()),
495        };
496
497        Ok(final_bound)
498    }
499
500    fn inverse_normal_cdf(&self, p: f64) -> f64 {
501        // Approximation of inverse normal CDF for confidence intervals
502        if p <= 0.5 {
503            -self.inverse_normal_cdf(1.0 - p)
504        } else {
505            let t = (-2.0 * (1.0 - p).ln()).sqrt();
506            let c0 = 2.515517;
507            let c1 = 0.802853;
508            let c2 = 0.010328;
509            let d1 = 1.432788;
510            let d2 = 0.189269;
511            let d3 = 0.001308;
512
513            t - (c0 + c1 * t + c2 * t * t) / (1.0 + d1 * t + d2 * t * t + d3 * t * t * t)
514        }
515    }
516
517    fn estimate_convergence_rate(&self, errors: &[f64]) -> Option<f64> {
518        if errors.len() < 3 {
519            return None;
520        }
521
522        // Fit log(error) = a + b * log(dimension) to estimate convergence rate
523        let dimensions: Vec<f64> = self
524            .config
525            .approximation_dimensions
526            .iter()
527            .take(errors.len())
528            .map(|&x| (x as f64).ln())
529            .collect();
530
531        let log_errors: Vec<f64> = errors.iter().map(|&x| x.ln()).collect();
532
533        // Simple linear regression
534        let n = dimensions.len() as f64;
535        let sum_x = dimensions.iter().sum::<f64>();
536        let sum_y = log_errors.iter().sum::<f64>();
537        let sum_xy = dimensions
538            .iter()
539            .zip(log_errors.iter())
540            .map(|(&x, &y)| x * y)
541            .sum::<f64>();
542        let sum_x2 = dimensions.iter().map(|&x| x * x).sum::<f64>();
543
544        let slope = (n * sum_xy - sum_x * sum_y) / (n * sum_x2 - sum_x * sum_x);
545        Some(-slope) // Negative because we expect decreasing error
546    }
547
548    fn analyze_stability<T: ValidatableKernelMethod>(
549        &self,
550        method: &T,
551        data: &Array2<f64>,
552        condition_numbers: &[f64],
553    ) -> Result<StabilityAnalysis> {
554        let mut rng = RealStdRng::seed_from_u64(self.config.random_state.unwrap_or(42));
555        let normal = RandNormal::new(0.0, self.config.stability_tolerance)
556            .expect("operation should succeed");
557
558        // Test perturbation sensitivity
559        let mut perturbation_errors = Vec::new();
560
561        for _ in 0..5 {
562            let mut perturbed_data = data.clone();
563            for elem in perturbed_data.iter_mut() {
564                *elem += rng.sample(normal);
565            }
566
567            let original_fitted = method.fit_with_dimension(data, 100)?;
568            let perturbed_fitted = method.fit_with_dimension(&perturbed_data, 100)?;
569
570            let original_approx = original_fitted.get_kernel_approximation(data)?;
571            let perturbed_approx = perturbed_fitted.get_kernel_approximation(data)?;
572
573            let error = self.compute_approximation_error(&perturbed_approx, &original_approx)?;
574            perturbation_errors.push(error);
575        }
576
577        let perturbation_sensitivity =
578            perturbation_errors.iter().sum::<f64>() / perturbation_errors.len() as f64;
579
580        // Numerical stability from condition numbers
581        let numerical_stability = if condition_numbers.is_empty() {
582            1.0
583        } else {
584            let mean_condition =
585                condition_numbers.iter().sum::<f64>() / condition_numbers.len() as f64;
586            1.0 / mean_condition.ln().max(1.0)
587        };
588
589        // Eigenvalue stability (placeholder)
590        let eigenvalue_stability = 1.0 - perturbation_sensitivity;
591
592        Ok(StabilityAnalysis {
593            perturbation_sensitivity,
594            numerical_stability,
595            condition_numbers: condition_numbers.to_vec(),
596            eigenvalue_stability,
597        })
598    }
599
600    fn analyze_sample_complexity<T: ValidatableKernelMethod>(
601        &self,
602        method: &T,
603        data: &Array2<f64>,
604    ) -> Result<SampleComplexityAnalysis> {
605        let mut sample_errors = Vec::new();
606
607        // Test different sample sizes
608        for &n_samples in &self.config.sample_sizes {
609            if n_samples > data.nrows() {
610                continue;
611            }
612
613            let subset_data = data
614                .slice(scirs2_core::ndarray::s![..n_samples, ..])
615                .to_owned();
616            let fitted = method.fit_with_dimension(&subset_data, 100)?;
617            let approx = fitted.get_kernel_approximation(&subset_data)?;
618
619            let rbf_kernel = self.compute_rbf_kernel(&subset_data, 1.0)?;
620            let error = self.compute_approximation_error(&approx, &rbf_kernel)?;
621            sample_errors.push(error);
622        }
623
624        // Estimate minimum required samples
625        let target_error = self.config.max_approximation_error;
626        let minimum_samples = self
627            .config
628            .sample_sizes
629            .iter()
630            .zip(sample_errors.iter())
631            .find(|(_, &error)| error <= target_error)
632            .map(|(&samples, _)| samples)
633            .unwrap_or(
634                *self
635                    .config
636                    .sample_sizes
637                    .last()
638                    .expect("operation should succeed"),
639            );
640
641        // Estimate convergence rate with respect to sample size
642        let convergence_rate = if sample_errors.len() >= 2 {
643            let log_samples: Vec<f64> = self
644                .config
645                .sample_sizes
646                .iter()
647                .take(sample_errors.len())
648                .map(|&x| (x as f64).ln())
649                .collect();
650            let log_errors: Vec<f64> = sample_errors.iter().map(|&x| x.ln()).collect();
651
652            // Linear regression for convergence rate
653            let n = log_samples.len() as f64;
654            let sum_x = log_samples.iter().sum::<f64>();
655            let sum_y = log_errors.iter().sum::<f64>();
656            let sum_xy = log_samples
657                .iter()
658                .zip(log_errors.iter())
659                .map(|(&x, &y)| x * y)
660                .sum::<f64>();
661            let sum_x2 = log_samples.iter().map(|&x| x * x).sum::<f64>();
662
663            -(n * sum_xy - sum_x * sum_y) / (n * sum_x2 - sum_x * sum_x)
664        } else {
665            0.5 // Default assumption
666        };
667
668        let sample_efficiency = 1.0 / minimum_samples as f64;
669        let dimension_scaling = data.ncols() as f64 / minimum_samples as f64;
670
671        Ok(SampleComplexityAnalysis {
672            minimum_samples,
673            convergence_rate,
674            sample_efficiency,
675            dimension_scaling,
676        })
677    }
678
679    fn analyze_dimension_dependency<T: ValidatableKernelMethod>(
680        &self,
681        _method: &T,
682        data: &Array2<f64>,
683        errors: &[f64],
684    ) -> Result<DimensionDependencyAnalysis> {
685        let approximation_quality_vs_dimension: Vec<(usize, f64)> = self
686            .config
687            .approximation_dimensions
688            .iter()
689            .take(errors.len())
690            .zip(errors.iter())
691            .map(|(&dim, &error)| (dim, 1.0 - error)) // Convert error to quality
692            .collect();
693
694        // Estimate computational cost (simplified)
695        let computational_cost_vs_dimension: Vec<(usize, f64)> = self
696            .config
697            .approximation_dimensions
698            .iter()
699            .map(|&dim| (dim, dim as f64 * data.nrows() as f64))
700            .collect();
701
702        // Find optimal dimension (best quality-to-cost ratio)
703        let optimal_dimension = approximation_quality_vs_dimension
704            .iter()
705            .zip(computational_cost_vs_dimension.iter())
706            .map(|((dim, quality), (_, cost))| (*dim, quality / cost))
707            .max_by(|a, b| a.1.partial_cmp(&b.1).expect("operation should succeed"))
708            .map(|(dim, _)| dim)
709            .unwrap_or(100);
710
711        let dimension_efficiency = approximation_quality_vs_dimension
712            .iter()
713            .map(|(_, quality)| quality)
714            .sum::<f64>()
715            / approximation_quality_vs_dimension.len() as f64;
716
717        Ok(DimensionDependencyAnalysis {
718            approximation_quality_vs_dimension,
719            computational_cost_vs_dimension,
720            optimal_dimension,
721            dimension_efficiency,
722        })
723    }
724
725    fn generate_parameter_combinations(
726        &self,
727        parameter_grid: &HashMap<String, Vec<f64>>,
728    ) -> Vec<HashMap<String, f64>> {
729        let mut combinations = vec![HashMap::new()];
730
731        for (param_name, param_values) in parameter_grid {
732            let mut new_combinations = Vec::new();
733
734            for combination in &combinations {
735                for &param_value in param_values {
736                    let mut new_combination = combination.clone();
737                    new_combination.insert(param_name.clone(), param_value);
738                    new_combinations.push(new_combination);
739                }
740            }
741
742            combinations = new_combinations;
743        }
744
745        combinations
746    }
747
748    fn k_fold_cross_validation<T: ValidatableKernelMethod>(
749        &self,
750        method: &T,
751        data: &Array2<f64>,
752        _targets: Option<&Array1<f64>>,
753        parameters: &HashMap<String, f64>,
754        k: usize,
755    ) -> Result<Vec<f64>> {
756        let n_samples = data.nrows();
757        let fold_size = n_samples / k;
758        let mut scores = Vec::new();
759
760        for fold in 0..k {
761            let start_idx = fold * fold_size;
762            let end_idx = if fold == k - 1 {
763                n_samples
764            } else {
765                (fold + 1) * fold_size
766            };
767
768            // Create train and validation sets
769            let train_indices: Vec<usize> = (0..n_samples)
770                .filter(|&i| i < start_idx || i >= end_idx)
771                .collect();
772            let val_indices: Vec<usize> = (start_idx..end_idx).collect();
773
774            let train_data = data.select(Axis(0), &train_indices);
775            let val_data = data.select(Axis(0), &val_indices);
776
777            // Fit with parameters
778            let fitted = method.fit_with_parameters(&train_data, parameters)?;
779            let train_approx = fitted.get_kernel_approximation(&train_data)?;
780            let val_approx = fitted.get_kernel_approximation(&val_data)?;
781
782            // Compute validation score (kernel alignment as proxy)
783            let train_kernel = self.compute_rbf_kernel(&train_data, 1.0)?;
784            let val_kernel = self.compute_rbf_kernel(&val_data, 1.0)?;
785
786            let train_error = self.compute_approximation_error(&train_approx, &train_kernel)?;
787            let val_error = self.compute_approximation_error(&val_approx, &val_kernel)?;
788
789            // Score is negative error (higher is better)
790            let score = -(train_error + val_error) / 2.0;
791            scores.push(score);
792        }
793
794        Ok(scores)
795    }
796}
797
798/// Trait for kernel methods that can be validated
799pub trait ValidatableKernelMethod {
800    /// Get method name
801    fn method_name(&self) -> String;
802
803    /// Fit with specific approximation dimension
804    fn fit_with_dimension(
805        &self,
806        data: &Array2<f64>,
807        n_components: usize,
808    ) -> Result<Box<dyn ValidatedFittedMethod>>;
809
810    /// Fit with specific parameters
811    fn fit_with_parameters(
812        &self,
813        data: &Array2<f64>,
814        parameters: &HashMap<String, f64>,
815    ) -> Result<Box<dyn ValidatedFittedMethod>>;
816}
817
818/// Trait for fitted methods that can be validated
819pub trait ValidatedFittedMethod {
820    /// Get kernel approximation matrix
821    fn get_kernel_approximation(&self, data: &Array2<f64>) -> Result<Array2<f64>>;
822
823    /// Compute condition number if applicable
824    fn compute_condition_number(&self) -> Result<Option<f64>>;
825
826    /// Get approximation dimension
827    fn approximation_dimension(&self) -> usize;
828}
829
830#[allow(non_snake_case)]
831#[cfg(test)]
832mod tests {
833    use super::*;
834    // Mock implementation for testing
835    struct MockValidatableRBF {
836        gamma: f64,
837    }
838
839    impl ValidatableKernelMethod for MockValidatableRBF {
840        fn method_name(&self) -> String {
841            "MockRBF".to_string()
842        }
843
844        fn fit_with_dimension(
845            &self,
846            _data: &Array2<f64>,
847            n_components: usize,
848        ) -> Result<Box<dyn ValidatedFittedMethod>> {
849            Ok(Box::new(MockValidatedFitted { n_components }))
850        }
851
852        fn fit_with_parameters(
853            &self,
854            _data: &Array2<f64>,
855            parameters: &HashMap<String, f64>,
856        ) -> Result<Box<dyn ValidatedFittedMethod>> {
857            let n_components = parameters.get("n_components").copied().unwrap_or(100.0) as usize;
858            Ok(Box::new(MockValidatedFitted { n_components }))
859        }
860    }
861
862    struct MockValidatedFitted {
863        n_components: usize,
864    }
865
866    impl ValidatedFittedMethod for MockValidatedFitted {
867        fn get_kernel_approximation(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
868            let n_samples = data.nrows();
869            let mut kernel = Array2::zeros((n_samples, n_samples));
870
871            // Simple mock kernel matrix (identity-like)
872            for i in 0..n_samples {
873                kernel[[i, i]] = 1.0;
874                for j in i + 1..n_samples {
875                    let similarity = 0.5; // Simple mock similarity
876                    kernel[[i, j]] = similarity;
877                    kernel[[j, i]] = similarity;
878                }
879            }
880
881            Ok(kernel)
882        }
883
884        fn compute_condition_number(&self) -> Result<Option<f64>> {
885            // Simplified condition number estimation
886            Ok(Some(10.0))
887        }
888
889        fn approximation_dimension(&self) -> usize {
890            self.n_components
891        }
892    }
893
894    #[test]
895    fn test_validator_creation() {
896        let config = ValidationConfig::default();
897        let validator = KernelApproximationValidator::new(config);
898
899        assert!(!validator.theoretical_bounds.is_empty());
900        assert!(validator.theoretical_bounds.contains_key("RBF"));
901    }
902
903    #[test]
904    fn test_method_validation() {
905        let config = ValidationConfig {
906            approximation_dimensions: vec![10, 20],
907            repetitions: 2,
908            ..Default::default()
909        };
910        let validator = KernelApproximationValidator::new(config);
911
912        let data = Array2::from_shape_fn((50, 5), |(i, j)| (i + j) as f64 * 0.1);
913        let method = MockValidatableRBF { gamma: 1.0 };
914
915        let result = validator
916            .validate_method(&method, &data, None)
917            .expect("operation should succeed");
918
919        assert_eq!(result.method_name, "MockRBF");
920        assert_eq!(result.empirical_errors.len(), 2);
921        assert_eq!(result.theoretical_bounds.len(), 2);
922        // Convergence rate may be None if insufficient data points or poor fit
923        // This is acceptable as long as other validation results are present
924        if let Some(rate) = result.convergence_rate {
925            assert!(rate.is_finite());
926        }
927    }
928
929    #[test]
930    fn test_cross_validation() {
931        let config = ValidationConfig::default();
932        let validator = KernelApproximationValidator::new(config);
933
934        let data = Array2::from_shape_fn((30, 4), |(i, j)| (i + j) as f64 * 0.1);
935        let method = MockValidatableRBF { gamma: 1.0 };
936
937        let mut parameter_grid = HashMap::new();
938        parameter_grid.insert("gamma".to_string(), vec![0.5, 1.0, 2.0]);
939        parameter_grid.insert("n_components".to_string(), vec![10.0, 20.0]);
940
941        let result = validator
942            .cross_validate(&method, &data, None, parameter_grid)
943            .expect("operation should succeed");
944
945        assert_eq!(result.method_name, "MockRBF");
946        assert!(!result.cv_scores.is_empty());
947        assert!(!result.best_parameters.is_empty());
948    }
949
950    #[test]
951    fn test_theoretical_bounds() {
952        let config = ValidationConfig::default();
953        let validator = KernelApproximationValidator::new(config);
954
955        let bound = validator
956            .theoretical_bounds
957            .get("RBF")
958            .expect("operation should succeed");
959        let theoretical_bound = validator
960            .compute_theoretical_bound(bound, 100, 10, 50)
961            .expect("operation should succeed");
962
963        assert!(theoretical_bound > 0.0);
964        assert!(theoretical_bound.is_finite());
965    }
966}