saorsa_core/adaptive/
beta_distribution.rs1use rand::Rng;
26
27#[derive(Debug, Clone, Copy)]
29pub struct BetaDistribution {
30 pub alpha: f64,
32 pub beta: f64,
34}
35
36impl BetaDistribution {
37 pub fn new(alpha: f64, beta: f64) -> Result<Self, BetaError> {
39 if alpha <= 0.0 || beta <= 0.0 {
40 return Err(BetaError::InvalidParameters {
41 alpha,
42 beta,
43 reason: "Alpha and beta must be positive".to_string(),
44 });
45 }
46
47 if !alpha.is_finite() || !beta.is_finite() {
48 return Err(BetaError::InvalidParameters {
49 alpha,
50 beta,
51 reason: "Parameters must be finite".to_string(),
52 });
53 }
54
55 Ok(Self { alpha, beta })
56 }
57
58 pub fn sample<R: Rng>(&self, rng: &mut R) -> f64 {
60 if self.alpha == 1.0 && self.beta == 1.0 {
62 return rng.r#gen::<f64>();
64 }
65
66 if self.alpha == 1.0 {
67 let u: f64 = rng.r#gen::<f64>();
69 return 1.0 - u.powf(1.0 / self.beta);
70 }
71
72 if self.beta == 1.0 {
73 let u: f64 = rng.r#gen::<f64>();
75 return u.powf(1.0 / self.alpha);
76 }
77
78 let gamma_alpha = sample_gamma(self.alpha, rng);
81 let gamma_beta = sample_gamma(self.beta, rng);
82
83 gamma_alpha / (gamma_alpha + gamma_beta)
84 }
85
86 pub fn mean(&self) -> f64 {
88 self.alpha / (self.alpha + self.beta)
89 }
90
91 pub fn variance(&self) -> f64 {
93 let sum = self.alpha + self.beta;
94 (self.alpha * self.beta) / (sum * sum * (sum + 1.0))
95 }
96
97 pub fn mode(&self) -> Option<f64> {
99 if self.alpha > 1.0 && self.beta > 1.0 {
100 Some((self.alpha - 1.0) / (self.alpha + self.beta - 2.0))
101 } else if self.alpha == 1.0 && self.beta == 1.0 {
102 Some(0.5)
104 } else if self.alpha < 1.0 && self.beta < 1.0 {
105 None
107 } else if self.alpha < 1.0 {
108 Some(0.0)
109 } else if self.beta < 1.0 {
110 Some(1.0)
111 } else {
112 None
113 }
114 }
115
116 pub fn update(&mut self, success: bool) {
118 if success {
119 self.alpha += 1.0;
120 } else {
121 self.beta += 1.0;
122 }
123 }
124
125 pub fn confidence_interval(&self) -> (f64, f64) {
127 if self.alpha > 30.0 && self.beta > 30.0 {
129 let mean = self.mean();
130 let std_dev = self.variance().sqrt();
131 let margin = 1.96 * std_dev; ((mean - margin).max(0.0), (mean + margin).min(1.0))
133 } else {
134 let lower = self.alpha / (self.alpha + self.beta + 2.0);
137 let upper = (self.alpha + 1.0) / (self.alpha + self.beta + 1.0);
138 (lower, upper)
139 }
140 }
141}
142
143#[allow(clippy::many_single_char_names)]
145fn sample_gamma<R: Rng>(shape: f64, rng: &mut R) -> f64 {
146 if shape < 1.0 {
147 let u: f64 = rng.r#gen::<f64>();
149 sample_gamma(1.0 + shape, rng) * u.powf(1.0 / shape)
150 } else {
151 let d = shape - 1.0 / 3.0;
153 let c = 1.0 / (9.0 * d).sqrt();
154
155 loop {
156 let mut x;
157 let mut v;
158
159 loop {
160 let (z, ok) = standard_normal(rng);
162 if ok {
163 x = z;
164 } else {
165 continue;
166 }
167 v = 1.0 + c * x;
168 if v > 0.0 {
169 break;
170 }
171 }
172
173 v = v * v * v;
174 let u: f64 = rng.r#gen::<f64>();
175
176 if u < 1.0 - 0.0331 * x * x * x * x {
177 return d * v;
178 }
179
180 if u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
181 return d * v;
182 }
183 }
184 }
185}
186
187fn standard_normal<R: Rng>(rng: &mut R) -> (f64, bool) {
190 let u1: f64 = rng.r#gen::<f64>();
191 let u2: f64 = rng.r#gen::<f64>();
192 if u1 <= f64::MIN_POSITIVE {
194 return (0.0, false);
195 }
196 let r = (-2.0_f64 * u1.ln()).sqrt();
197 let theta = 2.0 * std::f64::consts::PI * u2;
198 (r * theta.cos(), true)
199}
200
201#[derive(Debug, Clone)]
203pub enum BetaError {
204 InvalidParameters {
206 alpha: f64,
207 beta: f64,
208 reason: String,
209 },
210}
211
212impl std::fmt::Display for BetaError {
213 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214 match self {
215 BetaError::InvalidParameters {
216 alpha,
217 beta,
218 reason,
219 } => {
220 write!(
221 f,
222 "Invalid Beta parameters (α={}, β={}): {}",
223 alpha, beta, reason
224 )
225 }
226 }
227 }
228}
229
230impl std::error::Error for BetaError {}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use rand::thread_rng;
236
237 #[test]
238 fn test_beta_distribution_creation() {
239 let dist = BetaDistribution::new(2.0, 3.0);
241 assert!(dist.is_ok());
242
243 assert!(BetaDistribution::new(0.0, 1.0).is_err());
245 assert!(BetaDistribution::new(1.0, -1.0).is_err());
246 assert!(BetaDistribution::new(f64::INFINITY, 1.0).is_err());
247 assert!(BetaDistribution::new(1.0, f64::NAN).is_err());
248 }
249
250 #[test]
251 fn test_beta_distribution_sampling() {
252 let mut rng = thread_rng();
253 let dist = BetaDistribution::new(2.0, 5.0).unwrap();
254
255 for _ in 0..1000 {
257 let sample = dist.sample(&mut rng);
258 assert!(sample >= 0.0);
259 assert!(sample <= 1.0);
260 assert!(sample.is_finite());
261 }
262 }
263
264 #[test]
265 fn test_beta_distribution_special_cases() {
266 let mut rng = thread_rng();
267
268 let uniform = BetaDistribution::new(1.0, 1.0).unwrap();
270 let samples: Vec<f64> = (0..1000).map(|_| uniform.sample(&mut rng)).collect();
271 let mean = samples.iter().sum::<f64>() / samples.len() as f64;
272 assert!((mean - 0.5).abs() < 0.05); let beta_1_b = BetaDistribution::new(1.0, 3.0).unwrap();
276 for _ in 0..100 {
277 let sample = beta_1_b.sample(&mut rng);
278 assert!((0.0..=1.0).contains(&sample));
279 }
280
281 let beta_a_1 = BetaDistribution::new(3.0, 1.0).unwrap();
283 for _ in 0..100 {
284 let sample = beta_a_1.sample(&mut rng);
285 assert!((0.0..=1.0).contains(&sample));
286 }
287 }
288
289 #[test]
290 fn test_beta_distribution_moments() {
291 let dist = BetaDistribution::new(2.0, 5.0).unwrap();
292
293 assert_eq!(dist.mean(), 2.0 / 7.0);
295
296 let expected_variance = (2.0 * 5.0) / (7.0 * 7.0 * 8.0);
298 assert!((dist.variance() - expected_variance).abs() < 1e-10);
299
300 let mode = dist.mode().unwrap();
302 assert_eq!(mode, 1.0 / 5.0); }
304
305 #[test]
306 fn test_beta_parameter_updates() {
307 let mut dist = BetaDistribution::new(1.0, 1.0).unwrap();
308
309 dist.update(true);
311 assert_eq!(dist.alpha, 2.0);
312 assert_eq!(dist.beta, 1.0);
313
314 dist.update(false);
316 assert_eq!(dist.alpha, 2.0);
317 assert_eq!(dist.beta, 2.0);
318 }
319
320 #[test]
321 fn test_beta_confidence_interval() {
322 let dist_small = BetaDistribution::new(2.0, 3.0).unwrap();
324 let (lower, upper) = dist_small.confidence_interval();
325 assert!(lower >= 0.0);
326 assert!(upper <= 1.0);
327 assert!(lower < upper);
328
329 let dist_large = BetaDistribution::new(50.0, 40.0).unwrap();
331 let (lower, upper) = dist_large.confidence_interval();
332 let mean = dist_large.mean();
333 assert!(lower < mean);
334 assert!(mean < upper);
335 }
336
337 #[test]
338 fn test_beta_distribution_convergence() {
339 let mut rng = thread_rng();
341 let dist = BetaDistribution::new(3.0, 7.0).unwrap();
342 let expected_mean = dist.mean();
343
344 let n_samples = 10000;
345 let samples: Vec<f64> = (0..n_samples).map(|_| dist.sample(&mut rng)).collect();
346 let sample_mean = samples.iter().sum::<f64>() / n_samples as f64;
347
348 assert!((sample_mean - expected_mean).abs() < 0.01);
350 }
351
352 #[test]
353 fn test_mode_edge_cases() {
354 let uniform = BetaDistribution::new(1.0, 1.0).unwrap();
356 assert_eq!(uniform.mode(), Some(0.5));
357
358 let mode_0 = BetaDistribution::new(0.5, 2.0).unwrap();
360 assert_eq!(mode_0.mode(), Some(0.0));
361
362 let mode_1 = BetaDistribution::new(2.0, 0.5).unwrap();
364 assert_eq!(mode_1.mode(), Some(1.0));
365
366 let bimodal = BetaDistribution::new(0.5, 0.5).unwrap();
368 assert_eq!(bimodal.mode(), None);
369 }
370}