scirs2_core/random/
core.rs

1//! Core random number generation functionality for SCIRS2 ecosystem
2//!
3//! This module provides the foundational Random struct and traits that serve as the
4//! basis for all random number generation across the SCIRS2 scientific computing ecosystem.
5
6use ::ndarray::{Array, Dimension, Ix2, IxDyn};
7use rand::prelude::*;
8use rand::rngs::StdRng;
9use rand::{Rng, RngCore, SeedableRng};
10use rand_distr::{Distribution, Uniform};
11use std::cell::RefCell;
12
13/// Enhanced random number generator for scientific computing
14///
15/// This is the core random number generator used throughout the SCIRS2 ecosystem.
16/// It provides deterministic, high-quality random number generation with support
17/// for seeding, thread-local storage, and scientific reproducibility.
18#[derive(Debug)]
19pub struct Random<R = rand::rngs::ThreadRng> {
20    pub(crate) rng: R,
21}
22
23impl Default for Random<rand::rngs::ThreadRng> {
24    fn default() -> Self {
25        Random {
26            rng: rand::thread_rng(),
27        }
28    }
29}
30
31impl Random<StdRng> {
32    /// Create a new random number generator with a specific seed
33    ///
34    /// This ensures deterministic behavior across runs, which is critical
35    /// for scientific reproducibility and testing.
36    pub fn seed(seed: u64) -> Random<StdRng> {
37        Random {
38            rng: StdRng::seed_from_u64(seed),
39        }
40    }
41}
42
43// Implement SeedableRng for Random<StdRng> to support ecosystem requirements
44impl SeedableRng for Random<StdRng> {
45    type Seed = <StdRng as SeedableRng>::Seed;
46
47    fn from_seed(seed: Self::Seed) -> Self {
48        Random {
49            rng: StdRng::from_seed(seed),
50        }
51    }
52
53    fn seed_from_u64(state: u64) -> Self {
54        Random {
55            rng: StdRng::seed_from_u64(state),
56        }
57    }
58}
59
60/// Create a seeded random number generator (convenience function)
61///
62/// This is the primary way to create deterministic RNGs across the SCIRS2 ecosystem.
63pub fn seeded_rng(seed: u64) -> Random<StdRng> {
64    Random::seed_from_u64(seed)
65}
66
67/// Get a thread-local random number generator (convenience function)
68///
69/// This provides fast access to a thread-local RNG for performance-critical code.
70pub fn thread_rng() -> Random<rand::rngs::ThreadRng> {
71    Random::default()
72}
73
74// Implement RngCore for Random to forward to inner RNG
75impl<R: RngCore> RngCore for Random<R> {
76    fn next_u32(&mut self) -> u32 {
77        self.rng.next_u32()
78    }
79
80    fn next_u64(&mut self) -> u64 {
81        self.rng.next_u64()
82    }
83
84    fn fill_bytes(&mut self, dest: &mut [u8]) {
85        self.rng.fill_bytes(dest)
86    }
87}
88
89// Implement Distribution sampling methods
90impl<R: Rng> Random<R> {
91    /// Sample from a distribution
92    pub fn sample<T, D: Distribution<T>>(&mut self, distribution: D) -> T {
93        distribution.sample(&mut self.rng)
94    }
95
96    /// Generate a random value in a range
97    pub fn gen_range<T, B>(&mut self, range: B) -> T
98    where
99        T: rand_distr::uniform::SampleUniform,
100        B: rand_distr::uniform::SampleRange<T>,
101    {
102        self.rng.gen_range(range)
103    }
104
105    /// Generate a random boolean
106    pub fn gen_bool(&mut self, p: f64) -> bool {
107        self.rng.gen_bool(p)
108    }
109
110    /// Fill a slice with random values
111    pub fn fill<T>(&mut self, slice: &mut [T])
112    where
113        rand_distr::StandardUniform: rand_distr::Distribution<T>,
114    {
115        for item in slice.iter_mut() {
116            *item = self.rng.gen();
117        }
118    }
119
120    /// Generate a vector of random values
121    pub fn sample_vec<T, D>(&mut self, distribution: D, size: usize) -> Vec<T>
122    where
123        D: Distribution<T> + Copy,
124    {
125        (0..size).map(|_| self.sample(distribution)).collect()
126    }
127
128    /// Generate a random array with specified shape and distribution
129    pub fn sample_array<T, Dim, D>(&mut self, shape: Dim, distribution: D) -> Array<T, Dim>
130    where
131        Dim: Dimension,
132        D: Distribution<T> + Copy,
133    {
134        let size = shape.size();
135        let values: Vec<T> = (0..size).map(|_| self.sample(distribution)).collect();
136        Array::from_shape_vec(shape, values).expect("Operation failed")
137    }
138
139    /// Access the underlying RNG (for advanced use cases)
140    pub fn rng_mut(&mut self) -> &mut R {
141        &mut self.rng
142    }
143
144    /// Access the underlying RNG (read-only)
145    pub fn rng(&self) -> &R {
146        &self.rng
147    }
148}
149
150/// Extension trait for distributions to create arrays directly
151///
152/// This provides a consistent interface for generating random arrays
153/// across the SCIRS2 ecosystem.
154pub trait DistributionExt<T>: Distribution<T> + Sized {
155    /// Create a random array with values from this distribution
156    fn random_array<R: Rng, Dim: Dimension>(&self, rng: &mut Random<R>, shape: Dim) -> Array<T, Dim>
157    where
158        Self: Copy,
159    {
160        rng.sample_array(shape, *self)
161    }
162
163    /// Create a random vector with values from this distribution
164    fn sample_vec<R: Rng>(&self, rng: &mut Random<R>, size: usize) -> Vec<T>
165    where
166        Self: Copy,
167    {
168        rng.sample_vec(*self, size)
169    }
170}
171
172// Implement the extension trait for all distributions
173impl<D, T> DistributionExt<T> for D where D: Distribution<T> {}
174
175thread_local! {
176    static THREAD_RNG: RefCell<Random> = RefCell::new(Random::default());
177}
178
179/// Get a reference to the thread-local random number generator
180#[allow(dead_code)]
181pub fn get_rng<F, R>(f: F) -> R
182where
183    F: FnOnce(&mut Random) -> R,
184{
185    THREAD_RNG.with(|rng| f(&mut rng.borrow_mut()))
186}
187
188/// Scientific random number generation utilities
189pub mod scientific {
190    use super::*;
191
192    /// Generate reproducible random sequences for scientific experiments
193    pub struct ReproducibleSequence {
194        seed: u64,
195        sequence_id: u64,
196    }
197
198    impl ReproducibleSequence {
199        /// Create a new reproducible sequence
200        pub fn new(seed: u64) -> Self {
201            Self {
202                seed,
203                sequence_id: 0,
204            }
205        }
206
207        /// Get the next RNG in the sequence
208        pub fn next_rng(&mut self) -> Random<StdRng> {
209            let combined_seed = self.seed.wrapping_mul(31).wrapping_add(self.sequence_id);
210            self.sequence_id += 1;
211            Random::seed(combined_seed)
212        }
213
214        /// Reset the sequence
215        pub fn reset(&mut self) {
216            self.sequence_id = 0;
217        }
218
219        /// Get current sequence position
220        pub fn position(&self) -> u64 {
221            self.sequence_id
222        }
223    }
224
225    /// Deterministic random state for reproducible experiments
226    #[derive(Debug, Clone)]
227    pub struct DeterministicState {
228        pub seed: u64,
229        pub call_count: u64,
230    }
231
232    impl DeterministicState {
233        /// Create a new deterministic state
234        pub fn new(seed: u64) -> Self {
235            Self {
236                seed,
237                call_count: 0,
238            }
239        }
240
241        /// Create an RNG from this state and advance the counter
242        pub fn next_rng(&mut self) -> Random<StdRng> {
243            let rng_seed = self.seed.wrapping_mul(31).wrapping_add(self.call_count);
244            self.call_count += 1;
245            Random::seed(rng_seed)
246        }
247
248        /// Get current state without advancing
249        pub fn current_state(&self) -> (u64, u64) {
250            (self.seed, self.call_count)
251        }
252
253        /// Get current position in the sequence
254        pub fn position(&self) -> u64 {
255            self.call_count
256        }
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    use approx::assert_abs_diff_eq;
264
265    #[test]
266    fn test_random_creation() {
267        let mut rng = Random::default();
268        let _val = rng.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
269    }
270
271    #[test]
272    fn test_seeded_rng() {
273        let mut rng1 = seeded_rng(42);
274        let mut rng2 = seeded_rng(42);
275
276        let val1 = rng1.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
277        let val2 = rng2.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
278
279        assert_eq!(val1, val2);
280    }
281
282    #[test]
283    fn test_thread_rng() {
284        let mut rng = thread_rng();
285        let val = rng.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
286        assert!((0.0..1.0).contains(&val));
287    }
288
289    #[test]
290    fn test_reproducible_sequence() {
291        let mut seq1 = scientific::ReproducibleSequence::new(123);
292        let mut seq2 = scientific::ReproducibleSequence::new(123);
293
294        let mut rng1_1 = seq1.next_rng();
295        let mut rng1_2 = seq1.next_rng();
296
297        let mut rng2_1 = seq2.next_rng();
298        let mut rng2_2 = seq2.next_rng();
299
300        let val1_1 = rng1_1.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
301        let val1_2 = rng1_2.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
302
303        let val2_1 = rng2_1.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
304        let val2_2 = rng2_2.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
305
306        assert_eq!(val1_1, val2_1);
307        assert_eq!(val1_2, val2_2);
308        assert_ne!(val1_1, val1_2);
309    }
310
311    #[test]
312    fn test_deterministic_state() {
313        let mut state1 = scientific::DeterministicState::new(456);
314        let mut state2 = scientific::DeterministicState::new(456);
315
316        let mut rng1 = state1.next_rng();
317        let mut rng2 = state2.next_rng();
318
319        let val1 = rng1.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
320        let val2 = rng2.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
321
322        assert_eq!(val1, val2);
323        assert_eq!(state1.position(), state2.position());
324    }
325
326    #[test]
327    fn test_sample_array() {
328        let mut rng = seeded_rng(789);
329        let array = rng.sample_array(Ix2(3, 3), Uniform::new(0.0, 1.0).expect("Operation failed"));
330
331        assert_eq!(array.shape(), &[3, 3]);
332        assert!(array.iter().all(|&x| (0.0..1.0).contains(&x)));
333    }
334
335    #[test]
336    fn test_distribution_ext() {
337        let mut rng = seeded_rng(101112);
338        let distribution = Uniform::new(-1.0, 1.0).expect("Operation failed");
339
340        let vec = distribution.sample_vec(&mut rng, 10);
341        assert_eq!(vec.len(), 10);
342        assert!(vec.iter().all(|&x| (-1.0..1.0).contains(&x)));
343
344        let array = distribution.random_array(&mut rng, Ix2(2, 5));
345        assert_eq!(array.shape(), &[2, 5]);
346    }
347}