saorsa_core/adaptive/
beta_distribution.rs1use rand::Rng;
26
27#[derive(Debug, Clone, Copy)]
29pub struct BetaDistribution {
30 pub alpha: f64,
32 pub beta: f64,
34}
35
36impl BetaDistribution {
37 pub fn new(alpha: f64, beta: f64) -> Result<Self, BetaError> {
39 if alpha <= 0.0 || beta <= 0.0 {
40 return Err(BetaError::InvalidParameters {
41 alpha,
42 beta,
43 reason: "Alpha and beta must be positive".to_string(),
44 });
45 }
46
47 if !alpha.is_finite() || !beta.is_finite() {
48 return Err(BetaError::InvalidParameters {
49 alpha,
50 beta,
51 reason: "Parameters must be finite".to_string(),
52 });
53 }
54
55 Ok(Self { alpha, beta })
56 }
57
58 pub fn sample<R: Rng>(&self, rng: &mut R) -> f64 {
60 if self.alpha == 1.0 && self.beta == 1.0 {
62 return rng.r#gen::<f64>();
64 }
65
66 if self.alpha == 1.0 {
67 let u: f64 = rng.r#gen::<f64>();
69 return 1.0 - u.powf(1.0 / self.beta);
70 }
71
72 if self.beta == 1.0 {
73 let u: f64 = rng.r#gen::<f64>();
75 return u.powf(1.0 / self.alpha);
76 }
77
78 let gamma_alpha = sample_gamma(self.alpha, rng);
81 let gamma_beta = sample_gamma(self.beta, rng);
82
83 gamma_alpha / (gamma_alpha + gamma_beta)
84 }
85
86 pub fn mean(&self) -> f64 {
88 self.alpha / (self.alpha + self.beta)
89 }
90
91 pub fn variance(&self) -> f64 {
93 let sum = self.alpha + self.beta;
94 (self.alpha * self.beta) / (sum * sum * (sum + 1.0))
95 }
96
97 pub fn mode(&self) -> Option<f64> {
99 if self.alpha > 1.0 && self.beta > 1.0 {
100 Some((self.alpha - 1.0) / (self.alpha + self.beta - 2.0))
101 } else if self.alpha == 1.0 && self.beta == 1.0 {
102 Some(0.5)
104 } else if self.alpha < 1.0 && self.beta < 1.0 {
105 None
107 } else if self.alpha < 1.0 {
108 Some(0.0)
109 } else if self.beta < 1.0 {
110 Some(1.0)
111 } else {
112 None
113 }
114 }
115
116 pub fn update(&mut self, success: bool) {
118 if success {
119 self.alpha += 1.0;
120 } else {
121 self.beta += 1.0;
122 }
123 }
124
125 pub fn confidence_interval(&self) -> (f64, f64) {
127 if self.alpha > 30.0 && self.beta > 30.0 {
129 let mean = self.mean();
130 let std_dev = self.variance().sqrt();
131 let margin = 1.96 * std_dev; ((mean - margin).max(0.0), (mean + margin).min(1.0))
133 } else {
134 let lower = self.alpha / (self.alpha + self.beta + 2.0);
137 let upper = (self.alpha + 1.0) / (self.alpha + self.beta + 1.0);
138 (lower, upper)
139 }
140 }
141}
142
143#[allow(clippy::many_single_char_names)]
145fn sample_gamma<R: Rng>(shape: f64, rng: &mut R) -> f64 {
146 if shape < 1.0 {
147 let u: f64 = rng.r#gen::<f64>();
149 sample_gamma(1.0 + shape, rng) * u.powf(1.0 / shape)
150 } else {
151 let d = shape - 1.0 / 3.0;
153 let c = 1.0 / (9.0 * d).sqrt();
154
155 loop {
156 let mut x;
157 let mut v;
158
159 loop {
160 x = rng.gen_range(-1.0..1.0); v = 1.0 + c * x;
162 if v > 0.0 {
163 break;
164 }
165 }
166
167 v = v * v * v;
168 let u: f64 = rng.r#gen::<f64>();
169
170 if u < 1.0 - 0.0331 * x * x * x * x {
171 return d * v;
172 }
173
174 if u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
175 return d * v;
176 }
177 }
178 }
179}
180
181#[derive(Debug, Clone)]
183pub enum BetaError {
184 InvalidParameters {
186 alpha: f64,
187 beta: f64,
188 reason: String,
189 },
190}
191
192impl std::fmt::Display for BetaError {
193 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194 match self {
195 BetaError::InvalidParameters {
196 alpha,
197 beta,
198 reason,
199 } => {
200 write!(
201 f,
202 "Invalid Beta parameters (α={}, β={}): {}",
203 alpha, beta, reason
204 )
205 }
206 }
207 }
208}
209
210impl std::error::Error for BetaError {}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215 use rand::thread_rng;
216
217 #[test]
218 fn test_beta_distribution_creation() {
219 let dist = BetaDistribution::new(2.0, 3.0);
221 assert!(dist.is_ok());
222
223 assert!(BetaDistribution::new(0.0, 1.0).is_err());
225 assert!(BetaDistribution::new(1.0, -1.0).is_err());
226 assert!(BetaDistribution::new(f64::INFINITY, 1.0).is_err());
227 assert!(BetaDistribution::new(1.0, f64::NAN).is_err());
228 }
229
230 #[test]
231 fn test_beta_distribution_sampling() {
232 let mut rng = thread_rng();
233 let dist = BetaDistribution::new(2.0, 5.0).unwrap();
234
235 for _ in 0..1000 {
237 let sample = dist.sample(&mut rng);
238 assert!(sample >= 0.0);
239 assert!(sample <= 1.0);
240 assert!(sample.is_finite());
241 }
242 }
243
244 #[test]
245 fn test_beta_distribution_special_cases() {
246 let mut rng = thread_rng();
247
248 let uniform = BetaDistribution::new(1.0, 1.0).unwrap();
250 let samples: Vec<f64> = (0..1000).map(|_| uniform.sample(&mut rng)).collect();
251 let mean = samples.iter().sum::<f64>() / samples.len() as f64;
252 assert!((mean - 0.5).abs() < 0.05); let beta_1_b = BetaDistribution::new(1.0, 3.0).unwrap();
256 for _ in 0..100 {
257 let sample = beta_1_b.sample(&mut rng);
258 assert!(sample >= 0.0 && sample <= 1.0);
259 }
260
261 let beta_a_1 = BetaDistribution::new(3.0, 1.0).unwrap();
263 for _ in 0..100 {
264 let sample = beta_a_1.sample(&mut rng);
265 assert!(sample >= 0.0 && sample <= 1.0);
266 }
267 }
268
269 #[test]
270 fn test_beta_distribution_moments() {
271 let dist = BetaDistribution::new(2.0, 5.0).unwrap();
272
273 assert_eq!(dist.mean(), 2.0 / 7.0);
275
276 let expected_variance = (2.0 * 5.0) / (7.0 * 7.0 * 8.0);
278 assert!((dist.variance() - expected_variance).abs() < 1e-10);
279
280 let mode = dist.mode().unwrap();
282 assert_eq!(mode, 1.0 / 5.0); }
284
285 #[test]
286 fn test_beta_parameter_updates() {
287 let mut dist = BetaDistribution::new(1.0, 1.0).unwrap();
288
289 dist.update(true);
291 assert_eq!(dist.alpha, 2.0);
292 assert_eq!(dist.beta, 1.0);
293
294 dist.update(false);
296 assert_eq!(dist.alpha, 2.0);
297 assert_eq!(dist.beta, 2.0);
298 }
299
300 #[test]
301 fn test_beta_confidence_interval() {
302 let dist_small = BetaDistribution::new(2.0, 3.0).unwrap();
304 let (lower, upper) = dist_small.confidence_interval();
305 assert!(lower >= 0.0);
306 assert!(upper <= 1.0);
307 assert!(lower < upper);
308
309 let dist_large = BetaDistribution::new(50.0, 40.0).unwrap();
311 let (lower, upper) = dist_large.confidence_interval();
312 let mean = dist_large.mean();
313 assert!(lower < mean);
314 assert!(mean < upper);
315 }
316
317 #[test]
318 fn test_beta_distribution_convergence() {
319 let mut rng = thread_rng();
321 let dist = BetaDistribution::new(3.0, 7.0).unwrap();
322 let expected_mean = dist.mean();
323
324 let n_samples = 10000;
325 let samples: Vec<f64> = (0..n_samples).map(|_| dist.sample(&mut rng)).collect();
326 let sample_mean = samples.iter().sum::<f64>() / n_samples as f64;
327
328 assert!((sample_mean - expected_mean).abs() < 0.01);
330 }
331
332 #[test]
333 fn test_mode_edge_cases() {
334 let uniform = BetaDistribution::new(1.0, 1.0).unwrap();
336 assert_eq!(uniform.mode(), Some(0.5));
337
338 let mode_0 = BetaDistribution::new(0.5, 2.0).unwrap();
340 assert_eq!(mode_0.mode(), Some(0.0));
341
342 let mode_1 = BetaDistribution::new(2.0, 0.5).unwrap();
344 assert_eq!(mode_1.mode(), Some(1.0));
345
346 let bimodal = BetaDistribution::new(0.5, 0.5).unwrap();
348 assert_eq!(bimodal.mode(), None);
349 }
350}