sklears_kernel_approximation/
kernel_framework.rs

1//! Comprehensive trait-based framework for kernel approximations
2//!
3//! This module provides a unified trait system for implementing kernel approximation
4//! methods, making it easy to create new approximation strategies and compose them.
5//!
6//! # Architecture
7//!
8//! - **KernelMethod**: Core trait for all kernel approximation methods
9//! - **SamplingStrategy**: Abstract sampling strategies (uniform, importance, etc.)
10//! - **FeatureMap**: Abstract feature transformations
11//! - **ApproximationQuality**: Quality metrics and guarantees
12//! - **ComposableKernel**: Combine multiple kernels
13
14use scirs2_core::ndarray::{Array1, Array2};
15use sklears_core::error::SklearsError;
16use std::fmt::Debug;
17
18/// Core trait for kernel approximation methods
19pub trait KernelMethod: Send + Sync + Debug {
20    /// Get the approximation method name
21    fn name(&self) -> &str;
22
23    /// Get the number of output features (if known before fitting)
24    fn n_output_features(&self) -> Option<usize>;
25
26    /// Get approximation complexity (e.g., O(n*d), O(n^2))
27    fn complexity(&self) -> Complexity;
28
29    /// Get theoretical error bounds (if available)
30    fn error_bound(&self) -> Option<ErrorBound>;
31
32    /// Check if this method supports the given kernel type
33    fn supports_kernel(&self, kernel_type: KernelType) -> bool;
34
35    /// Get supported kernel types
36    fn supported_kernels(&self) -> Vec<KernelType>;
37}
38
39/// Sampling strategy for selecting landmarks/components
40pub trait SamplingStrategy: Send + Sync + Debug {
41    /// Sample indices from the dataset
42    fn sample(&self, data: &Array2<f64>, n_samples: usize) -> Result<Vec<usize>, SklearsError>;
43
44    /// Get the sampling strategy name
45    fn name(&self) -> &str;
46
47    /// Check if this strategy requires fitting
48    fn requires_fitting(&self) -> bool {
49        false
50    }
51
52    /// Fit the sampling strategy (if needed)
53    fn fit(&mut self, _data: &Array2<f64>) -> Result<(), SklearsError> {
54        Ok(())
55    }
56
57    /// Get sampling weights (if applicable)
58    fn weights(&self) -> Option<Array1<f64>> {
59        None
60    }
61}
62
63/// Feature map transformation
64pub trait FeatureMap: Send + Sync + Debug {
65    /// Apply the feature map to input data
66    fn transform(&self, data: &Array2<f64>) -> Result<Array2<f64>, SklearsError>;
67
68    /// Get the output dimension
69    fn output_dim(&self) -> usize;
70
71    /// Get the feature map name
72    fn name(&self) -> &str;
73
74    /// Check if the feature map is invertible
75    fn is_invertible(&self) -> bool {
76        false
77    }
78
79    /// Inverse transform (if supported)
80    fn inverse_transform(&self, _features: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
81        Err(SklearsError::InvalidInput(
82            "Inverse transform not supported".to_string(),
83        ))
84    }
85}
86
87/// Approximation quality metrics
88pub trait ApproximationQuality: Send + Sync + Debug {
89    /// Compute approximation quality metric
90    fn compute(
91        &self,
92        exact_kernel: &Array2<f64>,
93        approx_kernel: &Array2<f64>,
94    ) -> Result<f64, SklearsError>;
95
96    /// Get the metric name
97    fn name(&self) -> &str;
98
99    /// Check if higher values indicate better quality
100    fn higher_is_better(&self) -> bool;
101
102    /// Get acceptable quality threshold
103    fn acceptable_threshold(&self) -> Option<f64> {
104        None
105    }
106}
107
108/// Computational complexity classification
109#[derive(Debug, Clone, PartialEq, Eq)]
110pub enum Complexity {
111    /// O(d) - linear in feature dimension
112    Linear,
113    /// O(d log d) - quasi-linear
114    QuasiLinear,
115    /// O(n*d) - linear in samples and features
116    LinearBoth,
117    /// O(n*d^2) - quadratic in features
118    QuadraticFeatures,
119    /// O(n^2*d) - quadratic in samples
120    QuadraticSamples,
121    /// O(n^3) - cubic (e.g., exact methods)
122    Cubic,
123    /// Custom complexity
124    Custom(String),
125}
126
127impl Complexity {
128    /// Get a human-readable description
129    pub fn description(&self) -> &str {
130        match self {
131            Complexity::Linear => "O(d) - Linear in features",
132            Complexity::QuasiLinear => "O(d log d) - Quasi-linear",
133            Complexity::LinearBoth => "O(n*d) - Linear in samples and features",
134            Complexity::QuadraticFeatures => "O(n*d^2) - Quadratic in features",
135            Complexity::QuadraticSamples => "O(n^2*d) - Quadratic in samples",
136            Complexity::Cubic => "O(n^3) - Cubic complexity",
137            Complexity::Custom(s) => s,
138        }
139    }
140}
141
142/// Error bound information
143#[derive(Debug, Clone)]
144pub struct ErrorBound {
145    /// Type of bound (probabilistic, deterministic, etc.)
146    pub bound_type: BoundType,
147    /// Error value
148    pub error: f64,
149    /// Confidence level (for probabilistic bounds)
150    pub confidence: Option<f64>,
151    /// Description of the bound
152    pub description: String,
153}
154
155/// Type of error bound
156#[derive(Debug, Clone, Copy, PartialEq, Eq)]
157pub enum BoundType {
158    /// Probabilistic bound (holds with probability)
159    Probabilistic,
160    /// Deterministic bound (always holds)
161    Deterministic,
162    /// Expected error bound
163    Expected,
164    /// Empirical bound from validation
165    Empirical,
166}
167
168/// Kernel type classification
169#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
170pub enum KernelType {
171    /// Radial Basis Function (Gaussian)
172    RBF,
173    /// Laplacian kernel
174    Laplacian,
175    /// Polynomial kernel
176    Polynomial,
177    /// Linear kernel
178    Linear,
179    /// Arc-cosine (neural network) kernel
180    ArcCosine,
181    /// Chi-squared kernel
182    ChiSquared,
183    /// String kernel
184    String,
185    /// Graph kernel
186    Graph,
187    /// Custom kernel
188    Custom,
189}
190
191impl KernelType {
192    /// Get kernel name
193    pub fn name(&self) -> &str {
194        match self {
195            KernelType::RBF => "RBF",
196            KernelType::Laplacian => "Laplacian",
197            KernelType::Polynomial => "Polynomial",
198            KernelType::Linear => "Linear",
199            KernelType::ArcCosine => "ArcCosine",
200            KernelType::ChiSquared => "ChiSquared",
201            KernelType::String => "String",
202            KernelType::Graph => "Graph",
203            KernelType::Custom => "Custom",
204        }
205    }
206}
207
208/// Uniform random sampling strategy
209#[derive(Debug, Clone)]
210pub struct UniformSampling {
211    /// Random seed
212    pub random_state: Option<u64>,
213}
214
215impl UniformSampling {
216    /// Create a new uniform sampling strategy
217    pub fn new(random_state: Option<u64>) -> Self {
218        Self { random_state }
219    }
220}
221
222impl SamplingStrategy for UniformSampling {
223    fn sample(&self, data: &Array2<f64>, n_samples: usize) -> Result<Vec<usize>, SklearsError> {
224        use scirs2_core::random::seeded_rng;
225
226        let (n_rows, _) = data.dim();
227        if n_samples > n_rows {
228            return Err(SklearsError::InvalidInput(format!(
229                "Cannot sample {} points from {} samples",
230                n_samples, n_rows
231            )));
232        }
233
234        let mut rng = seeded_rng(self.random_state.unwrap_or(42));
235
236        // Reservoir sampling for unbiased selection
237        let mut indices: Vec<usize> = (0..n_samples).collect();
238        for i in n_samples..n_rows {
239            let j = rng.gen_range(0..=i);
240            if j < n_samples {
241                indices[j] = i;
242            }
243        }
244
245        Ok(indices)
246    }
247
248    fn name(&self) -> &str {
249        "UniformSampling"
250    }
251}
252
253/// K-means based sampling strategy
254#[derive(Debug, Clone)]
255pub struct KMeansSampling {
256    /// Number of iterations for k-means
257    pub n_iterations: usize,
258    /// Random seed
259    pub random_state: Option<u64>,
260    /// Cluster centers (fitted)
261    centers: Option<Array2<f64>>,
262}
263
264impl KMeansSampling {
265    /// Create a new k-means sampling strategy
266    pub fn new(n_iterations: usize, random_state: Option<u64>) -> Self {
267        Self {
268            n_iterations,
269            random_state,
270            centers: None,
271        }
272    }
273}
274
275impl SamplingStrategy for KMeansSampling {
276    fn sample(&self, data: &Array2<f64>, n_samples: usize) -> Result<Vec<usize>, SklearsError> {
277        use scirs2_core::random::seeded_rng;
278
279        let (n_rows, n_features) = data.dim();
280        if n_samples > n_rows {
281            return Err(SklearsError::InvalidInput(format!(
282                "Cannot sample {} points from {} samples",
283                n_samples, n_rows
284            )));
285        }
286
287        let mut rng = seeded_rng(self.random_state.unwrap_or(42));
288
289        // Initialize centers randomly
290        let mut centers = Array2::zeros((n_samples, n_features));
291        let mut initial_indices: Vec<usize> = (0..n_rows).collect();
292        for i in 0..n_samples {
293            let idx = rng.gen_range(0..initial_indices.len());
294            let sample_idx = initial_indices.swap_remove(idx);
295            centers.row_mut(i).assign(&data.row(sample_idx));
296        }
297
298        // K-means iterations
299        let mut assignments = vec![0; n_rows];
300        for _ in 0..self.n_iterations {
301            // Assign points to nearest center
302            for i in 0..n_rows {
303                let point = data.row(i);
304                let mut min_dist = f64::INFINITY;
305                let mut best_cluster = 0;
306
307                for j in 0..n_samples {
308                    let center = centers.row(j);
309                    let dist: f64 = point
310                        .iter()
311                        .zip(center.iter())
312                        .map(|(a, b)| (a - b).powi(2))
313                        .sum();
314
315                    if dist < min_dist {
316                        min_dist = dist;
317                        best_cluster = j;
318                    }
319                }
320                assignments[i] = best_cluster;
321            }
322
323            // Update centers
324            let mut counts = vec![0; n_samples];
325            centers.fill(0.0);
326
327            for i in 0..n_rows {
328                let cluster = assignments[i];
329                let point = data.row(i);
330                for (j, &val) in point.iter().enumerate() {
331                    centers[[cluster, j]] += val;
332                }
333                counts[cluster] += 1;
334            }
335
336            for j in 0..n_samples {
337                if counts[j] > 0 {
338                    for k in 0..n_features {
339                        centers[[j, k]] /= counts[j] as f64;
340                    }
341                }
342            }
343        }
344
345        // Find nearest point to each center
346        let mut selected_indices = Vec::with_capacity(n_samples);
347        for center_idx in 0..n_samples {
348            let center = centers.row(center_idx);
349            let mut min_dist = f64::INFINITY;
350            let mut best_idx = 0;
351
352            for i in 0..n_rows {
353                let point = data.row(i);
354                let dist: f64 = point
355                    .iter()
356                    .zip(center.iter())
357                    .map(|(a, b)| (a - b).powi(2))
358                    .sum();
359
360                if dist < min_dist {
361                    min_dist = dist;
362                    best_idx = i;
363                }
364            }
365            selected_indices.push(best_idx);
366        }
367
368        Ok(selected_indices)
369    }
370
371    fn name(&self) -> &str {
372        "KMeansSampling"
373    }
374
375    fn requires_fitting(&self) -> bool {
376        true
377    }
378
379    fn fit(&mut self, data: &Array2<f64>) -> Result<(), SklearsError> {
380        // Store centers for potential reuse
381        let _ = data;
382        Ok(())
383    }
384}
385
386/// Kernel alignment quality metric
387#[derive(Debug, Clone)]
388pub struct KernelAlignmentMetric;
389
390impl ApproximationQuality for KernelAlignmentMetric {
391    fn compute(
392        &self,
393        exact_kernel: &Array2<f64>,
394        approx_kernel: &Array2<f64>,
395    ) -> Result<f64, SklearsError> {
396        let (n1, m1) = exact_kernel.dim();
397        let (n2, m2) = approx_kernel.dim();
398
399        if n1 != n2 || m1 != m2 {
400            return Err(SklearsError::InvalidInput(
401                "Kernel matrices must have the same shape".to_string(),
402            ));
403        }
404
405        // Compute Frobenius inner product
406        let mut inner_product = 0.0;
407        let mut exact_norm = 0.0;
408        let mut approx_norm = 0.0;
409
410        for i in 0..n1 {
411            for j in 0..m1 {
412                let exact_val = exact_kernel[[i, j]];
413                let approx_val = approx_kernel[[i, j]];
414                inner_product += exact_val * approx_val;
415                exact_norm += exact_val * exact_val;
416                approx_norm += approx_val * approx_val;
417            }
418        }
419
420        if exact_norm < 1e-10 || approx_norm < 1e-10 {
421            return Ok(0.0);
422        }
423
424        Ok(inner_product / (exact_norm.sqrt() * approx_norm.sqrt()))
425    }
426
427    fn name(&self) -> &str {
428        "KernelAlignment"
429    }
430
431    fn higher_is_better(&self) -> bool {
432        true
433    }
434
435    fn acceptable_threshold(&self) -> Option<f64> {
436        Some(0.9) // 90% alignment is typically considered good
437    }
438}
439
440/// Composable kernel that combines multiple kernel methods
441#[derive(Debug)]
442pub struct CompositeKernelMethod {
443    /// List of kernel methods to compose
444    methods: Vec<Box<dyn KernelMethod>>,
445    /// Combination strategy
446    strategy: CombinationStrategy,
447}
448
449/// Strategy for combining multiple kernels
450#[derive(Debug, Clone, Copy)]
451pub enum CombinationStrategy {
452    /// Concatenate features from all kernels
453    Concatenate,
454    /// Average kernel matrices
455    Average,
456    /// Weighted sum
457    WeightedSum,
458    /// Product of kernels
459    Product,
460}
461
462impl CompositeKernelMethod {
463    /// Create a new composite kernel method
464    pub fn new(strategy: CombinationStrategy) -> Self {
465        Self {
466            methods: Vec::new(),
467            strategy,
468        }
469    }
470
471    /// Add a kernel method to the composition
472    pub fn add_method(&mut self, method: Box<dyn KernelMethod>) {
473        self.methods.push(method);
474    }
475
476    /// Get the combination strategy
477    pub fn strategy(&self) -> CombinationStrategy {
478        self.strategy
479    }
480
481    /// Get number of methods
482    pub fn len(&self) -> usize {
483        self.methods.len()
484    }
485
486    /// Check if empty
487    pub fn is_empty(&self) -> bool {
488        self.methods.is_empty()
489    }
490}
491
492impl KernelMethod for CompositeKernelMethod {
493    fn name(&self) -> &str {
494        "CompositeKernel"
495    }
496
497    fn n_output_features(&self) -> Option<usize> {
498        match self.strategy {
499            CombinationStrategy::Concatenate => {
500                let mut total = 0;
501                for method in &self.methods {
502                    if let Some(n) = method.n_output_features() {
503                        total += n;
504                    } else {
505                        return None;
506                    }
507                }
508                Some(total)
509            }
510            _ => {
511                // For other strategies, use the first method's output size
512                self.methods.first().and_then(|m| m.n_output_features())
513            }
514        }
515    }
516
517    fn complexity(&self) -> Complexity {
518        // Return the worst complexity among all methods
519        let mut worst = Complexity::Linear;
520        for method in &self.methods {
521            let c = method.complexity();
522            worst = match (worst, c.clone()) {
523                (Complexity::Cubic, _) | (_, Complexity::Cubic) => Complexity::Cubic,
524                (Complexity::QuadraticSamples, _) | (_, Complexity::QuadraticSamples) => {
525                    Complexity::QuadraticSamples
526                }
527                (Complexity::QuadraticFeatures, _) | (_, Complexity::QuadraticFeatures) => {
528                    Complexity::QuadraticFeatures
529                }
530                _ => c,
531            };
532        }
533        worst
534    }
535
536    fn error_bound(&self) -> Option<ErrorBound> {
537        // Combine error bounds (if available)
538        // For simplicity, return None if any method doesn't have a bound
539        None
540    }
541
542    fn supports_kernel(&self, kernel_type: KernelType) -> bool {
543        // Check if any method supports this kernel type
544        self.methods.iter().any(|m| m.supports_kernel(kernel_type))
545    }
546
547    fn supported_kernels(&self) -> Vec<KernelType> {
548        let mut kernels = Vec::new();
549        for method in &self.methods {
550            for kernel in method.supported_kernels() {
551                if !kernels.contains(&kernel) {
552                    kernels.push(kernel);
553                }
554            }
555        }
556        kernels
557    }
558}
559
560#[cfg(test)]
561mod tests {
562    use super::*;
563    use scirs2_core::ndarray::array;
564
565    #[test]
566    fn test_complexity_description() {
567        let c = Complexity::Linear;
568        assert!(c.description().contains("Linear"));
569
570        let c = Complexity::QuasiLinear;
571        assert!(c.description().contains("Quasi-linear"));
572    }
573
574    #[test]
575    fn test_kernel_type_name() {
576        assert_eq!(KernelType::RBF.name(), "RBF");
577        assert_eq!(KernelType::Polynomial.name(), "Polynomial");
578    }
579
580    #[test]
581    fn test_uniform_sampling() {
582        let strategy = UniformSampling::new(Some(42));
583        let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
584
585        let indices = strategy.sample(&data, 2).unwrap();
586        assert_eq!(indices.len(), 2);
587        assert!(indices[0] < 4);
588        assert!(indices[1] < 4);
589    }
590
591    #[test]
592    fn test_kmeans_sampling() {
593        let strategy = KMeansSampling::new(5, Some(42));
594        let data = array![
595            [1.0, 1.0],
596            [1.1, 1.1],
597            [5.0, 5.0],
598            [5.1, 5.1],
599            [9.0, 9.0],
600            [9.1, 9.1]
601        ];
602
603        let indices = strategy.sample(&data, 3).unwrap();
604        assert_eq!(indices.len(), 3);
605    }
606
607    #[test]
608    fn test_kernel_alignment_metric() {
609        let metric = KernelAlignmentMetric;
610        let exact = array![[1.0, 0.5], [0.5, 1.0]];
611        let approx = array![[1.0, 0.6], [0.6, 1.0]];
612
613        let alignment = metric.compute(&exact, &approx).unwrap();
614        assert!(alignment > 0.9 && alignment <= 1.0);
615        assert!(metric.higher_is_better());
616    }
617
618    #[test]
619    fn test_composite_kernel_method() {
620        let composite = CompositeKernelMethod::new(CombinationStrategy::Concatenate);
621        assert!(composite.is_empty());
622        assert_eq!(composite.len(), 0);
623    }
624
625    #[test]
626    fn test_bound_type() {
627        let bound = ErrorBound {
628            bound_type: BoundType::Probabilistic,
629            error: 0.1,
630            confidence: Some(0.95),
631            description: "Test bound".to_string(),
632        };
633
634        assert_eq!(bound.bound_type, BoundType::Probabilistic);
635        assert_eq!(bound.error, 0.1);
636    }
637}