scirs2_core/random/
core.rs1use ::ndarray::{Array, Dimension, Ix2, IxDyn};
7use rand::prelude::*;
8use rand::rngs::StdRng;
9use rand::{Rng, RngCore, SeedableRng};
10use rand_distr::{Distribution, Uniform};
11use std::cell::RefCell;
12
13#[derive(Debug)]
19pub struct Random<R = rand::rngs::ThreadRng> {
20 pub(crate) rng: R,
21}
22
23impl Default for Random<rand::rngs::ThreadRng> {
24 fn default() -> Self {
25 Random {
26 rng: rand::thread_rng(),
27 }
28 }
29}
30
31impl Random<StdRng> {
32 pub fn seed(seed: u64) -> Random<StdRng> {
37 Random {
38 rng: StdRng::seed_from_u64(seed),
39 }
40 }
41}
42
43impl SeedableRng for Random<StdRng> {
45 type Seed = <StdRng as SeedableRng>::Seed;
46
47 fn from_seed(seed: Self::Seed) -> Self {
48 Random {
49 rng: StdRng::from_seed(seed),
50 }
51 }
52
53 fn seed_from_u64(state: u64) -> Self {
54 Random {
55 rng: StdRng::seed_from_u64(state),
56 }
57 }
58}
59
60pub fn seeded_rng(seed: u64) -> Random<StdRng> {
64 Random::seed_from_u64(seed)
65}
66
67pub fn thread_rng() -> Random<rand::rngs::ThreadRng> {
71 Random::default()
72}
73
74impl<R: RngCore> RngCore for Random<R> {
76 fn next_u32(&mut self) -> u32 {
77 self.rng.next_u32()
78 }
79
80 fn next_u64(&mut self) -> u64 {
81 self.rng.next_u64()
82 }
83
84 fn fill_bytes(&mut self, dest: &mut [u8]) {
85 self.rng.fill_bytes(dest)
86 }
87}
88
89impl<R: Rng> Random<R> {
91 pub fn sample<T, D: Distribution<T>>(&mut self, distribution: D) -> T {
93 distribution.sample(&mut self.rng)
94 }
95
96 pub fn gen_range<T, B>(&mut self, range: B) -> T
98 where
99 T: rand_distr::uniform::SampleUniform,
100 B: rand_distr::uniform::SampleRange<T>,
101 {
102 self.rng.gen_range(range)
103 }
104
105 pub fn gen_bool(&mut self, p: f64) -> bool {
107 self.rng.gen_bool(p)
108 }
109
110 pub fn fill<T>(&mut self, slice: &mut [T])
112 where
113 rand_distr::StandardUniform: rand_distr::Distribution<T>,
114 {
115 for item in slice.iter_mut() {
116 *item = self.rng.gen();
117 }
118 }
119
120 pub fn sample_vec<T, D>(&mut self, distribution: D, size: usize) -> Vec<T>
122 where
123 D: Distribution<T> + Copy,
124 {
125 (0..size).map(|_| self.sample(distribution)).collect()
126 }
127
128 pub fn sample_array<T, Dim, D>(&mut self, shape: Dim, distribution: D) -> Array<T, Dim>
130 where
131 Dim: Dimension,
132 D: Distribution<T> + Copy,
133 {
134 let size = shape.size();
135 let values: Vec<T> = (0..size).map(|_| self.sample(distribution)).collect();
136 Array::from_shape_vec(shape, values).expect("Operation failed")
137 }
138
139 pub fn rng_mut(&mut self) -> &mut R {
141 &mut self.rng
142 }
143
144 pub fn rng(&self) -> &R {
146 &self.rng
147 }
148}
149
150pub trait DistributionExt<T>: Distribution<T> + Sized {
155 fn random_array<R: Rng, Dim: Dimension>(&self, rng: &mut Random<R>, shape: Dim) -> Array<T, Dim>
157 where
158 Self: Copy,
159 {
160 rng.sample_array(shape, *self)
161 }
162
163 fn sample_vec<R: Rng>(&self, rng: &mut Random<R>, size: usize) -> Vec<T>
165 where
166 Self: Copy,
167 {
168 rng.sample_vec(*self, size)
169 }
170}
171
172impl<D, T> DistributionExt<T> for D where D: Distribution<T> {}
174
175thread_local! {
176 static THREAD_RNG: RefCell<Random> = RefCell::new(Random::default());
177}
178
179#[allow(dead_code)]
181pub fn get_rng<F, R>(f: F) -> R
182where
183 F: FnOnce(&mut Random) -> R,
184{
185 THREAD_RNG.with(|rng| f(&mut rng.borrow_mut()))
186}
187
188pub mod scientific {
190 use super::*;
191
192 pub struct ReproducibleSequence {
194 seed: u64,
195 sequence_id: u64,
196 }
197
198 impl ReproducibleSequence {
199 pub fn new(seed: u64) -> Self {
201 Self {
202 seed,
203 sequence_id: 0,
204 }
205 }
206
207 pub fn next_rng(&mut self) -> Random<StdRng> {
209 let combined_seed = self.seed.wrapping_mul(31).wrapping_add(self.sequence_id);
210 self.sequence_id += 1;
211 Random::seed(combined_seed)
212 }
213
214 pub fn reset(&mut self) {
216 self.sequence_id = 0;
217 }
218
219 pub fn position(&self) -> u64 {
221 self.sequence_id
222 }
223 }
224
225 #[derive(Debug, Clone)]
227 pub struct DeterministicState {
228 pub seed: u64,
229 pub call_count: u64,
230 }
231
232 impl DeterministicState {
233 pub fn new(seed: u64) -> Self {
235 Self {
236 seed,
237 call_count: 0,
238 }
239 }
240
241 pub fn next_rng(&mut self) -> Random<StdRng> {
243 let rng_seed = self.seed.wrapping_mul(31).wrapping_add(self.call_count);
244 self.call_count += 1;
245 Random::seed(rng_seed)
246 }
247
248 pub fn current_state(&self) -> (u64, u64) {
250 (self.seed, self.call_count)
251 }
252
253 pub fn position(&self) -> u64 {
255 self.call_count
256 }
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263 use approx::assert_abs_diff_eq;
264
265 #[test]
266 fn test_random_creation() {
267 let mut rng = Random::default();
268 let _val = rng.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
269 }
270
271 #[test]
272 fn test_seeded_rng() {
273 let mut rng1 = seeded_rng(42);
274 let mut rng2 = seeded_rng(42);
275
276 let val1 = rng1.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
277 let val2 = rng2.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
278
279 assert_eq!(val1, val2);
280 }
281
282 #[test]
283 fn test_thread_rng() {
284 let mut rng = thread_rng();
285 let val = rng.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
286 assert!((0.0..1.0).contains(&val));
287 }
288
289 #[test]
290 fn test_reproducible_sequence() {
291 let mut seq1 = scientific::ReproducibleSequence::new(123);
292 let mut seq2 = scientific::ReproducibleSequence::new(123);
293
294 let mut rng1_1 = seq1.next_rng();
295 let mut rng1_2 = seq1.next_rng();
296
297 let mut rng2_1 = seq2.next_rng();
298 let mut rng2_2 = seq2.next_rng();
299
300 let val1_1 = rng1_1.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
301 let val1_2 = rng1_2.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
302
303 let val2_1 = rng2_1.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
304 let val2_2 = rng2_2.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
305
306 assert_eq!(val1_1, val2_1);
307 assert_eq!(val1_2, val2_2);
308 assert_ne!(val1_1, val1_2);
309 }
310
311 #[test]
312 fn test_deterministic_state() {
313 let mut state1 = scientific::DeterministicState::new(456);
314 let mut state2 = scientific::DeterministicState::new(456);
315
316 let mut rng1 = state1.next_rng();
317 let mut rng2 = state2.next_rng();
318
319 let val1 = rng1.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
320 let val2 = rng2.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
321
322 assert_eq!(val1, val2);
323 assert_eq!(state1.position(), state2.position());
324 }
325
326 #[test]
327 fn test_sample_array() {
328 let mut rng = seeded_rng(789);
329 let array = rng.sample_array(Ix2(3, 3), Uniform::new(0.0, 1.0).expect("Operation failed"));
330
331 assert_eq!(array.shape(), &[3, 3]);
332 assert!(array.iter().all(|&x| (0.0..1.0).contains(&x)));
333 }
334
335 #[test]
336 fn test_distribution_ext() {
337 let mut rng = seeded_rng(101112);
338 let distribution = Uniform::new(-1.0, 1.0).expect("Operation failed");
339
340 let vec = distribution.sample_vec(&mut rng, 10);
341 assert_eq!(vec.len(), 10);
342 assert!(vec.iter().all(|&x| (-1.0..1.0).contains(&x)));
343
344 let array = distribution.random_array(&mut rng, Ix2(2, 5));
345 assert_eq!(array.shape(), &[2, 5]);
346 }
347}