1use crate::error::{CoreError, Result};
16use crate::tensor::Tensor;
17use crate::{Float, Integer, Scalar};
18
19#[inline]
25fn splitmix64(state: &mut u64) -> u64 {
26 *state = state.wrapping_add(0x9e37_79b9_7f4a_7c15);
27 let mut z = *state;
28 z = (z ^ (z >> 30)).wrapping_mul(0xbf58_476d_1ce4_e5b9);
29 z = (z ^ (z >> 27)).wrapping_mul(0x94d0_49bb_1331_11eb);
30 z ^ (z >> 31)
31}
32
33pub struct Rng {
52 s: [u64; 4],
53 spare_normal: Option<f64>,
55}
56
57impl Rng {
58 pub fn new(seed: u64) -> Self {
62 let mut sm = seed;
63 let s = [
64 splitmix64(&mut sm),
65 splitmix64(&mut sm),
66 splitmix64(&mut sm),
67 splitmix64(&mut sm),
68 ];
69 Self {
70 s,
71 spare_normal: None,
72 }
73 }
74
75 pub fn seed(&mut self, seed: u64) {
87 *self = Self::new(seed);
88 }
89
90 pub fn fork(&mut self, n: usize) -> Vec<Self> {
105 (0..n).map(|_| Self::new(self.next_u64())).collect()
106 }
107
108 #[inline]
119 pub fn next_u64(&mut self) -> u64 {
120 let result = (self.s[1].wrapping_mul(5)).rotate_left(7).wrapping_mul(9);
121
122 let t = self.s[1] << 17;
123
124 self.s[2] ^= self.s[0];
125 self.s[3] ^= self.s[1];
126 self.s[1] ^= self.s[2];
127 self.s[0] ^= self.s[3];
128
129 self.s[2] ^= t;
130 self.s[3] = self.s[3].rotate_left(45);
131
132 result
133 }
134
135 #[inline]
139 pub fn next_f64(&mut self) -> f64 {
140 (self.next_u64() >> 11) as f64 * (1.0 / (1u64 << 53) as f64)
141 }
142
143 pub fn next_normal_f64(&mut self) -> f64 {
158 ziggurat_normal(self)
159 }
160}
161
162const ZIG_N: usize = 128;
168const ZIG_R: f64 = 3.442619855899;
170const ZIG_V: f64 = 9.91256303526217e-3;
172
173fn zig_tables() -> ([f64; ZIG_N + 1], [f64; ZIG_N + 1]) {
176 let mut xtab = [0.0f64; ZIG_N + 1];
177 let mut ytab = [0.0f64; ZIG_N + 1];
178
179 let f = |x: f64| (-0.5 * x * x).exp();
180
181 xtab[ZIG_N] = ZIG_V / f(ZIG_R);
182 xtab[ZIG_N - 1] = ZIG_R;
183 ytab[ZIG_N] = 0.0;
184 ytab[ZIG_N - 1] = f(xtab[ZIG_N - 1]);
185
186 let mut i = ZIG_N - 2;
187 loop {
188 xtab[i] = (-2.0 * (ZIG_V / xtab[i + 1] + f(xtab[i + 1])).ln()).sqrt();
189 ytab[i] = f(xtab[i]);
190 if i == 0 {
191 break;
192 }
193 i -= 1;
194 }
195 (xtab, ytab)
197}
198
199fn zig_tail(rng: &mut Rng, positive: bool) -> f64 {
201 loop {
202 let x = -rng.next_f64().ln() / ZIG_R; let y = -rng.next_f64().ln();
204 if 2.0 * y >= x * x {
205 return if positive { ZIG_R + x } else { -(ZIG_R + x) };
206 }
207 }
208}
209
210fn ziggurat_normal(rng: &mut Rng) -> f64 {
212 use std::sync::OnceLock;
216 static TABLES: OnceLock<([f64; ZIG_N + 1], [f64; ZIG_N + 1])> = OnceLock::new();
217 let (xtab, ytab) = TABLES.get_or_init(zig_tables);
218
219 loop {
220 let u = rng.next_u64();
221 let i = (u & 0x7F) as usize; let sign = if u & 0x80 != 0 { 1.0 } else { -1.0 };
223 let u_float = (u >> 8) as f64 / ((1u64 << 56) as f64);
225 let x = u_float * xtab[i];
226
227 if x < xtab[i + 1] {
229 return sign * x;
230 }
231
232 if i == 0 {
234 return zig_tail(rng, sign > 0.0);
235 }
236
237 let y = ytab[i + 1] + (ytab[i] - ytab[i + 1]) * rng.next_f64();
239 if y < (-0.5 * x * x).exp() {
240 return sign * x;
241 }
242 }
243}
244
245pub fn uniform<T: Float>(rng: &mut Rng, shape: Vec<usize>) -> Tensor<T> {
262 let numel: usize = shape.iter().product();
263 let data: Vec<T> = (0..numel).map(|_| T::from_f64(rng.next_f64())).collect();
264 Tensor::from_vec(data, shape).expect("shape product matches data length")
265}
266
267pub fn uniform_range<T: Float>(
280 rng: &mut Rng,
281 shape: Vec<usize>,
282 low: T,
283 high: T,
284) -> Result<Tensor<T>> {
285 if low >= high {
286 return Err(CoreError::InvalidArgument {
287 reason: "uniform_range requires low < high",
288 });
289 }
290 let range = high - low;
291 let numel: usize = shape.iter().product();
292 let data: Vec<T> = (0..numel)
293 .map(|_| low + T::from_f64(rng.next_f64()) * range)
294 .collect();
295 Ok(Tensor::from_vec(data, shape).expect("shape product matches data length"))
296}
297
298pub fn normal<T: Float>(rng: &mut Rng, shape: Vec<usize>, mean: T, std_dev: T) -> Tensor<T> {
311 let numel: usize = shape.iter().product();
312 let data: Vec<T> = (0..numel)
313 .map(|_| mean + std_dev * T::from_f64(rng.next_normal_f64()))
314 .collect();
315 Tensor::from_vec(data, shape).expect("shape product matches data length")
316}
317
318pub fn standard_normal<T: Float>(rng: &mut Rng, shape: Vec<usize>) -> Tensor<T> {
329 normal(rng, shape, T::zero(), T::one())
330}
331
332pub fn randint<T: Integer>(rng: &mut Rng, shape: Vec<usize>, low: T, high: T) -> Result<Tensor<T>> {
345 if low >= high {
346 return Err(CoreError::InvalidArgument {
347 reason: "randint requires low < high",
348 });
349 }
350 let range = int_range_as_usize(low, high);
352 let numel: usize = shape.iter().product();
353 let data: Vec<T> = (0..numel)
354 .map(|_| {
355 let idx = (rng.next_f64() * range as f64) as usize;
356 low + T::from_usize(idx.min(range - 1))
357 })
358 .collect();
359 Ok(Tensor::from_vec(data, shape).expect("shape product matches data length"))
360}
361
362fn int_range_as_usize<T: Integer>(low: T, high: T) -> usize {
367 (high - low).to_usize()
368}
369
370pub fn bernoulli<T: Scalar>(rng: &mut Rng, shape: Vec<usize>, p: f64) -> Result<Tensor<T>> {
383 if !(0.0..=1.0).contains(&p) {
384 return Err(CoreError::InvalidArgument {
385 reason: "bernoulli requires p in [0, 1]",
386 });
387 }
388 let numel: usize = shape.iter().product();
389 let data: Vec<T> = (0..numel)
390 .map(|_| {
391 if rng.next_f64() < p {
392 T::one()
393 } else {
394 T::zero()
395 }
396 })
397 .collect();
398 Ok(Tensor::from_vec(data, shape).expect("shape product matches data length"))
399}
400
401pub fn shuffle<T: Scalar>(rng: &mut Rng, tensor: &mut Tensor<T>) {
418 let n = tensor.numel();
419 if n <= 1 {
420 return;
421 }
422 let data = tensor.as_mut_slice();
423 for i in (1..n).rev() {
424 let j = (rng.next_f64() * (i + 1) as f64) as usize;
426 let j = j.min(i);
428 data.swap(i, j);
429 }
430}
431
432pub fn choice<T: Scalar>(
451 rng: &mut Rng,
452 tensor: &Tensor<T>,
453 n: usize,
454 replace: bool,
455) -> Result<Tensor<T>> {
456 if tensor.ndim() != 1 {
457 return Err(CoreError::InvalidArgument {
458 reason: "choice requires a 1-D tensor",
459 });
460 }
461 let len = tensor.numel();
462
463 if !replace && n > len {
464 return Err(CoreError::InvalidArgument {
465 reason: "choice without replacement: n > tensor length",
466 });
467 }
468
469 let src = tensor.as_slice();
470
471 if replace {
472 let data: Vec<T> = (0..n)
473 .map(|_| {
474 let idx = (rng.next_f64() * len as f64) as usize;
475 src[idx.min(len - 1)]
476 })
477 .collect();
478 Tensor::from_vec(data, vec![n])
479 } else {
480 let mut indices: Vec<usize> = (0..len).collect();
482 for i in 0..n {
483 let j = i + (rng.next_f64() * (len - i) as f64) as usize;
484 let j = j.min(len - 1);
485 indices.swap(i, j);
486 }
487 let data: Vec<T> = indices[..n].iter().map(|&i| src[i]).collect();
488 Tensor::from_vec(data, vec![n])
489 }
490}
491
492#[cfg(test)]
497#[allow(clippy::float_cmp)]
498mod tests {
499 use super::*;
500
501 #[test]
502 fn test_rng_reproducibility() {
503 let mut rng1 = Rng::new(12345);
504 let mut rng2 = Rng::new(12345);
505 for _ in 0..100 {
506 assert_eq!(rng1.next_u64(), rng2.next_u64());
507 }
508 }
509
510 #[test]
511 fn test_rng_different_seeds() {
512 let mut rng1 = Rng::new(1);
513 let mut rng2 = Rng::new(2);
514 let seq1: Vec<u64> = (0..10).map(|_| rng1.next_u64()).collect();
516 let seq2: Vec<u64> = (0..10).map(|_| rng2.next_u64()).collect();
517 assert_ne!(seq1, seq2);
518 }
519
520 #[test]
521 fn test_next_f64_range() {
522 let mut rng = Rng::new(42);
523 for _ in 0..10_000 {
524 let v = rng.next_f64();
525 assert!((0.0..1.0).contains(&v), "next_f64 out of range: {v}");
526 }
527 }
528
529 #[test]
530 fn test_reseed() {
531 let mut rng = Rng::new(99);
532 let first = rng.next_u64();
533 rng.seed(99);
534 let second = rng.next_u64();
535 assert_eq!(first, second);
536 }
537
538 #[test]
539 fn test_fork() {
540 let mut rng = Rng::new(42);
541 let children = rng.fork(4);
542 assert_eq!(children.len(), 4);
543 let vals: Vec<u64> = children.into_iter().map(|mut r| r.next_u64()).collect();
545 for i in 0..vals.len() {
546 for j in (i + 1)..vals.len() {
547 assert_ne!(vals[i], vals[j], "child RNGs should be independent");
548 }
549 }
550 }
551
552 #[test]
553 fn test_fork_reproducible() {
554 let mut rng1 = Rng::new(42);
555 let children1 = rng1.fork(3);
556 let mut rng2 = Rng::new(42);
557 let children2 = rng2.fork(3);
558 for (mut c1, mut c2) in children1.into_iter().zip(children2) {
560 assert_eq!(c1.next_u64(), c2.next_u64());
561 }
562 }
563
564 #[test]
565 fn test_uniform_shape() {
566 let mut rng = Rng::new(0);
567 let t = uniform::<f64>(&mut rng, vec![3, 4, 5]);
568 assert_eq!(t.shape(), &[3, 4, 5]);
569 assert_eq!(t.numel(), 60);
570 }
571
572 #[test]
573 fn test_uniform_range_values() {
574 let mut rng = Rng::new(0);
575 let t = uniform::<f64>(&mut rng, vec![1000]);
576 for &v in t.as_slice() {
577 assert!((0.0..1.0).contains(&v));
578 }
579 }
580
581 #[test]
582 fn test_uniform_range_bounds() {
583 let mut rng = Rng::new(7);
584 let t = uniform_range::<f64>(&mut rng, vec![5000], 2.0, 5.0).unwrap();
585 for &v in t.as_slice() {
586 assert!((2.0..5.0).contains(&v), "value {v} out of [2, 5)");
587 }
588 }
589
590 #[test]
591 fn test_uniform_range_invalid() {
592 let mut rng = Rng::new(0);
593 assert!(uniform_range::<f64>(&mut rng, vec![10], 5.0, 2.0).is_err());
594 assert!(uniform_range::<f64>(&mut rng, vec![10], 3.0, 3.0).is_err());
595 }
596
597 #[test]
598 fn test_standard_normal_stats() {
599 let mut rng = Rng::new(42);
600 let t = standard_normal::<f64>(&mut rng, vec![100_000]);
601 let data = t.as_slice();
602
603 let mean = data.iter().sum::<f64>() / data.len() as f64;
604 let var = data.iter().map(|&x| (x - mean) * (x - mean)).sum::<f64>() / data.len() as f64;
605 let std = var.sqrt();
606
607 assert!(
608 mean.abs() < 0.02,
609 "standard normal mean too far from 0: {mean}"
610 );
611 assert!(
612 (std - 1.0).abs() < 0.02,
613 "standard normal std too far from 1: {std}"
614 );
615 }
616
617 #[test]
618 fn test_normal_custom() {
619 let mut rng = Rng::new(42);
620 let t = normal::<f64>(&mut rng, vec![50_000], 10.0, 2.0);
621 let data = t.as_slice();
622
623 let mean = data.iter().sum::<f64>() / data.len() as f64;
624 assert!(
625 (mean - 10.0).abs() < 0.1,
626 "normal(10, 2) mean too far from 10: {mean}"
627 );
628 }
629
630 #[test]
631 fn test_randint_range() {
632 let mut rng = Rng::new(0);
633 let t = randint::<i32>(&mut rng, vec![10_000], 5, 10).unwrap();
634 for &v in t.as_slice() {
635 assert!((5..10).contains(&v), "randint value {v} not in [5, 10)");
636 }
637 }
638
639 #[test]
640 fn test_randint_invalid() {
641 let mut rng = Rng::new(0);
642 assert!(randint::<i32>(&mut rng, vec![10], 10, 5).is_err());
643 assert!(randint::<i32>(&mut rng, vec![10], 5, 5).is_err());
644 }
645
646 #[test]
647 fn test_bernoulli_values() {
648 let mut rng = Rng::new(0);
649 let t = bernoulli::<f64>(&mut rng, vec![10_000], 0.3).unwrap();
650 for &v in t.as_slice() {
651 assert!(v == 0.0 || v == 1.0, "bernoulli value {v} not 0 or 1");
652 }
653 let ones = t.as_slice().iter().filter(|&&x| x == 1.0).count();
655 let freq = ones as f64 / 10_000.0;
656 assert!(
657 (freq - 0.3).abs() < 0.03,
658 "bernoulli frequency {freq} too far from 0.3"
659 );
660 }
661
662 #[test]
663 fn test_bernoulli_invalid() {
664 let mut rng = Rng::new(0);
665 assert!(bernoulli::<f64>(&mut rng, vec![10], -0.1).is_err());
666 assert!(bernoulli::<f64>(&mut rng, vec![10], 1.1).is_err());
667 }
668
669 #[test]
670 fn test_shuffle_preserves_elements() {
671 let mut rng = Rng::new(42);
672 let mut t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10], vec![10]).unwrap();
673 shuffle(&mut rng, &mut t);
674
675 let mut sorted = t.as_slice().to_vec();
676 sorted.sort_unstable();
677 assert_eq!(sorted, vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
678 }
679
680 #[test]
681 fn test_shuffle_modifies_order() {
682 let mut rng = Rng::new(42);
683 let original = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
684 let mut t = Tensor::from_vec(original.clone(), vec![10]).unwrap();
685 shuffle(&mut rng, &mut t);
686 assert_ne!(t.as_slice(), &original[..]);
688 }
689
690 #[test]
691 fn test_choice_with_replacement() {
692 let mut rng = Rng::new(0);
693 let t = Tensor::from_vec(vec![10.0, 20.0, 30.0, 40.0, 50.0], vec![5]).unwrap();
694 let sample = choice(&mut rng, &t, 100, true).unwrap();
695 assert_eq!(sample.shape(), &[100]);
696 let valid = [10.0, 20.0, 30.0, 40.0, 50.0];
698 for &v in sample.as_slice() {
699 assert!(valid.contains(&v), "unexpected value {v}");
700 }
701 }
702
703 #[test]
704 fn test_choice_without_replacement() {
705 let mut rng = Rng::new(0);
706 let t = Tensor::from_vec(vec![10, 20, 30, 40, 50], vec![5]).unwrap();
707 let sample = choice(&mut rng, &t, 3, false).unwrap();
708 assert_eq!(sample.shape(), &[3]);
709
710 let data = sample.as_slice();
712 let mut dedup = data.to_vec();
713 dedup.sort_unstable();
714 dedup.dedup();
715 assert_eq!(dedup.len(), 3);
716 }
717
718 #[test]
719 fn test_choice_without_replacement_too_many() {
720 let mut rng = Rng::new(0);
721 let t = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
722 assert!(choice(&mut rng, &t, 5, false).is_err());
723 }
724
725 #[test]
726 fn test_choice_not_1d() {
727 let mut rng = Rng::new(0);
728 let t = Tensor::from_vec(vec![1, 2, 3, 4], vec![2, 2]).unwrap();
729 assert!(choice(&mut rng, &t, 2, true).is_err());
730 }
731
732 #[test]
733 fn test_uniform_f32() {
734 let mut rng = Rng::new(42);
735 let t = uniform::<f32>(&mut rng, vec![100]);
736 assert_eq!(t.shape(), &[100]);
737 for &v in t.as_slice() {
738 assert!((0.0..1.0).contains(&v));
739 }
740 }
741}