scirs2_stats/distributions/multivariate/student_t.rs
1//! Multivariate Student's t-distribution functions
2//!
3//! This module provides functionality for the Multivariate Student's t-distribution.
4
5use crate::error::{StatsError, StatsResult};
6use crate::sampling::SampleableDistribution;
7use scirs2_core::ndarray::{
8 s, Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Axis, Data, Ix1, Ix2,
9};
10use scirs2_core::random::prelude::*;
11use scirs2_core::random::{ChiSquared, Distribution, Normal as RandNormal};
12use std::fmt::Debug;
13
14// Import the helper functions used by MultivariateNormal
15use super::normal::{compute_cholesky, compute_inverse_from_cholesky};
16
17// Implementation of the natural logarithm of the gamma function
18// This is a workaround for the unstable gamma function in Rust
19#[allow(dead_code)]
20fn lgamma(x: f64) -> f64 {
21 if x <= 0.0 {
22 panic!("lgamma requires positive input");
23 }
24
25 // For integers, we can use a simpler calculation
26 if x.fract() == 0.0 && x <= 20.0 {
27 let n = x as usize;
28 if n == 1 || n == 2 {
29 return 0.0; // ln(1) = 0
30 }
31
32 let mut result = 0.0;
33 for i in 2..n {
34 result += (i as f64).ln();
35 }
36 return result;
37 }
38
39 // For x = 0.5, we have Γ(0.5) = sqrt(π)
40 if (x - 0.5).abs() < 1e-10 {
41 return (std::f64::consts::PI.sqrt()).ln();
42 }
43
44 // For x > 1, use the recurrence relation: Γ(x+1) = x * Γ(x)
45 if x > 1.0 {
46 return (x - 1.0).ln() + lgamma(x - 1.0);
47 }
48
49 // For 0 < x < 1, use the reflection formula: Γ(x) * Γ(1-x) = π/sin(πx)
50 if x < 1.0 {
51 return (std::f64::consts::PI / (std::f64::consts::PI * x).sin()).ln() - lgamma(1.0 - x);
52 }
53
54 // Lanczos approximation for other values around 1
55 let p = [
56 676.5203681218851,
57 -1259.1392167224028,
58 771.323_428_777_653_1,
59 -176.615_029_162_140_6,
60 12.507343278686905,
61 -0.13857109526572012,
62 9.984_369_578_019_572e-6,
63 1.5056327351493116e-7,
64 ];
65
66 let x_adj = x - 1.0;
67 let t = x_adj + 7.5;
68
69 let mut sum = 0.0;
70 for (i, &coef) in p.iter().enumerate() {
71 sum += coef / (x_adj + (i + 1) as f64);
72 }
73
74 let pi = std::f64::consts::PI;
75 let sqrt_2pi = (2.0 * pi).sqrt();
76
77 sqrt_2pi.ln() + sum.ln() + (x_adj + 0.5) * t.ln() - t
78}
79
80/// Multivariate Student's t-distribution structure
81#[derive(Debug, Clone)]
82pub struct MultivariateT {
83 /// Mean vector
84 pub mean: Array1<f64>,
85 /// Scale matrix (like covariance but scaled by df/(df-2) for df > 2)
86 pub scale: Array2<f64>,
87 /// Dimensionality of the distribution
88 pub dim: usize,
89 /// Degrees of freedom
90 pub df: f64,
91 /// Cholesky decomposition of the scale matrix (lower triangular)
92 cholesky_l: Array2<f64>,
93 /// Determinant of the scale matrix
94 scale_det: f64,
95 /// Inverse of the scale matrix
96 scale_inv: Array2<f64>,
97}
98
99impl MultivariateT {
100 /// Create a new multivariate Student's t-distribution with given parameters
101 ///
102 /// # Arguments
103 ///
104 /// * `mean` - Mean vector (k-dimensional)
105 /// * `scale` - Scale matrix (k x k, symmetric positive-definite)
106 /// * `df` - Degrees of freedom (> 0)
107 ///
108 /// # Returns
109 ///
110 /// * A new MultivariateT distribution instance
111 ///
112 /// # Examples
113 ///
114 /// ```
115 /// use scirs2_core::ndarray::array;
116 /// use scirs2_stats::distributions::multivariate::student_t::MultivariateT;
117 ///
118 /// // Create a 2D multivariate Student's t-distribution with 5 degrees of freedom
119 /// let mean = array![0.0, 0.0];
120 /// let scale = array![[1.0, 0.5], [0.5, 2.0]];
121 /// let mvt = MultivariateT::new(mean, scale, 5.0).expect("Operation failed");
122 /// ```
123 pub fn new<D1, D2>(
124 mean: ArrayBase<D1, Ix1>,
125 scale: ArrayBase<D2, Ix2>,
126 df: f64,
127 ) -> StatsResult<Self>
128 where
129 D1: Data<Elem = f64>,
130 D2: Data<Elem = f64>,
131 {
132 // Validate dimensions
133 let dim = mean.len();
134 if scale.shape()[0] != dim || scale.shape()[1] != dim {
135 return Err(StatsError::DimensionMismatch(format!(
136 "Scale matrix shape ({:?}) must match mean vector length ({})",
137 scale.shape(),
138 dim
139 )));
140 }
141
142 // Validate degrees of freedom
143 if df <= 0.0 {
144 return Err(StatsError::DomainError(
145 "Degrees of freedom must be positive".to_string(),
146 ));
147 }
148
149 // Create owned copies of inputs
150 let mean = mean.to_owned();
151 let scale = scale.to_owned();
152
153 // Compute Cholesky decomposition (lower triangular L where Σ = L·L^T)
154 let cholesky_l = compute_cholesky(&scale).map_err(|_| {
155 StatsError::DomainError("Scale matrix must be positive definite".to_string())
156 })?;
157
158 // For positive definite matrix, det(Σ) = det(L)^2 = prod(diag(L))^2
159 let scale_det = {
160 let mut det = 1.0;
161 for i in 0..dim {
162 det *= cholesky_l[[i, i]];
163 }
164 det * det // Square it since det(Σ) = det(L)^2
165 };
166
167 // Compute inverse using Cholesky decomposition
168 let scale_inv = compute_inverse_from_cholesky(&cholesky_l).map_err(|_| {
169 StatsError::ComputationError("Failed to compute matrix inverse".to_string())
170 })?;
171
172 Ok(MultivariateT {
173 mean,
174 scale,
175 dim,
176 df,
177 cholesky_l,
178 scale_det,
179 scale_inv,
180 })
181 }
182
183 /// Calculate the probability density function (PDF) at a given point
184 ///
185 /// # Arguments
186 ///
187 /// * `x` - The point at which to evaluate the PDF
188 ///
189 /// # Returns
190 ///
191 /// * The value of the PDF at the given point
192 ///
193 /// # Examples
194 ///
195 /// ```
196 /// use scirs2_core::ndarray::array;
197 /// use scirs2_stats::distributions::multivariate::student_t::MultivariateT;
198 ///
199 /// let mean = array![0.0, 0.0];
200 /// let scale = array![[1.0, 0.0], [0.0, 1.0]];
201 /// let mvt = MultivariateT::new(mean, scale, 5.0).expect("Operation failed");
202 ///
203 /// // PDF at origin
204 /// let pdf_at_origin = mvt.pdf(&array![0.0, 0.0]);
205 /// ```
206 pub fn pdf<D>(&self, x: &ArrayBase<D, Ix1>) -> f64
207 where
208 D: Data<Elem = f64>,
209 {
210 if x.len() != self.dim {
211 return 0.0; // Return zero for invalid dimensions
212 }
213
214 let pi = std::f64::consts::PI;
215
216 // Calculate the constant part of the PDF
217 let gamma_term_num = lgamma((self.df + self.dim as f64) / 2.0).exp();
218 let gamma_term_denom = lgamma(self.df / 2.0).exp()
219 * lgamma(self.dim as f64 / 2.0).exp()
220 * self.df.powf(self.dim as f64 / 2.0);
221 let constant_factor = gamma_term_num
222 / (gamma_term_denom * pi.powf(self.dim as f64 / 2.0) * self.scale_det.sqrt());
223
224 // Calculate Mahalanobis distance: (x - μ)^T Σ^-1 (x - μ)
225 let diff = x - &self.mean;
226 let mahalanobis_squared = self.mahalanobis_distance_squared(&diff.view());
227
228 // PDF = C * [1 + (1/v) * dist]^(-(v+p)/2)
229 // where C is a normalization constant, v is df, p is dimension, and dist is Mahalanobis distance squared
230 constant_factor
231 * (1.0 + mahalanobis_squared / self.df).powf(-(self.df + self.dim as f64) / 2.0)
232 }
233
234 /// Calculate the Mahalanobis distance squared: (x - μ)^T Σ^-1 (x - μ)
235 fn mahalanobis_distance_squared(&self, diff: &ArrayView1<f64>) -> f64 {
236 // Compute (x - μ)^T Σ^-1 (x - μ)
237 diff.dot(&self.scale_inv.dot(diff))
238 }
239
240 /// Generate random samples from the distribution
241 ///
242 /// # Arguments
243 ///
244 /// * `size` - Number of samples to generate
245 ///
246 /// # Returns
247 ///
248 /// * Matrix where each row is a random sample
249 ///
250 /// # Examples
251 ///
252 /// ```
253 /// use scirs2_core::ndarray::array;
254 /// use scirs2_stats::distributions::multivariate::student_t::MultivariateT;
255 ///
256 /// let mean = array![0.0, 0.0];
257 /// let scale = array![[1.0, 0.5], [0.5, 2.0]];
258 /// let mvt = MultivariateT::new(mean, scale, 5.0).expect("Operation failed");
259 ///
260 /// let samples = mvt.rvs(100).expect("Operation failed");
261 /// assert_eq!(samples.shape(), &[100, 2]);
262 /// ```
263 pub fn rvs(&self, size: usize) -> StatsResult<Array2<f64>> {
264 let mut rng = thread_rng();
265 let normal_dist = RandNormal::new(0.0, 1.0).expect("Operation failed");
266 let chi2_dist = ChiSquared::new(self.df).expect("Operation failed");
267
268 // Create a matrix for the samples
269 let mut samples = Array2::<f64>::zeros((size, self.dim));
270
271 // For each sample
272 for i in 0..size {
273 // Generate standard normal samples for each dimension
274 let mut z = Array1::<f64>::zeros(self.dim);
275 for j in 0..self.dim {
276 z[j] = normal_dist.sample(&mut rng);
277 }
278
279 // Generate chi-square sample with df degrees of freedom
280 let w = chi2_dist.sample(&mut rng);
281
282 // Transform Z using the Cholesky decomposition
283 let mut transformed = Array1::<f64>::zeros(self.dim);
284 for j in 0..self.dim {
285 for k in 0..=j {
286 transformed[j] += self.cholesky_l[[j, k]] * z[k];
287 }
288 }
289
290 // Apply the t-distribution scaling
291 let scaling_factor = (self.df / w).sqrt();
292 for j in 0..self.dim {
293 samples[[i, j]] = self.mean[j] + transformed[j] * scaling_factor;
294 }
295 }
296
297 Ok(samples)
298 }
299
300 /// Generate a single random sample from the distribution
301 ///
302 /// # Returns
303 ///
304 /// * Vector representing a single sample
305 ///
306 /// # Examples
307 ///
308 /// ```
309 /// use scirs2_core::ndarray::array;
310 /// use scirs2_stats::distributions::multivariate::student_t::MultivariateT;
311 ///
312 /// let mean = array![0.0, 0.0];
313 /// let scale = array![[1.0, 0.5], [0.5, 2.0]];
314 /// let mvt = MultivariateT::new(mean, scale, 5.0).expect("Operation failed");
315 ///
316 /// let sample = mvt.rvs_single().expect("Operation failed");
317 /// assert_eq!(sample.len(), 2);
318 /// ```
319 pub fn rvs_single(&self) -> StatsResult<Array1<f64>> {
320 let samples = self.rvs(1)?;
321 Ok(samples.index_axis(Axis(0), 0).to_owned())
322 }
323
324 /// Calculate the log probability density function (log PDF) at a given point
325 ///
326 /// # Arguments
327 ///
328 /// * `x` - The point at which to evaluate the log PDF
329 ///
330 /// # Returns
331 ///
332 /// * The value of the log PDF at the given point
333 ///
334 /// # Examples
335 ///
336 /// ```
337 /// use scirs2_core::ndarray::array;
338 /// use scirs2_stats::distributions::multivariate::student_t::MultivariateT;
339 ///
340 /// let mean = array![0.0, 0.0];
341 /// let scale = array![[1.0, 0.0], [0.0, 1.0]];
342 /// let mvt = MultivariateT::new(mean, scale, 5.0).expect("Operation failed");
343 ///
344 /// let log_pdf = mvt.logpdf(&array![0.0, 0.0]);
345 /// ```
346 pub fn logpdf<D>(&self, x: &ArrayBase<D, Ix1>) -> f64
347 where
348 D: Data<Elem = f64>,
349 {
350 if x.len() != self.dim {
351 return f64::NEG_INFINITY; // Return -∞ for invalid dimensions
352 }
353
354 let pi = std::f64::consts::PI;
355
356 // Calculate the constant part of the log PDF
357 let gamma_term_num = lgamma((self.df + self.dim as f64) / 2.0);
358 let gamma_term_denom = lgamma(self.df / 2.0)
359 + lgamma(self.dim as f64 / 2.0)
360 + (self.dim as f64 / 2.0) * self.df.ln();
361 let log_const = gamma_term_num
362 - gamma_term_denom
363 - (self.dim as f64 / 2.0) * pi.ln()
364 - 0.5 * self.scale_det.ln();
365
366 // Calculate Mahalanobis distance: (x - μ)^T Σ^-1 (x - μ)
367 let diff = x - &self.mean;
368 let mahalanobis_squared = self.mahalanobis_distance_squared(&diff.view());
369
370 // log(PDF) = log(C) - ((v+p)/2) * log(1 + (1/v) * dist)
371 log_const - ((self.df + self.dim as f64) / 2.0) * (1.0 + mahalanobis_squared / self.df).ln()
372 }
373
374 /// Get the dimension of the distribution
375 pub fn dim(&self) -> usize {
376 self.dim
377 }
378
379 /// Get the scale matrix of the distribution
380 pub fn scale(&self) -> ArrayView2<f64> {
381 self.scale.view()
382 }
383
384 /// Get the mean vector of the distribution
385 pub fn mean(&self) -> ArrayView1<f64> {
386 self.mean.view()
387 }
388
389 /// Get the degrees of freedom of the distribution
390 pub fn df(&self) -> f64 {
391 self.df
392 }
393}
394
395/// Create a multivariate Student's t-distribution with the given parameters.
396///
397/// This is a convenience function to create a multivariate Student's t-distribution with
398/// the given mean vector, scale matrix, and degrees of freedom.
399///
400/// # Arguments
401///
402/// * `mean` - Mean vector (k-dimensional)
403/// * `scale` - Scale matrix (k x k, symmetric positive-definite)
404/// * `df` - Degrees of freedom (> 0)
405///
406/// # Returns
407///
408/// * A multivariate Student's t-distribution object
409///
410/// # Examples
411///
412/// ```
413/// use scirs2_core::ndarray::array;
414/// use scirs2_stats::distributions::multivariate;
415///
416/// let mean = array![0.0, 0.0];
417/// let scale = array![[1.0, 0.5], [0.5, 2.0]];
418/// let mvt = multivariate::multivariate_t(mean, scale, 5.0).expect("Operation failed");
419/// let pdf_at_origin = mvt.pdf(&array![0.0, 0.0]);
420/// ```
421#[allow(dead_code)]
422pub fn multivariate_t<D1, D2>(
423 mean: ArrayBase<D1, Ix1>,
424 scale: ArrayBase<D2, Ix2>,
425 df: f64,
426) -> StatsResult<MultivariateT>
427where
428 D1: Data<Elem = f64>,
429 D2: Data<Elem = f64>,
430{
431 MultivariateT::new(mean, scale, df)
432}
433
434/// Implementation of SampleableDistribution for MultivariateT
435impl SampleableDistribution<Array1<f64>> for MultivariateT {
436 fn rvs(&self, size: usize) -> StatsResult<Vec<Array1<f64>>> {
437 let samples_matrix = self.rvs(size)?;
438 let mut result = Vec::with_capacity(size);
439
440 for i in 0..size {
441 let row = samples_matrix.slice(s![i, ..]).to_owned();
442 result.push(row);
443 }
444
445 Ok(result)
446 }
447}
448
449#[cfg(test)]
450mod tests {
451 use super::*;
452 use approx::assert_relative_eq;
453 use scirs2_core::ndarray::{array, Axis};
454
455 #[test]
456 fn test_mvt_creation() {
457 // 2D standard multivariate t
458 let mean = array![0.0, 0.0];
459 let scale = array![[1.0, 0.0], [0.0, 1.0]];
460 let mvt = MultivariateT::new(mean.clone(), scale.clone(), 5.0).expect("Operation failed");
461
462 assert_eq!(mvt.dim, 2);
463 assert_eq!(mvt.mean, mean);
464 assert_eq!(mvt.scale, scale);
465 assert_eq!(mvt.df, 5.0);
466
467 // Custom 3D multivariate t
468 let mean3 = array![1.0, 2.0, 3.0];
469 let scale3 = array![[1.0, 0.5, 0.3], [0.5, 2.0, 0.2], [0.3, 0.2, 1.5]];
470 let mvt3 =
471 MultivariateT::new(mean3.clone(), scale3.clone(), 10.0).expect("Operation failed");
472
473 assert_eq!(mvt3.dim, 3);
474 assert_eq!(mvt3.mean, mean3);
475 assert_eq!(mvt3.scale, scale3);
476 assert_eq!(mvt3.df, 10.0);
477 }
478
479 #[test]
480 fn test_mvt_creation_errors() {
481 // Dimension mismatch
482 let mean = array![0.0, 0.0, 0.0];
483 let scale = array![[1.0, 0.0], [0.0, 1.0]];
484 assert!(MultivariateT::new(mean, scale, 5.0).is_err());
485
486 // Non-positive definite scale matrix
487 let mean = array![0.0, 0.0];
488 let scale = array![[1.0, 2.0], [2.0, 1.0]]; // Not positive definite
489 assert!(MultivariateT::new(mean, scale, 5.0).is_err());
490
491 // Invalid degrees of freedom
492 let mean = array![0.0, 0.0];
493 let scale = array![[1.0, 0.0], [0.0, 1.0]];
494 assert!(MultivariateT::new(mean.clone(), scale.clone(), 0.0).is_err());
495 assert!(MultivariateT::new(mean, scale, -1.0).is_err());
496 }
497
498 #[test]
499 fn test_mvt_pdf() {
500 // 2D standard multivariate t with 5 degrees of freedom
501 let mean = array![0.0, 0.0];
502 let scale = array![[1.0, 0.0], [0.0, 1.0]];
503 let mvt = MultivariateT::new(mean, scale, 5.0).expect("Operation failed");
504
505 // PDF at origin should be calculable
506 let pdf_at_origin = mvt.pdf(&array![0.0, 0.0]);
507 assert!(pdf_at_origin > 0.0);
508
509 // PDF at origin should be greater than at [1, 1]
510 let pdf_at_one = mvt.pdf(&array![1.0, 1.0]);
511 assert!(pdf_at_origin > pdf_at_one);
512
513 // PDF should be symmetric
514 let pdf_at_pos = mvt.pdf(&array![2.0, 1.0]);
515 let pdf_at_neg = mvt.pdf(&array![-2.0, -1.0]);
516 assert_relative_eq!(pdf_at_pos, pdf_at_neg, epsilon = 1e-10);
517 }
518
519 #[test]
520 fn test_mvt_logpdf() {
521 // 2D standard multivariate t with 5 degrees of freedom
522 let mean = array![0.0, 0.0];
523 let scale = array![[1.0, 0.0], [0.0, 1.0]];
524 let mvt = MultivariateT::new(mean, scale, 5.0).expect("Operation failed");
525
526 // Check that exp(logPDF) = PDF
527 let x = array![1.0, 1.0];
528 let pdf = mvt.pdf(&x);
529 let logpdf = mvt.logpdf(&x);
530 assert_relative_eq!(logpdf.exp(), pdf, epsilon = 1e-7);
531 }
532
533 #[test]
534 fn test_mvt_rvs() {
535 // 2D multivariate t with correlation and 10 degrees of freedom
536 let mean = array![1.0, 2.0];
537 let scale = array![[1.0, 0.5], [0.5, 2.0]];
538 let mvt = MultivariateT::new(mean, scale, 10.0).expect("Operation failed");
539
540 // Generate samples and check dimensions
541 let n_samples_ = 1000;
542 let samples = mvt.rvs(n_samples_).expect("Operation failed");
543 assert_eq!(samples.shape(), &[n_samples_, 2]);
544
545 // Check statistics (rough check as it's random and t-distribution has heavier tails)
546 let sample_mean = samples.mean_axis(Axis(0)).expect("Operation failed");
547 assert_relative_eq!(sample_mean[0], 1.0, epsilon = 0.3);
548 assert_relative_eq!(sample_mean[1], 2.0, epsilon = 0.3);
549 }
550
551 #[test]
552 fn test_mvt_rvs_single() {
553 let mean = array![1.0, 2.0];
554 let scale = array![[1.0, 0.5], [0.5, 2.0]];
555 let mvt = MultivariateT::new(mean, scale, 5.0).expect("Operation failed");
556
557 let sample = mvt.rvs_single().expect("Operation failed");
558 assert_eq!(sample.len(), 2);
559 }
560}