Skip to main content

waremax_core/
rng.rs

1//! Seeded random number generator for deterministic simulation
2
3use rand::{Rng, SeedableRng};
4use rand_chacha::ChaCha8Rng;
5use rand_distr::{Distribution, Exp, Gamma, Poisson};
6
7/// Seeded RNG wrapper for deterministic simulation
8#[derive(Debug)]
9pub struct SimRng {
10    rng: ChaCha8Rng,
11}
12
13impl SimRng {
14    /// Create a new RNG with the given seed
15    pub fn new(seed: u64) -> Self {
16        Self {
17            rng: ChaCha8Rng::seed_from_u64(seed),
18        }
19    }
20
21    /// Generate a random value in the given range
22    pub fn gen_range<T, R>(&mut self, range: R) -> T
23    where
24        T: rand::distributions::uniform::SampleUniform,
25        R: rand::distributions::uniform::SampleRange<T>,
26    {
27        self.rng.gen_range(range)
28    }
29
30    /// Generate a random f64 in [0, 1)
31    pub fn gen_f64(&mut self) -> f64 {
32        self.rng.gen()
33    }
34
35    /// Generate a random bool with the given probability of true
36    pub fn gen_bool(&mut self, p: f64) -> bool {
37        self.rng.gen_bool(p)
38    }
39
40    /// Generate an exponential random variable
41    ///
42    /// Used for inter-arrival times in Poisson processes.
43    /// Mean = 1/rate
44    pub fn exponential(&mut self, rate: f64) -> f64 {
45        if rate <= 0.0 {
46            return f64::INFINITY;
47        }
48        let exp = Exp::new(rate).unwrap();
49        exp.sample(&mut self.rng)
50    }
51
52    /// Generate a Poisson random variable
53    ///
54    /// Returns the number of events in a unit interval given the rate.
55    pub fn poisson(&mut self, lambda: f64) -> u32 {
56        if lambda <= 0.0 {
57            return 0;
58        }
59        let pois = Poisson::new(lambda).unwrap();
60        pois.sample(&mut self.rng) as u32
61    }
62
63    /// Generate from a negative binomial distribution
64    ///
65    /// Used for order line counts. Returns at least 1.
66    /// Uses Gamma-Poisson mixture.
67    pub fn negbin(&mut self, mean: f64, dispersion: f64) -> u32 {
68        if mean <= 0.0 || dispersion <= 0.0 {
69            return 1;
70        }
71
72        // Negative binomial via Gamma-Poisson mixture
73        // r = dispersion, p = r / (r + mean)
74        let r = dispersion;
75        let p = r / (r + mean);
76        let gamma_shape = r;
77        let gamma_scale = (1.0 - p) / p;
78
79        let gamma = Gamma::new(gamma_shape, gamma_scale).unwrap();
80        let lambda = gamma.sample(&mut self.rng);
81
82        if lambda <= 0.0 {
83            return 1;
84        }
85
86        let pois = Poisson::new(lambda).unwrap();
87        let value = pois.sample(&mut self.rng) as u32;
88
89        // Return at least 1
90        value.max(1)
91    }
92
93    /// Generate from a Zipf distribution
94    ///
95    /// Used for SKU popularity. Returns an index in [0, n).
96    pub fn zipf(&mut self, n: usize, alpha: f64) -> usize {
97        if n == 0 {
98            return 0;
99        }
100        if n == 1 {
101            return 0;
102        }
103
104        // Calculate normalization constant (Hurwitz zeta approximation)
105        let mut h_sum = 0.0;
106        for k in 1..=n {
107            h_sum += 1.0 / (k as f64).powf(alpha);
108        }
109
110        // Generate uniform and find corresponding rank
111        let u: f64 = self.rng.gen();
112        let target = u * h_sum;
113
114        let mut cumsum = 0.0;
115        for k in 1..=n {
116            cumsum += 1.0 / (k as f64).powf(alpha);
117            if cumsum >= target {
118                return k - 1; // Return 0-indexed
119            }
120        }
121
122        n - 1
123    }
124
125    /// Generate a normal random variable
126    pub fn normal(&mut self, mean: f64, stddev: f64) -> f64 {
127        use rand_distr::Normal;
128        let normal = Normal::new(mean, stddev).unwrap();
129        normal.sample(&mut self.rng)
130    }
131
132    /// Generate a lognormal random variable
133    pub fn lognormal(&mut self, mean: f64, stddev: f64) -> f64 {
134        use rand_distr::LogNormal;
135        // Convert mean/stddev to mu/sigma for lognormal
136        let variance = stddev * stddev;
137        let mu = (mean * mean / (mean * mean + variance).sqrt()).ln();
138        let sigma = (1.0 + variance / (mean * mean)).ln().sqrt();
139        let lognormal = LogNormal::new(mu, sigma).unwrap();
140        lognormal.sample(&mut self.rng)
141    }
142
143    /// Generate a uniform random variable
144    pub fn uniform(&mut self, min: f64, max: f64) -> f64 {
145        self.rng.gen_range(min..max)
146    }
147
148    /// Choose a random element from a slice
149    pub fn choose<'a, T>(&mut self, slice: &'a [T]) -> Option<&'a T> {
150        if slice.is_empty() {
151            None
152        } else {
153            let idx = self.gen_range(0..slice.len());
154            Some(&slice[idx])
155        }
156    }
157
158    /// Shuffle a slice in place
159    pub fn shuffle<T>(&mut self, slice: &mut [T]) {
160        use rand::seq::SliceRandom;
161        slice.shuffle(&mut self.rng);
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168
169    #[test]
170    fn test_determinism() {
171        let mut rng1 = SimRng::new(42);
172        let mut rng2 = SimRng::new(42);
173
174        for _ in 0..100 {
175            assert_eq!(rng1.gen_f64(), rng2.gen_f64());
176        }
177    }
178
179    #[test]
180    fn test_exponential() {
181        let mut rng = SimRng::new(42);
182
183        // Generate many samples and check mean
184        let rate = 2.0;
185        let expected_mean = 1.0 / rate;
186        let samples: Vec<f64> = (0..10000).map(|_| rng.exponential(rate)).collect();
187        let actual_mean: f64 = samples.iter().sum::<f64>() / samples.len() as f64;
188
189        // Should be close to expected mean
190        assert!((actual_mean - expected_mean).abs() < 0.1);
191    }
192
193    #[test]
194    fn test_negbin() {
195        let mut rng = SimRng::new(42);
196
197        // Generate samples and verify all >= 1
198        for _ in 0..100 {
199            let value = rng.negbin(2.2, 1.3);
200            assert!(value >= 1);
201        }
202    }
203
204    #[test]
205    fn test_zipf() {
206        let mut rng = SimRng::new(42);
207        let n = 100;
208
209        // Lower indices should be more common
210        let mut counts = vec![0u32; n];
211        for _ in 0..10000 {
212            let idx = rng.zipf(n, 1.0);
213            assert!(idx < n);
214            counts[idx] += 1;
215        }
216
217        // First element should be most common
218        assert!(counts[0] > counts[n - 1]);
219    }
220}