1use ::ndarray::{Array, Dimension, Ix2, IxDyn};
7use rand::rngs::StdRng;
8use rand::{Rng, RngCore, SeedableRng};
9use rand_distr::{Distribution, Uniform};
10use std::cell::RefCell;
11
12#[derive(Debug)]
18pub struct Random<R = rand::rngs::ThreadRng> {
19 pub(crate) rng: R,
20}
21
22impl Default for Random<rand::rngs::ThreadRng> {
23 fn default() -> Self {
24 Random { rng: rand::rng() }
25 }
26}
27
28impl Random<StdRng> {
29 pub fn seed(seed: u64) -> Random<StdRng> {
34 Random {
35 rng: StdRng::seed_from_u64(seed),
36 }
37 }
38}
39
40impl SeedableRng for Random<StdRng> {
42 type Seed = <StdRng as SeedableRng>::Seed;
43
44 fn from_seed(seed: Self::Seed) -> Self {
45 Random {
46 rng: StdRng::from_seed(seed),
47 }
48 }
49
50 fn seed_from_u64(state: u64) -> Self {
51 Random {
52 rng: StdRng::seed_from_u64(state),
53 }
54 }
55}
56
57pub fn seeded_rng(seed: u64) -> Random<StdRng> {
61 Random::seed_from_u64(seed)
62}
63
64pub fn thread_rng() -> Random<rand::rngs::ThreadRng> {
68 Random::default()
69}
70
71impl<R: RngCore> RngCore for Random<R> {
73 fn next_u32(&mut self) -> u32 {
74 self.rng.next_u32()
75 }
76
77 fn next_u64(&mut self) -> u64 {
78 self.rng.next_u64()
79 }
80
81 fn fill_bytes(&mut self, dest: &mut [u8]) {
82 self.rng.fill_bytes(dest)
83 }
84}
85
86impl<R: Rng> Random<R> {
88 pub fn sample<T, D: Distribution<T>>(&mut self, distribution: D) -> T {
90 distribution.sample(&mut self.rng)
91 }
92
93 pub fn random_range<T, B>(&mut self, range: B) -> T
95 where
96 T: rand_distr::uniform::SampleUniform,
97 B: rand_distr::uniform::SampleRange<T>,
98 {
99 rand::Rng::random_range(&mut self.rng, range)
100 }
101
102 pub fn random_bool(&mut self, p: f64) -> bool {
104 rand::Rng::random_bool(&mut self.rng, p)
105 }
106
107 pub fn random<T>(&mut self) -> T
109 where
110 rand_distr::StandardUniform: rand_distr::Distribution<T>,
111 {
112 rand::Rng::random(&mut self.rng)
113 }
114
115 pub fn gen_range<T, B>(&mut self, range: B) -> T
117 where
118 T: rand_distr::uniform::SampleUniform,
119 B: rand_distr::uniform::SampleRange<T>,
120 {
121 self.random_range(range)
122 }
123
124 pub fn gen_bool(&mut self, p: f64) -> bool {
126 self.random_bool(p)
127 }
128
129 pub fn fill<T>(&mut self, slice: &mut [T])
131 where
132 rand_distr::StandardUniform: rand_distr::Distribution<T>,
133 {
134 for item in slice.iter_mut() {
135 *item = rand::Rng::random(&mut self.rng);
136 }
137 }
138
139 pub fn sample_vec<T, D>(&mut self, distribution: D, size: usize) -> Vec<T>
141 where
142 D: Distribution<T> + Copy,
143 {
144 (0..size).map(|_| self.sample(distribution)).collect()
145 }
146
147 pub fn sample_array<T, Dim, D>(&mut self, shape: Dim, distribution: D) -> Array<T, Dim>
149 where
150 Dim: Dimension,
151 D: Distribution<T> + Copy,
152 {
153 let size = shape.size();
154 let values: Vec<T> = (0..size).map(|_| self.sample(distribution)).collect();
155 Array::from_shape_vec(shape, values).expect("Operation failed")
156 }
157
158 pub fn rng_mut(&mut self) -> &mut R {
160 &mut self.rng
161 }
162
163 pub fn rng(&self) -> &R {
165 &self.rng
166 }
167}
168
169pub trait DistributionExt<T>: Distribution<T> + Sized {
174 fn random_array<R: Rng, Dim: Dimension>(&self, rng: &mut Random<R>, shape: Dim) -> Array<T, Dim>
176 where
177 Self: Copy,
178 {
179 rng.sample_array(shape, *self)
180 }
181
182 fn sample_vec<R: Rng>(&self, rng: &mut Random<R>, size: usize) -> Vec<T>
184 where
185 Self: Copy,
186 {
187 rng.sample_vec(*self, size)
188 }
189}
190
191impl<D, T> DistributionExt<T> for D where D: Distribution<T> {}
193
194thread_local! {
195 static THREAD_RNG: RefCell<Random> = RefCell::new(Random::default());
196}
197
198#[allow(dead_code)]
200pub fn get_rng<F, R>(f: F) -> R
201where
202 F: FnOnce(&mut Random) -> R,
203{
204 THREAD_RNG.with(|rng| f(&mut rng.borrow_mut()))
205}
206
207pub mod scientific {
209 use super::*;
210
211 pub struct ReproducibleSequence {
213 seed: u64,
214 sequence_id: u64,
215 }
216
217 impl ReproducibleSequence {
218 pub fn new(seed: u64) -> Self {
220 Self {
221 seed,
222 sequence_id: 0,
223 }
224 }
225
226 pub fn next_rng(&mut self) -> Random<StdRng> {
228 let combined_seed = self.seed.wrapping_mul(31).wrapping_add(self.sequence_id);
229 self.sequence_id += 1;
230 Random::seed(combined_seed)
231 }
232
233 pub fn reset(&mut self) {
235 self.sequence_id = 0;
236 }
237
238 pub fn position(&self) -> u64 {
240 self.sequence_id
241 }
242 }
243
244 #[derive(Debug, Clone)]
246 pub struct DeterministicState {
247 pub seed: u64,
248 pub call_count: u64,
249 }
250
251 impl DeterministicState {
252 pub fn new(seed: u64) -> Self {
254 Self {
255 seed,
256 call_count: 0,
257 }
258 }
259
260 pub fn next_rng(&mut self) -> Random<StdRng> {
262 let rng_seed = self.seed.wrapping_mul(31).wrapping_add(self.call_count);
263 self.call_count += 1;
264 Random::seed(rng_seed)
265 }
266
267 pub fn current_state(&self) -> (u64, u64) {
269 (self.seed, self.call_count)
270 }
271
272 pub fn position(&self) -> u64 {
274 self.call_count
275 }
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282 use approx::assert_abs_diff_eq;
283
284 #[test]
285 fn test_random_creation() {
286 let mut rng = Random::default();
287 let _val = rng.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
288 }
289
290 #[test]
291 fn test_seeded_rng() {
292 let mut rng1 = seeded_rng(42);
293 let mut rng2 = seeded_rng(42);
294
295 let val1 = rng1.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
296 let val2 = rng2.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
297
298 assert_eq!(val1, val2);
299 }
300
301 #[test]
302 fn test_thread_rng() {
303 let mut rng = thread_rng();
304 let val = rng.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
305 assert!((0.0..1.0).contains(&val));
306 }
307
308 #[test]
309 fn test_reproducible_sequence() {
310 let mut seq1 = scientific::ReproducibleSequence::new(123);
311 let mut seq2 = scientific::ReproducibleSequence::new(123);
312
313 let mut rng1_1 = seq1.next_rng();
314 let mut rng1_2 = seq1.next_rng();
315
316 let mut rng2_1 = seq2.next_rng();
317 let mut rng2_2 = seq2.next_rng();
318
319 let val1_1 = rng1_1.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
320 let val1_2 = rng1_2.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
321
322 let val2_1 = rng2_1.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
323 let val2_2 = rng2_2.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
324
325 assert_eq!(val1_1, val2_1);
326 assert_eq!(val1_2, val2_2);
327 assert_ne!(val1_1, val1_2);
328 }
329
330 #[test]
331 fn test_deterministic_state() {
332 let mut state1 = scientific::DeterministicState::new(456);
333 let mut state2 = scientific::DeterministicState::new(456);
334
335 let mut rng1 = state1.next_rng();
336 let mut rng2 = state2.next_rng();
337
338 let val1 = rng1.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
339 let val2 = rng2.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
340
341 assert_eq!(val1, val2);
342 assert_eq!(state1.position(), state2.position());
343 }
344
345 #[test]
346 fn test_sample_array() {
347 let mut rng = seeded_rng(789);
348 let array = rng.sample_array(Ix2(3, 3), Uniform::new(0.0, 1.0).expect("Operation failed"));
349
350 assert_eq!(array.shape(), &[3, 3]);
351 assert!(array.iter().all(|&x| (0.0..1.0).contains(&x)));
352 }
353
354 #[test]
355 fn test_distribution_ext() {
356 let mut rng = seeded_rng(101112);
357 let distribution = Uniform::new(-1.0, 1.0).expect("Operation failed");
358
359 let vec = distribution.sample_vec(&mut rng, 10);
360 assert_eq!(vec.len(), 10);
361 assert!(vec.iter().all(|&x| (-1.0..1.0).contains(&x)));
362
363 let array = distribution.random_array(&mut rng, Ix2(2, 5));
364 assert_eq!(array.shape(), &[2, 5]);
365 }
366}