1use crate::error::{StatsError, StatsResult};
6use crate::sampling::SampleableDistribution;
7use crate::traits::{Distribution as DistributionTrait, MultivariateDistribution};
8use scirs2_core::ndarray::{
9 s, Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Axis, Data, Ix1, Ix2,
10};
11use scirs2_core::random::prelude::*;
12use scirs2_core::random::{Distribution as RandDistribution, Normal as RandNormal};
13use statrs::statistics::Statistics;
14use std::fmt::Debug;
15
16#[derive(Debug, Clone)]
18pub struct MultivariateNormal {
19 pub mean: Array1<f64>,
21 pub cov: Array2<f64>,
23 pub dim: usize,
25 cholesky_l: Array2<f64>,
27 cov_det: f64,
29 cov_inv: Array2<f64>,
31}
32
33impl MultivariateNormal {
34 pub fn new<D1, D2>(mean: ArrayBase<D1, Ix1>, cov: ArrayBase<D2, Ix2>) -> StatsResult<Self>
57 where
58 D1: Data<Elem = f64>,
59 D2: Data<Elem = f64>,
60 {
61 let dim = mean.len();
63 if cov.shape()[0] != dim || cov.shape()[1] != dim {
64 return Err(StatsError::DimensionMismatch(format!(
65 "Covariance matrix shape ({:?}) must match _mean vector length ({})",
66 cov.shape(),
67 dim
68 )));
69 }
70
71 let _mean = mean.to_owned();
73 let cov = cov.to_owned();
74
75 let cholesky_l = compute_cholesky(&cov).map_err(|_| {
77 StatsError::DomainError("Covariance matrix must be positive definite".to_string())
78 })?;
79
80 let cov_det = {
82 let mut det = 1.0;
83 for i in 0..dim {
84 det *= cholesky_l[[i, i]];
85 }
86 det * det };
88
89 let cov_inv = compute_inverse_from_cholesky(&cholesky_l).map_err(|_| {
91 StatsError::ComputationError("Failed to compute matrix inverse".to_string())
92 })?;
93
94 Ok(MultivariateNormal {
95 mean: _mean,
96 cov,
97 dim,
98 cholesky_l,
99 cov_det,
100 cov_inv,
101 })
102 }
103
104 pub fn pdf<D>(&self, x: &ArrayBase<D, Ix1>) -> f64
129 where
130 D: Data<Elem = f64>,
131 {
132 if x.len() != self.dim {
133 return 0.0; }
135
136 let pi = std::f64::consts::PI;
137 let two = 2.0;
138 let constant_factor = 1.0 / ((two * pi).powf(self.dim as f64 / two) * self.cov_det.sqrt());
139
140 let diff = x - &self.mean;
142 let mahalanobis_squared = self.mahalanobis_distance_squared(&diff.view());
143
144 constant_factor * (-mahalanobis_squared / two).exp()
146 }
147
148 fn mahalanobis_distance_squared(&self, diff: &ArrayView1<f64>) -> f64 {
150 diff.dot(&self.cov_inv.dot(diff))
152 }
153
154 pub fn rvs(&self, size: usize) -> StatsResult<Array2<f64>> {
178 let mut rng = thread_rng();
179 let normal = RandNormal::new(0.0, 1.0).expect("Operation failed");
180
181 let mut std_normal_samples = Array2::<f64>::zeros((size, self.dim));
183 for i in 0..size {
184 for j in 0..self.dim {
185 let sample = normal.sample(&mut rng);
186 std_normal_samples[[i, j]] = sample;
187 }
188 }
189
190 let mut samples = Array2::<f64>::zeros((size, self.dim));
193 for i in 0..size {
194 let z = std_normal_samples.slice(s![i, ..]);
196
197 let mut transformed = Array1::<f64>::zeros(self.dim);
199 for j in 0..self.dim {
200 for k in 0..=j {
201 transformed[j] += self.cholesky_l[[j, k]] * z[k];
202 }
203 }
204
205 for j in 0..self.dim {
207 samples[[i, j]] = self.mean[j] + transformed[j];
208 }
209 }
210
211 Ok(samples)
212 }
213
214 pub fn rvs_single(&self) -> StatsResult<Array1<f64>> {
234 let samples = self.rvs(1)?;
235 Ok(samples.index_axis(Axis(0), 0).to_owned())
236 }
237
238 pub fn logpdf<D>(&self, x: &ArrayBase<D, Ix1>) -> f64
262 where
263 D: Data<Elem = f64>,
264 {
265 if x.len() != self.dim {
266 return f64::NEG_INFINITY; }
268
269 let pi = std::f64::consts::PI;
270 let two = 2.0;
271
272 let log_const = -(self.dim as f64) / two * (two * pi).ln() - self.cov_det.ln() / two;
274
275 let diff = x - &self.mean;
277 let mahalanobis_squared = self.mahalanobis_distance_squared(&diff.view());
278
279 log_const - mahalanobis_squared / two
281 }
282
283 pub fn dim(&self) -> usize {
285 self.dim
286 }
287
288 pub fn cov(&self) -> ArrayView2<f64> {
290 self.cov.view()
291 }
292
293 pub fn mean(&self) -> ArrayView1<f64> {
295 self.mean.view()
296 }
297}
298
299#[allow(dead_code)]
301pub fn compute_cholesky(a: &Array2<f64>) -> Result<Array2<f64>, String> {
302 let n = a.shape()[0];
303 let mut l = Array2::<f64>::zeros((n, n));
304
305 for i in 0..n {
307 for j in 0..=i {
308 let mut sum = 0.0;
309
310 if j == i {
311 for k in 0..j {
313 sum += l[[j, k]] * l[[j, k]];
314 }
315 let diag_value = a[[j, j]] - sum;
316 if diag_value <= 0.0 {
317 return Err("Matrix is not positive definite".to_string());
318 }
319 l[[j, j]] = diag_value.sqrt();
320 } else {
321 for k in 0..j {
323 sum += l[[i, k]] * l[[j, k]];
324 }
325 l[[i, j]] = (a[[i, j]] - sum) / l[[j, j]];
326 }
327 }
328 }
329
330 Ok(l)
331}
332
333#[allow(dead_code)]
335pub fn compute_inverse_from_cholesky(l: &Array2<f64>) -> Result<Array2<f64>, String> {
336 let n = l.shape()[0];
337 let mut inv = Array2::<f64>::zeros((n, n));
338
339 let mut l_inv = Array2::<f64>::zeros((n, n));
341
342 for i in 0..n {
344 l_inv[[i, i]] = 1.0 / l[[i, i]];
345 }
346
347 for i in 1..n {
349 for j in 0..i {
350 let mut sum = 0.0;
351 for k in j..i {
352 sum += l[[i, k]] * l_inv[[k, j]];
353 }
354 l_inv[[i, j]] = -sum / l[[i, i]];
355 }
356 }
357
358 for i in 0..n {
360 for j in 0..n {
361 let mut sum = 0.0;
362 let max_idx = i.max(j);
364 for k in max_idx..n {
365 sum += l_inv[[k, i]] * l_inv[[k, j]];
366 }
367 inv[[i, j]] = sum;
368 }
369 }
370
371 Ok(inv)
372}
373
374#[allow(dead_code)]
400pub fn multivariate_normal<D1, D2>(
401 mean: ArrayBase<D1, Ix1>,
402 cov: ArrayBase<D2, Ix2>,
403) -> StatsResult<MultivariateNormal>
404where
405 D1: Data<Elem = f64>,
406 D2: Data<Elem = f64>,
407{
408 MultivariateNormal::new(mean, cov)
409}
410
411impl DistributionTrait<f64> for MultivariateNormal {
413 fn mean(&self) -> f64 {
414 if self.dim > 0 {
416 self.mean[0]
417 } else {
418 0.0
419 }
420 }
421
422 fn var(&self) -> f64 {
423 if self.dim > 0 {
425 self.cov[[0, 0]]
426 } else {
427 0.0
428 }
429 }
430
431 fn std(&self) -> f64 {
432 self.var().sqrt()
433 }
434
435 fn rvs(&self, size: usize) -> StatsResult<Array1<f64>> {
436 let samples_matrix = self.rvs(size)?;
438 Ok(samples_matrix.column(0).to_owned())
439 }
440
441 fn entropy(&self) -> f64 {
442 let k = self.dim as f64;
444 let pi = std::f64::consts::PI;
445
446 k / 2.0 + k / 2.0 * (2.0 * pi).ln() + 0.5 * self.cov_det.ln()
447 }
448}
449
450impl MultivariateDistribution<f64> for MultivariateNormal {
452 fn pdf(&self, x: &Array1<f64>) -> f64 {
453 self.pdf(x)
454 }
455
456 fn rvs(&self, size: usize) -> StatsResult<scirs2_core::ndarray::Array2<f64>> {
457 self.rvs(size)
458 }
459
460 fn mean(&self) -> Array1<f64> {
461 self.mean.clone()
462 }
463
464 fn cov(&self) -> scirs2_core::ndarray::Array2<f64> {
465 self.cov.clone()
466 }
467
468 fn dim(&self) -> usize {
469 self.dim
470 }
471
472 fn logpdf(&self, x: &Array1<f64>) -> f64 {
473 self.logpdf(x)
474 }
475
476 fn rvs_single(&self) -> StatsResult<Vec<f64>> {
477 let sample = self.rvs(1)?;
478 Ok(sample.row(0).to_vec())
479 }
480}
481
482impl SampleableDistribution<Array1<f64>> for MultivariateNormal {
484 fn rvs(&self, size: usize) -> StatsResult<Vec<Array1<f64>>> {
485 let samples_matrix = self.rvs(size)?;
486 let mut result = Vec::with_capacity(size);
487
488 for i in 0..size {
489 let row = samples_matrix.slice(s![i, ..]).to_owned();
490 result.push(row);
491 }
492
493 Ok(result)
494 }
495}
496
497#[cfg(test)]
498mod tests {
499 use super::*;
500 use approx::assert_relative_eq;
501 use scirs2_core::ndarray::{array, Axis};
502
503 #[test]
504 fn test_mvn_creation() {
505 let mean = array![0.0, 0.0];
507 let cov = array![[1.0, 0.0], [0.0, 1.0]];
508 let mvn = MultivariateNormal::new(mean.clone(), cov.clone()).expect("Operation failed");
509
510 assert_eq!(mvn.dim, 2);
511 assert_eq!(mvn.mean, mean);
512 assert_eq!(mvn.cov, cov);
513
514 let mean3 = array![1.0, 2.0, 3.0];
516 let cov3 = array![[1.0, 0.5, 0.3], [0.5, 2.0, 0.2], [0.3, 0.2, 1.5]];
517 let mvn3 = MultivariateNormal::new(mean3.clone(), cov3.clone()).expect("Operation failed");
518
519 assert_eq!(mvn3.dim, 3);
520 assert_eq!(mvn3.mean, mean3);
521 assert_eq!(mvn3.cov, cov3);
522 }
523
524 #[test]
525 fn test_mvn_creation_errors() {
526 let mean = array![0.0, 0.0, 0.0];
528 let cov = array![[1.0, 0.0], [0.0, 1.0]];
529 assert!(MultivariateNormal::new(mean, cov).is_err());
530
531 let mean = array![0.0, 0.0];
533 let cov = array![[1.0, 2.0], [2.0, 1.0]]; assert!(MultivariateNormal::new(mean, cov).is_err());
535 }
536
537 #[test]
538 fn test_mvn_pdf() {
539 let mean = array![0.0, 0.0];
541 let cov = array![[1.0, 0.0], [0.0, 1.0]];
542 let mvn = MultivariateNormal::new(mean, cov).expect("Operation failed");
543
544 let pdf_at_origin = mvn.pdf(&array![0.0, 0.0]);
546 assert_relative_eq!(pdf_at_origin, 0.15915494, epsilon = 1e-7);
547
548 let pdf_at_one = mvn.pdf(&array![1.0, 1.0]);
550 assert_relative_eq!(pdf_at_one, 0.05854983, epsilon = 1e-7);
551 }
552
553 #[test]
554 fn test_mvn_logpdf() {
555 let mean = array![0.0, 0.0];
557 let cov = array![[1.0, 0.0], [0.0, 1.0]];
558 let mvn = MultivariateNormal::new(mean, cov).expect("Operation failed");
559
560 let logpdf_at_origin = mvn.logpdf(&array![0.0, 0.0]);
562 assert_relative_eq!(logpdf_at_origin, -1.837877, epsilon = 1e-6);
563
564 let x = array![1.0, 1.0];
566 let pdf = mvn.pdf(&x);
567 let logpdf = mvn.logpdf(&x);
568 assert_relative_eq!(logpdf.exp(), pdf, epsilon = 1e-7);
569 }
570
571 #[test]
572 fn test_mvn_rvs() {
573 let mean = array![1.0, 2.0];
575 let cov = array![[1.0, 0.5], [0.5, 2.0]];
576 let mvn = MultivariateNormal::new(mean, cov).expect("Operation failed");
577
578 let n_samples_ = 500;
580 let samples = mvn.rvs(n_samples_).expect("Operation failed");
581 assert_eq!(samples.shape(), &[n_samples_, 2]);
582
583 let sample_mean = samples.mean_axis(Axis(0)).expect("Operation failed");
585 assert_relative_eq!(sample_mean[0], 1.0, epsilon = 0.3);
586 assert_relative_eq!(sample_mean[1], 2.0, epsilon = 0.3);
587
588 let centered = samples.mapv(|x| x) - &sample_mean;
590 let sample_cov = centered.t().dot(¢ered) / (n_samples_ as f64 - 1.0);
591 assert_relative_eq!(sample_cov[[0, 0]], 1.0, epsilon = 0.5);
592 assert_relative_eq!(sample_cov[[1, 1]], 2.0, epsilon = 0.5);
593 assert_relative_eq!(sample_cov[[0, 1]].abs(), 0.5, epsilon = 0.3);
594 }
595
596 #[test]
597 fn test_mvn_rvs_single() {
598 let mean = array![1.0, 2.0];
599 let cov = array![[1.0, 0.5], [0.5, 2.0]];
600 let mvn = MultivariateNormal::new(mean, cov).expect("Operation failed");
601
602 let sample = mvn.rvs_single().expect("Operation failed");
603 assert_eq!(sample.len(), 2);
604 }
605
606 #[test]
607 fn test_cholesky() {
608 let a = array![[4.0, 2.0, 2.0], [2.0, 5.0, 3.0], [2.0, 3.0, 6.0]];
610
611 let l = compute_cholesky(&a).expect("Operation failed");
612
613 let mut a_reconstructed = Array2::<f64>::zeros((3, 3));
615 for i in 0..3 {
616 for j in 0..3 {
617 for k in 0..=j.min(i) {
618 a_reconstructed[[i, j]] += l[[i, k]] * l[[j, k]];
619 }
620 }
621 }
622
623 for i in 0..3 {
625 for j in 0..3 {
626 assert_relative_eq!(a[[i, j]], a_reconstructed[[i, j]], epsilon = 1e-10);
627 }
628 }
629 }
630
631 #[test]
632 fn test_inverse() {
633 let a = array![[4.0, 2.0, 2.0], [2.0, 5.0, 3.0], [2.0, 3.0, 6.0]];
635
636 let l = compute_cholesky(&a).expect("Operation failed");
638
639 let a_inv = compute_inverse_from_cholesky(&l).expect("Operation failed");
641
642 let identity = a.dot(&a_inv);
644
645 for i in 0..3 {
646 for j in 0..3 {
647 if i == j {
648 assert_relative_eq!(identity[[i, j]], 1.0, epsilon = 1e-10);
649 } else {
650 assert_relative_eq!(identity[[i, j]], 0.0, epsilon = 1e-10);
651 }
652 }
653 }
654 }
655}