Skip to main content

scivex_core/random/
mod.rs

1//! Pseudo-random number generation and random tensor creation.
2//!
3//! Provides a fast, high-quality PRNG (`Rng`) based on the xoshiro256\*\*
4//! algorithm and convenience functions for creating tensors filled with random
5//! values drawn from common distributions.
6//!
7//! # Design
8//!
9//! - **Zero external dependencies** — the PRNG is implemented from scratch.
10//! - **Explicit state** — all functions take `&mut Rng`; there is no hidden
11//!   global or thread-local state.
12//! - Seeding uses `SplitMix64` to expand a single `u64` into the 4-word
13//!   xoshiro256\*\* state (avoids the zero-state trap).
14
15use crate::error::{CoreError, Result};
16use crate::tensor::Tensor;
17use crate::{Float, Integer, Scalar};
18
19// ---------------------------------------------------------------------------
20// SplitMix64 — used only for seeding
21// ---------------------------------------------------------------------------
22
23/// Advance a `SplitMix64` state by one step and return the mixed output.
24#[inline]
25fn splitmix64(state: &mut u64) -> u64 {
26    *state = state.wrapping_add(0x9e37_79b9_7f4a_7c15);
27    let mut z = *state;
28    z = (z ^ (z >> 30)).wrapping_mul(0xbf58_476d_1ce4_e5b9);
29    z = (z ^ (z >> 27)).wrapping_mul(0x94d0_49bb_1331_11eb);
30    z ^ (z >> 31)
31}
32
33// ---------------------------------------------------------------------------
34// Rng — xoshiro256**
35// ---------------------------------------------------------------------------
36
37/// A fast, high-quality pseudo-random number generator.
38///
39/// Uses the xoshiro256\*\* algorithm (Blackman & Vigna), which has a period of
40/// 2^256 − 1 and passes all `BigCrush` tests.
41///
42/// # Examples
43///
44/// ```
45/// use scivex_core::random::Rng;
46///
47/// let mut rng = Rng::new(42);
48/// let value = rng.next_f64(); // uniform in [0, 1)
49/// assert!((0.0..1.0).contains(&value));
50/// ```
51pub struct Rng {
52    s: [u64; 4],
53    /// Cached spare normal value from Box-Muller (None when empty).
54    spare_normal: Option<f64>,
55}
56
57impl Rng {
58    /// Create a new PRNG seeded from a single `u64`.
59    ///
60    /// The seed is expanded into the 4-word internal state via `SplitMix64`.
61    pub fn new(seed: u64) -> Self {
62        let mut sm = seed;
63        let s = [
64            splitmix64(&mut sm),
65            splitmix64(&mut sm),
66            splitmix64(&mut sm),
67            splitmix64(&mut sm),
68        ];
69        Self {
70            s,
71            spare_normal: None,
72        }
73    }
74
75    /// Re-seed the generator, discarding all previous state.
76    ///
77    /// # Examples
78    ///
79    /// ```
80    /// use scivex_core::random::Rng;
81    /// let mut rng = Rng::new(1);
82    /// let first = rng.next_u64();
83    /// rng.seed(1);
84    /// assert_eq!(rng.next_u64(), first);
85    /// ```
86    pub fn seed(&mut self, seed: u64) {
87        *self = Self::new(seed);
88    }
89
90    /// Create `n` independent child RNGs by advancing the state.
91    ///
92    /// Each child receives a unique seed derived from the parent's state.
93    /// This is useful for parallel workloads where each thread needs its
94    /// own RNG to avoid contention.
95    ///
96    /// # Examples
97    ///
98    /// ```
99    /// use scivex_core::random::Rng;
100    /// let mut rng = Rng::new(42);
101    /// let children = rng.fork(4);
102    /// assert_eq!(children.len(), 4);
103    /// ```
104    pub fn fork(&mut self, n: usize) -> Vec<Self> {
105        (0..n).map(|_| Self::new(self.next_u64())).collect()
106    }
107
108    /// Generate the next random `u64`.
109    ///
110    /// # Examples
111    ///
112    /// ```
113    /// use scivex_core::random::Rng;
114    /// let mut rng = Rng::new(1);
115    /// let v = rng.next_u64(); // some pseudo-random u64
116    /// let _ = v; // value is deterministic but not checked here
117    /// ```
118    #[inline]
119    pub fn next_u64(&mut self) -> u64 {
120        let result = (self.s[1].wrapping_mul(5)).rotate_left(7).wrapping_mul(9);
121
122        let t = self.s[1] << 17;
123
124        self.s[2] ^= self.s[0];
125        self.s[3] ^= self.s[1];
126        self.s[1] ^= self.s[2];
127        self.s[0] ^= self.s[3];
128
129        self.s[2] ^= t;
130        self.s[3] = self.s[3].rotate_left(45);
131
132        result
133    }
134
135    /// Generate a random `f64` uniformly distributed in [0, 1).
136    ///
137    /// Uses the upper 53 bits of `next_u64` divided by 2^53.
138    #[inline]
139    pub fn next_f64(&mut self) -> f64 {
140        (self.next_u64() >> 11) as f64 * (1.0 / (1u64 << 53) as f64)
141    }
142
143    /// Generate a standard normal (N(0,1)) `f64` via the Ziggurat algorithm.
144    ///
145    /// ~97% of samples require only a multiply and comparison (no
146    /// transcendentals), making this much faster than Box-Muller.
147    ///
148    /// # Examples
149    ///
150    /// ```
151    /// use scivex_core::random::Rng;
152    /// let mut rng = Rng::new(0);
153    /// // Draw 1000 samples and verify mean is near 0
154    /// let mean: f64 = (0..1000).map(|_| rng.next_normal_f64()).sum::<f64>() / 1000.0;
155    /// assert!(mean.abs() < 0.2);
156    /// ```
157    pub fn next_normal_f64(&mut self) -> f64 {
158        ziggurat_normal(self)
159    }
160}
161
162// ---------------------------------------------------------------------------
163// Ziggurat algorithm for normal distribution
164// ---------------------------------------------------------------------------
165
166/// Number of rectangles in the Ziggurat decomposition.
167const ZIG_N: usize = 128;
168/// Right-most x of the base rectangle.
169const ZIG_R: f64 = 3.442619855899;
170/// Area of each rectangle (= area of the tail).
171const ZIG_V: f64 = 9.91256303526217e-3;
172
173/// Precomputed Ziggurat table: x-coordinates of rectangle right edges.
174/// `YTAB[i] = f(XTAB[i])` where `f(x) = exp(-x*x/2)`.
175fn zig_tables() -> ([f64; ZIG_N + 1], [f64; ZIG_N + 1]) {
176    let mut xtab = [0.0f64; ZIG_N + 1];
177    let mut ytab = [0.0f64; ZIG_N + 1];
178
179    let f = |x: f64| (-0.5 * x * x).exp();
180
181    xtab[ZIG_N] = ZIG_V / f(ZIG_R);
182    xtab[ZIG_N - 1] = ZIG_R;
183    ytab[ZIG_N] = 0.0;
184    ytab[ZIG_N - 1] = f(xtab[ZIG_N - 1]);
185
186    let mut i = ZIG_N - 2;
187    loop {
188        xtab[i] = (-2.0 * (ZIG_V / xtab[i + 1] + f(xtab[i + 1])).ln()).sqrt();
189        ytab[i] = f(xtab[i]);
190        if i == 0 {
191            break;
192        }
193        i -= 1;
194    }
195    // xtab[0] is the top (smallest x); ytab[0] = f(xtab[0]) ≈ 1.
196    (xtab, ytab)
197}
198
199/// Sample from the tail of the normal distribution (|x| > R).
200fn zig_tail(rng: &mut Rng, positive: bool) -> f64 {
201    loop {
202        let x = -rng.next_f64().ln() / ZIG_R; // exponential with rate R
203        let y = -rng.next_f64().ln();
204        if 2.0 * y >= x * x {
205            return if positive { ZIG_R + x } else { -(ZIG_R + x) };
206        }
207    }
208}
209
210/// Ziggurat normal: O(1) expected time, ~97% fast-path.
211fn ziggurat_normal(rng: &mut Rng) -> f64 {
212    // We compute tables once per call. In a tight loop the compiler will
213    // usually const-fold or cache these, but for absolute best perf we use
214    // a static lazy init.
215    use std::sync::OnceLock;
216    static TABLES: OnceLock<([f64; ZIG_N + 1], [f64; ZIG_N + 1])> = OnceLock::new();
217    let (xtab, ytab) = TABLES.get_or_init(zig_tables);
218
219    loop {
220        let u = rng.next_u64();
221        let i = (u & 0x7F) as usize; // bottom 7 bits → layer index [0, 127]
222        let sign = if u & 0x80 != 0 { 1.0 } else { -1.0 };
223        // Use remaining bits for a uniform float in [0, 1).
224        let u_float = (u >> 8) as f64 / ((1u64 << 56) as f64);
225        let x = u_float * xtab[i];
226
227        // Fast accept: x falls strictly inside rectangle i.
228        if x < xtab[i + 1] {
229            return sign * x;
230        }
231
232        // Bottom layer includes the tail.
233        if i == 0 {
234            return zig_tail(rng, sign > 0.0);
235        }
236
237        // Slow accept: sample within the wedge between rectangles.
238        let y = ytab[i + 1] + (ytab[i] - ytab[i + 1]) * rng.next_f64();
239        if y < (-0.5 * x * x).exp() {
240            return sign * x;
241        }
242    }
243}
244
245// ---------------------------------------------------------------------------
246// Free functions — random tensor generation
247// ---------------------------------------------------------------------------
248
249/// Create a tensor filled with values uniformly distributed in [0, 1).
250///
251/// # Examples
252///
253/// ```
254/// use scivex_core::random::{Rng, uniform};
255///
256/// let mut rng = Rng::new(0);
257/// let t = uniform::<f64>(&mut rng, vec![2, 3]);
258/// assert_eq!(t.shape(), &[2, 3]);
259/// assert!(t.iter().all(|&x| (0.0..1.0).contains(&x)));
260/// ```
261pub fn uniform<T: Float>(rng: &mut Rng, shape: Vec<usize>) -> Tensor<T> {
262    let numel: usize = shape.iter().product();
263    let data: Vec<T> = (0..numel).map(|_| T::from_f64(rng.next_f64())).collect();
264    Tensor::from_vec(data, shape).expect("shape product matches data length")
265}
266
267/// Create a tensor filled with values uniformly distributed in [`low`, `high`).
268///
269/// Returns an error if `low >= high`.
270///
271/// # Examples
272///
273/// ```
274/// use scivex_core::random::{Rng, uniform_range};
275/// let mut rng = Rng::new(0);
276/// let t = uniform_range::<f64>(&mut rng, vec![5], 2.0_f64, 5.0_f64).unwrap();
277/// assert!(t.iter().all(|&x| (2.0_f64..5.0_f64).contains(&x)));
278/// ```
279pub fn uniform_range<T: Float>(
280    rng: &mut Rng,
281    shape: Vec<usize>,
282    low: T,
283    high: T,
284) -> Result<Tensor<T>> {
285    if low >= high {
286        return Err(CoreError::InvalidArgument {
287            reason: "uniform_range requires low < high",
288        });
289    }
290    let range = high - low;
291    let numel: usize = shape.iter().product();
292    let data: Vec<T> = (0..numel)
293        .map(|_| low + T::from_f64(rng.next_f64()) * range)
294        .collect();
295    Ok(Tensor::from_vec(data, shape).expect("shape product matches data length"))
296}
297
298/// Create a tensor of samples from a Gaussian distribution.
299///
300/// Uses the Box-Muller transform internally.
301///
302/// # Examples
303///
304/// ```
305/// use scivex_core::random::{Rng, normal};
306/// let mut rng = Rng::new(1);
307/// let t = normal::<f64>(&mut rng, vec![3], 0.0_f64, 1.0_f64);
308/// assert_eq!(t.shape(), &[3]);
309/// ```
310pub fn normal<T: Float>(rng: &mut Rng, shape: Vec<usize>, mean: T, std_dev: T) -> Tensor<T> {
311    let numel: usize = shape.iter().product();
312    let data: Vec<T> = (0..numel)
313        .map(|_| mean + std_dev * T::from_f64(rng.next_normal_f64()))
314        .collect();
315    Tensor::from_vec(data, shape).expect("shape product matches data length")
316}
317
318/// Create a tensor of samples from the standard normal distribution N(0, 1).
319///
320/// # Examples
321///
322/// ```
323/// use scivex_core::random::{Rng, standard_normal};
324/// let mut rng = Rng::new(7);
325/// let t = standard_normal::<f64>(&mut rng, vec![4]);
326/// assert_eq!(t.shape(), &[4]);
327/// ```
328pub fn standard_normal<T: Float>(rng: &mut Rng, shape: Vec<usize>) -> Tensor<T> {
329    normal(rng, shape, T::zero(), T::one())
330}
331
332/// Create a tensor of random integers in [`low`, `high`).
333///
334/// Returns an error if `low >= high`.
335///
336/// # Examples
337///
338/// ```
339/// use scivex_core::random::{Rng, randint};
340/// let mut rng = Rng::new(0);
341/// let t = randint::<i32>(&mut rng, vec![10], 0, 5).unwrap();
342/// assert!(t.iter().all(|&x| (0..5).contains(&x)));
343/// ```
344pub fn randint<T: Integer>(rng: &mut Rng, shape: Vec<usize>, low: T, high: T) -> Result<Tensor<T>> {
345    if low >= high {
346        return Err(CoreError::InvalidArgument {
347            reason: "randint requires low < high",
348        });
349    }
350    // Find the range as a usize via binary search on from_usize.
351    let range = int_range_as_usize(low, high);
352    let numel: usize = shape.iter().product();
353    let data: Vec<T> = (0..numel)
354        .map(|_| {
355            let idx = (rng.next_f64() * range as f64) as usize;
356            low + T::from_usize(idx.min(range - 1))
357        })
358        .collect();
359    Ok(Tensor::from_vec(data, shape).expect("shape product matches data length"))
360}
361
362/// Compute the number of integers in [low, high) as a `usize`.
363/// used for random integer generation where ranges are typically small, the
364/// linear scan is acceptable. For large ranges a binary search on `from_usize`
365/// is unsafe because signed types wrap on overflow.
366fn int_range_as_usize<T: Integer>(low: T, high: T) -> usize {
367    (high - low).to_usize()
368}
369
370/// Create a tensor of Bernoulli random variables (0 or 1) with probability `p`.
371///
372/// Returns an error if `p` is not in [0, 1].
373///
374/// # Examples
375///
376/// ```
377/// use scivex_core::random::{Rng, bernoulli};
378/// let mut rng = Rng::new(0);
379/// let t = bernoulli::<f64>(&mut rng, vec![10], 0.5).unwrap();
380/// assert!(t.iter().all(|&x| x == 0.0 || x == 1.0));
381/// ```
382pub fn bernoulli<T: Scalar>(rng: &mut Rng, shape: Vec<usize>, p: f64) -> Result<Tensor<T>> {
383    if !(0.0..=1.0).contains(&p) {
384        return Err(CoreError::InvalidArgument {
385            reason: "bernoulli requires p in [0, 1]",
386        });
387    }
388    let numel: usize = shape.iter().product();
389    let data: Vec<T> = (0..numel)
390        .map(|_| {
391            if rng.next_f64() < p {
392                T::one()
393            } else {
394                T::zero()
395            }
396        })
397        .collect();
398    Ok(Tensor::from_vec(data, shape).expect("shape product matches data length"))
399}
400
401/// Shuffle the elements of a tensor in-place using the Fisher-Yates algorithm.
402///
403/// Operates on the flat (storage-order) data regardless of shape.
404///
405/// # Examples
406///
407/// ```
408/// use scivex_core::random::{Rng, shuffle};
409/// use scivex_core::Tensor;
410/// let mut rng = Rng::new(0);
411/// let mut t = Tensor::from_vec(vec![1, 2, 3, 4, 5], vec![5]).unwrap();
412/// shuffle(&mut rng, &mut t);
413/// let mut sorted = t.as_slice().to_vec();
414/// sorted.sort_unstable();
415/// assert_eq!(sorted, vec![1, 2, 3, 4, 5]);
416/// ```
417pub fn shuffle<T: Scalar>(rng: &mut Rng, tensor: &mut Tensor<T>) {
418    let n = tensor.numel();
419    if n <= 1 {
420        return;
421    }
422    let data = tensor.as_mut_slice();
423    for i in (1..n).rev() {
424        // Generate j in [0, i] using rejection-free modular reduction.
425        let j = (rng.next_f64() * (i + 1) as f64) as usize;
426        // Clamp to valid range (floating-point edge case).
427        let j = j.min(i);
428        data.swap(i, j);
429    }
430}
431
432/// Sample `n` elements from a 1-D tensor.
433///
434/// - `replace = true`: sampling with replacement (may contain duplicates).
435/// - `replace = false`: sampling without replacement. Returns an error if
436///   `n > tensor.numel()`.
437///
438/// Returns an error if the tensor is not 1-D.
439///
440/// # Examples
441///
442/// ```
443/// use scivex_core::random::{Rng, choice};
444/// use scivex_core::Tensor;
445/// let mut rng = Rng::new(0);
446/// let t = Tensor::from_vec(vec![10, 20, 30, 40, 50], vec![5]).unwrap();
447/// let sample = choice(&mut rng, &t, 3, false).unwrap();
448/// assert_eq!(sample.shape(), &[3]);
449/// ```
450pub fn choice<T: Scalar>(
451    rng: &mut Rng,
452    tensor: &Tensor<T>,
453    n: usize,
454    replace: bool,
455) -> Result<Tensor<T>> {
456    if tensor.ndim() != 1 {
457        return Err(CoreError::InvalidArgument {
458            reason: "choice requires a 1-D tensor",
459        });
460    }
461    let len = tensor.numel();
462
463    if !replace && n > len {
464        return Err(CoreError::InvalidArgument {
465            reason: "choice without replacement: n > tensor length",
466        });
467    }
468
469    let src = tensor.as_slice();
470
471    if replace {
472        let data: Vec<T> = (0..n)
473            .map(|_| {
474                let idx = (rng.next_f64() * len as f64) as usize;
475                src[idx.min(len - 1)]
476            })
477            .collect();
478        Tensor::from_vec(data, vec![n])
479    } else {
480        // Fisher-Yates partial shuffle on index array
481        let mut indices: Vec<usize> = (0..len).collect();
482        for i in 0..n {
483            let j = i + (rng.next_f64() * (len - i) as f64) as usize;
484            let j = j.min(len - 1);
485            indices.swap(i, j);
486        }
487        let data: Vec<T> = indices[..n].iter().map(|&i| src[i]).collect();
488        Tensor::from_vec(data, vec![n])
489    }
490}
491
492// ===========================================================================
493// Tests
494// ===========================================================================
495
496#[cfg(test)]
497#[allow(clippy::float_cmp)]
498mod tests {
499    use super::*;
500
501    #[test]
502    fn test_rng_reproducibility() {
503        let mut rng1 = Rng::new(12345);
504        let mut rng2 = Rng::new(12345);
505        for _ in 0..100 {
506            assert_eq!(rng1.next_u64(), rng2.next_u64());
507        }
508    }
509
510    #[test]
511    fn test_rng_different_seeds() {
512        let mut rng1 = Rng::new(1);
513        let mut rng2 = Rng::new(2);
514        // Extremely unlikely to be equal
515        let seq1: Vec<u64> = (0..10).map(|_| rng1.next_u64()).collect();
516        let seq2: Vec<u64> = (0..10).map(|_| rng2.next_u64()).collect();
517        assert_ne!(seq1, seq2);
518    }
519
520    #[test]
521    fn test_next_f64_range() {
522        let mut rng = Rng::new(42);
523        for _ in 0..10_000 {
524            let v = rng.next_f64();
525            assert!((0.0..1.0).contains(&v), "next_f64 out of range: {v}");
526        }
527    }
528
529    #[test]
530    fn test_reseed() {
531        let mut rng = Rng::new(99);
532        let first = rng.next_u64();
533        rng.seed(99);
534        let second = rng.next_u64();
535        assert_eq!(first, second);
536    }
537
538    #[test]
539    fn test_fork() {
540        let mut rng = Rng::new(42);
541        let children = rng.fork(4);
542        assert_eq!(children.len(), 4);
543        // All children should produce different sequences.
544        let vals: Vec<u64> = children.into_iter().map(|mut r| r.next_u64()).collect();
545        for i in 0..vals.len() {
546            for j in (i + 1)..vals.len() {
547                assert_ne!(vals[i], vals[j], "child RNGs should be independent");
548            }
549        }
550    }
551
552    #[test]
553    fn test_fork_reproducible() {
554        let mut rng1 = Rng::new(42);
555        let children1 = rng1.fork(3);
556        let mut rng2 = Rng::new(42);
557        let children2 = rng2.fork(3);
558        // Same seed → same children.
559        for (mut c1, mut c2) in children1.into_iter().zip(children2) {
560            assert_eq!(c1.next_u64(), c2.next_u64());
561        }
562    }
563
564    #[test]
565    fn test_uniform_shape() {
566        let mut rng = Rng::new(0);
567        let t = uniform::<f64>(&mut rng, vec![3, 4, 5]);
568        assert_eq!(t.shape(), &[3, 4, 5]);
569        assert_eq!(t.numel(), 60);
570    }
571
572    #[test]
573    fn test_uniform_range_values() {
574        let mut rng = Rng::new(0);
575        let t = uniform::<f64>(&mut rng, vec![1000]);
576        for &v in t.as_slice() {
577            assert!((0.0..1.0).contains(&v));
578        }
579    }
580
581    #[test]
582    fn test_uniform_range_bounds() {
583        let mut rng = Rng::new(7);
584        let t = uniform_range::<f64>(&mut rng, vec![5000], 2.0, 5.0).unwrap();
585        for &v in t.as_slice() {
586            assert!((2.0..5.0).contains(&v), "value {v} out of [2, 5)");
587        }
588    }
589
590    #[test]
591    fn test_uniform_range_invalid() {
592        let mut rng = Rng::new(0);
593        assert!(uniform_range::<f64>(&mut rng, vec![10], 5.0, 2.0).is_err());
594        assert!(uniform_range::<f64>(&mut rng, vec![10], 3.0, 3.0).is_err());
595    }
596
597    #[test]
598    fn test_standard_normal_stats() {
599        let mut rng = Rng::new(42);
600        let t = standard_normal::<f64>(&mut rng, vec![100_000]);
601        let data = t.as_slice();
602
603        let mean = data.iter().sum::<f64>() / data.len() as f64;
604        let var = data.iter().map(|&x| (x - mean) * (x - mean)).sum::<f64>() / data.len() as f64;
605        let std = var.sqrt();
606
607        assert!(
608            mean.abs() < 0.02,
609            "standard normal mean too far from 0: {mean}"
610        );
611        assert!(
612            (std - 1.0).abs() < 0.02,
613            "standard normal std too far from 1: {std}"
614        );
615    }
616
617    #[test]
618    fn test_normal_custom() {
619        let mut rng = Rng::new(42);
620        let t = normal::<f64>(&mut rng, vec![50_000], 10.0, 2.0);
621        let data = t.as_slice();
622
623        let mean = data.iter().sum::<f64>() / data.len() as f64;
624        assert!(
625            (mean - 10.0).abs() < 0.1,
626            "normal(10, 2) mean too far from 10: {mean}"
627        );
628    }
629
630    #[test]
631    fn test_randint_range() {
632        let mut rng = Rng::new(0);
633        let t = randint::<i32>(&mut rng, vec![10_000], 5, 10).unwrap();
634        for &v in t.as_slice() {
635            assert!((5..10).contains(&v), "randint value {v} not in [5, 10)");
636        }
637    }
638
639    #[test]
640    fn test_randint_invalid() {
641        let mut rng = Rng::new(0);
642        assert!(randint::<i32>(&mut rng, vec![10], 10, 5).is_err());
643        assert!(randint::<i32>(&mut rng, vec![10], 5, 5).is_err());
644    }
645
646    #[test]
647    fn test_bernoulli_values() {
648        let mut rng = Rng::new(0);
649        let t = bernoulli::<f64>(&mut rng, vec![10_000], 0.3).unwrap();
650        for &v in t.as_slice() {
651            assert!(v == 0.0 || v == 1.0, "bernoulli value {v} not 0 or 1");
652        }
653        // Check frequency is approximately p
654        let ones = t.as_slice().iter().filter(|&&x| x == 1.0).count();
655        let freq = ones as f64 / 10_000.0;
656        assert!(
657            (freq - 0.3).abs() < 0.03,
658            "bernoulli frequency {freq} too far from 0.3"
659        );
660    }
661
662    #[test]
663    fn test_bernoulli_invalid() {
664        let mut rng = Rng::new(0);
665        assert!(bernoulli::<f64>(&mut rng, vec![10], -0.1).is_err());
666        assert!(bernoulli::<f64>(&mut rng, vec![10], 1.1).is_err());
667    }
668
669    #[test]
670    fn test_shuffle_preserves_elements() {
671        let mut rng = Rng::new(42);
672        let mut t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10], vec![10]).unwrap();
673        shuffle(&mut rng, &mut t);
674
675        let mut sorted = t.as_slice().to_vec();
676        sorted.sort_unstable();
677        assert_eq!(sorted, vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
678    }
679
680    #[test]
681    fn test_shuffle_modifies_order() {
682        let mut rng = Rng::new(42);
683        let original = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
684        let mut t = Tensor::from_vec(original.clone(), vec![10]).unwrap();
685        shuffle(&mut rng, &mut t);
686        // Very unlikely to remain in original order with 10 elements
687        assert_ne!(t.as_slice(), &original[..]);
688    }
689
690    #[test]
691    fn test_choice_with_replacement() {
692        let mut rng = Rng::new(0);
693        let t = Tensor::from_vec(vec![10.0, 20.0, 30.0, 40.0, 50.0], vec![5]).unwrap();
694        let sample = choice(&mut rng, &t, 100, true).unwrap();
695        assert_eq!(sample.shape(), &[100]);
696        // All values should be from the original tensor
697        let valid = [10.0, 20.0, 30.0, 40.0, 50.0];
698        for &v in sample.as_slice() {
699            assert!(valid.contains(&v), "unexpected value {v}");
700        }
701    }
702
703    #[test]
704    fn test_choice_without_replacement() {
705        let mut rng = Rng::new(0);
706        let t = Tensor::from_vec(vec![10, 20, 30, 40, 50], vec![5]).unwrap();
707        let sample = choice(&mut rng, &t, 3, false).unwrap();
708        assert_eq!(sample.shape(), &[3]);
709
710        // No duplicates
711        let data = sample.as_slice();
712        let mut dedup = data.to_vec();
713        dedup.sort_unstable();
714        dedup.dedup();
715        assert_eq!(dedup.len(), 3);
716    }
717
718    #[test]
719    fn test_choice_without_replacement_too_many() {
720        let mut rng = Rng::new(0);
721        let t = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
722        assert!(choice(&mut rng, &t, 5, false).is_err());
723    }
724
725    #[test]
726    fn test_choice_not_1d() {
727        let mut rng = Rng::new(0);
728        let t = Tensor::from_vec(vec![1, 2, 3, 4], vec![2, 2]).unwrap();
729        assert!(choice(&mut rng, &t, 2, true).is_err());
730    }
731
732    #[test]
733    fn test_uniform_f32() {
734        let mut rng = Rng::new(42);
735        let t = uniform::<f32>(&mut rng, vec![100]);
736        assert_eq!(t.shape(), &[100]);
737        for &v in t.as_slice() {
738            assert!((0.0..1.0).contains(&v));
739        }
740    }
741}