1use crate::sparse_gp::core::*;
8use crate::sparse_gp::kernels::{KernelOps, SparseKernel};
9use scirs2_core::ndarray::{Array1, Array2, Axis};
10use sklears_core::error::{Result, SklearsError};
11use sklears_core::traits::{Fit, Predict};
12use std::collections::HashSet;
13
14impl<K: SparseKernel> StructuredKernelInterpolation<K> {
16 pub fn new(grid_size: Vec<usize>, kernel: K) -> Self {
18 Self {
19 grid_size,
20 kernel,
21 noise_variance: 1e-6,
22 interpolation: InterpolationMethod::Linear,
23 }
24 }
25
26 pub fn generate_grid_points(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
28 let n_features = x.ncols();
29 if self.grid_size.len() != n_features {
30 return Err(SklearsError::InvalidInput(
31 "Grid size dimension mismatch".to_string(),
32 ));
33 }
34
35 let total_grid_points: usize = self.grid_size.iter().product();
36 let mut grid_points = Array2::zeros((total_grid_points, n_features));
37
38 let mut ranges = Vec::with_capacity(n_features);
40 for j in 0..n_features {
41 let col = x.column(j);
42 let min_val = col.fold(f64::INFINITY, |a, &b| a.min(b));
43 let max_val = col.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
44 ranges.push((min_val, max_val));
45 }
46
47 let mut point_idx = 0;
49 self.generate_grid_recursive(
50 &mut grid_points,
51 &ranges,
52 &mut vec![0; n_features],
53 0,
54 &mut point_idx,
55 );
56
57 Ok(grid_points)
58 }
59
60 fn generate_grid_recursive(
62 &self,
63 grid_points: &mut Array2<f64>,
64 ranges: &[(f64, f64)],
65 current_indices: &mut Vec<usize>,
66 dim: usize,
67 point_idx: &mut usize,
68 ) {
69 if dim == ranges.len() {
70 for (j, &idx) in current_indices.iter().enumerate() {
72 let (min_val, max_val) = ranges[j];
73 let grid_val = if self.grid_size[j] == 1 {
74 (min_val + max_val) / 2.0
75 } else {
76 min_val + idx as f64 * (max_val - min_val) / (self.grid_size[j] - 1) as f64
77 };
78 grid_points[(*point_idx, j)] = grid_val;
79 }
80 *point_idx += 1;
81 return;
82 }
83
84 for i in 0..self.grid_size[dim] {
85 current_indices[dim] = i;
86 self.generate_grid_recursive(grid_points, ranges, current_indices, dim + 1, point_idx);
87 }
88 }
89
90 pub fn compute_interpolation_weights(
92 &self,
93 x: &Array2<f64>,
94 grid_points: &Array2<f64>,
95 ranges: &[(f64, f64)],
96 ) -> Result<Array2<f64>> {
97 let n = x.nrows();
98 let n_grid = grid_points.nrows();
99 let _n_features = x.ncols();
100
101 let mut weights = Array2::zeros((n, n_grid));
102
103 match self.interpolation {
104 InterpolationMethod::Linear => {
105 self.compute_linear_weights(x, grid_points, ranges, &mut weights)?;
106 }
107 InterpolationMethod::Cubic => {
108 self.compute_cubic_weights(x, grid_points, ranges, &mut weights)?;
109 }
110 }
111
112 for i in 0..n {
114 let weight_sum = weights.row(i).sum();
115 if weight_sum > 1e-12 {
116 for g in 0..n_grid {
117 weights[(i, g)] /= weight_sum;
118 }
119 }
120 }
121
122 Ok(weights)
123 }
124
125 fn compute_linear_weights(
127 &self,
128 x: &Array2<f64>,
129 grid_points: &Array2<f64>,
130 ranges: &[(f64, f64)],
131 weights: &mut Array2<f64>,
132 ) -> Result<()> {
133 let n = x.nrows();
134 let n_grid = grid_points.nrows();
135 let n_features = x.ncols();
136
137 for i in 0..n {
138 for g in 0..n_grid {
139 let mut weight = 1.0;
140 let mut valid = true;
141
142 for j in 0..n_features {
143 let x_val = x[(i, j)];
144 let grid_val = grid_points[(g, j)];
145 let (min_val, max_val) = ranges[j];
146
147 let grid_spacing = if self.grid_size[j] == 1 {
148 max_val - min_val
149 } else {
150 (max_val - min_val) / (self.grid_size[j] - 1) as f64
151 };
152
153 let distance = (x_val - grid_val).abs();
154
155 if distance > grid_spacing + 1e-12 {
157 valid = false;
158 break;
159 }
160
161 if grid_spacing > 1e-12 {
163 weight *= 1.0 - distance / grid_spacing;
164 }
165 }
166
167 if valid {
168 weights[(i, g)] = weight;
169 }
170 }
171 }
172
173 Ok(())
174 }
175
176 fn compute_cubic_weights(
178 &self,
179 x: &Array2<f64>,
180 grid_points: &Array2<f64>,
181 ranges: &[(f64, f64)],
182 weights: &mut Array2<f64>,
183 ) -> Result<()> {
184 let n = x.nrows();
185 let n_grid = grid_points.nrows();
186 let n_features = x.ncols();
187
188 for i in 0..n {
189 for g in 0..n_grid {
190 let mut weight = 1.0;
191 let mut valid = true;
192
193 for j in 0..n_features {
194 let x_val = x[(i, j)];
195 let grid_val = grid_points[(g, j)];
196 let (min_val, max_val) = ranges[j];
197
198 let grid_spacing = if self.grid_size[j] == 1 {
199 max_val - min_val
200 } else {
201 (max_val - min_val) / (self.grid_size[j] - 1) as f64
202 };
203
204 let distance = (x_val - grid_val).abs();
205
206 if distance > 2.0 * grid_spacing + 1e-12 {
208 valid = false;
209 break;
210 }
211
212 if grid_spacing > 1e-12 {
214 let t = distance / grid_spacing;
215 let cubic_weight = if t <= 1.0 {
216 1.0 - 1.5 * t * t + 0.75 * t * t * t
217 } else if t <= 2.0 {
218 0.25 * (2.0 - t).powi(3)
219 } else {
220 0.0
221 };
222 weight *= cubic_weight;
223 }
224 }
225
226 if valid && weight > 1e-12 {
227 weights[(i, g)] = weight;
228 }
229 }
230 }
231
232 Ok(())
233 }
234}
235
236impl<K: SparseKernel> Fit<Array2<f64>, Array1<f64>> for StructuredKernelInterpolation<K> {
238 type Fitted = FittedSKI<K>;
239
240 fn fit(self, x: &Array2<f64>, y: &Array1<f64>) -> Result<Self::Fitted> {
241 let grid_points = self.generate_grid_points(x)?;
243
244 let n_features = x.ncols();
246 let mut ranges = Vec::with_capacity(n_features);
247 for j in 0..n_features {
248 let col = x.column(j);
249 let min_val = col.fold(f64::INFINITY, |a, &b| a.min(b));
250 let max_val = col.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
251 ranges.push((min_val, max_val));
252 }
253
254 let weights = self.compute_interpolation_weights(x, &grid_points, &ranges)?;
256
257 let k_gg = self.compute_grid_kernel_matrix(&grid_points)?;
259
260 let mut k_gg_noise = k_gg;
262 let n_grid = grid_points.nrows();
263 for i in 0..n_grid {
264 k_gg_noise[(i, i)] += self.noise_variance;
265 }
266
267 let weighted_y = weights.t().dot(y);
269 let alpha = self.solve_structured_system(&k_gg_noise, &weighted_y)?;
270
271 Ok(FittedSKI {
272 grid_points,
273 weights,
274 kernel: self.kernel.clone(),
275 alpha,
276 })
277 }
278}
279
280impl<K: SparseKernel> StructuredKernelInterpolation<K> {
281 fn compute_grid_kernel_matrix(&self, grid_points: &Array2<f64>) -> Result<Array2<f64>> {
283 let n_features = grid_points.ncols();
284
285 if n_features == 1 || self.can_use_kronecker_structure() {
287 self.compute_kronecker_kernel_matrix(grid_points)
288 } else {
289 Ok(self.kernel.kernel_matrix(grid_points, grid_points))
291 }
292 }
293
294 fn can_use_kronecker_structure(&self) -> bool {
296 true
299 }
300
301 fn compute_kronecker_kernel_matrix(&self, grid_points: &Array2<f64>) -> Result<Array2<f64>> {
303 let n_features = grid_points.ncols();
304
305 if n_features == 1 {
306 return Ok(self.kernel.kernel_matrix(grid_points, grid_points));
308 }
309
310 Ok(self.kernel.kernel_matrix(grid_points, grid_points))
313 }
314
315 fn solve_structured_system(
317 &self,
318 k_matrix: &Array2<f64>,
319 rhs: &Array1<f64>,
320 ) -> Result<Array1<f64>> {
321 let k_inv = KernelOps::invert_using_cholesky(k_matrix)?;
324 Ok(k_inv.dot(rhs))
325 }
326}
327
328impl<K: SparseKernel> Predict<Array2<f64>, Array1<f64>> for FittedSKI<K> {
330 fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>> {
331 let n_features = x.ncols();
333 let mut ranges = Vec::with_capacity(n_features);
334 for j in 0..n_features {
335 let col = self.grid_points.column(j);
336 let min_val = col.fold(f64::INFINITY, |a, &b| a.min(b));
337 let max_val = col.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
338 ranges.push((min_val, max_val));
339 }
340
341 let grid_size = self.infer_grid_size_from_points()?;
343
344 let ski = StructuredKernelInterpolation {
345 grid_size,
346 kernel: self.kernel.clone(),
347 noise_variance: 1e-6,
348 interpolation: InterpolationMethod::Linear,
349 };
350
351 let test_weights = ski.compute_interpolation_weights(x, &self.grid_points, &ranges)?;
352 let predictions = test_weights.dot(&self.alpha);
353 Ok(predictions)
354 }
355}
356
357impl<K: SparseKernel> FittedSKI<K> {
358 fn infer_grid_size_from_points(&self) -> Result<Vec<usize>> {
360 let n_features = self.grid_points.ncols();
361 let mut grid_size = vec![1; n_features];
362
363 for j in 0..n_features {
364 let col = self.grid_points.column(j);
365 let unique_vals: HashSet<_> = col.iter().map(|&x| (x * 1e6).round() as i64).collect();
366 grid_size[j] = unique_vals.len();
367 }
368
369 Ok(grid_size)
370 }
371
372 pub fn predict_with_variance(&self, x: &Array2<f64>) -> Result<(Array1<f64>, Array1<f64>)> {
374 let n_features = x.ncols();
376 let mut ranges = Vec::with_capacity(n_features);
377 for j in 0..n_features {
378 let col = self.grid_points.column(j);
379 let min_val = col.fold(f64::INFINITY, |a, &b| a.min(b));
380 let max_val = col.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
381 ranges.push((min_val, max_val));
382 }
383
384 let grid_size = self.infer_grid_size_from_points()?;
385 let ski = StructuredKernelInterpolation {
386 grid_size,
387 kernel: self.kernel.clone(),
388 noise_variance: 1e-6,
389 interpolation: InterpolationMethod::Linear,
390 };
391
392 let test_weights = ski.compute_interpolation_weights(x, &self.grid_points, &ranges)?;
393
394 let pred_mean = test_weights.dot(&self.alpha);
396
397 let k_test_diag = self.kernel.kernel_diagonal(x);
399 let pred_var = k_test_diag; Ok((pred_mean, pred_var))
402 }
403}
404
405pub struct TensorSKI<K: SparseKernel> {
407 pub grid_sizes: Vec<usize>,
409 pub kernel: K,
411 pub noise_variance: f64,
413 pub use_kronecker: bool,
415}
416
417impl<K: SparseKernel> TensorSKI<K> {
418 pub fn new(grid_sizes: Vec<usize>, kernel: K) -> Self {
419 Self {
420 grid_sizes,
421 kernel,
422 noise_variance: 1e-6,
423 use_kronecker: true,
424 }
425 }
426
427 pub fn fit_tensor(&self, x: &Array2<f64>, _y: &Array1<f64>) -> Result<FittedTensorSKI<K>> {
429 if !self.use_kronecker {
430 return Err(SklearsError::InvalidInput(
431 "Tensor SKI requires Kronecker structure".to_string(),
432 ));
433 }
434
435 let n_features = x.ncols();
436 if self.grid_sizes.len() != n_features {
437 return Err(SklearsError::InvalidInput(
438 "Grid sizes must match number of features".to_string(),
439 ));
440 }
441
442 let mut dim_grids = Vec::with_capacity(n_features);
444 for j in 0..n_features {
445 let col = x.column(j);
446 let min_val = col.fold(f64::INFINITY, |a, &b| a.min(b));
447 let max_val = col.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
448
449 let mut grid_1d = Array1::zeros(self.grid_sizes[j]);
450 for i in 0..self.grid_sizes[j] {
451 if self.grid_sizes[j] == 1 {
452 grid_1d[i] = (min_val + max_val) / 2.0;
453 } else {
454 grid_1d[i] =
455 min_val + i as f64 * (max_val - min_val) / (self.grid_sizes[j] - 1) as f64;
456 }
457 }
458 dim_grids.push(grid_1d);
459 }
460
461 let mut kernel_matrices_1d = Vec::with_capacity(n_features);
463 for j in 0..n_features {
464 let grid_1d_2d = dim_grids[j].clone().insert_axis(Axis(1));
465 let k_1d = self.kernel.kernel_matrix(&grid_1d_2d, &grid_1d_2d);
466 kernel_matrices_1d.push(k_1d);
467 }
468
469 Ok(FittedTensorSKI {
470 dim_grids,
471 kernel_matrices_1d,
472 kernel: self.kernel.clone(),
473 alpha: Array1::zeros(1), })
475 }
476}
477
478#[derive(Debug, Clone)]
480pub struct FittedTensorSKI<K: SparseKernel> {
481 pub dim_grids: Vec<Array1<f64>>,
483 pub kernel_matrices_1d: Vec<Array2<f64>>,
485 pub kernel: K,
487 pub alpha: Array1<f64>,
489}
490
491#[allow(non_snake_case)]
492#[cfg(test)]
493mod tests {
494 use super::*;
495 use crate::sparse_gp::kernels::RBFKernel;
496 use approx::assert_abs_diff_eq;
497 use scirs2_core::ndarray::array;
498
499 #[test]
500 fn test_grid_generation() {
501 let kernel = RBFKernel::new(1.0, 1.0);
502 let ski = StructuredKernelInterpolation::new(vec![3, 2], kernel);
503
504 let x = array![[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]];
505 let grid_points = ski
506 .generate_grid_points(&x)
507 .expect("operation should succeed");
508
509 assert_eq!(grid_points.shape(), &[6, 2]); assert!(grid_points.iter().all(|&x| x.is_finite()));
511 }
512
513 #[test]
514 fn test_linear_interpolation_weights() {
515 let kernel = RBFKernel::new(1.0, 1.0);
516 let ski = StructuredKernelInterpolation::new(vec![3, 3], kernel)
517 .interpolation(InterpolationMethod::Linear);
518
519 let x = array![[0.5, 0.5], [1.0, 1.0]];
520 let grid_points = array![
521 [0.0, 0.0],
522 [0.0, 1.0],
523 [0.0, 2.0],
524 [1.0, 0.0],
525 [1.0, 1.0],
526 [1.0, 2.0],
527 [2.0, 0.0],
528 [2.0, 1.0],
529 [2.0, 2.0]
530 ];
531 let ranges = vec![(0.0, 2.0), (0.0, 2.0)];
532
533 let weights = ski
534 .compute_interpolation_weights(&x, &grid_points, &ranges)
535 .expect("operation should succeed");
536
537 assert_eq!(weights.shape(), &[2, 9]);
538
539 for i in 0..2 {
541 let weight_sum = weights.row(i).sum();
542 assert_abs_diff_eq!(weight_sum, 1.0, epsilon = 1e-10);
543 }
544 }
545
546 #[test]
547 fn test_ski_fit() {
548 let kernel = RBFKernel::new(1.0, 1.0);
549 let ski = StructuredKernelInterpolation::new(vec![3, 3], kernel).noise_variance(0.1);
550
551 let x = array![[0.0, 0.0], [0.5, 0.5], [1.0, 1.0], [1.5, 1.5]];
552 let y = array![0.0, 0.25, 1.0, 2.25];
553
554 let fitted = ski.fit(&x, &y).expect("operation should succeed");
555
556 assert_eq!(fitted.grid_points.nrows(), 9); assert_eq!(fitted.alpha.len(), 9);
558 assert!(fitted.alpha.iter().all(|&x| x.is_finite()));
559 }
560
561 #[test]
562 fn test_ski_prediction() {
563 let kernel = RBFKernel::new(1.0, 1.0);
564 let ski = StructuredKernelInterpolation::new(vec![3, 3], kernel).noise_variance(0.1);
565
566 let x = array![[0.0, 0.0], [0.5, 0.5], [1.0, 1.0], [1.5, 1.5]];
567 let y = array![0.0, 0.25, 1.0, 2.25];
568
569 let fitted = ski.fit(&x, &y).expect("operation should succeed");
570 let x_test = array![[0.25, 0.25], [0.75, 0.75]];
571 let predictions = fitted.predict(&x_test).expect("operation should succeed");
572
573 assert_eq!(predictions.len(), 2);
574 assert!(predictions.iter().all(|&x| x.is_finite()));
575 }
576
577 #[test]
578 fn test_ski_with_variance() {
579 let kernel = RBFKernel::new(1.0, 1.0);
580 let ski = StructuredKernelInterpolation::new(vec![3, 3], kernel).noise_variance(0.1);
581
582 let x = array![[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]];
583 let y = array![0.0, 1.0, 4.0];
584
585 let fitted = ski.fit(&x, &y).expect("operation should succeed");
586 let x_test = array![[0.5, 0.5], [1.5, 1.5]];
587 let (mean, var) = fitted
588 .predict_with_variance(&x_test)
589 .expect("operation should succeed");
590
591 assert_eq!(mean.len(), 2);
592 assert_eq!(var.len(), 2);
593 assert!(mean.iter().all(|&x| x.is_finite()));
594 assert!(var.iter().all(|&x| x >= 0.0 && x.is_finite()));
595 }
596
597 #[test]
598 fn test_cubic_interpolation() {
599 let kernel = RBFKernel::new(1.0, 1.0);
600 let ski = StructuredKernelInterpolation::new(vec![4, 4], kernel)
601 .interpolation(InterpolationMethod::Cubic);
602
603 let x = array![[0.5, 0.5], [1.5, 1.5]];
604 let grid_points = ski
605 .generate_grid_points(&x)
606 .expect("operation should succeed");
607 let ranges = vec![(0.0, 2.0), (0.0, 2.0)];
608
609 let weights = ski
610 .compute_interpolation_weights(&x, &grid_points, &ranges)
611 .expect("operation should succeed");
612
613 assert_eq!(weights.shape(), &[2, 16]); for i in 0..2 {
617 let weight_sum = weights.row(i).sum();
618 assert_abs_diff_eq!(weight_sum, 1.0, epsilon = 1e-10);
619 assert!(weights.row(i).iter().all(|&w| w >= -1e-12)); }
621 }
622
623 #[test]
624 fn test_tensor_ski_creation() {
625 let kernel = RBFKernel::new(1.0, 1.0);
626 let tensor_ski = TensorSKI::new(vec![4, 3, 5], kernel);
627
628 assert_eq!(tensor_ski.grid_sizes, vec![4, 3, 5]);
629 assert!(tensor_ski.use_kronecker);
630 }
631
632 #[test]
633 fn test_grid_size_inference() {
634 let kernel = RBFKernel::new(1.0, 1.0);
635 let ski = StructuredKernelInterpolation::new(vec![3, 2], kernel);
636
637 let _x = array![[0.0, 0.0], [1.0, 1.0]];
638 let fitted_ski = FittedSKI {
639 grid_points: array![
640 [0.0, 0.0],
641 [0.0, 1.0],
642 [1.0, 0.0],
643 [1.0, 1.0],
644 [2.0, 0.0],
645 [2.0, 1.0]
646 ],
647 weights: Array2::zeros((2, 6)),
648 kernel: ski.kernel.clone(),
649 alpha: Array1::zeros(6),
650 };
651
652 let inferred_grid_size = fitted_ski
653 .infer_grid_size_from_points()
654 .expect("operation should succeed");
655 assert_eq!(inferred_grid_size, vec![3, 2]);
656 }
657}