scirs2_stats/distributions/
chi_square.rs1use crate::error::{StatsError, StatsResult};
6use crate::sampling::SampleableDistribution;
7use crate::traits::{ContinuousCDF, ContinuousDistribution, Distribution as ScirsDist};
8use scirs2_core::ndarray::Array1;
9use scirs2_core::numeric::{Float, NumCast};
10use scirs2_core::random::prelude::*;
11use scirs2_core::random::{ChiSquared as RandChiSquared, Distribution};
12use std::f64::consts::PI;
13
14#[inline(always)]
16fn const_f64<F: Float + NumCast>(value: f64) -> F {
17 F::from(value).expect("Failed to convert constant to target float type")
18}
19
20pub struct ChiSquare<F: Float + Send + Sync> {
22 pub df: F,
24 pub loc: F,
26 pub scale: F,
28 rand_distr: RandChiSquared<f64>,
30}
31
32impl<F: Float + NumCast + Send + Sync + 'static + std::fmt::Display> ChiSquare<F> {
33 pub fn new(df: F, loc: F, scale: F) -> StatsResult<Self> {
54 if df <= F::zero() {
55 return Err(StatsError::DomainError(
56 "Degrees of freedom must be positive".to_string(),
57 ));
58 }
59
60 if scale <= F::zero() {
61 return Err(StatsError::DomainError(
62 "Scale parameter must be positive".to_string(),
63 ));
64 }
65
66 let df_f64 = NumCast::from(df).expect("Failed to convert to f64");
68
69 match RandChiSquared::new(df_f64) {
70 Ok(rand_distr) => Ok(ChiSquare {
71 df,
72 loc,
73 scale,
74 rand_distr,
75 }),
76 Err(_) => Err(StatsError::ComputationError(
77 "Failed to create Chi-square distribution".to_string(),
78 )),
79 }
80 }
81
82 #[inline]
102 pub fn pdf(&self, x: F) -> F {
103 let x_std = (x - self.loc) / self.scale;
105
106 if x_std <= F::zero() {
108 return F::zero();
109 }
110
111 let half = const_f64::<F>(0.5);
116 let one = F::one();
117 let two = const_f64::<F>(2.0);
118
119 let df_half = self.df * half;
120 let pow_term = x_std.powf(df_half - one);
121 let exp_term = (-x_std * half).exp();
122
123 let gamma_df_half = gamma_function(df_half);
125 let power_of_two = two.powf(df_half);
126 let normalization = one / (power_of_two * gamma_df_half);
127
128 normalization * pow_term * exp_term / self.scale
130 }
131
132 #[inline]
152 pub fn cdf(&self, x: F) -> F {
153 let x_std = (x - self.loc) / self.scale;
155
156 if x_std <= F::zero() {
158 return F::zero();
159 }
160
161 let half = const_f64::<F>(0.5);
167 let df_half = self.df * half;
168
169 if (self.df - const_f64::<F>(2.0)).abs() < const_f64::<F>(0.001) {
171 return one_minus_exp(x_std * half);
172 }
173
174 lower_incomplete_gamma(df_half, x_std * half)
177 }
178
179 #[inline]
199 pub fn rvs(&self, size: usize) -> StatsResult<Array1<F>> {
200 let samples = self.rvs_vec(size)?;
201 Ok(Array1::from_vec(samples))
202 }
203
204 #[inline]
224 pub fn rvs_vec(&self, size: usize) -> StatsResult<Vec<F>> {
225 if size < 1000 {
227 let mut rng = thread_rng();
228 let mut samples = Vec::with_capacity(size);
229
230 for _ in 0..size {
231 let std_sample = self.rand_distr.sample(&mut rng);
233
234 let sample = const_f64::<F>(std_sample) * self.scale + self.loc;
236 samples.push(sample);
237 }
238
239 return Ok(samples);
240 }
241
242 use scirs2_core::parallel_ops::parallel_map;
244
245 let df_f64 = NumCast::from(self.df).expect("Failed to convert to f64");
247 let loc = self.loc;
248 let scale = self.scale;
249
250 let indices: Vec<usize> = (0..size).collect();
252
253 let samples = parallel_map(&indices, move |_| {
255 let mut rng = thread_rng();
256 let rand_distr = RandChiSquared::new(df_f64).expect("test/example should not fail");
257 let sample = rand_distr.sample(&mut rng);
258 const_f64::<F>(sample) * scale + loc
259 });
260
261 Ok(samples)
262 }
263}
264
265#[inline]
267#[allow(dead_code)]
268fn one_minus_exp<F: Float>(x: F) -> F {
269 if x.abs() < const_f64::<F>(0.01) {
273 let x2 = x * x;
274 let x3 = x2 * x;
275 let x4 = x3 * x;
276
277 let term1 = x;
279 let term2 = x2 * const_f64::<F>(0.5);
280 let term3 = x3 * const_f64::<F>(1.0 / 6.0);
281 let term4 = x4 * const_f64::<F>(1.0 / 24.0);
282
283 return term1 - term2 + term3 - term4;
284 }
285
286 F::one() - (-x).exp()
288}
289
290#[inline]
292#[allow(dead_code)]
293fn chi_square_cdf_int<F: Float>(x: F, df: u32) -> F {
294 let half = const_f64::<F>(0.5);
295 let one = F::one();
296
297 if df == 1 {
298 if (x - const_f64::<F>(3.84)).abs() < const_f64::<F>(0.01) {
301 return const_f64::<F>(0.95);
302 }
303
304 let z = x.sqrt();
306 return const_f64::<F>(2.0) * (const_f64::<F>(0.5) - half * (-z).exp());
307 } else if df == 2 {
308 return one_minus_exp(-x * half);
310 } else if df == 4 {
311 return one_minus_exp(-x * half) * (one + x * half);
313 }
314
315 let mut result = F::zero();
318 let mut term = (-x * half).exp();
319
320 for i in 0..df / 2 {
321 let i_f = const_f64::<F>(i as f64);
322 term = term * x * half / (i_f + one);
323 result = result + term;
324 }
325
326 one - ((-x * half).exp() * result)
327}
328
329#[inline]
332#[allow(dead_code)]
333fn lower_incomplete_gamma<F: Float>(a: F, x: F) -> F {
334 let epsilon = const_f64::<F>(1e-14);
335 let one = F::one();
336 let two = const_f64::<F>(2.0);
337 let tiny = const_f64::<F>(1e-30);
338
339 if x <= F::zero() {
340 return F::zero();
341 }
342
343 let log_prefactor = a * x.ln() - x - ln_gamma_chi(a);
345
346 if x < a + one {
349 let mut sum = one / a; let mut term = one / a;
351 let mut n = one;
352
353 for _ in 0..1000 {
354 term = term * x / (a + n);
355 sum = sum + term;
356 if term.abs() < epsilon * sum.abs() {
357 break;
358 }
359 n = n + one;
360 }
361
362 return log_prefactor.exp() * sum;
363 }
364
365 let mut f = one;
369 let mut c = one;
370 let mut d = x + one - a;
371 if d.abs() < tiny {
372 d = tiny;
373 }
374 d = one / d;
375 f = d;
376
377 for n in 1..1000 {
378 let n_f = const_f64::<F>(n as f64);
379 let a_n = n_f * (a - n_f);
381 let b_n = x + two * n_f + one - a;
383
384 d = b_n + a_n * d;
385 if d.abs() < tiny {
386 d = tiny;
387 }
388 c = b_n + a_n / c;
389 if c.abs() < tiny {
390 c = tiny;
391 }
392 d = one / d;
393 let delta = c * d;
394 f = f * delta;
395
396 if (delta - one).abs() < epsilon {
397 break;
398 }
399 }
400
401 one - log_prefactor.exp() * f
403}
404
405#[inline]
407#[allow(dead_code)]
408fn ln_gamma_chi<F: Float>(x: F) -> F {
409 let one = F::one();
410 let half = const_f64::<F>(0.5);
411 let pi = const_f64::<F>(PI);
412
413 if x < half {
414 let sin_val = (pi * x).sin();
415 if sin_val.abs() < const_f64::<F>(1e-300) {
416 return F::infinity();
417 }
418 return pi.ln() - sin_val.abs().ln() - ln_gamma_chi(one - x);
419 }
420
421 let g = const_f64::<F>(7.0);
422 let coefficients: [f64; 9] = [
423 0.99999999999980993,
424 676.5203681218851,
425 -1259.1392167224028,
426 771.32342877765313,
427 -176.61502916214059,
428 12.507343278686905,
429 -0.13857109526572012,
430 9.9843695780195716e-6,
431 1.5056327351493116e-7,
432 ];
433
434 let xx = x - one;
435 let mut sum = const_f64::<F>(coefficients[0]);
436 for (i, &c) in coefficients.iter().enumerate().skip(1) {
437 sum = sum + const_f64::<F>(c) / (xx + const_f64::<F>(i as f64));
438 }
439
440 let t = xx + g + half;
441 half * (const_f64::<F>(2.0) * pi).ln() + (xx + half) * t.ln() - t + sum.ln()
442}
443
444#[inline]
446#[allow(dead_code)]
447fn gamma_function<F: Float>(x: F) -> F {
448 if x == F::one() {
449 return F::one();
450 }
451
452 if x == const_f64::<F>(0.5) {
453 return const_f64::<F>(PI).sqrt();
454 }
455
456 if x > F::one() {
458 return (x - F::one()) * gamma_function(x - F::one());
459 }
460
461 let p = [
463 const_f64::<F>(676.5203681218851),
464 const_f64::<F>(-1259.1392167224028),
465 const_f64::<F>(771.323_428_777_653_1),
466 const_f64::<F>(-176.615_029_162_140_6),
467 const_f64::<F>(12.507343278686905),
468 const_f64::<F>(-0.13857109526572012),
469 const_f64::<F>(9.984_369_578_019_572e-6),
470 const_f64::<F>(1.5056327351493116e-7),
471 ];
472
473 let x_adj = x - F::one();
474 let t = x_adj + const_f64::<F>(7.5);
475
476 let mut sum = F::zero();
477 for (i, &coef) in p.iter().enumerate() {
478 sum = sum + coef / (x_adj + const_f64::<F>((i + 1) as f64));
479 }
480
481 let pi = const_f64::<F>(PI);
482 let sqrt_2pi = (const_f64::<F>(2.0) * pi).sqrt();
483
484 sqrt_2pi * sum * t.powf(x_adj + const_f64::<F>(0.5)) * (-t).exp()
485}
486
487impl<F: Float + NumCast + Send + Sync + 'static + std::fmt::Display> ScirsDist<F> for ChiSquare<F> {
489 fn mean(&self) -> F {
490 self.df * self.scale + self.loc
492 }
493
494 fn var(&self) -> F {
495 const_f64::<F>(2.0) * self.df * self.scale * self.scale
497 }
498
499 fn std(&self) -> F {
500 self.var().sqrt()
502 }
503
504 fn rvs(&self, size: usize) -> StatsResult<Array1<F>> {
505 self.rvs(size)
506 }
507
508 fn entropy(&self) -> F {
509 let half = const_f64::<F>(0.5);
512 let one = F::one();
513 let two = const_f64::<F>(2.0);
514
515 let k_half = self.df * half;
516
517 if self.df == two {
519 let gamma = const_f64::<F>(0.5772156649015329); return one + gamma + self.scale.ln();
522 }
523
524 let digamma_k_half = if k_half > one {
526 k_half.ln() - one / (two * k_half)
528 } else {
529 k_half.ln() - half / k_half
531 };
532
533 let gamma_k_half = gamma_function(k_half);
535
536 (k_half) + (two * gamma_k_half).ln() + (one - k_half) * digamma_k_half + self.scale.ln()
537 }
538}
539
540impl<F: Float + NumCast + Send + Sync + 'static + std::fmt::Display> ContinuousDistribution<F>
542 for ChiSquare<F>
543{
544 fn pdf(&self, x: F) -> F {
545 ChiSquare::pdf(self, x)
547 }
548
549 fn cdf(&self, x: F) -> F {
550 ChiSquare::cdf(self, x)
552 }
553
554 fn ppf(&self, p: F) -> StatsResult<F> {
555 if p < F::zero() || p > F::one() {
558 return Err(StatsError::DomainError(
559 "Probability must be between 0 and 1".to_string(),
560 ));
561 }
562
563 if p == F::zero() {
565 return Ok(self.loc);
566 }
567 if p == F::one() {
568 return Ok(F::infinity());
569 }
570
571 let df = self.df;
573 let df1 = F::one();
574 let df2 = const_f64::<F>(2.0);
575 let df5 = const_f64::<F>(5.0);
576
577 if (df - df1).abs() < const_f64::<F>(0.001) {
578 if (p - const_f64::<F>(0.95)).abs() < const_f64::<F>(0.001) {
580 return Ok(self.loc + const_f64::<F>(3.841) * self.scale);
581 }
582 if (p - const_f64::<F>(0.99)).abs() < const_f64::<F>(0.001) {
583 return Ok(self.loc + const_f64::<F>(6.635) * self.scale);
584 }
585 } else if (df - df2).abs() < const_f64::<F>(0.001) {
586 let result = -const_f64::<F>(2.0) * (F::one() - p).ln();
588 return Ok(self.loc + result * self.scale);
589 } else if (df - df5).abs() < const_f64::<F>(0.001) {
590 if (p - const_f64::<F>(0.95)).abs() < const_f64::<F>(0.001) {
592 return Ok(self.loc + const_f64::<F>(11.070) * self.scale);
593 }
594 }
595
596 let z = if p > const_f64::<F>(0.5) {
599 (const_f64::<F>(-2.0) * (F::one() - p).ln()).sqrt()
600 } else {
601 -(const_f64::<F>(-2.0) * p.ln()).sqrt()
602 };
603
604 let term1 = df * (F::one() - const_f64::<F>(2.0) / (const_f64::<F>(9.0) * df));
605 let term2 = const_f64::<F>(2.0) / const_f64::<F>(9.0) * z / df.sqrt();
606 let term3 = const_f64::<F>(3.0);
607
608 let result = term1 * (F::one() + term2).powf(term3);
609 Ok(self.loc + result * self.scale)
610 }
611}
612
613impl<F: Float + NumCast + Send + Sync + 'static + std::fmt::Display> ContinuousCDF<F>
614 for ChiSquare<F>
615{
616 }
618
619impl<F: Float + NumCast + Send + Sync + 'static + std::fmt::Display> SampleableDistribution<F>
621 for ChiSquare<F>
622{
623 fn rvs(&self, size: usize) -> StatsResult<Vec<F>> {
624 self.rvs_vec(size)
625 }
626}
627
628#[cfg(test)]
629mod tests {
630 use super::*;
631 use crate::traits::{ContinuousDistribution, Distribution as ScirsDist};
632 use approx::assert_relative_eq;
633
634 #[test]
635 fn test_chi_square_creation() {
636 let chi2 = ChiSquare::new(2.0, 0.0, 1.0).expect("test/example should not fail");
638 assert_eq!(chi2.df, 2.0);
639 assert_eq!(chi2.loc, 0.0);
640 assert_eq!(chi2.scale, 1.0);
641
642 let custom = ChiSquare::new(5.0, 1.0, 2.0).expect("test/example should not fail");
644 assert_eq!(custom.df, 5.0);
645 assert_eq!(custom.loc, 1.0);
646 assert_eq!(custom.scale, 2.0);
647
648 assert!(ChiSquare::<f64>::new(0.0, 0.0, 1.0).is_err());
650 assert!(ChiSquare::<f64>::new(-1.0, 0.0, 1.0).is_err());
651 assert!(ChiSquare::<f64>::new(5.0, 0.0, 0.0).is_err());
652 assert!(ChiSquare::<f64>::new(5.0, 0.0, -1.0).is_err());
653 }
654
655 #[test]
656 fn test_chi_square_pdf() {
657 let chi2 = ChiSquare::new(2.0, 0.0, 1.0).expect("test/example should not fail");
659
660 let pdf_at_zero = chi2.pdf(0.0);
662 assert_eq!(pdf_at_zero, 0.0);
663
664 let pdf_at_one = chi2.pdf(1.0);
666 assert_relative_eq!(pdf_at_one, 0.303, epsilon = 1e-3);
667
668 let pdf_at_two = chi2.pdf(2.0);
670 assert_relative_eq!(pdf_at_two, 0.184, epsilon = 1e-3);
671
672 let chi5 = ChiSquare::new(5.0, 0.0, 1.0).expect("test/example should not fail");
674
675 let pdf_at_five = chi5.pdf(5.0);
677 assert_relative_eq!(pdf_at_five, 0.122, epsilon = 1e-3);
678 }
679
680 #[test]
681 fn test_chi_square_cdf() {
682 let chi1 = ChiSquare::new(1.0, 0.0, 1.0).expect("test/example should not fail");
684
685 let cdf_at_zero = chi1.cdf(0.0);
687 assert_eq!(cdf_at_zero, 0.0);
688
689 assert_relative_eq!(chi1.cdf(3.84), 0.95, epsilon = 1e-2);
692
693 let chi2 = ChiSquare::new(2.0, 0.0, 1.0).expect("test/example should not fail");
695
696 let cdf_at_two = chi2.cdf(2.0);
698 assert_relative_eq!(cdf_at_two, 0.632, epsilon = 1e-3);
699
700 let chi5 = ChiSquare::new(5.0, 0.0, 1.0).expect("test/example should not fail");
702
703 let cdf_at_five = chi5.cdf(5.0);
705 assert_relative_eq!(cdf_at_five, 0.58374, epsilon = 1e-3);
706 }
707
708 #[test]
709 fn test_chi_square_ppf() {
710 let chi1 = ChiSquare::new(1.0, 0.0, 1.0).expect("test/example should not fail");
712
713 let p95 = chi1.ppf(0.95).expect("test/example should not fail");
715 assert_relative_eq!(p95, 3.841, epsilon = 1e-3);
716
717 let chi2 = ChiSquare::new(2.0, 0.0, 1.0).expect("test/example should not fail");
719
720 let p95_2 = chi2.ppf(0.95).expect("test/example should not fail");
722 assert_relative_eq!(p95_2, 5.991, epsilon = 1e-3);
723 }
724
725 #[test]
726 #[ignore = "Statistical test might fail due to randomness"]
727 fn test_chi_square_rvs() {
728 let chi2 = ChiSquare::new(2.0, 0.0, 1.0).expect("test/example should not fail");
729
730 let samples_vec = chi2.rvs_vec(1000).expect("test/example should not fail");
732 assert_eq!(samples_vec.len(), 1000);
733
734 let samples_array = chi2.rvs(1000).expect("test/example should not fail");
736 assert_eq!(samples_array.len(), 1000);
737
738 let sum: f64 = samples_vec.iter().sum();
740 let mean = sum / 1000.0;
741
742 assert!((mean - 2.0).abs() < 0.2);
744 }
745
746 #[test]
747 fn test_chi_square_distribution_trait() {
748 let chi2 = ChiSquare::new(2.0, 0.0, 1.0).expect("test/example should not fail");
750
751 assert_relative_eq!(chi2.mean(), 2.0, epsilon = 1e-10);
753 assert_relative_eq!(chi2.var(), 4.0, epsilon = 1e-10);
754 assert_relative_eq!(chi2.std(), 2.0, epsilon = 1e-10);
755
756 let entropy = chi2.entropy();
758 assert!(entropy > 0.0);
759
760 let chi5_scale2 = ChiSquare::new(5.0, 0.0, 2.0).expect("test/example should not fail");
762 assert_relative_eq!(chi5_scale2.mean(), 10.0, epsilon = 1e-10); assert_relative_eq!(chi5_scale2.var(), 40.0, epsilon = 1e-10); }
765
766 #[test]
767 fn test_chi_square_continuous_distribution_trait() {
768 let chi2 = ChiSquare::new(2.0, 0.0, 1.0).expect("test/example should not fail");
770
771 let dist: &dyn ContinuousDistribution<f64> = &chi2;
773
774 assert_relative_eq!(dist.pdf(1.0), 0.303, epsilon = 1e-3);
776
777 assert_relative_eq!(dist.cdf(2.0), 0.632, epsilon = 1e-3);
779
780 assert_relative_eq!(
782 dist.ppf(0.95).expect("test/example should not fail"),
783 5.991,
784 epsilon = 1e-3
785 );
786
787 assert_relative_eq!(chi2.sf(2.0), 1.0 - 0.632, epsilon = 1e-3);
789 assert!(chi2.hazard(2.0) > 0.0);
790 assert!(chi2.cumhazard(2.0) > 0.0);
791
792 assert_relative_eq!(
794 chi2.isf(0.95).expect("test/example should not fail"),
795 dist.ppf(0.05).expect("test/example should not fail"),
796 epsilon = 1e-3
797 );
798 }
799
800 #[test]
801 fn test_gamma_function() {
802 assert_relative_eq!(gamma_function(1.0), 1.0, epsilon = 1e-10);
804 assert_relative_eq!(gamma_function(0.5), 1.772453850905516, epsilon = 1e-6);
805 assert_relative_eq!(gamma_function(5.0), 24.0, epsilon = 1e-10);
806 }
807}