sklears_kernel_approximation/sparse_gp/
core.rs

1//! Core types, enums, and structures for sparse Gaussian Process implementation
2//!
3//! This module provides the foundational data structures and type definitions
4//! for sparse Gaussian Process approximations with SIMD acceleration.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use sklears_core::error::{Result, SklearsError};
8use std::fmt;
9
10/// Available sparse approximation methods for Gaussian Processes
11#[derive(Debug, Clone)]
12pub enum SparseApproximation {
13    /// Subset of Regressors (SoR) - Uses subset of training points as inducing points
14    SubsetOfRegressors,
15
16    /// Fully Independent Conditional (FIC) - Assumes independence given inducing points
17    FullyIndependentConditional,
18
19    /// Partially Independent Conditional (PIC) - Block-diagonal conditional independence
20    PartiallyIndependentConditional {
21        /// Block size for PIC approximation
22        block_size: usize,
23    },
24
25    /// Variational Free Energy (VFE) - Variational sparse approximation
26    VariationalFreeEnergy {
27        /// Use whitened representation
28        whitened: bool,
29        /// Use natural gradients for optimization
30        natural_gradients: bool,
31    },
32}
33
34/// Strategies for selecting inducing points in sparse GP approximations
35#[derive(Debug, Clone)]
36pub enum InducingPointStrategy {
37    /// Random selection from training data
38    Random,
39
40    /// K-means clustering to find representative points
41    KMeans,
42
43    /// Uniform grid over input space
44    UniformGrid {
45        /// Grid size for each dimension
46        grid_size: Vec<usize>,
47    },
48
49    /// Greedy selection based on maximum posterior variance
50    GreedyVariance,
51
52    /// User-specified inducing points
53    UserSpecified(Array2<f64>),
54}
55
56/// Scalable inference methods for large-scale sparse GP prediction
57#[derive(Debug, Clone)]
58pub enum ScalableInferenceMethod {
59    /// Direct matrix inversion (for small problems)
60    Direct,
61
62    /// Preconditioned Conjugate Gradient solver
63    PreconditionedCG {
64        /// Maximum number of iterations
65        max_iter: usize,
66        /// Convergence tolerance
67        tol: f64,
68        /// Preconditioner type
69        preconditioner: PreconditionerType,
70    },
71
72    /// Lanczos eigendecomposition method
73    Lanczos {
74        /// Number of Lanczos vectors to compute
75        num_vectors: usize,
76        /// Tolerance for convergence
77        tol: f64,
78    },
79}
80
81/// Preconditioner types for iterative solvers
82#[derive(Debug, Clone)]
83pub enum PreconditionerType {
84    /// No preconditioning
85    None,
86
87    /// Diagonal preconditioning M = diag(A)^(-1)
88    Diagonal,
89
90    /// Incomplete Cholesky factorization
91    IncompleteCholesky {
92        /// Fill factor for sparsity control
93        fill_factor: f64,
94    },
95
96    /// Symmetric Successive Over-Relaxation (SSOR)
97    SSOR {
98        /// Relaxation parameter
99        omega: f64,
100    },
101}
102
103/// Interpolation methods for structured kernel approximations
104#[derive(Debug, Clone)]
105pub enum InterpolationMethod {
106    /// Linear interpolation
107    Linear,
108    /// Cubic interpolation
109    Cubic,
110}
111
112/// Core sparse Gaussian Process structure with configuration parameters
113#[derive(Debug, Clone)]
114pub struct SparseGaussianProcess<K> {
115    /// Number of inducing points
116    pub num_inducing: usize,
117
118    /// Kernel function
119    pub kernel: K,
120
121    /// Sparse approximation method
122    pub approximation: SparseApproximation,
123
124    /// Strategy for selecting inducing points
125    pub inducing_strategy: InducingPointStrategy,
126
127    /// Observation noise variance
128    pub noise_variance: f64,
129
130    /// Maximum optimization iterations
131    pub max_iter: usize,
132
133    /// Convergence tolerance
134    pub tol: f64,
135}
136
137/// Fitted sparse Gaussian Process with learned parameters
138#[derive(Debug, Clone)]
139pub struct FittedSparseGP<K> {
140    /// Inducing point locations
141    pub inducing_points: Array2<f64>,
142
143    /// Kernel function with learned parameters
144    pub kernel: K,
145
146    /// Sparse approximation method used
147    pub approximation: SparseApproximation,
148
149    /// Precomputed alpha coefficients
150    pub alpha: Array1<f64>,
151
152    /// Inverse of K_mm (inducing point kernel matrix)
153    pub k_mm_inv: Array2<f64>,
154
155    /// Noise variance
156    pub noise_variance: f64,
157
158    /// Variational parameters (if using VFE)
159    pub variational_params: Option<VariationalParams>,
160}
161
162/// Variational parameters for Variational Free Energy approximation
163#[derive(Debug, Clone)]
164pub struct VariationalParams {
165    /// Variational mean parameter
166    pub mean: Array1<f64>,
167
168    /// Cholesky factor of variational covariance
169    pub cov_factor: Array2<f64>,
170
171    /// Evidence Lower BOund (ELBO) value
172    pub elbo: f64,
173
174    /// KL divergence term
175    pub kl_divergence: f64,
176
177    /// Log likelihood term
178    pub log_likelihood: f64,
179}
180
181/// Structured Kernel Interpolation (KISS-GP) for fast structured GP inference
182#[derive(Debug, Clone)]
183pub struct StructuredKernelInterpolation<K> {
184    /// Grid size for each dimension
185    pub grid_size: Vec<usize>,
186
187    /// Kernel function
188    pub kernel: K,
189
190    /// Noise variance
191    pub noise_variance: f64,
192
193    /// Interpolation method
194    pub interpolation: InterpolationMethod,
195}
196
197/// Fitted structured kernel interpolation
198#[derive(Debug, Clone)]
199pub struct FittedSKI<K> {
200    /// Grid points
201    pub grid_points: Array2<f64>,
202
203    /// Interpolation weights
204    pub weights: Array2<f64>,
205
206    /// Kernel function
207    pub kernel: K,
208
209    /// Precomputed alpha
210    pub alpha: Array1<f64>,
211}
212
213/// Configuration for sparse GP optimization
214#[derive(Debug, Clone)]
215pub struct OptimizationConfig {
216    /// Maximum number of iterations
217    pub max_iter: usize,
218
219    /// Convergence tolerance
220    pub tolerance: f64,
221
222    /// Learning rate for gradient-based methods
223    pub learning_rate: f64,
224
225    /// Whether to use natural gradients
226    pub natural_gradients: bool,
227}
228
229impl Default for OptimizationConfig {
230    fn default() -> Self {
231        Self {
232            max_iter: 100,
233            tolerance: 1e-6,
234            learning_rate: 0.01,
235            natural_gradients: false,
236        }
237    }
238}
239
240/// Error types specific to sparse GP operations
241#[derive(Debug)]
242pub enum SparseGPError {
243    /// Invalid inducing point configuration
244    InvalidInducingPoints(String),
245
246    /// Numerical instability in computation
247    NumericalInstability(String),
248
249    /// Convergence failure
250    ConvergenceFailure(String),
251
252    /// Invalid approximation parameters
253    InvalidApproximation(String),
254}
255
256impl fmt::Display for SparseGPError {
257    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
258        match self {
259            SparseGPError::InvalidInducingPoints(msg) => {
260                write!(f, "Invalid inducing points: {}", msg)
261            }
262            SparseGPError::NumericalInstability(msg) => {
263                write!(f, "Numerical instability: {}", msg)
264            }
265            SparseGPError::ConvergenceFailure(msg) => {
266                write!(f, "Convergence failure: {}", msg)
267            }
268            SparseGPError::InvalidApproximation(msg) => {
269                write!(f, "Invalid approximation: {}", msg)
270            }
271        }
272    }
273}
274
275impl std::error::Error for SparseGPError {}
276
277/// Convert SparseGPError to SklearsError
278impl From<SparseGPError> for SklearsError {
279    fn from(err: SparseGPError) -> Self {
280        match err {
281            SparseGPError::InvalidInducingPoints(msg) => SklearsError::InvalidInput(msg),
282            SparseGPError::NumericalInstability(msg) => SklearsError::NumericalError(msg),
283            SparseGPError::ConvergenceFailure(msg) => SklearsError::NumericalError(msg),
284            SparseGPError::InvalidApproximation(msg) => SklearsError::InvalidInput(msg),
285        }
286    }
287}
288
289/// Builder-style methods for sparse GP configuration
290impl<K> SparseGaussianProcess<K> {
291    /// Set the sparse approximation method
292    pub fn approximation(mut self, approximation: SparseApproximation) -> Self {
293        self.approximation = approximation;
294        self
295    }
296
297    /// Set the inducing point selection strategy
298    pub fn inducing_strategy(mut self, strategy: InducingPointStrategy) -> Self {
299        self.inducing_strategy = strategy;
300        self
301    }
302
303    /// Set the observation noise variance
304    pub fn noise_variance(mut self, noise_variance: f64) -> Self {
305        self.noise_variance = noise_variance;
306        self
307    }
308
309    /// Set optimization parameters
310    pub fn optimization_params(mut self, max_iter: usize, tol: f64) -> Self {
311        self.max_iter = max_iter;
312        self.tol = tol;
313        self
314    }
315}
316
317/// Builder-style methods for SKI configuration
318impl<K> StructuredKernelInterpolation<K> {
319    /// Set noise variance
320    pub fn noise_variance(mut self, noise_variance: f64) -> Self {
321        self.noise_variance = noise_variance;
322        self
323    }
324
325    /// Set interpolation method
326    pub fn interpolation(mut self, interpolation: InterpolationMethod) -> Self {
327        self.interpolation = interpolation;
328        self
329    }
330}
331
332/// Helper functions for sparse GP operations
333pub mod utils {
334    use super::*;
335
336    /// Validate inducing point configuration
337    pub fn validate_inducing_points(
338        num_inducing: usize,
339        n_features: usize,
340        strategy: &InducingPointStrategy,
341    ) -> Result<()> {
342        match strategy {
343            InducingPointStrategy::UniformGrid { grid_size } => {
344                if grid_size.len() != n_features {
345                    return Err(SklearsError::InvalidInput(
346                        "Grid size must match number of features".to_string(),
347                    ));
348                }
349
350                let total_points: usize = grid_size.iter().product();
351                if total_points != num_inducing {
352                    return Err(SklearsError::InvalidInput(format!(
353                        "Grid size product {} must equal num_inducing {}",
354                        total_points, num_inducing
355                    )));
356                }
357            }
358            InducingPointStrategy::UserSpecified(points) => {
359                if points.nrows() != num_inducing {
360                    return Err(SklearsError::InvalidInput(
361                        "User-specified points must match num_inducing".to_string(),
362                    ));
363                }
364                if points.ncols() != n_features {
365                    return Err(SklearsError::InvalidInput(
366                        "User-specified points must match number of features".to_string(),
367                    ));
368                }
369            }
370            _ => {} // Other strategies are validated during execution
371        }
372
373        Ok(())
374    }
375
376    /// Check for numerical stability in matrices
377    pub fn check_matrix_stability(matrix: &Array2<f64>, name: &str) -> Result<()> {
378        let has_nan = matrix.iter().any(|&x| x.is_nan());
379        let has_inf = matrix.iter().any(|&x| x.is_infinite());
380
381        if has_nan || has_inf {
382            return Err(SklearsError::NumericalError(format!(
383                "Matrix {} contains NaN or infinite values",
384                name
385            )));
386        }
387
388        Ok(())
389    }
390
391    /// Compute matrix condition number estimate
392    pub fn estimate_condition_number(matrix: &Array2<f64>) -> f64 {
393        // Simple condition number estimate using diagonal dominance
394        let diag_sum: f64 = matrix.diag().iter().map(|x| x.abs()).sum();
395        let off_diag_sum: f64 = matrix.iter().map(|x| x.abs()).sum::<f64>() - diag_sum;
396
397        if diag_sum > 0.0 {
398            (diag_sum + off_diag_sum) / diag_sum
399        } else {
400            f64::INFINITY
401        }
402    }
403}