sklears_kernel_approximation/sparse_gp/
core.rs1use scirs2_core::ndarray::{Array1, Array2};
7use sklears_core::error::{Result, SklearsError};
8use std::fmt;
9
10#[derive(Debug, Clone)]
12pub enum SparseApproximation {
13 SubsetOfRegressors,
15
16 FullyIndependentConditional,
18
19 PartiallyIndependentConditional {
21 block_size: usize,
23 },
24
25 VariationalFreeEnergy {
27 whitened: bool,
29 natural_gradients: bool,
31 },
32}
33
34#[derive(Debug, Clone)]
36pub enum InducingPointStrategy {
37 Random,
39
40 KMeans,
42
43 UniformGrid {
45 grid_size: Vec<usize>,
47 },
48
49 GreedyVariance,
51
52 UserSpecified(Array2<f64>),
54}
55
56#[derive(Debug, Clone)]
58pub enum ScalableInferenceMethod {
59 Direct,
61
62 PreconditionedCG {
64 max_iter: usize,
66 tol: f64,
68 preconditioner: PreconditionerType,
70 },
71
72 Lanczos {
74 num_vectors: usize,
76 tol: f64,
78 },
79}
80
81#[derive(Debug, Clone)]
83pub enum PreconditionerType {
84 None,
86
87 Diagonal,
89
90 IncompleteCholesky {
92 fill_factor: f64,
94 },
95
96 SSOR {
98 omega: f64,
100 },
101}
102
103#[derive(Debug, Clone)]
105pub enum InterpolationMethod {
106 Linear,
108 Cubic,
110}
111
112#[derive(Debug, Clone)]
114pub struct SparseGaussianProcess<K> {
115 pub num_inducing: usize,
117
118 pub kernel: K,
120
121 pub approximation: SparseApproximation,
123
124 pub inducing_strategy: InducingPointStrategy,
126
127 pub noise_variance: f64,
129
130 pub max_iter: usize,
132
133 pub tol: f64,
135}
136
137#[derive(Debug, Clone)]
139pub struct FittedSparseGP<K> {
140 pub inducing_points: Array2<f64>,
142
143 pub kernel: K,
145
146 pub approximation: SparseApproximation,
148
149 pub alpha: Array1<f64>,
151
152 pub k_mm_inv: Array2<f64>,
154
155 pub noise_variance: f64,
157
158 pub variational_params: Option<VariationalParams>,
160}
161
162#[derive(Debug, Clone)]
164pub struct VariationalParams {
165 pub mean: Array1<f64>,
167
168 pub cov_factor: Array2<f64>,
170
171 pub elbo: f64,
173
174 pub kl_divergence: f64,
176
177 pub log_likelihood: f64,
179}
180
181#[derive(Debug, Clone)]
183pub struct StructuredKernelInterpolation<K> {
184 pub grid_size: Vec<usize>,
186
187 pub kernel: K,
189
190 pub noise_variance: f64,
192
193 pub interpolation: InterpolationMethod,
195}
196
197#[derive(Debug, Clone)]
199pub struct FittedSKI<K> {
200 pub grid_points: Array2<f64>,
202
203 pub weights: Array2<f64>,
205
206 pub kernel: K,
208
209 pub alpha: Array1<f64>,
211}
212
213#[derive(Debug, Clone)]
215pub struct OptimizationConfig {
216 pub max_iter: usize,
218
219 pub tolerance: f64,
221
222 pub learning_rate: f64,
224
225 pub natural_gradients: bool,
227}
228
229impl Default for OptimizationConfig {
230 fn default() -> Self {
231 Self {
232 max_iter: 100,
233 tolerance: 1e-6,
234 learning_rate: 0.01,
235 natural_gradients: false,
236 }
237 }
238}
239
240#[derive(Debug)]
242pub enum SparseGPError {
243 InvalidInducingPoints(String),
245
246 NumericalInstability(String),
248
249 ConvergenceFailure(String),
251
252 InvalidApproximation(String),
254}
255
256impl fmt::Display for SparseGPError {
257 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
258 match self {
259 SparseGPError::InvalidInducingPoints(msg) => {
260 write!(f, "Invalid inducing points: {}", msg)
261 }
262 SparseGPError::NumericalInstability(msg) => {
263 write!(f, "Numerical instability: {}", msg)
264 }
265 SparseGPError::ConvergenceFailure(msg) => {
266 write!(f, "Convergence failure: {}", msg)
267 }
268 SparseGPError::InvalidApproximation(msg) => {
269 write!(f, "Invalid approximation: {}", msg)
270 }
271 }
272 }
273}
274
275impl std::error::Error for SparseGPError {}
276
277impl From<SparseGPError> for SklearsError {
279 fn from(err: SparseGPError) -> Self {
280 match err {
281 SparseGPError::InvalidInducingPoints(msg) => SklearsError::InvalidInput(msg),
282 SparseGPError::NumericalInstability(msg) => SklearsError::NumericalError(msg),
283 SparseGPError::ConvergenceFailure(msg) => SklearsError::NumericalError(msg),
284 SparseGPError::InvalidApproximation(msg) => SklearsError::InvalidInput(msg),
285 }
286 }
287}
288
289impl<K> SparseGaussianProcess<K> {
291 pub fn approximation(mut self, approximation: SparseApproximation) -> Self {
293 self.approximation = approximation;
294 self
295 }
296
297 pub fn inducing_strategy(mut self, strategy: InducingPointStrategy) -> Self {
299 self.inducing_strategy = strategy;
300 self
301 }
302
303 pub fn noise_variance(mut self, noise_variance: f64) -> Self {
305 self.noise_variance = noise_variance;
306 self
307 }
308
309 pub fn optimization_params(mut self, max_iter: usize, tol: f64) -> Self {
311 self.max_iter = max_iter;
312 self.tol = tol;
313 self
314 }
315}
316
317impl<K> StructuredKernelInterpolation<K> {
319 pub fn noise_variance(mut self, noise_variance: f64) -> Self {
321 self.noise_variance = noise_variance;
322 self
323 }
324
325 pub fn interpolation(mut self, interpolation: InterpolationMethod) -> Self {
327 self.interpolation = interpolation;
328 self
329 }
330}
331
332pub mod utils {
334 use super::*;
335
336 pub fn validate_inducing_points(
338 num_inducing: usize,
339 n_features: usize,
340 strategy: &InducingPointStrategy,
341 ) -> Result<()> {
342 match strategy {
343 InducingPointStrategy::UniformGrid { grid_size } => {
344 if grid_size.len() != n_features {
345 return Err(SklearsError::InvalidInput(
346 "Grid size must match number of features".to_string(),
347 ));
348 }
349
350 let total_points: usize = grid_size.iter().product();
351 if total_points != num_inducing {
352 return Err(SklearsError::InvalidInput(format!(
353 "Grid size product {} must equal num_inducing {}",
354 total_points, num_inducing
355 )));
356 }
357 }
358 InducingPointStrategy::UserSpecified(points) => {
359 if points.nrows() != num_inducing {
360 return Err(SklearsError::InvalidInput(
361 "User-specified points must match num_inducing".to_string(),
362 ));
363 }
364 if points.ncols() != n_features {
365 return Err(SklearsError::InvalidInput(
366 "User-specified points must match number of features".to_string(),
367 ));
368 }
369 }
370 _ => {} }
372
373 Ok(())
374 }
375
376 pub fn check_matrix_stability(matrix: &Array2<f64>, name: &str) -> Result<()> {
378 let has_nan = matrix.iter().any(|&x| x.is_nan());
379 let has_inf = matrix.iter().any(|&x| x.is_infinite());
380
381 if has_nan || has_inf {
382 return Err(SklearsError::NumericalError(format!(
383 "Matrix {} contains NaN or infinite values",
384 name
385 )));
386 }
387
388 Ok(())
389 }
390
391 pub fn estimate_condition_number(matrix: &Array2<f64>) -> f64 {
393 let diag_sum: f64 = matrix.diag().iter().map(|x| x.abs()).sum();
395 let off_diag_sum: f64 = matrix.iter().map(|x| x.abs()).sum::<f64>() - diag_sum;
396
397 if diag_sum > 0.0 {
398 (diag_sum + off_diag_sum) / diag_sum
399 } else {
400 f64::INFINITY
401 }
402 }
403}