saorsa_core/adaptive/
beta_distribution.rs

1// Copyright 2024 Saorsa Labs Limited
2//
3// This software is dual-licensed under:
4// - GNU Affero General Public License v3.0 or later (AGPL-3.0-or-later)
5// - Commercial License
6//
7// For AGPL-3.0 license, see LICENSE-AGPL-3.0
8// For commercial licensing, contact: saorsalabs@gmail.com
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under these licenses is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
14//! # Beta Distribution Implementation
15//!
16//! This module provides a proper Beta distribution implementation for Thompson Sampling
17//! in the Multi-Armed Bandit routing optimization system.
18//!
19//! ## Features
20//! - Exact Beta distribution sampling using acceptance-rejection method
21//! - Fast path for special cases (uniform, degenerate)
22//! - Parameter validation and bounds checking
23//! - Thread-safe sampling with thread-local RNG
24
25use rand::Rng;
26
27/// Beta distribution parameters
28#[derive(Debug, Clone, Copy)]
29pub struct BetaDistribution {
30    /// Alpha parameter (successes + 1)
31    pub alpha: f64,
32    /// Beta parameter (failures + 1)
33    pub beta: f64,
34}
35
36impl BetaDistribution {
37    /// Create a new Beta distribution
38    pub fn new(alpha: f64, beta: f64) -> Result<Self, BetaError> {
39        if alpha <= 0.0 || beta <= 0.0 {
40            return Err(BetaError::InvalidParameters {
41                alpha,
42                beta,
43                reason: "Alpha and beta must be positive".to_string(),
44            });
45        }
46
47        if !alpha.is_finite() || !beta.is_finite() {
48            return Err(BetaError::InvalidParameters {
49                alpha,
50                beta,
51                reason: "Parameters must be finite".to_string(),
52            });
53        }
54
55        Ok(Self { alpha, beta })
56    }
57
58    /// Sample from the Beta distribution
59    pub fn sample<R: Rng>(&self, rng: &mut R) -> f64 {
60        // Special cases for efficiency
61        if self.alpha == 1.0 && self.beta == 1.0 {
62            // Uniform distribution
63            return rng.r#gen::<f64>();
64        }
65
66        if self.alpha == 1.0 {
67            // Beta(1, β) = 1 - U^(1/β) where U ~ Uniform(0,1)
68            let u: f64 = rng.r#gen::<f64>();
69            return 1.0 - u.powf(1.0 / self.beta);
70        }
71
72        if self.beta == 1.0 {
73            // Beta(α, 1) = U^(1/α) where U ~ Uniform(0,1)
74            let u: f64 = rng.r#gen::<f64>();
75            return u.powf(1.0 / self.alpha);
76        }
77
78        // General case: use Gamma distribution method
79        // Beta(α, β) = Gamma(α) / (Gamma(α) + Gamma(β))
80        let gamma_alpha = sample_gamma(self.alpha, rng);
81        let gamma_beta = sample_gamma(self.beta, rng);
82
83        gamma_alpha / (gamma_alpha + gamma_beta)
84    }
85
86    /// Get the mean of the distribution
87    pub fn mean(&self) -> f64 {
88        self.alpha / (self.alpha + self.beta)
89    }
90
91    /// Get the variance of the distribution
92    pub fn variance(&self) -> f64 {
93        let sum = self.alpha + self.beta;
94        (self.alpha * self.beta) / (sum * sum * (sum + 1.0))
95    }
96
97    /// Get the mode of the distribution (if it exists)
98    pub fn mode(&self) -> Option<f64> {
99        if self.alpha > 1.0 && self.beta > 1.0 {
100            Some((self.alpha - 1.0) / (self.alpha + self.beta - 2.0))
101        } else if self.alpha == 1.0 && self.beta == 1.0 {
102            // Uniform distribution, any value in [0,1] is a mode
103            Some(0.5)
104        } else if self.alpha < 1.0 && self.beta < 1.0 {
105            // Bimodal at 0 and 1
106            None
107        } else if self.alpha < 1.0 {
108            Some(0.0)
109        } else if self.beta < 1.0 {
110            Some(1.0)
111        } else {
112            None
113        }
114    }
115
116    /// Update parameters based on success/failure
117    pub fn update(&mut self, success: bool) {
118        if success {
119            self.alpha += 1.0;
120        } else {
121            self.beta += 1.0;
122        }
123    }
124
125    /// Get the 95% confidence interval
126    pub fn confidence_interval(&self) -> (f64, f64) {
127        // Use normal approximation for large parameters
128        if self.alpha > 30.0 && self.beta > 30.0 {
129            let mean = self.mean();
130            let std_dev = self.variance().sqrt();
131            let margin = 1.96 * std_dev; // 95% CI
132            ((mean - margin).max(0.0), (mean + margin).min(1.0))
133        } else {
134            // For small parameters, use quantile approximation
135            // This is a simplified version; for exact quantiles use incomplete beta function
136            let lower = self.alpha / (self.alpha + self.beta + 2.0);
137            let upper = (self.alpha + 1.0) / (self.alpha + self.beta + 1.0);
138            (lower, upper)
139        }
140    }
141}
142
143/// Sample from Gamma distribution using Marsaglia and Tsang's method
144#[allow(clippy::many_single_char_names)]
145fn sample_gamma<R: Rng>(shape: f64, rng: &mut R) -> f64 {
146    if shape < 1.0 {
147        // Use Johnk's algorithm for shape < 1
148        let u: f64 = rng.r#gen::<f64>();
149        sample_gamma(1.0 + shape, rng) * u.powf(1.0 / shape)
150    } else {
151        // Marsaglia and Tsang's method for shape >= 1
152        let d = shape - 1.0 / 3.0;
153        let c = 1.0 / (9.0 * d).sqrt();
154
155        loop {
156            let mut x;
157            let mut v;
158
159            loop {
160                x = rng.gen_range(-1.0..1.0); // Standard normal approximation
161                v = 1.0 + c * x;
162                if v > 0.0 {
163                    break;
164                }
165            }
166
167            v = v * v * v;
168            let u: f64 = rng.r#gen::<f64>();
169
170            if u < 1.0 - 0.0331 * x * x * x * x {
171                return d * v;
172            }
173
174            if u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
175                return d * v;
176            }
177        }
178    }
179}
180
181/// Errors that can occur with Beta distribution
182#[derive(Debug, Clone)]
183pub enum BetaError {
184    /// Invalid parameters provided
185    InvalidParameters {
186        alpha: f64,
187        beta: f64,
188        reason: String,
189    },
190}
191
192impl std::fmt::Display for BetaError {
193    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194        match self {
195            BetaError::InvalidParameters {
196                alpha,
197                beta,
198                reason,
199            } => {
200                write!(
201                    f,
202                    "Invalid Beta parameters (α={}, β={}): {}",
203                    alpha, beta, reason
204                )
205            }
206        }
207    }
208}
209
210impl std::error::Error for BetaError {}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215    use rand::thread_rng;
216
217    #[test]
218    fn test_beta_distribution_creation() {
219        // Valid parameters
220        let dist = BetaDistribution::new(2.0, 3.0);
221        assert!(dist.is_ok());
222
223        // Invalid parameters
224        assert!(BetaDistribution::new(0.0, 1.0).is_err());
225        assert!(BetaDistribution::new(1.0, -1.0).is_err());
226        assert!(BetaDistribution::new(f64::INFINITY, 1.0).is_err());
227        assert!(BetaDistribution::new(1.0, f64::NAN).is_err());
228    }
229
230    #[test]
231    fn test_beta_distribution_sampling() {
232        let mut rng = thread_rng();
233        let dist = BetaDistribution::new(2.0, 5.0).unwrap();
234
235        // Sample should be in [0, 1]
236        for _ in 0..1000 {
237            let sample = dist.sample(&mut rng);
238            assert!(sample >= 0.0);
239            assert!(sample <= 1.0);
240            assert!(sample.is_finite());
241        }
242    }
243
244    #[test]
245    fn test_beta_distribution_special_cases() {
246        let mut rng = thread_rng();
247
248        // Uniform distribution
249        let uniform = BetaDistribution::new(1.0, 1.0).unwrap();
250        let samples: Vec<f64> = (0..1000).map(|_| uniform.sample(&mut rng)).collect();
251        let mean = samples.iter().sum::<f64>() / samples.len() as f64;
252        assert!((mean - 0.5).abs() < 0.05); // Should be close to 0.5
253
254        // Beta(1, β)
255        let beta_1_b = BetaDistribution::new(1.0, 3.0).unwrap();
256        for _ in 0..100 {
257            let sample = beta_1_b.sample(&mut rng);
258            assert!(sample >= 0.0 && sample <= 1.0);
259        }
260
261        // Beta(α, 1)
262        let beta_a_1 = BetaDistribution::new(3.0, 1.0).unwrap();
263        for _ in 0..100 {
264            let sample = beta_a_1.sample(&mut rng);
265            assert!(sample >= 0.0 && sample <= 1.0);
266        }
267    }
268
269    #[test]
270    fn test_beta_distribution_moments() {
271        let dist = BetaDistribution::new(2.0, 5.0).unwrap();
272
273        // Test mean
274        assert_eq!(dist.mean(), 2.0 / 7.0);
275
276        // Test variance
277        let expected_variance = (2.0 * 5.0) / (7.0 * 7.0 * 8.0);
278        assert!((dist.variance() - expected_variance).abs() < 1e-10);
279
280        // Test mode
281        let mode = dist.mode().unwrap();
282        assert_eq!(mode, 1.0 / 5.0); // (α-1)/(α+β-2) = 1/5
283    }
284
285    #[test]
286    fn test_beta_parameter_updates() {
287        let mut dist = BetaDistribution::new(1.0, 1.0).unwrap();
288
289        // Success increases alpha
290        dist.update(true);
291        assert_eq!(dist.alpha, 2.0);
292        assert_eq!(dist.beta, 1.0);
293
294        // Failure increases beta
295        dist.update(false);
296        assert_eq!(dist.alpha, 2.0);
297        assert_eq!(dist.beta, 2.0);
298    }
299
300    #[test]
301    fn test_beta_confidence_interval() {
302        // Small parameters
303        let dist_small = BetaDistribution::new(2.0, 3.0).unwrap();
304        let (lower, upper) = dist_small.confidence_interval();
305        assert!(lower >= 0.0);
306        assert!(upper <= 1.0);
307        assert!(lower < upper);
308
309        // Large parameters (uses normal approximation)
310        let dist_large = BetaDistribution::new(50.0, 40.0).unwrap();
311        let (lower, upper) = dist_large.confidence_interval();
312        let mean = dist_large.mean();
313        assert!(lower < mean);
314        assert!(mean < upper);
315    }
316
317    #[test]
318    fn test_beta_distribution_convergence() {
319        // Test that sampling converges to expected mean
320        let mut rng = thread_rng();
321        let dist = BetaDistribution::new(3.0, 7.0).unwrap();
322        let expected_mean = dist.mean();
323
324        let n_samples = 10000;
325        let samples: Vec<f64> = (0..n_samples).map(|_| dist.sample(&mut rng)).collect();
326        let sample_mean = samples.iter().sum::<f64>() / n_samples as f64;
327
328        // Should converge within 1% of expected mean
329        assert!((sample_mean - expected_mean).abs() < 0.01);
330    }
331
332    #[test]
333    fn test_mode_edge_cases() {
334        // Uniform distribution
335        let uniform = BetaDistribution::new(1.0, 1.0).unwrap();
336        assert_eq!(uniform.mode(), Some(0.5));
337
338        // Mode at 0
339        let mode_0 = BetaDistribution::new(0.5, 2.0).unwrap();
340        assert_eq!(mode_0.mode(), Some(0.0));
341
342        // Mode at 1
343        let mode_1 = BetaDistribution::new(2.0, 0.5).unwrap();
344        assert_eq!(mode_1.mode(), Some(1.0));
345
346        // Bimodal (no single mode)
347        let bimodal = BetaDistribution::new(0.5, 0.5).unwrap();
348        assert_eq!(bimodal.mode(), None);
349    }
350}