saorsa_core/adaptive/
beta_distribution.rs

1// Copyright 2024 Saorsa Labs Limited
2//
3// This software is dual-licensed under:
4// - GNU Affero General Public License v3.0 or later (AGPL-3.0-or-later)
5// - Commercial License
6//
7// For AGPL-3.0 license, see LICENSE-AGPL-3.0
8// For commercial licensing, contact: david@saorsalabs.com
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under these licenses is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
14//! # Beta Distribution Implementation
15//!
16//! This module provides a proper Beta distribution implementation for Thompson Sampling
17//! in the Multi-Armed Bandit routing optimization system.
18//!
19//! ## Features
20//! - Exact Beta distribution sampling using acceptance-rejection method
21//! - Fast path for special cases (uniform, degenerate)
22//! - Parameter validation and bounds checking
23//! - Thread-safe sampling with thread-local RNG
24
25use rand::Rng;
26use statrs::distribution::{Beta as StatBeta, ContinuousCDF};
27
28/// Beta distribution parameters
29#[derive(Debug, Clone, Copy)]
30pub struct BetaDistribution {
31    /// Alpha parameter (successes + 1)
32    pub alpha: f64,
33    /// Beta parameter (failures + 1)
34    pub beta: f64,
35}
36
37impl BetaDistribution {
38    /// Create a new Beta distribution
39    pub fn new(alpha: f64, beta: f64) -> Result<Self, BetaError> {
40        if alpha <= 0.0 || beta <= 0.0 {
41            return Err(BetaError::InvalidParameters {
42                alpha,
43                beta,
44                reason: "Alpha and beta must be positive".to_string(),
45            });
46        }
47
48        if !alpha.is_finite() || !beta.is_finite() {
49            return Err(BetaError::InvalidParameters {
50                alpha,
51                beta,
52                reason: "Parameters must be finite".to_string(),
53            });
54        }
55
56        Ok(Self { alpha, beta })
57    }
58
59    /// Sample from the Beta distribution
60    pub fn sample<R: Rng>(&self, rng: &mut R) -> f64 {
61        // Special cases for efficiency
62        if self.alpha == 1.0 && self.beta == 1.0 {
63            // Uniform distribution
64            return rng.r#gen::<f64>();
65        }
66
67        if self.alpha == 1.0 {
68            // Beta(1, β) = 1 - U^(1/β) where U ~ Uniform(0,1)
69            let u: f64 = rng.r#gen::<f64>();
70            return 1.0 - u.powf(1.0 / self.beta);
71        }
72
73        if self.beta == 1.0 {
74            // Beta(α, 1) = U^(1/α) where U ~ Uniform(0,1)
75            let u: f64 = rng.r#gen::<f64>();
76            return u.powf(1.0 / self.alpha);
77        }
78
79        // General case: use Gamma distribution method
80        // Beta(α, β) = Gamma(α) / (Gamma(α) + Gamma(β))
81        let gamma_alpha = sample_gamma(self.alpha, rng);
82        let gamma_beta = sample_gamma(self.beta, rng);
83
84        gamma_alpha / (gamma_alpha + gamma_beta)
85    }
86
87    /// Get the mean of the distribution
88    pub fn mean(&self) -> f64 {
89        self.alpha / (self.alpha + self.beta)
90    }
91
92    /// Get the variance of the distribution
93    pub fn variance(&self) -> f64 {
94        let sum = self.alpha + self.beta;
95        (self.alpha * self.beta) / (sum * sum * (sum + 1.0))
96    }
97
98    /// Get the mode of the distribution (if it exists)
99    pub fn mode(&self) -> Option<f64> {
100        if self.alpha > 1.0 && self.beta > 1.0 {
101            Some((self.alpha - 1.0) / (self.alpha + self.beta - 2.0))
102        } else if self.alpha == 1.0 && self.beta == 1.0 {
103            // Uniform distribution, any value in [0,1] is a mode
104            Some(0.5)
105        } else if self.alpha < 1.0 && self.beta < 1.0 {
106            // Bimodal at 0 and 1
107            None
108        } else if self.alpha < 1.0 {
109            Some(0.0)
110        } else if self.beta < 1.0 {
111            Some(1.0)
112        } else {
113            None
114        }
115    }
116
117    /// Update parameters based on success/failure
118    pub fn update(&mut self, success: bool) {
119        if success {
120            self.alpha += 1.0;
121        } else {
122            self.beta += 1.0;
123        }
124    }
125
126    /// Get the 95% confidence interval
127    pub fn confidence_interval(&self) -> (f64, f64) {
128        const LOWER_QUANTILE: f64 = 0.05;
129        const UPPER_QUANTILE: f64 = 0.95;
130
131        match StatBeta::new(self.alpha, self.beta) {
132            Ok(beta) => {
133                let lower = beta.inverse_cdf(LOWER_QUANTILE).clamp(0.0, 1.0);
134                let upper = beta.inverse_cdf(UPPER_QUANTILE).clamp(0.0, 1.0);
135                (lower, upper)
136            }
137            Err(_) => (0.0, 1.0),
138        }
139    }
140}
141
142/// Sample from Gamma distribution using Marsaglia and Tsang's method
143#[allow(clippy::many_single_char_names)]
144fn sample_gamma<R: Rng>(shape: f64, rng: &mut R) -> f64 {
145    if shape < 1.0 {
146        // Use Johnk's algorithm for shape < 1
147        let u: f64 = rng.r#gen::<f64>();
148        sample_gamma(1.0 + shape, rng) * u.powf(1.0 / shape)
149    } else {
150        // Marsaglia and Tsang's method for shape >= 1
151        let d = shape - 1.0 / 3.0;
152        let c = 1.0 / (9.0 * d).sqrt();
153
154        loop {
155            let mut x;
156            let mut v;
157
158            loop {
159                // Generate a standard normal sample using Box-Muller
160                let (z, ok) = standard_normal(rng);
161                if ok {
162                    x = z;
163                } else {
164                    continue;
165                }
166                v = 1.0 + c * x;
167                if v > 0.0 {
168                    break;
169                }
170            }
171
172            v = v * v * v;
173            let u: f64 = rng.r#gen::<f64>();
174
175            if u < 1.0 - 0.0331 * x * x * x * x {
176                return d * v;
177            }
178
179            if u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
180                return d * v;
181            }
182        }
183    }
184}
185
186/// Generate a single standard normal N(0,1) value via Box-Muller transform.
187/// Returns (z, true) on success; (0.0, false) if a retry is needed due to log(0).
188fn standard_normal<R: Rng>(rng: &mut R) -> (f64, bool) {
189    let u1: f64 = rng.r#gen::<f64>();
190    let u2: f64 = rng.r#gen::<f64>();
191    // Avoid u1 == 0 which would cause ln(0)
192    if u1 <= f64::MIN_POSITIVE {
193        return (0.0, false);
194    }
195    let r = (-2.0_f64 * u1.ln()).sqrt();
196    let theta = 2.0 * std::f64::consts::PI * u2;
197    (r * theta.cos(), true)
198}
199
200/// Errors that can occur with Beta distribution
201#[derive(Debug, Clone)]
202pub enum BetaError {
203    /// Invalid parameters provided
204    InvalidParameters {
205        alpha: f64,
206        beta: f64,
207        reason: String,
208    },
209}
210
211impl std::fmt::Display for BetaError {
212    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
213        match self {
214            BetaError::InvalidParameters {
215                alpha,
216                beta,
217                reason,
218            } => {
219                write!(
220                    f,
221                    "Invalid Beta parameters (α={}, β={}): {}",
222                    alpha, beta, reason
223                )
224            }
225        }
226    }
227}
228
229impl std::error::Error for BetaError {}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use rand::thread_rng;
235
236    #[test]
237    fn test_beta_distribution_creation() {
238        // Valid parameters
239        let dist = BetaDistribution::new(2.0, 3.0);
240        assert!(dist.is_ok());
241
242        // Invalid parameters
243        assert!(BetaDistribution::new(0.0, 1.0).is_err());
244        assert!(BetaDistribution::new(1.0, -1.0).is_err());
245        assert!(BetaDistribution::new(f64::INFINITY, 1.0).is_err());
246        assert!(BetaDistribution::new(1.0, f64::NAN).is_err());
247    }
248
249    #[test]
250    fn test_beta_distribution_sampling() {
251        let mut rng = thread_rng();
252        let dist = BetaDistribution::new(2.0, 5.0).unwrap();
253
254        // Sample should be in [0, 1]
255        for _ in 0..1000 {
256            let sample = dist.sample(&mut rng);
257            assert!(sample >= 0.0);
258            assert!(sample <= 1.0);
259            assert!(sample.is_finite());
260        }
261    }
262
263    #[test]
264    fn test_beta_distribution_special_cases() {
265        let mut rng = thread_rng();
266
267        // Uniform distribution
268        let uniform = BetaDistribution::new(1.0, 1.0).unwrap();
269        let samples: Vec<f64> = (0..1000).map(|_| uniform.sample(&mut rng)).collect();
270        let mean = samples.iter().sum::<f64>() / samples.len() as f64;
271        assert!((mean - 0.5).abs() < 0.05); // Should be close to 0.5
272
273        // Beta(1, β)
274        let beta_1_b = BetaDistribution::new(1.0, 3.0).unwrap();
275        for _ in 0..100 {
276            let sample = beta_1_b.sample(&mut rng);
277            assert!((0.0..=1.0).contains(&sample));
278        }
279
280        // Beta(α, 1)
281        let beta_a_1 = BetaDistribution::new(3.0, 1.0).unwrap();
282        for _ in 0..100 {
283            let sample = beta_a_1.sample(&mut rng);
284            assert!((0.0..=1.0).contains(&sample));
285        }
286    }
287
288    #[test]
289    fn test_beta_distribution_moments() {
290        let dist = BetaDistribution::new(2.0, 5.0).unwrap();
291
292        // Test mean
293        assert_eq!(dist.mean(), 2.0 / 7.0);
294
295        // Test variance
296        let expected_variance = (2.0 * 5.0) / (7.0 * 7.0 * 8.0);
297        assert!((dist.variance() - expected_variance).abs() < 1e-10);
298
299        // Test mode
300        let mode = dist.mode().unwrap();
301        assert_eq!(mode, 1.0 / 5.0); // (α-1)/(α+β-2) = 1/5
302    }
303
304    #[test]
305    fn test_beta_parameter_updates() {
306        let mut dist = BetaDistribution::new(1.0, 1.0).unwrap();
307
308        // Success increases alpha
309        dist.update(true);
310        assert_eq!(dist.alpha, 2.0);
311        assert_eq!(dist.beta, 1.0);
312
313        // Failure increases beta
314        dist.update(false);
315        assert_eq!(dist.alpha, 2.0);
316        assert_eq!(dist.beta, 2.0);
317    }
318
319    #[test]
320    fn test_beta_confidence_interval() {
321        // Small parameters
322        let dist_small = BetaDistribution::new(2.0, 3.0).unwrap();
323        let (lower, upper) = dist_small.confidence_interval();
324        assert!(lower >= 0.0);
325        assert!(upper <= 1.0);
326        assert!(lower < upper);
327
328        // Large parameters (uses normal approximation)
329        let dist_large = BetaDistribution::new(50.0, 40.0).unwrap();
330        let (lower, upper) = dist_large.confidence_interval();
331        let mean = dist_large.mean();
332        assert!(lower < mean);
333        assert!(mean < upper);
334    }
335
336    #[test]
337    fn test_beta_distribution_convergence() {
338        // Test that sampling converges to expected mean
339        let mut rng = thread_rng();
340        let dist = BetaDistribution::new(3.0, 7.0).unwrap();
341        let expected_mean = dist.mean();
342
343        let n_samples = 10000;
344        let samples: Vec<f64> = (0..n_samples).map(|_| dist.sample(&mut rng)).collect();
345        let sample_mean = samples.iter().sum::<f64>() / n_samples as f64;
346
347        // Should converge within 1% of expected mean
348        assert!((sample_mean - expected_mean).abs() < 0.01);
349    }
350
351    #[test]
352    fn test_mode_edge_cases() {
353        // Uniform distribution
354        let uniform = BetaDistribution::new(1.0, 1.0).unwrap();
355        assert_eq!(uniform.mode(), Some(0.5));
356
357        // Mode at 0
358        let mode_0 = BetaDistribution::new(0.5, 2.0).unwrap();
359        assert_eq!(mode_0.mode(), Some(0.0));
360
361        // Mode at 1
362        let mode_1 = BetaDistribution::new(2.0, 0.5).unwrap();
363        assert_eq!(mode_1.mode(), Some(1.0));
364
365        // Bimodal (no single mode)
366        let bimodal = BetaDistribution::new(0.5, 0.5).unwrap();
367        assert_eq!(bimodal.mode(), None);
368    }
369}