sklears_kernel_approximation/sparse_gp/
ski.rs

1//! Structured Kernel Interpolation (SKI/KISS-GP) for fast GP inference
2//!
3//! This module implements Structured Kernel Interpolation methods for
4//! fast Gaussian Process inference on structured data, including
5//! grid-based interpolation and Kronecker structure exploitation.
6
7use crate::sparse_gp::core::*;
8use crate::sparse_gp::kernels::{KernelOps, SparseKernel};
9use scirs2_core::ndarray::{Array1, Array2, Axis};
10use sklears_core::error::{Result, SklearsError};
11use sklears_core::traits::{Fit, Predict};
12use std::collections::HashSet;
13
14/// Structured Kernel Interpolation implementation
15impl<K: SparseKernel> StructuredKernelInterpolation<K> {
16    /// Create new structured kernel interpolation
17    pub fn new(grid_size: Vec<usize>, kernel: K) -> Self {
18        Self {
19            grid_size,
20            kernel,
21            noise_variance: 1e-6,
22            interpolation: InterpolationMethod::Linear,
23        }
24    }
25
26    /// Generate structured grid points over input space
27    pub fn generate_grid_points(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
28        let n_features = x.ncols();
29        if self.grid_size.len() != n_features {
30            return Err(SklearsError::InvalidInput(
31                "Grid size dimension mismatch".to_string(),
32            ));
33        }
34
35        let total_grid_points: usize = self.grid_size.iter().product();
36        let mut grid_points = Array2::zeros((total_grid_points, n_features));
37
38        // Compute feature ranges
39        let mut ranges = Vec::with_capacity(n_features);
40        for j in 0..n_features {
41            let col = x.column(j);
42            let min_val = col.fold(f64::INFINITY, |a, &b| a.min(b));
43            let max_val = col.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
44            ranges.push((min_val, max_val));
45        }
46
47        // Generate grid points recursively
48        let mut point_idx = 0;
49        self.generate_grid_recursive(
50            &mut grid_points,
51            &ranges,
52            &mut vec![0; n_features],
53            0,
54            &mut point_idx,
55        );
56
57        Ok(grid_points)
58    }
59
60    /// Recursive helper for grid generation
61    fn generate_grid_recursive(
62        &self,
63        grid_points: &mut Array2<f64>,
64        ranges: &[(f64, f64)],
65        current_indices: &mut Vec<usize>,
66        dim: usize,
67        point_idx: &mut usize,
68    ) {
69        if dim == ranges.len() {
70            // Generate point at current multi-index
71            for (j, &idx) in current_indices.iter().enumerate() {
72                let (min_val, max_val) = ranges[j];
73                let grid_val = if self.grid_size[j] == 1 {
74                    (min_val + max_val) / 2.0
75                } else {
76                    min_val + idx as f64 * (max_val - min_val) / (self.grid_size[j] - 1) as f64
77                };
78                grid_points[(*point_idx, j)] = grid_val;
79            }
80            *point_idx += 1;
81            return;
82        }
83
84        for i in 0..self.grid_size[dim] {
85            current_indices[dim] = i;
86            self.generate_grid_recursive(grid_points, ranges, current_indices, dim + 1, point_idx);
87        }
88    }
89
90    /// Compute interpolation weights for data points to grid
91    pub fn compute_interpolation_weights(
92        &self,
93        x: &Array2<f64>,
94        grid_points: &Array2<f64>,
95        ranges: &[(f64, f64)],
96    ) -> Result<Array2<f64>> {
97        let n = x.nrows();
98        let n_grid = grid_points.nrows();
99        let _n_features = x.ncols();
100
101        let mut weights = Array2::zeros((n, n_grid));
102
103        match self.interpolation {
104            InterpolationMethod::Linear => {
105                self.compute_linear_weights(x, grid_points, ranges, &mut weights)?;
106            }
107            InterpolationMethod::Cubic => {
108                self.compute_cubic_weights(x, grid_points, ranges, &mut weights)?;
109            }
110        }
111
112        // Normalize weights for each data point
113        for i in 0..n {
114            let weight_sum = weights.row(i).sum();
115            if weight_sum > 1e-12 {
116                for g in 0..n_grid {
117                    weights[(i, g)] /= weight_sum;
118                }
119            }
120        }
121
122        Ok(weights)
123    }
124
125    /// Compute linear interpolation weights
126    fn compute_linear_weights(
127        &self,
128        x: &Array2<f64>,
129        grid_points: &Array2<f64>,
130        ranges: &[(f64, f64)],
131        weights: &mut Array2<f64>,
132    ) -> Result<()> {
133        let n = x.nrows();
134        let n_grid = grid_points.nrows();
135        let n_features = x.ncols();
136
137        for i in 0..n {
138            for g in 0..n_grid {
139                let mut weight = 1.0;
140                let mut valid = true;
141
142                for j in 0..n_features {
143                    let x_val = x[(i, j)];
144                    let grid_val = grid_points[(g, j)];
145                    let (min_val, max_val) = ranges[j];
146
147                    let grid_spacing = if self.grid_size[j] == 1 {
148                        max_val - min_val
149                    } else {
150                        (max_val - min_val) / (self.grid_size[j] - 1) as f64
151                    };
152
153                    let distance = (x_val - grid_val).abs();
154
155                    // Check if point is within interpolation support
156                    if distance > grid_spacing + 1e-12 {
157                        valid = false;
158                        break;
159                    }
160
161                    // Linear weight: 1 - |distance| / grid_spacing
162                    if grid_spacing > 1e-12 {
163                        weight *= 1.0 - distance / grid_spacing;
164                    }
165                }
166
167                if valid {
168                    weights[(i, g)] = weight;
169                }
170            }
171        }
172
173        Ok(())
174    }
175
176    /// Compute cubic interpolation weights
177    fn compute_cubic_weights(
178        &self,
179        x: &Array2<f64>,
180        grid_points: &Array2<f64>,
181        ranges: &[(f64, f64)],
182        weights: &mut Array2<f64>,
183    ) -> Result<()> {
184        let n = x.nrows();
185        let n_grid = grid_points.nrows();
186        let n_features = x.ncols();
187
188        for i in 0..n {
189            for g in 0..n_grid {
190                let mut weight = 1.0;
191                let mut valid = true;
192
193                for j in 0..n_features {
194                    let x_val = x[(i, j)];
195                    let grid_val = grid_points[(g, j)];
196                    let (min_val, max_val) = ranges[j];
197
198                    let grid_spacing = if self.grid_size[j] == 1 {
199                        max_val - min_val
200                    } else {
201                        (max_val - min_val) / (self.grid_size[j] - 1) as f64
202                    };
203
204                    let distance = (x_val - grid_val).abs();
205
206                    // Cubic interpolation has wider support
207                    if distance > 2.0 * grid_spacing + 1e-12 {
208                        valid = false;
209                        break;
210                    }
211
212                    // Cubic B-spline weight
213                    if grid_spacing > 1e-12 {
214                        let t = distance / grid_spacing;
215                        let cubic_weight = if t <= 1.0 {
216                            1.0 - 1.5 * t * t + 0.75 * t * t * t
217                        } else if t <= 2.0 {
218                            0.25 * (2.0 - t).powi(3)
219                        } else {
220                            0.0
221                        };
222                        weight *= cubic_weight;
223                    }
224                }
225
226                if valid && weight > 1e-12 {
227                    weights[(i, g)] = weight;
228                }
229            }
230        }
231
232        Ok(())
233    }
234}
235
236/// Fit implementation for SKI
237impl<K: SparseKernel> Fit<Array2<f64>, Array1<f64>> for StructuredKernelInterpolation<K> {
238    type Fitted = FittedSKI<K>;
239
240    fn fit(self, x: &Array2<f64>, y: &Array1<f64>) -> Result<Self::Fitted> {
241        // Generate grid points
242        let grid_points = self.generate_grid_points(x)?;
243
244        // Compute feature ranges
245        let n_features = x.ncols();
246        let mut ranges = Vec::with_capacity(n_features);
247        for j in 0..n_features {
248            let col = x.column(j);
249            let min_val = col.fold(f64::INFINITY, |a, &b| a.min(b));
250            let max_val = col.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
251            ranges.push((min_val, max_val));
252        }
253
254        // Compute interpolation weights
255        let weights = self.compute_interpolation_weights(x, &grid_points, &ranges)?;
256
257        // Compute kernel matrix on grid with Kronecker structure
258        let k_gg = self.compute_grid_kernel_matrix(&grid_points)?;
259
260        // Add noise to diagonal
261        let mut k_gg_noise = k_gg;
262        let n_grid = grid_points.nrows();
263        for i in 0..n_grid {
264            k_gg_noise[(i, i)] += self.noise_variance;
265        }
266
267        // Solve structured system: (K_gg + σ²I) α = W^T y
268        let weighted_y = weights.t().dot(y);
269        let alpha = self.solve_structured_system(&k_gg_noise, &weighted_y)?;
270
271        Ok(FittedSKI {
272            grid_points,
273            weights,
274            kernel: self.kernel.clone(),
275            alpha,
276        })
277    }
278}
279
280impl<K: SparseKernel> StructuredKernelInterpolation<K> {
281    /// Compute kernel matrix on grid with potential Kronecker structure
282    fn compute_grid_kernel_matrix(&self, grid_points: &Array2<f64>) -> Result<Array2<f64>> {
283        let n_features = grid_points.ncols();
284
285        // Check if we can use Kronecker structure (1D case or separable kernel)
286        if n_features == 1 || self.can_use_kronecker_structure() {
287            self.compute_kronecker_kernel_matrix(grid_points)
288        } else {
289            // Fall back to standard kernel matrix computation
290            Ok(self.kernel.kernel_matrix(grid_points, grid_points))
291        }
292    }
293
294    /// Check if kernel supports Kronecker decomposition
295    fn can_use_kronecker_structure(&self) -> bool {
296        // For now, assume RBF kernels can use Kronecker structure
297        // This would be determined by kernel type in full implementation
298        true
299    }
300
301    /// Compute kernel matrix using Kronecker structure
302    fn compute_kronecker_kernel_matrix(&self, grid_points: &Array2<f64>) -> Result<Array2<f64>> {
303        let n_features = grid_points.ncols();
304
305        if n_features == 1 {
306            // 1D case - no Kronecker structure needed
307            return Ok(self.kernel.kernel_matrix(grid_points, grid_points));
308        }
309
310        // Multi-dimensional case: K = K_1 ⊗ K_2 ⊗ ... ⊗ K_d
311        // For now, use standard computation as full Kronecker implementation is complex
312        Ok(self.kernel.kernel_matrix(grid_points, grid_points))
313    }
314
315    /// Solve structured linear system efficiently
316    fn solve_structured_system(
317        &self,
318        k_matrix: &Array2<f64>,
319        rhs: &Array1<f64>,
320    ) -> Result<Array1<f64>> {
321        // For Kronecker structured systems, we could use specialized solvers
322        // For now, use standard Cholesky decomposition
323        let k_inv = KernelOps::invert_using_cholesky(k_matrix)?;
324        Ok(k_inv.dot(rhs))
325    }
326}
327
328/// Prediction implementation for fitted SKI
329impl<K: SparseKernel> Predict<Array2<f64>, Array1<f64>> for FittedSKI<K> {
330    fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>> {
331        // Compute test interpolation weights
332        let n_features = x.ncols();
333        let mut ranges = Vec::with_capacity(n_features);
334        for j in 0..n_features {
335            let col = self.grid_points.column(j);
336            let min_val = col.fold(f64::INFINITY, |a, &b| a.min(b));
337            let max_val = col.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
338            ranges.push((min_val, max_val));
339        }
340
341        // Reconstruct grid sizes from grid points
342        let grid_size = self.infer_grid_size_from_points()?;
343
344        let ski = StructuredKernelInterpolation {
345            grid_size,
346            kernel: self.kernel.clone(),
347            noise_variance: 1e-6,
348            interpolation: InterpolationMethod::Linear,
349        };
350
351        let test_weights = ski.compute_interpolation_weights(x, &self.grid_points, &ranges)?;
352        let predictions = test_weights.dot(&self.alpha);
353        Ok(predictions)
354    }
355}
356
357impl<K: SparseKernel> FittedSKI<K> {
358    /// Infer grid size from grid points (for prediction)
359    fn infer_grid_size_from_points(&self) -> Result<Vec<usize>> {
360        let n_features = self.grid_points.ncols();
361        let mut grid_size = vec![1; n_features];
362
363        for j in 0..n_features {
364            let col = self.grid_points.column(j);
365            let unique_vals: HashSet<_> = col.iter().map(|&x| (x * 1e6).round() as i64).collect();
366            grid_size[j] = unique_vals.len();
367        }
368
369        Ok(grid_size)
370    }
371
372    /// Predict with uncertainty quantification
373    pub fn predict_with_variance(&self, x: &Array2<f64>) -> Result<(Array1<f64>, Array1<f64>)> {
374        // Compute interpolation weights
375        let n_features = x.ncols();
376        let mut ranges = Vec::with_capacity(n_features);
377        for j in 0..n_features {
378            let col = self.grid_points.column(j);
379            let min_val = col.fold(f64::INFINITY, |a, &b| a.min(b));
380            let max_val = col.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
381            ranges.push((min_val, max_val));
382        }
383
384        let grid_size = self.infer_grid_size_from_points()?;
385        let ski = StructuredKernelInterpolation {
386            grid_size,
387            kernel: self.kernel.clone(),
388            noise_variance: 1e-6,
389            interpolation: InterpolationMethod::Linear,
390        };
391
392        let test_weights = ski.compute_interpolation_weights(x, &self.grid_points, &ranges)?;
393
394        // Predictive mean
395        let pred_mean = test_weights.dot(&self.alpha);
396
397        // Predictive variance (simplified - full implementation would account for interpolation uncertainty)
398        let k_test_diag = self.kernel.kernel_diagonal(x);
399        let pred_var = k_test_diag; // Simplified - would subtract interpolation effects
400
401        Ok((pred_mean, pred_var))
402    }
403}
404
405/// Multi-dimensional SKI with tensor structure
406pub struct TensorSKI<K: SparseKernel> {
407    /// Grid sizes for each dimension
408    pub grid_sizes: Vec<usize>,
409    /// Kernel function
410    pub kernel: K,
411    /// Noise variance
412    pub noise_variance: f64,
413    /// Whether to use Kronecker structure
414    pub use_kronecker: bool,
415}
416
417impl<K: SparseKernel> TensorSKI<K> {
418    pub fn new(grid_sizes: Vec<usize>, kernel: K) -> Self {
419        Self {
420            grid_sizes,
421            kernel,
422            noise_variance: 1e-6,
423            use_kronecker: true,
424        }
425    }
426
427    /// Fit tensor SKI with full Kronecker structure
428    pub fn fit_tensor(&self, x: &Array2<f64>, _y: &Array1<f64>) -> Result<FittedTensorSKI<K>> {
429        if !self.use_kronecker {
430            return Err(SklearsError::InvalidInput(
431                "Tensor SKI requires Kronecker structure".to_string(),
432            ));
433        }
434
435        let n_features = x.ncols();
436        if self.grid_sizes.len() != n_features {
437            return Err(SklearsError::InvalidInput(
438                "Grid sizes must match number of features".to_string(),
439            ));
440        }
441
442        // Generate 1D grids for each dimension
443        let mut dim_grids = Vec::with_capacity(n_features);
444        for j in 0..n_features {
445            let col = x.column(j);
446            let min_val = col.fold(f64::INFINITY, |a, &b| a.min(b));
447            let max_val = col.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
448
449            let mut grid_1d = Array1::zeros(self.grid_sizes[j]);
450            for i in 0..self.grid_sizes[j] {
451                if self.grid_sizes[j] == 1 {
452                    grid_1d[i] = (min_val + max_val) / 2.0;
453                } else {
454                    grid_1d[i] =
455                        min_val + i as f64 * (max_val - min_val) / (self.grid_sizes[j] - 1) as f64;
456                }
457            }
458            dim_grids.push(grid_1d);
459        }
460
461        // Compute 1D kernel matrices
462        let mut kernel_matrices_1d = Vec::with_capacity(n_features);
463        for j in 0..n_features {
464            let grid_1d_2d = dim_grids[j].clone().insert_axis(Axis(1));
465            let k_1d = self.kernel.kernel_matrix(&grid_1d_2d, &grid_1d_2d);
466            kernel_matrices_1d.push(k_1d);
467        }
468
469        Ok(FittedTensorSKI {
470            dim_grids,
471            kernel_matrices_1d,
472            kernel: self.kernel.clone(),
473            alpha: Array1::zeros(1), // Would be computed from Kronecker system
474        })
475    }
476}
477
478/// Fitted tensor SKI with Kronecker structure
479#[derive(Debug, Clone)]
480pub struct FittedTensorSKI<K: SparseKernel> {
481    /// 1D grids for each dimension
482    pub dim_grids: Vec<Array1<f64>>,
483    /// 1D kernel matrices for each dimension
484    pub kernel_matrices_1d: Vec<Array2<f64>>,
485    /// Kernel function
486    pub kernel: K,
487    /// Tensor coefficients
488    pub alpha: Array1<f64>,
489}
490
491#[allow(non_snake_case)]
492#[cfg(test)]
493mod tests {
494    use super::*;
495    use crate::sparse_gp::kernels::RBFKernel;
496    use approx::assert_abs_diff_eq;
497    use scirs2_core::ndarray::array;
498
499    #[test]
500    fn test_grid_generation() {
501        let kernel = RBFKernel::new(1.0, 1.0);
502        let ski = StructuredKernelInterpolation::new(vec![3, 2], kernel);
503
504        let x = array![[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]];
505        let grid_points = ski.generate_grid_points(&x).unwrap();
506
507        assert_eq!(grid_points.shape(), &[6, 2]); // 3 × 2 grid
508        assert!(grid_points.iter().all(|&x| x.is_finite()));
509    }
510
511    #[test]
512    fn test_linear_interpolation_weights() {
513        let kernel = RBFKernel::new(1.0, 1.0);
514        let ski = StructuredKernelInterpolation::new(vec![3, 3], kernel)
515            .interpolation(InterpolationMethod::Linear);
516
517        let x = array![[0.5, 0.5], [1.0, 1.0]];
518        let grid_points = array![
519            [0.0, 0.0],
520            [0.0, 1.0],
521            [0.0, 2.0],
522            [1.0, 0.0],
523            [1.0, 1.0],
524            [1.0, 2.0],
525            [2.0, 0.0],
526            [2.0, 1.0],
527            [2.0, 2.0]
528        ];
529        let ranges = vec![(0.0, 2.0), (0.0, 2.0)];
530
531        let weights = ski
532            .compute_interpolation_weights(&x, &grid_points, &ranges)
533            .unwrap();
534
535        assert_eq!(weights.shape(), &[2, 9]);
536
537        // Check that weights sum to 1 for each data point
538        for i in 0..2 {
539            let weight_sum = weights.row(i).sum();
540            assert_abs_diff_eq!(weight_sum, 1.0, epsilon = 1e-10);
541        }
542    }
543
544    #[test]
545    fn test_ski_fit() {
546        let kernel = RBFKernel::new(1.0, 1.0);
547        let ski = StructuredKernelInterpolation::new(vec![3, 3], kernel).noise_variance(0.1);
548
549        let x = array![[0.0, 0.0], [0.5, 0.5], [1.0, 1.0], [1.5, 1.5]];
550        let y = array![0.0, 0.25, 1.0, 2.25];
551
552        let fitted = ski.fit(&x, &y).unwrap();
553
554        assert_eq!(fitted.grid_points.nrows(), 9); // 3 × 3 grid
555        assert_eq!(fitted.alpha.len(), 9);
556        assert!(fitted.alpha.iter().all(|&x| x.is_finite()));
557    }
558
559    #[test]
560    fn test_ski_prediction() {
561        let kernel = RBFKernel::new(1.0, 1.0);
562        let ski = StructuredKernelInterpolation::new(vec![3, 3], kernel).noise_variance(0.1);
563
564        let x = array![[0.0, 0.0], [0.5, 0.5], [1.0, 1.0], [1.5, 1.5]];
565        let y = array![0.0, 0.25, 1.0, 2.25];
566
567        let fitted = ski.fit(&x, &y).unwrap();
568        let x_test = array![[0.25, 0.25], [0.75, 0.75]];
569        let predictions = fitted.predict(&x_test).unwrap();
570
571        assert_eq!(predictions.len(), 2);
572        assert!(predictions.iter().all(|&x| x.is_finite()));
573    }
574
575    #[test]
576    fn test_ski_with_variance() {
577        let kernel = RBFKernel::new(1.0, 1.0);
578        let ski = StructuredKernelInterpolation::new(vec![3, 3], kernel).noise_variance(0.1);
579
580        let x = array![[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]];
581        let y = array![0.0, 1.0, 4.0];
582
583        let fitted = ski.fit(&x, &y).unwrap();
584        let x_test = array![[0.5, 0.5], [1.5, 1.5]];
585        let (mean, var) = fitted.predict_with_variance(&x_test).unwrap();
586
587        assert_eq!(mean.len(), 2);
588        assert_eq!(var.len(), 2);
589        assert!(mean.iter().all(|&x| x.is_finite()));
590        assert!(var.iter().all(|&x| x >= 0.0 && x.is_finite()));
591    }
592
593    #[test]
594    fn test_cubic_interpolation() {
595        let kernel = RBFKernel::new(1.0, 1.0);
596        let ski = StructuredKernelInterpolation::new(vec![4, 4], kernel)
597            .interpolation(InterpolationMethod::Cubic);
598
599        let x = array![[0.5, 0.5], [1.5, 1.5]];
600        let grid_points = ski.generate_grid_points(&x).unwrap();
601        let ranges = vec![(0.0, 2.0), (0.0, 2.0)];
602
603        let weights = ski
604            .compute_interpolation_weights(&x, &grid_points, &ranges)
605            .unwrap();
606
607        assert_eq!(weights.shape(), &[2, 16]); // 4 × 4 grid
608
609        // Check that weights are non-negative and sum to 1
610        for i in 0..2 {
611            let weight_sum = weights.row(i).sum();
612            assert_abs_diff_eq!(weight_sum, 1.0, epsilon = 1e-10);
613            assert!(weights.row(i).iter().all(|&w| w >= -1e-12)); // Allow small numerical errors
614        }
615    }
616
617    #[test]
618    fn test_tensor_ski_creation() {
619        let kernel = RBFKernel::new(1.0, 1.0);
620        let tensor_ski = TensorSKI::new(vec![4, 3, 5], kernel);
621
622        assert_eq!(tensor_ski.grid_sizes, vec![4, 3, 5]);
623        assert!(tensor_ski.use_kronecker);
624    }
625
626    #[test]
627    fn test_grid_size_inference() {
628        let kernel = RBFKernel::new(1.0, 1.0);
629        let ski = StructuredKernelInterpolation::new(vec![3, 2], kernel);
630
631        let _x = array![[0.0, 0.0], [1.0, 1.0]];
632        let fitted_ski = FittedSKI {
633            grid_points: array![
634                [0.0, 0.0],
635                [0.0, 1.0],
636                [1.0, 0.0],
637                [1.0, 1.0],
638                [2.0, 0.0],
639                [2.0, 1.0]
640            ],
641            weights: Array2::zeros((2, 6)),
642            kernel: ski.kernel.clone(),
643            alpha: Array1::zeros(6),
644        };
645
646        let inferred_grid_size = fitted_ski.infer_grid_size_from_points().unwrap();
647        assert_eq!(inferred_grid_size, vec![3, 2]);
648    }
649}