1use crate::sparse_gp::core::*;
8use crate::sparse_gp::kernels::{KernelOps, SparseKernel};
9use scirs2_core::ndarray::{Array1, Array2};
10use scirs2_core::random::{thread_rng, Rng};
11use sklears_core::error::{Result, SklearsError};
12use std::f64::consts::PI;
13
14pub struct VariationalFreeEnergy;
16
17impl VariationalFreeEnergy {
18 pub fn fit<K: SparseKernel>(
20 x: &Array2<f64>,
21 y: &Array1<f64>,
22 inducing_points: &Array2<f64>,
23 kernel: &K,
24 noise_variance: f64,
25 whitened: bool,
26 natural_gradients: bool,
27 max_iter: usize,
28 tol: f64,
29 ) -> Result<(Array1<f64>, Array2<f64>, VariationalParams)> {
30 let _n = x.nrows();
31 let m = inducing_points.nrows();
32
33 let mut variational_mean = Array1::zeros(m);
35 let mut variational_cov_factor = Array2::eye(m);
36
37 let k_mm = kernel.kernel_matrix(inducing_points, inducing_points);
39 let k_nm = kernel.kernel_matrix(x, inducing_points);
40 let k_diag = kernel.kernel_diagonal(x);
41
42 let k_mm_inv = KernelOps::invert_using_cholesky(&k_mm)?;
43
44 let mut best_elbo = f64::NEG_INFINITY;
46 let mut best_params = None;
47
48 for _iter in 0..max_iter {
49 let old_mean = variational_mean.clone();
50 let old_cov_factor = variational_cov_factor.clone();
51
52 let elbo_result = Self::compute_elbo_and_gradients(
54 y,
55 &k_nm,
56 &k_diag,
57 &k_mm,
58 &k_mm_inv,
59 &variational_mean,
60 &variational_cov_factor,
61 noise_variance,
62 whitened,
63 )?;
64
65 if elbo_result.elbo > best_elbo {
66 best_elbo = elbo_result.elbo;
67 best_params = Some((
68 variational_mean.clone(),
69 variational_cov_factor.clone(),
70 elbo_result.clone(),
71 ));
72 }
73
74 if natural_gradients {
76 Self::natural_gradient_update(
77 &mut variational_mean,
78 &mut variational_cov_factor,
79 &elbo_result,
80 0.01, )?;
82 } else {
83 Self::standard_gradient_update(
84 &mut variational_mean,
85 &mut variational_cov_factor,
86 &elbo_result,
87 0.01, )?;
89 }
90
91 let mean_change = (&variational_mean - &old_mean).mapv(|x| x * x).sum().sqrt();
93 let cov_change = (&variational_cov_factor - &old_cov_factor)
94 .mapv(|x| x * x)
95 .sum()
96 .sqrt();
97
98 if mean_change < tol && cov_change < tol {
99 break;
100 }
101 }
102
103 let (best_mean, best_cov_factor, best_elbo_result) = best_params
104 .ok_or_else(|| SklearsError::NumericalError("VFE optimization failed".to_string()))?;
105
106 let alpha = if whitened {
108 Self::compute_alpha_whitened(&k_nm, &k_mm_inv, &best_mean, noise_variance)?
109 } else {
110 Self::compute_alpha_standard(&k_mm_inv, &best_mean)?
111 };
112
113 let vfe_params = VariationalParams {
114 mean: best_mean,
115 cov_factor: best_cov_factor,
116 elbo: best_elbo_result.elbo,
117 kl_divergence: best_elbo_result.kl_divergence,
118 log_likelihood: best_elbo_result.log_likelihood,
119 };
120
121 Ok((alpha, k_mm_inv, vfe_params))
122 }
123
124 fn compute_elbo_and_gradients(
126 y: &Array1<f64>,
127 k_nm: &Array2<f64>,
128 k_diag: &Array1<f64>,
129 k_mm: &Array2<f64>,
130 k_mm_inv: &Array2<f64>,
131 variational_mean: &Array1<f64>,
132 variational_cov_factor: &Array2<f64>,
133 noise_variance: f64,
134 whitened: bool,
135 ) -> Result<ELBOResult> {
136 let _n = y.len();
137 let _m = variational_mean.len();
138
139 let variational_cov = variational_cov_factor.dot(&variational_cov_factor.t());
141
142 let (log_likelihood, ll_grad_mean, ll_grad_cov) = Self::compute_expected_log_likelihood(
144 y,
145 k_nm,
146 k_diag,
147 variational_mean,
148 &variational_cov,
149 noise_variance,
150 )?;
151
152 let (kl_divergence, kl_grad_mean, kl_grad_cov) = if whitened {
154 Self::compute_kl_divergence_whitened(variational_mean, &variational_cov)?
155 } else {
156 Self::compute_kl_divergence_standard(
157 variational_mean,
158 &variational_cov,
159 k_mm,
160 k_mm_inv,
161 )?
162 };
163
164 let elbo = log_likelihood - kl_divergence;
166
167 let grad_mean = &ll_grad_mean - &kl_grad_mean;
169 let grad_cov = &ll_grad_cov - &kl_grad_cov;
170
171 Ok(ELBOResult {
172 elbo,
173 log_likelihood,
174 kl_divergence,
175 grad_mean,
176 grad_cov,
177 })
178 }
179
180 fn compute_expected_log_likelihood(
182 y: &Array1<f64>,
183 k_nm: &Array2<f64>,
184 k_diag: &Array1<f64>,
185 variational_mean: &Array1<f64>,
186 variational_cov: &Array2<f64>,
187 noise_variance: f64,
188 ) -> Result<(f64, Array1<f64>, Array2<f64>)> {
189 let n = y.len();
190
191 let pred_mean = k_nm.dot(variational_mean);
193
194 let k_nm_s_k_mn = k_nm.dot(variational_cov).dot(&k_nm.t());
196 let mut pred_var_diag = k_diag.clone();
197
198 for i in 0..n {
199 pred_var_diag[i] -= k_nm_s_k_mn[(i, i)];
200 pred_var_diag[i] = pred_var_diag[i].max(1e-6); }
202
203 let residuals = y - &pred_mean;
205 let total_var = &pred_var_diag + noise_variance;
206
207 let log_likelihood = -0.5
208 * (n as f64 * (2.0 * PI).ln()
209 + total_var.mapv(|x| x.ln()).sum()
210 + residuals
211 .iter()
212 .zip(total_var.iter())
213 .map(|(r, v)| r * r / v)
214 .sum::<f64>());
215
216 let grad_mean = k_nm.t().dot(&(y - &pred_mean)) / noise_variance;
218 let grad_cov = Array2::zeros(variational_cov.dim());
219
220 Ok((log_likelihood, grad_mean, grad_cov))
221 }
222
223 fn compute_kl_divergence_standard(
225 variational_mean: &Array1<f64>,
226 variational_cov: &Array2<f64>,
227 k_mm: &Array2<f64>,
228 k_mm_inv: &Array2<f64>,
229 ) -> Result<(f64, Array1<f64>, Array2<f64>)> {
230 let m = variational_mean.len();
231
232 let trace_term = (k_mm_inv * variational_cov).diag().sum();
236
237 let quad_term = variational_mean.dot(&k_mm_inv.dot(variational_mean));
239
240 let log_det_s = Self::log_det_from_cholesky_factor(variational_cov)?;
242 let log_det_k_mm =
243 KernelOps::log_det_from_cholesky(&KernelOps::cholesky_with_jitter(k_mm, 1e-6)?);
244
245 let kl = 0.5 * (trace_term + quad_term - m as f64 - log_det_s + log_det_k_mm);
246
247 let grad_mean = k_mm_inv.dot(variational_mean);
249 let grad_cov = k_mm_inv.clone();
250
251 Ok((kl, grad_mean, grad_cov))
252 }
253
254 fn compute_kl_divergence_whitened(
256 variational_mean: &Array1<f64>,
257 variational_cov: &Array2<f64>,
258 ) -> Result<(f64, Array1<f64>, Array2<f64>)> {
259 let m = variational_mean.len();
260
261 let trace_term = variational_cov.diag().sum();
263 let quad_term = variational_mean.dot(variational_mean);
264 let log_det_s = Self::log_det_from_cholesky_factor(variational_cov)?;
265
266 let kl = 0.5 * (trace_term + quad_term - m as f64 - log_det_s);
267
268 let grad_mean = variational_mean.clone();
270 let grad_cov = Array2::eye(m);
271
272 Ok((kl, grad_mean, grad_cov))
273 }
274
275 fn natural_gradient_update(
277 variational_mean: &mut Array1<f64>,
278 variational_cov_factor: &mut Array2<f64>,
279 elbo_result: &ELBOResult,
280 learning_rate: f64,
281 ) -> Result<()> {
282 *variational_mean = &*variational_mean + learning_rate * &elbo_result.grad_mean;
287
288 let cov_update = learning_rate * &elbo_result.grad_cov;
290 *variational_cov_factor = &*variational_cov_factor + &cov_update;
291
292 Ok(())
293 }
294
295 fn standard_gradient_update(
297 variational_mean: &mut Array1<f64>,
298 variational_cov_factor: &mut Array2<f64>,
299 elbo_result: &ELBOResult,
300 learning_rate: f64,
301 ) -> Result<()> {
302 *variational_mean = &*variational_mean + learning_rate * &elbo_result.grad_mean;
304 *variational_cov_factor = &*variational_cov_factor + learning_rate * &elbo_result.grad_cov;
305
306 Ok(())
307 }
308
309 fn compute_alpha_whitened(
311 _k_nm: &Array2<f64>,
312 k_mm_inv: &Array2<f64>,
313 variational_mean: &Array1<f64>,
314 _noise_variance: f64,
315 ) -> Result<Array1<f64>> {
316 let l_mm = KernelOps::cholesky_with_jitter(k_mm_inv, 1e-6)?;
319
320 let alpha = k_mm_inv.dot(&l_mm.dot(variational_mean));
321 Ok(alpha)
322 }
323
324 fn compute_alpha_standard(
326 k_mm_inv: &Array2<f64>,
327 variational_mean: &Array1<f64>,
328 ) -> Result<Array1<f64>> {
329 Ok(k_mm_inv.dot(variational_mean))
330 }
331
332 fn log_det_from_cholesky_factor(matrix: &Array2<f64>) -> Result<f64> {
334 let eigenvals = matrix.diag();
336 let log_det = eigenvals.mapv(|x| x.abs().max(1e-12).ln()).sum();
337 Ok(log_det)
338 }
339}
340
341#[derive(Debug, Clone)]
343pub struct ELBOResult {
344 pub elbo: f64,
346 pub log_likelihood: f64,
348 pub kl_divergence: f64,
350 pub grad_mean: Array1<f64>,
352 pub grad_cov: Array2<f64>,
354}
355
356pub struct StochasticVariationalInference;
358
359impl StochasticVariationalInference {
360 pub fn fit<K: SparseKernel>(
362 x: &Array2<f64>,
363 y: &Array1<f64>,
364 inducing_points: &Array2<f64>,
365 kernel: &K,
366 noise_variance: f64,
367 batch_size: usize,
368 max_iter: usize,
369 learning_rate: f64,
370 ) -> Result<(Array1<f64>, Array2<f64>, VariationalParams)> {
371 let n = x.nrows();
372 let m = inducing_points.nrows();
373
374 let mut variational_mean = Array1::zeros(m);
376 let mut variational_cov_factor = Array2::eye(m);
377
378 let k_mm = kernel.kernel_matrix(inducing_points, inducing_points);
380 let k_mm_inv = KernelOps::invert_using_cholesky(&k_mm)?;
381
382 let mut rng = thread_rng();
383
384 for iter in 0..max_iter {
386 let batch_indices = Self::sample_batch(&mut rng, n, batch_size);
388 let (x_batch, y_batch) = Self::extract_batch(x, y, &batch_indices);
389
390 let k_batch_m = kernel.kernel_matrix(&x_batch, inducing_points);
392 let k_batch_diag = kernel.kernel_diagonal(&x_batch);
393
394 let elbo_result = VariationalFreeEnergy::compute_elbo_and_gradients(
396 &y_batch,
397 &k_batch_m,
398 &k_batch_diag,
399 &k_mm,
400 &k_mm_inv,
401 &variational_mean,
402 &variational_cov_factor,
403 noise_variance,
404 false, )?;
406
407 let scale_factor = n as f64 / batch_size as f64;
409 let scaled_grad_mean = &elbo_result.grad_mean * scale_factor;
410 let scaled_grad_cov = &elbo_result.grad_cov * scale_factor;
411
412 variational_mean = &variational_mean + learning_rate * &scaled_grad_mean;
414 variational_cov_factor = &variational_cov_factor + learning_rate * &scaled_grad_cov;
415
416 if iter % 100 == 0 {
418 }
420 }
421
422 let alpha = VariationalFreeEnergy::compute_alpha_standard(&k_mm_inv, &variational_mean)?;
424
425 let _variational_cov = variational_cov_factor.dot(&variational_cov_factor.t());
426
427 let k_nm = kernel.kernel_matrix(x, inducing_points);
429 let k_diag = kernel.kernel_diagonal(x);
430
431 let final_elbo = VariationalFreeEnergy::compute_elbo_and_gradients(
432 y,
433 &k_nm,
434 &k_diag,
435 &k_mm,
436 &k_mm_inv,
437 &variational_mean,
438 &variational_cov_factor,
439 noise_variance,
440 false,
441 )?;
442
443 let vfe_params = VariationalParams {
444 mean: variational_mean,
445 cov_factor: variational_cov_factor,
446 elbo: final_elbo.elbo,
447 kl_divergence: final_elbo.kl_divergence,
448 log_likelihood: final_elbo.log_likelihood,
449 };
450
451 Ok((alpha, k_mm_inv, vfe_params))
452 }
453
454 fn sample_batch(rng: &mut impl Rng, n: usize, batch_size: usize) -> Vec<usize> {
456 let mut indices: Vec<usize> = (0..n).collect();
457
458 for i in (1..n).rev() {
460 let j = rng.gen_range(0..i + 1);
461 indices.swap(i, j);
462 }
463
464 indices.into_iter().take(batch_size).collect()
465 }
466
467 fn extract_batch(
469 x: &Array2<f64>,
470 y: &Array1<f64>,
471 indices: &[usize],
472 ) -> (Array2<f64>, Array1<f64>) {
473 let batch_size = indices.len();
474 let n_features = x.ncols();
475
476 let mut x_batch = Array2::zeros((batch_size, n_features));
477 let mut y_batch = Array1::zeros(batch_size);
478
479 for (i, &idx) in indices.iter().enumerate() {
480 x_batch.row_mut(i).assign(&x.row(idx));
481 y_batch[i] = y[idx];
482 }
483
484 (x_batch, y_batch)
485 }
486}
487
488pub mod variational_utils {
490 use super::*;
491
492 pub fn initialize_variational_params<K: SparseKernel>(
494 x: &Array2<f64>,
495 y: &Array1<f64>,
496 inducing_points: &Array2<f64>,
497 kernel: &K,
498 noise_variance: f64,
499 ) -> Result<(Array1<f64>, Array2<f64>)> {
500 let m = inducing_points.nrows();
501
502 let k_mm = kernel.kernel_matrix(inducing_points, inducing_points);
504 let k_nm = kernel.kernel_matrix(x, inducing_points);
505
506 let _k_mm_inv = KernelOps::invert_using_cholesky(&k_mm)?;
507
508 let mut k_mm_reg = k_mm.clone();
510 for i in 0..m {
511 k_mm_reg[(i, i)] += noise_variance;
512 }
513
514 let k_mm_reg_inv = KernelOps::invert_using_cholesky(&k_mm_reg)?;
515 let initial_mean = k_mm_reg_inv.dot(&k_nm.t()).dot(y);
516
517 let initial_cov_factor = Array2::eye(m) * 0.1;
519
520 Ok((initial_mean, initial_cov_factor))
521 }
522
523 pub fn predictive_moments<K: SparseKernel>(
525 x_test: &Array2<f64>,
526 inducing_points: &Array2<f64>,
527 kernel: &K,
528 vfe_params: &VariationalParams,
529 noise_variance: f64,
530 ) -> Result<(Array1<f64>, Array1<f64>)> {
531 let k_star_m = kernel.kernel_matrix(x_test, inducing_points);
532 let k_star_star = kernel.kernel_diagonal(x_test);
533
534 let pred_mean = k_star_m.dot(&vfe_params.mean);
536
537 let variational_cov = vfe_params.cov_factor.dot(&vfe_params.cov_factor.t());
539 let epistemic_var = compute_epistemic_variance(&k_star_m, &variational_cov);
540
541 let pred_var = &k_star_star - &epistemic_var + noise_variance;
542
543 Ok((pred_mean, pred_var))
544 }
545
546 fn compute_epistemic_variance(
548 k_star_m: &Array2<f64>,
549 variational_cov: &Array2<f64>,
550 ) -> Array1<f64> {
551 let temp = k_star_m.dot(variational_cov);
552 let mut epistemic_var = Array1::zeros(k_star_m.nrows());
553
554 for i in 0..k_star_m.nrows() {
555 epistemic_var[i] = k_star_m.row(i).dot(&temp.row(i));
556 }
557
558 epistemic_var
559 }
560}
561
562#[allow(non_snake_case)]
563#[cfg(test)]
564mod tests {
565 use super::*;
566 use crate::sparse_gp::kernels::RBFKernel;
567 use scirs2_core::ndarray::array;
568
569 #[test]
570 fn test_vfe_initialization() {
571 let kernel = RBFKernel::new(1.0, 1.0);
572 let x = array![[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]];
573 let y = array![0.0, 1.0, 4.0];
574 let inducing_points = array![[0.0, 0.0], [2.0, 2.0]];
575
576 let (mean, cov_factor) = variational_utils::initialize_variational_params(
577 &x,
578 &y,
579 &inducing_points,
580 &kernel,
581 0.1,
582 )
583 .unwrap();
584
585 assert_eq!(mean.len(), 2);
586 assert_eq!(cov_factor.shape(), &[2, 2]);
587 assert!(mean.iter().all(|&x| x.is_finite()));
588 }
589
590 #[test]
591 fn test_kl_divergence_whitened() {
592 let mean = array![0.5, -0.3];
593 let cov_factor = Array2::eye(2) * 0.8;
594 let cov = cov_factor.dot(&cov_factor.t());
595
596 let (kl, grad_mean, grad_cov) =
597 VariationalFreeEnergy::compute_kl_divergence_whitened(&mean, &cov).unwrap();
598
599 assert!(kl >= 0.0); assert_eq!(grad_mean.len(), 2);
601 assert_eq!(grad_cov.shape(), &[2, 2]);
602 }
603
604 #[test]
605 fn test_expected_log_likelihood() {
606 let y = array![0.0, 1.0, 2.0];
607 let k_nm = array![[1.0, 0.5], [0.8, 0.3], [0.6, 0.9]];
608 let k_diag = array![1.0, 1.0, 1.0];
609 let mean = array![0.5, 0.3];
610 let cov = Array2::eye(2) * 0.1;
611
612 let (ll, grad_mean, grad_cov) = VariationalFreeEnergy::compute_expected_log_likelihood(
613 &y, &k_nm, &k_diag, &mean, &cov, 0.1,
614 )
615 .unwrap();
616
617 assert!(ll.is_finite());
618 assert_eq!(grad_mean.len(), 2);
619 assert_eq!(grad_cov.shape(), &[2, 2]);
620 }
621
622 #[test]
623 fn test_stochastic_batch_sampling() {
624 let mut rng = thread_rng();
625 let batch_indices = StochasticVariationalInference::sample_batch(&mut rng, 10, 3);
626
627 assert_eq!(batch_indices.len(), 3);
628 assert!(batch_indices.iter().all(|&i| i < 10));
629
630 let mut sorted_indices = batch_indices.clone();
632 sorted_indices.sort();
633 sorted_indices.dedup();
634 assert_eq!(sorted_indices.len(), batch_indices.len());
635 }
636
637 #[test]
638 fn test_predictive_moments() {
639 let kernel = RBFKernel::new(1.0, 1.0);
640 let x_test = array![[0.5, 0.5], [1.5, 1.5]];
641 let inducing_points = array![[0.0, 0.0], [2.0, 2.0]];
642
643 let vfe_params = VariationalParams {
644 mean: array![0.5, 0.3],
645 cov_factor: Array2::eye(2) * 0.1,
646 elbo: -10.5,
647 kl_divergence: 2.3,
648 log_likelihood: -12.8,
649 };
650
651 let (pred_mean, pred_var) = variational_utils::predictive_moments(
652 &x_test,
653 &inducing_points,
654 &kernel,
655 &vfe_params,
656 0.1,
657 )
658 .unwrap();
659
660 assert_eq!(pred_mean.len(), 2);
661 assert_eq!(pred_var.len(), 2);
662 assert!(pred_mean.iter().all(|&x| x.is_finite()));
663 assert!(pred_var.iter().all(|&x| x >= 0.0 && x.is_finite()));
664 }
665}