1use ::ndarray::{Array, Dimension, Ix2, IxDyn};
7use rand::rngs::StdRng;
8use rand::{Rng, SeedableRng};
9use rand_distr::{Distribution, Uniform};
10use std::cell::RefCell;
11use std::convert::Infallible;
12
13#[derive(Debug)]
19pub struct Random<R = rand::rngs::ThreadRng> {
20 pub(crate) rng: R,
21}
22
23impl Default for Random<rand::rngs::ThreadRng> {
24 fn default() -> Self {
25 Random { rng: rand::rng() }
26 }
27}
28
29impl Random<StdRng> {
30 pub fn seed(seed: u64) -> Random<StdRng> {
35 Random {
36 rng: StdRng::seed_from_u64(seed),
37 }
38 }
39}
40
41impl SeedableRng for Random<StdRng> {
43 type Seed = <StdRng as SeedableRng>::Seed;
44
45 fn from_seed(seed: Self::Seed) -> Self {
46 Random {
47 rng: StdRng::from_seed(seed),
48 }
49 }
50
51 fn seed_from_u64(state: u64) -> Self {
52 Random {
53 rng: StdRng::seed_from_u64(state),
54 }
55 }
56}
57
58pub fn seeded_rng(seed: u64) -> Random<StdRng> {
62 Random::seed_from_u64(seed)
63}
64
65pub fn thread_rng() -> Random<rand::rngs::ThreadRng> {
69 Random::default()
70}
71
72impl<R: Rng> rand::TryRng for Random<R> {
76 type Error = Infallible;
77
78 fn try_next_u32(&mut self) -> Result<u32, Self::Error> {
79 Ok(self.rng.next_u32())
80 }
81
82 fn try_next_u64(&mut self) -> Result<u64, Self::Error> {
83 Ok(self.rng.next_u64())
84 }
85
86 fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Self::Error> {
87 self.rng.fill_bytes(dest);
88 Ok(())
89 }
90}
91
92impl<R: Rng> Random<R> {
94 pub fn sample<T, D: Distribution<T>>(&mut self, distribution: D) -> T {
96 distribution.sample(&mut self.rng)
97 }
98
99 pub fn random_range<T, B>(&mut self, range: B) -> T
101 where
102 T: rand_distr::uniform::SampleUniform,
103 B: rand_distr::uniform::SampleRange<T>,
104 {
105 rand::RngExt::random_range(&mut self.rng, range)
106 }
107
108 pub fn random_bool(&mut self, p: f64) -> bool {
110 rand::RngExt::random_bool(&mut self.rng, p)
111 }
112
113 pub fn random<T>(&mut self) -> T
115 where
116 rand_distr::StandardUniform: rand_distr::Distribution<T>,
117 {
118 rand::RngExt::random(&mut self.rng)
119 }
120
121 pub fn gen_range<T, B>(&mut self, range: B) -> T
123 where
124 T: rand_distr::uniform::SampleUniform,
125 B: rand_distr::uniform::SampleRange<T>,
126 {
127 self.random_range(range)
128 }
129
130 pub fn gen_bool(&mut self, p: f64) -> bool {
132 self.random_bool(p)
133 }
134
135 pub fn fill<T>(&mut self, slice: &mut [T])
137 where
138 rand_distr::StandardUniform: rand_distr::Distribution<T>,
139 {
140 for item in slice.iter_mut() {
141 *item = rand::RngExt::random(&mut self.rng);
142 }
143 }
144
145 pub fn sample_vec<T, D>(&mut self, distribution: D, size: usize) -> Vec<T>
147 where
148 D: Distribution<T> + Copy,
149 {
150 (0..size).map(|_| self.sample(distribution)).collect()
151 }
152
153 pub fn sample_array<T, Dim, D>(&mut self, shape: Dim, distribution: D) -> Array<T, Dim>
155 where
156 Dim: Dimension,
157 D: Distribution<T> + Copy,
158 {
159 let size = shape.size();
160 let values: Vec<T> = (0..size).map(|_| self.sample(distribution)).collect();
161 Array::from_shape_vec(shape, values).expect("Operation failed")
162 }
163
164 pub fn rng_mut(&mut self) -> &mut R {
166 &mut self.rng
167 }
168
169 pub fn rng(&self) -> &R {
171 &self.rng
172 }
173}
174
175pub trait DistributionExt<T>: Distribution<T> + Sized {
180 fn random_array<R: Rng, Dim: Dimension>(&self, rng: &mut Random<R>, shape: Dim) -> Array<T, Dim>
182 where
183 Self: Copy,
184 {
185 rng.sample_array(shape, *self)
186 }
187
188 fn sample_vec<R: Rng>(&self, rng: &mut Random<R>, size: usize) -> Vec<T>
190 where
191 Self: Copy,
192 {
193 rng.sample_vec(*self, size)
194 }
195}
196
197impl<D, T> DistributionExt<T> for D where D: Distribution<T> {}
199
200thread_local! {
201 static THREAD_RNG: RefCell<Random> = RefCell::new(Random::default());
202}
203
204#[allow(dead_code)]
206pub fn get_rng<F, R>(f: F) -> R
207where
208 F: FnOnce(&mut Random) -> R,
209{
210 THREAD_RNG.with(|rng| f(&mut rng.borrow_mut()))
211}
212
213pub mod scientific {
215 use super::*;
216
217 pub struct ReproducibleSequence {
219 seed: u64,
220 sequence_id: u64,
221 }
222
223 impl ReproducibleSequence {
224 pub fn new(seed: u64) -> Self {
226 Self {
227 seed,
228 sequence_id: 0,
229 }
230 }
231
232 pub fn next_rng(&mut self) -> Random<StdRng> {
234 let combined_seed = self.seed.wrapping_mul(31).wrapping_add(self.sequence_id);
235 self.sequence_id += 1;
236 Random::seed(combined_seed)
237 }
238
239 pub fn reset(&mut self) {
241 self.sequence_id = 0;
242 }
243
244 pub fn position(&self) -> u64 {
246 self.sequence_id
247 }
248 }
249
250 #[derive(Debug, Clone)]
252 pub struct DeterministicState {
253 pub seed: u64,
254 pub call_count: u64,
255 }
256
257 impl DeterministicState {
258 pub fn new(seed: u64) -> Self {
260 Self {
261 seed,
262 call_count: 0,
263 }
264 }
265
266 pub fn next_rng(&mut self) -> Random<StdRng> {
268 let rng_seed = self.seed.wrapping_mul(31).wrapping_add(self.call_count);
269 self.call_count += 1;
270 Random::seed(rng_seed)
271 }
272
273 pub fn current_state(&self) -> (u64, u64) {
275 (self.seed, self.call_count)
276 }
277
278 pub fn position(&self) -> u64 {
280 self.call_count
281 }
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288 use approx::assert_abs_diff_eq;
289
290 #[test]
291 fn test_random_creation() {
292 let mut rng = Random::default();
293 let _val = rng.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
294 }
295
296 #[test]
297 fn test_seeded_rng() {
298 let mut rng1 = seeded_rng(42);
299 let mut rng2 = seeded_rng(42);
300
301 let val1 = rng1.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
302 let val2 = rng2.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
303
304 assert_eq!(val1, val2);
305 }
306
307 #[test]
308 fn test_thread_rng() {
309 let mut rng = thread_rng();
310 let val = rng.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
311 assert!((0.0..1.0).contains(&val));
312 }
313
314 #[test]
315 fn test_reproducible_sequence() {
316 let mut seq1 = scientific::ReproducibleSequence::new(123);
317 let mut seq2 = scientific::ReproducibleSequence::new(123);
318
319 let mut rng1_1 = seq1.next_rng();
320 let mut rng1_2 = seq1.next_rng();
321
322 let mut rng2_1 = seq2.next_rng();
323 let mut rng2_2 = seq2.next_rng();
324
325 let val1_1 = rng1_1.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
326 let val1_2 = rng1_2.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
327
328 let val2_1 = rng2_1.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
329 let val2_2 = rng2_2.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
330
331 assert_eq!(val1_1, val2_1);
332 assert_eq!(val1_2, val2_2);
333 assert_ne!(val1_1, val1_2);
334 }
335
336 #[test]
337 fn test_deterministic_state() {
338 let mut state1 = scientific::DeterministicState::new(456);
339 let mut state2 = scientific::DeterministicState::new(456);
340
341 let mut rng1 = state1.next_rng();
342 let mut rng2 = state2.next_rng();
343
344 let val1 = rng1.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
345 let val2 = rng2.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
346
347 assert_eq!(val1, val2);
348 assert_eq!(state1.position(), state2.position());
349 }
350
351 #[test]
352 fn test_sample_array() {
353 let mut rng = seeded_rng(789);
354 let array = rng.sample_array(Ix2(3, 3), Uniform::new(0.0, 1.0).expect("Operation failed"));
355
356 assert_eq!(array.shape(), &[3, 3]);
357 assert!(array.iter().all(|&x| (0.0..1.0).contains(&x)));
358 }
359
360 #[test]
361 fn test_distribution_ext() {
362 let mut rng = seeded_rng(101112);
363 let distribution = Uniform::new(-1.0, 1.0).expect("Operation failed");
364
365 let vec = distribution.sample_vec(&mut rng, 10);
366 assert_eq!(vec.len(), 10);
367 assert!(vec.iter().all(|&x| (-1.0..1.0).contains(&x)));
368
369 let array = distribution.random_array(&mut rng, Ix2(2, 5));
370 assert_eq!(array.shape(), &[2, 5]);
371 }
372}