saorsa_core/adaptive/
beta_distribution.rs

1// Copyright 2024 Saorsa Labs Limited
2//
3// This software is dual-licensed under:
4// - GNU Affero General Public License v3.0 or later (AGPL-3.0-or-later)
5// - Commercial License
6//
7// For AGPL-3.0 license, see LICENSE-AGPL-3.0
8// For commercial licensing, contact: saorsalabs@gmail.com
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under these licenses is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
14//! # Beta Distribution Implementation
15//!
16//! This module provides a proper Beta distribution implementation for Thompson Sampling
17//! in the Multi-Armed Bandit routing optimization system.
18//!
19//! ## Features
20//! - Exact Beta distribution sampling using acceptance-rejection method
21//! - Fast path for special cases (uniform, degenerate)
22//! - Parameter validation and bounds checking
23//! - Thread-safe sampling with thread-local RNG
24
25use rand::Rng;
26
27/// Beta distribution parameters
28#[derive(Debug, Clone, Copy)]
29pub struct BetaDistribution {
30    /// Alpha parameter (successes + 1)
31    pub alpha: f64,
32    /// Beta parameter (failures + 1)
33    pub beta: f64,
34}
35
36impl BetaDistribution {
37    /// Create a new Beta distribution
38    pub fn new(alpha: f64, beta: f64) -> Result<Self, BetaError> {
39        if alpha <= 0.0 || beta <= 0.0 {
40            return Err(BetaError::InvalidParameters {
41                alpha,
42                beta,
43                reason: "Alpha and beta must be positive".to_string(),
44            });
45        }
46
47        if !alpha.is_finite() || !beta.is_finite() {
48            return Err(BetaError::InvalidParameters {
49                alpha,
50                beta,
51                reason: "Parameters must be finite".to_string(),
52            });
53        }
54
55        Ok(Self { alpha, beta })
56    }
57
58    /// Sample from the Beta distribution
59    pub fn sample<R: Rng>(&self, rng: &mut R) -> f64 {
60        // Special cases for efficiency
61        if self.alpha == 1.0 && self.beta == 1.0 {
62            // Uniform distribution
63            return rng.r#gen::<f64>();
64        }
65
66        if self.alpha == 1.0 {
67            // Beta(1, β) = 1 - U^(1/β) where U ~ Uniform(0,1)
68            let u: f64 = rng.r#gen::<f64>();
69            return 1.0 - u.powf(1.0 / self.beta);
70        }
71
72        if self.beta == 1.0 {
73            // Beta(α, 1) = U^(1/α) where U ~ Uniform(0,1)
74            let u: f64 = rng.r#gen::<f64>();
75            return u.powf(1.0 / self.alpha);
76        }
77
78        // General case: use Gamma distribution method
79        // Beta(α, β) = Gamma(α) / (Gamma(α) + Gamma(β))
80        let gamma_alpha = sample_gamma(self.alpha, rng);
81        let gamma_beta = sample_gamma(self.beta, rng);
82
83        gamma_alpha / (gamma_alpha + gamma_beta)
84    }
85
86    /// Get the mean of the distribution
87    pub fn mean(&self) -> f64 {
88        self.alpha / (self.alpha + self.beta)
89    }
90
91    /// Get the variance of the distribution
92    pub fn variance(&self) -> f64 {
93        let sum = self.alpha + self.beta;
94        (self.alpha * self.beta) / (sum * sum * (sum + 1.0))
95    }
96
97    /// Get the mode of the distribution (if it exists)
98    pub fn mode(&self) -> Option<f64> {
99        if self.alpha > 1.0 && self.beta > 1.0 {
100            Some((self.alpha - 1.0) / (self.alpha + self.beta - 2.0))
101        } else if self.alpha == 1.0 && self.beta == 1.0 {
102            // Uniform distribution, any value in [0,1] is a mode
103            Some(0.5)
104        } else if self.alpha < 1.0 && self.beta < 1.0 {
105            // Bimodal at 0 and 1
106            None
107        } else if self.alpha < 1.0 {
108            Some(0.0)
109        } else if self.beta < 1.0 {
110            Some(1.0)
111        } else {
112            None
113        }
114    }
115
116    /// Update parameters based on success/failure
117    pub fn update(&mut self, success: bool) {
118        if success {
119            self.alpha += 1.0;
120        } else {
121            self.beta += 1.0;
122        }
123    }
124
125    /// Get the 95% confidence interval
126    pub fn confidence_interval(&self) -> (f64, f64) {
127        // Use normal approximation for large parameters
128        if self.alpha > 30.0 && self.beta > 30.0 {
129            let mean = self.mean();
130            let std_dev = self.variance().sqrt();
131            let margin = 1.96 * std_dev; // 95% CI
132            ((mean - margin).max(0.0), (mean + margin).min(1.0))
133        } else {
134            // For small parameters, use quantile approximation
135            // This is a simplified version; for exact quantiles use incomplete beta function
136            let lower = self.alpha / (self.alpha + self.beta + 2.0);
137            let upper = (self.alpha + 1.0) / (self.alpha + self.beta + 1.0);
138            (lower, upper)
139        }
140    }
141}
142
143/// Sample from Gamma distribution using Marsaglia and Tsang's method
144#[allow(clippy::many_single_char_names)]
145fn sample_gamma<R: Rng>(shape: f64, rng: &mut R) -> f64 {
146    if shape < 1.0 {
147        // Use Johnk's algorithm for shape < 1
148        let u: f64 = rng.r#gen::<f64>();
149        sample_gamma(1.0 + shape, rng) * u.powf(1.0 / shape)
150    } else {
151        // Marsaglia and Tsang's method for shape >= 1
152        let d = shape - 1.0 / 3.0;
153        let c = 1.0 / (9.0 * d).sqrt();
154
155        loop {
156            let mut x;
157            let mut v;
158
159            loop {
160                // Generate a standard normal sample using Box-Muller
161                let (z, ok) = standard_normal(rng);
162                if ok {
163                    x = z;
164                } else {
165                    continue;
166                }
167                v = 1.0 + c * x;
168                if v > 0.0 {
169                    break;
170                }
171            }
172
173            v = v * v * v;
174            let u: f64 = rng.r#gen::<f64>();
175
176            if u < 1.0 - 0.0331 * x * x * x * x {
177                return d * v;
178            }
179
180            if u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
181                return d * v;
182            }
183        }
184    }
185}
186
187/// Generate a single standard normal N(0,1) value via Box-Muller transform.
188/// Returns (z, true) on success; (0.0, false) if a retry is needed due to log(0).
189fn standard_normal<R: Rng>(rng: &mut R) -> (f64, bool) {
190    let u1: f64 = rng.r#gen::<f64>();
191    let u2: f64 = rng.r#gen::<f64>();
192    // Avoid u1 == 0 which would cause ln(0)
193    if u1 <= f64::MIN_POSITIVE {
194        return (0.0, false);
195    }
196    let r = (-2.0_f64 * u1.ln()).sqrt();
197    let theta = 2.0 * std::f64::consts::PI * u2;
198    (r * theta.cos(), true)
199}
200
201/// Errors that can occur with Beta distribution
202#[derive(Debug, Clone)]
203pub enum BetaError {
204    /// Invalid parameters provided
205    InvalidParameters {
206        alpha: f64,
207        beta: f64,
208        reason: String,
209    },
210}
211
212impl std::fmt::Display for BetaError {
213    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214        match self {
215            BetaError::InvalidParameters {
216                alpha,
217                beta,
218                reason,
219            } => {
220                write!(
221                    f,
222                    "Invalid Beta parameters (α={}, β={}): {}",
223                    alpha, beta, reason
224                )
225            }
226        }
227    }
228}
229
230impl std::error::Error for BetaError {}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use rand::thread_rng;
236
237    #[test]
238    fn test_beta_distribution_creation() {
239        // Valid parameters
240        let dist = BetaDistribution::new(2.0, 3.0);
241        assert!(dist.is_ok());
242
243        // Invalid parameters
244        assert!(BetaDistribution::new(0.0, 1.0).is_err());
245        assert!(BetaDistribution::new(1.0, -1.0).is_err());
246        assert!(BetaDistribution::new(f64::INFINITY, 1.0).is_err());
247        assert!(BetaDistribution::new(1.0, f64::NAN).is_err());
248    }
249
250    #[test]
251    fn test_beta_distribution_sampling() {
252        let mut rng = thread_rng();
253        let dist = BetaDistribution::new(2.0, 5.0).unwrap();
254
255        // Sample should be in [0, 1]
256        for _ in 0..1000 {
257            let sample = dist.sample(&mut rng);
258            assert!(sample >= 0.0);
259            assert!(sample <= 1.0);
260            assert!(sample.is_finite());
261        }
262    }
263
264    #[test]
265    fn test_beta_distribution_special_cases() {
266        let mut rng = thread_rng();
267
268        // Uniform distribution
269        let uniform = BetaDistribution::new(1.0, 1.0).unwrap();
270        let samples: Vec<f64> = (0..1000).map(|_| uniform.sample(&mut rng)).collect();
271        let mean = samples.iter().sum::<f64>() / samples.len() as f64;
272        assert!((mean - 0.5).abs() < 0.05); // Should be close to 0.5
273
274        // Beta(1, β)
275        let beta_1_b = BetaDistribution::new(1.0, 3.0).unwrap();
276        for _ in 0..100 {
277            let sample = beta_1_b.sample(&mut rng);
278            assert!((0.0..=1.0).contains(&sample));
279        }
280
281        // Beta(α, 1)
282        let beta_a_1 = BetaDistribution::new(3.0, 1.0).unwrap();
283        for _ in 0..100 {
284            let sample = beta_a_1.sample(&mut rng);
285            assert!((0.0..=1.0).contains(&sample));
286        }
287    }
288
289    #[test]
290    fn test_beta_distribution_moments() {
291        let dist = BetaDistribution::new(2.0, 5.0).unwrap();
292
293        // Test mean
294        assert_eq!(dist.mean(), 2.0 / 7.0);
295
296        // Test variance
297        let expected_variance = (2.0 * 5.0) / (7.0 * 7.0 * 8.0);
298        assert!((dist.variance() - expected_variance).abs() < 1e-10);
299
300        // Test mode
301        let mode = dist.mode().unwrap();
302        assert_eq!(mode, 1.0 / 5.0); // (α-1)/(α+β-2) = 1/5
303    }
304
305    #[test]
306    fn test_beta_parameter_updates() {
307        let mut dist = BetaDistribution::new(1.0, 1.0).unwrap();
308
309        // Success increases alpha
310        dist.update(true);
311        assert_eq!(dist.alpha, 2.0);
312        assert_eq!(dist.beta, 1.0);
313
314        // Failure increases beta
315        dist.update(false);
316        assert_eq!(dist.alpha, 2.0);
317        assert_eq!(dist.beta, 2.0);
318    }
319
320    #[test]
321    fn test_beta_confidence_interval() {
322        // Small parameters
323        let dist_small = BetaDistribution::new(2.0, 3.0).unwrap();
324        let (lower, upper) = dist_small.confidence_interval();
325        assert!(lower >= 0.0);
326        assert!(upper <= 1.0);
327        assert!(lower < upper);
328
329        // Large parameters (uses normal approximation)
330        let dist_large = BetaDistribution::new(50.0, 40.0).unwrap();
331        let (lower, upper) = dist_large.confidence_interval();
332        let mean = dist_large.mean();
333        assert!(lower < mean);
334        assert!(mean < upper);
335    }
336
337    #[test]
338    fn test_beta_distribution_convergence() {
339        // Test that sampling converges to expected mean
340        let mut rng = thread_rng();
341        let dist = BetaDistribution::new(3.0, 7.0).unwrap();
342        let expected_mean = dist.mean();
343
344        let n_samples = 10000;
345        let samples: Vec<f64> = (0..n_samples).map(|_| dist.sample(&mut rng)).collect();
346        let sample_mean = samples.iter().sum::<f64>() / n_samples as f64;
347
348        // Should converge within 1% of expected mean
349        assert!((sample_mean - expected_mean).abs() < 0.01);
350    }
351
352    #[test]
353    fn test_mode_edge_cases() {
354        // Uniform distribution
355        let uniform = BetaDistribution::new(1.0, 1.0).unwrap();
356        assert_eq!(uniform.mode(), Some(0.5));
357
358        // Mode at 0
359        let mode_0 = BetaDistribution::new(0.5, 2.0).unwrap();
360        assert_eq!(mode_0.mode(), Some(0.0));
361
362        // Mode at 1
363        let mode_1 = BetaDistribution::new(2.0, 0.5).unwrap();
364        assert_eq!(mode_1.mode(), Some(1.0));
365
366        // Bimodal (no single mode)
367        let bimodal = BetaDistribution::new(0.5, 0.5).unwrap();
368        assert_eq!(bimodal.mode(), None);
369    }
370}