Skip to main content

scirs2_core/random/
core.rs

1//! Core random number generation functionality for SCIRS2 ecosystem
2//!
3//! This module provides the foundational Random struct and traits that serve as the
4//! basis for all random number generation across the SCIRS2 scientific computing ecosystem.
5
6use ::ndarray::{Array, Dimension, Ix2, IxDyn};
7use rand::rngs::StdRng;
8use rand::{Rng, SeedableRng};
9use rand_distr::{Distribution, Uniform};
10use std::cell::RefCell;
11use std::convert::Infallible;
12
13/// Enhanced random number generator for scientific computing
14///
15/// This is the core random number generator used throughout the SCIRS2 ecosystem.
16/// It provides deterministic, high-quality random number generation with support
17/// for seeding, thread-local storage, and scientific reproducibility.
18#[derive(Debug)]
19pub struct Random<R = rand::rngs::ThreadRng> {
20    pub(crate) rng: R,
21}
22
23impl Default for Random<rand::rngs::ThreadRng> {
24    fn default() -> Self {
25        Random { rng: rand::rng() }
26    }
27}
28
29impl Random<StdRng> {
30    /// Create a new random number generator with a specific seed
31    ///
32    /// This ensures deterministic behavior across runs, which is critical
33    /// for scientific reproducibility and testing.
34    pub fn seed(seed: u64) -> Random<StdRng> {
35        Random {
36            rng: StdRng::seed_from_u64(seed),
37        }
38    }
39}
40
41// Implement SeedableRng for Random<StdRng> to support ecosystem requirements
42impl SeedableRng for Random<StdRng> {
43    type Seed = <StdRng as SeedableRng>::Seed;
44
45    fn from_seed(seed: Self::Seed) -> Self {
46        Random {
47            rng: StdRng::from_seed(seed),
48        }
49    }
50
51    fn seed_from_u64(state: u64) -> Self {
52        Random {
53            rng: StdRng::seed_from_u64(state),
54        }
55    }
56}
57
58/// Create a seeded random number generator (convenience function)
59///
60/// This is the primary way to create deterministic RNGs across the SCIRS2 ecosystem.
61pub fn seeded_rng(seed: u64) -> Random<StdRng> {
62    Random::seed_from_u64(seed)
63}
64
65/// Get a thread-local random number generator (convenience function)
66///
67/// This provides fast access to a thread-local RNG for performance-critical code.
68pub fn thread_rng() -> Random<rand::rngs::ThreadRng> {
69    Random::default()
70}
71
72// Implement TryRng for Random to forward to inner RNG
73// In rand_core 0.10, TryRng is the base trait; Rng and RngCore are auto-implemented
74// for TryRng<Error = Infallible>.
75impl<R: Rng> rand::TryRng for Random<R> {
76    type Error = Infallible;
77
78    fn try_next_u32(&mut self) -> Result<u32, Self::Error> {
79        Ok(self.rng.next_u32())
80    }
81
82    fn try_next_u64(&mut self) -> Result<u64, Self::Error> {
83        Ok(self.rng.next_u64())
84    }
85
86    fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Self::Error> {
87        self.rng.fill_bytes(dest);
88        Ok(())
89    }
90}
91
92// Implement Distribution sampling methods
93impl<R: Rng> Random<R> {
94    /// Sample from a distribution
95    pub fn sample<T, D: Distribution<T>>(&mut self, distribution: D) -> T {
96        distribution.sample(&mut self.rng)
97    }
98
99    /// Generate a random value in a range
100    pub fn random_range<T, B>(&mut self, range: B) -> T
101    where
102        T: rand_distr::uniform::SampleUniform,
103        B: rand_distr::uniform::SampleRange<T>,
104    {
105        rand::RngExt::random_range(&mut self.rng, range)
106    }
107
108    /// Generate a random boolean
109    pub fn random_bool(&mut self, p: f64) -> bool {
110        rand::RngExt::random_bool(&mut self.rng, p)
111    }
112
113    /// Generate a random value of the inferred type
114    pub fn random<T>(&mut self) -> T
115    where
116        rand_distr::StandardUniform: rand_distr::Distribution<T>,
117    {
118        rand::RngExt::random(&mut self.rng)
119    }
120
121    /// Backward-compat alias for `random_range`
122    pub fn gen_range<T, B>(&mut self, range: B) -> T
123    where
124        T: rand_distr::uniform::SampleUniform,
125        B: rand_distr::uniform::SampleRange<T>,
126    {
127        self.random_range(range)
128    }
129
130    /// Backward-compat alias for `random_bool`
131    pub fn gen_bool(&mut self, p: f64) -> bool {
132        self.random_bool(p)
133    }
134
135    /// Fill a slice with random values
136    pub fn fill<T>(&mut self, slice: &mut [T])
137    where
138        rand_distr::StandardUniform: rand_distr::Distribution<T>,
139    {
140        for item in slice.iter_mut() {
141            *item = rand::RngExt::random(&mut self.rng);
142        }
143    }
144
145    /// Generate a vector of random values
146    pub fn sample_vec<T, D>(&mut self, distribution: D, size: usize) -> Vec<T>
147    where
148        D: Distribution<T> + Copy,
149    {
150        (0..size).map(|_| self.sample(distribution)).collect()
151    }
152
153    /// Generate a random array with specified shape and distribution
154    pub fn sample_array<T, Dim, D>(&mut self, shape: Dim, distribution: D) -> Array<T, Dim>
155    where
156        Dim: Dimension,
157        D: Distribution<T> + Copy,
158    {
159        let size = shape.size();
160        let values: Vec<T> = (0..size).map(|_| self.sample(distribution)).collect();
161        Array::from_shape_vec(shape, values).expect("Operation failed")
162    }
163
164    /// Access the underlying RNG (for advanced use cases)
165    pub fn rng_mut(&mut self) -> &mut R {
166        &mut self.rng
167    }
168
169    /// Access the underlying RNG (read-only)
170    pub fn rng(&self) -> &R {
171        &self.rng
172    }
173}
174
175/// Extension trait for distributions to create arrays directly
176///
177/// This provides a consistent interface for generating random arrays
178/// across the SCIRS2 ecosystem.
179pub trait DistributionExt<T>: Distribution<T> + Sized {
180    /// Create a random array with values from this distribution
181    fn random_array<R: Rng, Dim: Dimension>(&self, rng: &mut Random<R>, shape: Dim) -> Array<T, Dim>
182    where
183        Self: Copy,
184    {
185        rng.sample_array(shape, *self)
186    }
187
188    /// Create a random vector with values from this distribution
189    fn sample_vec<R: Rng>(&self, rng: &mut Random<R>, size: usize) -> Vec<T>
190    where
191        Self: Copy,
192    {
193        rng.sample_vec(*self, size)
194    }
195}
196
197// Implement the extension trait for all distributions
198impl<D, T> DistributionExt<T> for D where D: Distribution<T> {}
199
200thread_local! {
201    static THREAD_RNG: RefCell<Random> = RefCell::new(Random::default());
202}
203
204/// Get a reference to the thread-local random number generator
205#[allow(dead_code)]
206pub fn get_rng<F, R>(f: F) -> R
207where
208    F: FnOnce(&mut Random) -> R,
209{
210    THREAD_RNG.with(|rng| f(&mut rng.borrow_mut()))
211}
212
213/// Scientific random number generation utilities
214pub mod scientific {
215    use super::*;
216
217    /// Generate reproducible random sequences for scientific experiments
218    pub struct ReproducibleSequence {
219        seed: u64,
220        sequence_id: u64,
221    }
222
223    impl ReproducibleSequence {
224        /// Create a new reproducible sequence
225        pub fn new(seed: u64) -> Self {
226            Self {
227                seed,
228                sequence_id: 0,
229            }
230        }
231
232        /// Get the next RNG in the sequence
233        pub fn next_rng(&mut self) -> Random<StdRng> {
234            let combined_seed = self.seed.wrapping_mul(31).wrapping_add(self.sequence_id);
235            self.sequence_id += 1;
236            Random::seed(combined_seed)
237        }
238
239        /// Reset the sequence
240        pub fn reset(&mut self) {
241            self.sequence_id = 0;
242        }
243
244        /// Get current sequence position
245        pub fn position(&self) -> u64 {
246            self.sequence_id
247        }
248    }
249
250    /// Deterministic random state for reproducible experiments
251    #[derive(Debug, Clone)]
252    pub struct DeterministicState {
253        pub seed: u64,
254        pub call_count: u64,
255    }
256
257    impl DeterministicState {
258        /// Create a new deterministic state
259        pub fn new(seed: u64) -> Self {
260            Self {
261                seed,
262                call_count: 0,
263            }
264        }
265
266        /// Create an RNG from this state and advance the counter
267        pub fn next_rng(&mut self) -> Random<StdRng> {
268            let rng_seed = self.seed.wrapping_mul(31).wrapping_add(self.call_count);
269            self.call_count += 1;
270            Random::seed(rng_seed)
271        }
272
273        /// Get current state without advancing
274        pub fn current_state(&self) -> (u64, u64) {
275            (self.seed, self.call_count)
276        }
277
278        /// Get current position in the sequence
279        pub fn position(&self) -> u64 {
280            self.call_count
281        }
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288    use approx::assert_abs_diff_eq;
289
290    #[test]
291    fn test_random_creation() {
292        let mut rng = Random::default();
293        let _val = rng.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
294    }
295
296    #[test]
297    fn test_seeded_rng() {
298        let mut rng1 = seeded_rng(42);
299        let mut rng2 = seeded_rng(42);
300
301        let val1 = rng1.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
302        let val2 = rng2.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
303
304        assert_eq!(val1, val2);
305    }
306
307    #[test]
308    fn test_thread_rng() {
309        let mut rng = thread_rng();
310        let val = rng.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
311        assert!((0.0..1.0).contains(&val));
312    }
313
314    #[test]
315    fn test_reproducible_sequence() {
316        let mut seq1 = scientific::ReproducibleSequence::new(123);
317        let mut seq2 = scientific::ReproducibleSequence::new(123);
318
319        let mut rng1_1 = seq1.next_rng();
320        let mut rng1_2 = seq1.next_rng();
321
322        let mut rng2_1 = seq2.next_rng();
323        let mut rng2_2 = seq2.next_rng();
324
325        let val1_1 = rng1_1.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
326        let val1_2 = rng1_2.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
327
328        let val2_1 = rng2_1.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
329        let val2_2 = rng2_2.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
330
331        assert_eq!(val1_1, val2_1);
332        assert_eq!(val1_2, val2_2);
333        assert_ne!(val1_1, val1_2);
334    }
335
336    #[test]
337    fn test_deterministic_state() {
338        let mut state1 = scientific::DeterministicState::new(456);
339        let mut state2 = scientific::DeterministicState::new(456);
340
341        let mut rng1 = state1.next_rng();
342        let mut rng2 = state2.next_rng();
343
344        let val1 = rng1.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
345        let val2 = rng2.sample(Uniform::new(0.0, 1.0).expect("Operation failed"));
346
347        assert_eq!(val1, val2);
348        assert_eq!(state1.position(), state2.position());
349    }
350
351    #[test]
352    fn test_sample_array() {
353        let mut rng = seeded_rng(789);
354        let array = rng.sample_array(Ix2(3, 3), Uniform::new(0.0, 1.0).expect("Operation failed"));
355
356        assert_eq!(array.shape(), &[3, 3]);
357        assert!(array.iter().all(|&x| (0.0..1.0).contains(&x)));
358    }
359
360    #[test]
361    fn test_distribution_ext() {
362        let mut rng = seeded_rng(101112);
363        let distribution = Uniform::new(-1.0, 1.0).expect("Operation failed");
364
365        let vec = distribution.sample_vec(&mut rng, 10);
366        assert_eq!(vec.len(), 10);
367        assert!(vec.iter().all(|&x| (-1.0..1.0).contains(&x)));
368
369        let array = distribution.random_array(&mut rng, Ix2(2, 5));
370        assert_eq!(array.shape(), &[2, 5]);
371    }
372}