1use crate::sparse_gp::core::*;
8use crate::sparse_gp::kernels::{KernelOps, SparseKernel};
9use scirs2_core::ndarray::ndarray_linalg::SVD;
10use scirs2_core::ndarray::s;
11use scirs2_core::ndarray::{Array1, Array2};
12use scirs2_core::random::essentials::Uniform as RandUniform;
13use scirs2_core::random::Distribution;
14use scirs2_core::random::{thread_rng, Rng};
15use sklears_core::error::{Result, SklearsError};
16
17pub struct ScalableInference;
19
20impl ScalableInference {
21 pub fn predict<K: SparseKernel>(
23 method: &ScalableInferenceMethod,
24 k_star_m: &Array2<f64>,
25 inducing_points: &Array2<f64>,
26 alpha: &Array1<f64>,
27 kernel: &K,
28 noise_variance: f64,
29 ) -> Result<Array1<f64>> {
30 match method {
31 ScalableInferenceMethod::Direct => Self::predict_direct(k_star_m, alpha),
32 ScalableInferenceMethod::PreconditionedCG {
33 max_iter,
34 tol,
35 preconditioner,
36 } => Self::predict_with_pcg(
37 k_star_m,
38 inducing_points,
39 kernel,
40 noise_variance,
41 *max_iter,
42 *tol,
43 preconditioner,
44 ),
45 ScalableInferenceMethod::Lanczos { num_vectors, tol } => Self::predict_with_lanczos(
46 k_star_m,
47 inducing_points,
48 kernel,
49 noise_variance,
50 *num_vectors,
51 *tol,
52 ),
53 }
54 }
55
56 fn predict_direct(k_star_m: &Array2<f64>, alpha: &Array1<f64>) -> Result<Array1<f64>> {
58 Ok(k_star_m.dot(alpha))
59 }
60
61 fn predict_with_pcg<K: SparseKernel>(
63 k_star_m: &Array2<f64>,
64 inducing_points: &Array2<f64>,
65 kernel: &K,
66 noise_variance: f64,
67 max_iter: usize,
68 tol: f64,
69 preconditioner: &PreconditionerType,
70 ) -> Result<Array1<f64>> {
71 let m = inducing_points.nrows();
72
73 let mut a_matrix = kernel.kernel_matrix(inducing_points, inducing_points);
75 for i in 0..m {
76 a_matrix[(i, i)] += noise_variance;
77 }
78
79 let rhs = k_star_m.t().dot(&Array1::ones(k_star_m.nrows()));
81
82 let solution = PreconditionedCG::solve(&a_matrix, &rhs, max_iter, tol, preconditioner)?;
84
85 Ok(k_star_m.dot(&solution))
86 }
87
88 fn predict_with_lanczos<K: SparseKernel>(
90 k_star_m: &Array2<f64>,
91 inducing_points: &Array2<f64>,
92 kernel: &K,
93 noise_variance: f64,
94 num_vectors: usize,
95 tol: f64,
96 ) -> Result<Array1<f64>> {
97 let k_mm = kernel.kernel_matrix(inducing_points, inducing_points);
99
100 let (eigenvals, eigenvecs) = LanczosMethod::eigendecomposition(&k_mm, num_vectors, tol)?;
102
103 let k_star_transformed = k_star_m.dot(&eigenvecs);
106
107 let scaled_eigenvals = eigenvals.mapv(|x| {
109 if x > 1e-10 {
110 1.0 / (x + noise_variance)
111 } else {
112 0.0
113 }
114 });
115
116 let prediction = k_star_transformed.dot(&scaled_eigenvals);
118 Ok(prediction)
119 }
120}
121
122pub struct PreconditionedCG;
124
125impl PreconditionedCG {
126 pub fn solve(
128 a: &Array2<f64>,
129 b: &Array1<f64>,
130 max_iter: usize,
131 tol: f64,
132 preconditioner: &PreconditionerType,
133 ) -> Result<Array1<f64>> {
134 let n = a.nrows();
135 let mut x = Array1::zeros(n);
136 let mut r = b - &a.dot(&x);
137
138 let precond_matrix = PreconditionerSetup::setup_preconditioner(a, preconditioner)?;
140 let mut z = PreconditionerSetup::apply_preconditioner(&precond_matrix, &r, preconditioner)?;
141 let mut p = z.clone();
142 let mut rsold = r.dot(&z);
143
144 for iter in 0..max_iter {
145 let ap = a.dot(&p);
146 let alpha = rsold / p.dot(&ap);
147
148 x = &x + alpha * &p;
149 r = &r - alpha * ≈
150
151 let rnorm = r.mapv(|x| x * x).sum().sqrt();
153 if rnorm < tol {
154 break;
155 }
156
157 z = PreconditionerSetup::apply_preconditioner(&precond_matrix, &r, preconditioner)?;
158 let rsnew = r.dot(&z);
159 let beta = rsnew / rsold;
160
161 p = &z + beta * &p;
162 rsold = rsnew;
163 }
164
165 Ok(x)
166 }
167}
168
169pub struct PreconditionerSetup;
171
172impl PreconditionerSetup {
173 pub fn setup_preconditioner(
175 a: &Array2<f64>,
176 preconditioner: &PreconditionerType,
177 ) -> Result<Array2<f64>> {
178 match preconditioner {
179 PreconditionerType::None => Ok(Array2::eye(a.nrows())),
180
181 PreconditionerType::Diagonal => {
182 let diag_inv = a
184 .diag()
185 .mapv(|x| if x.abs() > 1e-12 { 1.0 / x } else { 1.0 });
186 Ok(Array2::from_diag(&diag_inv))
187 }
188
189 PreconditionerType::IncompleteCholesky { fill_factor: _ } => {
190 let diag_inv = a
192 .diag()
193 .mapv(|x| if x > 1e-12 { 1.0 / x.sqrt() } else { 1.0 });
194 Ok(Array2::from_diag(&diag_inv))
195 }
196
197 PreconditionerType::SSOR { omega } => {
198 let n = a.nrows();
200 let mut d = Array2::zeros((n, n));
201 let mut l = Array2::zeros((n, n));
202
203 for i in 0..n {
205 d[(i, i)] = a[(i, i)];
206 for j in 0..i {
207 l[(i, j)] = a[(i, j)];
208 }
209 }
210
211 let d_inv =
213 Array2::from_diag(
214 &d.diag()
215 .mapv(|x| if x.abs() > 1e-12 { 1.0 / x } else { 1.0 }),
216 );
217 let dl = &d + *omega * &l;
218 Ok(dl.dot(&d_inv).dot(&dl.t()))
219 }
220 }
221 }
222
223 pub fn apply_preconditioner(
225 precond: &Array2<f64>,
226 vector: &Array1<f64>,
227 preconditioner: &PreconditionerType,
228 ) -> Result<Array1<f64>> {
229 match preconditioner {
230 PreconditionerType::None => Ok(vector.clone()),
231 PreconditionerType::Diagonal => Ok(&precond.diag().to_owned() * vector),
232 _ => Ok(precond.dot(vector)),
233 }
234 }
235}
236
237pub struct LanczosMethod;
239
240impl LanczosMethod {
241 pub fn eigendecomposition(
243 matrix: &Array2<f64>,
244 num_vectors: usize,
245 tol: f64,
246 ) -> Result<(Array1<f64>, Array2<f64>)> {
247 let n = matrix.nrows();
248 let m = num_vectors.min(n);
249
250 let mut q_matrix = Array2::zeros((n, m));
252 let mut alpha_vec = Array1::zeros(m);
253 let mut beta_vec = Array1::zeros(m);
254
255 let mut rng = thread_rng();
257 let uniform = RandUniform::new(-1.0, 1.0).unwrap();
258 let mut q_0 = Array1::zeros(n);
259 for i in 0..n {
260 q_0[i] = rng.sample(uniform);
261 }
262 let q_0_norm = (q_0.mapv(|x| x * x).sum() as f64).sqrt();
263 q_0 = q_0 / q_0_norm;
264 q_matrix.column_mut(0).assign(&q_0);
265
266 let mut beta = 0.0;
267 let mut q_prev = Array1::zeros(n);
268
269 for j in 0..m {
270 let q_j = q_matrix.column(j).to_owned();
271 let mut w: Array1<f64> = matrix.dot(&q_j) - beta * &q_prev;
272
273 alpha_vec[j] = q_j.dot(&w);
274 w = &w - alpha_vec[j] * &q_j;
275
276 beta = w.mapv(|x| x * x).sum().sqrt();
277 if j < m - 1 {
278 beta_vec[j] = beta;
279 if beta < tol {
280 break;
281 }
282 q_matrix.column_mut(j + 1).assign(&(&w / beta));
283 }
284
285 q_prev = q_j;
286 }
287
288 let (eigenvals, eigenvecs_tri) = TridiagonalEigenSolver::solve(&alpha_vec, &beta_vec)?;
290
291 let eigenvecs = q_matrix.dot(&eigenvecs_tri);
293
294 Ok((eigenvals, eigenvecs))
295 }
296}
297
298pub struct TridiagonalEigenSolver;
300
301impl TridiagonalEigenSolver {
302 pub fn solve(alpha: &Array1<f64>, beta: &Array1<f64>) -> Result<(Array1<f64>, Array2<f64>)> {
304 let n = alpha.len();
305
306 let mut tri_matrix = Array2::zeros((n, n));
308 for i in 0..n {
309 tri_matrix[(i, i)] = alpha[i];
310 if i < n - 1 {
311 tri_matrix[(i, i + 1)] = beta[i];
312 tri_matrix[(i + 1, i)] = beta[i];
313 }
314 }
315
316 let (u, s, _vt) = tri_matrix
318 .svd(true, true)
319 .map_err(|e| SklearsError::NumericalError(format!("SVD failed: {:?}", e)))?;
320 let u =
321 u.ok_or_else(|| SklearsError::NumericalError("U matrix not computed".to_string()))?;
322
323 Ok((s, u))
324 }
325}
326
327pub struct IterativeRefinement;
329
330impl IterativeRefinement {
331 pub fn refine_solution(
333 a: &Array2<f64>,
334 b: &Array1<f64>,
335 x: &Array1<f64>,
336 max_iter: usize,
337 tol: f64,
338 ) -> Result<Array1<f64>> {
339 let mut x_refined = x.clone();
340
341 for _iter in 0..max_iter {
342 let residual = b - &a.dot(&x_refined);
344
345 let residual_norm = residual.mapv(|x| x * x).sum().sqrt();
347 if residual_norm < tol {
348 break;
349 }
350
351 let dx = KernelOps::invert_using_cholesky(a)?.dot(&residual);
353
354 x_refined = &x_refined + &dx;
356 }
357
358 Ok(x_refined)
359 }
360}
361
362pub struct SpecializedSolvers;
364
365impl SpecializedSolvers {
366 pub fn solve_kronecker(
368 a1: &Array2<f64>,
369 a2: &Array2<f64>,
370 b: &Array2<f64>,
371 ) -> Result<Array2<f64>> {
372 let a1_inv = KernelOps::invert_using_cholesky(a1)?;
376 let a2_inv = KernelOps::invert_using_cholesky(a2)?;
377
378 let x = a1_inv.dot(b).dot(&a2_inv.t());
379
380 Ok(x)
381 }
382
383 pub fn solve_block_diagonal(
385 blocks: &[Array2<f64>],
386 rhs_blocks: &[Array1<f64>],
387 ) -> Result<Array1<f64>> {
388 if blocks.len() != rhs_blocks.len() {
389 return Err(SklearsError::InvalidInput(
390 "Number of blocks must match RHS blocks".to_string(),
391 ));
392 }
393
394 let mut solution_blocks = Vec::new();
395
396 for (block, rhs_block) in blocks.iter().zip(rhs_blocks.iter()) {
397 let block_inv = KernelOps::invert_using_cholesky(block)?;
398 let block_solution = block_inv.dot(rhs_block);
399 solution_blocks.push(block_solution);
400 }
401
402 let total_size: usize = solution_blocks.iter().map(|b| b.len()).sum();
404 let mut solution = Array1::zeros(total_size);
405 let mut offset = 0;
406
407 for block_solution in solution_blocks {
408 let block_size = block_solution.len();
409 solution
410 .slice_mut(s![offset..offset + block_size])
411 .assign(&block_solution);
412 offset += block_size;
413 }
414
415 Ok(solution)
416 }
417}
418
419#[allow(non_snake_case)]
420#[cfg(test)]
421mod tests {
422 use super::*;
423
424 use approx::assert_abs_diff_eq;
425 use scirs2_core::ndarray::array;
426
427 #[test]
428 fn test_direct_prediction() {
429 let k_star_m = array![[0.5, 0.3], [0.7, 0.2]];
430 let alpha = array![1.0, 2.0];
431
432 let result = ScalableInference::predict_direct(&k_star_m, &alpha).unwrap();
433 let expected = array![1.1, 1.1]; for (a, b) in result.iter().zip(expected.iter()) {
436 assert_abs_diff_eq!(*a, *b, epsilon = 1e-10);
437 }
438 }
439
440 #[test]
441 fn test_diagonal_preconditioner() {
442 let matrix = array![[4.0, 1.0], [1.0, 3.0]];
443 let precond =
444 PreconditionerSetup::setup_preconditioner(&matrix, &PreconditionerType::Diagonal)
445 .unwrap();
446
447 assert_abs_diff_eq!(precond[(0, 0)], 0.25, epsilon = 1e-10);
449 assert_abs_diff_eq!(precond[(1, 1)], 1.0 / 3.0, epsilon = 1e-10);
450 assert_abs_diff_eq!(precond[(0, 1)], 0.0, epsilon = 1e-10);
451 }
452
453 #[test]
454 fn test_pcg_solver() {
455 let a = array![[4.0, 1.0], [1.0, 3.0]];
456 let b = array![1.0, 2.0];
457
458 let solution =
459 PreconditionedCG::solve(&a, &b, 100, 1e-10, &PreconditionerType::Diagonal).unwrap();
460
461 let residual = &b - &a.dot(&solution);
463 let residual_norm = residual.mapv(|x| x * x).sum().sqrt();
464 assert!(residual_norm < 1e-8);
465 }
466
467 #[test]
468 fn test_lanczos_eigendecomposition() {
469 let matrix = array![[3.0, 1.0], [1.0, 2.0]];
470
471 let (eigenvals, eigenvecs) = LanczosMethod::eigendecomposition(&matrix, 2, 1e-10).unwrap();
472
473 assert_eq!(eigenvals.len(), 2);
474 assert_eq!(eigenvecs.shape(), &[2, 2]);
475
476 assert!(eigenvals.iter().all(|&x| x > 0.0));
478 }
479
480 #[test]
481 fn test_iterative_refinement() {
482 let a = array![[2.0, 1.0], [1.0, 2.0]];
483 let b = array![3.0, 3.0];
484 let x_initial = array![1.0, 1.0]; let x_refined =
487 IterativeRefinement::refine_solution(&a, &b, &x_initial, 10, 1e-12).unwrap();
488
489 for (a, b) in x_refined.iter().zip(x_initial.iter()) {
491 assert_abs_diff_eq!(*a, *b, epsilon = 1e-10);
492 }
493 }
494
495 #[test]
496 fn test_block_diagonal_solver() {
497 let block1 = array![[2.0, 0.0], [0.0, 3.0]];
498 let block2 = array![[1.0]];
499 let blocks = vec![block1, block2];
500
501 let rhs1 = array![4.0, 6.0];
502 let rhs2 = array![2.0];
503 let rhs_blocks = vec![rhs1, rhs2];
504
505 let solution = SpecializedSolvers::solve_block_diagonal(&blocks, &rhs_blocks).unwrap();
506
507 let expected = array![2.0, 2.0, 2.0];
509 assert_eq!(solution.len(), expected.len());
510
511 for (a, b) in solution.iter().zip(expected.iter()) {
512 assert_abs_diff_eq!(*a, *b, epsilon = 1e-5);
513 }
514 }
515}