sklears_kernel_approximation/
validation.rs

1//! Advanced Validation Framework for Kernel Approximation Methods
2//!
3//! This module provides comprehensive validation tools including theoretical error bound
4//! validation, convergence analysis, and approximation quality assessment.
5
6use scirs2_core::ndarray::{Array1, Array2, Axis};
7use scirs2_core::random::essentials::Normal as RandNormal;
8use scirs2_core::random::rngs::StdRng as RealStdRng;
9use scirs2_core::random::Distribution;
10use scirs2_core::random::Rng;
11use scirs2_core::random::SeedableRng;
12use sklears_core::error::Result;
13use std::collections::HashMap;
14
15/// Comprehensive validation framework for kernel approximation methods
16#[derive(Debug, Clone)]
17/// KernelApproximationValidator
18pub struct KernelApproximationValidator {
19    config: ValidationConfig,
20    theoretical_bounds: HashMap<String, TheoreticalBound>,
21}
22
23/// Configuration for validation
24#[derive(Debug, Clone)]
25/// ValidationConfig
26pub struct ValidationConfig {
27    /// confidence_level
28    pub confidence_level: f64,
29    /// max_approximation_error
30    pub max_approximation_error: f64,
31    /// convergence_tolerance
32    pub convergence_tolerance: f64,
33    /// stability_tolerance
34    pub stability_tolerance: f64,
35    /// sample_sizes
36    pub sample_sizes: Vec<usize>,
37    /// approximation_dimensions
38    pub approximation_dimensions: Vec<usize>,
39    /// repetitions
40    pub repetitions: usize,
41    /// random_state
42    pub random_state: Option<u64>,
43}
44
45impl Default for ValidationConfig {
46    fn default() -> Self {
47        Self {
48            confidence_level: 0.95,
49            max_approximation_error: 0.1,
50            convergence_tolerance: 1e-6,
51            stability_tolerance: 1e-4,
52            sample_sizes: vec![100, 500, 1000, 2000],
53            approximation_dimensions: vec![50, 100, 200, 500],
54            repetitions: 10,
55            random_state: Some(42),
56        }
57    }
58}
59
60/// Theoretical error bounds for different approximation methods
61#[derive(Debug, Clone)]
62/// TheoreticalBound
63pub struct TheoreticalBound {
64    /// method_name
65    pub method_name: String,
66    /// bound_type
67    pub bound_type: BoundType,
68    /// bound_function
69    pub bound_function: BoundFunction,
70    /// constants
71    pub constants: HashMap<String, f64>,
72}
73
74/// Types of theoretical bounds
75#[derive(Debug, Clone)]
76/// BoundType
77pub enum BoundType {
78    /// Probabilistic bound with confidence level
79    Probabilistic { confidence: f64 },
80    /// Deterministic worst-case bound
81    Deterministic,
82    /// Expected error bound
83    Expected,
84    /// Concentration inequality bound
85    Concentration { deviation_parameter: f64 },
86}
87
88/// Functions for computing theoretical bounds
89#[derive(Debug, Clone)]
90/// BoundFunction
91pub enum BoundFunction {
92    /// RFF approximation error: O(sqrt(log(d)/m))
93    RandomFourierFeatures,
94    /// Nyström approximation error: depends on eigenvalue decay
95    Nystroem,
96    /// Structured random features: O(sqrt(d*log(d)/m))
97    StructuredRandomFeatures,
98    /// Fastfood approximation: O(sqrt(d*log^2(d)/m))
99    Fastfood,
100    /// Custom bound function
101    Custom { formula: String },
102}
103
104/// Result of validation analysis
105#[derive(Debug, Clone)]
106/// ValidationResult
107pub struct ValidationResult {
108    /// method_name
109    pub method_name: String,
110    /// empirical_errors
111    pub empirical_errors: Vec<f64>,
112    /// theoretical_bounds
113    pub theoretical_bounds: Vec<f64>,
114    /// bound_violations
115    pub bound_violations: usize,
116    /// bound_tightness
117    pub bound_tightness: f64,
118    /// convergence_rate
119    pub convergence_rate: Option<f64>,
120    /// stability_analysis
121    pub stability_analysis: StabilityAnalysis,
122    /// sample_complexity
123    pub sample_complexity: SampleComplexityAnalysis,
124    /// dimension_dependency
125    pub dimension_dependency: DimensionDependencyAnalysis,
126}
127
128/// Stability analysis results
129#[derive(Debug, Clone)]
130/// StabilityAnalysis
131pub struct StabilityAnalysis {
132    /// perturbation_sensitivity
133    pub perturbation_sensitivity: f64,
134    /// numerical_stability
135    pub numerical_stability: f64,
136    /// condition_numbers
137    pub condition_numbers: Vec<f64>,
138    /// eigenvalue_stability
139    pub eigenvalue_stability: f64,
140}
141
142/// Sample complexity analysis
143#[derive(Debug, Clone)]
144/// SampleComplexityAnalysis
145pub struct SampleComplexityAnalysis {
146    /// minimum_samples
147    pub minimum_samples: usize,
148    /// convergence_rate
149    pub convergence_rate: f64,
150    /// sample_efficiency
151    pub sample_efficiency: f64,
152    /// dimension_scaling
153    pub dimension_scaling: f64,
154}
155
156/// Dimension dependency analysis
157#[derive(Debug, Clone)]
158/// DimensionDependencyAnalysis
159pub struct DimensionDependencyAnalysis {
160    /// approximation_quality_vs_dimension
161    pub approximation_quality_vs_dimension: Vec<(usize, f64)>,
162    /// computational_cost_vs_dimension
163    pub computational_cost_vs_dimension: Vec<(usize, f64)>,
164    /// optimal_dimension
165    pub optimal_dimension: usize,
166    /// dimension_efficiency
167    pub dimension_efficiency: f64,
168}
169
170/// Cross-validation result for kernel approximation
171#[derive(Debug, Clone)]
172/// CrossValidationResult
173pub struct CrossValidationResult {
174    /// method_name
175    pub method_name: String,
176    /// cv_scores
177    pub cv_scores: Vec<f64>,
178    /// mean_score
179    pub mean_score: f64,
180    /// std_score
181    pub std_score: f64,
182    /// best_parameters
183    pub best_parameters: HashMap<String, f64>,
184    /// parameter_sensitivity
185    pub parameter_sensitivity: HashMap<String, f64>,
186}
187
188impl KernelApproximationValidator {
189    /// Create a new validator with configuration
190    pub fn new(config: ValidationConfig) -> Self {
191        let mut validator = Self {
192            config,
193            theoretical_bounds: HashMap::new(),
194        };
195
196        // Add default theoretical bounds
197        validator.add_default_bounds();
198        validator
199    }
200
201    /// Add theoretical bounds for a specific method
202    pub fn add_theoretical_bound(&mut self, bound: TheoreticalBound) {
203        self.theoretical_bounds
204            .insert(bound.method_name.clone(), bound);
205    }
206
207    fn add_default_bounds(&mut self) {
208        // RFF bounds
209        self.add_theoretical_bound(TheoreticalBound {
210            method_name: "RBF".to_string(),
211            bound_type: BoundType::Probabilistic { confidence: 0.95 },
212            bound_function: BoundFunction::RandomFourierFeatures,
213            constants: [
214                ("kernel_bound".to_string(), 1.0),
215                ("lipschitz_constant".to_string(), 1.0),
216            ]
217            .iter()
218            .cloned()
219            .collect(),
220        });
221
222        // Nyström bounds
223        self.add_theoretical_bound(TheoreticalBound {
224            method_name: "Nystroem".to_string(),
225            bound_type: BoundType::Expected,
226            bound_function: BoundFunction::Nystroem,
227            constants: [
228                ("trace_bound".to_string(), 1.0),
229                ("effective_rank".to_string(), 100.0),
230            ]
231            .iter()
232            .cloned()
233            .collect(),
234        });
235
236        // Fastfood bounds
237        self.add_theoretical_bound(TheoreticalBound {
238            method_name: "Fastfood".to_string(),
239            bound_type: BoundType::Probabilistic { confidence: 0.95 },
240            bound_function: BoundFunction::Fastfood,
241            constants: [
242                ("dimension_factor".to_string(), 1.0),
243                ("log_factor".to_string(), 2.0),
244            ]
245            .iter()
246            .cloned()
247            .collect(),
248        });
249    }
250
251    /// Validate a kernel approximation method
252    pub fn validate_method<T: ValidatableKernelMethod>(
253        &self,
254        method: &T,
255        data: &Array2<f64>,
256        true_kernel: Option<&Array2<f64>>,
257    ) -> Result<ValidationResult> {
258        let method_name = method.method_name();
259        let mut empirical_errors = Vec::new();
260        let mut theoretical_bounds = Vec::new();
261        let mut condition_numbers = Vec::new();
262
263        // Test different approximation dimensions
264        for &n_components in &self.config.approximation_dimensions {
265            let mut dimension_errors = Vec::new();
266
267            for _ in 0..self.config.repetitions {
268                // Fit and evaluate the method
269                let fitted = method.fit_with_dimension(data, n_components)?;
270                let approximation = fitted.get_kernel_approximation(data)?;
271
272                // Compute empirical error
273                let empirical_error = if let Some(true_k) = true_kernel {
274                    self.compute_approximation_error(&approximation, true_k)?
275                } else {
276                    // Use RBF kernel as reference
277                    let rbf_kernel = self.compute_rbf_kernel(data, 1.0)?;
278                    self.compute_approximation_error(&approximation, &rbf_kernel)?
279                };
280
281                dimension_errors.push(empirical_error);
282
283                // Compute condition number for stability analysis
284                if let Some(cond_num) = fitted.compute_condition_number()? {
285                    condition_numbers.push(cond_num);
286                }
287            }
288
289            let mean_error = dimension_errors.iter().sum::<f64>() / dimension_errors.len() as f64;
290            empirical_errors.push(mean_error);
291
292            // Compute theoretical bound
293            if let Some(bound) = self.theoretical_bounds.get(&method_name) {
294                let theoretical_bound = self.compute_theoretical_bound(
295                    bound,
296                    data.nrows(),
297                    data.ncols(),
298                    n_components,
299                )?;
300                theoretical_bounds.push(theoretical_bound);
301            } else {
302                theoretical_bounds.push(f64::INFINITY);
303            }
304        }
305
306        // Count bound violations
307        let bound_violations = empirical_errors
308            .iter()
309            .zip(theoretical_bounds.iter())
310            .filter(|(&emp, &theo)| emp > theo)
311            .count();
312
313        // Compute bound tightness (average ratio of empirical to theoretical)
314        let bound_tightness = empirical_errors
315            .iter()
316            .zip(theoretical_bounds.iter())
317            .filter(|(_, &theo)| theo.is_finite())
318            .map(|(&emp, &theo)| emp / theo)
319            .sum::<f64>()
320            / empirical_errors.len() as f64;
321
322        // Analyze convergence rate
323        let convergence_rate = self.estimate_convergence_rate(&empirical_errors);
324
325        // Perform stability analysis
326        let stability_analysis = self.analyze_stability(method, data, &condition_numbers)?;
327
328        // Analyze sample complexity
329        let sample_complexity = self.analyze_sample_complexity(method, data)?;
330
331        // Analyze dimension dependency
332        let dimension_dependency =
333            self.analyze_dimension_dependency(method, data, &empirical_errors)?;
334
335        Ok(ValidationResult {
336            method_name,
337            empirical_errors,
338            theoretical_bounds,
339            bound_violations,
340            bound_tightness,
341            convergence_rate,
342            stability_analysis,
343            sample_complexity,
344            dimension_dependency,
345        })
346    }
347
348    /// Perform cross-validation for parameter selection
349    pub fn cross_validate<T: ValidatableKernelMethod>(
350        &self,
351        method: &T,
352        data: &Array2<f64>,
353        targets: Option<&Array1<f64>>,
354        parameter_grid: HashMap<String, Vec<f64>>,
355    ) -> Result<CrossValidationResult> {
356        let mut best_score = f64::NEG_INFINITY;
357        let mut best_parameters = HashMap::new();
358        let mut all_scores = Vec::new();
359        let mut parameter_sensitivity = HashMap::new();
360
361        // Generate parameter combinations
362        let param_combinations = self.generate_parameter_combinations(&parameter_grid);
363
364        for params in param_combinations {
365            let cv_scores = self.k_fold_cross_validation(method, data, targets, &params, 5)?;
366            let mean_score = cv_scores.iter().sum::<f64>() / cv_scores.len() as f64;
367
368            all_scores.push(mean_score);
369
370            if mean_score > best_score {
371                best_score = mean_score;
372                best_parameters = params.clone();
373            }
374        }
375
376        // Analyze parameter sensitivity
377        for (param_name, param_values) in &parameter_grid {
378            let mut sensitivities = Vec::new();
379
380            for (i, &param_value) in param_values.iter().enumerate() {
381                let mut single_param = best_parameters.clone();
382                single_param.insert(param_name.clone(), param_value);
383
384                let cv_scores =
385                    self.k_fold_cross_validation(method, data, targets, &single_param, 3)?;
386                let mean_score = cv_scores.iter().sum::<f64>() / cv_scores.len() as f64;
387                sensitivities.push((best_score - mean_score).abs());
388            }
389
390            let sensitivity = sensitivities.iter().sum::<f64>() / sensitivities.len() as f64;
391            parameter_sensitivity.insert(param_name.clone(), sensitivity);
392        }
393
394        let mean_score = all_scores.iter().sum::<f64>() / all_scores.len() as f64;
395        let variance = all_scores
396            .iter()
397            .map(|&x| (x - mean_score).powi(2))
398            .sum::<f64>()
399            / all_scores.len() as f64;
400        let std_score = variance.sqrt();
401
402        Ok(CrossValidationResult {
403            method_name: method.method_name(),
404            cv_scores: all_scores,
405            mean_score,
406            std_score,
407            best_parameters,
408            parameter_sensitivity,
409        })
410    }
411
412    fn compute_approximation_error(
413        &self,
414        approx_kernel: &Array2<f64>,
415        true_kernel: &Array2<f64>,
416    ) -> Result<f64> {
417        // Compute Frobenius norm error
418        let diff = approx_kernel - true_kernel;
419        let frobenius_error = diff.mapv(|x| x * x).sum().sqrt();
420
421        // Normalize by true kernel norm
422        let true_norm = true_kernel.mapv(|x| x * x).sum().sqrt();
423        Ok(frobenius_error / true_norm.max(1e-8))
424    }
425
426    fn compute_rbf_kernel(&self, data: &Array2<f64>, gamma: f64) -> Result<Array2<f64>> {
427        let n_samples = data.nrows();
428        let mut kernel = Array2::zeros((n_samples, n_samples));
429
430        for i in 0..n_samples {
431            for j in i..n_samples {
432                let diff = &data.row(i) - &data.row(j);
433                let dist_sq = diff.mapv(|x| x * x).sum();
434                let similarity = (-gamma * dist_sq).exp();
435                kernel[[i, j]] = similarity;
436                kernel[[j, i]] = similarity;
437            }
438        }
439
440        Ok(kernel)
441    }
442
443    fn compute_theoretical_bound(
444        &self,
445        bound: &TheoreticalBound,
446        n_samples: usize,
447        n_features: usize,
448        n_components: usize,
449    ) -> Result<f64> {
450        let bound_value = match &bound.bound_function {
451            BoundFunction::RandomFourierFeatures => {
452                let kernel_bound = bound.constants.get("kernel_bound").unwrap_or(&1.0);
453                let lipschitz = bound.constants.get("lipschitz_constant").unwrap_or(&1.0);
454
455                // O(sqrt(log(d)/m)) bound for RFF
456                let log_factor = (n_features as f64).ln();
457                kernel_bound * lipschitz * (log_factor / n_components as f64).sqrt()
458            }
459            BoundFunction::Nystroem => {
460                let trace_bound = bound.constants.get("trace_bound").unwrap_or(&1.0);
461                let effective_rank = bound.constants.get("effective_rank").unwrap_or(&100.0);
462
463                // Approximation depends on eigenvalue decay
464                trace_bound * (effective_rank / n_components as f64).sqrt()
465            }
466            BoundFunction::StructuredRandomFeatures => {
467                let log_factor = (n_features as f64).ln();
468                (n_features as f64 * log_factor / n_components as f64).sqrt()
469            }
470            BoundFunction::Fastfood => {
471                let log_factor = bound.constants.get("log_factor").unwrap_or(&2.0);
472                let dim_factor = bound.constants.get("dimension_factor").unwrap_or(&1.0);
473
474                let log_d = (n_features as f64).ln();
475                dim_factor
476                    * (n_features as f64 * log_d.powf(*log_factor) / n_components as f64).sqrt()
477            }
478            BoundFunction::Custom { formula: _ } => {
479                // Placeholder for custom formulas
480                1.0 / (n_components as f64).sqrt()
481            }
482        };
483
484        // Apply bound type modifications
485        let final_bound = match &bound.bound_type {
486            BoundType::Probabilistic { confidence } => {
487                // Add confidence-dependent factor
488                let z_score = self.inverse_normal_cdf(*confidence);
489                bound_value * (1.0 + z_score / (n_samples as f64).sqrt())
490            }
491            BoundType::Deterministic => bound_value,
492            BoundType::Expected => bound_value * 0.8, // Expected is typically tighter
493            BoundType::Concentration {
494                deviation_parameter,
495            } => bound_value * (1.0 + deviation_parameter / (n_samples as f64).sqrt()),
496        };
497
498        Ok(final_bound)
499    }
500
501    fn inverse_normal_cdf(&self, p: f64) -> f64 {
502        // Approximation of inverse normal CDF for confidence intervals
503        if p <= 0.5 {
504            -self.inverse_normal_cdf(1.0 - p)
505        } else {
506            let t = (-2.0 * (1.0 - p).ln()).sqrt();
507            let c0 = 2.515517;
508            let c1 = 0.802853;
509            let c2 = 0.010328;
510            let d1 = 1.432788;
511            let d2 = 0.189269;
512            let d3 = 0.001308;
513
514            t - (c0 + c1 * t + c2 * t * t) / (1.0 + d1 * t + d2 * t * t + d3 * t * t * t)
515        }
516    }
517
518    fn estimate_convergence_rate(&self, errors: &[f64]) -> Option<f64> {
519        if errors.len() < 3 {
520            return None;
521        }
522
523        // Fit log(error) = a + b * log(dimension) to estimate convergence rate
524        let dimensions: Vec<f64> = self
525            .config
526            .approximation_dimensions
527            .iter()
528            .take(errors.len())
529            .map(|&x| (x as f64).ln())
530            .collect();
531
532        let log_errors: Vec<f64> = errors.iter().map(|&x| x.ln()).collect();
533
534        // Simple linear regression
535        let n = dimensions.len() as f64;
536        let sum_x = dimensions.iter().sum::<f64>();
537        let sum_y = log_errors.iter().sum::<f64>();
538        let sum_xy = dimensions
539            .iter()
540            .zip(log_errors.iter())
541            .map(|(&x, &y)| x * y)
542            .sum::<f64>();
543        let sum_x2 = dimensions.iter().map(|&x| x * x).sum::<f64>();
544
545        let slope = (n * sum_xy - sum_x * sum_y) / (n * sum_x2 - sum_x * sum_x);
546        Some(-slope) // Negative because we expect decreasing error
547    }
548
549    fn analyze_stability<T: ValidatableKernelMethod>(
550        &self,
551        method: &T,
552        data: &Array2<f64>,
553        condition_numbers: &[f64],
554    ) -> Result<StabilityAnalysis> {
555        let mut rng = RealStdRng::seed_from_u64(self.config.random_state.unwrap_or(42));
556        let normal = RandNormal::new(0.0, self.config.stability_tolerance).unwrap();
557
558        // Test perturbation sensitivity
559        let mut perturbation_errors = Vec::new();
560
561        for _ in 0..5 {
562            let mut perturbed_data = data.clone();
563            for elem in perturbed_data.iter_mut() {
564                *elem += rng.sample(normal);
565            }
566
567            let original_fitted = method.fit_with_dimension(data, 100)?;
568            let perturbed_fitted = method.fit_with_dimension(&perturbed_data, 100)?;
569
570            let original_approx = original_fitted.get_kernel_approximation(data)?;
571            let perturbed_approx = perturbed_fitted.get_kernel_approximation(data)?;
572
573            let error = self.compute_approximation_error(&perturbed_approx, &original_approx)?;
574            perturbation_errors.push(error);
575        }
576
577        let perturbation_sensitivity =
578            perturbation_errors.iter().sum::<f64>() / perturbation_errors.len() as f64;
579
580        // Numerical stability from condition numbers
581        let numerical_stability = if condition_numbers.is_empty() {
582            1.0
583        } else {
584            let mean_condition =
585                condition_numbers.iter().sum::<f64>() / condition_numbers.len() as f64;
586            1.0 / mean_condition.ln().max(1.0)
587        };
588
589        // Eigenvalue stability (placeholder)
590        let eigenvalue_stability = 1.0 - perturbation_sensitivity;
591
592        Ok(StabilityAnalysis {
593            perturbation_sensitivity,
594            numerical_stability,
595            condition_numbers: condition_numbers.to_vec(),
596            eigenvalue_stability,
597        })
598    }
599
600    fn analyze_sample_complexity<T: ValidatableKernelMethod>(
601        &self,
602        method: &T,
603        data: &Array2<f64>,
604    ) -> Result<SampleComplexityAnalysis> {
605        let mut sample_errors = Vec::new();
606
607        // Test different sample sizes
608        for &n_samples in &self.config.sample_sizes {
609            if n_samples > data.nrows() {
610                continue;
611            }
612
613            let subset_data = data
614                .slice(scirs2_core::ndarray::s![..n_samples, ..])
615                .to_owned();
616            let fitted = method.fit_with_dimension(&subset_data, 100)?;
617            let approx = fitted.get_kernel_approximation(&subset_data)?;
618
619            let rbf_kernel = self.compute_rbf_kernel(&subset_data, 1.0)?;
620            let error = self.compute_approximation_error(&approx, &rbf_kernel)?;
621            sample_errors.push(error);
622        }
623
624        // Estimate minimum required samples
625        let target_error = self.config.max_approximation_error;
626        let minimum_samples = self
627            .config
628            .sample_sizes
629            .iter()
630            .zip(sample_errors.iter())
631            .find(|(_, &error)| error <= target_error)
632            .map(|(&samples, _)| samples)
633            .unwrap_or(*self.config.sample_sizes.last().unwrap());
634
635        // Estimate convergence rate with respect to sample size
636        let convergence_rate = if sample_errors.len() >= 2 {
637            let log_samples: Vec<f64> = self
638                .config
639                .sample_sizes
640                .iter()
641                .take(sample_errors.len())
642                .map(|&x| (x as f64).ln())
643                .collect();
644            let log_errors: Vec<f64> = sample_errors.iter().map(|&x| x.ln()).collect();
645
646            // Linear regression for convergence rate
647            let n = log_samples.len() as f64;
648            let sum_x = log_samples.iter().sum::<f64>();
649            let sum_y = log_errors.iter().sum::<f64>();
650            let sum_xy = log_samples
651                .iter()
652                .zip(log_errors.iter())
653                .map(|(&x, &y)| x * y)
654                .sum::<f64>();
655            let sum_x2 = log_samples.iter().map(|&x| x * x).sum::<f64>();
656
657            -(n * sum_xy - sum_x * sum_y) / (n * sum_x2 - sum_x * sum_x)
658        } else {
659            0.5 // Default assumption
660        };
661
662        let sample_efficiency = 1.0 / minimum_samples as f64;
663        let dimension_scaling = data.ncols() as f64 / minimum_samples as f64;
664
665        Ok(SampleComplexityAnalysis {
666            minimum_samples,
667            convergence_rate,
668            sample_efficiency,
669            dimension_scaling,
670        })
671    }
672
673    fn analyze_dimension_dependency<T: ValidatableKernelMethod>(
674        &self,
675        method: &T,
676        data: &Array2<f64>,
677        errors: &[f64],
678    ) -> Result<DimensionDependencyAnalysis> {
679        let approximation_quality_vs_dimension: Vec<(usize, f64)> = self
680            .config
681            .approximation_dimensions
682            .iter()
683            .take(errors.len())
684            .zip(errors.iter())
685            .map(|(&dim, &error)| (dim, 1.0 - error)) // Convert error to quality
686            .collect();
687
688        // Estimate computational cost (simplified)
689        let computational_cost_vs_dimension: Vec<(usize, f64)> = self
690            .config
691            .approximation_dimensions
692            .iter()
693            .map(|&dim| (dim, dim as f64 * data.nrows() as f64))
694            .collect();
695
696        // Find optimal dimension (best quality-to-cost ratio)
697        let optimal_dimension = approximation_quality_vs_dimension
698            .iter()
699            .zip(computational_cost_vs_dimension.iter())
700            .map(|((dim, quality), (_, cost))| (*dim, quality / cost))
701            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
702            .map(|(dim, _)| dim)
703            .unwrap_or(100);
704
705        let dimension_efficiency = approximation_quality_vs_dimension
706            .iter()
707            .map(|(_, quality)| quality)
708            .sum::<f64>()
709            / approximation_quality_vs_dimension.len() as f64;
710
711        Ok(DimensionDependencyAnalysis {
712            approximation_quality_vs_dimension,
713            computational_cost_vs_dimension,
714            optimal_dimension,
715            dimension_efficiency,
716        })
717    }
718
719    fn generate_parameter_combinations(
720        &self,
721        parameter_grid: &HashMap<String, Vec<f64>>,
722    ) -> Vec<HashMap<String, f64>> {
723        let mut combinations = vec![HashMap::new()];
724
725        for (param_name, param_values) in parameter_grid {
726            let mut new_combinations = Vec::new();
727
728            for combination in &combinations {
729                for &param_value in param_values {
730                    let mut new_combination = combination.clone();
731                    new_combination.insert(param_name.clone(), param_value);
732                    new_combinations.push(new_combination);
733                }
734            }
735
736            combinations = new_combinations;
737        }
738
739        combinations
740    }
741
742    fn k_fold_cross_validation<T: ValidatableKernelMethod>(
743        &self,
744        method: &T,
745        data: &Array2<f64>,
746        _targets: Option<&Array1<f64>>,
747        parameters: &HashMap<String, f64>,
748        k: usize,
749    ) -> Result<Vec<f64>> {
750        let n_samples = data.nrows();
751        let fold_size = n_samples / k;
752        let mut scores = Vec::new();
753
754        for fold in 0..k {
755            let start_idx = fold * fold_size;
756            let end_idx = if fold == k - 1 {
757                n_samples
758            } else {
759                (fold + 1) * fold_size
760            };
761
762            // Create train and validation sets
763            let train_indices: Vec<usize> = (0..n_samples)
764                .filter(|&i| i < start_idx || i >= end_idx)
765                .collect();
766            let val_indices: Vec<usize> = (start_idx..end_idx).collect();
767
768            let train_data = data.select(Axis(0), &train_indices);
769            let val_data = data.select(Axis(0), &val_indices);
770
771            // Fit with parameters
772            let fitted = method.fit_with_parameters(&train_data, parameters)?;
773            let train_approx = fitted.get_kernel_approximation(&train_data)?;
774            let val_approx = fitted.get_kernel_approximation(&val_data)?;
775
776            // Compute validation score (kernel alignment as proxy)
777            let train_kernel = self.compute_rbf_kernel(&train_data, 1.0)?;
778            let val_kernel = self.compute_rbf_kernel(&val_data, 1.0)?;
779
780            let train_error = self.compute_approximation_error(&train_approx, &train_kernel)?;
781            let val_error = self.compute_approximation_error(&val_approx, &val_kernel)?;
782
783            // Score is negative error (higher is better)
784            let score = -(train_error + val_error) / 2.0;
785            scores.push(score);
786        }
787
788        Ok(scores)
789    }
790}
791
792/// Trait for kernel methods that can be validated
793pub trait ValidatableKernelMethod {
794    /// Get method name
795    fn method_name(&self) -> String;
796
797    /// Fit with specific approximation dimension
798    fn fit_with_dimension(
799        &self,
800        data: &Array2<f64>,
801        n_components: usize,
802    ) -> Result<Box<dyn ValidatedFittedMethod>>;
803
804    /// Fit with specific parameters
805    fn fit_with_parameters(
806        &self,
807        data: &Array2<f64>,
808        parameters: &HashMap<String, f64>,
809    ) -> Result<Box<dyn ValidatedFittedMethod>>;
810}
811
812/// Trait for fitted methods that can be validated
813pub trait ValidatedFittedMethod {
814    /// Get kernel approximation matrix
815    fn get_kernel_approximation(&self, data: &Array2<f64>) -> Result<Array2<f64>>;
816
817    /// Compute condition number if applicable
818    fn compute_condition_number(&self) -> Result<Option<f64>>;
819
820    /// Get approximation dimension
821    fn approximation_dimension(&self) -> usize;
822}
823
824#[allow(non_snake_case)]
825#[cfg(test)]
826mod tests {
827    use super::*;
828    // Mock implementation for testing
829    struct MockValidatableRBF {
830        gamma: f64,
831    }
832
833    impl ValidatableKernelMethod for MockValidatableRBF {
834        fn method_name(&self) -> String {
835            "MockRBF".to_string()
836        }
837
838        fn fit_with_dimension(
839            &self,
840            data: &Array2<f64>,
841            n_components: usize,
842        ) -> Result<Box<dyn ValidatedFittedMethod>> {
843            Ok(Box::new(MockValidatedFitted { n_components }))
844        }
845
846        fn fit_with_parameters(
847            &self,
848            data: &Array2<f64>,
849            parameters: &HashMap<String, f64>,
850        ) -> Result<Box<dyn ValidatedFittedMethod>> {
851            let n_components = parameters.get("n_components").copied().unwrap_or(100.0) as usize;
852            Ok(Box::new(MockValidatedFitted { n_components }))
853        }
854    }
855
856    struct MockValidatedFitted {
857        n_components: usize,
858    }
859
860    impl ValidatedFittedMethod for MockValidatedFitted {
861        fn get_kernel_approximation(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
862            let n_samples = data.nrows();
863            let mut kernel = Array2::zeros((n_samples, n_samples));
864
865            // Simple mock kernel matrix (identity-like)
866            for i in 0..n_samples {
867                kernel[[i, i]] = 1.0;
868                for j in i + 1..n_samples {
869                    let similarity = 0.5; // Simple mock similarity
870                    kernel[[i, j]] = similarity;
871                    kernel[[j, i]] = similarity;
872                }
873            }
874
875            Ok(kernel)
876        }
877
878        fn compute_condition_number(&self) -> Result<Option<f64>> {
879            // Simplified condition number estimation
880            Ok(Some(10.0))
881        }
882
883        fn approximation_dimension(&self) -> usize {
884            self.n_components
885        }
886    }
887
888    #[test]
889    fn test_validator_creation() {
890        let config = ValidationConfig::default();
891        let validator = KernelApproximationValidator::new(config);
892
893        assert!(!validator.theoretical_bounds.is_empty());
894        assert!(validator.theoretical_bounds.contains_key("RBF"));
895    }
896
897    #[test]
898    fn test_method_validation() {
899        let config = ValidationConfig {
900            approximation_dimensions: vec![10, 20],
901            repetitions: 2,
902            ..Default::default()
903        };
904        let validator = KernelApproximationValidator::new(config);
905
906        let data = Array2::from_shape_fn((50, 5), |(i, j)| (i + j) as f64 * 0.1);
907        let method = MockValidatableRBF { gamma: 1.0 };
908
909        let result = validator.validate_method(&method, &data, None).unwrap();
910
911        assert_eq!(result.method_name, "MockRBF");
912        assert_eq!(result.empirical_errors.len(), 2);
913        assert_eq!(result.theoretical_bounds.len(), 2);
914        // Convergence rate may be None if insufficient data points or poor fit
915        // This is acceptable as long as other validation results are present
916        if let Some(rate) = result.convergence_rate {
917            assert!(rate.is_finite());
918        }
919    }
920
921    #[test]
922    fn test_cross_validation() {
923        let config = ValidationConfig::default();
924        let validator = KernelApproximationValidator::new(config);
925
926        let data = Array2::from_shape_fn((30, 4), |(i, j)| (i + j) as f64 * 0.1);
927        let method = MockValidatableRBF { gamma: 1.0 };
928
929        let mut parameter_grid = HashMap::new();
930        parameter_grid.insert("gamma".to_string(), vec![0.5, 1.0, 2.0]);
931        parameter_grid.insert("n_components".to_string(), vec![10.0, 20.0]);
932
933        let result = validator
934            .cross_validate(&method, &data, None, parameter_grid)
935            .unwrap();
936
937        assert_eq!(result.method_name, "MockRBF");
938        assert!(!result.cv_scores.is_empty());
939        assert!(!result.best_parameters.is_empty());
940    }
941
942    #[test]
943    fn test_theoretical_bounds() {
944        let config = ValidationConfig::default();
945        let validator = KernelApproximationValidator::new(config);
946
947        let bound = validator.theoretical_bounds.get("RBF").unwrap();
948        let theoretical_bound = validator
949            .compute_theoretical_bound(bound, 100, 10, 50)
950            .unwrap();
951
952        assert!(theoretical_bound > 0.0);
953        assert!(theoretical_bound.is_finite());
954    }
955}