Skip to main content

sklears_kernel_approximation/
type_safety.rs

1//! Type Safety Enhancements for Kernel Approximation Methods
2//!
3//! This module provides compile-time type safety using phantom types and const generics
4//! to prevent common errors in kernel approximation usage.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::rand_prelude::SliceRandom;
8use scirs2_core::random::essentials::{Normal as RandNormal, Uniform as RandUniform};
9use scirs2_core::random::rngs::StdRng as RealStdRng;
10use scirs2_core::random::RngExt;
11use scirs2_core::random::{thread_rng, SeedableRng};
12use sklears_core::error::{Result, SklearsError};
13use std::marker::PhantomData;
14
15/// Phantom type to represent the state of a kernel approximation method
16pub trait ApproximationState {}
17
18/// Untrained state - method hasn't been fitted yet
19#[derive(Debug, Clone, Copy)]
20/// Untrained
21pub struct Untrained;
22impl ApproximationState for Untrained {}
23
24/// Trained state - method has been fitted and can transform data
25#[derive(Debug, Clone, Copy)]
26/// Trained
27pub struct Trained;
28impl ApproximationState for Trained {}
29
30/// Phantom type to represent kernel types
31pub trait KernelType {
32    /// Name of the kernel type
33    const NAME: &'static str;
34
35    /// Whether this kernel type supports parameter learning
36    const SUPPORTS_PARAMETER_LEARNING: bool;
37
38    /// Default bandwidth/gamma parameter
39    const DEFAULT_BANDWIDTH: f64;
40}
41
42/// RBF (Gaussian) kernel type
43#[derive(Debug, Clone, Copy)]
44/// RBFKernel
45pub struct RBFKernel;
46impl KernelType for RBFKernel {
47    const NAME: &'static str = "RBF";
48    const SUPPORTS_PARAMETER_LEARNING: bool = true;
49    const DEFAULT_BANDWIDTH: f64 = 1.0;
50}
51
52/// Laplacian kernel type
53#[derive(Debug, Clone, Copy)]
54/// LaplacianKernel
55pub struct LaplacianKernel;
56impl KernelType for LaplacianKernel {
57    const NAME: &'static str = "Laplacian";
58    const SUPPORTS_PARAMETER_LEARNING: bool = true;
59    const DEFAULT_BANDWIDTH: f64 = 1.0;
60}
61
62/// Polynomial kernel type
63#[derive(Debug, Clone, Copy)]
64/// PolynomialKernel
65pub struct PolynomialKernel;
66impl KernelType for PolynomialKernel {
67    const NAME: &'static str = "Polynomial";
68    const SUPPORTS_PARAMETER_LEARNING: bool = true;
69    const DEFAULT_BANDWIDTH: f64 = 1.0;
70}
71
72/// Arc-cosine kernel type
73#[derive(Debug, Clone, Copy)]
74/// ArcCosineKernel
75pub struct ArcCosineKernel;
76impl KernelType for ArcCosineKernel {
77    const NAME: &'static str = "ArcCosine";
78    const SUPPORTS_PARAMETER_LEARNING: bool = false;
79    const DEFAULT_BANDWIDTH: f64 = 1.0;
80}
81
82/// Phantom type to represent approximation methods
83pub trait ApproximationMethod {
84    /// Name of the approximation method
85    const NAME: &'static str;
86
87    /// Whether this method supports incremental updates
88    const SUPPORTS_INCREMENTAL: bool;
89
90    /// Whether this method provides theoretical error bounds
91    const HAS_ERROR_BOUNDS: bool;
92
93    /// Computational complexity class
94    const COMPLEXITY: ComplexityClass;
95}
96
97/// Computational complexity classes
98#[derive(Debug, Clone, Copy, PartialEq, Eq)]
99/// ComplexityClass
100pub enum ComplexityClass {
101    /// O(n²) complexity
102    Quadratic,
103    /// O(n log n) complexity
104    QuasiLinear,
105    /// O(n) complexity
106    Linear,
107    /// O(d log d) complexity where d is dimension
108    DimensionDependent,
109}
110
111/// Random Fourier Features approximation method
112#[derive(Debug, Clone, Copy)]
113/// RandomFourierFeatures
114pub struct RandomFourierFeatures;
115impl ApproximationMethod for RandomFourierFeatures {
116    const NAME: &'static str = "RandomFourierFeatures";
117    const SUPPORTS_INCREMENTAL: bool = true;
118    const HAS_ERROR_BOUNDS: bool = true;
119    const COMPLEXITY: ComplexityClass = ComplexityClass::Linear;
120}
121
122/// Nyström approximation method
123#[derive(Debug, Clone, Copy)]
124/// NystromMethod
125pub struct NystromMethod;
126impl ApproximationMethod for NystromMethod {
127    const NAME: &'static str = "Nystrom";
128    const SUPPORTS_INCREMENTAL: bool = false;
129    const HAS_ERROR_BOUNDS: bool = true;
130    const COMPLEXITY: ComplexityClass = ComplexityClass::Quadratic;
131}
132
133/// Fastfood approximation method
134#[derive(Debug, Clone, Copy)]
135/// FastfoodMethod
136pub struct FastfoodMethod;
137impl ApproximationMethod for FastfoodMethod {
138    const NAME: &'static str = "Fastfood";
139    const SUPPORTS_INCREMENTAL: bool = false;
140    const HAS_ERROR_BOUNDS: bool = true;
141    const COMPLEXITY: ComplexityClass = ComplexityClass::DimensionDependent;
142}
143
144/// Type-safe kernel approximation with compile-time guarantees
145#[derive(Debug, Clone)]
146/// TypeSafeKernelApproximation
147pub struct TypeSafeKernelApproximation<State, Kernel, Method, const N_COMPONENTS: usize>
148where
149    State: ApproximationState,
150    Kernel: KernelType,
151    Method: ApproximationMethod,
152{
153    /// Phantom data for compile-time type checking
154    _phantom: PhantomData<(State, Kernel, Method)>,
155
156    /// Method parameters
157    parameters: ApproximationParameters,
158
159    /// Random state for reproducibility
160    random_state: Option<u64>,
161}
162
163/// Parameters for kernel approximation methods
164#[derive(Debug, Clone)]
165/// ApproximationParameters
166pub struct ApproximationParameters {
167    /// Bandwidth/gamma parameter
168    pub bandwidth: f64,
169
170    /// Polynomial degree (for polynomial kernels)
171    pub degree: Option<usize>,
172
173    /// Coefficient for polynomial kernels
174    pub coef0: Option<f64>,
175
176    /// Additional custom parameters
177    pub custom: std::collections::HashMap<String, f64>,
178}
179
180impl Default for ApproximationParameters {
181    fn default() -> Self {
182        Self {
183            bandwidth: 1.0,
184            degree: None,
185            coef0: None,
186            custom: std::collections::HashMap::new(),
187        }
188    }
189}
190
191/// Fitted kernel approximation that can transform data
192#[derive(Debug, Clone)]
193/// FittedTypeSafeKernelApproximation
194pub struct FittedTypeSafeKernelApproximation<Kernel, Method, const N_COMPONENTS: usize>
195where
196    Kernel: KernelType,
197    Method: ApproximationMethod,
198{
199    /// Phantom data for compile-time type checking
200    _phantom: PhantomData<(Kernel, Method)>,
201
202    /// Random features or transformation parameters
203    transformation_params: TransformationParameters<N_COMPONENTS>,
204
205    /// Fitted parameters
206    fitted_parameters: ApproximationParameters,
207
208    /// Approximation quality metrics
209    quality_metrics: QualityMetrics,
210}
211
212/// Transformation parameters for different approximation methods
213#[derive(Debug, Clone)]
214/// TransformationParameters
215pub enum TransformationParameters<const N: usize> {
216    /// Random features for RFF methods
217    RandomFeatures {
218        weights: Array2<f64>,
219
220        biases: Option<Array1<f64>>,
221    },
222    /// Inducing points and eigendecomposition for Nyström
223    Nystrom {
224        inducing_points: Array2<f64>,
225
226        eigenvalues: Array1<f64>,
227
228        eigenvectors: Array2<f64>,
229    },
230    /// Structured matrices for Fastfood
231    Fastfood {
232        structured_matrices: Vec<Array2<f64>>,
233        scaling: Array1<f64>,
234    },
235}
236
237/// Quality metrics for approximation assessment
238#[derive(Debug, Clone)]
239/// QualityMetrics
240#[derive(Default)]
241pub struct QualityMetrics {
242    /// Approximation error estimate
243    pub approximation_error: Option<f64>,
244
245    /// Effective rank of approximation
246    pub effective_rank: Option<f64>,
247
248    /// Condition number
249    pub condition_number: Option<f64>,
250
251    /// Kernel alignment score
252    pub kernel_alignment: Option<f64>,
253}
254
255// Implementation for untrained approximation
256impl<Kernel, Method, const N_COMPONENTS: usize> Default
257    for TypeSafeKernelApproximation<Untrained, Kernel, Method, N_COMPONENTS>
258where
259    Kernel: KernelType,
260    Method: ApproximationMethod,
261{
262    fn default() -> Self {
263        Self::new()
264    }
265}
266
267impl<Kernel, Method, const N_COMPONENTS: usize>
268    TypeSafeKernelApproximation<Untrained, Kernel, Method, N_COMPONENTS>
269where
270    Kernel: KernelType,
271    Method: ApproximationMethod,
272{
273    /// Create a new untrained kernel approximation
274    pub fn new() -> Self {
275        Self {
276            _phantom: PhantomData,
277            parameters: ApproximationParameters {
278                bandwidth: Kernel::DEFAULT_BANDWIDTH,
279                ..Default::default()
280            },
281            random_state: None,
282        }
283    }
284
285    /// Set bandwidth parameter (only available for kernels that support it)
286    pub fn bandwidth(mut self, bandwidth: f64) -> Self
287    where
288        Kernel: KernelTypeWithBandwidth,
289    {
290        self.parameters.bandwidth = bandwidth;
291        self
292    }
293
294    /// Set polynomial degree (only available for polynomial kernels)
295    pub fn degree(mut self, degree: usize) -> Self
296    where
297        Kernel: PolynomialKernelType,
298    {
299        self.parameters.degree = Some(degree);
300        self
301    }
302
303    /// Set random state for reproducibility
304    pub fn random_state(mut self, seed: u64) -> Self {
305        self.random_state = Some(seed);
306        self
307    }
308
309    /// Fit the approximation method (state transition)
310    pub fn fit(
311        self,
312        data: &Array2<f64>,
313    ) -> Result<FittedTypeSafeKernelApproximation<Kernel, Method, N_COMPONENTS>>
314    where
315        Kernel: FittableKernel<Method>,
316        Method: FittableMethod<Kernel>,
317    {
318        self.fit_impl(data)
319    }
320
321    /// Internal fitting implementation
322    fn fit_impl(
323        self,
324        data: &Array2<f64>,
325    ) -> Result<FittedTypeSafeKernelApproximation<Kernel, Method, N_COMPONENTS>>
326    where
327        Kernel: FittableKernel<Method>,
328        Method: FittableMethod<Kernel>,
329    {
330        let transformation_params = match Method::NAME {
331            "RandomFourierFeatures" => self.fit_random_fourier_features(data)?,
332            "Nystrom" => self.fit_nystrom(data)?,
333            "Fastfood" => self.fit_fastfood(data)?,
334            _ => {
335                return Err(SklearsError::InvalidOperation(format!(
336                    "Unsupported method: {}",
337                    Method::NAME
338                )));
339            }
340        };
341
342        Ok(FittedTypeSafeKernelApproximation {
343            _phantom: PhantomData,
344            transformation_params,
345            fitted_parameters: self.parameters,
346            quality_metrics: QualityMetrics::default(),
347        })
348    }
349
350    fn fit_random_fourier_features(
351        &self,
352        data: &Array2<f64>,
353    ) -> Result<TransformationParameters<N_COMPONENTS>> {
354        let mut rng = match self.random_state {
355            Some(seed) => RealStdRng::seed_from_u64(seed),
356            None => RealStdRng::from_seed(thread_rng().random()),
357        };
358
359        let (_, n_features) = data.dim();
360
361        let weights = match Kernel::NAME {
362            "RBF" => {
363                let normal = RandNormal::new(0.0, self.parameters.bandwidth)
364                    .expect("operation should succeed");
365                Array2::from_shape_fn((N_COMPONENTS, n_features), |_| rng.sample(normal))
366            }
367            "Laplacian" => {
368                // Laplacian kernel uses Cauchy distribution
369                let uniform =
370                    RandUniform::new(0.0, std::f64::consts::PI).expect("operation should succeed");
371                Array2::from_shape_fn((N_COMPONENTS, n_features), |_| {
372                    let u = rng.sample(uniform);
373                    (u - std::f64::consts::PI / 2.0).tan() / self.parameters.bandwidth
374                })
375            }
376            _ => {
377                return Err(SklearsError::InvalidOperation(format!(
378                    "RFF not supported for kernel: {}",
379                    Kernel::NAME
380                )));
381            }
382        };
383
384        let uniform =
385            RandUniform::new(0.0, 2.0 * std::f64::consts::PI).expect("operation should succeed");
386        let biases = Some(Array1::from_shape_fn(N_COMPONENTS, |_| rng.sample(uniform)));
387
388        Ok(TransformationParameters::RandomFeatures { weights, biases })
389    }
390
391    fn fit_nystrom(&self, data: &Array2<f64>) -> Result<TransformationParameters<N_COMPONENTS>> {
392        let (n_samples, _) = data.dim();
393        let n_inducing = N_COMPONENTS.min(n_samples);
394
395        // Sample inducing points
396        let mut rng = match self.random_state {
397            Some(seed) => RealStdRng::seed_from_u64(seed),
398            None => RealStdRng::from_seed(thread_rng().random()),
399        };
400
401        let mut indices: Vec<usize> = (0..n_samples).collect();
402        indices.shuffle(&mut rng);
403
404        let inducing_indices = &indices[..n_inducing];
405        let inducing_points = data.select(scirs2_core::ndarray::Axis(0), inducing_indices);
406
407        // Compute kernel matrix on inducing points
408        let kernel_matrix = self.compute_kernel_matrix(&inducing_points, &inducing_points)?;
409
410        // Simplified eigendecomposition (using diagonal as approximation)
411        let eigenvalues = Array1::from_shape_fn(n_inducing, |i| kernel_matrix[[i, i]]);
412        let eigenvectors = Array2::eye(n_inducing);
413
414        Ok(TransformationParameters::Nystrom {
415            inducing_points,
416            eigenvalues,
417            eigenvectors,
418        })
419    }
420
421    fn fit_fastfood(&self, data: &Array2<f64>) -> Result<TransformationParameters<N_COMPONENTS>> {
422        let mut rng = match self.random_state {
423            Some(seed) => RealStdRng::seed_from_u64(seed),
424            None => RealStdRng::from_seed(thread_rng().random()),
425        };
426
427        let (_, n_features) = data.dim();
428
429        // Create structured matrices for Fastfood
430        let mut structured_matrices = Vec::new();
431
432        // Binary matrix
433        let binary_matrix = Array2::from_shape_fn((n_features, n_features), |(i, j)| {
434            if i == j {
435                if rng.random::<bool>() {
436                    1.0
437                } else {
438                    -1.0
439                }
440            } else {
441                0.0
442            }
443        });
444        structured_matrices.push(binary_matrix);
445
446        // Gaussian scaling
447        let scaling = Array1::from_shape_fn(N_COMPONENTS, |_| {
448            use scirs2_core::random::RandNormal;
449            let normal = RandNormal::new(0.0, 1.0).expect("operation should succeed");
450            rng.sample(normal)
451        });
452
453        Ok(TransformationParameters::Fastfood {
454            structured_matrices,
455            scaling,
456        })
457    }
458
459    fn compute_kernel_matrix(&self, x1: &Array2<f64>, x2: &Array2<f64>) -> Result<Array2<f64>> {
460        let (n1, _) = x1.dim();
461        let (n2, _) = x2.dim();
462        let mut kernel = Array2::zeros((n1, n2));
463
464        for i in 0..n1 {
465            for j in 0..n2 {
466                let similarity = match Kernel::NAME {
467                    "RBF" => {
468                        let diff = &x1.row(i) - &x2.row(j);
469                        let dist_sq = diff.mapv(|x| x * x).sum();
470                        (-self.parameters.bandwidth * dist_sq).exp()
471                    }
472                    "Laplacian" => {
473                        let diff = &x1.row(i) - &x2.row(j);
474                        let dist = diff.mapv(|x| x.abs()).sum();
475                        (-self.parameters.bandwidth * dist).exp()
476                    }
477                    "Polynomial" => {
478                        let dot_product = x1.row(i).dot(&x2.row(j));
479                        let degree = self.parameters.degree.unwrap_or(2) as i32;
480                        let coef0 = self.parameters.coef0.unwrap_or(1.0);
481                        (self.parameters.bandwidth * dot_product + coef0).powi(degree)
482                    }
483                    _ => {
484                        return Err(SklearsError::InvalidOperation(format!(
485                            "Unsupported kernel: {}",
486                            Kernel::NAME
487                        )));
488                    }
489                };
490                kernel[[i, j]] = similarity;
491            }
492        }
493
494        Ok(kernel)
495    }
496}
497
498// Implementation for fitted approximation
499impl<Kernel, Method, const N_COMPONENTS: usize>
500    FittedTypeSafeKernelApproximation<Kernel, Method, N_COMPONENTS>
501where
502    Kernel: KernelType,
503    Method: ApproximationMethod,
504{
505    /// Transform data using the fitted approximation
506    pub fn transform(&self, data: &Array2<f64>) -> Result<Array2<f64>> {
507        match &self.transformation_params {
508            TransformationParameters::RandomFeatures { weights, biases } => {
509                self.transform_random_features(data, weights, biases.as_ref())
510            }
511            TransformationParameters::Nystrom {
512                inducing_points,
513                eigenvalues,
514                eigenvectors,
515            } => self.transform_nystrom(data, inducing_points, eigenvalues, eigenvectors),
516            TransformationParameters::Fastfood {
517                structured_matrices,
518                scaling,
519            } => self.transform_fastfood(data, structured_matrices, scaling),
520        }
521    }
522
523    fn transform_random_features(
524        &self,
525        data: &Array2<f64>,
526        weights: &Array2<f64>,
527        biases: Option<&Array1<f64>>,
528    ) -> Result<Array2<f64>> {
529        let (n_samples, _) = data.dim();
530        let mut features = Array2::zeros((n_samples, N_COMPONENTS * 2));
531
532        for (i, sample) in data.axis_iter(scirs2_core::ndarray::Axis(0)).enumerate() {
533            for j in 0..N_COMPONENTS {
534                let projection = sample.dot(&weights.row(j));
535                let phase = if let Some(b) = biases {
536                    projection + b[j]
537                } else {
538                    projection
539                };
540
541                features[[i, 2 * j]] = phase.cos();
542                features[[i, 2 * j + 1]] = phase.sin();
543            }
544        }
545
546        Ok(features)
547    }
548
549    fn transform_nystrom(
550        &self,
551        data: &Array2<f64>,
552        inducing_points: &Array2<f64>,
553        eigenvalues: &Array1<f64>,
554        eigenvectors: &Array2<f64>,
555    ) -> Result<Array2<f64>> {
556        // Compute kernel between data and inducing points
557        let kernel_matrix = self.compute_kernel_matrix_fitted(data, inducing_points)?;
558
559        // Apply eigendecomposition transformation
560        let mut features = Array2::zeros((data.nrows(), eigenvalues.len()));
561
562        for i in 0..data.nrows() {
563            for j in 0..eigenvalues.len() {
564                if eigenvalues[j] > 1e-8 {
565                    let mut feature_value = 0.0;
566                    for k in 0..inducing_points.nrows() {
567                        feature_value += kernel_matrix[[i, k]] * eigenvectors[[k, j]];
568                    }
569                    features[[i, j]] = feature_value / eigenvalues[j].sqrt();
570                }
571            }
572        }
573
574        Ok(features)
575    }
576
577    fn transform_fastfood(
578        &self,
579        data: &Array2<f64>,
580        structured_matrices: &[Array2<f64>],
581        scaling: &Array1<f64>,
582    ) -> Result<Array2<f64>> {
583        let (n_samples, n_features) = data.dim();
584        let mut features = data.clone();
585
586        // Apply structured transformations
587        for matrix in structured_matrices {
588            features = features.dot(matrix);
589        }
590
591        // Apply scaling and take cosine features
592        let mut result = Array2::zeros((n_samples, N_COMPONENTS));
593        for i in 0..n_samples {
594            for j in 0..N_COMPONENTS.min(n_features) {
595                result[[i, j]] = (features[[i, j]] * scaling[j]).cos();
596            }
597        }
598
599        Ok(result)
600    }
601
602    fn compute_kernel_matrix_fitted(
603        &self,
604        x1: &Array2<f64>,
605        x2: &Array2<f64>,
606    ) -> Result<Array2<f64>> {
607        let (n1, _) = x1.dim();
608        let (n2, _) = x2.dim();
609        let mut kernel = Array2::zeros((n1, n2));
610
611        for i in 0..n1 {
612            for j in 0..n2 {
613                let similarity = match Kernel::NAME {
614                    "RBF" => {
615                        let diff = &x1.row(i) - &x2.row(j);
616                        let dist_sq = diff.mapv(|x| x * x).sum();
617                        (-self.fitted_parameters.bandwidth * dist_sq).exp()
618                    }
619                    "Laplacian" => {
620                        let diff = &x1.row(i) - &x2.row(j);
621                        let dist = diff.mapv(|x| x.abs()).sum();
622                        (-self.fitted_parameters.bandwidth * dist).exp()
623                    }
624                    "Polynomial" => {
625                        let dot_product = x1.row(i).dot(&x2.row(j));
626                        let degree = self.fitted_parameters.degree.unwrap_or(2) as i32;
627                        let coef0 = self.fitted_parameters.coef0.unwrap_or(1.0);
628                        (self.fitted_parameters.bandwidth * dot_product + coef0).powi(degree)
629                    }
630                    _ => {
631                        return Err(SklearsError::InvalidOperation(format!(
632                            "Unsupported kernel: {}",
633                            Kernel::NAME
634                        )));
635                    }
636                };
637                kernel[[i, j]] = similarity;
638            }
639        }
640
641        Ok(kernel)
642    }
643
644    /// Get approximation quality metrics
645    pub fn quality_metrics(&self) -> &QualityMetrics {
646        &self.quality_metrics
647    }
648
649    /// Get the number of components (compile-time constant)
650    pub const fn n_components() -> usize {
651        N_COMPONENTS
652    }
653
654    /// Get kernel type name
655    pub fn kernel_name(&self) -> &'static str {
656        Kernel::NAME
657    }
658
659    /// Get approximation method name
660    pub fn method_name(&self) -> &'static str {
661        Method::NAME
662    }
663}
664
665// Trait constraints for type safety
666
667/// Marker trait for kernel types that support bandwidth parameters
668pub trait KernelTypeWithBandwidth: KernelType {}
669impl KernelTypeWithBandwidth for RBFKernel {}
670impl KernelTypeWithBandwidth for LaplacianKernel {}
671
672/// Marker trait for polynomial kernel types
673pub trait PolynomialKernelType: KernelType {}
674impl PolynomialKernelType for PolynomialKernel {}
675
676/// Marker trait for kernels that can be fitted with specific methods
677pub trait FittableKernel<Method: ApproximationMethod>: KernelType {}
678impl FittableKernel<RandomFourierFeatures> for RBFKernel {}
679impl FittableKernel<RandomFourierFeatures> for LaplacianKernel {}
680impl FittableKernel<NystromMethod> for RBFKernel {}
681impl FittableKernel<NystromMethod> for LaplacianKernel {}
682impl FittableKernel<NystromMethod> for PolynomialKernel {}
683impl FittableKernel<FastfoodMethod> for RBFKernel {}
684
685/// Marker trait for methods that can be fitted with specific kernels
686pub trait FittableMethod<Kernel: KernelType>: ApproximationMethod {}
687impl FittableMethod<RBFKernel> for RandomFourierFeatures {}
688impl FittableMethod<LaplacianKernel> for RandomFourierFeatures {}
689impl FittableMethod<RBFKernel> for NystromMethod {}
690impl FittableMethod<LaplacianKernel> for NystromMethod {}
691impl FittableMethod<PolynomialKernel> for NystromMethod {}
692impl FittableMethod<RBFKernel> for FastfoodMethod {}
693
694/// Type aliases for common kernel approximation combinations
695pub type RBFRandomFourierFeatures<const N: usize> =
696    TypeSafeKernelApproximation<Untrained, RBFKernel, RandomFourierFeatures, N>;
697
698pub type LaplacianRandomFourierFeatures<const N: usize> =
699    TypeSafeKernelApproximation<Untrained, LaplacianKernel, RandomFourierFeatures, N>;
700
701pub type RBFNystrom<const N: usize> =
702    TypeSafeKernelApproximation<Untrained, RBFKernel, NystromMethod, N>;
703
704pub type PolynomialNystrom<const N: usize> =
705    TypeSafeKernelApproximation<Untrained, PolynomialKernel, NystromMethod, N>;
706
707pub type RBFFastfood<const N: usize> =
708    TypeSafeKernelApproximation<Untrained, RBFKernel, FastfoodMethod, N>;
709
710/// Fitted type aliases
711pub type FittedRBFRandomFourierFeatures<const N: usize> =
712    FittedTypeSafeKernelApproximation<RBFKernel, RandomFourierFeatures, N>;
713
714pub type FittedLaplacianRandomFourierFeatures<const N: usize> =
715    FittedTypeSafeKernelApproximation<LaplacianKernel, RandomFourierFeatures, N>;
716
717pub type FittedRBFNystrom<const N: usize> =
718    FittedTypeSafeKernelApproximation<RBFKernel, NystromMethod, N>;
719
720#[allow(non_snake_case)]
721#[cfg(test)]
722mod tests {
723    use super::*;
724    use scirs2_core::ndarray::array;
725
726    #[test]
727    fn test_type_safe_rbf_rff() {
728        let data = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0],];
729
730        // Create untrained approximation with compile-time dimension
731        let approximation: RBFRandomFourierFeatures<10> = TypeSafeKernelApproximation::new()
732            .bandwidth(1.5)
733            .random_state(42);
734
735        // Fit and transform
736        let fitted = approximation.fit(&data).expect("operation should succeed");
737        let features = fitted.transform(&data).expect("operation should succeed");
738
739        assert_eq!(features.shape(), &[4, 20]); // 10 components * 2 (cos, sin)
740        assert_eq!(FittedRBFRandomFourierFeatures::<10>::n_components(), 10);
741        assert_eq!(fitted.kernel_name(), "RBF");
742        assert_eq!(fitted.method_name(), "RandomFourierFeatures");
743    }
744
745    #[test]
746    fn test_type_safe_nystrom() {
747        let data = array![
748            [1.0, 2.0, 3.0],
749            [2.0, 3.0, 4.0],
750            [3.0, 4.0, 5.0],
751            [4.0, 5.0, 6.0],
752        ];
753
754        let approximation: RBFNystrom<5> = TypeSafeKernelApproximation::new()
755            .bandwidth(2.0)
756            .random_state(42);
757
758        let fitted = approximation.fit(&data).expect("operation should succeed");
759        let features = fitted.transform(&data).expect("operation should succeed");
760
761        assert_eq!(features.shape()[0], 4); // n_samples
762        assert_eq!(fitted.kernel_name(), "RBF");
763        assert_eq!(fitted.method_name(), "Nystrom");
764    }
765
766    #[test]
767    fn test_polynomial_kernel() {
768        let data = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0],];
769
770        let approximation: PolynomialNystrom<3> = TypeSafeKernelApproximation::new()
771            .degree(3)
772            .random_state(42);
773
774        let fitted = approximation.fit(&data).expect("operation should succeed");
775        let features = fitted.transform(&data).expect("operation should succeed");
776
777        assert_eq!(fitted.kernel_name(), "Polynomial");
778        assert!(features.nrows() == 3);
779    }
780
781    #[test]
782    fn test_compile_time_constants() {
783        // These should be available at compile time
784        assert_eq!(RBFKernel::NAME, "RBF");
785        assert!(RBFKernel::SUPPORTS_PARAMETER_LEARNING);
786        assert_eq!(RandomFourierFeatures::NAME, "RandomFourierFeatures");
787        assert!(RandomFourierFeatures::SUPPORTS_INCREMENTAL);
788        assert!(RandomFourierFeatures::HAS_ERROR_BOUNDS);
789    }
790
791    // This test demonstrates compile-time type safety
792    // The following code would not compile due to type constraints:
793
794    // #[test]
795    // fn test_type_safety_violations() {
796    //     // This would fail: ArcCosine kernel doesn't support bandwidth
797    //     // let _ = TypeSafeKernelApproximation::<Untrained, ArcCosineKernel, RandomFourierFeatures, 10>::new()
798    //     //     .bandwidth(1.0); // ERROR: ArcCosineKernel doesn't implement KernelTypeWithBandwidth
799    //
800    //     // This would fail: RBF kernel with degree parameter
801    //     // let _ = TypeSafeKernelApproximation::<Untrained, RBFKernel, RandomFourierFeatures, 10>::new()
802    //     //     .degree(2); // ERROR: RBFKernel doesn't implement PolynomialKernelType
803    //
804    //     // This would fail: Trying to fit incompatible kernel-method combination
805    //     // let _ = TypeSafeKernelApproximation::<Untrained, ArcCosineKernel, RandomFourierFeatures, 10>::new()
806    //     //     .fit(&data); // ERROR: ArcCosineKernel doesn't implement FittableKernel<RandomFourierFeatures>
807    // }
808}
809
810// ==================================================================================
811// ADVANCED TYPE SAFETY ENHANCEMENTS - Zero-Cost Abstractions and Compile-Time Validation
812// ==================================================================================
813
814/// Advanced compile-time parameter validation traits
815pub trait ParameterValidation<const MIN: usize, const MAX: usize> {
816    /// Validates parameters at compile time
817    const IS_VALID: bool = MIN <= MAX;
818
819    /// Get the parameter range
820    fn parameter_range() -> (usize, usize) {
821        (MIN, MAX)
822    }
823}
824
825/// Zero-cost abstraction for validated component counts
826#[derive(Debug, Clone, Copy)]
827/// ValidatedComponents
828pub struct ValidatedComponents<const N: usize>;
829
830impl<const N: usize> Default for ValidatedComponents<N> {
831    fn default() -> Self {
832        Self::new()
833    }
834}
835
836impl<const N: usize> ValidatedComponents<N> {
837    /// Create validated components with compile-time checks
838    pub fn new() -> Self {
839        // Compile-time assertions
840        assert!(N > 0, "Component count must be positive");
841        assert!(N <= 10000, "Component count too large");
842        Self
843    }
844
845    /// Get the component count
846    pub const fn count(&self) -> usize {
847        N
848    }
849}
850
851/// Compile-time compatibility checking between kernels and methods
852pub trait KernelMethodCompatibility<K: KernelType, M: ApproximationMethod> {
853    /// Whether this kernel-method combination is supported
854    const IS_COMPATIBLE: bool;
855
856    /// Performance characteristics of this combination
857    const PERFORMANCE_TIER: PerformanceTier;
858
859    /// Memory complexity
860    const MEMORY_COMPLEXITY: ComplexityClass;
861}
862
863/// Performance tiers for different kernel-method combinations
864#[derive(Debug, Clone, Copy, PartialEq, Eq)]
865/// PerformanceTier
866pub enum PerformanceTier {
867    /// Optimal performance combination
868    Optimal,
869    /// Good performance
870    Good,
871    /// Acceptable performance
872    Acceptable,
873    /// Poor performance (not recommended)
874    Poor,
875}
876
877/// Implement compatibility rules
878impl KernelMethodCompatibility<RBFKernel, RandomFourierFeatures> for () {
879    const IS_COMPATIBLE: bool = true;
880    const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Optimal;
881    const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Linear;
882}
883
884impl KernelMethodCompatibility<LaplacianKernel, RandomFourierFeatures> for () {
885    const IS_COMPATIBLE: bool = true;
886    const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Optimal;
887    const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Linear;
888}
889
890impl KernelMethodCompatibility<RBFKernel, NystromMethod> for () {
891    const IS_COMPATIBLE: bool = true;
892    const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Good;
893    const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Quadratic;
894}
895
896impl KernelMethodCompatibility<PolynomialKernel, NystromMethod> for () {
897    const IS_COMPATIBLE: bool = true;
898    const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Good;
899    const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Quadratic;
900}
901
902impl KernelMethodCompatibility<RBFKernel, FastfoodMethod> for () {
903    const IS_COMPATIBLE: bool = true;
904    const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Optimal;
905    const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::DimensionDependent;
906}
907
908/// Arc-cosine kernels have limited compatibility
909impl KernelMethodCompatibility<ArcCosineKernel, RandomFourierFeatures> for () {
910    const IS_COMPATIBLE: bool = true;
911    const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Acceptable;
912    const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Linear;
913}
914
915impl KernelMethodCompatibility<ArcCosineKernel, NystromMethod> for () {
916    const IS_COMPATIBLE: bool = false; // Not recommended
917    const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Poor;
918    const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Quadratic;
919}
920
921impl KernelMethodCompatibility<ArcCosineKernel, FastfoodMethod> for () {
922    const IS_COMPATIBLE: bool = false; // Not supported
923    const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Poor;
924    const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::DimensionDependent;
925}
926
927// Additional compatibility implementations for missing combinations
928impl KernelMethodCompatibility<LaplacianKernel, NystromMethod> for () {
929    const IS_COMPATIBLE: bool = true;
930    const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Good;
931    const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Quadratic;
932}
933
934impl KernelMethodCompatibility<PolynomialKernel, RandomFourierFeatures> for () {
935    const IS_COMPATIBLE: bool = true;
936    const PERFORMANCE_TIER: PerformanceTier = PerformanceTier::Good;
937    const MEMORY_COMPLEXITY: ComplexityClass = ComplexityClass::Linear;
938}
939
940/// Zero-cost wrapper for compile-time validated approximations
941#[derive(Debug, Clone)]
942/// ValidatedKernelApproximation
943pub struct ValidatedKernelApproximation<K, M, const N: usize>
944where
945    K: KernelType,
946    M: ApproximationMethod,
947    (): KernelMethodCompatibility<K, M>,
948{
949    inner: TypeSafeKernelApproximation<Untrained, K, M, N>,
950    _validation: ValidatedComponents<N>,
951}
952
953impl<K, M, const N: usize> Default for ValidatedKernelApproximation<K, M, N>
954where
955    K: KernelType,
956    M: ApproximationMethod,
957    (): KernelMethodCompatibility<K, M>,
958{
959    fn default() -> Self {
960        Self::new()
961    }
962}
963
964impl<K, M, const N: usize> ValidatedKernelApproximation<K, M, N>
965where
966    K: KernelType,
967    M: ApproximationMethod,
968    (): KernelMethodCompatibility<K, M>,
969{
970    /// Create a new validated kernel approximation with compile-time checks
971    pub fn new() -> Self {
972        // Compile-time compatibility check
973        assert!(
974            <() as KernelMethodCompatibility<K, M>>::IS_COMPATIBLE,
975            "Incompatible kernel-method combination"
976        );
977
978        Self {
979            inner: TypeSafeKernelApproximation::new(),
980            _validation: ValidatedComponents::new(),
981        }
982    }
983
984    /// Get performance information at compile time
985    pub const fn performance_info() -> (PerformanceTier, ComplexityClass, ComplexityClass) {
986        (
987            <() as KernelMethodCompatibility<K, M>>::PERFORMANCE_TIER,
988            M::COMPLEXITY,
989            <() as KernelMethodCompatibility<K, M>>::MEMORY_COMPLEXITY,
990        )
991    }
992
993    /// Check if this combination is optimal
994    pub const fn is_optimal() -> bool {
995        matches!(
996            <() as KernelMethodCompatibility<K, M>>::PERFORMANCE_TIER,
997            PerformanceTier::Optimal
998        )
999    }
1000}
1001
1002/// Enhanced quality metrics with compile-time bounds
1003#[derive(Debug, Clone, Copy)]
1004/// BoundedQualityMetrics
1005pub struct BoundedQualityMetrics<const MIN_ALIGNMENT: u32, const MAX_ERROR: u32> {
1006    kernel_alignment: f64,
1007    approximation_error: f64,
1008    effective_rank: f64,
1009}
1010
1011impl<const MIN_ALIGNMENT: u32, const MAX_ERROR: u32>
1012    BoundedQualityMetrics<MIN_ALIGNMENT, MAX_ERROR>
1013{
1014    /// Create new bounded quality metrics with compile-time validation
1015    pub const fn new(alignment: f64, error: f64, rank: f64) -> Option<Self> {
1016        let min_align = MIN_ALIGNMENT as f64 / 100.0; // Convert percentage to decimal
1017        let max_err = MAX_ERROR as f64 / 100.0;
1018
1019        if alignment >= min_align && error <= max_err && rank > 0.0 {
1020            Some(Self {
1021                kernel_alignment: alignment,
1022                approximation_error: error,
1023                effective_rank: rank,
1024            })
1025        } else {
1026            None
1027        }
1028    }
1029
1030    /// Get compile-time bounds
1031    pub const fn bounds() -> (f64, f64) {
1032        (MIN_ALIGNMENT as f64 / 100.0, MAX_ERROR as f64 / 100.0)
1033    }
1034
1035    /// Check if metrics meet quality standards
1036    pub const fn meets_standards(&self) -> bool {
1037        let (min_align, max_err) = Self::bounds();
1038        self.kernel_alignment >= min_align && self.approximation_error <= max_err
1039    }
1040}
1041
1042/// Type alias for high-quality metrics (>90% alignment, <5% error)
1043pub type HighQualityMetrics = BoundedQualityMetrics<90, 5>;
1044
1045/// Type alias for acceptable quality metrics (>70% alignment, <15% error)
1046pub type AcceptableQualityMetrics = BoundedQualityMetrics<70, 15>;
1047
1048/// Macro for easy creation of validated kernel approximations
1049#[macro_export]
1050macro_rules! validated_kernel_approximation {
1051    (RBF, RandomFourierFeatures, $n:literal) => {
1052        ValidatedKernelApproximation::<RBFKernel, RandomFourierFeatures, $n>::new()
1053    };
1054    (Laplacian, RandomFourierFeatures, $n:literal) => {
1055        ValidatedKernelApproximation::<LaplacianKernel, RandomFourierFeatures, $n>::new()
1056    };
1057    (RBF, Nystrom, $n:literal) => {
1058        ValidatedKernelApproximation::<RBFKernel, NystromMethod, $n>::new()
1059    };
1060    (RBF, Fastfood, $n:literal) => {
1061        ValidatedKernelApproximation::<RBFKernel, FastfoodMethod, $n>::new()
1062    };
1063}
1064
1065#[allow(non_snake_case)]
1066#[cfg(test)]
1067mod advanced_type_safety_tests {
1068    use super::*;
1069
1070    #[test]
1071    fn test_validated_components() {
1072        let components = ValidatedComponents::<100>::new();
1073        assert_eq!(components.count(), 100);
1074    }
1075
1076    #[test]
1077    fn test_kernel_method_compatibility() {
1078        // Test compile-time compatibility checks
1079        assert!(<() as KernelMethodCompatibility<RBFKernel, RandomFourierFeatures>>::IS_COMPATIBLE);
1080        assert!(!<() as KernelMethodCompatibility<ArcCosineKernel, NystromMethod>>::IS_COMPATIBLE);
1081
1082        // Test performance tiers
1083        assert_eq!(
1084            <() as KernelMethodCompatibility<RBFKernel, RandomFourierFeatures>>::PERFORMANCE_TIER,
1085            PerformanceTier::Optimal
1086        );
1087    }
1088
1089    #[test]
1090    fn test_bounded_quality_metrics() {
1091        // High quality metrics
1092        let high_quality =
1093            HighQualityMetrics::new(0.95, 0.03, 50.0).expect("operation should succeed");
1094        assert!(high_quality.meets_standards());
1095
1096        // Low quality metrics should fail bounds check
1097        let low_quality = HighQualityMetrics::new(0.60, 0.20, 10.0);
1098        assert!(low_quality.is_none());
1099    }
1100
1101    #[test]
1102    fn test_macro_creation() {
1103        // Test that the macro compiles for valid combinations
1104        let _rbf_rff = validated_kernel_approximation!(RBF, RandomFourierFeatures, 100);
1105        let _lap_rff = validated_kernel_approximation!(Laplacian, RandomFourierFeatures, 50);
1106        let _rbf_nys = validated_kernel_approximation!(RBF, Nystrom, 30);
1107        let _rbf_ff = validated_kernel_approximation!(RBF, Fastfood, 128);
1108
1109        // Performance info should be available at compile time
1110        let (tier, complexity, memory) = ValidatedKernelApproximation::<
1111            RBFKernel,
1112            RandomFourierFeatures,
1113            100,
1114        >::performance_info();
1115        assert_eq!(tier, PerformanceTier::Optimal);
1116        assert_eq!(complexity, ComplexityClass::Linear);
1117        assert_eq!(memory, ComplexityClass::Linear);
1118    }
1119}
1120
1121/// Advanced Zero-Cost Kernel Composition Abstractions
1122///
1123/// These traits enable compile-time composition of kernel operations
1124/// with zero runtime overhead.
1125///
1126/// Trait for kernels that can be composed
1127pub trait ComposableKernel: KernelType {
1128    type CompositionResult<Other: ComposableKernel>: ComposableKernel;
1129
1130    /// Combine this kernel with another kernel
1131    fn compose<Other: ComposableKernel>(self) -> Self::CompositionResult<Other>;
1132}
1133
1134/// Sum composition of two kernels
1135#[derive(Debug, Clone, Copy)]
1136/// SumKernel
1137pub struct SumKernel<K1: KernelType, K2: KernelType> {
1138    _phantom: PhantomData<(K1, K2)>,
1139}
1140
1141impl<K1: KernelType, K2: KernelType> KernelType for SumKernel<K1, K2> {
1142    const NAME: &'static str = "Sum";
1143    const SUPPORTS_PARAMETER_LEARNING: bool =
1144        K1::SUPPORTS_PARAMETER_LEARNING && K2::SUPPORTS_PARAMETER_LEARNING;
1145    const DEFAULT_BANDWIDTH: f64 = (K1::DEFAULT_BANDWIDTH + K2::DEFAULT_BANDWIDTH) / 2.0;
1146}
1147
1148/// Product composition of two kernels  
1149#[derive(Debug, Clone, Copy)]
1150/// ProductKernel
1151pub struct ProductKernel<K1: KernelType, K2: KernelType> {
1152    _phantom: PhantomData<(K1, K2)>,
1153}
1154
1155impl<K1: KernelType, K2: KernelType> KernelType for ProductKernel<K1, K2> {
1156    const NAME: &'static str = "Product";
1157    const SUPPORTS_PARAMETER_LEARNING: bool =
1158        K1::SUPPORTS_PARAMETER_LEARNING && K2::SUPPORTS_PARAMETER_LEARNING;
1159    const DEFAULT_BANDWIDTH: f64 = K1::DEFAULT_BANDWIDTH * K2::DEFAULT_BANDWIDTH;
1160}
1161
1162impl ComposableKernel for RBFKernel {
1163    type CompositionResult<Other: ComposableKernel> = SumKernel<Self, Other>;
1164
1165    fn compose<Other: ComposableKernel>(self) -> Self::CompositionResult<Other> {
1166        SumKernel {
1167            _phantom: PhantomData,
1168        }
1169    }
1170}
1171
1172impl ComposableKernel for LaplacianKernel {
1173    type CompositionResult<Other: ComposableKernel> = SumKernel<Self, Other>;
1174
1175    fn compose<Other: ComposableKernel>(self) -> Self::CompositionResult<Other> {
1176        SumKernel {
1177            _phantom: PhantomData,
1178        }
1179    }
1180}
1181
1182impl ComposableKernel for PolynomialKernel {
1183    type CompositionResult<Other: ComposableKernel> = ProductKernel<Self, Other>;
1184
1185    fn compose<Other: ComposableKernel>(self) -> Self::CompositionResult<Other> {
1186        ProductKernel {
1187            _phantom: PhantomData,
1188        }
1189    }
1190}
1191
1192/// Compile-time feature size validation
1193pub trait ValidatedFeatureSize<const N: usize> {
1194    const IS_POWER_OF_TWO: bool = (N != 0) && ((N & (N - 1)) == 0);
1195    const IS_REASONABLE_SIZE: bool = N >= 8 && N <= 8192;
1196    const IS_VALID: bool = Self::IS_POWER_OF_TWO && Self::IS_REASONABLE_SIZE;
1197}
1198
1199// Implement for all const generic sizes
1200impl<const N: usize> ValidatedFeatureSize<N> for () {}
1201
1202/// Zero-cost wrapper for validated feature dimensions
1203#[derive(Debug, Clone)]
1204/// ValidatedFeatures
1205pub struct ValidatedFeatures<const N: usize> {
1206    _phantom: PhantomData<[f64; N]>,
1207}
1208
1209impl<const N: usize> Default for ValidatedFeatures<N>
1210where
1211    (): ValidatedFeatureSize<N>,
1212{
1213    fn default() -> Self {
1214        Self::new()
1215    }
1216}
1217
1218impl<const N: usize> ValidatedFeatures<N>
1219where
1220    (): ValidatedFeatureSize<N>,
1221{
1222    /// Create new validated features (compile-time checked)
1223    pub const fn new() -> Self {
1224        Self {
1225            _phantom: PhantomData,
1226        }
1227    }
1228
1229    /// Get the validated feature count
1230    pub const fn count() -> usize {
1231        N
1232    }
1233
1234    /// Check if size is optimal for Fastfood
1235    pub const fn is_fastfood_optimal() -> bool {
1236        <() as ValidatedFeatureSize<N>>::IS_POWER_OF_TWO
1237    }
1238}
1239
1240/// Advanced approximation quality bounds
1241pub trait ApproximationQualityBounds<Method: ApproximationMethod> {
1242    /// Theoretical error bound for this method
1243    const ERROR_BOUND_CONSTANT: f64;
1244
1245    /// Sample complexity scaling
1246    const SAMPLE_COMPLEXITY_EXPONENT: f64;
1247
1248    /// Dimension dependency
1249    const DIMENSION_DEPENDENCY: f64;
1250
1251    /// Compute theoretical error bound
1252    fn error_bound(n_samples: usize, n_features: usize, n_components: usize) -> f64 {
1253        let base_rate = Self::ERROR_BOUND_CONSTANT;
1254        let sample_factor = (n_samples as f64).powf(-Self::SAMPLE_COMPLEXITY_EXPONENT);
1255        let dim_factor = (n_features as f64).powf(Self::DIMENSION_DEPENDENCY);
1256        let comp_factor = (n_components as f64).powf(-0.5);
1257
1258        base_rate * sample_factor * dim_factor * comp_factor
1259    }
1260}
1261
1262impl ApproximationQualityBounds<RandomFourierFeatures> for () {
1263    const ERROR_BOUND_CONSTANT: f64 = 2.0;
1264    const SAMPLE_COMPLEXITY_EXPONENT: f64 = 0.25;
1265    const DIMENSION_DEPENDENCY: f64 = 0.1;
1266}
1267
1268impl ApproximationQualityBounds<NystromMethod> for () {
1269    const ERROR_BOUND_CONSTANT: f64 = 1.5;
1270    const SAMPLE_COMPLEXITY_EXPONENT: f64 = 0.33;
1271    const DIMENSION_DEPENDENCY: f64 = 0.05;
1272}
1273
1274impl ApproximationQualityBounds<FastfoodMethod> for () {
1275    const ERROR_BOUND_CONSTANT: f64 = 3.0;
1276    const SAMPLE_COMPLEXITY_EXPONENT: f64 = 0.2;
1277    const DIMENSION_DEPENDENCY: f64 = 0.15;
1278}
1279
1280/// Advanced type-safe kernel configuration builder
1281#[derive(Debug, Clone)]
1282/// TypeSafeKernelConfig
1283pub struct TypeSafeKernelConfig<K: KernelType, M: ApproximationMethod, const N: usize>
1284where
1285    (): ValidatedFeatureSize<N>,
1286    (): KernelMethodCompatibility<K, M>,
1287    (): ApproximationQualityBounds<M>,
1288{
1289    kernel_type: PhantomData<K>,
1290    method_type: PhantomData<M>,
1291    features: ValidatedFeatures<N>,
1292    bandwidth: f64,
1293    quality_threshold: f64,
1294}
1295
1296impl<K: KernelType, M: ApproximationMethod, const N: usize> Default
1297    for TypeSafeKernelConfig<K, M, N>
1298where
1299    (): ValidatedFeatureSize<N>,
1300    (): KernelMethodCompatibility<K, M>,
1301    (): ApproximationQualityBounds<M>,
1302{
1303    fn default() -> Self {
1304        Self::new()
1305    }
1306}
1307
1308impl<K: KernelType, M: ApproximationMethod, const N: usize> TypeSafeKernelConfig<K, M, N>
1309where
1310    (): ValidatedFeatureSize<N>,
1311    (): KernelMethodCompatibility<K, M>,
1312    (): ApproximationQualityBounds<M>,
1313{
1314    /// Create a new type-safe configuration
1315    pub fn new() -> Self {
1316        Self {
1317            kernel_type: PhantomData,
1318            method_type: PhantomData,
1319            features: ValidatedFeatures::new(),
1320            bandwidth: K::DEFAULT_BANDWIDTH,
1321            quality_threshold: 0.1,
1322        }
1323    }
1324
1325    /// Set bandwidth parameter with compile-time validation
1326    pub fn bandwidth(mut self, bandwidth: f64) -> Self {
1327        assert!(bandwidth > 0.0, "Bandwidth must be positive");
1328        self.bandwidth = bandwidth;
1329        self
1330    }
1331
1332    /// Set quality threshold with bounds checking
1333    pub fn quality_threshold(mut self, threshold: f64) -> Self {
1334        assert!(
1335            threshold > 0.0 && threshold < 1.0,
1336            "Quality threshold must be between 0 and 1"
1337        );
1338        self.quality_threshold = threshold;
1339        self
1340    }
1341
1342    /// Get performance tier for this configuration
1343    pub const fn performance_tier() -> PerformanceTier {
1344        <() as KernelMethodCompatibility<K, M>>::PERFORMANCE_TIER
1345    }
1346
1347    /// Get memory complexity for this configuration
1348    pub const fn memory_complexity() -> ComplexityClass {
1349        <() as KernelMethodCompatibility<K, M>>::MEMORY_COMPLEXITY
1350    }
1351
1352    /// Compute theoretical error bound for given data size
1353    pub fn theoretical_error_bound(&self, n_samples: usize, n_features: usize) -> f64 {
1354        <() as ApproximationQualityBounds<M>>::error_bound(n_samples, n_features, N)
1355    }
1356
1357    /// Check if configuration meets quality requirements
1358    pub fn meets_quality_requirements(&self, n_samples: usize, n_features: usize) -> bool {
1359        let error_bound = self.theoretical_error_bound(n_samples, n_features);
1360        error_bound <= self.quality_threshold
1361    }
1362}
1363
1364/// Type aliases for common validated configurations
1365pub type ValidatedRBFRandomFourier<const N: usize> =
1366    TypeSafeKernelConfig<RBFKernel, RandomFourierFeatures, N>;
1367pub type ValidatedLaplacianNystrom<const N: usize> =
1368    TypeSafeKernelConfig<LaplacianKernel, NystromMethod, N>;
1369pub type ValidatedPolynomialRFF<const N: usize> =
1370    TypeSafeKernelConfig<PolynomialKernel, RandomFourierFeatures, N>;
1371
1372impl<K1: KernelType, K2: KernelType> ComposableKernel for SumKernel<K1, K2> {
1373    type CompositionResult<Other: ComposableKernel> = SumKernel<Self, Other>;
1374
1375    fn compose<Other: ComposableKernel>(self) -> Self::CompositionResult<Other> {
1376        SumKernel {
1377            _phantom: PhantomData,
1378        }
1379    }
1380}
1381
1382impl<K1: KernelType, K2: KernelType> ComposableKernel for ProductKernel<K1, K2> {
1383    type CompositionResult<Other: ComposableKernel> = ProductKernel<Self, Other>;
1384
1385    fn compose<Other: ComposableKernel>(self) -> Self::CompositionResult<Other> {
1386        ProductKernel {
1387            _phantom: PhantomData,
1388        }
1389    }
1390}
1391
1392#[allow(non_snake_case)]
1393#[cfg(test)]
1394mod advanced_composition_tests {
1395    use super::*;
1396
1397    #[test]
1398    fn test_validated_features() {
1399        // These should compile
1400        let _features_64 = ValidatedFeatures::<64>::new();
1401        let _features_128 = ValidatedFeatures::<128>::new();
1402        let _features_256 = ValidatedFeatures::<256>::new();
1403
1404        assert_eq!(ValidatedFeatures::<64>::count(), 64);
1405        assert!(ValidatedFeatures::<64>::is_fastfood_optimal());
1406    }
1407
1408    #[test]
1409    fn test_type_safe_configs() {
1410        let config = ValidatedRBFRandomFourier::<128>::new()
1411            .bandwidth(1.5)
1412            .quality_threshold(0.05);
1413
1414        assert_eq!(
1415            ValidatedRBFRandomFourier::<128>::performance_tier(),
1416            PerformanceTier::Optimal
1417        );
1418        assert!(config.meets_quality_requirements(1000, 10));
1419    }
1420
1421    #[test]
1422    fn test_kernel_composition() {
1423        let _rbf = RBFKernel;
1424        let _laplacian = LaplacianKernel;
1425        let _polynomial = PolynomialKernel;
1426
1427        // These should compile and create valid composed kernels
1428        let _composed1 = _rbf.compose::<LaplacianKernel>();
1429        let _composed2 = _polynomial.compose::<RBFKernel>();
1430    }
1431
1432    #[test]
1433    fn test_approximation_bounds() {
1434        let rff_bound =
1435            <() as ApproximationQualityBounds<RandomFourierFeatures>>::error_bound(1000, 10, 100);
1436        let nystrom_bound =
1437            <() as ApproximationQualityBounds<NystromMethod>>::error_bound(1000, 10, 100);
1438        let fastfood_bound =
1439            <() as ApproximationQualityBounds<FastfoodMethod>>::error_bound(1000, 10, 100);
1440
1441        assert!(rff_bound > 0.0);
1442        assert!(nystrom_bound > 0.0);
1443        assert!(fastfood_bound > 0.0);
1444
1445        // Nyström should generally have tighter bounds than RFF
1446        assert!(nystrom_bound < rff_bound);
1447    }
1448}
1449
1450// ============================================================================
1451// Configuration Presets for Common Use Cases
1452// ============================================================================
1453
1454/// Configuration presets for common kernel approximation scenarios
1455pub struct KernelPresets;
1456
1457impl KernelPresets {
1458    /// Fast approximation preset - prioritizes speed over accuracy
1459    pub fn fast_rbf_128() -> ValidatedRBFRandomFourier<128> {
1460        ValidatedRBFRandomFourier::<128>::new()
1461            .bandwidth(1.0)
1462            .quality_threshold(0.2)
1463    }
1464
1465    /// Balanced approximation preset - good trade-off between speed and accuracy
1466    pub fn balanced_rbf_256() -> ValidatedRBFRandomFourier<256> {
1467        ValidatedRBFRandomFourier::<256>::new()
1468            .bandwidth(1.0)
1469            .quality_threshold(0.1)
1470    }
1471
1472    /// High-accuracy approximation preset - prioritizes accuracy over speed
1473    pub fn accurate_rbf_512() -> ValidatedRBFRandomFourier<512> {
1474        ValidatedRBFRandomFourier::<512>::new()
1475            .bandwidth(1.0)
1476            .quality_threshold(0.05)
1477    }
1478
1479    /// Ultra-fast approximation for large-scale problems
1480    pub fn ultrafast_rbf_64() -> ValidatedRBFRandomFourier<64> {
1481        ValidatedRBFRandomFourier::<64>::new()
1482            .bandwidth(1.0)
1483            .quality_threshold(0.3)
1484    }
1485
1486    /// High-precision Nyström preset for small to medium datasets
1487    pub fn precise_nystroem_128() -> ValidatedLaplacianNystrom<128> {
1488        ValidatedLaplacianNystrom::<128>::new()
1489            .bandwidth(1.0)
1490            .quality_threshold(0.01)
1491    }
1492
1493    /// Memory-efficient preset for resource-constrained environments
1494    pub fn memory_efficient_rbf_32() -> ValidatedRBFRandomFourier<32> {
1495        ValidatedRBFRandomFourier::<32>::new()
1496            .bandwidth(1.0)
1497            .quality_threshold(0.4)
1498    }
1499
1500    /// Polynomial kernel preset for structured data
1501    pub fn polynomial_features_256() -> ValidatedPolynomialRFF<256> {
1502        ValidatedPolynomialRFF::<256>::new()
1503            .bandwidth(1.0)
1504            .quality_threshold(0.15)
1505    }
1506}
1507
1508// ============================================================================
1509// Profile-Guided Optimization Support
1510// ============================================================================
1511
1512/// Profile-guided optimization configuration
1513#[derive(Debug, Clone)]
1514/// ProfileGuidedConfig
1515pub struct ProfileGuidedConfig {
1516    /// Enable PGO-based feature size selection
1517    pub enable_pgo_feature_selection: bool,
1518
1519    /// Enable PGO-based bandwidth optimization
1520    pub enable_pgo_bandwidth_optimization: bool,
1521
1522    /// Profile data file path
1523    pub profile_data_path: Option<String>,
1524
1525    /// Target hardware architecture
1526    pub target_architecture: TargetArchitecture,
1527
1528    /// Optimization level
1529    pub optimization_level: OptimizationLevel,
1530}
1531
1532/// Target hardware architecture for optimization
1533#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1534/// TargetArchitecture
1535pub enum TargetArchitecture {
1536    /// Generic x86_64 architecture
1537    X86_64Generic,
1538    /// x86_64 with AVX2 support
1539    X86_64AVX2,
1540    /// x86_64 with AVX-512 support
1541    X86_64AVX512,
1542    /// ARM64 architecture
1543    ARM64,
1544    /// ARM64 with NEON support
1545    ARM64NEON,
1546}
1547
1548/// Optimization level for profile-guided optimization
1549#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1550/// OptimizationLevel
1551pub enum OptimizationLevel {
1552    /// No optimization
1553    None,
1554    /// Basic optimizations
1555    Basic,
1556    /// Aggressive optimizations
1557    Aggressive,
1558    /// Maximum optimizations (may increase compile time significantly)
1559    Maximum,
1560}
1561
1562impl Default for ProfileGuidedConfig {
1563    fn default() -> Self {
1564        Self {
1565            enable_pgo_feature_selection: false,
1566            enable_pgo_bandwidth_optimization: false,
1567            profile_data_path: None,
1568            target_architecture: TargetArchitecture::X86_64Generic,
1569            optimization_level: OptimizationLevel::Basic,
1570        }
1571    }
1572}
1573
1574impl ProfileGuidedConfig {
1575    /// Create a new PGO configuration
1576    pub fn new() -> Self {
1577        Self::default()
1578    }
1579
1580    /// Enable PGO-based feature size selection
1581    pub fn enable_feature_selection(mut self) -> Self {
1582        self.enable_pgo_feature_selection = true;
1583        self
1584    }
1585
1586    /// Enable PGO-based bandwidth optimization
1587    pub fn enable_bandwidth_optimization(mut self) -> Self {
1588        self.enable_pgo_bandwidth_optimization = true;
1589        self
1590    }
1591
1592    /// Set profile data file path
1593    pub fn profile_data_path<P: Into<String>>(mut self, path: P) -> Self {
1594        self.profile_data_path = Some(path.into());
1595        self
1596    }
1597
1598    /// Set target architecture
1599    pub fn target_architecture(mut self, arch: TargetArchitecture) -> Self {
1600        self.target_architecture = arch;
1601        self
1602    }
1603
1604    /// Set optimization level
1605    pub fn optimization_level(mut self, level: OptimizationLevel) -> Self {
1606        self.optimization_level = level;
1607        self
1608    }
1609
1610    /// Get recommended feature count based on architecture and optimization level
1611    pub fn recommended_feature_count(&self, data_size: usize, dimensionality: usize) -> usize {
1612        let base_features = match self.target_architecture {
1613            TargetArchitecture::X86_64Generic => 128,
1614            TargetArchitecture::X86_64AVX2 => 256,
1615            TargetArchitecture::X86_64AVX512 => 512,
1616            TargetArchitecture::ARM64 => 128,
1617            TargetArchitecture::ARM64NEON => 256,
1618        };
1619
1620        let scale_factor = match self.optimization_level {
1621            OptimizationLevel::None => 0.5,
1622            OptimizationLevel::Basic => 1.0,
1623            OptimizationLevel::Aggressive => 1.5,
1624            OptimizationLevel::Maximum => 2.0,
1625        };
1626
1627        let scaled_features = (base_features as f64 * scale_factor) as usize;
1628
1629        // Adjust based on data characteristics
1630        let data_adjustment = if data_size > 10000 {
1631            1.2
1632        } else if data_size < 1000 {
1633            0.8
1634        } else {
1635            1.0
1636        };
1637
1638        let dimension_adjustment = if dimensionality > 100 {
1639            1.1
1640        } else if dimensionality < 10 {
1641            0.9
1642        } else {
1643            1.0
1644        };
1645
1646        ((scaled_features as f64 * data_adjustment * dimension_adjustment) as usize)
1647            .max(32)
1648            .min(1024)
1649    }
1650}
1651
1652// ============================================================================
1653// Serialization Support for Approximation Models
1654// ============================================================================
1655
1656/// Serializable kernel approximation configuration
1657#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
1658/// SerializableKernelConfig
1659pub struct SerializableKernelConfig {
1660    /// Kernel type name
1661    pub kernel_type: String,
1662
1663    /// Approximation method name
1664    pub approximation_method: String,
1665
1666    /// Number of features/components
1667    pub n_components: usize,
1668
1669    /// Bandwidth parameter
1670    pub bandwidth: f64,
1671
1672    /// Quality threshold
1673    pub quality_threshold: f64,
1674
1675    /// Random state for reproducibility
1676    pub random_state: Option<u64>,
1677
1678    /// Additional parameters
1679    pub additional_params: std::collections::HashMap<String, f64>,
1680}
1681
1682/// Serializable fitted model parameters
1683#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
1684/// SerializableFittedParams
1685pub struct SerializableFittedParams {
1686    /// Configuration used to create this model
1687    pub config: SerializableKernelConfig,
1688
1689    /// Random features (for RFF methods)
1690    pub random_features: Option<Vec<Vec<f64>>>,
1691
1692    /// Selected indices (for Nyström methods)
1693    pub selected_indices: Option<Vec<usize>>,
1694
1695    /// Eigenvalues (for Nyström methods)
1696    pub eigenvalues: Option<Vec<f64>>,
1697
1698    /// Eigenvectors (for Nyström methods)
1699    pub eigenvectors: Option<Vec<Vec<f64>>>,
1700
1701    /// Quality metrics achieved
1702    pub quality_metrics: std::collections::HashMap<String, f64>,
1703
1704    /// Timestamp of when model was fitted
1705    pub fitted_timestamp: Option<u64>,
1706}
1707
1708/// Trait for serializable kernel approximation methods
1709pub trait SerializableKernelApproximation {
1710    /// Export configuration to serializable format
1711    fn export_config(&self) -> Result<SerializableKernelConfig>;
1712
1713    /// Import configuration from serializable format
1714    fn import_config(config: &SerializableKernelConfig) -> Result<Self>
1715    where
1716        Self: Sized;
1717
1718    /// Export fitted parameters to serializable format
1719    fn export_fitted_params(&self) -> Result<SerializableFittedParams>;
1720
1721    /// Import fitted parameters from serializable format
1722    fn import_fitted_params(&mut self, params: &SerializableFittedParams) -> Result<()>;
1723
1724    /// Save model to file
1725    fn save_to_file<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
1726        let config = self.export_config()?;
1727        let fitted_params = self.export_fitted_params()?;
1728
1729        let model_data = serde_json::json!({
1730            "config": config,
1731            "fitted_params": fitted_params,
1732            "version": "1.0",
1733            "sklears_version": env!("CARGO_PKG_VERSION"),
1734        });
1735
1736        std::fs::write(
1737            path,
1738            serde_json::to_string_pretty(&model_data).expect("operation should succeed"),
1739        )
1740        .map_err(SklearsError::from)?;
1741
1742        Ok(())
1743    }
1744
1745    /// Load model from file
1746    fn load_from_file<P: AsRef<std::path::Path>>(path: P) -> Result<Self>
1747    where
1748        Self: Sized,
1749    {
1750        let content = std::fs::read_to_string(path).map_err(SklearsError::from)?;
1751
1752        let model_data: serde_json::Value = serde_json::from_str(&content).map_err(|e| {
1753            SklearsError::SerializationError(format!("Failed to parse JSON: {}", e))
1754        })?;
1755
1756        let config: SerializableKernelConfig = serde_json::from_value(model_data["config"].clone())
1757            .map_err(|e| {
1758                SklearsError::SerializationError(format!("Failed to deserialize config: {}", e))
1759            })?;
1760
1761        let fitted_params: SerializableFittedParams =
1762            serde_json::from_value(model_data["fitted_params"].clone()).map_err(|e| {
1763                SklearsError::SerializationError(format!(
1764                    "Failed to deserialize fitted params: {}",
1765                    e
1766                ))
1767            })?;
1768
1769        let mut model = Self::import_config(&config)?;
1770        model.import_fitted_params(&fitted_params)?;
1771
1772        Ok(model)
1773    }
1774}
1775
1776#[allow(non_snake_case)]
1777#[cfg(test)]
1778mod preset_tests {
1779    use super::*;
1780
1781    #[test]
1782    fn test_kernel_presets() {
1783        let _fast_config = KernelPresets::fast_rbf_128();
1784        assert_eq!(
1785            ValidatedRBFRandomFourier::<128>::performance_tier(),
1786            PerformanceTier::Optimal
1787        );
1788
1789        let _balanced_config = KernelPresets::balanced_rbf_256();
1790        assert_eq!(
1791            ValidatedRBFRandomFourier::<256>::performance_tier(),
1792            PerformanceTier::Optimal
1793        );
1794
1795        let _accurate_config = KernelPresets::accurate_rbf_512();
1796        assert_eq!(
1797            ValidatedRBFRandomFourier::<512>::performance_tier(),
1798            PerformanceTier::Optimal
1799        );
1800    }
1801
1802    #[test]
1803    fn test_profile_guided_config() {
1804        let pgo_config = ProfileGuidedConfig::new()
1805            .enable_feature_selection()
1806            .enable_bandwidth_optimization()
1807            .target_architecture(TargetArchitecture::X86_64AVX2)
1808            .optimization_level(OptimizationLevel::Aggressive);
1809
1810        assert!(pgo_config.enable_pgo_feature_selection);
1811        assert!(pgo_config.enable_pgo_bandwidth_optimization);
1812        assert_eq!(
1813            pgo_config.target_architecture,
1814            TargetArchitecture::X86_64AVX2
1815        );
1816        assert_eq!(pgo_config.optimization_level, OptimizationLevel::Aggressive);
1817
1818        // Test feature count recommendation
1819        let features = pgo_config.recommended_feature_count(5000, 50);
1820        assert!(features >= 32 && features <= 1024);
1821    }
1822
1823    #[test]
1824    fn test_serializable_config() {
1825        let config = SerializableKernelConfig {
1826            kernel_type: "RBF".to_string(),
1827            approximation_method: "RandomFourierFeatures".to_string(),
1828            n_components: 256,
1829            bandwidth: 1.5,
1830            quality_threshold: 0.1,
1831            random_state: Some(42),
1832            additional_params: std::collections::HashMap::new(),
1833        };
1834
1835        // Test serialization
1836        let serialized = serde_json::to_string(&config).expect("operation should succeed");
1837        assert!(serialized.contains("RBF"));
1838        assert!(serialized.contains("RandomFourierFeatures"));
1839
1840        // Test deserialization
1841        let deserialized: SerializableKernelConfig =
1842            serde_json::from_str(&serialized).expect("operation should succeed");
1843        assert_eq!(deserialized.kernel_type, "RBF");
1844        assert_eq!(deserialized.n_components, 256);
1845        assert_eq!(deserialized.bandwidth, 1.5);
1846    }
1847}