saorsa_core/adaptive/
beta_distribution.rs1use rand::Rng;
26use statrs::distribution::{Beta as StatBeta, ContinuousCDF};
27
28#[derive(Debug, Clone, Copy)]
30pub struct BetaDistribution {
31 pub alpha: f64,
33 pub beta: f64,
35}
36
37impl BetaDistribution {
38 pub fn new(alpha: f64, beta: f64) -> Result<Self, BetaError> {
40 if alpha <= 0.0 || beta <= 0.0 {
41 return Err(BetaError::InvalidParameters {
42 alpha,
43 beta,
44 reason: "Alpha and beta must be positive".to_string(),
45 });
46 }
47
48 if !alpha.is_finite() || !beta.is_finite() {
49 return Err(BetaError::InvalidParameters {
50 alpha,
51 beta,
52 reason: "Parameters must be finite".to_string(),
53 });
54 }
55
56 Ok(Self { alpha, beta })
57 }
58
59 pub fn sample<R: Rng>(&self, rng: &mut R) -> f64 {
61 if self.alpha == 1.0 && self.beta == 1.0 {
63 return rng.r#gen::<f64>();
65 }
66
67 if self.alpha == 1.0 {
68 let u: f64 = rng.r#gen::<f64>();
70 return 1.0 - u.powf(1.0 / self.beta);
71 }
72
73 if self.beta == 1.0 {
74 let u: f64 = rng.r#gen::<f64>();
76 return u.powf(1.0 / self.alpha);
77 }
78
79 let gamma_alpha = sample_gamma(self.alpha, rng);
82 let gamma_beta = sample_gamma(self.beta, rng);
83
84 gamma_alpha / (gamma_alpha + gamma_beta)
85 }
86
87 pub fn mean(&self) -> f64 {
89 self.alpha / (self.alpha + self.beta)
90 }
91
92 pub fn variance(&self) -> f64 {
94 let sum = self.alpha + self.beta;
95 (self.alpha * self.beta) / (sum * sum * (sum + 1.0))
96 }
97
98 pub fn mode(&self) -> Option<f64> {
100 if self.alpha > 1.0 && self.beta > 1.0 {
101 Some((self.alpha - 1.0) / (self.alpha + self.beta - 2.0))
102 } else if self.alpha == 1.0 && self.beta == 1.0 {
103 Some(0.5)
105 } else if self.alpha < 1.0 && self.beta < 1.0 {
106 None
108 } else if self.alpha < 1.0 {
109 Some(0.0)
110 } else if self.beta < 1.0 {
111 Some(1.0)
112 } else {
113 None
114 }
115 }
116
117 pub fn update(&mut self, success: bool) {
119 if success {
120 self.alpha += 1.0;
121 } else {
122 self.beta += 1.0;
123 }
124 }
125
126 pub fn confidence_interval(&self) -> (f64, f64) {
128 const LOWER_QUANTILE: f64 = 0.05;
129 const UPPER_QUANTILE: f64 = 0.95;
130
131 match StatBeta::new(self.alpha, self.beta) {
132 Ok(beta) => {
133 let lower = beta.inverse_cdf(LOWER_QUANTILE).clamp(0.0, 1.0);
134 let upper = beta.inverse_cdf(UPPER_QUANTILE).clamp(0.0, 1.0);
135 (lower, upper)
136 }
137 Err(_) => (0.0, 1.0),
138 }
139 }
140}
141
142#[allow(clippy::many_single_char_names)]
144fn sample_gamma<R: Rng>(shape: f64, rng: &mut R) -> f64 {
145 if shape < 1.0 {
146 let u: f64 = rng.r#gen::<f64>();
148 sample_gamma(1.0 + shape, rng) * u.powf(1.0 / shape)
149 } else {
150 let d = shape - 1.0 / 3.0;
152 let c = 1.0 / (9.0 * d).sqrt();
153
154 loop {
155 let mut x;
156 let mut v;
157
158 loop {
159 let (z, ok) = standard_normal(rng);
161 if ok {
162 x = z;
163 } else {
164 continue;
165 }
166 v = 1.0 + c * x;
167 if v > 0.0 {
168 break;
169 }
170 }
171
172 v = v * v * v;
173 let u: f64 = rng.r#gen::<f64>();
174
175 if u < 1.0 - 0.0331 * x * x * x * x {
176 return d * v;
177 }
178
179 if u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
180 return d * v;
181 }
182 }
183 }
184}
185
186fn standard_normal<R: Rng>(rng: &mut R) -> (f64, bool) {
189 let u1: f64 = rng.r#gen::<f64>();
190 let u2: f64 = rng.r#gen::<f64>();
191 if u1 <= f64::MIN_POSITIVE {
193 return (0.0, false);
194 }
195 let r = (-2.0_f64 * u1.ln()).sqrt();
196 let theta = 2.0 * std::f64::consts::PI * u2;
197 (r * theta.cos(), true)
198}
199
200#[derive(Debug, Clone)]
202pub enum BetaError {
203 InvalidParameters {
205 alpha: f64,
206 beta: f64,
207 reason: String,
208 },
209}
210
211impl std::fmt::Display for BetaError {
212 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
213 match self {
214 BetaError::InvalidParameters {
215 alpha,
216 beta,
217 reason,
218 } => {
219 write!(
220 f,
221 "Invalid Beta parameters (α={}, β={}): {}",
222 alpha, beta, reason
223 )
224 }
225 }
226 }
227}
228
229impl std::error::Error for BetaError {}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use rand::thread_rng;
235
236 #[test]
237 fn test_beta_distribution_creation() {
238 let dist = BetaDistribution::new(2.0, 3.0);
240 assert!(dist.is_ok());
241
242 assert!(BetaDistribution::new(0.0, 1.0).is_err());
244 assert!(BetaDistribution::new(1.0, -1.0).is_err());
245 assert!(BetaDistribution::new(f64::INFINITY, 1.0).is_err());
246 assert!(BetaDistribution::new(1.0, f64::NAN).is_err());
247 }
248
249 #[test]
250 fn test_beta_distribution_sampling() {
251 let mut rng = thread_rng();
252 let dist = BetaDistribution::new(2.0, 5.0).unwrap();
253
254 for _ in 0..1000 {
256 let sample = dist.sample(&mut rng);
257 assert!(sample >= 0.0);
258 assert!(sample <= 1.0);
259 assert!(sample.is_finite());
260 }
261 }
262
263 #[test]
264 fn test_beta_distribution_special_cases() {
265 let mut rng = thread_rng();
266
267 let uniform = BetaDistribution::new(1.0, 1.0).unwrap();
269 let samples: Vec<f64> = (0..1000).map(|_| uniform.sample(&mut rng)).collect();
270 let mean = samples.iter().sum::<f64>() / samples.len() as f64;
271 assert!((mean - 0.5).abs() < 0.05); let beta_1_b = BetaDistribution::new(1.0, 3.0).unwrap();
275 for _ in 0..100 {
276 let sample = beta_1_b.sample(&mut rng);
277 assert!((0.0..=1.0).contains(&sample));
278 }
279
280 let beta_a_1 = BetaDistribution::new(3.0, 1.0).unwrap();
282 for _ in 0..100 {
283 let sample = beta_a_1.sample(&mut rng);
284 assert!((0.0..=1.0).contains(&sample));
285 }
286 }
287
288 #[test]
289 fn test_beta_distribution_moments() {
290 let dist = BetaDistribution::new(2.0, 5.0).unwrap();
291
292 assert_eq!(dist.mean(), 2.0 / 7.0);
294
295 let expected_variance = (2.0 * 5.0) / (7.0 * 7.0 * 8.0);
297 assert!((dist.variance() - expected_variance).abs() < 1e-10);
298
299 let mode = dist.mode().unwrap();
301 assert_eq!(mode, 1.0 / 5.0); }
303
304 #[test]
305 fn test_beta_parameter_updates() {
306 let mut dist = BetaDistribution::new(1.0, 1.0).unwrap();
307
308 dist.update(true);
310 assert_eq!(dist.alpha, 2.0);
311 assert_eq!(dist.beta, 1.0);
312
313 dist.update(false);
315 assert_eq!(dist.alpha, 2.0);
316 assert_eq!(dist.beta, 2.0);
317 }
318
319 #[test]
320 fn test_beta_confidence_interval() {
321 let dist_small = BetaDistribution::new(2.0, 3.0).unwrap();
323 let (lower, upper) = dist_small.confidence_interval();
324 assert!(lower >= 0.0);
325 assert!(upper <= 1.0);
326 assert!(lower < upper);
327
328 let dist_large = BetaDistribution::new(50.0, 40.0).unwrap();
330 let (lower, upper) = dist_large.confidence_interval();
331 let mean = dist_large.mean();
332 assert!(lower < mean);
333 assert!(mean < upper);
334 }
335
336 #[test]
337 fn test_beta_distribution_convergence() {
338 let mut rng = thread_rng();
340 let dist = BetaDistribution::new(3.0, 7.0).unwrap();
341 let expected_mean = dist.mean();
342
343 let n_samples = 10000;
344 let samples: Vec<f64> = (0..n_samples).map(|_| dist.sample(&mut rng)).collect();
345 let sample_mean = samples.iter().sum::<f64>() / n_samples as f64;
346
347 assert!((sample_mean - expected_mean).abs() < 0.01);
349 }
350
351 #[test]
352 fn test_mode_edge_cases() {
353 let uniform = BetaDistribution::new(1.0, 1.0).unwrap();
355 assert_eq!(uniform.mode(), Some(0.5));
356
357 let mode_0 = BetaDistribution::new(0.5, 2.0).unwrap();
359 assert_eq!(mode_0.mode(), Some(0.0));
360
361 let mode_1 = BetaDistribution::new(2.0, 0.5).unwrap();
363 assert_eq!(mode_1.mode(), Some(1.0));
364
365 let bimodal = BetaDistribution::new(0.5, 0.5).unwrap();
367 assert_eq!(bimodal.mode(), None);
368 }
369}