sklears_kernel_approximation/
type_safety.rs

1//! Type Safety Enhancements for Kernel Approximation Methods
2//!
3//! This module provides compile-time type safety using phantom types and const generics
4//! to prevent common errors in kernel approximation usage.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::rand_prelude::SliceRandom;
8use scirs2_core::random::essentials::{Normal as RandNormal, Uniform as RandUniform};
9use scirs2_core::random::rngs::StdRng as RealStdRng;
10use scirs2_core::random::Distribution;
11use scirs2_core::random::{thread_rng, Rng, SeedableRng};
12use sklears_core::error::{Result, SklearsError};
13use std::marker::PhantomData;
14
15/// Phantom type to represent the state of a kernel approximation method
16pub trait ApproximationState {}
17
18/// Untrained state - method hasn't been fitted yet
19#[derive(Debug, Clone, Copy)]
20/// Untrained
21pub struct Untrained;
22impl ApproximationState for Untrained {}
23
24/// Trained state - method has been fitted and can transform data
25#[derive(Debug, Clone, Copy)]
26/// Trained
27pub struct Trained;
28impl ApproximationState for Trained {}
29
30/// Phantom type to represent kernel types
31pub trait KernelType {
32    /// Name of the kernel type
33    const NAME: &'static str;
34
35    /// Whether this kernel type supports parameter learning
36    const SUPPORTS_PARAMETER_LEARNING: bool;
37
38    /// Default bandwidth/gamma parameter
39    const DEFAULT_BANDWIDTH: f64;
40}
41
42/// RBF (Gaussian) kernel type
43#[derive(Debug, Clone, Copy)]
44/// RBFKernel
45pub struct RBFKernel;
46impl KernelType for RBFKernel {
47    const NAME: &'static str = "RBF";
48    const SUPPORTS_PARAMETER_LEARNING: bool = true;
49    const DEFAULT_BANDWIDTH: f64 = 1.0;
50}
51
52/// Laplacian kernel type
53#[derive(Debug, Clone, Copy)]
54/// LaplacianKernel
55pub struct LaplacianKernel;
56impl KernelType for LaplacianKernel {
57    const NAME: &'static str = "Laplacian";
58    const SUPPORTS_PARAMETER_LEARNING: bool = true;
59    const DEFAULT_BANDWIDTH: f64 = 1.0;
60}
61
62/// Polynomial kernel type
63#[derive(Debug, Clone, Copy)]
64/// PolynomialKernel
65pub struct PolynomialKernel;
66impl KernelType for PolynomialKernel {
67    const NAME: &'static str = "Polynomial";
68    const SUPPORTS_PARAMETER_LEARNING: bool = true;
69    const DEFAULT_BANDWIDTH: f64 = 1.0;
70}
71
72/// Arc-cosine kernel type
73#[derive(Debug, Clone, Copy)]
74/// ArcCosineKernel
75pub struct ArcCosineKernel;
76impl KernelType for ArcCosineKernel {
77    const NAME: &'static str = "ArcCosine";
78    const SUPPORTS_PARAMETER_LEARNING: bool = false;
79    const DEFAULT_BANDWIDTH: f64 = 1.0;
80}
81
82/// Phantom type to represent approximation methods
83pub trait ApproximationMethod {
84    /// Name of the approximation method
85    const NAME: &'static str;
86
87    /// Whether this method supports incremental updates
88    const SUPPORTS_INCREMENTAL: bool;
89
90    /// Whether this method provides theoretical error bounds
91    const HAS_ERROR_BOUNDS: bool;
92
93    /// Computational complexity class
94    const COMPLEXITY: ComplexityClass;
95}
96
97/// Computational complexity classes
98#[derive(Debug, Clone, Copy, PartialEq, Eq)]
99/// ComplexityClass
100pub enum ComplexityClass {
101    /// O(n²) complexity
102    Quadratic,
103    /// O(n log n) complexity
104    QuasiLinear,
105    /// O(n) complexity
106    Linear,
107    /// O(d log d) complexity where d is dimension
108    DimensionDependent,
109}
110
111/// Random Fourier Features approximation method
112#[derive(Debug, Clone, Copy)]
113/// RandomFourierFeatures
114pub struct RandomFourierFeatures;
115impl ApproximationMethod for RandomFourierFeatures {
116    const NAME: &'static str = "RandomFourierFeatures";
117    const SUPPORTS_INCREMENTAL: bool = true;
118    const HAS_ERROR_BOUNDS: bool = true;
119    const COMPLEXITY: ComplexityClass = ComplexityClass::Linear;
120}
121
122/// Nyström approximation method
123#[derive(Debug, Clone, Copy)]
124/// NystromMethod
125pub struct NystromMethod;
126impl ApproximationMethod for NystromMethod {
127    const NAME: &'static str = "Nystrom";
128    const SUPPORTS_INCREMENTAL: bool = false;
129    const HAS_ERROR_BOUNDS: bool = true;
130    const COMPLEXITY: ComplexityClass = ComplexityClass::Quadratic;
131}
132
133/// Fastfood approximation method
134#[derive(Debug, Clone, Copy)]
135/// FastfoodMethod
136pub struct FastfoodMethod;
137impl ApproximationMethod for FastfoodMethod {
138    const NAME: &'static str = "Fastfood";
139    const SUPPORTS_INCREMENTAL: bool = false;
140    const HAS_ERROR_BOUNDS: bool = true;
141    const COMPLEXITY: ComplexityClass = ComplexityClass::DimensionDependent;
142}
143
144/// Type-safe kernel approximation with compile-time guarantees
145#[derive(Debug, Clone)]
146/// TypeSafeKernelApproximation
147pub struct TypeSafeKernelApproximation<State, Kernel, Method, const N_COMPONENTS: usize>
148where
149    State: ApproximationState,
150    Kernel: KernelType,
151    Method: ApproximationMethod,
152{
153    /// Phantom data for compile-time type checking
154    _phantom: PhantomData<(State, Kernel, Method)>,
155
156    /// Method parameters
157    parameters: ApproximationParameters,
158
159    /// Random state for reproducibility
160    random_state: Option<u64>,
161}
162
163/// Parameters for kernel approximation methods
164#[derive(Debug, Clone)]
165/// ApproximationParameters
166pub struct ApproximationParameters {
167    /// Bandwidth/gamma parameter
168    pub bandwidth: f64,
169
170    /// Polynomial degree (for polynomial kernels)
171    pub degree: Option<usize>,
172
173    /// Coefficient for polynomial kernels
174    pub coef0: Option<f64>,
175
176    /// Additional custom parameters
177    pub custom: std::collections::HashMap<String, f64>,
178}
179
180impl Default for ApproximationParameters {
181    fn default() -> Self {
182        Self {
183            bandwidth: 1.0,
184            degree: None,
185            coef0: None,
186            custom: std::collections::HashMap::new(),
187        }
188    }
189}
190
191/// Fitted kernel approximation that can transform data
192#[derive(Debug, Clone)]
193/// FittedTypeSafeKernelApproximation
194pub struct FittedTypeSafeKernelApproximation<Kernel, Method, const N_COMPONENTS: usize>
195where
196    Kernel: KernelType,
197    Method: ApproximationMethod,
198{
199    /// Phantom data for compile-time type checking
200    _phantom: PhantomData<(Kernel, Method)>,
201
202    /// Random features or transformation parameters
203    transformation_params: TransformationParameters<N_COMPONENTS>,
204
205    /// Fitted parameters
206    fitted_parameters: ApproximationParameters,
207
208    /// Approximation quality metrics
209    quality_metrics: QualityMetrics,
210}
211
212/// Transformation parameters for different approximation methods
213#[derive(Debug, Clone)]
214/// TransformationParameters
215pub enum TransformationParameters<const N: usize> {
216    /// Random features for RFF methods
217    RandomFeatures {
218        weights: Array2<f64>,
219
220        biases: Option<Array1<f64>>,
221    },
222    /// Inducing points and eigendecomposition for Nyström
223    Nystrom {
224        inducing_points: Array2<f64>,
225
226        eigenvalues: Array1<f64>,
227
228        eigenvectors: Array2<f64>,
229    },
230    /// Structured matrices for Fastfood
231    Fastfood {
232        structured_matrices: Vec<Array2<f64>>,
233        scaling: Array1<f64>,
234    },
235}
236
237/// Quality metrics for approximation assessment
238#[derive(Debug, Clone)]
239/// QualityMetrics
240pub struct QualityMetrics {
241    /// Approximation error estimate
242    pub approximation_error: Option<f64>,
243
244    /// Effective rank of approximation
245    pub effective_rank: Option<f64>,
246
247    /// Condition number
248    pub condition_number: Option<f64>,
249
250    /// Kernel alignment score
251    pub kernel_alignment: Option<f64>,
252}
253
254impl Default for QualityMetrics {
255    fn default() -> Self {
256        Self {
257            approximation_error: None,
258            effective_rank: None,
259            condition_number: None,
260            kernel_alignment: None,
261        }
262    }
263}
264
265// Implementation for untrained approximation
266impl<Kernel, Method, const N_COMPONENTS: usize>
267    TypeSafeKernelApproximation<Untrained, Kernel, Method, N_COMPONENTS>
268where
269    Kernel: KernelType,
270    Method: ApproximationMethod,
271{
272    /// Create a new untrained kernel approximation
273    pub fn new() -> Self {
274        Self {
275            _phantom: PhantomData,
276            parameters: ApproximationParameters {
277                bandwidth: Kernel::DEFAULT_BANDWIDTH,
278                ..Default::default()
279            },
280            random_state: None,
281        }
282    }
283
284    /// Set bandwidth parameter (only available for kernels that support it)
285    pub fn bandwidth(mut self, bandwidth: f64) -> Self
286    where
287        Kernel: KernelTypeWithBandwidth,
288    {
289        self.parameters.bandwidth = bandwidth;
290        self
291    }
292
293    /// Set polynomial degree (only available for polynomial kernels)
294    pub fn degree(mut self, degree: usize) -> Self
295    where
296        Kernel: PolynomialKernelType,
297    {
298        self.parameters.degree = Some(degree);
299        self
300    }
301
302    /// Set random state for reproducibility
303    pub fn random_state(mut self, seed: u64) -> Self {
304        self.random_state = Some(seed);
305        self
306    }
307
308    /// Fit the approximation method (state transition)
309    pub fn fit(
310        self,
311        data: &Array2<f64>,
312    ) -> Result<FittedTypeSafeKernelApproximation<Kernel, Method, N_COMPONENTS>>
313    where
314        Kernel: FittableKernel<Method>,
315        Method: FittableMethod<Kernel>,
316    {
317        self.fit_impl(data)
318    }
319
320    /// Internal fitting implementation
321    fn fit_impl(
322        self,
323        data: &Array2<f64>,
324    ) -> Result<FittedTypeSafeKernelApproximation<Kernel, Method, N_COMPONENTS>>
325    where
326        Kernel: FittableKernel<Method>,
327        Method: FittableMethod<Kernel>,
328    {
329        let transformation_params = match Method::NAME {
330            "RandomFourierFeatures" => self.fit_random_fourier_features(data)?,
331            "Nystrom" => self.fit_nystrom(data)?,
332            "Fastfood" => self.fit_fastfood(data)?,
333            _ => {
334                return Err(SklearsError::InvalidOperation(format!(
335                    "Unsupported method: {}",
336                    Method::NAME
337                )));
338            }
339        };
340
341        Ok(FittedTypeSafeKernelApproximation {
342            _phantom: PhantomData,
343            transformation_params,
344            fitted_parameters: self.parameters,
345            quality_metrics: QualityMetrics::default(),
346        })
347    }
348
349    fn fit_random_fourier_features(
350        &self,
351        data: &Array2<f64>,
352    ) -> Result<TransformationParameters<N_COMPONENTS>> {
353        let mut rng = match self.random_state {
354            Some(seed) => RealStdRng::seed_from_u64(seed),
355            None => RealStdRng::from_seed(thread_rng().gen()),
356        };
357
358        let (_, n_features) = data.dim();
359
360        let weights = match Kernel::NAME {
361            "RBF" => {
362                let normal = RandNormal::new(0.0, self.parameters.bandwidth).unwrap();
363                Array2::from_shape_fn((N_COMPONENTS, n_features), |_| rng.sample(normal))
364            }
365            "Laplacian" => {
366                // Laplacian kernel uses Cauchy distribution
367                let uniform = RandUniform::new(0.0, std::f64::consts::PI).unwrap();
368                Array2::from_shape_fn((N_COMPONENTS, n_features), |_| {
369                    let u = rng.sample(uniform);
370                    (u - std::f64::consts::PI / 2.0).tan() / self.parameters.bandwidth
371                })
372            }
373            _ => {
374                return Err(SklearsError::InvalidOperation(format!(
375                    "RFF not supported for kernel: {}",
376                    Kernel::NAME
377                )));
378            }
379        };
380
381        let uniform = RandUniform::new(0.0, 2.0 * std::f64::consts::PI).unwrap();
382        let biases = Some(Array1::from_shape_fn(N_COMPONENTS, |_| rng.sample(uniform)));
383
384        Ok(TransformationParameters::RandomFeatures { weights, biases })
385    }
386
387    fn fit_nystrom(&self, data: &Array2<f64>) -> Result<TransformationParameters<N_COMPONENTS>> {
388        let (n_samples, _) = data.dim();
389        let n_inducing = N_COMPONENTS.min(n_samples);
390
391        // Sample inducing points
392        let mut rng = match self.random_state {
393            Some(seed) => RealStdRng::seed_from_u64(seed),
394            None => RealStdRng::from_seed(thread_rng().gen()),
395        };
396
397        let mut indices: Vec<usize> = (0..n_samples).collect();
398        indices.shuffle(&mut rng);
399
400        let inducing_indices = &indices[..n_inducing];
401        let inducing_points = data.select(scirs2_core::ndarray::Axis(0), inducing_indices);
402
403        // Compute kernel matrix on inducing points
404        let kernel_matrix = self.compute_kernel_matrix(&inducing_points, &inducing_points)?;
405
406        // Simplified eigendecomposition (using diagonal as approximation)
407        let eigenvalues = Array1::from_shape_fn(n_inducing, |i| kernel_matrix[[i, i]]);
408        let eigenvectors = Array2::eye(n_inducing);
409
410        Ok(TransformationParameters::Nystrom {
411            inducing_points,
412            eigenvalues,
413            eigenvectors,
414        })
415    }
416
417    fn fit_fastfood(&self, data: &Array2<f64>) -> Result<TransformationParameters<N_COMPONENTS>> {
418        let mut rng = match self.random_state {
419            Some(seed) => RealStdRng::seed_from_u64(seed),
420            None => RealStdRng::from_seed(thread_rng().gen()),
421        };
422
423        let (_, n_features) = data.dim();
424
425        // Create structured matrices for Fastfood
426        let mut structured_matrices = Vec::new();
427
428        // Binary matrix
429        let binary_matrix = Array2::from_shape_fn((n_features, n_features), |(i, j)| {
430            if i == j {
431                if rng.gen::<bool>() {
432                    1.0
433                } else {
434                    -1.0
435                }
436            } else {
437                0.0
438            }
439        });
440        structured_matrices.push(binary_matrix);
441
442        // Gaussian scaling
443        let scaling = Array1::from_shape_fn(N_COMPONENTS, |_| {
444            use scirs2_core::random::{RandNormal, Rng};
445            let normal = RandNormal::new(0.0, 1.0).unwrap();
446            rng.sample(normal)
447        });
448
449        Ok(TransformationParameters::Fastfood {
450            structured_matrices,
451            scaling,
452        })
453    }
454
455    fn compute_kernel_matrix(&self, x1: &Array2<f64>, x2: &Array2<f64>) -> Result<Array2<f64>> {
456        let (n1, _) = x1.dim();
457        let (n2, _) = x2.dim();
458        let mut kernel = Array2::zeros((n1, n2));
459
460        for i in 0..n1 {
461            for j in 0..n2 {
462                let similarity = match Kernel::NAME {
463                    "RBF" => {
464                        let diff = &x1.row(i) - &x2.row(j);
465                        let dist_sq = diff.mapv(|x| x * x).sum();
466                        (-self.parameters.bandwidth * dist_sq).exp()
467                    }
468                    "Laplacian" => {
469                        let diff = &x1.row(i) - &x2.row(j);
470                        let dist = diff.mapv(|x| x.abs()).sum();
471                        (-self.parameters.bandwidth * dist).exp()
472                    }
473                    "Polynomial" => {
474                        let dot_product = x1.row(i).dot(&x2.row(j));
475                        let degree = self.parameters.degree.unwrap_or(2) as i32;
476                        let coef0 = self.parameters.coef0.unwrap_or(1.0);
477                        (self.parameters.bandwidth * dot_product + coef0).powi(degree)
478                    }
479                    _ => {
480                        return Err(SklearsError::InvalidOperation(format!(
481                            "Unsupported kernel: {}",
482                            Kernel::NAME
483                        )));
484                    }
485                };
486                kernel[[i, j]] = similarity;
487            }
488        }
489
490        Ok(kernel)
491    }
492}
493
494// Implementation for fitted approximation
495impl<Kernel, Method, const N_COMPONENTS: usize>
496    FittedTypeSafeKernelApproximation<Kernel, Method, N_COMPONENTS>
497where
498    Kernel: KernelType,
499    Method: ApproximationMethod,
500{
501    /// Transform data using the fitted approximation
502    pub fn transform(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
503        match &self.transformation_params {
504            TransformationParameters::RandomFeatures { weights, biases } => {
505                self.transform_random_features(data, weights, biases.as_ref())
506            }
507            TransformationParameters::Nystrom {
508                inducing_points,
509                eigenvalues,
510                eigenvectors,
511            } => self.transform_nystrom(data, inducing_points, eigenvalues, eigenvectors),
512            TransformationParameters::Fastfood {
513                structured_matrices,
514                scaling,
515            } => self.transform_fastfood(data, structured_matrices, scaling),
516        }
517    }
518
519    fn transform_random_features(
520        &self,
521        data: &Array2<f64>,
522        weights: &Array2<f64>,
523        biases: Option<&Array1<f64>>,
524    ) -> Result<Array2<f64>> {
525        let (n_samples, _) = data.dim();
526        let mut features = Array2::zeros((n_samples, N_COMPONENTS * 2));
527
528        for (i, sample) in data.axis_iter(scirs2_core::ndarray::Axis(0)).enumerate() {
529            for j in 0..N_COMPONENTS {
530                let projection = sample.dot(&weights.row(j));
531                let phase = if let Some(b) = biases {
532                    projection + b[j]
533                } else {
534                    projection
535                };
536
537                features[[i, 2 * j]] = phase.cos();
538                features[[i, 2 * j + 1]] = phase.sin();
539            }
540        }
541
542        Ok(features)
543    }
544
545    fn transform_nystrom(
546        &self,
547        data: &Array2<f64>,
548        inducing_points: &Array2<f64>,
549        eigenvalues: &Array1<f64>,
550        eigenvectors: &Array2<f64>,
551    ) -> Result<Array2<f64>> {
552        // Compute kernel between data and inducing points
553        let kernel_matrix = self.compute_kernel_matrix_fitted(data, inducing_points)?;
554
555        // Apply eigendecomposition transformation
556        let mut features = Array2::zeros((data.nrows(), eigenvalues.len()));
557
558        for i in 0..data.nrows() {
559            for j in 0..eigenvalues.len() {
560                if eigenvalues[j] > 1e-8 {
561                    let mut feature_value = 0.0;
562                    for k in 0..inducing_points.nrows() {
563                        feature_value += kernel_matrix[[i, k]] * eigenvectors[[k, j]];
564                    }
565                    features[[i, j]] = feature_value / eigenvalues[j].sqrt();
566                }
567            }
568        }
569
570        Ok(features)
571    }
572
573    fn transform_fastfood(
574        &self,
575        data: &Array2<f64>,
576        structured_matrices: &[Array2<f64>],
577        scaling: &Array1<f64>,
578    ) -> Result<Array2<f64>> {
579        let (n_samples, n_features) = data.dim();
580        let mut features = data.clone();
581
582        // Apply structured transformations
583        for matrix in structured_matrices {
584            features = features.dot(matrix);
585        }
586
587        // Apply scaling and take cosine features
588        let mut result = Array2::zeros((n_samples, N_COMPONENTS));
589        for i in 0..n_samples {
590            for j in 0..N_COMPONENTS.min(n_features) {
591                result[[i, j]] = (features[[i, j]] * scaling[j]).cos();
592            }
593        }
594
595        Ok(result)
596    }
597
598    fn compute_kernel_matrix_fitted(
599        &self,
600        x1: &Array2<f64>,
601        x2: &Array2<f64>,
602    ) -> Result<Array2<f64>> {
603        let (n1, _) = x1.dim();
604        let (n2, _) = x2.dim();
605        let mut kernel = Array2::zeros((n1, n2));
606
607        for i in 0..n1 {
608            for j in 0..n2 {
609                let similarity = match Kernel::NAME {
610                    "RBF" => {
611                        let diff = &x1.row(i) - &x2.row(j);
612                        let dist_sq = diff.mapv(|x| x * x).sum();
613                        (-self.fitted_parameters.bandwidth * dist_sq).exp()
614                    }
615                    "Laplacian" => {
616                        let diff = &x1.row(i) - &x2.row(j);
617                        let dist = diff.mapv(|x| x.abs()).sum();
618                        (-self.fitted_parameters.bandwidth * dist).exp()
619                    }
620                    "Polynomial" => {
621                        let dot_product = x1.row(i).dot(&x2.row(j));
622                        let degree = self.fitted_parameters.degree.unwrap_or(2) as i32;
623                        let coef0 = self.fitted_parameters.coef0.unwrap_or(1.0);
624                        (self.fitted_parameters.bandwidth * dot_product + coef0).powi(degree)
625                    }
626                    _ => {
627                        return Err(SklearsError::InvalidOperation(format!(
628                            "Unsupported kernel: {}",
629                            Kernel::NAME
630                        )));
631                    }
632                };
633                kernel[[i, j]] = similarity;
634            }
635        }
636
637        Ok(kernel)
638    }
639
640    /// Get approximation quality metrics
641    pub fn quality_metrics(&self) -> &QualityMetrics {
642        &self.quality_metrics
643    }
644
645    /// Get the number of components (compile-time constant)
646    pub const fn n_components() -> usize {
647        N_COMPONENTS
648    }
649
650    /// Get kernel type name
651    pub fn kernel_name(&self) -> &'static str {
652        Kernel::NAME
653    }
654
655    /// Get approximation method name
656    pub fn method_name(&self) -> &'static str {
657        Method::NAME
658    }
659}
660
661// Trait constraints for type safety
662
663/// Marker trait for kernel types that support bandwidth parameters
664pub trait KernelTypeWithBandwidth: KernelType {}
665impl KernelTypeWithBandwidth for RBFKernel {}
666impl KernelTypeWithBandwidth for LaplacianKernel {}
667
668/// Marker trait for polynomial kernel types
669pub trait PolynomialKernelType: KernelType {}
670impl PolynomialKernelType for PolynomialKernel {}
671
672/// Marker trait for kernels that can be fitted with specific methods
673pub trait FittableKernel<Method: ApproximationMethod>: KernelType {}
674impl FittableKernel<RandomFourierFeatures> for RBFKernel {}
675impl FittableKernel<RandomFourierFeatures> for LaplacianKernel {}
676impl FittableKernel<NystromMethod> for RBFKernel {}
677impl FittableKernel<NystromMethod> for LaplacianKernel {}
678impl FittableKernel<NystromMethod> for PolynomialKernel {}
679impl FittableKernel<FastfoodMethod> for RBFKernel {}
680
681/// Marker trait for methods that can be fitted with specific kernels
682pub trait FittableMethod<Kernel: KernelType>: ApproximationMethod {}
683impl FittableMethod<RBFKernel> for RandomFourierFeatures {}
684impl FittableMethod<LaplacianKernel> for RandomFourierFeatures {}
685impl FittableMethod<RBFKernel> for NystromMethod {}
686impl FittableMethod<LaplacianKernel> for NystromMethod {}
687impl FittableMethod<PolynomialKernel> for NystromMethod {}
688impl FittableMethod<RBFKernel> for FastfoodMethod {}
689
690/// Type aliases for common kernel approximation combinations
691pub type RBFRandomFourierFeatures<const N: usize> =
692    TypeSafeKernelApproximation<Untrained, RBFKernel, RandomFourierFeatures, N>;
693
694pub type LaplacianRandomFourierFeatures<const N: usize> =
695    TypeSafeKernelApproximation<Untrained, LaplacianKernel, RandomFourierFeatures, N>;
696
697pub type RBFNystrom<const N: usize> =
698    TypeSafeKernelApproximation<Untrained, RBFKernel, NystromMethod, N>;
699
700pub type PolynomialNystrom<const N: usize> =
701    TypeSafeKernelApproximation<Untrained, PolynomialKernel, NystromMethod, N>;
702
703pub type RBFFastfood<const N: usize> =
704    TypeSafeKernelApproximation<Untrained, RBFKernel, FastfoodMethod, N>;
705
706/// Fitted type aliases
707pub type FittedRBFRandomFourierFeatures<const N: usize> =
708    FittedTypeSafeKernelApproximation<RBFKernel, RandomFourierFeatures, N>;
709
710pub type FittedLaplacianRandomFourierFeatures<const N: usize> =
711    FittedTypeSafeKernelApproximation<LaplacianKernel, RandomFourierFeatures, N>;
712
713pub type FittedRBFNystrom<const N: usize> =
714    FittedTypeSafeKernelApproximation<RBFKernel, NystromMethod, N>;
715
716#[allow(non_snake_case)]
717#[cfg(test)]
718mod tests {
719    use super::*;
720    use scirs2_core::ndarray::array;
721
722    #[test]
723    fn test_type_safe_rbf_rff() {
724        let data = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0],];
725
726        // Create untrained approximation with compile-time dimension
727        let approximation: RBFRandomFourierFeatures<10> = TypeSafeKernelApproximation::new()
728            .bandwidth(1.5)
729            .random_state(42);
730
731        // Fit and transform
732        let fitted = approximation.fit(&data).unwrap();
733        let features = fitted.transform(&data).unwrap();
734
735        assert_eq!(features.shape(), &[4, 20]); // 10 components * 2 (cos, sin)
736        assert_eq!(FittedRBFRandomFourierFeatures::<10>::n_components(), 10);
737        assert_eq!(fitted.kernel_name(), "RBF");
738        assert_eq!(fitted.method_name(), "RandomFourierFeatures");
739    }
740
741    #[test]
742    fn test_type_safe_nystrom() {
743        let data = array![
744            [1.0, 2.0, 3.0],
745            [2.0, 3.0, 4.0],
746            [3.0, 4.0, 5.0],
747            [4.0, 5.0, 6.0],
748        ];
749
750        let approximation: RBFNystrom<5> = TypeSafeKernelApproximation::new()
751            .bandwidth(2.0)
752            .random_state(42);
753
754        let fitted = approximation.fit(&data).unwrap();
755        let features = fitted.transform(&data).unwrap();
756
757        assert_eq!(features.shape()[0], 4); // n_samples
758        assert_eq!(fitted.kernel_name(), "RBF");
759        assert_eq!(fitted.method_name(), "Nystrom");
760    }
761
762    #[test]
763    fn test_polynomial_kernel() {
764        let data = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0],];
765
766        let approximation: PolynomialNystrom<3> = TypeSafeKernelApproximation::new()
767            .degree(3)
768            .random_state(42);
769
770        let fitted = approximation.fit(&data).unwrap();
771        let features = fitted.transform(&data).unwrap();
772
773        assert_eq!(fitted.kernel_name(), "Polynomial");
774        assert!(features.nrows() == 3);
775    }
776
777    #[test]
778    fn test_compile_time_constants() {
779        // These should be available at compile time
780        assert_eq!(RBFKernel::NAME, "RBF");
781        assert!(RBFKernel::SUPPORTS_PARAMETER_LEARNING);
782        assert_eq!(RandomFourierFeatures::NAME, "RandomFourierFeatures");
783        assert!(RandomFourierFeatures::SUPPORTS_INCREMENTAL);
784        assert!(RandomFourierFeatures::HAS_ERROR_BOUNDS);
785    }
786
787    // This test demonstrates compile-time type safety
788    // The following code would not compile due to type constraints:
789
790    // #[test]
791    // fn test_type_safety_violations() {
792    //     // This would fail: ArcCosine kernel doesn't support bandwidth
793    //     // let _ = TypeSafeKernelApproximation::<Untrained, ArcCosineKernel, RandomFourierFeatures, 10>::new()
794    //     //     .bandwidth(1.0); // ERROR: ArcCosineKernel doesn't implement KernelTypeWithBandwidth
795    //
796    //     // This would fail: RBF kernel with degree parameter
797    //     // let _ = TypeSafeKernelApproximation::<Untrained, RBFKernel, RandomFourierFeatures, 10>::new()
798    //     //     .degree(2); // ERROR: RBFKernel doesn't implement PolynomialKernelType
799    //
800    //     // This would fail: Trying to fit incompatible kernel-method combination
801    //     // let _ = TypeSafeKernelApproximation::<Untrained, ArcCosineKernel, RandomFourierFeatures, 10>::new()
802    //     //     .fit(&data); // ERROR: ArcCosineKernel doesn't implement FittableKernel<RandomFourierFeatures>
803    // }
804}
805
806// ==================================================================================
807// ADVANCED TYPE SAFETY ENHANCEMENTS - Zero-Cost Abstractions and Compile-Time Validation
808// ==================================================================================
809
810/// Advanced compile-time parameter validation traits
811pub trait ParameterValidation<const MIN: usize, const MAX: usize> {
812    /// Validates parameters at compile time
813    const IS_VALID: bool = MIN <= MAX;
814
815    /// Get the parameter range
816    fn parameter_range() -> (usize, usize) {
817        (MIN, MAX)
818    }
819}
820
821/// Zero-cost abstraction for validated component counts
822#[derive(Debug, Clone, Copy)]
823/// ValidatedComponents
824pub struct ValidatedComponents<const N: usize>;
825
826impl<const N: usize> ValidatedComponents<N> {
827    /// Create validated components with compile-time checks
828    pub fn new() -> Self {
829        // Compile-time assertions
830        assert!(N > 0, "Component count must be positive");
831        assert!(N <= 10000, "Component count too large");
832        Self
833    }
834
835    /// Get the component count
836    pub const fn count(&self) -> usize {
837        N
838    }
839}
840
841/// Compile-time compatibility checking between kernels and methods
842pub trait KernelMethodCompatibility<K: KernelType, M: ApproximationMethod> {
843    /// Whether this kernel-method combination is supported
844    const IS_COMPATIBLE: bool;
845
846    /// Performance characteristics of this combination
847    const PERFORMANCE_TIER: PerformanceTier;
848
849    /// Memory complexity
850    const MEMORY_COMPLEXITY: ComplexityClass;
851}
852
853/// Performance tiers for different kernel-method combinations
854#[derive(Debug, Clone, Copy, PartialEq, Eq)]
855/// PerformanceTier
856pub enum PerformanceTier {
857    /// Optimal performance combination
858    Optimal,
859    /// Good performance
860    Good,
861    /// Acceptable performance
862    Acceptable,
863    /// Poor performance (not recommended)
864    Poor,
865}
866
867/// Implement compatibility rules
868impl KernelMethodCompatibility<RBFKernel, RandomFourierFeatures> for () {
869    const IS_COMPATIBLE: bool = true;
870    const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Optimal;
871    const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Linear;
872}
873
874impl KernelMethodCompatibility<LaplacianKernel, RandomFourierFeatures> for () {
875    const IS_COMPATIBLE: bool = true;
876    const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Optimal;
877    const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Linear;
878}
879
880impl KernelMethodCompatibility<RBFKernel, NystromMethod> for () {
881    const IS_COMPATIBLE: bool = true;
882    const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Good;
883    const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Quadratic;
884}
885
886impl KernelMethodCompatibility<PolynomialKernel, NystromMethod> for () {
887    const IS_COMPATIBLE: bool = true;
888    const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Good;
889    const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Quadratic;
890}
891
892impl KernelMethodCompatibility<RBFKernel, FastfoodMethod> for () {
893    const IS_COMPATIBLE: bool = true;
894    const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Optimal;
895    const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::DimensionDependent;
896}
897
898/// Arc-cosine kernels have limited compatibility
899impl KernelMethodCompatibility<ArcCosineKernel, RandomFourierFeatures> for () {
900    const IS_COMPATIBLE: bool = true;
901    const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Acceptable;
902    const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Linear;
903}
904
905impl KernelMethodCompatibility<ArcCosineKernel, NystromMethod> for () {
906    const IS_COMPATIBLE: bool = false; // Not recommended
907    const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Poor;
908    const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Quadratic;
909}
910
911impl KernelMethodCompatibility<ArcCosineKernel, FastfoodMethod> for () {
912    const IS_COMPATIBLE: bool = false; // Not supported
913    const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Poor;
914    const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::DimensionDependent;
915}
916
917// Additional compatibility implementations for missing combinations
918impl KernelMethodCompatibility<LaplacianKernel, NystromMethod> for () {
919    const IS_COMPATIBLE: bool = true;
920    const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Good;
921    const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Quadratic;
922}
923
924impl KernelMethodCompatibility<PolynomialKernel, RandomFourierFeatures> for () {
925    const IS_COMPATIBLE: bool = true;
926    const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Good;
927    const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Linear;
928}
929
930/// Zero-cost wrapper for compile-time validated approximations
931#[derive(Debug, Clone)]
932/// ValidatedKernelApproximation
933pub struct ValidatedKernelApproximation<K, M, const N: usize>
934where
935    K: KernelType,
936    M: ApproximationMethod,
937    (): KernelMethodCompatibility<K, M>,
938{
939    inner: TypeSafeKernelApproximation<Untrained, K, M, N>,
940    _validation: ValidatedComponents<N>,
941}
942
943impl<K, M, const N: usize> ValidatedKernelApproximation<K, M, N>
944where
945    K: KernelType,
946    M: ApproximationMethod,
947    (): KernelMethodCompatibility<K, M>,
948{
949    /// Create a new validated kernel approximation with compile-time checks
950    pub fn new() -> Self {
951        // Compile-time compatibility check
952        assert!(
953            <() as KernelMethodCompatibility<K, M>>::IS_COMPATIBLE,
954            "Incompatible kernel-method combination"
955        );
956
957        Self {
958            inner: TypeSafeKernelApproximation::new(),
959            _validation: ValidatedComponents::new(),
960        }
961    }
962
963    /// Get performance information at compile time
964    pub const fn performance_info() -> (PerformanceTier, ComplexityClass, ComplexityClass) {
965        (
966            <() as KernelMethodCompatibility<K, M>>::PERFORMANCE_TIER,
967            M::COMPLEXITY,
968            <() as KernelMethodCompatibility<K, M>>::MEMORY_COMPLEXITY,
969        )
970    }
971
972    /// Check if this combination is optimal
973    pub const fn is_optimal() -> bool {
974        matches!(
975            <() as KernelMethodCompatibility<K, M>>::PERFORMANCE_TIER,
976            PerformanceTier::Optimal
977        )
978    }
979}
980
981/// Enhanced quality metrics with compile-time bounds
982#[derive(Debug, Clone, Copy)]
983/// BoundedQualityMetrics
984pub struct BoundedQualityMetrics<const MIN_ALIGNMENT: u32, const MAX_ERROR: u32> {
985    kernel_alignment: f64,
986    approximation_error: f64,
987    effective_rank: f64,
988}
989
990impl<const MIN_ALIGNMENT: u32, const MAX_ERROR: u32>
991    BoundedQualityMetrics<MIN_ALIGNMENT, MAX_ERROR>
992{
993    /// Create new bounded quality metrics with compile-time validation
994    pub const fn new(alignment: f64, error: f64, rank: f64) -> Option<Self> {
995        let min_align = MIN_ALIGNMENT as f64 / 100.0; // Convert percentage to decimal
996        let max_err = MAX_ERROR as f64 / 100.0;
997
998        if alignment >= min_align && error <= max_err && rank > 0.0 {
999            Some(Self {
1000                kernel_alignment: alignment,
1001                approximation_error: error,
1002                effective_rank: rank,
1003            })
1004        } else {
1005            None
1006        }
1007    }
1008
1009    /// Get compile-time bounds
1010    pub const fn bounds() -> (f64, f64) {
1011        (MIN_ALIGNMENT as f64 / 100.0, MAX_ERROR as f64 / 100.0)
1012    }
1013
1014    /// Check if metrics meet quality standards
1015    pub const fn meets_standards(&self) -> bool {
1016        let (min_align, max_err) = Self::bounds();
1017        self.kernel_alignment >= min_align && self.approximation_error <= max_err
1018    }
1019}
1020
1021/// Type alias for high-quality metrics (>90% alignment, <5% error)
1022pub type HighQualityMetrics = BoundedQualityMetrics<90, 5>;
1023
1024/// Type alias for acceptable quality metrics (>70% alignment, <15% error)
1025pub type AcceptableQualityMetrics = BoundedQualityMetrics<70, 15>;
1026
1027/// Macro for easy creation of validated kernel approximations
1028#[macro_export]
1029macro_rules! validated_kernel_approximation {
1030    (RBF, RandomFourierFeatures, $n:literal) => {
1031        ValidatedKernelApproximation::<RBFKernel, RandomFourierFeatures, $n>::new()
1032    };
1033    (Laplacian, RandomFourierFeatures, $n:literal) => {
1034        ValidatedKernelApproximation::<LaplacianKernel, RandomFourierFeatures, $n>::new()
1035    };
1036    (RBF, Nystrom, $n:literal) => {
1037        ValidatedKernelApproximation::<RBFKernel, NystromMethod, $n>::new()
1038    };
1039    (RBF, Fastfood, $n:literal) => {
1040        ValidatedKernelApproximation::<RBFKernel, FastfoodMethod, $n>::new()
1041    };
1042}
1043
1044#[allow(non_snake_case)]
1045#[cfg(test)]
1046mod advanced_type_safety_tests {
1047    use super::*;
1048
1049    #[test]
1050    fn test_validated_components() {
1051        let components = ValidatedComponents::<100>::new();
1052        assert_eq!(components.count(), 100);
1053    }
1054
1055    #[test]
1056    fn test_kernel_method_compatibility() {
1057        // Test compile-time compatibility checks
1058        assert!(<() as KernelMethodCompatibility<RBFKernel, RandomFourierFeatures>>::IS_COMPATIBLE);
1059        assert!(!<() as KernelMethodCompatibility<ArcCosineKernel, NystromMethod>>::IS_COMPATIBLE);
1060
1061        // Test performance tiers
1062        assert_eq!(
1063            <() as KernelMethodCompatibility<RBFKernel, RandomFourierFeatures>>::PERFORMANCE_TIER,
1064            PerformanceTier::Optimal
1065        );
1066    }
1067
1068    #[test]
1069    fn test_bounded_quality_metrics() {
1070        // High quality metrics
1071        let high_quality = HighQualityMetrics::new(0.95, 0.03, 50.0).unwrap();
1072        assert!(high_quality.meets_standards());
1073
1074        // Low quality metrics should fail bounds check
1075        let low_quality = HighQualityMetrics::new(0.60, 0.20, 10.0);
1076        assert!(low_quality.is_none());
1077    }
1078
1079    #[test]
1080    fn test_macro_creation() {
1081        // Test that the macro compiles for valid combinations
1082        let _rbf_rff = validated_kernel_approximation!(RBF, RandomFourierFeatures, 100);
1083        let _lap_rff = validated_kernel_approximation!(Laplacian, RandomFourierFeatures, 50);
1084        let _rbf_nys = validated_kernel_approximation!(RBF, Nystrom, 30);
1085        let _rbf_ff = validated_kernel_approximation!(RBF, Fastfood, 128);
1086
1087        // Performance info should be available at compile time
1088        let (tier, complexity, memory) = ValidatedKernelApproximation::<
1089            RBFKernel,
1090            RandomFourierFeatures,
1091            100,
1092        >::performance_info();
1093        assert_eq!(tier, PerformanceTier::Optimal);
1094        assert_eq!(complexity, ComplexityClass::Linear);
1095        assert_eq!(memory, ComplexityClass::Linear);
1096    }
1097}
1098
1099/// Advanced Zero-Cost Kernel Composition Abstractions
1100///
1101/// These traits enable compile-time composition of kernel operations
1102/// with zero runtime overhead.
1103///
1104/// Trait for kernels that can be composed
1105pub trait ComposableKernel: KernelType {
1106    type CompositionResult<Other: ComposableKernel>: ComposableKernel;
1107
1108    /// Combine this kernel with another kernel
1109    fn compose<Other: ComposableKernel>(self) -> Self::CompositionResult<Other>;
1110}
1111
1112/// Sum composition of two kernels
1113#[derive(Debug, Clone, Copy)]
1114/// SumKernel
1115pub struct SumKernel<K1: KernelType, K2: KernelType> {
1116    _phantom: PhantomData<(K1, K2)>,
1117}
1118
1119impl<K1: KernelType, K2: KernelType> KernelType for SumKernel<K1, K2> {
1120    const NAME: &'static str = "Sum";
1121    const SUPPORTS_PARAMETER_LEARNING: bool =
1122        K1::SUPPORTS_PARAMETER_LEARNING && K2::SUPPORTS_PARAMETER_LEARNING;
1123    const DEFAULT_BANDWIDTH: f64 = (K1::DEFAULT_BANDWIDTH + K2::DEFAULT_BANDWIDTH) / 2.0;
1124}
1125
1126/// Product composition of two kernels  
1127#[derive(Debug, Clone, Copy)]
1128/// ProductKernel
1129pub struct ProductKernel<K1: KernelType, K2: KernelType> {
1130    _phantom: PhantomData<(K1, K2)>,
1131}
1132
1133impl<K1: KernelType, K2: KernelType> KernelType for ProductKernel<K1, K2> {
1134    const NAME: &'static str = "Product";
1135    const SUPPORTS_PARAMETER_LEARNING: bool =
1136        K1::SUPPORTS_PARAMETER_LEARNING && K2::SUPPORTS_PARAMETER_LEARNING;
1137    const DEFAULT_BANDWIDTH: f64 = K1::DEFAULT_BANDWIDTH * K2::DEFAULT_BANDWIDTH;
1138}
1139
1140impl ComposableKernel for RBFKernel {
1141    type CompositionResult<Other: ComposableKernel> = SumKernel<Self, Other>;
1142
1143    fn compose<Other: ComposableKernel>(self) -> Self::CompositionResult<Other> {
1144        SumKernel {
1145            _phantom: PhantomData,
1146        }
1147    }
1148}
1149
1150impl ComposableKernel for LaplacianKernel {
1151    type CompositionResult<Other: ComposableKernel> = SumKernel<Self, Other>;
1152
1153    fn compose<Other: ComposableKernel>(self) -> Self::CompositionResult<Other> {
1154        SumKernel {
1155            _phantom: PhantomData,
1156        }
1157    }
1158}
1159
1160impl ComposableKernel for PolynomialKernel {
1161    type CompositionResult<Other: ComposableKernel> = ProductKernel<Self, Other>;
1162
1163    fn compose<Other: ComposableKernel>(self) -> Self::CompositionResult<Other> {
1164        ProductKernel {
1165            _phantom: PhantomData,
1166        }
1167    }
1168}
1169
1170/// Compile-time feature size validation
1171pub trait ValidatedFeatureSize<const N: usize> {
1172    const IS_POWER_OF_TWO: bool = (N != 0) && ((N & (N - 1)) == 0);
1173    const IS_REASONABLE_SIZE: bool = N >= 8 && N <= 8192;
1174    const IS_VALID: bool = Self::IS_POWER_OF_TWO && Self::IS_REASONABLE_SIZE;
1175}
1176
1177// Implement for all const generic sizes
1178impl<const N: usize> ValidatedFeatureSize<N> for () {}
1179
1180/// Zero-cost wrapper for validated feature dimensions
1181#[derive(Debug, Clone)]
1182/// ValidatedFeatures
1183pub struct ValidatedFeatures<const N: usize> {
1184    _phantom: PhantomData<[f64; N]>,
1185}
1186
1187impl<const N: usize> ValidatedFeatures<N>
1188where
1189    (): ValidatedFeatureSize<N>,
1190{
1191    /// Create new validated features (compile-time checked)
1192    pub const fn new() -> Self {
1193        Self {
1194            _phantom: PhantomData,
1195        }
1196    }
1197
1198    /// Get the validated feature count
1199    pub const fn count() -> usize {
1200        N
1201    }
1202
1203    /// Check if size is optimal for Fastfood
1204    pub const fn is_fastfood_optimal() -> bool {
1205        <() as ValidatedFeatureSize<N>>::IS_POWER_OF_TWO
1206    }
1207}
1208
1209/// Advanced approximation quality bounds
1210pub trait ApproximationQualityBounds<Method: ApproximationMethod> {
1211    /// Theoretical error bound for this method
1212    const ERROR_BOUND_CONSTANT: f64;
1213
1214    /// Sample complexity scaling
1215    const SAMPLE_COMPLEXITY_EXPONENT: f64;
1216
1217    /// Dimension dependency
1218    const DIMENSION_DEPENDENCY: f64;
1219
1220    /// Compute theoretical error bound
1221    fn error_bound(n_samples: usize, n_features: usize, n_components: usize) -> f64 {
1222        let base_rate = Self::ERROR_BOUND_CONSTANT;
1223        let sample_factor = (n_samples as f64).powf(-Self::SAMPLE_COMPLEXITY_EXPONENT);
1224        let dim_factor = (n_features as f64).powf(Self::DIMENSION_DEPENDENCY);
1225        let comp_factor = (n_components as f64).powf(-0.5);
1226
1227        base_rate * sample_factor * dim_factor * comp_factor
1228    }
1229}
1230
1231impl ApproximationQualityBounds<RandomFourierFeatures> for () {
1232    const ERROR_BOUND_CONSTANT: f64 = 2.0;
1233    const SAMPLE_COMPLEXITY_EXPONENT: f64 = 0.25;
1234    const DIMENSION_DEPENDENCY: f64 = 0.1;
1235}
1236
1237impl ApproximationQualityBounds<NystromMethod> for () {
1238    const ERROR_BOUND_CONSTANT: f64 = 1.5;
1239    const SAMPLE_COMPLEXITY_EXPONENT: f64 = 0.33;
1240    const DIMENSION_DEPENDENCY: f64 = 0.05;
1241}
1242
1243impl ApproximationQualityBounds<FastfoodMethod> for () {
1244    const ERROR_BOUND_CONSTANT: f64 = 3.0;
1245    const SAMPLE_COMPLEXITY_EXPONENT: f64 = 0.2;
1246    const DIMENSION_DEPENDENCY: f64 = 0.15;
1247}
1248
1249/// Advanced type-safe kernel configuration builder
1250#[derive(Debug, Clone)]
1251/// TypeSafeKernelConfig
1252pub struct TypeSafeKernelConfig<K: KernelType, M: ApproximationMethod, const N: usize>
1253where
1254    (): ValidatedFeatureSize<N>,
1255    (): KernelMethodCompatibility<K, M>,
1256    (): ApproximationQualityBounds<M>,
1257{
1258    kernel_type: PhantomData<K>,
1259    method_type: PhantomData<M>,
1260    features: ValidatedFeatures<N>,
1261    bandwidth: f64,
1262    quality_threshold: f64,
1263}
1264
1265impl<K: KernelType, M: ApproximationMethod, const N: usize> TypeSafeKernelConfig<K, M, N>
1266where
1267    (): ValidatedFeatureSize<N>,
1268    (): KernelMethodCompatibility<K, M>,
1269    (): ApproximationQualityBounds<M>,
1270{
1271    /// Create a new type-safe configuration
1272    pub fn new() -> Self {
1273        Self {
1274            kernel_type: PhantomData,
1275            method_type: PhantomData,
1276            features: ValidatedFeatures::new(),
1277            bandwidth: K::DEFAULT_BANDWIDTH,
1278            quality_threshold: 0.1,
1279        }
1280    }
1281
1282    /// Set bandwidth parameter with compile-time validation
1283    pub fn bandwidth(mut self, bandwidth: f64) -> Self {
1284        assert!(bandwidth > 0.0, "Bandwidth must be positive");
1285        self.bandwidth = bandwidth;
1286        self
1287    }
1288
1289    /// Set quality threshold with bounds checking
1290    pub fn quality_threshold(mut self, threshold: f64) -> Self {
1291        assert!(
1292            threshold > 0.0 && threshold < 1.0,
1293            "Quality threshold must be between 0 and 1"
1294        );
1295        self.quality_threshold = threshold;
1296        self
1297    }
1298
1299    /// Get performance tier for this configuration
1300    pub const fn performance_tier() -> PerformanceTier {
1301        <() as KernelMethodCompatibility<K, M>>::PERFORMANCE_TIER
1302    }
1303
1304    /// Get memory complexity for this configuration
1305    pub const fn memory_complexity() -> ComplexityClass {
1306        <() as KernelMethodCompatibility<K, M>>::MEMORY_COMPLEXITY
1307    }
1308
1309    /// Compute theoretical error bound for given data size
1310    pub fn theoretical_error_bound(&self, n_samples: usize, n_features: usize) -> f64 {
1311        <() as ApproximationQualityBounds<M>>::error_bound(n_samples, n_features, N)
1312    }
1313
1314    /// Check if configuration meets quality requirements
1315    pub fn meets_quality_requirements(&self, n_samples: usize, n_features: usize) -> bool {
1316        let error_bound = self.theoretical_error_bound(n_samples, n_features);
1317        error_bound <= self.quality_threshold
1318    }
1319}
1320
1321/// Type aliases for common validated configurations
1322pub type ValidatedRBFRandomFourier<const N: usize> =
1323    TypeSafeKernelConfig<RBFKernel, RandomFourierFeatures, N>;
1324pub type ValidatedLaplacianNystrom<const N: usize> =
1325    TypeSafeKernelConfig<LaplacianKernel, NystromMethod, N>;
1326pub type ValidatedPolynomialRFF<const N: usize> =
1327    TypeSafeKernelConfig<PolynomialKernel, RandomFourierFeatures, N>;
1328
1329impl<K1: KernelType, K2: KernelType> ComposableKernel for SumKernel<K1, K2> {
1330    type CompositionResult<Other: ComposableKernel> = SumKernel<Self, Other>;
1331
1332    fn compose<Other: ComposableKernel>(self) -> Self::CompositionResult<Other> {
1333        SumKernel {
1334            _phantom: PhantomData,
1335        }
1336    }
1337}
1338
1339impl<K1: KernelType, K2: KernelType> ComposableKernel for ProductKernel<K1, K2> {
1340    type CompositionResult<Other: ComposableKernel> = ProductKernel<Self, Other>;
1341
1342    fn compose<Other: ComposableKernel>(self) -> Self::CompositionResult<Other> {
1343        ProductKernel {
1344            _phantom: PhantomData,
1345        }
1346    }
1347}
1348
1349#[allow(non_snake_case)]
1350#[cfg(test)]
1351mod advanced_composition_tests {
1352    use super::*;
1353
1354    #[test]
1355    fn test_validated_features() {
1356        // These should compile
1357        let _features_64 = ValidatedFeatures::<64>::new();
1358        let _features_128 = ValidatedFeatures::<128>::new();
1359        let _features_256 = ValidatedFeatures::<256>::new();
1360
1361        assert_eq!(ValidatedFeatures::<64>::count(), 64);
1362        assert!(ValidatedFeatures::<64>::is_fastfood_optimal());
1363    }
1364
1365    #[test]
1366    fn test_type_safe_configs() {
1367        let config = ValidatedRBFRandomFourier::<128>::new()
1368            .bandwidth(1.5)
1369            .quality_threshold(0.05);
1370
1371        assert_eq!(
1372            ValidatedRBFRandomFourier::<128>::performance_tier(),
1373            PerformanceTier::Optimal
1374        );
1375        assert!(config.meets_quality_requirements(1000, 10));
1376    }
1377
1378    #[test]
1379    fn test_kernel_composition() {
1380        let _rbf = RBFKernel;
1381        let _laplacian = LaplacianKernel;
1382        let _polynomial = PolynomialKernel;
1383
1384        // These should compile and create valid composed kernels
1385        let _composed1 = _rbf.compose::<LaplacianKernel>();
1386        let _composed2 = _polynomial.compose::<RBFKernel>();
1387    }
1388
1389    #[test]
1390    fn test_approximation_bounds() {
1391        let rff_bound =
1392            <() as ApproximationQualityBounds<RandomFourierFeatures>>::error_bound(1000, 10, 100);
1393        let nystrom_bound =
1394            <() as ApproximationQualityBounds<NystromMethod>>::error_bound(1000, 10, 100);
1395        let fastfood_bound =
1396            <() as ApproximationQualityBounds<FastfoodMethod>>::error_bound(1000, 10, 100);
1397
1398        assert!(rff_bound > 0.0);
1399        assert!(nystrom_bound > 0.0);
1400        assert!(fastfood_bound > 0.0);
1401
1402        // Nyström should generally have tighter bounds than RFF
1403        assert!(nystrom_bound < rff_bound);
1404    }
1405}
1406
1407// ============================================================================
1408// Configuration Presets for Common Use Cases
1409// ============================================================================
1410
1411/// Configuration presets for common kernel approximation scenarios
1412pub struct KernelPresets;
1413
1414impl KernelPresets {
1415    /// Fast approximation preset - prioritizes speed over accuracy
1416    pub fn fast_rbf_128() -> ValidatedRBFRandomFourier<128> {
1417        ValidatedRBFRandomFourier::<128>::new()
1418            .bandwidth(1.0)
1419            .quality_threshold(0.2)
1420    }
1421
1422    /// Balanced approximation preset - good trade-off between speed and accuracy
1423    pub fn balanced_rbf_256() -> ValidatedRBFRandomFourier<256> {
1424        ValidatedRBFRandomFourier::<256>::new()
1425            .bandwidth(1.0)
1426            .quality_threshold(0.1)
1427    }
1428
1429    /// High-accuracy approximation preset - prioritizes accuracy over speed
1430    pub fn accurate_rbf_512() -> ValidatedRBFRandomFourier<512> {
1431        ValidatedRBFRandomFourier::<512>::new()
1432            .bandwidth(1.0)
1433            .quality_threshold(0.05)
1434    }
1435
1436    /// Ultra-fast approximation for large-scale problems
1437    pub fn ultrafast_rbf_64() -> ValidatedRBFRandomFourier<64> {
1438        ValidatedRBFRandomFourier::<64>::new()
1439            .bandwidth(1.0)
1440            .quality_threshold(0.3)
1441    }
1442
1443    /// High-precision Nyström preset for small to medium datasets
1444    pub fn precise_nystroem_128() -> ValidatedLaplacianNystrom<128> {
1445        ValidatedLaplacianNystrom::<128>::new()
1446            .bandwidth(1.0)
1447            .quality_threshold(0.01)
1448    }
1449
1450    /// Memory-efficient preset for resource-constrained environments
1451    pub fn memory_efficient_rbf_32() -> ValidatedRBFRandomFourier<32> {
1452        ValidatedRBFRandomFourier::<32>::new()
1453            .bandwidth(1.0)
1454            .quality_threshold(0.4)
1455    }
1456
1457    /// Polynomial kernel preset for structured data
1458    pub fn polynomial_features_256() -> ValidatedPolynomialRFF<256> {
1459        ValidatedPolynomialRFF::<256>::new()
1460            .bandwidth(1.0)
1461            .quality_threshold(0.15)
1462    }
1463}
1464
1465// ============================================================================
1466// Profile-Guided Optimization Support
1467// ============================================================================
1468
1469/// Profile-guided optimization configuration
1470#[derive(Debug, Clone)]
1471/// ProfileGuidedConfig
1472pub struct ProfileGuidedConfig {
1473    /// Enable PGO-based feature size selection
1474    pub enable_pgo_feature_selection: bool,
1475
1476    /// Enable PGO-based bandwidth optimization
1477    pub enable_pgo_bandwidth_optimization: bool,
1478
1479    /// Profile data file path
1480    pub profile_data_path: Option<String>,
1481
1482    /// Target hardware architecture
1483    pub target_architecture: TargetArchitecture,
1484
1485    /// Optimization level
1486    pub optimization_level: OptimizationLevel,
1487}
1488
1489/// Target hardware architecture for optimization
1490#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1491/// TargetArchitecture
1492pub enum TargetArchitecture {
1493    /// Generic x86_64 architecture
1494    X86_64Generic,
1495    /// x86_64 with AVX2 support
1496    X86_64AVX2,
1497    /// x86_64 with AVX-512 support
1498    X86_64AVX512,
1499    /// ARM64 architecture
1500    ARM64,
1501    /// ARM64 with NEON support
1502    ARM64NEON,
1503}
1504
1505/// Optimization level for profile-guided optimization
1506#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1507/// OptimizationLevel
1508pub enum OptimizationLevel {
1509    /// No optimization
1510    None,
1511    /// Basic optimizations
1512    Basic,
1513    /// Aggressive optimizations
1514    Aggressive,
1515    /// Maximum optimizations (may increase compile time significantly)
1516    Maximum,
1517}
1518
1519impl Default for ProfileGuidedConfig {
1520    fn default() -> Self {
1521        Self {
1522            enable_pgo_feature_selection: false,
1523            enable_pgo_bandwidth_optimization: false,
1524            profile_data_path: None,
1525            target_architecture: TargetArchitecture::X86_64Generic,
1526            optimization_level: OptimizationLevel::Basic,
1527        }
1528    }
1529}
1530
1531impl ProfileGuidedConfig {
1532    /// Create a new PGO configuration
1533    pub fn new() -> Self {
1534        Self::default()
1535    }
1536
1537    /// Enable PGO-based feature size selection
1538    pub fn enable_feature_selection(mut self) -> Self {
1539        self.enable_pgo_feature_selection = true;
1540        self
1541    }
1542
1543    /// Enable PGO-based bandwidth optimization
1544    pub fn enable_bandwidth_optimization(mut self) -> Self {
1545        self.enable_pgo_bandwidth_optimization = true;
1546        self
1547    }
1548
1549    /// Set profile data file path
1550    pub fn profile_data_path<P: Into<String>>(mut self, path: P) -> Self {
1551        self.profile_data_path = Some(path.into());
1552        self
1553    }
1554
1555    /// Set target architecture
1556    pub fn target_architecture(mut self, arch: TargetArchitecture) -> Self {
1557        self.target_architecture = arch;
1558        self
1559    }
1560
1561    /// Set optimization level
1562    pub fn optimization_level(mut self, level: OptimizationLevel) -> Self {
1563        self.optimization_level = level;
1564        self
1565    }
1566
1567    /// Get recommended feature count based on architecture and optimization level
1568    pub fn recommended_feature_count(&self, data_size: usize, dimensionality: usize) -> usize {
1569        let base_features = match self.target_architecture {
1570            TargetArchitecture::X86_64Generic => 128,
1571            TargetArchitecture::X86_64AVX2 => 256,
1572            TargetArchitecture::X86_64AVX512 => 512,
1573            TargetArchitecture::ARM64 => 128,
1574            TargetArchitecture::ARM64NEON => 256,
1575        };
1576
1577        let scale_factor = match self.optimization_level {
1578            OptimizationLevel::None => 0.5,
1579            OptimizationLevel::Basic => 1.0,
1580            OptimizationLevel::Aggressive => 1.5,
1581            OptimizationLevel::Maximum => 2.0,
1582        };
1583
1584        let scaled_features = (base_features as f64 * scale_factor) as usize;
1585
1586        // Adjust based on data characteristics
1587        let data_adjustment = if data_size > 10000 {
1588            1.2
1589        } else if data_size < 1000 {
1590            0.8
1591        } else {
1592            1.0
1593        };
1594
1595        let dimension_adjustment = if dimensionality > 100 {
1596            1.1
1597        } else if dimensionality < 10 {
1598            0.9
1599        } else {
1600            1.0
1601        };
1602
1603        ((scaled_features as f64 * data_adjustment * dimension_adjustment) as usize)
1604            .max(32)
1605            .min(1024)
1606    }
1607}
1608
1609// ============================================================================
1610// Serialization Support for Approximation Models
1611// ============================================================================
1612
1613/// Serializable kernel approximation configuration
1614#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
1615/// SerializableKernelConfig
1616pub struct SerializableKernelConfig {
1617    /// Kernel type name
1618    pub kernel_type: String,
1619
1620    /// Approximation method name
1621    pub approximation_method: String,
1622
1623    /// Number of features/components
1624    pub n_components: usize,
1625
1626    /// Bandwidth parameter
1627    pub bandwidth: f64,
1628
1629    /// Quality threshold
1630    pub quality_threshold: f64,
1631
1632    /// Random state for reproducibility
1633    pub random_state: Option<u64>,
1634
1635    /// Additional parameters
1636    pub additional_params: std::collections::HashMap<String, f64>,
1637}
1638
1639/// Serializable fitted model parameters
1640#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
1641/// SerializableFittedParams
1642pub struct SerializableFittedParams {
1643    /// Configuration used to create this model
1644    pub config: SerializableKernelConfig,
1645
1646    /// Random features (for RFF methods)
1647    pub random_features: Option<Vec<Vec<f64>>>,
1648
1649    /// Selected indices (for Nyström methods)
1650    pub selected_indices: Option<Vec<usize>>,
1651
1652    /// Eigenvalues (for Nyström methods)
1653    pub eigenvalues: Option<Vec<f64>>,
1654
1655    /// Eigenvectors (for Nyström methods)
1656    pub eigenvectors: Option<Vec<Vec<f64>>>,
1657
1658    /// Quality metrics achieved
1659    pub quality_metrics: std::collections::HashMap<String, f64>,
1660
1661    /// Timestamp of when model was fitted
1662    pub fitted_timestamp: Option<u64>,
1663}
1664
1665/// Trait for serializable kernel approximation methods
1666pub trait SerializableKernelApproximation {
1667    /// Export configuration to serializable format
1668    fn export_config(&self) -> Result<SerializableKernelConfig>;
1669
1670    /// Import configuration from serializable format
1671    fn import_config(config: &SerializableKernelConfig) -> Result<Self>
1672    where
1673        Self: Sized;
1674
1675    /// Export fitted parameters to serializable format
1676    fn export_fitted_params(&self) -> Result<SerializableFittedParams>;
1677
1678    /// Import fitted parameters from serializable format
1679    fn import_fitted_params(&mut self, params: &SerializableFittedParams) -> Result<()>;
1680
1681    /// Save model to file
1682    fn save_to_file<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
1683        let config = self.export_config()?;
1684        let fitted_params = self.export_fitted_params()?;
1685
1686        let model_data = serde_json::json!({
1687            "config": config,
1688            "fitted_params": fitted_params,
1689            "version": "1.0",
1690            "sklears_version": env!("CARGO_PKG_VERSION"),
1691        });
1692
1693        std::fs::write(path, serde_json::to_string_pretty(&model_data).unwrap())
1694            .map_err(SklearsError::from)?;
1695
1696        Ok(())
1697    }
1698
1699    /// Load model from file
1700    fn load_from_file<P: AsRef<std::path::Path>>(path: P) -> Result<Self>
1701    where
1702        Self: Sized,
1703    {
1704        let content = std::fs::read_to_string(path).map_err(SklearsError::from)?;
1705
1706        let model_data: serde_json::Value = serde_json::from_str(&content).map_err(|e| {
1707            SklearsError::SerializationError(format!("Failed to parse JSON: {}", e))
1708        })?;
1709
1710        let config: SerializableKernelConfig = serde_json::from_value(model_data["config"].clone())
1711            .map_err(|e| {
1712                SklearsError::SerializationError(format!("Failed to deserialize config: {}", e))
1713            })?;
1714
1715        let fitted_params: SerializableFittedParams =
1716            serde_json::from_value(model_data["fitted_params"].clone()).map_err(|e| {
1717                SklearsError::SerializationError(format!(
1718                    "Failed to deserialize fitted params: {}",
1719                    e
1720                ))
1721            })?;
1722
1723        let mut model = Self::import_config(&config)?;
1724        model.import_fitted_params(&fitted_params)?;
1725
1726        Ok(model)
1727    }
1728}
1729
1730#[allow(non_snake_case)]
1731#[cfg(test)]
1732mod preset_tests {
1733    use super::*;
1734
1735    #[test]
1736    fn test_kernel_presets() {
1737        let fast_config = KernelPresets::fast_rbf_128();
1738        assert_eq!(
1739            ValidatedRBFRandomFourier::<128>::performance_tier(),
1740            PerformanceTier::Optimal
1741        );
1742
1743        let balanced_config = KernelPresets::balanced_rbf_256();
1744        assert_eq!(
1745            ValidatedRBFRandomFourier::<256>::performance_tier(),
1746            PerformanceTier::Optimal
1747        );
1748
1749        let accurate_config = KernelPresets::accurate_rbf_512();
1750        assert_eq!(
1751            ValidatedRBFRandomFourier::<512>::performance_tier(),
1752            PerformanceTier::Optimal
1753        );
1754    }
1755
1756    #[test]
1757    fn test_profile_guided_config() {
1758        let pgo_config = ProfileGuidedConfig::new()
1759            .enable_feature_selection()
1760            .enable_bandwidth_optimization()
1761            .target_architecture(TargetArchitecture::X86_64AVX2)
1762            .optimization_level(OptimizationLevel::Aggressive);
1763
1764        assert!(pgo_config.enable_pgo_feature_selection);
1765        assert!(pgo_config.enable_pgo_bandwidth_optimization);
1766        assert_eq!(
1767            pgo_config.target_architecture,
1768            TargetArchitecture::X86_64AVX2
1769        );
1770        assert_eq!(pgo_config.optimization_level, OptimizationLevel::Aggressive);
1771
1772        // Test feature count recommendation
1773        let features = pgo_config.recommended_feature_count(5000, 50);
1774        assert!(features >= 32 && features <= 1024);
1775    }
1776
1777    #[test]
1778    fn test_serializable_config() {
1779        let config = SerializableKernelConfig {
1780            kernel_type: "RBF".to_string(),
1781            approximation_method: "RandomFourierFeatures".to_string(),
1782            n_components: 256,
1783            bandwidth: 1.5,
1784            quality_threshold: 0.1,
1785            random_state: Some(42),
1786            additional_params: std::collections::HashMap::new(),
1787        };
1788
1789        // Test serialization
1790        let serialized = serde_json::to_string(&config).unwrap();
1791        assert!(serialized.contains("RBF"));
1792        assert!(serialized.contains("RandomFourierFeatures"));
1793
1794        // Test deserialization
1795        let deserialized: SerializableKernelConfig = serde_json::from_str(&serialized).unwrap();
1796        assert_eq!(deserialized.kernel_type, "RBF");
1797        assert_eq!(deserialized.n_components, 256);
1798        assert_eq!(deserialized.bandwidth, 1.5);
1799    }
1800}