scirs2_stats/distributions/
student_t.rs1use crate::error::{StatsError, StatsResult};
6use crate::sampling::SampleableDistribution;
7use crate::traits::{ContinuousCDF, ContinuousDistribution, Distribution as ScirsDist};
8use scirs2_core::ndarray::Array1;
9use scirs2_core::numeric::{Float, NumCast};
10use scirs2_core::random::prelude::*;
11use scirs2_core::random::{Distribution, StudentT as RandStudentT};
12use statrs::function::beta::{beta_reg, inv_beta_reg};
13use std::f64::consts::PI;
14
15#[inline(always)]
17fn const_f64<F: Float + NumCast>(value: f64) -> F {
18 F::from(value).expect("Failed to convert constant to target float type")
19}
20
21pub struct StudentT<F: Float + Send + Sync> {
23 pub df: F,
25 pub loc: F,
27 pub scale: F,
29 rand_distr: RandStudentT<f64>,
31}
32
33impl<F: Float + NumCast + Send + Sync + 'static + std::fmt::Display> StudentT<F> {
34 pub fn new(df: F, loc: F, scale: F) -> StatsResult<Self> {
55 if df <= F::zero() {
56 return Err(StatsError::DomainError(
57 "Degrees of freedom must be positive".to_string(),
58 ));
59 }
60
61 if scale <= F::zero() {
62 return Err(StatsError::DomainError(
63 "Scale parameter must be positive".to_string(),
64 ));
65 }
66
67 let df_f64 = NumCast::from(df).expect("Failed to convert to f64");
69
70 match RandStudentT::new(df_f64) {
71 Ok(rand_distr) => Ok(StudentT {
72 df,
73 loc,
74 scale,
75 rand_distr,
76 }),
77 Err(_) => Err(StatsError::ComputationError(
78 "Failed to create Student's t distribution".to_string(),
79 )),
80 }
81 }
82
83 #[inline]
103 pub fn pdf(&self, x: F) -> F {
104 let x_std = (x - self.loc) / self.scale;
106
107 let df_half = self.df / const_f64::<F>(2.0);
109 let df_plus_one_half = (self.df + F::one()) / const_f64::<F>(2.0);
110
111 let one = F::one();
113 let pi = const_f64::<F>(PI);
114
115 let numerator = gamma_function(df_plus_one_half);
117 let denominator = gamma_function(df_half) * (self.df * pi).sqrt();
118
119 let factor = numerator / denominator / self.scale;
120 let exponent = -(df_plus_one_half) * (one + x_std * x_std / self.df).ln();
121
122 factor * exponent.exp()
123 }
124
125 #[inline]
147 pub fn cdf(&self, x: F) -> F {
148 let x_std = (x - self.loc) / self.scale;
150
151 if x_std.is_nan() {
153 return F::nan();
154 }
155
156 if x_std == F::infinity() {
158 return F::one();
159 }
160 if x_std == F::neg_infinity() {
161 return F::zero();
162 }
163
164 if x_std == F::zero() {
166 return const_f64::<F>(0.5);
167 }
168
169 let x_f64: f64 = NumCast::from(x_std).unwrap_or(0.0);
171 let df_f64: f64 = NumCast::from(self.df).unwrap_or(1.0);
172
173 let h = df_f64 / (df_f64 + x_f64 * x_f64);
178 let ib = 0.5 * beta_reg(df_f64 / 2.0, 0.5, h);
179
180 let result = if x_f64 <= 0.0 { ib } else { 1.0 - ib };
181
182 const_f64::<F>(result)
183 }
184
185 #[inline]
205 pub fn rvs(&self, size: usize) -> StatsResult<Array1<F>> {
206 let samples = self.rvs_vec(size)?;
207 Ok(Array1::from_vec(samples))
208 }
209
210 #[inline]
230 pub fn rvs_vec(&self, size: usize) -> StatsResult<Vec<F>> {
231 if size < 1000 {
233 let mut rng = thread_rng();
234 let mut samples = Vec::with_capacity(size);
235
236 for _ in 0..size {
237 let std_sample = self.rand_distr.sample(&mut rng);
239
240 let sample = const_f64::<F>(std_sample) * self.scale + self.loc;
242 samples.push(sample);
243 }
244
245 return Ok(samples);
246 }
247
248 use scirs2_core::parallel_ops::parallel_map;
250
251 let df_f64 = NumCast::from(self.df).expect("Failed to convert to f64");
253 let loc = self.loc;
254 let scale = self.scale;
255
256 let indices: Vec<usize> = (0..size).collect();
258
259 let samples = parallel_map(&indices, move |_| {
261 let mut rng = thread_rng();
262 let rand_distr = RandStudentT::new(df_f64).expect("test/example should not fail");
263 let sample = rand_distr.sample(&mut rng);
264 const_f64::<F>(sample) * scale + loc
265 });
266
267 Ok(samples)
268 }
269}
270
271#[inline]
273#[allow(dead_code)]
274fn gamma_function<F: Float>(x: F) -> F {
275 if x == F::one() {
276 return F::one();
277 }
278
279 if x == const_f64::<F>(0.5) {
280 return const_f64::<F>(PI).sqrt();
281 }
282
283 if x > F::one() {
285 return (x - F::one()) * gamma_function(x - F::one());
286 }
287
288 let p = [
290 const_f64::<F>(676.5203681218851),
291 const_f64::<F>(-1259.1392167224028),
292 const_f64::<F>(771.323_428_777_653_1),
293 const_f64::<F>(-176.615_029_162_140_6),
294 const_f64::<F>(12.507343278686905),
295 const_f64::<F>(-0.13857109526572012),
296 const_f64::<F>(9.984_369_578_019_572e-6),
297 const_f64::<F>(1.5056327351493116e-7),
298 ];
299
300 let x_adj = x - F::one();
301 let t = x_adj + const_f64::<F>(7.5);
302
303 let mut sum = F::zero();
304 for (i, &coef) in p.iter().enumerate() {
305 sum = sum + coef / (x_adj + const_f64::<F>((i + 1) as f64));
306 }
307
308 let pi = const_f64::<F>(PI);
309 let sqrt_2pi = (const_f64::<F>(2.0) * pi).sqrt();
310
311 sqrt_2pi * sum * t.powf(x_adj + const_f64::<F>(0.5)) * (-t).exp()
312}
313
314impl<F: Float + NumCast + Send + Sync + 'static + std::fmt::Display> ScirsDist<F> for StudentT<F> {
316 fn mean(&self) -> F {
317 if self.df <= F::one() {
319 F::nan()
320 } else {
321 self.loc
322 }
323 }
324
325 fn var(&self) -> F {
326 if self.df <= const_f64::<F>(2.0) {
329 F::nan()
330 } else {
331 self.df / (self.df - const_f64::<F>(2.0)) * self.scale * self.scale
332 }
333 }
334
335 fn std(&self) -> F {
336 self.var().sqrt()
338 }
339
340 fn rvs(&self, size: usize) -> StatsResult<Array1<F>> {
341 self.rvs(size)
342 }
343
344 fn entropy(&self) -> F {
345 let df = self.df;
348 let half = const_f64::<F>(0.5);
349 let one = F::one();
350
351 if df <= F::zero() {
352 return F::nan();
353 }
354
355 if df > const_f64::<F>(1000.0) {
357 let e = const_f64::<F>(std::f64::consts::E);
358 return half * (const_f64::<F>(2.0) * const_f64::<F>(std::f64::consts::PI) * e).ln()
359 + self.scale.ln();
360 }
361
362 let half_df_plus_half = (df + one) * half;
364 let half_df = df * half;
365
366 let term1 = half_df_plus_half * (gamma_function(half) / gamma_function(half_df)).ln();
367 let term2 = half_df_plus_half;
368 let term3 = half * (df * const_f64::<F>(std::f64::consts::PI)).ln();
369
370 term1 + term2 + term3
371 }
372}
373
374impl<F: Float + NumCast + Send + Sync + 'static + std::fmt::Display> ContinuousDistribution<F>
376 for StudentT<F>
377{
378 fn pdf(&self, x: F) -> F {
379 StudentT::pdf(self, x)
381 }
382
383 fn cdf(&self, x: F) -> F {
384 StudentT::cdf(self, x)
386 }
387
388 fn ppf(&self, p: F) -> StatsResult<F> {
389 if p < F::zero() || p > F::one() {
391 return Err(StatsError::DomainError(
392 "Probability must be between 0 and 1".to_string(),
393 ));
394 }
395
396 if p == F::zero() {
398 return Ok(F::neg_infinity());
399 }
400 if p == F::one() {
401 return Ok(F::infinity());
402 }
403 if p == const_f64::<F>(0.5) {
404 return Ok(self.loc); }
406
407 let p_f64: f64 = NumCast::from(p).unwrap_or(0.5);
408 let df_f64: f64 = NumCast::from(self.df).unwrap_or(1.0);
409
410 let p1 = p_f64.min(1.0 - p_f64);
416 let y = inv_beta_reg(df_f64 / 2.0, 0.5, 2.0 * p1);
417
418 let t_value = if y == 0.0 {
419 f64::INFINITY
421 } else {
422 (df_f64 * (1.0 - y) / y).sqrt()
423 };
424
425 let signed_t = if p_f64 >= 0.5 { t_value } else { -t_value };
426
427 Ok(const_f64::<F>(signed_t) * self.scale + self.loc)
428 }
429}
430
431impl<F: Float + NumCast + Send + Sync + 'static + std::fmt::Display> ContinuousCDF<F>
432 for StudentT<F>
433{
434 }
436
437impl<F: Float + NumCast + Send + Sync + 'static + std::fmt::Display> SampleableDistribution<F>
439 for StudentT<F>
440{
441 fn rvs(&self, size: usize) -> StatsResult<Vec<F>> {
442 self.rvs_vec(size)
443 }
444}
445
446#[cfg(test)]
447mod tests {
448 use super::*;
449 use crate::traits::{ContinuousDistribution, Distribution as ScirsDist};
450 use approx::assert_relative_eq;
451
452 #[test]
453 fn test_student_t_creation() {
454 let t5 = StudentT::new(5.0, 0.0, 1.0).expect("test/example should not fail");
456 assert_eq!(t5.df, 5.0);
457 assert_eq!(t5.loc, 0.0);
458 assert_eq!(t5.scale, 1.0);
459
460 let custom = StudentT::new(10.0, 1.0, 2.0).expect("test/example should not fail");
462 assert_eq!(custom.df, 10.0);
463 assert_eq!(custom.loc, 1.0);
464 assert_eq!(custom.scale, 2.0);
465
466 assert!(StudentT::<f64>::new(0.0, 0.0, 1.0).is_err());
468 assert!(StudentT::<f64>::new(-1.0, 0.0, 1.0).is_err());
469 assert!(StudentT::<f64>::new(5.0, 0.0, 0.0).is_err());
470 assert!(StudentT::<f64>::new(5.0, 0.0, -1.0).is_err());
471 }
472
473 #[test]
474 fn test_student_t_pdf() {
475 let t5 = StudentT::new(5.0, 0.0, 1.0).expect("test/example should not fail");
477
478 let pdf_at_zero = t5.pdf(0.0);
480 assert_relative_eq!(pdf_at_zero, 0.3796, epsilon = 1e-4);
481
482 let pdf_at_one = t5.pdf(1.0);
484 assert_relative_eq!(pdf_at_one, 0.220, epsilon = 1e-3);
485
486 let pdf_at_neg_one = t5.pdf(-1.0);
488 assert_relative_eq!(pdf_at_neg_one, 0.220, epsilon = 1e-3);
489 }
490
491 #[test]
492 fn test_student_t_cdf() {
493 let t5 = StudentT::new(5.0, 0.0, 1.0).expect("test/example should not fail");
495
496 let cdf_at_zero = t5.cdf(0.0);
498 assert_relative_eq!(cdf_at_zero, 0.5, epsilon = 1e-10);
499
500 let cdf_at_one = t5.cdf(1.0);
502 assert_relative_eq!(cdf_at_one, 0.8183916424979924, epsilon = 1e-6);
503
504 let cdf_at_neg_one = t5.cdf(-1.0);
506 assert_relative_eq!(cdf_at_neg_one, 1.0 - 0.8183916424979924, epsilon = 1e-6);
507
508 let cdf_at_two = t5.cdf(2.0);
510 assert_relative_eq!(cdf_at_two, 0.9490302071648776, epsilon = 1e-6);
511 }
512
513 #[test]
514 fn test_student_t_ppf() {
515 let t5 = StudentT::new(5.0, 0.0, 1.0).expect("test/example should not fail");
517
518 let median = t5.ppf(0.5).expect("test/example should not fail");
520 assert_relative_eq!(median, 0.0, epsilon = 1e-10);
521
522 let p95 = t5.ppf(0.95).expect("test/example should not fail");
524 assert_relative_eq!(p95, 2.0150483726691575, epsilon = 1e-6);
525
526 let p05 = t5.ppf(0.05).expect("test/example should not fail");
528 assert_relative_eq!(p05, -2.0150483726691575, epsilon = 1e-6);
529
530 for &p in &[0.01, 0.1, 0.25, 0.5, 0.75, 0.9, 0.99] {
532 let x = t5.ppf(p).expect("test/example should not fail");
533 let p_roundtrip = t5.cdf(x);
534 assert_relative_eq!(p_roundtrip, p, epsilon = 1e-6);
535 }
536 }
537
538 #[test]
539 fn test_student_t_rvs() {
540 let t5 = StudentT::new(5.0, 0.0, 1.0).expect("test/example should not fail");
541
542 let samples_vec = t5.rvs_vec(1000).expect("test/example should not fail");
544 assert_eq!(samples_vec.len(), 1000);
545
546 let samples_array = t5.rvs(1000).expect("test/example should not fail");
548 assert_eq!(samples_array.len(), 1000);
549
550 let sum: f64 = samples_vec.iter().sum();
552 let mean = sum / 1000.0;
553
554 assert!(mean.abs() < 0.2);
556 }
557
558 #[test]
559 fn test_student_t_distribution_trait() {
560 let t5 = StudentT::new(5.0, 0.0, 1.0).expect("test/example should not fail");
562
563 assert_relative_eq!(t5.mean(), 0.0, epsilon = 1e-10);
565 assert_relative_eq!(t5.var(), 5.0 / 3.0, epsilon = 1e-10);
566 assert_relative_eq!(t5.std(), (5.0 / 3.0f64).sqrt(), epsilon = 1e-10);
567
568 let t1 = StudentT::new(1.0, 0.0, 1.0).expect("test/example should not fail");
570 assert!(t1.mean().is_nan());
571 assert!(t1.var().is_nan());
572 assert!(t1.std().is_nan());
573
574 let entropy = t5.entropy();
576 assert!(entropy > 0.0);
577 }
578
579 #[test]
580 fn test_student_t_continuous_distribution_trait() {
581 let t5 = StudentT::new(5.0, 0.0, 1.0).expect("test/example should not fail");
583
584 let dist: &dyn ContinuousDistribution<f64> = &t5;
586
587 assert_relative_eq!(dist.pdf(0.0), 0.3796, epsilon = 1e-4);
589
590 assert_relative_eq!(dist.cdf(0.0), 0.5, epsilon = 1e-10);
592
593 assert_relative_eq!(
595 dist.ppf(0.5).expect("test/example should not fail"),
596 0.0,
597 epsilon = 1e-10
598 );
599
600 assert_relative_eq!(t5.sf(0.0), 0.5, epsilon = 1e-10);
602 assert!(t5.hazard(0.0) > 0.0);
603 assert!(t5.cumhazard(0.0) > 0.0);
604
605 assert_relative_eq!(
607 t5.isf(0.95).expect("test/example should not fail"),
608 dist.ppf(0.05).expect("test/example should not fail"),
609 epsilon = 1e-6
610 );
611 }
612}