1use crate::sparse_gp::core::*;
8use crate::sparse_gp::kernels::{KernelOps, SparseKernel};
9use scirs2_core::ndarray::{Array1, Array2};
10use scirs2_core::random::{thread_rng, Rng};
11use scirs2_core::RngExt;
12use sklears_core::error::{Result, SklearsError};
13use std::f64::consts::PI;
14
15pub struct VariationalFreeEnergy;
17
18impl VariationalFreeEnergy {
19 pub fn fit<K: SparseKernel>(
21 x: &Array2<f64>,
22 y: &Array1<f64>,
23 inducing_points: &Array2<f64>,
24 kernel: &K,
25 noise_variance: f64,
26 whitened: bool,
27 natural_gradients: bool,
28 max_iter: usize,
29 tol: f64,
30 ) -> Result<(Array1<f64>, Array2<f64>, VariationalParams)> {
31 let _n = x.nrows();
32 let m = inducing_points.nrows();
33
34 let mut variational_mean = Array1::zeros(m);
36 let mut variational_cov_factor = Array2::eye(m);
37
38 let k_mm = kernel.kernel_matrix(inducing_points, inducing_points);
40 let k_nm = kernel.kernel_matrix(x, inducing_points);
41 let k_diag = kernel.kernel_diagonal(x);
42
43 let k_mm_inv = KernelOps::invert_using_cholesky(&k_mm)?;
44
45 let mut best_elbo = f64::NEG_INFINITY;
47 let mut best_params = None;
48
49 for _iter in 0..max_iter {
50 let old_mean = variational_mean.clone();
51 let old_cov_factor = variational_cov_factor.clone();
52
53 let elbo_result = Self::compute_elbo_and_gradients(
55 y,
56 &k_nm,
57 &k_diag,
58 &k_mm,
59 &k_mm_inv,
60 &variational_mean,
61 &variational_cov_factor,
62 noise_variance,
63 whitened,
64 )?;
65
66 if elbo_result.elbo > best_elbo {
67 best_elbo = elbo_result.elbo;
68 best_params = Some((
69 variational_mean.clone(),
70 variational_cov_factor.clone(),
71 elbo_result.clone(),
72 ));
73 }
74
75 if natural_gradients {
77 Self::natural_gradient_update(
78 &mut variational_mean,
79 &mut variational_cov_factor,
80 &elbo_result,
81 0.01, )?;
83 } else {
84 Self::standard_gradient_update(
85 &mut variational_mean,
86 &mut variational_cov_factor,
87 &elbo_result,
88 0.01, )?;
90 }
91
92 let mean_change = (&variational_mean - &old_mean).mapv(|x| x * x).sum().sqrt();
94 let cov_change = (&variational_cov_factor - &old_cov_factor)
95 .mapv(|x| x * x)
96 .sum()
97 .sqrt();
98
99 if mean_change < tol && cov_change < tol {
100 break;
101 }
102 }
103
104 let (best_mean, best_cov_factor, best_elbo_result) = best_params
105 .ok_or_else(|| SklearsError::NumericalError("VFE optimization failed".to_string()))?;
106
107 let alpha = if whitened {
109 Self::compute_alpha_whitened(&k_nm, &k_mm_inv, &best_mean, noise_variance)?
110 } else {
111 Self::compute_alpha_standard(&k_mm_inv, &best_mean)?
112 };
113
114 let vfe_params = VariationalParams {
115 mean: best_mean,
116 cov_factor: best_cov_factor,
117 elbo: best_elbo_result.elbo,
118 kl_divergence: best_elbo_result.kl_divergence,
119 log_likelihood: best_elbo_result.log_likelihood,
120 };
121
122 Ok((alpha, k_mm_inv, vfe_params))
123 }
124
125 fn compute_elbo_and_gradients(
127 y: &Array1<f64>,
128 k_nm: &Array2<f64>,
129 k_diag: &Array1<f64>,
130 k_mm: &Array2<f64>,
131 k_mm_inv: &Array2<f64>,
132 variational_mean: &Array1<f64>,
133 variational_cov_factor: &Array2<f64>,
134 noise_variance: f64,
135 whitened: bool,
136 ) -> Result<ELBOResult> {
137 let _n = y.len();
138 let _m = variational_mean.len();
139
140 let variational_cov = variational_cov_factor.dot(&variational_cov_factor.t());
142
143 let (log_likelihood, ll_grad_mean, ll_grad_cov) = Self::compute_expected_log_likelihood(
145 y,
146 k_nm,
147 k_diag,
148 variational_mean,
149 &variational_cov,
150 noise_variance,
151 )?;
152
153 let (kl_divergence, kl_grad_mean, kl_grad_cov) = if whitened {
155 Self::compute_kl_divergence_whitened(variational_mean, &variational_cov)?
156 } else {
157 Self::compute_kl_divergence_standard(
158 variational_mean,
159 &variational_cov,
160 k_mm,
161 k_mm_inv,
162 )?
163 };
164
165 let elbo = log_likelihood - kl_divergence;
167
168 let grad_mean = &ll_grad_mean - &kl_grad_mean;
170 let grad_cov = &ll_grad_cov - &kl_grad_cov;
171
172 Ok(ELBOResult {
173 elbo,
174 log_likelihood,
175 kl_divergence,
176 grad_mean,
177 grad_cov,
178 })
179 }
180
181 fn compute_expected_log_likelihood(
183 y: &Array1<f64>,
184 k_nm: &Array2<f64>,
185 k_diag: &Array1<f64>,
186 variational_mean: &Array1<f64>,
187 variational_cov: &Array2<f64>,
188 noise_variance: f64,
189 ) -> Result<(f64, Array1<f64>, Array2<f64>)> {
190 let n = y.len();
191
192 let pred_mean = k_nm.dot(variational_mean);
194
195 let k_nm_s_k_mn = k_nm.dot(variational_cov).dot(&k_nm.t());
197 let mut pred_var_diag = k_diag.clone();
198
199 for i in 0..n {
200 pred_var_diag[i] -= k_nm_s_k_mn[(i, i)];
201 pred_var_diag[i] = pred_var_diag[i].max(1e-6); }
203
204 let residuals = y - &pred_mean;
206 let total_var = &pred_var_diag + noise_variance;
207
208 let log_likelihood = -0.5
209 * (n as f64 * (2.0 * PI).ln()
210 + total_var.mapv(|x| x.ln()).sum()
211 + residuals
212 .iter()
213 .zip(total_var.iter())
214 .map(|(r, v)| r * r / v)
215 .sum::<f64>());
216
217 let grad_mean = k_nm.t().dot(&(y - &pred_mean)) / noise_variance;
219 let grad_cov = Array2::zeros(variational_cov.dim());
220
221 Ok((log_likelihood, grad_mean, grad_cov))
222 }
223
224 fn compute_kl_divergence_standard(
226 variational_mean: &Array1<f64>,
227 variational_cov: &Array2<f64>,
228 k_mm: &Array2<f64>,
229 k_mm_inv: &Array2<f64>,
230 ) -> Result<(f64, Array1<f64>, Array2<f64>)> {
231 let m = variational_mean.len();
232
233 let trace_term = (k_mm_inv * variational_cov).diag().sum();
237
238 let quad_term = variational_mean.dot(&k_mm_inv.dot(variational_mean));
240
241 let log_det_s = Self::log_det_from_cholesky_factor(variational_cov)?;
243 let log_det_k_mm =
244 KernelOps::log_det_from_cholesky(&KernelOps::cholesky_with_jitter(k_mm, 1e-6)?);
245
246 let kl = 0.5 * (trace_term + quad_term - m as f64 - log_det_s + log_det_k_mm);
247
248 let grad_mean = k_mm_inv.dot(variational_mean);
250 let grad_cov = k_mm_inv.clone();
251
252 Ok((kl, grad_mean, grad_cov))
253 }
254
255 fn compute_kl_divergence_whitened(
257 variational_mean: &Array1<f64>,
258 variational_cov: &Array2<f64>,
259 ) -> Result<(f64, Array1<f64>, Array2<f64>)> {
260 let m = variational_mean.len();
261
262 let trace_term = variational_cov.diag().sum();
264 let quad_term = variational_mean.dot(variational_mean);
265 let log_det_s = Self::log_det_from_cholesky_factor(variational_cov)?;
266
267 let kl = 0.5 * (trace_term + quad_term - m as f64 - log_det_s);
268
269 let grad_mean = variational_mean.clone();
271 let grad_cov = Array2::eye(m);
272
273 Ok((kl, grad_mean, grad_cov))
274 }
275
276 fn natural_gradient_update(
278 variational_mean: &mut Array1<f64>,
279 variational_cov_factor: &mut Array2<f64>,
280 elbo_result: &ELBOResult,
281 learning_rate: f64,
282 ) -> Result<()> {
283 *variational_mean = &*variational_mean + learning_rate * &elbo_result.grad_mean;
288
289 let cov_update = learning_rate * &elbo_result.grad_cov;
291 *variational_cov_factor = &*variational_cov_factor + &cov_update;
292
293 Ok(())
294 }
295
296 fn standard_gradient_update(
298 variational_mean: &mut Array1<f64>,
299 variational_cov_factor: &mut Array2<f64>,
300 elbo_result: &ELBOResult,
301 learning_rate: f64,
302 ) -> Result<()> {
303 *variational_mean = &*variational_mean + learning_rate * &elbo_result.grad_mean;
305 *variational_cov_factor = &*variational_cov_factor + learning_rate * &elbo_result.grad_cov;
306
307 Ok(())
308 }
309
310 fn compute_alpha_whitened(
312 _k_nm: &Array2<f64>,
313 k_mm_inv: &Array2<f64>,
314 variational_mean: &Array1<f64>,
315 _noise_variance: f64,
316 ) -> Result<Array1<f64>> {
317 let l_mm = KernelOps::cholesky_with_jitter(k_mm_inv, 1e-6)?;
320
321 let alpha = k_mm_inv.dot(&l_mm.dot(variational_mean));
322 Ok(alpha)
323 }
324
325 fn compute_alpha_standard(
327 k_mm_inv: &Array2<f64>,
328 variational_mean: &Array1<f64>,
329 ) -> Result<Array1<f64>> {
330 Ok(k_mm_inv.dot(variational_mean))
331 }
332
333 fn log_det_from_cholesky_factor(matrix: &Array2<f64>) -> Result<f64> {
335 let eigenvals = matrix.diag();
337 let log_det = eigenvals.mapv(|x| x.abs().max(1e-12).ln()).sum();
338 Ok(log_det)
339 }
340}
341
342#[derive(Debug, Clone)]
344pub struct ELBOResult {
345 pub elbo: f64,
347 pub log_likelihood: f64,
349 pub kl_divergence: f64,
351 pub grad_mean: Array1<f64>,
353 pub grad_cov: Array2<f64>,
355}
356
357pub struct StochasticVariationalInference;
359
360impl StochasticVariationalInference {
361 pub fn fit<K: SparseKernel>(
363 x: &Array2<f64>,
364 y: &Array1<f64>,
365 inducing_points: &Array2<f64>,
366 kernel: &K,
367 noise_variance: f64,
368 batch_size: usize,
369 max_iter: usize,
370 learning_rate: f64,
371 ) -> Result<(Array1<f64>, Array2<f64>, VariationalParams)> {
372 let n = x.nrows();
373 let m = inducing_points.nrows();
374
375 let mut variational_mean = Array1::zeros(m);
377 let mut variational_cov_factor = Array2::eye(m);
378
379 let k_mm = kernel.kernel_matrix(inducing_points, inducing_points);
381 let k_mm_inv = KernelOps::invert_using_cholesky(&k_mm)?;
382
383 let mut rng = thread_rng();
384
385 for iter in 0..max_iter {
387 let batch_indices = Self::sample_batch(&mut rng, n, batch_size);
389 let (x_batch, y_batch) = Self::extract_batch(x, y, &batch_indices);
390
391 let k_batch_m = kernel.kernel_matrix(&x_batch, inducing_points);
393 let k_batch_diag = kernel.kernel_diagonal(&x_batch);
394
395 let elbo_result = VariationalFreeEnergy::compute_elbo_and_gradients(
397 &y_batch,
398 &k_batch_m,
399 &k_batch_diag,
400 &k_mm,
401 &k_mm_inv,
402 &variational_mean,
403 &variational_cov_factor,
404 noise_variance,
405 false, )?;
407
408 let scale_factor = n as f64 / batch_size as f64;
410 let scaled_grad_mean = &elbo_result.grad_mean * scale_factor;
411 let scaled_grad_cov = &elbo_result.grad_cov * scale_factor;
412
413 variational_mean = &variational_mean + learning_rate * &scaled_grad_mean;
415 variational_cov_factor = &variational_cov_factor + learning_rate * &scaled_grad_cov;
416
417 if iter % 100 == 0 {
419 }
421 }
422
423 let alpha = VariationalFreeEnergy::compute_alpha_standard(&k_mm_inv, &variational_mean)?;
425
426 let _variational_cov = variational_cov_factor.dot(&variational_cov_factor.t());
427
428 let k_nm = kernel.kernel_matrix(x, inducing_points);
430 let k_diag = kernel.kernel_diagonal(x);
431
432 let final_elbo = VariationalFreeEnergy::compute_elbo_and_gradients(
433 y,
434 &k_nm,
435 &k_diag,
436 &k_mm,
437 &k_mm_inv,
438 &variational_mean,
439 &variational_cov_factor,
440 noise_variance,
441 false,
442 )?;
443
444 let vfe_params = VariationalParams {
445 mean: variational_mean,
446 cov_factor: variational_cov_factor,
447 elbo: final_elbo.elbo,
448 kl_divergence: final_elbo.kl_divergence,
449 log_likelihood: final_elbo.log_likelihood,
450 };
451
452 Ok((alpha, k_mm_inv, vfe_params))
453 }
454
455 fn sample_batch(rng: &mut impl Rng, n: usize, batch_size: usize) -> Vec<usize> {
457 let mut indices: Vec<usize> = (0..n).collect();
458
459 for i in (1..n).rev() {
461 let j = rng.random_range(0..i + 1);
462 indices.swap(i, j);
463 }
464
465 indices.into_iter().take(batch_size).collect()
466 }
467
468 fn extract_batch(
470 x: &Array2<f64>,
471 y: &Array1<f64>,
472 indices: &[usize],
473 ) -> (Array2<f64>, Array1<f64>) {
474 let batch_size = indices.len();
475 let n_features = x.ncols();
476
477 let mut x_batch = Array2::zeros((batch_size, n_features));
478 let mut y_batch = Array1::zeros(batch_size);
479
480 for (i, &idx) in indices.iter().enumerate() {
481 x_batch.row_mut(i).assign(&x.row(idx));
482 y_batch[i] = y[idx];
483 }
484
485 (x_batch, y_batch)
486 }
487}
488
489pub mod variational_utils {
491 use super::*;
492
493 pub fn initialize_variational_params<K: SparseKernel>(
495 x: &Array2<f64>,
496 y: &Array1<f64>,
497 inducing_points: &Array2<f64>,
498 kernel: &K,
499 noise_variance: f64,
500 ) -> Result<(Array1<f64>, Array2<f64>)> {
501 let m = inducing_points.nrows();
502
503 let k_mm = kernel.kernel_matrix(inducing_points, inducing_points);
505 let k_nm = kernel.kernel_matrix(x, inducing_points);
506
507 let _k_mm_inv = KernelOps::invert_using_cholesky(&k_mm)?;
508
509 let mut k_mm_reg = k_mm.clone();
511 for i in 0..m {
512 k_mm_reg[(i, i)] += noise_variance;
513 }
514
515 let k_mm_reg_inv = KernelOps::invert_using_cholesky(&k_mm_reg)?;
516 let initial_mean = k_mm_reg_inv.dot(&k_nm.t()).dot(y);
517
518 let initial_cov_factor = Array2::eye(m) * 0.1;
520
521 Ok((initial_mean, initial_cov_factor))
522 }
523
524 pub fn predictive_moments<K: SparseKernel>(
526 x_test: &Array2<f64>,
527 inducing_points: &Array2<f64>,
528 kernel: &K,
529 vfe_params: &VariationalParams,
530 noise_variance: f64,
531 ) -> Result<(Array1<f64>, Array1<f64>)> {
532 let k_star_m = kernel.kernel_matrix(x_test, inducing_points);
533 let k_star_star = kernel.kernel_diagonal(x_test);
534
535 let pred_mean = k_star_m.dot(&vfe_params.mean);
537
538 let variational_cov = vfe_params.cov_factor.dot(&vfe_params.cov_factor.t());
540 let epistemic_var = compute_epistemic_variance(&k_star_m, &variational_cov);
541
542 let pred_var = &k_star_star - &epistemic_var + noise_variance;
543
544 Ok((pred_mean, pred_var))
545 }
546
547 fn compute_epistemic_variance(
549 k_star_m: &Array2<f64>,
550 variational_cov: &Array2<f64>,
551 ) -> Array1<f64> {
552 let temp = k_star_m.dot(variational_cov);
553 let mut epistemic_var = Array1::zeros(k_star_m.nrows());
554
555 for i in 0..k_star_m.nrows() {
556 epistemic_var[i] = k_star_m.row(i).dot(&temp.row(i));
557 }
558
559 epistemic_var
560 }
561}
562
563#[allow(non_snake_case)]
564#[cfg(test)]
565mod tests {
566 use super::*;
567 use crate::sparse_gp::kernels::RBFKernel;
568 use scirs2_core::ndarray::array;
569
570 #[test]
571 fn test_vfe_initialization() {
572 let kernel = RBFKernel::new(1.0, 1.0);
573 let x = array![[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]];
574 let y = array![0.0, 1.0, 4.0];
575 let inducing_points = array![[0.0, 0.0], [2.0, 2.0]];
576
577 let (mean, cov_factor) = variational_utils::initialize_variational_params(
578 &x,
579 &y,
580 &inducing_points,
581 &kernel,
582 0.1,
583 )
584 .expect("operation should succeed");
585
586 assert_eq!(mean.len(), 2);
587 assert_eq!(cov_factor.shape(), &[2, 2]);
588 assert!(mean.iter().all(|&x| x.is_finite()));
589 }
590
591 #[test]
592 fn test_kl_divergence_whitened() {
593 let mean = array![0.5, -0.3];
594 let cov_factor = Array2::eye(2) * 0.8;
595 let cov = cov_factor.dot(&cov_factor.t());
596
597 let (kl, grad_mean, grad_cov) =
598 VariationalFreeEnergy::compute_kl_divergence_whitened(&mean, &cov)
599 .expect("operation should succeed");
600
601 assert!(kl >= 0.0); assert_eq!(grad_mean.len(), 2);
603 assert_eq!(grad_cov.shape(), &[2, 2]);
604 }
605
606 #[test]
607 fn test_expected_log_likelihood() {
608 let y = array![0.0, 1.0, 2.0];
609 let k_nm = array![[1.0, 0.5], [0.8, 0.3], [0.6, 0.9]];
610 let k_diag = array![1.0, 1.0, 1.0];
611 let mean = array![0.5, 0.3];
612 let cov = Array2::eye(2) * 0.1;
613
614 let (ll, grad_mean, grad_cov) = VariationalFreeEnergy::compute_expected_log_likelihood(
615 &y, &k_nm, &k_diag, &mean, &cov, 0.1,
616 )
617 .expect("operation should succeed");
618
619 assert!(ll.is_finite());
620 assert_eq!(grad_mean.len(), 2);
621 assert_eq!(grad_cov.shape(), &[2, 2]);
622 }
623
624 #[test]
625 fn test_stochastic_batch_sampling() {
626 let mut rng = thread_rng();
627 let batch_indices = StochasticVariationalInference::sample_batch(&mut rng, 10, 3);
628
629 assert_eq!(batch_indices.len(), 3);
630 assert!(batch_indices.iter().all(|&i| i < 10));
631
632 let mut sorted_indices = batch_indices.clone();
634 sorted_indices.sort();
635 sorted_indices.dedup();
636 assert_eq!(sorted_indices.len(), batch_indices.len());
637 }
638
639 #[test]
640 fn test_predictive_moments() {
641 let kernel = RBFKernel::new(1.0, 1.0);
642 let x_test = array![[0.5, 0.5], [1.5, 1.5]];
643 let inducing_points = array![[0.0, 0.0], [2.0, 2.0]];
644
645 let vfe_params = VariationalParams {
646 mean: array![0.5, 0.3],
647 cov_factor: Array2::eye(2) * 0.1,
648 elbo: -10.5,
649 kl_divergence: 2.3,
650 log_likelihood: -12.8,
651 };
652
653 let (pred_mean, pred_var) = variational_utils::predictive_moments(
654 &x_test,
655 &inducing_points,
656 &kernel,
657 &vfe_params,
658 0.1,
659 )
660 .expect("operation should succeed");
661
662 assert_eq!(pred_mean.len(), 2);
663 assert_eq!(pred_var.len(), 2);
664 assert!(pred_mean.iter().all(|&x| x.is_finite()));
665 assert!(pred_var.iter().all(|&x| x >= 0.0 && x.is_finite()));
666 }
667}