1use crate::sparse_gp::core::*;
8use crate::sparse_gp::kernels::{KernelOps, SparseKernel};
9use scirs2_core::ndarray::{Array1, Array2, Axis};
10use sklears_core::error::{Result, SklearsError};
11use sklears_core::traits::{Fit, Predict};
12use std::collections::HashSet;
13
14impl<K: SparseKernel> StructuredKernelInterpolation<K> {
16 pub fn new(grid_size: Vec<usize>, kernel: K) -> Self {
18 Self {
19 grid_size,
20 kernel,
21 noise_variance: 1e-6,
22 interpolation: InterpolationMethod::Linear,
23 }
24 }
25
26 pub fn generate_grid_points(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
28 let n_features = x.ncols();
29 if self.grid_size.len() != n_features {
30 return Err(SklearsError::InvalidInput(
31 "Grid size dimension mismatch".to_string(),
32 ));
33 }
34
35 let total_grid_points: usize = self.grid_size.iter().product();
36 let mut grid_points = Array2::zeros((total_grid_points, n_features));
37
38 let mut ranges = Vec::with_capacity(n_features);
40 for j in 0..n_features {
41 let col = x.column(j);
42 let min_val = col.fold(f64::INFINITY, |a, &b| a.min(b));
43 let max_val = col.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
44 ranges.push((min_val, max_val));
45 }
46
47 let mut point_idx = 0;
49 self.generate_grid_recursive(
50 &mut grid_points,
51 &ranges,
52 &mut vec![0; n_features],
53 0,
54 &mut point_idx,
55 );
56
57 Ok(grid_points)
58 }
59
60 fn generate_grid_recursive(
62 &self,
63 grid_points: &mut Array2<f64>,
64 ranges: &[(f64, f64)],
65 current_indices: &mut Vec<usize>,
66 dim: usize,
67 point_idx: &mut usize,
68 ) {
69 if dim == ranges.len() {
70 for (j, &idx) in current_indices.iter().enumerate() {
72 let (min_val, max_val) = ranges[j];
73 let grid_val = if self.grid_size[j] == 1 {
74 (min_val + max_val) / 2.0
75 } else {
76 min_val + idx as f64 * (max_val - min_val) / (self.grid_size[j] - 1) as f64
77 };
78 grid_points[(*point_idx, j)] = grid_val;
79 }
80 *point_idx += 1;
81 return;
82 }
83
84 for i in 0..self.grid_size[dim] {
85 current_indices[dim] = i;
86 self.generate_grid_recursive(grid_points, ranges, current_indices, dim + 1, point_idx);
87 }
88 }
89
90 pub fn compute_interpolation_weights(
92 &self,
93 x: &Array2<f64>,
94 grid_points: &Array2<f64>,
95 ranges: &[(f64, f64)],
96 ) -> Result<Array2<f64>> {
97 let n = x.nrows();
98 let n_grid = grid_points.nrows();
99 let _n_features = x.ncols();
100
101 let mut weights = Array2::zeros((n, n_grid));
102
103 match self.interpolation {
104 InterpolationMethod::Linear => {
105 self.compute_linear_weights(x, grid_points, ranges, &mut weights)?;
106 }
107 InterpolationMethod::Cubic => {
108 self.compute_cubic_weights(x, grid_points, ranges, &mut weights)?;
109 }
110 }
111
112 for i in 0..n {
114 let weight_sum = weights.row(i).sum();
115 if weight_sum > 1e-12 {
116 for g in 0..n_grid {
117 weights[(i, g)] /= weight_sum;
118 }
119 }
120 }
121
122 Ok(weights)
123 }
124
125 fn compute_linear_weights(
127 &self,
128 x: &Array2<f64>,
129 grid_points: &Array2<f64>,
130 ranges: &[(f64, f64)],
131 weights: &mut Array2<f64>,
132 ) -> Result<()> {
133 let n = x.nrows();
134 let n_grid = grid_points.nrows();
135 let n_features = x.ncols();
136
137 for i in 0..n {
138 for g in 0..n_grid {
139 let mut weight = 1.0;
140 let mut valid = true;
141
142 for j in 0..n_features {
143 let x_val = x[(i, j)];
144 let grid_val = grid_points[(g, j)];
145 let (min_val, max_val) = ranges[j];
146
147 let grid_spacing = if self.grid_size[j] == 1 {
148 max_val - min_val
149 } else {
150 (max_val - min_val) / (self.grid_size[j] - 1) as f64
151 };
152
153 let distance = (x_val - grid_val).abs();
154
155 if distance > grid_spacing + 1e-12 {
157 valid = false;
158 break;
159 }
160
161 if grid_spacing > 1e-12 {
163 weight *= 1.0 - distance / grid_spacing;
164 }
165 }
166
167 if valid {
168 weights[(i, g)] = weight;
169 }
170 }
171 }
172
173 Ok(())
174 }
175
176 fn compute_cubic_weights(
178 &self,
179 x: &Array2<f64>,
180 grid_points: &Array2<f64>,
181 ranges: &[(f64, f64)],
182 weights: &mut Array2<f64>,
183 ) -> Result<()> {
184 let n = x.nrows();
185 let n_grid = grid_points.nrows();
186 let n_features = x.ncols();
187
188 for i in 0..n {
189 for g in 0..n_grid {
190 let mut weight = 1.0;
191 let mut valid = true;
192
193 for j in 0..n_features {
194 let x_val = x[(i, j)];
195 let grid_val = grid_points[(g, j)];
196 let (min_val, max_val) = ranges[j];
197
198 let grid_spacing = if self.grid_size[j] == 1 {
199 max_val - min_val
200 } else {
201 (max_val - min_val) / (self.grid_size[j] - 1) as f64
202 };
203
204 let distance = (x_val - grid_val).abs();
205
206 if distance > 2.0 * grid_spacing + 1e-12 {
208 valid = false;
209 break;
210 }
211
212 if grid_spacing > 1e-12 {
214 let t = distance / grid_spacing;
215 let cubic_weight = if t <= 1.0 {
216 1.0 - 1.5 * t * t + 0.75 * t * t * t
217 } else if t <= 2.0 {
218 0.25 * (2.0 - t).powi(3)
219 } else {
220 0.0
221 };
222 weight *= cubic_weight;
223 }
224 }
225
226 if valid && weight > 1e-12 {
227 weights[(i, g)] = weight;
228 }
229 }
230 }
231
232 Ok(())
233 }
234}
235
236impl<K: SparseKernel> Fit<Array2<f64>, Array1<f64>> for StructuredKernelInterpolation<K> {
238 type Fitted = FittedSKI<K>;
239
240 fn fit(self, x: &Array2<f64>, y: &Array1<f64>) -> Result<Self::Fitted> {
241 let grid_points = self.generate_grid_points(x)?;
243
244 let n_features = x.ncols();
246 let mut ranges = Vec::with_capacity(n_features);
247 for j in 0..n_features {
248 let col = x.column(j);
249 let min_val = col.fold(f64::INFINITY, |a, &b| a.min(b));
250 let max_val = col.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
251 ranges.push((min_val, max_val));
252 }
253
254 let weights = self.compute_interpolation_weights(x, &grid_points, &ranges)?;
256
257 let k_gg = self.compute_grid_kernel_matrix(&grid_points)?;
259
260 let mut k_gg_noise = k_gg;
262 let n_grid = grid_points.nrows();
263 for i in 0..n_grid {
264 k_gg_noise[(i, i)] += self.noise_variance;
265 }
266
267 let weighted_y = weights.t().dot(y);
269 let alpha = self.solve_structured_system(&k_gg_noise, &weighted_y)?;
270
271 Ok(FittedSKI {
272 grid_points,
273 weights,
274 kernel: self.kernel.clone(),
275 alpha,
276 })
277 }
278}
279
280impl<K: SparseKernel> StructuredKernelInterpolation<K> {
281 fn compute_grid_kernel_matrix(&self, grid_points: &Array2<f64>) -> Result<Array2<f64>> {
283 let n_features = grid_points.ncols();
284
285 if n_features == 1 || self.can_use_kronecker_structure() {
287 self.compute_kronecker_kernel_matrix(grid_points)
288 } else {
289 Ok(self.kernel.kernel_matrix(grid_points, grid_points))
291 }
292 }
293
294 fn can_use_kronecker_structure(&self) -> bool {
296 true
299 }
300
301 fn compute_kronecker_kernel_matrix(&self, grid_points: &Array2<f64>) -> Result<Array2<f64>> {
303 let n_features = grid_points.ncols();
304
305 if n_features == 1 {
306 return Ok(self.kernel.kernel_matrix(grid_points, grid_points));
308 }
309
310 Ok(self.kernel.kernel_matrix(grid_points, grid_points))
313 }
314
315 fn solve_structured_system(
317 &self,
318 k_matrix: &Array2<f64>,
319 rhs: &Array1<f64>,
320 ) -> Result<Array1<f64>> {
321 let k_inv = KernelOps::invert_using_cholesky(k_matrix)?;
324 Ok(k_inv.dot(rhs))
325 }
326}
327
328impl<K: SparseKernel> Predict<Array2<f64>, Array1<f64>> for FittedSKI<K> {
330 fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>> {
331 let n_features = x.ncols();
333 let mut ranges = Vec::with_capacity(n_features);
334 for j in 0..n_features {
335 let col = self.grid_points.column(j);
336 let min_val = col.fold(f64::INFINITY, |a, &b| a.min(b));
337 let max_val = col.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
338 ranges.push((min_val, max_val));
339 }
340
341 let grid_size = self.infer_grid_size_from_points()?;
343
344 let ski = StructuredKernelInterpolation {
345 grid_size,
346 kernel: self.kernel.clone(),
347 noise_variance: 1e-6,
348 interpolation: InterpolationMethod::Linear,
349 };
350
351 let test_weights = ski.compute_interpolation_weights(x, &self.grid_points, &ranges)?;
352 let predictions = test_weights.dot(&self.alpha);
353 Ok(predictions)
354 }
355}
356
357impl<K: SparseKernel> FittedSKI<K> {
358 fn infer_grid_size_from_points(&self) -> Result<Vec<usize>> {
360 let n_features = self.grid_points.ncols();
361 let mut grid_size = vec![1; n_features];
362
363 for j in 0..n_features {
364 let col = self.grid_points.column(j);
365 let unique_vals: HashSet<_> = col.iter().map(|&x| (x * 1e6).round() as i64).collect();
366 grid_size[j] = unique_vals.len();
367 }
368
369 Ok(grid_size)
370 }
371
372 pub fn predict_with_variance(&self, x: &Array2<f64>) -> Result<(Array1<f64>, Array1<f64>)> {
374 let n_features = x.ncols();
376 let mut ranges = Vec::with_capacity(n_features);
377 for j in 0..n_features {
378 let col = self.grid_points.column(j);
379 let min_val = col.fold(f64::INFINITY, |a, &b| a.min(b));
380 let max_val = col.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
381 ranges.push((min_val, max_val));
382 }
383
384 let grid_size = self.infer_grid_size_from_points()?;
385 let ski = StructuredKernelInterpolation {
386 grid_size,
387 kernel: self.kernel.clone(),
388 noise_variance: 1e-6,
389 interpolation: InterpolationMethod::Linear,
390 };
391
392 let test_weights = ski.compute_interpolation_weights(x, &self.grid_points, &ranges)?;
393
394 let pred_mean = test_weights.dot(&self.alpha);
396
397 let k_test_diag = self.kernel.kernel_diagonal(x);
399 let pred_var = k_test_diag; Ok((pred_mean, pred_var))
402 }
403}
404
405pub struct TensorSKI<K: SparseKernel> {
407 pub grid_sizes: Vec<usize>,
409 pub kernel: K,
411 pub noise_variance: f64,
413 pub use_kronecker: bool,
415}
416
417impl<K: SparseKernel> TensorSKI<K> {
418 pub fn new(grid_sizes: Vec<usize>, kernel: K) -> Self {
419 Self {
420 grid_sizes,
421 kernel,
422 noise_variance: 1e-6,
423 use_kronecker: true,
424 }
425 }
426
427 pub fn fit_tensor(&self, x: &Array2<f64>, _y: &Array1<f64>) -> Result<FittedTensorSKI<K>> {
429 if !self.use_kronecker {
430 return Err(SklearsError::InvalidInput(
431 "Tensor SKI requires Kronecker structure".to_string(),
432 ));
433 }
434
435 let n_features = x.ncols();
436 if self.grid_sizes.len() != n_features {
437 return Err(SklearsError::InvalidInput(
438 "Grid sizes must match number of features".to_string(),
439 ));
440 }
441
442 let mut dim_grids = Vec::with_capacity(n_features);
444 for j in 0..n_features {
445 let col = x.column(j);
446 let min_val = col.fold(f64::INFINITY, |a, &b| a.min(b));
447 let max_val = col.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
448
449 let mut grid_1d = Array1::zeros(self.grid_sizes[j]);
450 for i in 0..self.grid_sizes[j] {
451 if self.grid_sizes[j] == 1 {
452 grid_1d[i] = (min_val + max_val) / 2.0;
453 } else {
454 grid_1d[i] =
455 min_val + i as f64 * (max_val - min_val) / (self.grid_sizes[j] - 1) as f64;
456 }
457 }
458 dim_grids.push(grid_1d);
459 }
460
461 let mut kernel_matrices_1d = Vec::with_capacity(n_features);
463 for j in 0..n_features {
464 let grid_1d_2d = dim_grids[j].clone().insert_axis(Axis(1));
465 let k_1d = self.kernel.kernel_matrix(&grid_1d_2d, &grid_1d_2d);
466 kernel_matrices_1d.push(k_1d);
467 }
468
469 Ok(FittedTensorSKI {
470 dim_grids,
471 kernel_matrices_1d,
472 kernel: self.kernel.clone(),
473 alpha: Array1::zeros(1), })
475 }
476}
477
478#[derive(Debug, Clone)]
480pub struct FittedTensorSKI<K: SparseKernel> {
481 pub dim_grids: Vec<Array1<f64>>,
483 pub kernel_matrices_1d: Vec<Array2<f64>>,
485 pub kernel: K,
487 pub alpha: Array1<f64>,
489}
490
491#[allow(non_snake_case)]
492#[cfg(test)]
493mod tests {
494 use super::*;
495 use crate::sparse_gp::kernels::RBFKernel;
496 use approx::assert_abs_diff_eq;
497 use scirs2_core::ndarray::array;
498
499 #[test]
500 fn test_grid_generation() {
501 let kernel = RBFKernel::new(1.0, 1.0);
502 let ski = StructuredKernelInterpolation::new(vec![3, 2], kernel);
503
504 let x = array![[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]];
505 let grid_points = ski.generate_grid_points(&x).unwrap();
506
507 assert_eq!(grid_points.shape(), &[6, 2]); assert!(grid_points.iter().all(|&x| x.is_finite()));
509 }
510
511 #[test]
512 fn test_linear_interpolation_weights() {
513 let kernel = RBFKernel::new(1.0, 1.0);
514 let ski = StructuredKernelInterpolation::new(vec![3, 3], kernel)
515 .interpolation(InterpolationMethod::Linear);
516
517 let x = array![[0.5, 0.5], [1.0, 1.0]];
518 let grid_points = array![
519 [0.0, 0.0],
520 [0.0, 1.0],
521 [0.0, 2.0],
522 [1.0, 0.0],
523 [1.0, 1.0],
524 [1.0, 2.0],
525 [2.0, 0.0],
526 [2.0, 1.0],
527 [2.0, 2.0]
528 ];
529 let ranges = vec![(0.0, 2.0), (0.0, 2.0)];
530
531 let weights = ski
532 .compute_interpolation_weights(&x, &grid_points, &ranges)
533 .unwrap();
534
535 assert_eq!(weights.shape(), &[2, 9]);
536
537 for i in 0..2 {
539 let weight_sum = weights.row(i).sum();
540 assert_abs_diff_eq!(weight_sum, 1.0, epsilon = 1e-10);
541 }
542 }
543
544 #[test]
545 fn test_ski_fit() {
546 let kernel = RBFKernel::new(1.0, 1.0);
547 let ski = StructuredKernelInterpolation::new(vec![3, 3], kernel).noise_variance(0.1);
548
549 let x = array![[0.0, 0.0], [0.5, 0.5], [1.0, 1.0], [1.5, 1.5]];
550 let y = array![0.0, 0.25, 1.0, 2.25];
551
552 let fitted = ski.fit(&x, &y).unwrap();
553
554 assert_eq!(fitted.grid_points.nrows(), 9); assert_eq!(fitted.alpha.len(), 9);
556 assert!(fitted.alpha.iter().all(|&x| x.is_finite()));
557 }
558
559 #[test]
560 fn test_ski_prediction() {
561 let kernel = RBFKernel::new(1.0, 1.0);
562 let ski = StructuredKernelInterpolation::new(vec![3, 3], kernel).noise_variance(0.1);
563
564 let x = array![[0.0, 0.0], [0.5, 0.5], [1.0, 1.0], [1.5, 1.5]];
565 let y = array![0.0, 0.25, 1.0, 2.25];
566
567 let fitted = ski.fit(&x, &y).unwrap();
568 let x_test = array![[0.25, 0.25], [0.75, 0.75]];
569 let predictions = fitted.predict(&x_test).unwrap();
570
571 assert_eq!(predictions.len(), 2);
572 assert!(predictions.iter().all(|&x| x.is_finite()));
573 }
574
575 #[test]
576 fn test_ski_with_variance() {
577 let kernel = RBFKernel::new(1.0, 1.0);
578 let ski = StructuredKernelInterpolation::new(vec![3, 3], kernel).noise_variance(0.1);
579
580 let x = array![[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]];
581 let y = array![0.0, 1.0, 4.0];
582
583 let fitted = ski.fit(&x, &y).unwrap();
584 let x_test = array![[0.5, 0.5], [1.5, 1.5]];
585 let (mean, var) = fitted.predict_with_variance(&x_test).unwrap();
586
587 assert_eq!(mean.len(), 2);
588 assert_eq!(var.len(), 2);
589 assert!(mean.iter().all(|&x| x.is_finite()));
590 assert!(var.iter().all(|&x| x >= 0.0 && x.is_finite()));
591 }
592
593 #[test]
594 fn test_cubic_interpolation() {
595 let kernel = RBFKernel::new(1.0, 1.0);
596 let ski = StructuredKernelInterpolation::new(vec![4, 4], kernel)
597 .interpolation(InterpolationMethod::Cubic);
598
599 let x = array![[0.5, 0.5], [1.5, 1.5]];
600 let grid_points = ski.generate_grid_points(&x).unwrap();
601 let ranges = vec![(0.0, 2.0), (0.0, 2.0)];
602
603 let weights = ski
604 .compute_interpolation_weights(&x, &grid_points, &ranges)
605 .unwrap();
606
607 assert_eq!(weights.shape(), &[2, 16]); for i in 0..2 {
611 let weight_sum = weights.row(i).sum();
612 assert_abs_diff_eq!(weight_sum, 1.0, epsilon = 1e-10);
613 assert!(weights.row(i).iter().all(|&w| w >= -1e-12)); }
615 }
616
617 #[test]
618 fn test_tensor_ski_creation() {
619 let kernel = RBFKernel::new(1.0, 1.0);
620 let tensor_ski = TensorSKI::new(vec![4, 3, 5], kernel);
621
622 assert_eq!(tensor_ski.grid_sizes, vec![4, 3, 5]);
623 assert!(tensor_ski.use_kronecker);
624 }
625
626 #[test]
627 fn test_grid_size_inference() {
628 let kernel = RBFKernel::new(1.0, 1.0);
629 let ski = StructuredKernelInterpolation::new(vec![3, 2], kernel);
630
631 let _x = array![[0.0, 0.0], [1.0, 1.0]];
632 let fitted_ski = FittedSKI {
633 grid_points: array![
634 [0.0, 0.0],
635 [0.0, 1.0],
636 [1.0, 0.0],
637 [1.0, 1.0],
638 [2.0, 0.0],
639 [2.0, 1.0]
640 ],
641 weights: Array2::zeros((2, 6)),
642 kernel: ski.kernel.clone(),
643 alpha: Array1::zeros(6),
644 };
645
646 let inferred_grid_size = fitted_ski.infer_grid_size_from_points().unwrap();
647 assert_eq!(inferred_grid_size, vec![3, 2]);
648 }
649}