sklears_utils/
random.rs

1//! Random number generation utilities
2
3use crate::{UtilsError, UtilsResult};
4use lazy_static::lazy_static;
5use scirs2_core::random::rngs::StdRng;
6use scirs2_core::random::{Rng, SeedableRng};
7use std::collections::HashMap;
8use std::sync::Mutex;
9
10lazy_static! {
11    static ref GLOBAL_RNG: Mutex<StdRng> = Mutex::new(StdRng::seed_from_u64(42));
12}
13
14/// Set the global random seed for reproducible results
15pub fn set_random_state(seed: u64) {
16    let mut rng = GLOBAL_RNG.lock().unwrap();
17    *rng = StdRng::seed_from_u64(seed);
18}
19
20/// Get a random number generator with the specified seed
21pub fn get_rng(seed: Option<u64>) -> StdRng {
22    match seed {
23        Some(s) => StdRng::seed_from_u64(s),
24        None => {
25            let mut rng = GLOBAL_RNG.lock().unwrap();
26            StdRng::seed_from_u64(rng.gen::<u64>())
27        }
28    }
29}
30
31/// Generate random indices for sampling
32pub fn random_indices(
33    n_samples: usize,
34    size: usize,
35    replace: bool,
36    seed: Option<u64>,
37) -> UtilsResult<Vec<usize>> {
38    if !replace && size > n_samples {
39        return Err(UtilsError::InvalidParameter(format!(
40            "Cannot sample {size} items from {n_samples} without replacement"
41        )));
42    }
43
44    let mut rng = get_rng(seed);
45    let mut indices = Vec::with_capacity(size);
46
47    if replace {
48        // Sampling with replacement
49        for _ in 0..size {
50            indices.push(rng.gen_range(0..n_samples));
51        }
52    } else {
53        // Sampling without replacement
54        let mut available: Vec<usize> = (0..n_samples).collect();
55        for _ in 0..size {
56            let idx = rng.gen_range(0..available.len());
57            indices.push(available.swap_remove(idx));
58        }
59    }
60
61    Ok(indices)
62}
63
64/// Shuffle an array of indices in place
65pub fn shuffle_indices(indices: &mut [usize], seed: Option<u64>) {
66    let mut rng = get_rng(seed);
67    for i in (1..indices.len()).rev() {
68        let j = rng.gen_range(0..=i);
69        indices.swap(i, j);
70    }
71}
72
73/// Generate a random permutation of indices
74pub fn random_permutation(n: usize, seed: Option<u64>) -> Vec<usize> {
75    let mut indices: Vec<usize> = (0..n).collect();
76    shuffle_indices(&mut indices, seed);
77    indices
78}
79
80/// Split indices into train/test sets
81pub fn train_test_split_indices(
82    n_samples: usize,
83    test_size: f64,
84    shuffle: bool,
85    seed: Option<u64>,
86) -> UtilsResult<(Vec<usize>, Vec<usize>)> {
87    if test_size <= 0.0 || test_size >= 1.0 {
88        return Err(UtilsError::InvalidParameter(format!(
89            "test_size must be in (0, 1), got {test_size}"
90        )));
91    }
92
93    let test_samples = (n_samples as f64 * test_size).round() as usize;
94    let train_samples = n_samples - test_samples;
95
96    let indices = if shuffle {
97        random_permutation(n_samples, seed)
98    } else {
99        (0..n_samples).collect()
100    };
101
102    let train_indices = indices[..train_samples].to_vec();
103    let test_indices = indices[train_samples..].to_vec();
104
105    Ok((train_indices, test_indices))
106}
107
108/// Generate random weights that sum to 1
109pub fn random_weights(n: usize, seed: Option<u64>) -> Vec<f64> {
110    let mut rng = get_rng(seed);
111    let mut weights: Vec<f64> = (0..n).map(|_| rng.gen::<f64>()).collect();
112    let sum: f64 = weights.iter().sum();
113
114    if sum > 0.0 {
115        for w in &mut weights {
116            *w /= sum;
117        }
118    } else {
119        // Fallback to uniform weights
120        let uniform_weight = 1.0 / n as f64;
121        weights.fill(uniform_weight);
122    }
123
124    weights
125}
126
127/// Bootstrap sampling: sample n_samples with replacement
128pub fn bootstrap_indices(n_samples: usize, seed: Option<u64>) -> Vec<usize> {
129    random_indices(n_samples, n_samples, true, seed).unwrap()
130}
131
132/// Generate k-fold cross-validation indices
133pub fn k_fold_indices(
134    n_samples: usize,
135    n_splits: usize,
136    shuffle: bool,
137    seed: Option<u64>,
138) -> UtilsResult<Vec<(Vec<usize>, Vec<usize>)>> {
139    if n_splits < 2 {
140        return Err(UtilsError::InvalidParameter(format!(
141            "n_splits must be at least 2, got {n_splits}"
142        )));
143    }
144
145    if n_splits > n_samples {
146        return Err(UtilsError::InvalidParameter(format!(
147            "n_splits {n_splits} cannot be greater than the number of samples {n_samples}"
148        )));
149    }
150
151    let indices = if shuffle {
152        random_permutation(n_samples, seed)
153    } else {
154        (0..n_samples).collect()
155    };
156
157    let mut folds = Vec::with_capacity(n_splits);
158    let fold_sizes: Vec<usize> = (0..n_splits)
159        .map(|i| (n_samples + n_splits - i - 1) / n_splits)
160        .collect();
161
162    let mut start = 0;
163    for fold_size in fold_sizes {
164        let end = start + fold_size;
165        let test_indices = indices[start..end].to_vec();
166        let mut train_indices = Vec::with_capacity(n_samples - fold_size);
167        train_indices.extend(&indices[..start]);
168        train_indices.extend(&indices[end..]);
169
170        folds.push((train_indices, test_indices));
171        start = end;
172    }
173
174    Ok(folds)
175}
176
177/// Generate stratified train/test split indices
178pub fn stratified_split_indices(
179    labels: &[i32],
180    test_size: f64,
181    seed: Option<u64>,
182) -> UtilsResult<(Vec<usize>, Vec<usize>)> {
183    if test_size <= 0.0 || test_size >= 1.0 {
184        return Err(UtilsError::InvalidParameter(format!(
185            "test_size must be in (0, 1), got {test_size}"
186        )));
187    }
188
189    // Group indices by class
190    let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
191    for (idx, &label) in labels.iter().enumerate() {
192        class_indices.entry(label).or_default().push(idx);
193    }
194
195    let mut train_indices = Vec::new();
196    let mut test_indices = Vec::new();
197
198    // Split each class separately
199    for indices in class_indices.values() {
200        let n_class = indices.len();
201        let n_test = (n_class as f64 * test_size).round() as usize;
202        let n_test = n_test.max(1).min(n_class - 1); // Ensure at least 1 in each split
203
204        let (class_train, class_test) =
205            train_test_split_indices(n_class, n_test as f64 / n_class as f64, true, seed)?;
206
207        train_indices.extend(class_train.iter().map(|&i| indices[i]));
208        test_indices.extend(class_test.iter().map(|&i| indices[i]));
209    }
210
211    Ok((train_indices, test_indices))
212}
213
214/// Reservoir sampling: efficiently sample k items from a stream of unknown size
215pub fn reservoir_sampling<T: Clone>(
216    items: impl Iterator<Item = T>,
217    k: usize,
218    seed: Option<u64>,
219) -> Vec<T> {
220    if k == 0 {
221        return Vec::new();
222    }
223
224    let mut rng = get_rng(seed);
225    let mut reservoir = Vec::with_capacity(k);
226
227    for (i, item) in items.enumerate() {
228        if i < k {
229            // Fill the reservoir for the first k items
230            reservoir.push(item);
231        } else {
232            // Randomly replace items in the reservoir
233            let j = rng.gen_range(0..=i);
234            if j < k {
235                reservoir[j] = item;
236            }
237        }
238    }
239
240    reservoir
241}
242
243/// Weighted sampling without replacement using systematic sampling
244pub fn weighted_sampling_without_replacement(
245    weights: &[f64],
246    k: usize,
247    seed: Option<u64>,
248) -> UtilsResult<Vec<usize>> {
249    if weights.is_empty() {
250        return Err(UtilsError::EmptyInput);
251    }
252
253    if k > weights.len() {
254        return Err(UtilsError::InvalidParameter(format!(
255            "Cannot sample {} items from {} weights without replacement",
256            k,
257            weights.len()
258        )));
259    }
260
261    let sum: f64 = weights.iter().sum();
262    if sum <= 0.0 {
263        return Err(UtilsError::InvalidParameter(
264            "Sum of weights must be positive".to_string(),
265        ));
266    }
267
268    let mut rng = get_rng(seed);
269    let mut cumsum = Vec::with_capacity(weights.len());
270    let mut running_sum = 0.0;
271
272    for &w in weights {
273        running_sum += w;
274        cumsum.push(running_sum / sum);
275    }
276
277    let mut selected = Vec::new();
278    let mut used = vec![false; weights.len()];
279
280    for _ in 0..k {
281        loop {
282            let r: f64 = rng.gen::<f64>();
283            let idx = cumsum
284                .binary_search_by(|&x| x.partial_cmp(&r).unwrap())
285                .unwrap_or_else(|i| i);
286
287            if idx < weights.len() && !used[idx] {
288                used[idx] = true;
289                selected.push(idx);
290                break;
291            }
292        }
293    }
294
295    Ok(selected)
296}
297
298/// Importance sampling: sample indices according to importance weights
299pub fn importance_sampling(
300    weights: &[f64],
301    n_samples: usize,
302    seed: Option<u64>,
303) -> UtilsResult<Vec<usize>> {
304    if weights.is_empty() {
305        return Err(UtilsError::EmptyInput);
306    }
307
308    let sum: f64 = weights.iter().sum();
309    if sum <= 0.0 {
310        return Err(UtilsError::InvalidParameter(
311            "Sum of weights must be positive".to_string(),
312        ));
313    }
314
315    let mut rng = get_rng(seed);
316    let mut cumsum = Vec::with_capacity(weights.len());
317    let mut running_sum = 0.0;
318
319    for &w in weights {
320        running_sum += w;
321        cumsum.push(running_sum / sum);
322    }
323
324    let mut samples = Vec::with_capacity(n_samples);
325
326    for _ in 0..n_samples {
327        let r: f64 = rng.gen::<f64>();
328        let idx = cumsum
329            .binary_search_by(|&x| x.partial_cmp(&r).unwrap())
330            .unwrap_or_else(|i| i);
331        samples.push(idx.min(weights.len() - 1));
332    }
333
334    Ok(samples)
335}
336
337/// Advanced distribution sampling utilities
338pub struct DistributionSampler {
339    rng: StdRng,
340}
341
342impl DistributionSampler {
343    /// Create a new distribution sampler with optional seed
344    pub fn new(seed: Option<u64>) -> Self {
345        Self { rng: get_rng(seed) }
346    }
347
348    /// Sample from normal distribution
349    pub fn normal(&mut self, mean: f64, std: f64, n: usize) -> UtilsResult<Vec<f64>> {
350        if std <= 0.0 {
351            return Err(UtilsError::InvalidParameter(
352                "Standard deviation must be positive".to_string(),
353            ));
354        }
355
356        let mut samples = Vec::with_capacity(n);
357        for _ in 0..n {
358            // Box-Muller transform
359            let u1 = self.rng.gen::<f64>();
360            let u2 = self.rng.gen::<f64>();
361            let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
362            samples.push(mean + std * z);
363        }
364        Ok(samples)
365    }
366
367    /// Sample from uniform distribution
368    pub fn uniform(&mut self, low: f64, high: f64, n: usize) -> UtilsResult<Vec<f64>> {
369        if low >= high {
370            return Err(UtilsError::InvalidParameter(
371                "Low bound must be less than high bound".to_string(),
372            ));
373        }
374
375        let samples = (0..n)
376            .map(|_| {
377                let u = self.rng.gen::<f64>();
378                low + (high - low) * u
379            })
380            .collect();
381        Ok(samples)
382    }
383
384    /// Sample from beta distribution
385    pub fn beta(&mut self, alpha: f64, beta: f64, n: usize) -> UtilsResult<Vec<f64>> {
386        if alpha <= 0.0 || beta <= 0.0 {
387            return Err(UtilsError::InvalidParameter(
388                "Beta parameters must be positive".to_string(),
389            ));
390        }
391
392        let mut samples = Vec::with_capacity(n);
393        for _ in 0..n {
394            // Simple gamma ratio method
395            let x = self.gamma_sample(alpha);
396            let y = self.gamma_sample(beta);
397            samples.push(x / (x + y));
398        }
399        Ok(samples)
400    }
401
402    /// Helper function to sample from gamma distribution
403    fn gamma_sample(&mut self, shape: f64) -> f64 {
404        // Simple approximation for gamma sampling
405        if shape < 1.0 {
406            let u = self.rng.gen::<f64>();
407            u.powf(1.0 / shape)
408        } else {
409            // Marsaglia and Tsang's method (simplified)
410            let d = shape - 1.0 / 3.0;
411            let c = 1.0 / (9.0 * d).sqrt();
412            loop {
413                let x = self.normal_sample();
414                let v = (1.0 + c * x).powi(3);
415                if v > 0.0 {
416                    let u = self.rng.gen::<f64>();
417                    if u < 1.0 - 0.0331 * x.powi(4)
418                        || u.ln() < 0.5 * x.powi(2) + d * (1.0 - v + v.ln())
419                    {
420                        return d * v;
421                    }
422                }
423            }
424        }
425    }
426
427    /// Helper function to sample from standard normal
428    fn normal_sample(&mut self) -> f64 {
429        let u1 = self.rng.gen::<f64>();
430        let u2 = self.rng.gen::<f64>();
431        (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
432    }
433
434    /// Sample from gamma distribution
435    pub fn gamma(&mut self, shape: f64, scale: f64, n: usize) -> UtilsResult<Vec<f64>> {
436        if shape <= 0.0 || scale <= 0.0 {
437            return Err(UtilsError::InvalidParameter(
438                "Gamma parameters must be positive".to_string(),
439            ));
440        }
441
442        let samples = (0..n).map(|_| self.gamma_sample(shape) * scale).collect();
443        Ok(samples)
444    }
445
446    /// Sample from multivariate normal distribution (diagonal covariance)
447    pub fn multivariate_normal_diag(
448        &mut self,
449        mean: &[f64],
450        variances: &[f64],
451        n: usize,
452    ) -> UtilsResult<Vec<Vec<f64>>> {
453        if mean.len() != variances.len() {
454            return Err(UtilsError::ShapeMismatch {
455                expected: vec![mean.len()],
456                actual: vec![variances.len()],
457            });
458        }
459
460        for &var in variances {
461            if var <= 0.0 {
462                return Err(UtilsError::InvalidParameter(
463                    "All variances must be positive".to_string(),
464                ));
465            }
466        }
467
468        let mut samples = Vec::with_capacity(n);
469
470        for _ in 0..n {
471            let mut sample = Vec::with_capacity(mean.len());
472            for (&m, &v) in mean.iter().zip(variances.iter()) {
473                let z = self.normal_sample();
474                sample.push(m + z * v.sqrt());
475            }
476            samples.push(sample);
477        }
478
479        Ok(samples)
480    }
481
482    /// Sample from truncated normal distribution
483    pub fn truncated_normal(
484        &mut self,
485        mean: f64,
486        std: f64,
487        low: f64,
488        high: f64,
489        n: usize,
490    ) -> UtilsResult<Vec<f64>> {
491        if std <= 0.0 {
492            return Err(UtilsError::InvalidParameter(
493                "Standard deviation must be positive".to_string(),
494            ));
495        }
496
497        if low >= high {
498            return Err(UtilsError::InvalidParameter(
499                "Low bound must be less than high bound".to_string(),
500            ));
501        }
502
503        let mut samples = Vec::with_capacity(n);
504
505        for _ in 0..n {
506            loop {
507                let sample = mean + std * self.normal_sample();
508                if sample >= low && sample <= high {
509                    samples.push(sample);
510                    break;
511                }
512            }
513        }
514
515        Ok(samples)
516    }
517
518    /// Sample from mixture distribution
519    pub fn mixture_normal(
520        &mut self,
521        components: &[(f64, f64, f64)], // (weight, mean, std)
522        n: usize,
523    ) -> UtilsResult<Vec<f64>> {
524        if components.is_empty() {
525            return Err(UtilsError::EmptyInput);
526        }
527
528        // Validate and normalize weights
529        let total_weight: f64 = components.iter().map(|(w, _, _)| w).sum();
530        if total_weight <= 0.0 {
531            return Err(UtilsError::InvalidParameter(
532                "Total mixture weight must be positive".to_string(),
533            ));
534        }
535
536        for &(_, _, std) in components {
537            if std <= 0.0 {
538                return Err(UtilsError::InvalidParameter(
539                    "All standard deviations must be positive".to_string(),
540                ));
541            }
542        }
543
544        let mut cumulative_weights = Vec::with_capacity(components.len());
545        let mut sum = 0.0;
546        for &(weight, _, _) in components {
547            sum += weight / total_weight;
548            cumulative_weights.push(sum);
549        }
550
551        let mut samples = Vec::with_capacity(n);
552
553        for _ in 0..n {
554            // Choose component
555            let r: f64 = self.rng.gen::<f64>();
556            let component_idx = cumulative_weights
557                .binary_search_by(|&x| x.partial_cmp(&r).unwrap())
558                .unwrap_or_else(|i| i);
559
560            let (_, mean, std) = components[component_idx];
561            samples.push(mean + std * self.normal_sample());
562        }
563
564        Ok(samples)
565    }
566}
567
568/// Thread-safe random state management
569pub struct ThreadSafeRng {
570    rng: Mutex<StdRng>,
571}
572
573impl ThreadSafeRng {
574    /// Create a new thread-safe RNG with optional seed
575    pub fn new(seed: Option<u64>) -> Self {
576        Self {
577            rng: Mutex::new(get_rng(seed)),
578        }
579    }
580
581    /// Generate a random number in range [0, 1)
582    pub fn gen(&self) -> f64 {
583        let mut rng = self.rng.lock().unwrap();
584        rng.gen::<f64>()
585    }
586
587    /// Generate a random integer in range [0, n)
588    pub fn random_range(&self, n: usize) -> usize {
589        let mut rng = self.rng.lock().unwrap();
590        rng.gen_range(0..n)
591    }
592
593    /// Generate random indices for sampling
594    pub fn sample_indices(
595        &self,
596        n_samples: usize,
597        size: usize,
598        replace: bool,
599    ) -> UtilsResult<Vec<usize>> {
600        random_indices(n_samples, size, replace, None)
601    }
602
603    /// Serialize the current random state
604    pub fn get_state(&self) -> [u8; 32] {
605        let _rng = self.rng.lock().unwrap();
606        // This is a simplified version - in practice you'd want to properly serialize the RNG state
607        [0u8; 32] // Placeholder
608    }
609
610    /// Deserialize and set random state
611    pub fn set_state(&self, _state: [u8; 32]) {
612        // Placeholder - in practice you'd restore the RNG state
613    }
614}
615
616#[allow(non_snake_case)]
617#[cfg(test)]
618mod tests {
619    use super::*;
620
621    #[test]
622    fn test_set_random_state() {
623        set_random_state(42);
624        let indices1 = random_indices(100, 10, false, None).unwrap();
625
626        set_random_state(42);
627        let indices2 = random_indices(100, 10, false, None).unwrap();
628
629        assert_eq!(indices1, indices2);
630    }
631
632    #[test]
633    fn test_random_indices_without_replacement() {
634        let indices = random_indices(10, 5, false, Some(42)).unwrap();
635        assert_eq!(indices.len(), 5);
636
637        // Check uniqueness
638        let mut sorted = indices.clone();
639        sorted.sort();
640        sorted.dedup();
641        assert_eq!(sorted.len(), 5);
642
643        // Check bounds
644        for &idx in &indices {
645            assert!(idx < 10);
646        }
647    }
648
649    #[test]
650    fn test_random_indices_with_replacement() {
651        let indices = random_indices(5, 10, true, Some(42)).unwrap();
652        assert_eq!(indices.len(), 10);
653
654        // Check bounds
655        for &idx in &indices {
656            assert!(idx < 5);
657        }
658    }
659
660    #[test]
661    fn test_train_test_split_indices() {
662        let (train, test) = train_test_split_indices(100, 0.2, true, Some(42)).unwrap();
663
664        assert_eq!(train.len() + test.len(), 100);
665        assert!((test.len() as f64 / 100.0 - 0.2).abs() < 0.1);
666
667        // Check no overlap
668        let mut all_indices = train.clone();
669        all_indices.extend(&test);
670        all_indices.sort();
671        all_indices.dedup();
672        assert_eq!(all_indices.len(), 100);
673    }
674
675    #[test]
676    fn test_random_weights() {
677        let weights = random_weights(5, Some(42));
678        assert_eq!(weights.len(), 5);
679
680        let sum: f64 = weights.iter().sum();
681        assert!((sum - 1.0).abs() < 1e-10);
682
683        for &w in &weights {
684            assert!(w >= 0.0);
685        }
686    }
687
688    #[test]
689    fn test_bootstrap_indices() {
690        let indices = bootstrap_indices(10, Some(42));
691        assert_eq!(indices.len(), 10);
692
693        for &idx in &indices {
694            assert!(idx < 10);
695        }
696    }
697
698    #[test]
699    fn test_stratified_split() {
700        let labels = vec![0, 0, 0, 0, 1, 1, 1, 1, 2, 2];
701        let (train, test) = stratified_split_indices(&labels, 0.3, Some(42)).unwrap();
702
703        assert_eq!(train.len() + test.len(), 10);
704
705        // Check that each class appears in both train and test
706        let train_labels: Vec<i32> = train.iter().map(|&i| labels[i]).collect();
707        let test_labels: Vec<i32> = test.iter().map(|&i| labels[i]).collect();
708
709        for &class in &[0, 1, 2] {
710            assert!(train_labels.contains(&class));
711            assert!(test_labels.contains(&class));
712        }
713    }
714
715    #[test]
716    fn test_reservoir_sampling() {
717        let items: Vec<i32> = (0..100).collect();
718        let sample = reservoir_sampling(items.into_iter(), 10, Some(42));
719
720        assert_eq!(sample.len(), 10);
721
722        // Check that all sampled items are in valid range
723        for &item in &sample {
724            assert!(item < 100);
725        }
726    }
727
728    #[test]
729    fn test_importance_sampling() {
730        let weights = vec![0.1, 0.3, 0.6]; // Biased towards last item
731        let samples = importance_sampling(&weights, 1000, Some(42)).unwrap();
732
733        assert_eq!(samples.len(), 1000);
734
735        // Count occurrences
736        let mut counts = [0; 3];
737        for &idx in &samples {
738            counts[idx] += 1;
739        }
740
741        // Last item (index 2) should be sampled most frequently
742        assert!(counts[2] > counts[1]);
743        assert!(counts[1] > counts[0]);
744    }
745
746    #[test]
747    fn test_weighted_sampling_without_replacement() {
748        let weights = vec![1.0, 2.0, 3.0, 4.0];
749        let sample = weighted_sampling_without_replacement(&weights, 3, Some(42)).unwrap();
750
751        assert_eq!(sample.len(), 3);
752
753        // Check uniqueness
754        let mut sorted = sample.clone();
755        sorted.sort();
756        sorted.dedup();
757        assert_eq!(sorted.len(), 3);
758    }
759
760    #[test]
761    fn test_distribution_sampler() {
762        let mut sampler = DistributionSampler::new(Some(42));
763
764        // Test normal distribution
765        let normal_samples = sampler.normal(0.0, 1.0, 100).unwrap();
766        assert_eq!(normal_samples.len(), 100);
767
768        // Test uniform distribution
769        let uniform_samples = sampler.uniform(0.0, 1.0, 100).unwrap();
770        assert_eq!(uniform_samples.len(), 100);
771        for &sample in &uniform_samples {
772            assert!(sample >= 0.0 && sample < 1.0);
773        }
774
775        // Test beta distribution
776        let beta_samples = sampler.beta(2.0, 3.0, 100).unwrap();
777        assert_eq!(beta_samples.len(), 100);
778        for &sample in &beta_samples {
779            assert!(sample >= 0.0 && sample <= 1.0);
780        }
781
782        // Test gamma distribution
783        let gamma_samples = sampler.gamma(2.0, 1.0, 100).unwrap();
784        assert_eq!(gamma_samples.len(), 100);
785        for &sample in &gamma_samples {
786            assert!(sample >= 0.0);
787        }
788
789        // Test multivariate normal
790        let mean = vec![0.0, 1.0];
791        let variances = vec![1.0, 2.0];
792        let mv_samples = sampler
793            .multivariate_normal_diag(&mean, &variances, 50)
794            .unwrap();
795        assert_eq!(mv_samples.len(), 50);
796        for sample in &mv_samples {
797            assert_eq!(sample.len(), 2);
798        }
799
800        // Test truncated normal
801        let truncated_samples = sampler.truncated_normal(0.0, 1.0, -1.0, 1.0, 50).unwrap();
802        assert_eq!(truncated_samples.len(), 50);
803        for &sample in &truncated_samples {
804            assert!(sample >= -1.0 && sample <= 1.0);
805        }
806
807        // Test mixture normal
808        let components = vec![(0.3, -1.0, 0.5), (0.7, 1.0, 0.5)];
809        let mixture_samples = sampler.mixture_normal(&components, 100).unwrap();
810        assert_eq!(mixture_samples.len(), 100);
811    }
812
813    #[test]
814    fn test_thread_safe_rng() {
815        let rng = ThreadSafeRng::new(Some(42));
816
817        let val1 = rng.gen();
818        let val2 = rng.gen();
819        assert_ne!(val1, val2);
820
821        let idx = rng.random_range(10);
822        assert!(idx < 10);
823    }
824}