Skip to main content

scirs2_stats/distributions/multivariate/
student_t.rs

1//! Multivariate Student's t-distribution functions
2//!
3//! This module provides functionality for the Multivariate Student's t-distribution.
4
5use crate::error::{StatsError, StatsResult};
6use crate::sampling::SampleableDistribution;
7use scirs2_core::ndarray::{
8    s, Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Axis, Data, Ix1, Ix2,
9};
10use scirs2_core::random::prelude::*;
11use scirs2_core::random::{ChiSquared, Distribution, Normal as RandNormal};
12use std::fmt::Debug;
13
14// Import the helper functions used by MultivariateNormal
15use super::normal::{compute_cholesky, compute_inverse_from_cholesky};
16
17// Implementation of the natural logarithm of the gamma function
18// This is a workaround for the unstable gamma function in Rust
19#[allow(dead_code)]
20fn lgamma(x: f64) -> f64 {
21    if x <= 0.0 {
22        panic!("lgamma requires positive input");
23    }
24
25    // For integers, we can use a simpler calculation
26    if x.fract() == 0.0 && x <= 20.0 {
27        let n = x as usize;
28        if n == 1 || n == 2 {
29            return 0.0; // ln(1) = 0
30        }
31
32        let mut result = 0.0;
33        for i in 2..n {
34            result += (i as f64).ln();
35        }
36        return result;
37    }
38
39    // For x = 0.5, we have Γ(0.5) = sqrt(π)
40    if (x - 0.5).abs() < 1e-10 {
41        return (std::f64::consts::PI.sqrt()).ln();
42    }
43
44    // For x > 1, use the recurrence relation: Γ(x+1) = x * Γ(x)
45    if x > 1.0 {
46        return (x - 1.0).ln() + lgamma(x - 1.0);
47    }
48
49    // For 0 < x < 1, use the reflection formula: Γ(x) * Γ(1-x) = π/sin(πx)
50    if x < 1.0 {
51        return (std::f64::consts::PI / (std::f64::consts::PI * x).sin()).ln() - lgamma(1.0 - x);
52    }
53
54    // Lanczos approximation for other values around 1
55    let p = [
56        676.5203681218851,
57        -1259.1392167224028,
58        771.323_428_777_653_1,
59        -176.615_029_162_140_6,
60        12.507343278686905,
61        -0.13857109526572012,
62        9.984_369_578_019_572e-6,
63        1.5056327351493116e-7,
64    ];
65
66    let x_adj = x - 1.0;
67    let t = x_adj + 7.5;
68
69    let mut sum = 0.0;
70    for (i, &coef) in p.iter().enumerate() {
71        sum += coef / (x_adj + (i + 1) as f64);
72    }
73
74    let pi = std::f64::consts::PI;
75    let sqrt_2pi = (2.0 * pi).sqrt();
76
77    sqrt_2pi.ln() + sum.ln() + (x_adj + 0.5) * t.ln() - t
78}
79
80/// Multivariate Student's t-distribution structure
81#[derive(Debug, Clone)]
82pub struct MultivariateT {
83    /// Mean vector
84    pub mean: Array1<f64>,
85    /// Scale matrix (like covariance but scaled by df/(df-2) for df > 2)
86    pub scale: Array2<f64>,
87    /// Dimensionality of the distribution
88    pub dim: usize,
89    /// Degrees of freedom
90    pub df: f64,
91    /// Cholesky decomposition of the scale matrix (lower triangular)
92    cholesky_l: Array2<f64>,
93    /// Determinant of the scale matrix
94    scale_det: f64,
95    /// Inverse of the scale matrix
96    scale_inv: Array2<f64>,
97}
98
99impl MultivariateT {
100    /// Create a new multivariate Student's t-distribution with given parameters
101    ///
102    /// # Arguments
103    ///
104    /// * `mean` - Mean vector (k-dimensional)
105    /// * `scale` - Scale matrix (k x k, symmetric positive-definite)
106    /// * `df` - Degrees of freedom (> 0)
107    ///
108    /// # Returns
109    ///
110    /// * A new MultivariateT distribution instance
111    ///
112    /// # Examples
113    ///
114    /// ```
115    /// use scirs2_core::ndarray::array;
116    /// use scirs2_stats::distributions::multivariate::student_t::MultivariateT;
117    ///
118    /// // Create a 2D multivariate Student's t-distribution with 5 degrees of freedom
119    /// let mean = array![0.0, 0.0];
120    /// let scale = array![[1.0, 0.5], [0.5, 2.0]];
121    /// let mvt = MultivariateT::new(mean, scale, 5.0).expect("Operation failed");
122    /// ```
123    pub fn new<D1, D2>(
124        mean: ArrayBase<D1, Ix1>,
125        scale: ArrayBase<D2, Ix2>,
126        df: f64,
127    ) -> StatsResult<Self>
128    where
129        D1: Data<Elem = f64>,
130        D2: Data<Elem = f64>,
131    {
132        // Validate dimensions
133        let dim = mean.len();
134        if scale.shape()[0] != dim || scale.shape()[1] != dim {
135            return Err(StatsError::DimensionMismatch(format!(
136                "Scale matrix shape ({:?}) must match mean vector length ({})",
137                scale.shape(),
138                dim
139            )));
140        }
141
142        // Validate degrees of freedom
143        if df <= 0.0 {
144            return Err(StatsError::DomainError(
145                "Degrees of freedom must be positive".to_string(),
146            ));
147        }
148
149        // Create owned copies of inputs
150        let mean = mean.to_owned();
151        let scale = scale.to_owned();
152
153        // Compute Cholesky decomposition (lower triangular L where Σ = L·L^T)
154        let cholesky_l = compute_cholesky(&scale).map_err(|_| {
155            StatsError::DomainError("Scale matrix must be positive definite".to_string())
156        })?;
157
158        // For positive definite matrix, det(Σ) = det(L)^2 = prod(diag(L))^2
159        let scale_det = {
160            let mut det = 1.0;
161            for i in 0..dim {
162                det *= cholesky_l[[i, i]];
163            }
164            det * det // Square it since det(Σ) = det(L)^2
165        };
166
167        // Compute inverse using Cholesky decomposition
168        let scale_inv = compute_inverse_from_cholesky(&cholesky_l).map_err(|_| {
169            StatsError::ComputationError("Failed to compute matrix inverse".to_string())
170        })?;
171
172        Ok(MultivariateT {
173            mean,
174            scale,
175            dim,
176            df,
177            cholesky_l,
178            scale_det,
179            scale_inv,
180        })
181    }
182
183    /// Calculate the probability density function (PDF) at a given point
184    ///
185    /// # Arguments
186    ///
187    /// * `x` - The point at which to evaluate the PDF
188    ///
189    /// # Returns
190    ///
191    /// * The value of the PDF at the given point
192    ///
193    /// # Examples
194    ///
195    /// ```
196    /// use scirs2_core::ndarray::array;
197    /// use scirs2_stats::distributions::multivariate::student_t::MultivariateT;
198    ///
199    /// let mean = array![0.0, 0.0];
200    /// let scale = array![[1.0, 0.0], [0.0, 1.0]];
201    /// let mvt = MultivariateT::new(mean, scale, 5.0).expect("Operation failed");
202    ///
203    /// // PDF at origin
204    /// let pdf_at_origin = mvt.pdf(&array![0.0, 0.0]);
205    /// ```
206    pub fn pdf<D>(&self, x: &ArrayBase<D, Ix1>) -> f64
207    where
208        D: Data<Elem = f64>,
209    {
210        if x.len() != self.dim {
211            return 0.0; // Return zero for invalid dimensions
212        }
213
214        let pi = std::f64::consts::PI;
215
216        // Calculate the constant part of the PDF
217        let gamma_term_num = lgamma((self.df + self.dim as f64) / 2.0).exp();
218        let gamma_term_denom = lgamma(self.df / 2.0).exp()
219            * lgamma(self.dim as f64 / 2.0).exp()
220            * self.df.powf(self.dim as f64 / 2.0);
221        let constant_factor = gamma_term_num
222            / (gamma_term_denom * pi.powf(self.dim as f64 / 2.0) * self.scale_det.sqrt());
223
224        // Calculate Mahalanobis distance: (x - μ)^T Σ^-1 (x - μ)
225        let diff = x - &self.mean;
226        let mahalanobis_squared = self.mahalanobis_distance_squared(&diff.view());
227
228        // PDF = C * [1 + (1/v) * dist]^(-(v+p)/2)
229        // where C is a normalization constant, v is df, p is dimension, and dist is Mahalanobis distance squared
230        constant_factor
231            * (1.0 + mahalanobis_squared / self.df).powf(-(self.df + self.dim as f64) / 2.0)
232    }
233
234    /// Calculate the Mahalanobis distance squared: (x - μ)^T Σ^-1 (x - μ)
235    fn mahalanobis_distance_squared(&self, diff: &ArrayView1<f64>) -> f64 {
236        // Compute (x - μ)^T Σ^-1 (x - μ)
237        diff.dot(&self.scale_inv.dot(diff))
238    }
239
240    /// Generate random samples from the distribution
241    ///
242    /// # Arguments
243    ///
244    /// * `size` - Number of samples to generate
245    ///
246    /// # Returns
247    ///
248    /// * Matrix where each row is a random sample
249    ///
250    /// # Examples
251    ///
252    /// ```
253    /// use scirs2_core::ndarray::array;
254    /// use scirs2_stats::distributions::multivariate::student_t::MultivariateT;
255    ///
256    /// let mean = array![0.0, 0.0];
257    /// let scale = array![[1.0, 0.5], [0.5, 2.0]];
258    /// let mvt = MultivariateT::new(mean, scale, 5.0).expect("Operation failed");
259    ///
260    /// let samples = mvt.rvs(100).expect("Operation failed");
261    /// assert_eq!(samples.shape(), &[100, 2]);
262    /// ```
263    pub fn rvs(&self, size: usize) -> StatsResult<Array2<f64>> {
264        let mut rng = thread_rng();
265        let normal_dist = RandNormal::new(0.0, 1.0).expect("Operation failed");
266        let chi2_dist = ChiSquared::new(self.df).expect("Operation failed");
267
268        // Create a matrix for the samples
269        let mut samples = Array2::<f64>::zeros((size, self.dim));
270
271        // For each sample
272        for i in 0..size {
273            // Generate standard normal samples for each dimension
274            let mut z = Array1::<f64>::zeros(self.dim);
275            for j in 0..self.dim {
276                z[j] = normal_dist.sample(&mut rng);
277            }
278
279            // Generate chi-square sample with df degrees of freedom
280            let w = chi2_dist.sample(&mut rng);
281
282            // Transform Z using the Cholesky decomposition
283            let mut transformed = Array1::<f64>::zeros(self.dim);
284            for j in 0..self.dim {
285                for k in 0..=j {
286                    transformed[j] += self.cholesky_l[[j, k]] * z[k];
287                }
288            }
289
290            // Apply the t-distribution scaling
291            let scaling_factor = (self.df / w).sqrt();
292            for j in 0..self.dim {
293                samples[[i, j]] = self.mean[j] + transformed[j] * scaling_factor;
294            }
295        }
296
297        Ok(samples)
298    }
299
300    /// Generate a single random sample from the distribution
301    ///
302    /// # Returns
303    ///
304    /// * Vector representing a single sample
305    ///
306    /// # Examples
307    ///
308    /// ```
309    /// use scirs2_core::ndarray::array;
310    /// use scirs2_stats::distributions::multivariate::student_t::MultivariateT;
311    ///
312    /// let mean = array![0.0, 0.0];
313    /// let scale = array![[1.0, 0.5], [0.5, 2.0]];
314    /// let mvt = MultivariateT::new(mean, scale, 5.0).expect("Operation failed");
315    ///
316    /// let sample = mvt.rvs_single().expect("Operation failed");
317    /// assert_eq!(sample.len(), 2);
318    /// ```
319    pub fn rvs_single(&self) -> StatsResult<Array1<f64>> {
320        let samples = self.rvs(1)?;
321        Ok(samples.index_axis(Axis(0), 0).to_owned())
322    }
323
324    /// Calculate the log probability density function (log PDF) at a given point
325    ///
326    /// # Arguments
327    ///
328    /// * `x` - The point at which to evaluate the log PDF
329    ///
330    /// # Returns
331    ///
332    /// * The value of the log PDF at the given point
333    ///
334    /// # Examples
335    ///
336    /// ```
337    /// use scirs2_core::ndarray::array;
338    /// use scirs2_stats::distributions::multivariate::student_t::MultivariateT;
339    ///
340    /// let mean = array![0.0, 0.0];
341    /// let scale = array![[1.0, 0.0], [0.0, 1.0]];
342    /// let mvt = MultivariateT::new(mean, scale, 5.0).expect("Operation failed");
343    ///
344    /// let log_pdf = mvt.logpdf(&array![0.0, 0.0]);
345    /// ```
346    pub fn logpdf<D>(&self, x: &ArrayBase<D, Ix1>) -> f64
347    where
348        D: Data<Elem = f64>,
349    {
350        if x.len() != self.dim {
351            return f64::NEG_INFINITY; // Return -∞ for invalid dimensions
352        }
353
354        let pi = std::f64::consts::PI;
355
356        // Calculate the constant part of the log PDF
357        let gamma_term_num = lgamma((self.df + self.dim as f64) / 2.0);
358        let gamma_term_denom = lgamma(self.df / 2.0)
359            + lgamma(self.dim as f64 / 2.0)
360            + (self.dim as f64 / 2.0) * self.df.ln();
361        let log_const = gamma_term_num
362            - gamma_term_denom
363            - (self.dim as f64 / 2.0) * pi.ln()
364            - 0.5 * self.scale_det.ln();
365
366        // Calculate Mahalanobis distance: (x - μ)^T Σ^-1 (x - μ)
367        let diff = x - &self.mean;
368        let mahalanobis_squared = self.mahalanobis_distance_squared(&diff.view());
369
370        // log(PDF) = log(C) - ((v+p)/2) * log(1 + (1/v) * dist)
371        log_const - ((self.df + self.dim as f64) / 2.0) * (1.0 + mahalanobis_squared / self.df).ln()
372    }
373
374    /// Get the dimension of the distribution
375    pub fn dim(&self) -> usize {
376        self.dim
377    }
378
379    /// Get the scale matrix of the distribution
380    pub fn scale(&self) -> ArrayView2<f64> {
381        self.scale.view()
382    }
383
384    /// Get the mean vector of the distribution
385    pub fn mean(&self) -> ArrayView1<f64> {
386        self.mean.view()
387    }
388
389    /// Get the degrees of freedom of the distribution
390    pub fn df(&self) -> f64 {
391        self.df
392    }
393}
394
395/// Create a multivariate Student's t-distribution with the given parameters.
396///
397/// This is a convenience function to create a multivariate Student's t-distribution with
398/// the given mean vector, scale matrix, and degrees of freedom.
399///
400/// # Arguments
401///
402/// * `mean` - Mean vector (k-dimensional)
403/// * `scale` - Scale matrix (k x k, symmetric positive-definite)
404/// * `df` - Degrees of freedom (> 0)
405///
406/// # Returns
407///
408/// * A multivariate Student's t-distribution object
409///
410/// # Examples
411///
412/// ```
413/// use scirs2_core::ndarray::array;
414/// use scirs2_stats::distributions::multivariate;
415///
416/// let mean = array![0.0, 0.0];
417/// let scale = array![[1.0, 0.5], [0.5, 2.0]];
418/// let mvt = multivariate::multivariate_t(mean, scale, 5.0).expect("Operation failed");
419/// let pdf_at_origin = mvt.pdf(&array![0.0, 0.0]);
420/// ```
421#[allow(dead_code)]
422pub fn multivariate_t<D1, D2>(
423    mean: ArrayBase<D1, Ix1>,
424    scale: ArrayBase<D2, Ix2>,
425    df: f64,
426) -> StatsResult<MultivariateT>
427where
428    D1: Data<Elem = f64>,
429    D2: Data<Elem = f64>,
430{
431    MultivariateT::new(mean, scale, df)
432}
433
434/// Implementation of SampleableDistribution for MultivariateT
435impl SampleableDistribution<Array1<f64>> for MultivariateT {
436    fn rvs(&self, size: usize) -> StatsResult<Vec<Array1<f64>>> {
437        let samples_matrix = self.rvs(size)?;
438        let mut result = Vec::with_capacity(size);
439
440        for i in 0..size {
441            let row = samples_matrix.slice(s![i, ..]).to_owned();
442            result.push(row);
443        }
444
445        Ok(result)
446    }
447}
448
449#[cfg(test)]
450mod tests {
451    use super::*;
452    use approx::assert_relative_eq;
453    use scirs2_core::ndarray::{array, Axis};
454
455    #[test]
456    fn test_mvt_creation() {
457        // 2D standard multivariate t
458        let mean = array![0.0, 0.0];
459        let scale = array![[1.0, 0.0], [0.0, 1.0]];
460        let mvt = MultivariateT::new(mean.clone(), scale.clone(), 5.0).expect("Operation failed");
461
462        assert_eq!(mvt.dim, 2);
463        assert_eq!(mvt.mean, mean);
464        assert_eq!(mvt.scale, scale);
465        assert_eq!(mvt.df, 5.0);
466
467        // Custom 3D multivariate t
468        let mean3 = array![1.0, 2.0, 3.0];
469        let scale3 = array![[1.0, 0.5, 0.3], [0.5, 2.0, 0.2], [0.3, 0.2, 1.5]];
470        let mvt3 =
471            MultivariateT::new(mean3.clone(), scale3.clone(), 10.0).expect("Operation failed");
472
473        assert_eq!(mvt3.dim, 3);
474        assert_eq!(mvt3.mean, mean3);
475        assert_eq!(mvt3.scale, scale3);
476        assert_eq!(mvt3.df, 10.0);
477    }
478
479    #[test]
480    fn test_mvt_creation_errors() {
481        // Dimension mismatch
482        let mean = array![0.0, 0.0, 0.0];
483        let scale = array![[1.0, 0.0], [0.0, 1.0]];
484        assert!(MultivariateT::new(mean, scale, 5.0).is_err());
485
486        // Non-positive definite scale matrix
487        let mean = array![0.0, 0.0];
488        let scale = array![[1.0, 2.0], [2.0, 1.0]]; // Not positive definite
489        assert!(MultivariateT::new(mean, scale, 5.0).is_err());
490
491        // Invalid degrees of freedom
492        let mean = array![0.0, 0.0];
493        let scale = array![[1.0, 0.0], [0.0, 1.0]];
494        assert!(MultivariateT::new(mean.clone(), scale.clone(), 0.0).is_err());
495        assert!(MultivariateT::new(mean, scale, -1.0).is_err());
496    }
497
498    #[test]
499    fn test_mvt_pdf() {
500        // 2D standard multivariate t with 5 degrees of freedom
501        let mean = array![0.0, 0.0];
502        let scale = array![[1.0, 0.0], [0.0, 1.0]];
503        let mvt = MultivariateT::new(mean, scale, 5.0).expect("Operation failed");
504
505        // PDF at origin should be calculable
506        let pdf_at_origin = mvt.pdf(&array![0.0, 0.0]);
507        assert!(pdf_at_origin > 0.0);
508
509        // PDF at origin should be greater than at [1, 1]
510        let pdf_at_one = mvt.pdf(&array![1.0, 1.0]);
511        assert!(pdf_at_origin > pdf_at_one);
512
513        // PDF should be symmetric
514        let pdf_at_pos = mvt.pdf(&array![2.0, 1.0]);
515        let pdf_at_neg = mvt.pdf(&array![-2.0, -1.0]);
516        assert_relative_eq!(pdf_at_pos, pdf_at_neg, epsilon = 1e-10);
517    }
518
519    #[test]
520    fn test_mvt_logpdf() {
521        // 2D standard multivariate t with 5 degrees of freedom
522        let mean = array![0.0, 0.0];
523        let scale = array![[1.0, 0.0], [0.0, 1.0]];
524        let mvt = MultivariateT::new(mean, scale, 5.0).expect("Operation failed");
525
526        // Check that exp(logPDF) = PDF
527        let x = array![1.0, 1.0];
528        let pdf = mvt.pdf(&x);
529        let logpdf = mvt.logpdf(&x);
530        assert_relative_eq!(logpdf.exp(), pdf, epsilon = 1e-7);
531    }
532
533    #[test]
534    fn test_mvt_rvs() {
535        // 2D multivariate t with correlation and 10 degrees of freedom
536        let mean = array![1.0, 2.0];
537        let scale = array![[1.0, 0.5], [0.5, 2.0]];
538        let mvt = MultivariateT::new(mean, scale, 10.0).expect("Operation failed");
539
540        // Generate samples and check dimensions
541        let n_samples_ = 1000;
542        let samples = mvt.rvs(n_samples_).expect("Operation failed");
543        assert_eq!(samples.shape(), &[n_samples_, 2]);
544
545        // Check statistics (rough check as it's random and t-distribution has heavier tails)
546        let sample_mean = samples.mean_axis(Axis(0)).expect("Operation failed");
547        assert_relative_eq!(sample_mean[0], 1.0, epsilon = 0.3);
548        assert_relative_eq!(sample_mean[1], 2.0, epsilon = 0.3);
549    }
550
551    #[test]
552    fn test_mvt_rvs_single() {
553        let mean = array![1.0, 2.0];
554        let scale = array![[1.0, 0.5], [0.5, 2.0]];
555        let mvt = MultivariateT::new(mean, scale, 5.0).expect("Operation failed");
556
557        let sample = mvt.rvs_single().expect("Operation failed");
558        assert_eq!(sample.len(), 2);
559    }
560}