1use crate::random::core::Random;
8use ::ndarray::{Array, Dimension, IxDyn};
9use rand::rngs::StdRng;
10use rand::{Rng, SeedableRng};
11use rand_distr::Distribution;
12use std::sync::atomic::{AtomicUsize, Ordering};
13use std::sync::Arc;
14use std::time::{SystemTime, UNIX_EPOCH};
15
16#[derive(Debug)]
21pub struct ThreadLocalRngPool {
22 seed_counter: Arc<AtomicUsize>,
23 base_seed: u64,
24}
25
26impl ThreadLocalRngPool {
27 pub fn new(seed: u64) -> Self {
32 Self {
33 seed_counter: Arc::new(AtomicUsize::new(0)),
34 base_seed: seed,
35 }
36 }
37
38 pub fn new_time_seeded() -> Self {
40 let seed = SystemTime::now()
41 .duration_since(UNIX_EPOCH)
42 .map(|d| d.as_secs())
43 .unwrap_or(42);
44 Self::new(seed)
45 }
46
47 pub fn get_rng(&self) -> Random<StdRng> {
52 let thread_id = self.seed_counter.fetch_add(1, Ordering::Relaxed);
53 let seed = self.base_seed.wrapping_add(thread_id as u64);
54 Random::seed(seed)
55 }
56
57 pub fn with_rng<F, R>(&self, f: F) -> R
62 where
63 F: FnOnce(&mut Random<StdRng>) -> R,
64 {
65 let mut rng = self.get_rng();
66 f(&mut rng)
67 }
68
69 pub fn base_seed(&self) -> u64 {
71 self.base_seed
72 }
73
74 pub fn thread_counter(&self) -> usize {
76 self.seed_counter.load(Ordering::Relaxed)
77 }
78
79 pub fn reset_counter(&self) {
81 self.seed_counter.store(0, Ordering::Relaxed);
82 }
83}
84
85impl Default for ThreadLocalRngPool {
86 fn default() -> Self {
87 Self::new_time_seeded()
88 }
89}
90
91pub struct ParallelRng;
97
98impl ParallelRng {
99 #[cfg(feature = "parallel")]
104 pub fn parallel_sample<D, T>(distribution: D, count: usize, pool: &ThreadLocalRngPool) -> Vec<T>
105 where
106 D: Distribution<T> + Copy + Send + Sync,
107 T: Send,
108 {
109 use crate::parallel_ops::{IntoParallelIterator, ParallelIterator};
110
111 (0..count)
112 .into_par_iter()
113 .map(|_| pool.with_rng(|rng| rng.sample(distribution)))
114 .collect()
115 }
116
117 #[cfg(feature = "parallel")]
119 pub fn parallel_sample_array<D, T, Sh>(
120 distribution: D,
121 shape: Sh,
122 pool: &ThreadLocalRngPool,
123 ) -> Array<T, IxDyn>
124 where
125 D: Distribution<T> + Copy + Send + Sync,
126 T: Send + Clone,
127 Sh: Into<IxDyn>,
128 {
129 let shape = shape.into();
130 let size = shape.size();
131 let samples = Self::parallel_sample(distribution, size, pool);
132 Array::from_shape_vec(shape, samples).expect("Operation failed")
133 }
134
135 #[cfg(not(feature = "parallel"))]
137 pub fn parallel_sample<D, T>(distribution: D, count: usize, pool: &ThreadLocalRngPool) -> Vec<T>
138 where
139 D: Distribution<T> + Copy,
140 {
141 pool.with_rng(|rng| rng.sample_vec(distribution, count))
142 }
143
144 #[cfg(not(feature = "parallel"))]
146 pub fn parallel_sample_array<D, T, Sh>(
147 distribution: D,
148 shape: Sh,
149 pool: &ThreadLocalRngPool,
150 ) -> Array<T, IxDyn>
151 where
152 D: Distribution<T> + Copy,
153 T: Send + Clone,
154 Sh: Into<IxDyn> + crate::ndarray::Dimension,
155 {
156 pool.with_rng(|rng| rng.sample_array(shape.into(), distribution))
157 }
158
159 pub fn parallel_sample_chunked<D, T>(
164 distribution: D,
165 count: usize,
166 chunk_size: usize,
167 pool: &ThreadLocalRngPool,
168 ) -> Vec<T>
169 where
170 D: Distribution<T> + Copy + Send + Sync,
171 T: Send,
172 {
173 let num_chunks = (count + chunk_size - 1) / chunk_size;
174 let mut result = Vec::with_capacity(count);
175
176 #[cfg(feature = "parallel")]
177 {
178 use crate::parallel_ops::{IntoParallelIterator, ParallelIterator};
179
180 let chunks: Vec<Vec<T>> = (0..num_chunks)
181 .into_par_iter()
182 .map(|chunk_idx| {
183 let start = chunk_idx * chunk_size;
184 let end = std::cmp::min(start + chunk_size, count);
185 let chunk_count = end - start;
186
187 pool.with_rng(|rng| rng.sample_vec(distribution, chunk_count))
188 })
189 .collect();
190
191 for chunk in chunks {
192 result.extend(chunk);
193 }
194 }
195
196 #[cfg(not(feature = "parallel"))]
197 {
198 for chunk_idx in 0..num_chunks {
199 let start = chunk_idx * chunk_size;
200 let end = std::cmp::min(start + chunk_size, count);
201 let chunk_count = end - start;
202
203 let chunk = pool.with_rng(|rng| rng.sample_vec(distribution, chunk_count));
204 result.extend(chunk);
205 }
206 }
207
208 result
209 }
210
211 pub fn parallel_bootstrap<T>(
215 data: &[T],
216 n_bootstrap: usize,
217 pool: &ThreadLocalRngPool,
218 ) -> Vec<Vec<T>>
219 where
220 T: Clone + Send + Sync,
221 {
222 #[cfg(feature = "parallel")]
223 {
224 use crate::parallel_ops::{IntoParallelIterator, ParallelIterator};
225
226 (0..n_bootstrap)
227 .into_par_iter()
228 .map(|_| {
229 pool.with_rng(|rng| {
230 (0..data.len())
231 .map(|_| {
232 let idx = rng.random_range(0..data.len());
233 data[idx].clone()
234 })
235 .collect()
236 })
237 })
238 .collect()
239 }
240
241 #[cfg(not(feature = "parallel"))]
242 {
243 (0..n_bootstrap)
244 .map(|_| {
245 pool.with_rng(|rng| {
246 (0..data.len())
247 .map(|_| {
248 let idx = rng.random_range(0..data.len());
249 data[idx].clone()
250 })
251 .collect()
252 })
253 })
254 .collect()
255 }
256 }
257}
258
259#[derive(Debug)]
264pub struct DistributedRngPool {
265 worker_id: usize,
266 total_workers: usize,
267 base_pool: ThreadLocalRngPool,
268}
269
270impl DistributedRngPool {
271 pub fn new(worker_id: usize, total_workers: usize, base_seed: u64) -> Self {
278 assert!(
279 worker_id < total_workers,
280 "Worker ID must be less than total workers"
281 );
282
283 let worker_seed = base_seed
285 .wrapping_mul(total_workers as u64)
286 .wrapping_add(worker_id as u64);
287
288 Self {
289 worker_id,
290 total_workers,
291 base_pool: ThreadLocalRngPool::new(worker_seed),
292 }
293 }
294
295 pub fn get_rng(&self) -> Random<StdRng> {
297 self.base_pool.get_rng()
298 }
299
300 pub fn with_rng<F, R>(&self, f: F) -> R
302 where
303 F: FnOnce(&mut Random<StdRng>) -> R,
304 {
305 self.base_pool.with_rng(f)
306 }
307
308 pub fn worker_info(&self) -> (usize, usize) {
310 (self.worker_id, self.total_workers)
311 }
312
313 pub fn worker_sample<D, T>(&self, distribution: D, total_samples: usize) -> Vec<T>
318 where
319 D: Distribution<T> + Copy,
320 {
321 let samples_per_worker = total_samples / self.total_workers;
322 let extra_samples = total_samples % self.total_workers;
323
324 let my_samples = if self.worker_id < extra_samples {
325 samples_per_worker + 1
326 } else {
327 samples_per_worker
328 };
329
330 self.with_rng(|rng| rng.sample_vec(distribution, my_samples))
331 }
332}
333
334pub struct BatchRng;
336
337impl BatchRng {
338 pub fn process_batches<D, T, F, R>(
342 distribution: D,
343 total_samples: usize,
344 batch_size: usize,
345 pool: &ThreadLocalRngPool,
346 mut processor: F,
347 ) -> Vec<R>
348 where
349 D: Distribution<T> + Copy + Send + Sync,
350 T: Send,
351 F: FnMut(Vec<T>) -> R,
352 R: Send,
353 {
354 let num_batches = (total_samples + batch_size - 1) / batch_size;
355 let mut results = Vec::with_capacity(num_batches);
356
357 for batch_idx in 0..num_batches {
358 let start = batch_idx * batch_size;
359 let end = std::cmp::min(start + batch_size, total_samples);
360 let current_batch_size = end - start;
361
362 let batch_samples =
363 ParallelRng::parallel_sample(distribution, current_batch_size, pool);
364 let result = processor(batch_samples);
365 results.push(result);
366 }
367
368 results
369 }
370
371 pub fn stream_samples<D, T, F>(
376 distribution: D,
377 total_samples: usize,
378 chunk_size: usize,
379 pool: &ThreadLocalRngPool,
380 mut callback: F,
381 ) where
382 D: Distribution<T> + Copy,
383 F: FnMut(&[T]),
384 {
385 let num_chunks = (total_samples + chunk_size - 1) / chunk_size;
386
387 for chunk_idx in 0..num_chunks {
388 let start = chunk_idx * chunk_size;
389 let end = std::cmp::min(start + chunk_size, total_samples);
390 let current_chunk_size = end - start;
391
392 let chunk_samples =
393 pool.with_rng(|rng| rng.sample_vec(distribution, current_chunk_size));
394 callback(&chunk_samples);
395 }
396 }
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402 use rand_distr::Uniform;
403
404 #[test]
405 fn test_thread_local_rng_pool() {
406 let pool = ThreadLocalRngPool::new(42);
407 assert_eq!(pool.base_seed(), 42);
408 assert_eq!(pool.thread_counter(), 0);
409
410 let value =
411 pool.with_rng(|rng| rng.sample(Uniform::new(0.0, 1.0).expect("Operation failed")));
412 assert!((0.0..1.0).contains(&value));
413 assert_eq!(pool.thread_counter(), 1);
414 }
415
416 #[test]
417 fn test_thread_local_rng_pool_deterministic() {
418 let pool1 = ThreadLocalRngPool::new(123);
419 let pool2 = ThreadLocalRngPool::new(123);
420
421 let value1 =
422 pool1.with_rng(|rng| rng.sample(Uniform::new(0.0, 1.0).expect("Operation failed")));
423 let value2 =
424 pool2.with_rng(|rng| rng.sample(Uniform::new(0.0, 1.0).expect("Operation failed")));
425
426 assert_eq!(value1, value2);
427 }
428
429 #[test]
430 fn test_thread_local_rng_pool_reset() {
431 let pool = ThreadLocalRngPool::new(42);
432
433 pool.with_rng(|_| {});
434 pool.with_rng(|_| {});
435 assert_eq!(pool.thread_counter(), 2);
436
437 pool.reset_counter();
438 assert_eq!(pool.thread_counter(), 0);
439 }
440
441 #[test]
442 fn test_parallel_rng_sequential() {
443 let pool = ThreadLocalRngPool::new(456);
444 let samples = ParallelRng::parallel_sample(
445 Uniform::new(0.0, 1.0).expect("Operation failed"),
446 100,
447 &pool,
448 );
449
450 assert_eq!(samples.len(), 100);
451 assert!(samples.iter().all(|&x| (0.0..1.0).contains(&x)));
452 }
453
454 #[test]
455 fn test_parallel_rng_chunked() {
456 let pool = ThreadLocalRngPool::new(789);
457 let samples = ParallelRng::parallel_sample_chunked(
458 Uniform::new(0.0, 1.0).expect("Operation failed"),
459 100,
460 25,
461 &pool,
462 );
463
464 assert_eq!(samples.len(), 100);
465 assert!(samples.iter().all(|&x| (0.0..1.0).contains(&x)));
466 }
467
468 #[test]
469 fn test_parallel_bootstrap() {
470 let data = vec![1, 2, 3, 4, 5];
471 let pool = ThreadLocalRngPool::new(101112);
472
473 let bootstrap_samples = ParallelRng::parallel_bootstrap(&data, 10, &pool);
474 assert_eq!(bootstrap_samples.len(), 10);
475 assert!(bootstrap_samples
476 .iter()
477 .all(|sample| sample.len() == data.len()));
478 assert!(bootstrap_samples
479 .iter()
480 .flatten()
481 .all(|&x| data.contains(&x)));
482 }
483
484 #[test]
485 fn test_distributed_rng_pool() {
486 let pool = DistributedRngPool::new(0, 4, 12345);
487 assert_eq!(pool.worker_info(), (0, 4));
488
489 let samples = pool.worker_sample(Uniform::new(0.0, 1.0).expect("Operation failed"), 100);
490 assert_eq!(samples.len(), 25); }
492
493 #[test]
494 fn test_distributed_rng_pool_uneven_distribution() {
495 let pool = DistributedRngPool::new(0, 3, 12345);
497 let samples = pool.worker_sample(Uniform::new(0.0, 1.0).expect("Operation failed"), 10);
498 assert_eq!(samples.len(), 4); let pool = DistributedRngPool::new(2, 3, 12345);
501 let samples = pool.worker_sample(Uniform::new(0.0, 1.0).expect("Operation failed"), 10);
502 assert_eq!(samples.len(), 3); }
504
505 #[test]
506 fn test_batch_processing() {
507 let pool = ThreadLocalRngPool::new(131415);
508 let results = BatchRng::process_batches(
509 Uniform::new(0.0, 1.0).expect("Operation failed"),
510 100,
511 25,
512 &pool,
513 |batch| batch.len(),
514 );
515
516 assert_eq!(results.len(), 4); assert_eq!(results.iter().sum::<usize>(), 100);
518 }
519
520 #[test]
521 fn test_stream_samples() {
522 let pool = ThreadLocalRngPool::new(161718);
523 let mut total_processed = 0;
524
525 BatchRng::stream_samples(
526 Uniform::new(0.0, 1.0).expect("Operation failed"),
527 100,
528 30,
529 &pool,
530 |chunk| {
531 total_processed += chunk.len();
532 assert!(chunk.iter().all(|&x| (0.0..1.0).contains(&x)));
533 },
534 );
535
536 assert_eq!(total_processed, 100);
537 }
538}