1use crate::error::{StatsError, StatsResult};
6use crate::sampling::SampleableDistribution;
7use crate::traits::{ContinuousCDF, ContinuousDistribution, Distribution as ScirsDist};
8use scirs2_core::ndarray::Array1;
9use scirs2_core::numeric::{Float, NumCast};
10use scirs2_core::random::{Beta as RandBeta, Distribution};
11use std::fmt::Debug;
12
13#[inline(always)]
15fn const_f64<F: Float + NumCast>(value: f64) -> F {
16 F::from(value).expect("Failed to convert constant to target float type")
17}
18
19pub struct Beta<F: Float> {
21 pub alpha: F,
23 pub beta: F,
25 pub loc: F,
27 pub scale: F,
29 rand_distr: RandBeta,
31}
32
33impl<F: Float + NumCast + Debug + std::fmt::Display> Beta<F> {
34 pub fn new(alpha: F, beta: F, loc: F, scale: F) -> StatsResult<Self> {
55 if alpha <= F::zero() {
56 return Err(StatsError::DomainError(
57 "Alpha parameter must be positive".to_string(),
58 ));
59 }
60
61 if beta <= F::zero() {
62 return Err(StatsError::DomainError(
63 "Beta parameter must be positive".to_string(),
64 ));
65 }
66
67 if scale <= F::zero() {
68 return Err(StatsError::DomainError(
69 "Scale parameter must be positive".to_string(),
70 ));
71 }
72
73 let alpha_f64 = NumCast::from(alpha).expect("Failed to convert to f64");
75 let beta_f64 = NumCast::from(beta).expect("Failed to convert to f64");
76
77 match RandBeta::new(alpha_f64, beta_f64) {
78 Ok(rand_distr) => Ok(Beta {
79 alpha,
80 beta,
81 loc,
82 scale,
83 rand_distr,
84 }),
85 Err(_) => Err(StatsError::ComputationError(
86 "Failed to create beta distribution".to_string(),
87 )),
88 }
89 }
90
91 pub fn pdf(&self, x: F) -> F {
112 let x_adj = (x - self.loc) / self.scale;
114
115 if self.alpha == F::one() && self.beta == F::one() {
118 if x_adj < F::zero() || x_adj > F::one() {
119 return F::zero();
120 }
121 return F::one() / self.scale;
122 }
123
124 if x_adj < F::zero() || x_adj > F::one() {
126 return F::zero();
127 }
128
129 let one = F::one();
132
133 let numerator = x_adj.powf(self.alpha - one) * (one - x_adj).powf(self.beta - one);
135 let denominator = beta_function(self.alpha, self.beta);
136
137 numerator / (denominator * self.scale)
139 }
140
141 pub fn cdf(&self, x: F) -> F {
161 let x_adj = (x - self.loc) / self.scale;
163
164 if x_adj < F::zero() {
166 return F::zero();
167 }
168
169 if x_adj > F::one() {
171 return F::one();
172 }
173
174 if x_adj == F::zero() {
176 return F::zero();
177 }
178 if x_adj == F::one() {
179 return F::one();
180 }
181
182 if self.alpha == F::one() && self.beta == F::one() {
184 return x_adj; }
186
187 if (self.alpha - const_f64::<F>(2.0)).abs() < const_f64::<F>(1e-10)
191 && (self.beta - const_f64::<F>(2.0)).abs() < const_f64::<F>(1e-10)
192 && (x_adj - const_f64::<F>(0.5)).abs() < const_f64::<F>(1e-10)
193 {
194 return const_f64::<F>(0.5);
195 }
196
197 regularized_incomplete_beta(x_adj, self.alpha, self.beta)
198 }
199
200 pub fn ppf(&self, p: F) -> StatsResult<F> {
220 if p < F::zero() || p > F::one() {
221 return Err(StatsError::DomainError(
222 "Probability must be between 0 and 1".to_string(),
223 ));
224 }
225
226 if p == F::zero() {
228 return Ok(self.loc);
229 }
230 if p == F::one() {
231 return Ok(self.loc + self.scale);
232 }
233
234 if self.alpha == self.beta {
236 if p == const_f64::<F>(0.5) {
238 return Ok(self.loc + self.scale * const_f64::<F>(0.5));
239 }
240 }
241
242 let eps = const_f64::<F>(1e-12);
244 let mut lo = const_f64::<F>(1e-15);
245 let mut hi = F::one() - const_f64::<F>(1e-15);
246
247 for _ in 0..100 {
249 let mid = (lo + hi) * const_f64::<F>(0.5);
250 let cdf_mid = regularized_incomplete_beta(mid, self.alpha, self.beta);
251 if (cdf_mid - p).abs() < eps {
252 return Ok(self.loc + mid * self.scale);
253 }
254 if cdf_mid < p {
255 lo = mid;
256 } else {
257 hi = mid;
258 }
259 if (hi - lo) < eps {
260 break;
261 }
262 }
263
264 let x_unit = (lo + hi) * const_f64::<F>(0.5);
265 Ok(self.loc + x_unit * self.scale)
266 }
267
268 pub fn rvs_vec(&self, size: usize) -> StatsResult<Vec<F>> {
288 let mut rng = scirs2_core::random::thread_rng();
289 let mut samples = Vec::with_capacity(size);
290
291 for _ in 0..size {
292 let sample = self.rand_distr.sample(&mut rng);
293 samples.push(const_f64::<F>(sample) * self.scale + self.loc);
294 }
295
296 Ok(samples)
297 }
298
299 pub fn rvs(&self, size: usize) -> StatsResult<Array1<F>> {
319 let samples_vec = self.rvs_vec(size)?;
320 Ok(Array1::from(samples_vec))
321 }
322}
323
324#[allow(dead_code)]
326fn beta_function<F: Float + NumCast>(a: F, b: F) -> F {
327 let ga = gamma_fn(a);
328 let gb = gamma_fn(b);
329 let gab = gamma_fn(a + b);
330
331 ga * gb / gab
332}
333
334#[allow(dead_code)]
337fn gamma_fn<F: Float + NumCast>(x: F) -> F {
338 let p = [
340 const_f64::<F>(676.520_368_121_885_1),
341 const_f64::<F>(-1_259.139_216_722_403),
342 const_f64::<F>(771.323_428_777_653_1),
343 const_f64::<F>(-176.615_029_162_140_6),
344 const_f64::<F>(12.507_343_278_686_9),
345 const_f64::<F>(-0.138_571_095_265_72),
346 const_f64::<F>(9.984_369_578_019_572e-6),
347 const_f64::<F>(1.505_632_735_149_31e-7),
348 ];
349
350 let one = F::one();
351 let half = const_f64::<F>(0.5);
352 let sqrt_2pi = const_f64::<F>(2.506_628_274_631); let g = const_f64::<F>(7.0); if x < half {
357 let sinpx = (const_f64::<F>(std::f64::consts::PI) * x).sin();
358 return const_f64::<F>(std::f64::consts::PI) / (sinpx * gamma_fn(one - x));
359 }
360
361 let z = x - one;
363
364 let mut acc = const_f64::<F>(0.999_999_999_999_809_9);
366 for (i, &coef) in p.iter().enumerate() {
367 let i_f = const_f64::<F>(i as f64);
368 acc = acc + coef / (z + i_f + one);
369 }
370
371 let t = z + g + half;
372 sqrt_2pi * t.powf(z + half) * (-t).exp() * acc
373}
374
375#[allow(dead_code)]
377fn initial_beta_quantile_guess<F: Float + NumCast>(p: F, alpha: F, beta: F) -> F {
378 let zero = F::zero();
379 let one = F::one();
380
381 if alpha == one && beta == one {
383 return p;
385 }
386
387 if alpha > const_f64::<F>(8.0) && beta > const_f64::<F>(8.0) {
389 let mu = alpha / (alpha + beta);
391 let sigma =
392 (alpha * beta / ((alpha + beta) * (alpha + beta) * (alpha + beta + one))).sqrt();
393
394 let z = normal_quantile_approx(p);
395 return (mu + z * sigma).max(zero).min(one);
396 }
397
398 if (alpha - beta).abs() < const_f64::<F>(0.01) {
400 if p <= const_f64::<F>(0.5) {
401 return p.powf(one / alpha);
402 } else {
403 return one - (one - p).powf(one / alpha);
404 }
405 }
406
407 if alpha == one && beta == one {
409 return p;
410 }
411
412 if p < const_f64::<F>(0.5) {
414 let approx = p.powf(one / alpha);
416 approx
417 .max(const_f64::<F>(1e-10))
418 .min(one - const_f64::<F>(1e-10))
419 } else {
420 let approx = one - ((one - p).powf(one / beta));
422 approx
423 .max(const_f64::<F>(1e-10))
424 .min(one - const_f64::<F>(1e-10))
425 }
426}
427
428#[allow(dead_code)]
431fn regularized_incomplete_beta<F: Float + NumCast>(x: F, a: F, b: F) -> F {
432 if x <= F::zero() {
433 return F::zero();
434 }
435 if x >= F::one() {
436 return F::one();
437 }
438
439 let one = F::one();
440 let two = const_f64::<F>(2.0);
441 let epsilon = const_f64::<F>(1e-14);
442 let tiny = const_f64::<F>(1e-30);
443 let max_iterations = 300;
444
445 let threshold = (a + one) / (a + b + two);
448 let use_symmetry = x > threshold;
449
450 let (x_cf, a_cf, b_cf) = if use_symmetry {
451 (one - x, b, a)
452 } else {
453 (x, a, b)
454 };
455
456 let ln_prefactor =
459 a_cf * x_cf.ln() + b_cf * (one - x_cf).ln() - a_cf.ln() - ln_beta_fn(a_cf, b_cf);
460 let prefactor = ln_prefactor.exp();
461
462 let mut f = one;
464 let mut c = one;
465 let mut d = one - (a_cf + b_cf) * x_cf / (a_cf + one);
466 if d.abs() < tiny {
467 d = tiny;
468 }
469 d = one / d;
470 f = d;
471
472 for m in 1..=max_iterations {
473 let m_f = const_f64::<F>(m as f64);
474
475 let two_m = two * m_f;
477 let num_even = m_f * (b_cf - m_f) * x_cf / ((a_cf + two_m - one) * (a_cf + two_m));
478
479 d = one + num_even * d;
480 if d.abs() < tiny {
481 d = tiny;
482 }
483 c = one + num_even / c;
484 if c.abs() < tiny {
485 c = tiny;
486 }
487 d = one / d;
488 let delta = c * d;
489 f = f * delta;
490
491 let num_odd =
493 -(a_cf + m_f) * (a_cf + b_cf + m_f) * x_cf / ((a_cf + two_m) * (a_cf + two_m + one));
494
495 d = one + num_odd * d;
496 if d.abs() < tiny {
497 d = tiny;
498 }
499 c = one + num_odd / c;
500 if c.abs() < tiny {
501 c = tiny;
502 }
503 d = one / d;
504 let delta = c * d;
505 f = f * delta;
506
507 if (delta - one).abs() < epsilon {
508 break;
509 }
510 }
511
512 let result = prefactor * f;
513
514 if use_symmetry {
515 one - result
516 } else {
517 result
518 }
519}
520
521#[allow(dead_code)]
523fn ln_beta_fn<F: Float + NumCast>(a: F, b: F) -> F {
524 ln_gamma_fn(a) + ln_gamma_fn(b) - ln_gamma_fn(a + b)
525}
526
527#[allow(dead_code)]
529fn ln_gamma_fn<F: Float + NumCast>(x: F) -> F {
530 let one = F::one();
531 let half = const_f64::<F>(0.5);
532 let pi = const_f64::<F>(std::f64::consts::PI);
533
534 if x < half {
535 let sin_val = (pi * x).sin();
536 if sin_val == F::zero() {
537 return F::infinity();
538 }
539 return pi.ln() - sin_val.abs().ln() - ln_gamma_fn(one - x);
540 }
541
542 let g = const_f64::<F>(7.0);
543 let coefficients: [f64; 9] = [
544 0.99999999999980993,
545 676.5203681218851,
546 -1259.1392167224028,
547 771.32342877765313,
548 -176.61502916214059,
549 12.507343278686905,
550 -0.13857109526572012,
551 9.9843695780195716e-6,
552 1.5056327351493116e-7,
553 ];
554
555 let xx = x - one;
556 let mut sum = const_f64::<F>(coefficients[0]);
557 for (i, &c) in coefficients.iter().enumerate().skip(1) {
558 sum = sum + const_f64::<F>(c) / (xx + const_f64::<F>(i as f64));
559 }
560
561 let t = xx + g + half;
562 half * (const_f64::<F>(2.0) * pi).ln() + (xx + half) * t.ln() - t + sum.ln()
563}
564
565#[allow(dead_code)]
567fn normal_quantile_approx<F: Float + NumCast>(p: F) -> F {
568 let half = const_f64::<F>(0.5);
569
570 let p_adj = if p > half { one_minus_p(p) } else { p };
572
573 let t = (-const_f64::<F>(2.0) * p_adj.ln()).sqrt();
575
576 let c0 = const_f64::<F>(2.515517);
578 let c1 = const_f64::<F>(0.802853);
579 let c2 = const_f64::<F>(0.010328);
580 let d1 = const_f64::<F>(1.432788);
581 let d2 = const_f64::<F>(0.189269);
582 let d3 = const_f64::<F>(0.001308);
583
584 let numerator = c0 + c1 * t + c2 * t * t;
585 let denominator = F::one() + d1 * t + d2 * t * t + d3 * t * t * t;
586
587 let result = t - numerator / denominator;
588
589 if p > half {
591 -result
592 } else {
593 result
594 }
595}
596
597#[allow(dead_code)]
599fn one_minus_p<F: Float>(p: F) -> F {
600 if p < const_f64::<F>(0.5) {
601 F::one() - p
602 } else {
603 let one_minus_p = F::one() - p;
605 if one_minus_p == F::zero() {
606 const_f64::<F>(f64::MIN_POSITIVE) } else {
608 one_minus_p
609 }
610 }
611}
612
613impl<F: Float + NumCast + Debug + std::fmt::Display> SampleableDistribution<F> for Beta<F> {
615 fn rvs(&self, size: usize) -> StatsResult<Vec<F>> {
616 self.rvs_vec(size)
617 }
618}
619
620impl<F: Float + NumCast + Debug + std::fmt::Display> ScirsDist<F> for Beta<F> {
622 fn mean(&self) -> F {
624 self.alpha / (self.alpha + self.beta)
626 }
627
628 fn var(&self) -> F {
630 let sum = self.alpha + self.beta;
632 let sum_squared = sum * sum;
633 (self.alpha * self.beta) / (sum_squared * (sum + F::one())) * self.scale * self.scale
634 }
635
636 fn std(&self) -> F {
638 self.var().sqrt()
639 }
640
641 fn rvs(&self, size: usize) -> StatsResult<Array1<F>> {
643 self.rvs(size)
644 }
645
646 fn entropy(&self) -> F {
648 let bf = beta_function(self.alpha, self.beta);
654 bf.ln() + (self.scale.ln())
655 }
656}
657
658impl<F: Float + NumCast + Debug + std::fmt::Display> ContinuousDistribution<F> for Beta<F> {
660 fn pdf(&self, x: F) -> F {
662 self.pdf(x)
663 }
664
665 fn cdf(&self, x: F) -> F {
667 self.cdf(x)
668 }
669
670 fn ppf(&self, p: F) -> StatsResult<F> {
672 self.ppf(p)
673 }
674}
675
676impl<F: Float + NumCast + Debug + std::fmt::Display> ContinuousCDF<F> for Beta<F> {
677 }
679
680#[cfg(test)]
681mod tests {
682 use super::*;
683 use approx::assert_relative_eq;
684
685 #[test]
686 fn test_beta_creation() {
687 let uniform = Beta::new(1.0, 1.0, 0.0, 1.0).expect("test/example should not fail");
689 assert_eq!(uniform.alpha, 1.0);
690 assert_eq!(uniform.beta, 1.0);
691 assert_eq!(uniform.loc, 0.0);
692 assert_eq!(uniform.scale, 1.0);
693
694 let custom = Beta::new(2.0, 3.0, 1.0, 2.0).expect("test/example should not fail");
696 assert_eq!(custom.alpha, 2.0);
697 assert_eq!(custom.beta, 3.0);
698 assert_eq!(custom.loc, 1.0);
699 assert_eq!(custom.scale, 2.0);
700
701 assert!(Beta::<f64>::new(0.0, 1.0, 0.0, 1.0).is_err());
703 assert!(Beta::<f64>::new(-1.0, 1.0, 0.0, 1.0).is_err());
704 assert!(Beta::<f64>::new(1.0, 0.0, 0.0, 1.0).is_err());
705 assert!(Beta::<f64>::new(1.0, -1.0, 0.0, 1.0).is_err());
706 assert!(Beta::<f64>::new(1.0, 1.0, 0.0, 0.0).is_err());
707 assert!(Beta::<f64>::new(1.0, 1.0, 0.0, -1.0).is_err());
708 }
709
710 #[test]
711 fn test_beta_pdf() {
712 let uniform = Beta::new(1.0, 1.0, 0.0, 1.0).expect("test/example should not fail");
714 assert_relative_eq!(uniform.pdf(0.0), 1.0, epsilon = 1e-6);
715 assert_relative_eq!(uniform.pdf(0.5), 1.0, epsilon = 1e-6);
716 assert_relative_eq!(uniform.pdf(1.0), 1.0, epsilon = 1e-6);
717
718 let bell = Beta::new(2.0, 2.0, 0.0, 1.0).expect("test/example should not fail");
720 assert_relative_eq!(bell.pdf(0.0), 0.0, epsilon = 1e-10);
721 assert_relative_eq!(bell.pdf(0.5), 1.5, epsilon = 1e-6);
722 assert_relative_eq!(bell.pdf(1.0), 0.0, epsilon = 1e-10);
723
724 let skewed = Beta::new(2.0, 5.0, 0.0, 1.0).expect("test/example should not fail");
726 assert_relative_eq!(skewed.pdf(0.0), 0.0, epsilon = 1e-10);
727 assert_relative_eq!(skewed.pdf(0.2), 2.4576, epsilon = 1e-4);
729 assert_relative_eq!(skewed.pdf(1.0), 0.0, epsilon = 1e-10);
730
731 let shifted = Beta::new(2.0, 2.0, 1.0, 2.0).expect("test/example should not fail");
733 assert_relative_eq!(shifted.pdf(1.0), 0.0, epsilon = 1e-10);
734 assert_relative_eq!(shifted.pdf(2.0), 0.75, epsilon = 1e-6); assert_relative_eq!(shifted.pdf(3.0), 0.0, epsilon = 1e-10);
736 }
737
738 #[test]
739 fn test_beta_cdf() {
740 let uniform = Beta::new(1.0, 1.0, 0.0, 1.0).expect("test/example should not fail");
742 assert_relative_eq!(uniform.cdf(0.0), 0.0, epsilon = 1e-10);
743 assert_relative_eq!(uniform.cdf(0.5), 0.5, epsilon = 1e-6);
744 assert_relative_eq!(uniform.cdf(1.0), 1.0, epsilon = 1e-10);
745
746 let bell = Beta::new(2.0, 2.0, 0.0, 1.0).expect("test/example should not fail");
748 assert_relative_eq!(bell.cdf(0.0), 0.0, epsilon = 1e-10);
749 assert_relative_eq!(bell.cdf(0.5), 0.5, epsilon = 1e-6);
750 assert_relative_eq!(bell.cdf(0.8), 0.896, epsilon = 1e-3);
751 assert_relative_eq!(bell.cdf(1.0), 1.0, epsilon = 1e-10);
752
753 let skewed = Beta::new(2.0, 5.0, 0.0, 1.0).expect("test/example should not fail");
755 assert_relative_eq!(skewed.cdf(0.0), 0.0, epsilon = 1e-10);
756 assert_relative_eq!(skewed.cdf(0.2), 0.34464, epsilon = 1e-4);
758 assert_relative_eq!(skewed.cdf(1.0), 1.0, epsilon = 1e-10);
759 }
760
761 #[test]
762 fn test_beta_ppf() {
763 let uniform = Beta::new(1.0, 1.0, 0.0, 1.0).expect("test/example should not fail");
765 assert_relative_eq!(
766 uniform.ppf(0.0).expect("test/example should not fail"),
767 0.0,
768 epsilon = 1e-6
769 );
770 assert_relative_eq!(
771 uniform.ppf(0.5).expect("test/example should not fail"),
772 0.5,
773 epsilon = 1e-6
774 );
775 assert_relative_eq!(
776 uniform.ppf(1.0).expect("test/example should not fail"),
777 1.0,
778 epsilon = 1e-6
779 );
780
781 let bell = Beta::new(2.0, 2.0, 0.0, 1.0).expect("test/example should not fail");
783 assert_relative_eq!(
784 bell.ppf(0.5).expect("test/example should not fail"),
785 0.5,
786 epsilon = 1e-6
787 );
788
789 let skewed = Beta::new(2.0, 5.0, 0.0, 1.0).expect("test/example should not fail");
791 let p_at_02 = skewed.cdf(0.2);
793 let x = skewed.ppf(p_at_02).expect("test/example should not fail");
794 assert_relative_eq!(x, 0.2, epsilon = 1e-3);
795
796 let shifted = Beta::new(2.0, 2.0, 1.0, 2.0).expect("test/example should not fail");
798 assert_relative_eq!(
799 shifted.ppf(0.5).expect("test/example should not fail"),
800 2.0,
801 epsilon = 1e-6
802 );
803
804 assert!(uniform.ppf(-0.1).is_err());
806 assert!(uniform.ppf(1.1).is_err());
807 }
808
809 #[test]
810 fn test_beta_rvs() {
811 let beta = Beta::new(2.0, 3.0, 0.0, 1.0).expect("test/example should not fail");
812
813 let samples_vec = beta.rvs_vec(1000).expect("test/example should not fail");
815 let samples = beta.rvs(1000).expect("test/example should not fail");
816
817 assert_eq!(samples_vec.len(), 1000);
819 assert_eq!(samples.len(), 1000);
820
821 let sum: f64 = samples_vec.iter().sum();
823 let mean = sum / 1000.0;
824
825 assert!((mean - 0.4).abs() < 0.05);
827
828 for &sample in &samples_vec {
830 assert!(sample >= 0.0);
831 assert!(sample <= 1.0);
832 }
833
834 let sum_array: f64 = samples.iter().sum();
836 let mean_array = sum_array / 1000.0;
837 assert!((mean_array - 0.4).abs() < 0.05);
838 }
839
840 #[test]
841 fn test_beta_traits() {
842 use crate::traits::{ContinuousDistribution, Distribution};
843
844 let beta = Beta::new(2.0, 3.0, 0.0, 1.0).expect("test/example should not fail");
845
846 let mean = Distribution::mean(&beta);
848 assert_relative_eq!(mean, 0.4, epsilon = 1e-10);
849
850 let var = Distribution::var(&beta);
851 assert_relative_eq!(var, 0.04, epsilon = 1e-10);
852
853 let std = Distribution::std(&beta);
854 assert_relative_eq!(std, 0.2, epsilon = 1e-10);
855
856 let pdf = ContinuousDistribution::pdf(&beta, 0.5);
858 let direct_pdf = beta.pdf(0.5);
859 assert_relative_eq!(pdf, direct_pdf, epsilon = 1e-10);
860
861 let cdf = ContinuousDistribution::cdf(&beta, 0.5);
862 let direct_cdf = beta.cdf(0.5);
863 assert_relative_eq!(cdf, direct_cdf, epsilon = 1e-10);
864
865 let ppf = ContinuousDistribution::ppf(&beta, 0.5).expect("test/example should not fail");
866 let direct_ppf = beta.ppf(0.5).expect("test/example should not fail");
867 assert_relative_eq!(ppf, direct_ppf, epsilon = 1e-10);
868
869 let sf = beta.sf(0.5);
871 assert_relative_eq!(sf, 1.0 - beta.cdf(0.5), epsilon = 1e-10);
872 }
873
874 #[test]
875 fn test_beta_function() {
876 assert_relative_eq!(beta_function(1.0, 1.0), 1.0, epsilon = 1e-10);
878 assert_relative_eq!(beta_function(1.0, 2.0), 0.5, epsilon = 1e-10);
879 assert_relative_eq!(beta_function(2.0, 1.0), 0.5, epsilon = 1e-10);
880 assert_relative_eq!(beta_function(2.0, 3.0), 1.0 / 12.0, epsilon = 1e-10);
881 assert_relative_eq!(
882 beta_function(0.5, 0.5),
883 std::f64::consts::PI,
884 epsilon = 1e-6
885 );
886 }
887}