Skip to main content

scirs2_core/random/
parallel.rs

1//! Thread-local RNG pools for high-performance parallel applications
2//!
3//! This module provides parallel-safe random number generation utilities
4//! that are essential for high-performance scientific computing applications
5//! requiring concurrent random number generation.
6
7use crate::random::core::Random;
8use ::ndarray::{Array, Dimension, IxDyn};
9use rand::rngs::StdRng;
10use rand::{Rng, SeedableRng};
11use rand_distr::Distribution;
12use std::sync::atomic::{AtomicUsize, Ordering};
13use std::sync::Arc;
14use std::time::{SystemTime, UNIX_EPOCH};
15
16/// Thread-local random number generator pool
17///
18/// Provides deterministic, thread-safe random number generation by maintaining
19/// separate RNG instances for each thread, seeded in a predictable manner.
20#[derive(Debug)]
21pub struct ThreadLocalRngPool {
22    seed_counter: Arc<AtomicUsize>,
23    base_seed: u64,
24}
25
26impl ThreadLocalRngPool {
27    /// Create a new thread-local RNG pool with a specific seed
28    ///
29    /// This ensures deterministic behavior across parallel executions
30    /// when the same base seed is used.
31    pub fn new(seed: u64) -> Self {
32        Self {
33            seed_counter: Arc::new(AtomicUsize::new(0)),
34            base_seed: seed,
35        }
36    }
37
38    /// Create a thread-local RNG pool with a seed derived from system time
39    pub fn new_time_seeded() -> Self {
40        let seed = SystemTime::now()
41            .duration_since(UNIX_EPOCH)
42            .map(|d| d.as_secs())
43            .unwrap_or(42);
44        Self::new(seed)
45    }
46
47    /// Get a thread-local RNG
48    ///
49    /// Each call to this method returns a new RNG instance seeded with
50    /// a deterministic value based on the base seed and thread counter.
51    pub fn get_rng(&self) -> Random<StdRng> {
52        let thread_id = self.seed_counter.fetch_add(1, Ordering::Relaxed);
53        let seed = self.base_seed.wrapping_add(thread_id as u64);
54        Random::seed(seed)
55    }
56
57    /// Execute a closure with a thread-local RNG
58    ///
59    /// This is the preferred way to use the thread-local RNG pool as it
60    /// ensures proper resource management and consistent seeding.
61    pub fn with_rng<F, R>(&self, f: F) -> R
62    where
63        F: FnOnce(&mut Random<StdRng>) -> R,
64    {
65        let mut rng = self.get_rng();
66        f(&mut rng)
67    }
68
69    /// Get the base seed used by this pool
70    pub fn base_seed(&self) -> u64 {
71        self.base_seed
72    }
73
74    /// Get the current thread counter value
75    pub fn thread_counter(&self) -> usize {
76        self.seed_counter.load(Ordering::Relaxed)
77    }
78
79    /// Reset the thread counter (useful for reproducible testing)
80    pub fn reset_counter(&self) {
81        self.seed_counter.store(0, Ordering::Relaxed);
82    }
83}
84
85impl Default for ThreadLocalRngPool {
86    fn default() -> Self {
87        Self::new_time_seeded()
88    }
89}
90
91/// Parallel random number generation utilities
92///
93/// Provides high-level functions for generating random numbers in parallel
94/// with automatic fallback to sequential generation when parallel features
95/// are not available.
96pub struct ParallelRng;
97
98impl ParallelRng {
99    /// Generate parallel random samples using Rayon (when available)
100    ///
101    /// When the "parallel" feature is enabled, this uses Rayon for parallel
102    /// generation. Otherwise, it falls back to sequential generation.
103    #[cfg(feature = "parallel")]
104    pub fn parallel_sample<D, T>(distribution: D, count: usize, pool: &ThreadLocalRngPool) -> Vec<T>
105    where
106        D: Distribution<T> + Copy + Send + Sync,
107        T: Send,
108    {
109        use crate::parallel_ops::{IntoParallelIterator, ParallelIterator};
110
111        (0..count)
112            .into_par_iter()
113            .map(|_| pool.with_rng(|rng| rng.sample(distribution)))
114            .collect()
115    }
116
117    /// Generate parallel random arrays using Rayon (when available)
118    #[cfg(feature = "parallel")]
119    pub fn parallel_sample_array<D, T, Sh>(
120        distribution: D,
121        shape: Sh,
122        pool: &ThreadLocalRngPool,
123    ) -> Array<T, IxDyn>
124    where
125        D: Distribution<T> + Copy + Send + Sync,
126        T: Send + Clone,
127        Sh: Into<IxDyn>,
128    {
129        let shape = shape.into();
130        let size = shape.size();
131        let samples = Self::parallel_sample(distribution, size, pool);
132        Array::from_shape_vec(shape, samples).expect("Operation failed")
133    }
134
135    /// Sequential fallback when parallel feature is not enabled
136    #[cfg(not(feature = "parallel"))]
137    pub fn parallel_sample<D, T>(distribution: D, count: usize, pool: &ThreadLocalRngPool) -> Vec<T>
138    where
139        D: Distribution<T> + Copy,
140    {
141        pool.with_rng(|rng| rng.sample_vec(distribution, count))
142    }
143
144    /// Sequential fallback when parallel feature is not enabled
145    #[cfg(not(feature = "parallel"))]
146    pub fn parallel_sample_array<D, T, Sh>(
147        distribution: D,
148        shape: Sh,
149        pool: &ThreadLocalRngPool,
150    ) -> Array<T, IxDyn>
151    where
152        D: Distribution<T> + Copy,
153        T: Send + Clone,
154        Sh: Into<IxDyn> + crate::ndarray::Dimension,
155    {
156        pool.with_rng(|rng| rng.sample_array(shape.into(), distribution))
157    }
158
159    /// Generate parallel random samples with chunked processing
160    ///
161    /// Divides the work into chunks to balance load and reduce overhead.
162    /// This is particularly useful for very large sample sizes.
163    pub fn parallel_sample_chunked<D, T>(
164        distribution: D,
165        count: usize,
166        chunk_size: usize,
167        pool: &ThreadLocalRngPool,
168    ) -> Vec<T>
169    where
170        D: Distribution<T> + Copy + Send + Sync,
171        T: Send,
172    {
173        let num_chunks = (count + chunk_size - 1) / chunk_size;
174        let mut result = Vec::with_capacity(count);
175
176        #[cfg(feature = "parallel")]
177        {
178            use crate::parallel_ops::{IntoParallelIterator, ParallelIterator};
179
180            let chunks: Vec<Vec<T>> = (0..num_chunks)
181                .into_par_iter()
182                .map(|chunk_idx| {
183                    let start = chunk_idx * chunk_size;
184                    let end = std::cmp::min(start + chunk_size, count);
185                    let chunk_count = end - start;
186
187                    pool.with_rng(|rng| rng.sample_vec(distribution, chunk_count))
188                })
189                .collect();
190
191            for chunk in chunks {
192                result.extend(chunk);
193            }
194        }
195
196        #[cfg(not(feature = "parallel"))]
197        {
198            for chunk_idx in 0..num_chunks {
199                let start = chunk_idx * chunk_size;
200                let end = std::cmp::min(start + chunk_size, count);
201                let chunk_count = end - start;
202
203                let chunk = pool.with_rng(|rng| rng.sample_vec(distribution, chunk_count));
204                result.extend(chunk);
205            }
206        }
207
208        result
209    }
210
211    /// Generate parallel bootstrap samples
212    ///
213    /// Useful for statistical bootstrap resampling in parallel.
214    pub fn parallel_bootstrap<T>(
215        data: &[T],
216        n_bootstrap: usize,
217        pool: &ThreadLocalRngPool,
218    ) -> Vec<Vec<T>>
219    where
220        T: Clone + Send + Sync,
221    {
222        #[cfg(feature = "parallel")]
223        {
224            use crate::parallel_ops::{IntoParallelIterator, ParallelIterator};
225
226            (0..n_bootstrap)
227                .into_par_iter()
228                .map(|_| {
229                    pool.with_rng(|rng| {
230                        (0..data.len())
231                            .map(|_| {
232                                let idx = rng.random_range(0..data.len());
233                                data[idx].clone()
234                            })
235                            .collect()
236                    })
237                })
238                .collect()
239        }
240
241        #[cfg(not(feature = "parallel"))]
242        {
243            (0..n_bootstrap)
244                .map(|_| {
245                    pool.with_rng(|rng| {
246                        (0..data.len())
247                            .map(|_| {
248                                let idx = rng.random_range(0..data.len());
249                                data[idx].clone()
250                            })
251                            .collect()
252                    })
253                })
254                .collect()
255        }
256    }
257}
258
259/// Workspace-aware parallel RNG for distributed computing
260///
261/// Provides coordination between multiple workers in a distributed system
262/// to ensure non-overlapping random sequences.
263#[derive(Debug)]
264pub struct DistributedRngPool {
265    worker_id: usize,
266    total_workers: usize,
267    base_pool: ThreadLocalRngPool,
268}
269
270impl DistributedRngPool {
271    /// Create a new distributed RNG pool
272    ///
273    /// # Parameters
274    /// * `worker_id` - Unique identifier for this worker (0-based)
275    /// * `total_workers` - Total number of workers in the system
276    /// * `base_seed` - Base seed for the entire distributed system
277    pub fn new(worker_id: usize, total_workers: usize, base_seed: u64) -> Self {
278        assert!(
279            worker_id < total_workers,
280            "Worker ID must be less than total workers"
281        );
282
283        // Create a unique seed for this worker
284        let worker_seed = base_seed
285            .wrapping_mul(total_workers as u64)
286            .wrapping_add(worker_id as u64);
287
288        Self {
289            worker_id,
290            total_workers,
291            base_pool: ThreadLocalRngPool::new(worker_seed),
292        }
293    }
294
295    /// Get a thread-local RNG for this worker
296    pub fn get_rng(&self) -> Random<StdRng> {
297        self.base_pool.get_rng()
298    }
299
300    /// Execute a closure with a worker-local RNG
301    pub fn with_rng<F, R>(&self, f: F) -> R
302    where
303        F: FnOnce(&mut Random<StdRng>) -> R,
304    {
305        self.base_pool.with_rng(f)
306    }
307
308    /// Get worker information
309    pub fn worker_info(&self) -> (usize, usize) {
310        (self.worker_id, self.total_workers)
311    }
312
313    /// Generate samples allocated to this worker
314    ///
315    /// Automatically calculates the portion of samples this worker should generate
316    /// based on its worker ID and the total number of workers.
317    pub fn worker_sample<D, T>(&self, distribution: D, total_samples: usize) -> Vec<T>
318    where
319        D: Distribution<T> + Copy,
320    {
321        let samples_per_worker = total_samples / self.total_workers;
322        let extra_samples = total_samples % self.total_workers;
323
324        let my_samples = if self.worker_id < extra_samples {
325            samples_per_worker + 1
326        } else {
327            samples_per_worker
328        };
329
330        self.with_rng(|rng| rng.sample_vec(distribution, my_samples))
331    }
332}
333
334/// Batch processing utilities for large-scale random number generation
335pub struct BatchRng;
336
337impl BatchRng {
338    /// Process random number generation in batches
339    ///
340    /// Useful for memory-efficient processing of very large datasets.
341    pub fn process_batches<D, T, F, R>(
342        distribution: D,
343        total_samples: usize,
344        batch_size: usize,
345        pool: &ThreadLocalRngPool,
346        mut processor: F,
347    ) -> Vec<R>
348    where
349        D: Distribution<T> + Copy + Send + Sync,
350        T: Send,
351        F: FnMut(Vec<T>) -> R,
352        R: Send,
353    {
354        let num_batches = (total_samples + batch_size - 1) / batch_size;
355        let mut results = Vec::with_capacity(num_batches);
356
357        for batch_idx in 0..num_batches {
358            let start = batch_idx * batch_size;
359            let end = std::cmp::min(start + batch_size, total_samples);
360            let current_batch_size = end - start;
361
362            let batch_samples =
363                ParallelRng::parallel_sample(distribution, current_batch_size, pool);
364            let result = processor(batch_samples);
365            results.push(result);
366        }
367
368        results
369    }
370
371    /// Stream random numbers with callback processing
372    ///
373    /// Generates random numbers in chunks and processes them with a callback,
374    /// useful for streaming applications where memory usage must be controlled.
375    pub fn stream_samples<D, T, F>(
376        distribution: D,
377        total_samples: usize,
378        chunk_size: usize,
379        pool: &ThreadLocalRngPool,
380        mut callback: F,
381    ) where
382        D: Distribution<T> + Copy,
383        F: FnMut(&[T]),
384    {
385        let num_chunks = (total_samples + chunk_size - 1) / chunk_size;
386
387        for chunk_idx in 0..num_chunks {
388            let start = chunk_idx * chunk_size;
389            let end = std::cmp::min(start + chunk_size, total_samples);
390            let current_chunk_size = end - start;
391
392            let chunk_samples =
393                pool.with_rng(|rng| rng.sample_vec(distribution, current_chunk_size));
394            callback(&chunk_samples);
395        }
396    }
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402    use rand_distr::Uniform;
403
404    #[test]
405    fn test_thread_local_rng_pool() {
406        let pool = ThreadLocalRngPool::new(42);
407        assert_eq!(pool.base_seed(), 42);
408        assert_eq!(pool.thread_counter(), 0);
409
410        let value =
411            pool.with_rng(|rng| rng.sample(Uniform::new(0.0, 1.0).expect("Operation failed")));
412        assert!((0.0..1.0).contains(&value));
413        assert_eq!(pool.thread_counter(), 1);
414    }
415
416    #[test]
417    fn test_thread_local_rng_pool_deterministic() {
418        let pool1 = ThreadLocalRngPool::new(123);
419        let pool2 = ThreadLocalRngPool::new(123);
420
421        let value1 =
422            pool1.with_rng(|rng| rng.sample(Uniform::new(0.0, 1.0).expect("Operation failed")));
423        let value2 =
424            pool2.with_rng(|rng| rng.sample(Uniform::new(0.0, 1.0).expect("Operation failed")));
425
426        assert_eq!(value1, value2);
427    }
428
429    #[test]
430    fn test_thread_local_rng_pool_reset() {
431        let pool = ThreadLocalRngPool::new(42);
432
433        pool.with_rng(|_| {});
434        pool.with_rng(|_| {});
435        assert_eq!(pool.thread_counter(), 2);
436
437        pool.reset_counter();
438        assert_eq!(pool.thread_counter(), 0);
439    }
440
441    #[test]
442    fn test_parallel_rng_sequential() {
443        let pool = ThreadLocalRngPool::new(456);
444        let samples = ParallelRng::parallel_sample(
445            Uniform::new(0.0, 1.0).expect("Operation failed"),
446            100,
447            &pool,
448        );
449
450        assert_eq!(samples.len(), 100);
451        assert!(samples.iter().all(|&x| (0.0..1.0).contains(&x)));
452    }
453
454    #[test]
455    fn test_parallel_rng_chunked() {
456        let pool = ThreadLocalRngPool::new(789);
457        let samples = ParallelRng::parallel_sample_chunked(
458            Uniform::new(0.0, 1.0).expect("Operation failed"),
459            100,
460            25,
461            &pool,
462        );
463
464        assert_eq!(samples.len(), 100);
465        assert!(samples.iter().all(|&x| (0.0..1.0).contains(&x)));
466    }
467
468    #[test]
469    fn test_parallel_bootstrap() {
470        let data = vec![1, 2, 3, 4, 5];
471        let pool = ThreadLocalRngPool::new(101112);
472
473        let bootstrap_samples = ParallelRng::parallel_bootstrap(&data, 10, &pool);
474        assert_eq!(bootstrap_samples.len(), 10);
475        assert!(bootstrap_samples
476            .iter()
477            .all(|sample| sample.len() == data.len()));
478        assert!(bootstrap_samples
479            .iter()
480            .flatten()
481            .all(|&x| data.contains(&x)));
482    }
483
484    #[test]
485    fn test_distributed_rng_pool() {
486        let pool = DistributedRngPool::new(0, 4, 12345);
487        assert_eq!(pool.worker_info(), (0, 4));
488
489        let samples = pool.worker_sample(Uniform::new(0.0, 1.0).expect("Operation failed"), 100);
490        assert_eq!(samples.len(), 25); // 100 / 4 = 25 samples per worker
491    }
492
493    #[test]
494    fn test_distributed_rng_pool_uneven_distribution() {
495        // Test with total samples not evenly divisible by workers
496        let pool = DistributedRngPool::new(0, 3, 12345);
497        let samples = pool.worker_sample(Uniform::new(0.0, 1.0).expect("Operation failed"), 10);
498        assert_eq!(samples.len(), 4); // First worker gets 4 samples (10/3 + 1)
499
500        let pool = DistributedRngPool::new(2, 3, 12345);
501        let samples = pool.worker_sample(Uniform::new(0.0, 1.0).expect("Operation failed"), 10);
502        assert_eq!(samples.len(), 3); // Last worker gets 3 samples (10/3)
503    }
504
505    #[test]
506    fn test_batch_processing() {
507        let pool = ThreadLocalRngPool::new(131415);
508        let results = BatchRng::process_batches(
509            Uniform::new(0.0, 1.0).expect("Operation failed"),
510            100,
511            25,
512            &pool,
513            |batch| batch.len(),
514        );
515
516        assert_eq!(results.len(), 4); // 100 samples / 25 per batch = 4 batches
517        assert_eq!(results.iter().sum::<usize>(), 100);
518    }
519
520    #[test]
521    fn test_stream_samples() {
522        let pool = ThreadLocalRngPool::new(161718);
523        let mut total_processed = 0;
524
525        BatchRng::stream_samples(
526            Uniform::new(0.0, 1.0).expect("Operation failed"),
527            100,
528            30,
529            &pool,
530            |chunk| {
531                total_processed += chunk.len();
532                assert!(chunk.iter().all(|&x| (0.0..1.0).contains(&x)));
533            },
534        );
535
536        assert_eq!(total_processed, 100);
537    }
538}