Skip to main content

u_numflow/
random.rs

1//! Random number generation, shuffling, and weighted sampling.
2//!
3//! Provides seeded RNG construction, Fisher-Yates shuffle, and
4//! weighted random sampling utilities.
5//!
6//! # Reproducibility
7//!
8//! For reproducible experiments, use [`create_rng`] with a fixed seed.
9//! The underlying algorithm (SmallRng) is deterministic for a given seed
10//! on the same platform.
11
12use rand::Rng;
13
14/// Creates a fast, seeded random number generator.
15///
16/// Uses `SmallRng` (Xoshiro256++) for high performance.
17/// The sequence is deterministic for a given seed on the same platform.
18///
19/// # Examples
20/// ```
21/// use u_numflow::random::create_rng;
22/// use rand::Rng;
23/// let mut rng = create_rng(42);
24/// let x: f64 = rng.random();
25/// assert!(x >= 0.0 && x < 1.0);
26/// ```
27pub fn create_rng(seed: u64) -> rand::rngs::SmallRng {
28    use rand::SeedableRng;
29    rand::rngs::SmallRng::seed_from_u64(seed)
30}
31
32/// Fisher-Yates (Durstenfeld) in-place shuffle.
33///
34/// Produces a uniformly random permutation: each of the n! permutations
35/// is equally likely.
36///
37/// # Algorithm
38/// Modern variant due to Durstenfeld (1964), popularized by Knuth as
39/// "Algorithm P". Iterates backwards, swapping each element with a
40/// uniformly chosen earlier (or same) position.
41///
42/// Reference: Knuth (1997), *TAOCP* Vol. 2, §3.4.2, Algorithm P.
43///
44/// # Complexity
45/// Time: O(n), Space: O(1) (in-place)
46///
47/// # Examples
48/// ```
49/// use u_numflow::random::{create_rng, shuffle};
50/// let mut v = vec![1, 2, 3, 4, 5];
51/// let mut rng = create_rng(42);
52/// shuffle(&mut v, &mut rng);
53/// // v is now a permutation of [1, 2, 3, 4, 5]
54/// v.sort();
55/// assert_eq!(v, vec![1, 2, 3, 4, 5]);
56/// ```
57pub fn shuffle<T, R: Rng>(slice: &mut [T], rng: &mut R) {
58    let n = slice.len();
59    if n <= 1 {
60        return;
61    }
62    for i in (1..n).rev() {
63        let j = rng.random_range(0..=i);
64        slice.swap(i, j);
65    }
66}
67
68/// Returns a shuffled index permutation of `[0, n)`.
69///
70/// Generates a random permutation of indices without modifying the
71/// original data. Useful when you need to iterate over data in random
72/// order without cloning.
73///
74/// # Complexity
75/// Time: O(n), Space: O(n)
76///
77/// # Examples
78/// ```
79/// use u_numflow::random::{create_rng, shuffled_indices};
80/// let mut rng = create_rng(42);
81/// let indices = shuffled_indices(5, &mut rng);
82/// assert_eq!(indices.len(), 5);
83/// let mut sorted = indices.clone();
84/// sorted.sort();
85/// assert_eq!(sorted, vec![0, 1, 2, 3, 4]);
86/// ```
87pub fn shuffled_indices<R: Rng>(n: usize, rng: &mut R) -> Vec<usize> {
88    let mut indices: Vec<usize> = (0..n).collect();
89    shuffle(&mut indices, rng);
90    indices
91}
92
93/// Selects a random index weighted by the given weights.
94///
95/// Uses the CDF binary search method. For repeated sampling from the
96/// same weights, prefer [`WeightedSampler`].
97///
98/// # Complexity
99/// Time: O(n) per sample
100///
101/// # Returns
102/// - `None` if `weights` is empty or all weights are zero.
103///
104/// # Examples
105/// ```
106/// use u_numflow::random::{create_rng, weighted_choose};
107/// let mut rng = create_rng(42);
108/// let weights = [1.0, 2.0, 3.0]; // index 2 is most likely
109/// let idx = weighted_choose(&weights, &mut rng).unwrap();
110/// assert!(idx < 3);
111/// ```
112pub fn weighted_choose<R: Rng>(weights: &[f64], rng: &mut R) -> Option<usize> {
113    if weights.is_empty() {
114        return None;
115    }
116
117    let total: f64 = weights.iter().filter(|w| **w > 0.0).sum();
118    if total <= 0.0 {
119        return None;
120    }
121
122    let threshold = rng.random_range(0.0..total);
123    let mut cumulative = 0.0;
124    for (i, &w) in weights.iter().enumerate() {
125        if w > 0.0 {
126            cumulative += w;
127            if cumulative > threshold {
128                return Some(i);
129            }
130        }
131    }
132
133    // Fallback (floating-point edge case)
134    Some(weights.len() - 1)
135}
136
137/// Pre-computed weighted sampler for O(log n) repeated sampling.
138///
139/// Builds a cumulative distribution table from weights, then uses
140/// binary search for each sample.
141///
142/// # Algorithm
143/// CDF-based weighted sampling with binary search.
144///
145/// # Complexity
146/// - Construction: O(n)
147/// - Sampling: O(log n)
148///
149/// # Examples
150/// ```
151/// use u_numflow::random::{create_rng, WeightedSampler};
152/// let weights = vec![1.0, 2.0, 3.0, 4.0];
153/// let sampler = WeightedSampler::new(&weights).unwrap();
154/// let mut rng = create_rng(42);
155/// let idx = sampler.sample(&mut rng);
156/// assert!(idx < 4);
157/// ```
158pub struct WeightedSampler {
159    cumulative: Vec<f64>,
160    total: f64,
161}
162
163impl WeightedSampler {
164    /// Creates a new weighted sampler from the given weights.
165    ///
166    /// # Returns
167    /// - `None` if `weights` is empty or all weights are zero/negative.
168    pub fn new(weights: &[f64]) -> Option<Self> {
169        if weights.is_empty() {
170            return None;
171        }
172
173        let mut cumulative = Vec::with_capacity(weights.len());
174        let mut total = 0.0;
175        for &w in weights {
176            if w > 0.0 {
177                total += w;
178            }
179            cumulative.push(total);
180        }
181
182        if total <= 0.0 {
183            return None;
184        }
185
186        Some(Self { cumulative, total })
187    }
188
189    /// Samples a random index according to the weights.
190    ///
191    /// # Complexity
192    /// O(log n) via binary search.
193    pub fn sample<R: Rng>(&self, rng: &mut R) -> usize {
194        let threshold = rng.random_range(0.0..self.total);
195        match self.cumulative.binary_search_by(|c| {
196            c.partial_cmp(&threshold)
197                .expect("cumulative values are finite")
198        }) {
199            Ok(i) => i,
200            Err(i) => i.min(self.cumulative.len() - 1),
201        }
202    }
203
204    /// Returns the number of categories.
205    pub fn len(&self) -> usize {
206        self.cumulative.len()
207    }
208
209    /// Returns true if there are no categories.
210    pub fn is_empty(&self) -> bool {
211        self.cumulative.is_empty()
212    }
213
214    /// Returns the total weight.
215    pub fn total_weight(&self) -> f64 {
216        self.total
217    }
218}
219
220// ============================================================================
221// Tests
222// ============================================================================
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    #[test]
229    fn test_create_rng_deterministic() {
230        let mut rng1 = create_rng(42);
231        let mut rng2 = create_rng(42);
232        let vals1: Vec<f64> = (0..10).map(|_| rng1.random()).collect();
233        let vals2: Vec<f64> = (0..10).map(|_| rng2.random()).collect();
234        assert_eq!(vals1, vals2);
235    }
236
237    #[test]
238    fn test_shuffle_preserves_elements() {
239        let mut v = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
240        let mut rng = create_rng(123);
241        shuffle(&mut v, &mut rng);
242        v.sort();
243        assert_eq!(v, vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
244    }
245
246    #[test]
247    fn test_shuffle_empty() {
248        let mut v: Vec<i32> = vec![];
249        let mut rng = create_rng(0);
250        shuffle(&mut v, &mut rng); // should not panic
251    }
252
253    #[test]
254    fn test_shuffle_single() {
255        let mut v = vec![42];
256        let mut rng = create_rng(0);
257        shuffle(&mut v, &mut rng);
258        assert_eq!(v, vec![42]);
259    }
260
261    #[test]
262    fn test_shuffle_actually_shuffles() {
263        // With 10 elements, probability of identity permutation is 1/10! ≈ 2.8e-7
264        let original = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
265        let mut v = original.clone();
266        let mut rng = create_rng(42);
267        shuffle(&mut v, &mut rng);
268        assert_ne!(v, original, "shuffle should change order (probabilistic)");
269    }
270
271    #[test]
272    fn test_shuffled_indices() {
273        let mut rng = create_rng(42);
274        let indices = shuffled_indices(10, &mut rng);
275        assert_eq!(indices.len(), 10);
276        let mut sorted = indices.clone();
277        sorted.sort();
278        assert_eq!(sorted, (0..10).collect::<Vec<_>>());
279    }
280
281    #[test]
282    fn test_weighted_choose_basic() {
283        let mut rng = create_rng(42);
284        let weights = [0.0, 0.0, 1.0]; // only index 2 has weight
285        for _ in 0..100 {
286            assert_eq!(weighted_choose(&weights, &mut rng), Some(2));
287        }
288    }
289
290    #[test]
291    fn test_weighted_choose_empty() {
292        let mut rng = create_rng(42);
293        assert_eq!(weighted_choose(&[], &mut rng), None);
294    }
295
296    #[test]
297    fn test_weighted_choose_all_zero() {
298        let mut rng = create_rng(42);
299        assert_eq!(weighted_choose(&[0.0, 0.0], &mut rng), None);
300    }
301
302    #[test]
303    fn test_weighted_choose_distribution() {
304        let mut rng = create_rng(42);
305        let weights = [1.0, 3.0]; // index 1 should be ~3x more likely
306        let mut counts = [0u32; 2];
307        let n = 10000;
308        for _ in 0..n {
309            let idx = weighted_choose(&weights, &mut rng).unwrap();
310            counts[idx] += 1;
311        }
312        let ratio = counts[1] as f64 / counts[0] as f64;
313        assert!(
314            (ratio - 3.0).abs() < 0.5,
315            "expected ratio ~3.0, got {ratio}"
316        );
317    }
318
319    #[test]
320    fn test_weighted_sampler_basic() {
321        let sampler = WeightedSampler::new(&[1.0, 2.0, 3.0]).unwrap();
322        assert_eq!(sampler.len(), 3);
323        assert!(!sampler.is_empty());
324        assert!((sampler.total_weight() - 6.0).abs() < 1e-15);
325    }
326
327    #[test]
328    fn test_weighted_sampler_deterministic_weight() {
329        let sampler = WeightedSampler::new(&[0.0, 0.0, 1.0]).unwrap();
330        let mut rng = create_rng(42);
331        for _ in 0..100 {
332            assert_eq!(sampler.sample(&mut rng), 2);
333        }
334    }
335
336    #[test]
337    fn test_weighted_sampler_distribution() {
338        let sampler = WeightedSampler::new(&[1.0, 3.0]).unwrap();
339        let mut rng = create_rng(42);
340        let mut counts = [0u32; 2];
341        let n = 10000;
342        for _ in 0..n {
343            counts[sampler.sample(&mut rng)] += 1;
344        }
345        let ratio = counts[1] as f64 / counts[0] as f64;
346        assert!(
347            (ratio - 3.0).abs() < 0.5,
348            "expected ratio ~3.0, got {ratio}"
349        );
350    }
351
352    #[test]
353    fn test_weighted_sampler_empty() {
354        assert!(WeightedSampler::new(&[]).is_none());
355    }
356}
357
358#[cfg(test)]
359mod proptests {
360    use super::*;
361    use proptest::prelude::*;
362
363    proptest! {
364        #![proptest_config(ProptestConfig::with_cases(300))]
365
366        #[test]
367        fn shuffle_is_permutation(
368            seed in 0_u64..10000,
369            data in proptest::collection::vec(0_i32..1000, 0..50),
370        ) {
371            let mut shuffled = data.clone();
372            let mut rng = create_rng(seed);
373            shuffle(&mut shuffled, &mut rng);
374            let mut sorted_orig = data.clone();
375            let mut sorted_shuf = shuffled;
376            sorted_orig.sort();
377            sorted_shuf.sort();
378            prop_assert_eq!(sorted_orig, sorted_shuf);
379        }
380
381        #[test]
382        fn weighted_choose_returns_valid_index(
383            seed in 0_u64..10000,
384            weights in proptest::collection::vec(0.0_f64..10.0, 1..20),
385        ) {
386            let has_positive = weights.iter().any(|&w| w > 0.0);
387            let mut rng = create_rng(seed);
388            let result = weighted_choose(&weights, &mut rng);
389            if has_positive {
390                let idx = result.unwrap();
391                prop_assert!(idx < weights.len());
392            }
393        }
394    }
395}