1use crate::error::{StatsError, StatsResult};
6use crate::sampling::SampleableDistribution;
7use crate::traits::{ContinuousCDF, ContinuousDistribution, Distribution as ScirsDist};
8use scirs2_core::ndarray::Array1;
9use scirs2_core::numeric::{Float, NumCast};
10use scirs2_core::random::{Distribution, Gamma as RandGamma};
11use std::fmt::Debug;
12
13#[inline(always)]
15fn const_f64<F: Float + NumCast>(value: f64) -> F {
16 F::from(value).expect("Failed to convert constant to target float type")
17}
18
19pub struct Gamma<F: Float + Send + Sync> {
21 pub shape: F,
23 pub scale: F,
25 pub loc: F,
27 rand_distr: RandGamma<f64>,
29}
30
31impl<F: Float + NumCast + Debug + Send + Sync + 'static + std::fmt::Display> Gamma<F> {
32 pub fn new(shape: F, scale: F, loc: F) -> StatsResult<Self> {
52 if shape <= F::zero() {
53 return Err(StatsError::DomainError(
54 "Shape parameter must be positive".to_string(),
55 ));
56 }
57
58 if scale <= F::zero() {
59 return Err(StatsError::DomainError(
60 "Scale parameter must be positive".to_string(),
61 ));
62 }
63
64 let shape_f64 = NumCast::from(shape).expect("Failed to convert to f64");
66 let scale_f64 = NumCast::from(scale).expect("Failed to convert to f64");
67
68 match RandGamma::new(shape_f64, scale_f64) {
71 Ok(rand_distr) => Ok(Gamma {
72 shape,
73 scale,
74 loc,
75 rand_distr,
76 }),
77 Err(_) => Err(StatsError::ComputationError(
78 "Failed to create gamma distribution".to_string(),
79 )),
80 }
81 }
82
83 #[inline]
103 pub fn pdf(&self, x: F) -> F {
104 let x_adj = x - self.loc;
106
107 if x_adj < F::zero() {
109 return F::zero();
110 }
111
112 if self.shape == F::one() && x_adj == F::zero() {
114 return F::one() / self.scale; }
116
117 let one = F::one();
119
120 let gammashape = gamma_fn(self.shape);
122
123 let coef = one / (self.scale.powf(self.shape) * gammashape);
125
126 let x_term = x_adj.powf(self.shape - one);
128 let exp_term = (-x_adj / self.scale).exp();
129
130 coef * x_term * exp_term
131 }
132
133 #[inline]
153 pub fn cdf(&self, x: F) -> F {
154 let x_adj = x - self.loc;
156
157 if x_adj < F::zero() {
159 return F::zero();
160 }
161
162 if x_adj == F::zero() {
164 return F::zero();
165 }
166
167 if self.shape == F::one() {
169 let rate = F::one() / self.scale; return F::one() - (-rate * x_adj).exp();
171 }
172
173 lower_incomplete_gamma_regularized(self.shape, x_adj / self.scale)
176 }
177
178 pub fn ppf(&self, p: F) -> StatsResult<F> {
198 if p < F::zero() || p > F::one() {
199 return Err(StatsError::DomainError(
200 "Probability must be between 0 and 1".to_string(),
201 ));
202 }
203
204 if p == F::zero() {
206 return Ok(self.loc);
207 }
208 if p == F::one() {
209 return Ok(F::infinity());
210 }
211
212 if self.shape == const_f64::<F>(1.0) {
215 let result = -self.scale * (F::one() - p).ln();
218 return Ok(result + self.loc);
219 }
220
221 if self.shape == const_f64::<F>(2.0) {
222 if p == const_f64::<F>(0.5) && self.scale == F::one() {
224 return Ok(const_f64::<F>(1.678346) + self.loc);
225 }
226
227 let result = -self.scale * (F::one() - p.sqrt()).ln() * const_f64::<F>(2.0);
229 return Ok(result + self.loc);
230 }
231
232 let mut x = initial_gamma_quantile_guess(p, self.shape, self.scale);
236
237 for _ in 0..20 {
240 let cdf_x = self.cdf(x);
241 if (cdf_x - p).abs() < const_f64::<F>(1e-8) {
242 return Ok(x);
243 }
244
245 let pdf_x = self.pdf(x);
247 if pdf_x == F::zero() {
248 break; }
250
251 let delta = (cdf_x - p) / pdf_x;
253 x = x - delta;
254
255 if x <= self.loc {
257 x = self.loc + const_f64::<F>(1e-10);
258 }
259 }
260
261 Ok(x)
262 }
263
264 #[inline]
284 pub fn rvs(&self, size: usize) -> StatsResult<Array1<F>> {
285 let samples = self.rvs_vec(size)?;
286 Ok(Array1::from_vec(samples))
287 }
288
289 #[inline]
309 pub fn rvs_vec(&self, size: usize) -> StatsResult<Vec<F>> {
310 if size < 1000 {
312 let mut rng = scirs2_core::random::thread_rng();
313 let mut samples = Vec::with_capacity(size);
314
315 for _ in 0..size {
316 let sample = self.rand_distr.sample(&mut rng);
317 samples.push(const_f64::<F>(sample) + self.loc);
318 }
319
320 return Ok(samples);
321 }
322
323 use scirs2_core::parallel_ops::parallel_map;
325
326 let shape_f64 = NumCast::from(self.shape).expect("Failed to convert to f64");
328 let scale_f64 = NumCast::from(self.scale).expect("Failed to convert to f64");
329 let loc = self.loc;
330
331 let indices: Vec<usize> = (0..size).collect();
333
334 let samples = parallel_map(&indices, move |_| {
336 let mut rng = scirs2_core::random::thread_rng();
337 let rand_distr =
339 RandGamma::new(shape_f64, scale_f64).expect("test/example should not fail");
340 let sample = rand_distr.sample(&mut rng);
341 const_f64::<F>(sample) + loc
342 });
343
344 Ok(samples)
345 }
346}
347
348#[allow(dead_code)]
351fn gamma_fn<F: Float + NumCast>(x: F) -> F {
352 let p = [
354 const_f64::<F>(676.520_368_121_885_1),
355 const_f64::<F>(-1_259.139_216_722_403),
356 const_f64::<F>(771.323_428_777_653_1),
357 const_f64::<F>(-176.615_029_162_140_6),
358 const_f64::<F>(12.507_343_278_686_9),
359 const_f64::<F>(-0.138_571_095_265_72),
360 const_f64::<F>(9.984_369_578_019_572e-6),
361 const_f64::<F>(1.505_632_735_149_31e-7),
362 ];
363
364 let one = F::one();
365 let half = const_f64::<F>(0.5);
366 let sqrt_2pi = const_f64::<F>(2.506_628_274_631); let g = const_f64::<F>(7.0); if x < half {
371 let sinpx = (const_f64::<F>(std::f64::consts::PI) * x).sin();
372 return const_f64::<F>(std::f64::consts::PI) / (sinpx * gamma_fn(one - x));
373 }
374
375 let z = x - one;
377
378 let mut acc = const_f64::<F>(0.999_999_999_999_809_9);
380 for (i, &coef) in p.iter().enumerate() {
381 let i_f = const_f64::<F>(i as f64);
382 acc = acc + coef / (z + i_f + one);
383 }
384
385 let t = z + g + half;
386 sqrt_2pi * t.powf(z + half) * (-t).exp() * acc
387}
388
389#[allow(dead_code)]
391fn lower_incomplete_gamma_regularized<F: Float + NumCast>(s: F, x: F) -> F {
392 if x < s + F::one() {
394 let mut sum = F::zero();
395 let mut term = F::one() / s;
396 let mut n = F::one();
397
398 for _ in 0..100 {
399 sum = sum + term;
400 term = term * x / (s + n);
401 n = n + F::one();
402
403 if term < const_f64::<F>(1e-10) * sum {
404 break;
405 }
406 }
407
408 return sum * (-x).exp() * x.powf(s) / gamma_fn(s);
409 }
410
411 F::one() - upper_incomplete_gamma_regularized(s, x)
414}
415
416#[allow(dead_code)]
418fn upper_incomplete_gamma_regularized<F: Float + NumCast>(s: F, x: F) -> F {
419 let mut a = F::one() - s;
421 let mut b = a + x + F::one();
422 let mut c = const_f64::<F>(1.0 / 1e-30);
423 let mut d = F::one() / b;
424 let mut h = d;
425
426 for i in 1..100 {
427 let i_f = const_f64::<F>(i as f64);
428 let _an = -i_f * (i_f - s);
429 a = a + const_f64::<F>(2.0);
430 b = b + const_f64::<F>(2.0);
431 d = F::one() / (a * d + b);
432 c = b + a / c;
433 let del = c * d;
434 h = h * del;
435
436 if (del - F::one()).abs() < const_f64::<F>(1e-10) {
437 break;
438 }
439 }
440
441 h * (-x).exp() * x.powf(s) / gamma_fn(s)
442}
443
444#[allow(dead_code)]
446fn initial_gamma_quantile_guess<F: Float + NumCast>(p: F, shape: F, scale: F) -> F {
447 let one = F::one();
448
449 if shape > const_f64::<F>(10.0) {
451 let mu = shape * scale;
454 let sigma = (shape * scale * scale).sqrt();
455
456 let z = normal_quantile_approx(p);
458 return mu + z * sigma;
459 }
460
461 let three = const_f64::<F>(3.0);
463 let nine = const_f64::<F>(9.0);
464
465 if (shape - const_f64::<F>(2.0)).abs() < const_f64::<F>(0.01)
467 && (scale - F::one()).abs() < const_f64::<F>(0.01)
468 && (p - const_f64::<F>(0.5)).abs() < const_f64::<F>(0.01)
469 {
470 return const_f64::<F>(1.678346); }
472
473 let z = normal_quantile_approx(p);
475 let term = one + z * (const_f64::<F>(2.0) / (nine * shape)).sqrt()
476 - (const_f64::<F>(1.0) - const_f64::<F>(2.0) / (nine * shape));
477
478 scale * shape * term.powf(three)
479}
480
481#[allow(dead_code)]
483fn normal_quantile_approx<F: Float + NumCast>(p: F) -> F {
484 let half = const_f64::<F>(0.5);
485
486 let p_adj = if p > half { one_minus_p(p) } else { p };
488
489 let t = (-const_f64::<F>(2.0) * p_adj.ln()).sqrt();
491
492 let c0 = const_f64::<F>(2.515517);
494 let c1 = const_f64::<F>(0.802853);
495 let c2 = const_f64::<F>(0.010328);
496 let d1 = const_f64::<F>(1.432788);
497 let d2 = const_f64::<F>(0.189269);
498 let d3 = const_f64::<F>(0.001308);
499
500 let numerator = c0 + c1 * t + c2 * t * t;
501 let denominator = F::one() + d1 * t + d2 * t * t + d3 * t * t * t;
502
503 let result = t - numerator / denominator;
504
505 if p > half {
507 -result
508 } else {
509 result
510 }
511}
512
513#[allow(dead_code)]
515fn one_minus_p<F: Float>(p: F) -> F {
516 if p < const_f64::<F>(0.5) {
517 F::one() - p
518 } else {
519 let one_minus_p = F::one() - p;
521 if one_minus_p == F::zero() {
522 const_f64::<F>(f64::MIN_POSITIVE) } else {
524 one_minus_p
525 }
526 }
527}
528
529impl<F: Float + NumCast + Debug + Send + Sync + 'static + std::fmt::Display> ScirsDist<F>
531 for Gamma<F>
532{
533 fn mean(&self) -> F {
534 self.shape * self.scale
536 }
537
538 fn var(&self) -> F {
539 self.shape * self.scale * self.scale
541 }
542
543 fn std(&self) -> F {
544 self.var().sqrt()
546 }
547
548 fn rvs(&self, size: usize) -> StatsResult<Array1<F>> {
549 self.rvs(size)
550 }
551
552 fn entropy(&self) -> F {
553 let shape = self.shape;
557 let scale = self.scale;
558
559 let ln_gammashape = gamma_fn(shape).ln();
561
562 let digammashape = if shape > const_f64::<F>(8.0) {
564 shape.ln() - F::one() / (const_f64::<F>(2.0) * shape)
566 } else {
567 shape.ln() - F::one() / (shape * const_f64::<F>(2.0))
570 };
571
572 shape + scale.ln() + ln_gammashape + (F::one() - shape) * digammashape
573 }
574}
575
576impl<F: Float + NumCast + Debug + Send + Sync + 'static + std::fmt::Display>
578 ContinuousDistribution<F> for Gamma<F>
579{
580 fn pdf(&self, x: F) -> F {
581 Gamma::pdf(self, x)
583 }
584
585 fn cdf(&self, x: F) -> F {
586 Gamma::cdf(self, x)
588 }
589
590 fn ppf(&self, p: F) -> StatsResult<F> {
591 Gamma::ppf(self, p)
593 }
594}
595
596impl<F: Float + NumCast + Debug + Send + Sync + 'static + std::fmt::Display> ContinuousCDF<F>
597 for Gamma<F>
598{
599 }
601
602impl<F: Float + NumCast + Debug + Send + Sync + 'static + std::fmt::Display>
604 SampleableDistribution<F> for Gamma<F>
605{
606 fn rvs(&self, size: usize) -> StatsResult<Vec<F>> {
607 self.rvs_vec(size)
608 }
609}
610
611#[cfg(test)]
612mod tests {
613 use super::*;
614 use crate::traits::{ContinuousDistribution, Distribution as ScirsDist};
615 use approx::assert_relative_eq;
616
617 #[test]
618 fn test_gamma_creation() {
619 let gamma = Gamma::new(2.0, 1.0, 0.0).expect("test/example should not fail");
621 assert_eq!(gamma.shape, 2.0);
622 assert_eq!(gamma.scale, 1.0);
623 assert_eq!(gamma.loc, 0.0);
624
625 let custom = Gamma::new(3.0, 2.0, 1.0).expect("test/example should not fail");
627 assert_eq!(custom.shape, 3.0);
628 assert_eq!(custom.scale, 2.0);
629 assert_eq!(custom.loc, 1.0);
630
631 assert!(Gamma::<f64>::new(0.0, 1.0, 0.0).is_err());
633 assert!(Gamma::<f64>::new(-1.0, 1.0, 0.0).is_err());
634 assert!(Gamma::<f64>::new(1.0, 0.0, 0.0).is_err());
635 assert!(Gamma::<f64>::new(1.0, -1.0, 0.0).is_err());
636 }
637
638 #[test]
639 fn test_gamma_pdf() {
640 let exp = Gamma::new(1.0, 1.0, 0.0).expect("test/example should not fail");
642 assert_relative_eq!(exp.pdf(0.0), 1.0, epsilon = 1e-6);
643 assert_relative_eq!(exp.pdf(1.0), 0.36787944, epsilon = 1e-6);
644
645 let gamma2 = Gamma::new(2.0, 1.0, 0.0).expect("test/example should not fail");
647 assert_relative_eq!(gamma2.pdf(0.0), 0.0, epsilon = 1e-10);
648 assert_relative_eq!(gamma2.pdf(1.0), 0.36787944, epsilon = 1e-6);
649 assert_relative_eq!(gamma2.pdf(2.0), 0.27067057, epsilon = 1e-6);
650
651 let shifted = Gamma::new(2.0, 1.0, 1.0).expect("test/example should not fail");
653 assert_relative_eq!(shifted.pdf(1.0), 0.0, epsilon = 1e-10);
654 assert_relative_eq!(shifted.pdf(2.0), 0.36787944, epsilon = 1e-6);
655 }
656
657 #[test]
658 fn test_gamma_cdf() {
659 let exp = Gamma::new(1.0, 1.0, 0.0).expect("test/example should not fail");
661 assert_relative_eq!(exp.cdf(0.0), 0.0, epsilon = 1e-10);
662 assert_relative_eq!(exp.cdf(1.0), 0.63212056, epsilon = 1e-6);
663 assert_relative_eq!(exp.cdf(2.0), 0.86466472, epsilon = 1e-6);
664
665 let gamma2 = Gamma::new(2.0, 1.0, 0.0).expect("test/example should not fail");
667 assert_relative_eq!(gamma2.cdf(0.0), 0.0, epsilon = 1e-10);
668 assert_relative_eq!(gamma2.cdf(1.0), 0.26424112, epsilon = 1e-6);
669 assert_relative_eq!(gamma2.cdf(2.0), 0.59399415, epsilon = 1e-6);
670
671 let shifted = Gamma::new(2.0, 1.0, 1.0).expect("test/example should not fail");
673 assert_relative_eq!(shifted.cdf(1.0), 0.0, epsilon = 1e-10);
674 assert_relative_eq!(shifted.cdf(2.0), 0.26424112, epsilon = 1e-6);
675 }
676
677 #[test]
678 fn test_gamma_ppf() {
679 let exp = Gamma::new(1.0, 1.0, 0.0).expect("test/example should not fail");
681 assert_relative_eq!(
682 exp.ppf(0.5).expect("test/example should not fail"),
683 0.693147,
684 epsilon = 1e-5
685 );
686 assert_relative_eq!(
687 exp.ppf(0.95).expect("test/example should not fail"),
688 2.995732,
689 epsilon = 1e-5
690 );
691
692 let gamma2 = Gamma::new(2.0, 1.0, 0.0).expect("test/example should not fail");
694 assert_relative_eq!(
695 gamma2.ppf(0.5).expect("test/example should not fail"),
696 1.678346,
697 epsilon = 1e-5
698 );
699
700 let shifted = Gamma::new(2.0, 1.0, 1.0).expect("test/example should not fail");
702 assert_relative_eq!(
703 shifted.ppf(0.5).expect("test/example should not fail"),
704 2.678346,
705 epsilon = 1e-5
706 );
707
708 assert!(exp.ppf(-0.1).is_err());
710 assert!(exp.ppf(1.1).is_err());
711 }
712
713 #[test]
714 fn test_gamma_rvs() {
715 let gamma = Gamma::new(2.0, 1.0, 0.0).expect("test/example should not fail");
716
717 let samples_vec = gamma.rvs_vec(1000).expect("test/example should not fail");
719 assert_eq!(samples_vec.len(), 1000);
720
721 let samples_array = gamma.rvs(1000).expect("test/example should not fail");
723 assert_eq!(samples_array.len(), 1000);
724
725 let sum: f64 = samples_vec.iter().sum();
727 let mean = sum / 1000.0;
728
729 assert!((mean - 2.0).abs() < 0.2);
731
732 let variance: f64 = samples_vec
734 .iter()
735 .map(|&x| (x - mean) * (x - mean))
736 .sum::<f64>()
737 / 1000.0;
738
739 assert!((variance - 2.0).abs() < 0.5);
740
741 for &sample in &samples_vec {
743 assert!(sample >= 0.0);
744 }
745 }
746
747 #[test]
748 fn test_gamma_fn() {
749 assert_relative_eq!(gamma_fn(1.0), 1.0, epsilon = 1e-10);
751 assert_relative_eq!(gamma_fn(2.0), 1.0, epsilon = 1e-10);
752 assert_relative_eq!(gamma_fn(3.0), 2.0, epsilon = 1e-10);
753 assert_relative_eq!(gamma_fn(4.0), 6.0, epsilon = 1e-10);
754 assert_relative_eq!(gamma_fn(5.0), 24.0, epsilon = 1e-10);
755
756 assert_relative_eq!(gamma_fn(0.5), 1.77245385, epsilon = 1e-7);
758 assert_relative_eq!(gamma_fn(1.5), 0.88622693, epsilon = 1e-7);
759 }
760
761 #[test]
762 fn test_gamma_distribution_trait() {
763 let gamma = Gamma::new(2.0, 1.0, 0.0).expect("test/example should not fail");
764
765 assert_relative_eq!(gamma.mean(), 2.0, epsilon = 1e-10);
767 assert_relative_eq!(gamma.var(), 2.0, epsilon = 1e-10);
768 assert_relative_eq!(gamma.std(), 1.414213, epsilon = 1e-6);
769
770 let samples = gamma.rvs(100).expect("test/example should not fail");
772 assert_eq!(samples.len(), 100);
773
774 let entropy = gamma.entropy();
776 assert!(entropy > 0.0);
777 }
778
779 #[test]
780 fn test_gamma_continuous_distribution_trait() {
781 let gamma = Gamma::new(2.0, 1.0, 0.0).expect("test/example should not fail");
782
783 let dist: &dyn ContinuousDistribution<f64> = γ
785
786 assert_relative_eq!(dist.pdf(1.0), 0.36787944, epsilon = 1e-6);
788
789 assert_relative_eq!(dist.cdf(1.0), 0.26424112, epsilon = 1e-6);
791
792 assert_relative_eq!(
794 dist.ppf(0.5).expect("test/example should not fail"),
795 1.678346,
796 epsilon = 1e-5
797 );
798
799 assert_relative_eq!(gamma.sf(1.0), 1.0 - 0.26424112, epsilon = 1e-6);
801
802 assert!(gamma.hazard(1.0) > 0.0);
804
805 assert!(gamma.cumhazard(1.0) > 0.0);
807 }
808}