1use crate::sparse_gp::core::*;
8use crate::sparse_gp::kernels::{KernelOps, SparseKernel};
9
10use scirs2_core::ndarray::s;
11use scirs2_core::ndarray::{Array1, Array2};
12use scirs2_core::random::essentials::Uniform as RandUniform;
13use scirs2_core::random::thread_rng;
14use scirs2_linalg::compat::ArrayLinalgExt;
15use sklears_core::error::{Result, SklearsError};
16
17pub struct ScalableInference;
19
20impl ScalableInference {
21 pub fn predict<K: SparseKernel>(
23 method: &ScalableInferenceMethod,
24 k_star_m: &Array2<f64>,
25 inducing_points: &Array2<f64>,
26 alpha: &Array1<f64>,
27 kernel: &K,
28 noise_variance: f64,
29 ) -> Result<Array1<f64>> {
30 match method {
31 ScalableInferenceMethod::Direct => Self::predict_direct(k_star_m, alpha),
32 ScalableInferenceMethod::PreconditionedCG {
33 max_iter,
34 tol,
35 preconditioner,
36 } => Self::predict_with_pcg(
37 k_star_m,
38 inducing_points,
39 kernel,
40 noise_variance,
41 *max_iter,
42 *tol,
43 preconditioner,
44 ),
45 ScalableInferenceMethod::Lanczos { num_vectors, tol } => Self::predict_with_lanczos(
46 k_star_m,
47 inducing_points,
48 kernel,
49 noise_variance,
50 *num_vectors,
51 *tol,
52 ),
53 }
54 }
55
56 fn predict_direct(k_star_m: &Array2<f64>, alpha: &Array1<f64>) -> Result<Array1<f64>> {
58 Ok(k_star_m.dot(alpha))
59 }
60
61 fn predict_with_pcg<K: SparseKernel>(
63 k_star_m: &Array2<f64>,
64 inducing_points: &Array2<f64>,
65 kernel: &K,
66 noise_variance: f64,
67 max_iter: usize,
68 tol: f64,
69 preconditioner: &PreconditionerType,
70 ) -> Result<Array1<f64>> {
71 let m = inducing_points.nrows();
72
73 let mut a_matrix = kernel.kernel_matrix(inducing_points, inducing_points);
75 for i in 0..m {
76 a_matrix[(i, i)] += noise_variance;
77 }
78
79 let rhs = k_star_m.t().dot(&Array1::ones(k_star_m.nrows()));
81
82 let solution = PreconditionedCG::solve(&a_matrix, &rhs, max_iter, tol, preconditioner)?;
84
85 Ok(k_star_m.dot(&solution))
86 }
87
88 fn predict_with_lanczos<K: SparseKernel>(
90 k_star_m: &Array2<f64>,
91 inducing_points: &Array2<f64>,
92 kernel: &K,
93 noise_variance: f64,
94 num_vectors: usize,
95 tol: f64,
96 ) -> Result<Array1<f64>> {
97 let k_mm = kernel.kernel_matrix(inducing_points, inducing_points);
99
100 let (eigenvals, eigenvecs) = LanczosMethod::eigendecomposition(&k_mm, num_vectors, tol)?;
102
103 let k_star_transformed = k_star_m.dot(&eigenvecs);
106
107 let scaled_eigenvals = eigenvals.mapv(|x| {
109 if x > 1e-10 {
110 1.0 / (x + noise_variance)
111 } else {
112 0.0
113 }
114 });
115
116 let prediction = k_star_transformed.dot(&scaled_eigenvals);
118 Ok(prediction)
119 }
120}
121
122pub struct PreconditionedCG;
124
125impl PreconditionedCG {
126 pub fn solve(
128 a: &Array2<f64>,
129 b: &Array1<f64>,
130 max_iter: usize,
131 tol: f64,
132 preconditioner: &PreconditionerType,
133 ) -> Result<Array1<f64>> {
134 let n = a.nrows();
135 let mut x = Array1::zeros(n);
136 let mut r = b - &a.dot(&x);
137
138 let precond_matrix = PreconditionerSetup::setup_preconditioner(a, preconditioner)?;
140 let mut z = PreconditionerSetup::apply_preconditioner(&precond_matrix, &r, preconditioner)?;
141 let mut p = z.clone();
142 let mut rsold = r.dot(&z);
143
144 for _iter in 0..max_iter {
145 let ap = a.dot(&p);
146 let alpha = rsold / p.dot(&ap);
147
148 x = &x + alpha * &p;
149 r = &r - alpha * ≈
150
151 let rnorm = r.mapv(|x| x * x).sum().sqrt();
153 if rnorm < tol {
154 break;
155 }
156
157 z = PreconditionerSetup::apply_preconditioner(&precond_matrix, &r, preconditioner)?;
158 let rsnew = r.dot(&z);
159 let beta = rsnew / rsold;
160
161 p = &z + beta * &p;
162 rsold = rsnew;
163 }
164
165 Ok(x)
166 }
167}
168
169pub struct PreconditionerSetup;
171
172impl PreconditionerSetup {
173 pub fn setup_preconditioner(
175 a: &Array2<f64>,
176 preconditioner: &PreconditionerType,
177 ) -> Result<Array2<f64>> {
178 match preconditioner {
179 PreconditionerType::None => Ok(Array2::eye(a.nrows())),
180
181 PreconditionerType::Diagonal => {
182 let diag_inv = a
184 .diag()
185 .mapv(|x| if x.abs() > 1e-12 { 1.0 / x } else { 1.0 });
186 Ok(Array2::from_diag(&diag_inv))
187 }
188
189 PreconditionerType::IncompleteCholesky { fill_factor: _ } => {
190 let diag_inv = a
192 .diag()
193 .mapv(|x| if x > 1e-12 { 1.0 / x.sqrt() } else { 1.0 });
194 Ok(Array2::from_diag(&diag_inv))
195 }
196
197 PreconditionerType::SSOR { omega } => {
198 let n = a.nrows();
200 let mut d = Array2::zeros((n, n));
201 let mut l = Array2::zeros((n, n));
202
203 for i in 0..n {
205 d[(i, i)] = a[(i, i)];
206 for j in 0..i {
207 l[(i, j)] = a[(i, j)];
208 }
209 }
210
211 let d_inv =
213 Array2::from_diag(
214 &d.diag()
215 .mapv(|x| if x.abs() > 1e-12 { 1.0 / x } else { 1.0 }),
216 );
217 let dl = &d + *omega * &l;
218 Ok(dl.dot(&d_inv).dot(&dl.t()))
219 }
220 }
221 }
222
223 pub fn apply_preconditioner(
225 precond: &Array2<f64>,
226 vector: &Array1<f64>,
227 preconditioner: &PreconditionerType,
228 ) -> Result<Array1<f64>> {
229 match preconditioner {
230 PreconditionerType::None => Ok(vector.clone()),
231 PreconditionerType::Diagonal => Ok(&precond.diag().to_owned() * vector),
232 _ => Ok(precond.dot(vector)),
233 }
234 }
235}
236
237pub struct LanczosMethod;
239
240impl LanczosMethod {
241 pub fn eigendecomposition(
243 matrix: &Array2<f64>,
244 num_vectors: usize,
245 tol: f64,
246 ) -> Result<(Array1<f64>, Array2<f64>)> {
247 let n = matrix.nrows();
248 let m = num_vectors.min(n);
249
250 let mut q_matrix = Array2::zeros((n, m));
252 let mut alpha_vec = Array1::zeros(m);
253 let mut beta_vec = Array1::zeros(m);
254
255 let mut rng = thread_rng();
257 let uniform = RandUniform::new(-1.0, 1.0).unwrap();
258 let mut q_0 = Array1::zeros(n);
259 for i in 0..n {
260 q_0[i] = rng.sample(uniform);
261 }
262 #[allow(clippy::unnecessary_cast)]
263 let q_0_norm = (q_0.mapv(|x| x * x).sum() as f64).sqrt();
264 q_0 /= q_0_norm;
265 q_matrix.column_mut(0).assign(&q_0);
266
267 let mut beta = 0.0;
268 let mut q_prev = Array1::zeros(n);
269
270 for j in 0..m {
271 let q_j = q_matrix.column(j).to_owned();
272 let mut w: Array1<f64> = matrix.dot(&q_j) - beta * &q_prev;
273
274 alpha_vec[j] = q_j.dot(&w);
275 w = &w - alpha_vec[j] * &q_j;
276
277 beta = w.mapv(|x| x * x).sum().sqrt();
278 if j < m - 1 {
279 beta_vec[j] = beta;
280 if beta < tol {
281 break;
282 }
283 q_matrix.column_mut(j + 1).assign(&(&w / beta));
284 }
285
286 q_prev = q_j;
287 }
288
289 let (eigenvals, eigenvecs_tri) = TridiagonalEigenSolver::solve(&alpha_vec, &beta_vec)?;
291
292 let eigenvecs = q_matrix.dot(&eigenvecs_tri);
294
295 Ok((eigenvals, eigenvecs))
296 }
297}
298
299pub struct TridiagonalEigenSolver;
301
302impl TridiagonalEigenSolver {
303 pub fn solve(alpha: &Array1<f64>, beta: &Array1<f64>) -> Result<(Array1<f64>, Array2<f64>)> {
305 let n = alpha.len();
306
307 let mut tri_matrix = Array2::zeros((n, n));
309 for i in 0..n {
310 tri_matrix[(i, i)] = alpha[i];
311 if i < n - 1 {
312 tri_matrix[(i, i + 1)] = beta[i];
313 tri_matrix[(i + 1, i)] = beta[i];
314 }
315 }
316
317 let (u, s, _vt) = tri_matrix
319 .svd(true)
320 .map_err(|e| SklearsError::NumericalError(format!("SVD failed: {:?}", e)))?;
321
322 Ok((s, u))
323 }
324}
325
326pub struct IterativeRefinement;
328
329impl IterativeRefinement {
330 pub fn refine_solution(
332 a: &Array2<f64>,
333 b: &Array1<f64>,
334 x: &Array1<f64>,
335 max_iter: usize,
336 tol: f64,
337 ) -> Result<Array1<f64>> {
338 let mut x_refined = x.clone();
339
340 for _iter in 0..max_iter {
341 let residual = b - &a.dot(&x_refined);
343
344 let residual_norm = residual.mapv(|x| x * x).sum().sqrt();
346 if residual_norm < tol {
347 break;
348 }
349
350 let dx = KernelOps::invert_using_cholesky(a)?.dot(&residual);
352
353 x_refined = &x_refined + &dx;
355 }
356
357 Ok(x_refined)
358 }
359}
360
361pub struct SpecializedSolvers;
363
364impl SpecializedSolvers {
365 pub fn solve_kronecker(
367 a1: &Array2<f64>,
368 a2: &Array2<f64>,
369 b: &Array2<f64>,
370 ) -> Result<Array2<f64>> {
371 let a1_inv = KernelOps::invert_using_cholesky(a1)?;
375 let a2_inv = KernelOps::invert_using_cholesky(a2)?;
376
377 let x = a1_inv.dot(b).dot(&a2_inv.t());
378
379 Ok(x)
380 }
381
382 pub fn solve_block_diagonal(
384 blocks: &[Array2<f64>],
385 rhs_blocks: &[Array1<f64>],
386 ) -> Result<Array1<f64>> {
387 if blocks.len() != rhs_blocks.len() {
388 return Err(SklearsError::InvalidInput(
389 "Number of blocks must match RHS blocks".to_string(),
390 ));
391 }
392
393 let mut solution_blocks = Vec::new();
394
395 for (block, rhs_block) in blocks.iter().zip(rhs_blocks.iter()) {
396 let block_inv = KernelOps::invert_using_cholesky(block)?;
397 let block_solution = block_inv.dot(rhs_block);
398 solution_blocks.push(block_solution);
399 }
400
401 let total_size: usize = solution_blocks.iter().map(|b| b.len()).sum();
403 let mut solution = Array1::zeros(total_size);
404 let mut offset = 0;
405
406 for block_solution in solution_blocks {
407 let block_size = block_solution.len();
408 solution
409 .slice_mut(s![offset..offset + block_size])
410 .assign(&block_solution);
411 offset += block_size;
412 }
413
414 Ok(solution)
415 }
416}
417
418#[allow(non_snake_case)]
419#[cfg(test)]
420mod tests {
421 use super::*;
422
423 use approx::assert_abs_diff_eq;
424 use scirs2_core::ndarray::array;
425
426 #[test]
427 fn test_direct_prediction() {
428 let k_star_m = array![[0.5, 0.3], [0.7, 0.2]];
429 let alpha = array![1.0, 2.0];
430
431 let result = ScalableInference::predict_direct(&k_star_m, &alpha).unwrap();
432 let expected = array![1.1, 1.1]; for (a, b) in result.iter().zip(expected.iter()) {
435 assert_abs_diff_eq!(*a, *b, epsilon = 1e-10);
436 }
437 }
438
439 #[test]
440 fn test_diagonal_preconditioner() {
441 let matrix = array![[4.0, 1.0], [1.0, 3.0]];
442 let precond =
443 PreconditionerSetup::setup_preconditioner(&matrix, &PreconditionerType::Diagonal)
444 .unwrap();
445
446 assert_abs_diff_eq!(precond[(0, 0)], 0.25, epsilon = 1e-10);
448 assert_abs_diff_eq!(precond[(1, 1)], 1.0 / 3.0, epsilon = 1e-10);
449 assert_abs_diff_eq!(precond[(0, 1)], 0.0, epsilon = 1e-10);
450 }
451
452 #[test]
453 fn test_pcg_solver() {
454 let a = array![[4.0, 1.0], [1.0, 3.0]];
455 let b = array![1.0, 2.0];
456
457 let solution =
458 PreconditionedCG::solve(&a, &b, 100, 1e-10, &PreconditionerType::Diagonal).unwrap();
459
460 let residual = &b - &a.dot(&solution);
462 let residual_norm = residual.mapv(|x| x * x).sum().sqrt();
463 assert!(residual_norm < 1e-8);
464 }
465
466 #[test]
467 fn test_lanczos_eigendecomposition() {
468 let matrix = array![[3.0, 1.0], [1.0, 2.0]];
469
470 let (eigenvals, eigenvecs) = LanczosMethod::eigendecomposition(&matrix, 2, 1e-10).unwrap();
471
472 assert_eq!(eigenvals.len(), 2);
473 assert_eq!(eigenvecs.shape(), &[2, 2]);
474
475 assert!(eigenvals.iter().all(|&x| x > 0.0));
477 }
478
479 #[test]
480 fn test_iterative_refinement() {
481 let a = array![[2.0, 1.0], [1.0, 2.0]];
482 let b = array![3.0, 3.0];
483 let x_initial = array![1.0, 1.0]; let x_refined =
486 IterativeRefinement::refine_solution(&a, &b, &x_initial, 10, 1e-12).unwrap();
487
488 for (a, b) in x_refined.iter().zip(x_initial.iter()) {
490 assert_abs_diff_eq!(*a, *b, epsilon = 1e-10);
491 }
492 }
493
494 #[test]
495 fn test_block_diagonal_solver() {
496 let block1 = array![[2.0, 0.0], [0.0, 3.0]];
497 let block2 = array![[1.0]];
498 let blocks = vec![block1, block2];
499
500 let rhs1 = array![4.0, 6.0];
501 let rhs2 = array![2.0];
502 let rhs_blocks = vec![rhs1, rhs2];
503
504 let solution = SpecializedSolvers::solve_block_diagonal(&blocks, &rhs_blocks).unwrap();
505
506 let expected = array![2.0, 2.0, 2.0];
508 assert_eq!(solution.len(), expected.len());
509
510 for (a, b) in solution.iter().zip(expected.iter()) {
511 assert_abs_diff_eq!(*a, *b, epsilon = 1e-5);
512 }
513 }
514}