1use crate::{UtilsError, UtilsResult};
4use lazy_static::lazy_static;
5use scirs2_core::random::rngs::StdRng;
6use scirs2_core::random::{Rng, SeedableRng};
7use std::collections::HashMap;
8use std::sync::Mutex;
9
10lazy_static! {
11 static ref GLOBAL_RNG: Mutex<StdRng> = Mutex::new(StdRng::seed_from_u64(42));
12}
13
14pub fn set_random_state(seed: u64) {
16 let mut rng = GLOBAL_RNG.lock().unwrap();
17 *rng = StdRng::seed_from_u64(seed);
18}
19
20pub fn get_rng(seed: Option<u64>) -> StdRng {
22 match seed {
23 Some(s) => StdRng::seed_from_u64(s),
24 None => {
25 let mut rng = GLOBAL_RNG.lock().unwrap();
26 StdRng::seed_from_u64(rng.gen::<u64>())
27 }
28 }
29}
30
31pub fn random_indices(
33 n_samples: usize,
34 size: usize,
35 replace: bool,
36 seed: Option<u64>,
37) -> UtilsResult<Vec<usize>> {
38 if !replace && size > n_samples {
39 return Err(UtilsError::InvalidParameter(format!(
40 "Cannot sample {size} items from {n_samples} without replacement"
41 )));
42 }
43
44 let mut rng = get_rng(seed);
45 let mut indices = Vec::with_capacity(size);
46
47 if replace {
48 for _ in 0..size {
50 indices.push(rng.gen_range(0..n_samples));
51 }
52 } else {
53 let mut available: Vec<usize> = (0..n_samples).collect();
55 for _ in 0..size {
56 let idx = rng.gen_range(0..available.len());
57 indices.push(available.swap_remove(idx));
58 }
59 }
60
61 Ok(indices)
62}
63
64pub fn shuffle_indices(indices: &mut [usize], seed: Option<u64>) {
66 let mut rng = get_rng(seed);
67 for i in (1..indices.len()).rev() {
68 let j = rng.gen_range(0..=i);
69 indices.swap(i, j);
70 }
71}
72
73pub fn random_permutation(n: usize, seed: Option<u64>) -> Vec<usize> {
75 let mut indices: Vec<usize> = (0..n).collect();
76 shuffle_indices(&mut indices, seed);
77 indices
78}
79
80pub fn train_test_split_indices(
82 n_samples: usize,
83 test_size: f64,
84 shuffle: bool,
85 seed: Option<u64>,
86) -> UtilsResult<(Vec<usize>, Vec<usize>)> {
87 if test_size <= 0.0 || test_size >= 1.0 {
88 return Err(UtilsError::InvalidParameter(format!(
89 "test_size must be in (0, 1), got {test_size}"
90 )));
91 }
92
93 let test_samples = (n_samples as f64 * test_size).round() as usize;
94 let train_samples = n_samples - test_samples;
95
96 let indices = if shuffle {
97 random_permutation(n_samples, seed)
98 } else {
99 (0..n_samples).collect()
100 };
101
102 let train_indices = indices[..train_samples].to_vec();
103 let test_indices = indices[train_samples..].to_vec();
104
105 Ok((train_indices, test_indices))
106}
107
108pub fn random_weights(n: usize, seed: Option<u64>) -> Vec<f64> {
110 let mut rng = get_rng(seed);
111 let mut weights: Vec<f64> = (0..n).map(|_| rng.gen::<f64>()).collect();
112 let sum: f64 = weights.iter().sum();
113
114 if sum > 0.0 {
115 for w in &mut weights {
116 *w /= sum;
117 }
118 } else {
119 let uniform_weight = 1.0 / n as f64;
121 weights.fill(uniform_weight);
122 }
123
124 weights
125}
126
127pub fn bootstrap_indices(n_samples: usize, seed: Option<u64>) -> Vec<usize> {
129 random_indices(n_samples, n_samples, true, seed).unwrap()
130}
131
132pub fn k_fold_indices(
134 n_samples: usize,
135 n_splits: usize,
136 shuffle: bool,
137 seed: Option<u64>,
138) -> UtilsResult<Vec<(Vec<usize>, Vec<usize>)>> {
139 if n_splits < 2 {
140 return Err(UtilsError::InvalidParameter(format!(
141 "n_splits must be at least 2, got {n_splits}"
142 )));
143 }
144
145 if n_splits > n_samples {
146 return Err(UtilsError::InvalidParameter(format!(
147 "n_splits {n_splits} cannot be greater than the number of samples {n_samples}"
148 )));
149 }
150
151 let indices = if shuffle {
152 random_permutation(n_samples, seed)
153 } else {
154 (0..n_samples).collect()
155 };
156
157 let mut folds = Vec::with_capacity(n_splits);
158 let fold_sizes: Vec<usize> = (0..n_splits)
159 .map(|i| (n_samples + n_splits - i - 1) / n_splits)
160 .collect();
161
162 let mut start = 0;
163 for fold_size in fold_sizes {
164 let end = start + fold_size;
165 let test_indices = indices[start..end].to_vec();
166 let mut train_indices = Vec::with_capacity(n_samples - fold_size);
167 train_indices.extend(&indices[..start]);
168 train_indices.extend(&indices[end..]);
169
170 folds.push((train_indices, test_indices));
171 start = end;
172 }
173
174 Ok(folds)
175}
176
177pub fn stratified_split_indices(
179 labels: &[i32],
180 test_size: f64,
181 seed: Option<u64>,
182) -> UtilsResult<(Vec<usize>, Vec<usize>)> {
183 if test_size <= 0.0 || test_size >= 1.0 {
184 return Err(UtilsError::InvalidParameter(format!(
185 "test_size must be in (0, 1), got {test_size}"
186 )));
187 }
188
189 let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
191 for (idx, &label) in labels.iter().enumerate() {
192 class_indices.entry(label).or_default().push(idx);
193 }
194
195 let mut train_indices = Vec::new();
196 let mut test_indices = Vec::new();
197
198 for indices in class_indices.values() {
200 let n_class = indices.len();
201 let n_test = (n_class as f64 * test_size).round() as usize;
202 let n_test = n_test.max(1).min(n_class - 1); let (class_train, class_test) =
205 train_test_split_indices(n_class, n_test as f64 / n_class as f64, true, seed)?;
206
207 train_indices.extend(class_train.iter().map(|&i| indices[i]));
208 test_indices.extend(class_test.iter().map(|&i| indices[i]));
209 }
210
211 Ok((train_indices, test_indices))
212}
213
214pub fn reservoir_sampling<T: Clone>(
216 items: impl Iterator<Item = T>,
217 k: usize,
218 seed: Option<u64>,
219) -> Vec<T> {
220 if k == 0 {
221 return Vec::new();
222 }
223
224 let mut rng = get_rng(seed);
225 let mut reservoir = Vec::with_capacity(k);
226
227 for (i, item) in items.enumerate() {
228 if i < k {
229 reservoir.push(item);
231 } else {
232 let j = rng.gen_range(0..=i);
234 if j < k {
235 reservoir[j] = item;
236 }
237 }
238 }
239
240 reservoir
241}
242
243pub fn weighted_sampling_without_replacement(
245 weights: &[f64],
246 k: usize,
247 seed: Option<u64>,
248) -> UtilsResult<Vec<usize>> {
249 if weights.is_empty() {
250 return Err(UtilsError::EmptyInput);
251 }
252
253 if k > weights.len() {
254 return Err(UtilsError::InvalidParameter(format!(
255 "Cannot sample {} items from {} weights without replacement",
256 k,
257 weights.len()
258 )));
259 }
260
261 let sum: f64 = weights.iter().sum();
262 if sum <= 0.0 {
263 return Err(UtilsError::InvalidParameter(
264 "Sum of weights must be positive".to_string(),
265 ));
266 }
267
268 let mut rng = get_rng(seed);
269 let mut cumsum = Vec::with_capacity(weights.len());
270 let mut running_sum = 0.0;
271
272 for &w in weights {
273 running_sum += w;
274 cumsum.push(running_sum / sum);
275 }
276
277 let mut selected = Vec::new();
278 let mut used = vec![false; weights.len()];
279
280 for _ in 0..k {
281 loop {
282 let r: f64 = rng.gen::<f64>();
283 let idx = cumsum
284 .binary_search_by(|&x| x.partial_cmp(&r).unwrap())
285 .unwrap_or_else(|i| i);
286
287 if idx < weights.len() && !used[idx] {
288 used[idx] = true;
289 selected.push(idx);
290 break;
291 }
292 }
293 }
294
295 Ok(selected)
296}
297
298pub fn importance_sampling(
300 weights: &[f64],
301 n_samples: usize,
302 seed: Option<u64>,
303) -> UtilsResult<Vec<usize>> {
304 if weights.is_empty() {
305 return Err(UtilsError::EmptyInput);
306 }
307
308 let sum: f64 = weights.iter().sum();
309 if sum <= 0.0 {
310 return Err(UtilsError::InvalidParameter(
311 "Sum of weights must be positive".to_string(),
312 ));
313 }
314
315 let mut rng = get_rng(seed);
316 let mut cumsum = Vec::with_capacity(weights.len());
317 let mut running_sum = 0.0;
318
319 for &w in weights {
320 running_sum += w;
321 cumsum.push(running_sum / sum);
322 }
323
324 let mut samples = Vec::with_capacity(n_samples);
325
326 for _ in 0..n_samples {
327 let r: f64 = rng.gen::<f64>();
328 let idx = cumsum
329 .binary_search_by(|&x| x.partial_cmp(&r).unwrap())
330 .unwrap_or_else(|i| i);
331 samples.push(idx.min(weights.len() - 1));
332 }
333
334 Ok(samples)
335}
336
337pub struct DistributionSampler {
339 rng: StdRng,
340}
341
342impl DistributionSampler {
343 pub fn new(seed: Option<u64>) -> Self {
345 Self { rng: get_rng(seed) }
346 }
347
348 pub fn normal(&mut self, mean: f64, std: f64, n: usize) -> UtilsResult<Vec<f64>> {
350 if std <= 0.0 {
351 return Err(UtilsError::InvalidParameter(
352 "Standard deviation must be positive".to_string(),
353 ));
354 }
355
356 let mut samples = Vec::with_capacity(n);
357 for _ in 0..n {
358 let u1 = self.rng.gen::<f64>();
360 let u2 = self.rng.gen::<f64>();
361 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
362 samples.push(mean + std * z);
363 }
364 Ok(samples)
365 }
366
367 pub fn uniform(&mut self, low: f64, high: f64, n: usize) -> UtilsResult<Vec<f64>> {
369 if low >= high {
370 return Err(UtilsError::InvalidParameter(
371 "Low bound must be less than high bound".to_string(),
372 ));
373 }
374
375 let samples = (0..n)
376 .map(|_| {
377 let u = self.rng.gen::<f64>();
378 low + (high - low) * u
379 })
380 .collect();
381 Ok(samples)
382 }
383
384 pub fn beta(&mut self, alpha: f64, beta: f64, n: usize) -> UtilsResult<Vec<f64>> {
386 if alpha <= 0.0 || beta <= 0.0 {
387 return Err(UtilsError::InvalidParameter(
388 "Beta parameters must be positive".to_string(),
389 ));
390 }
391
392 let mut samples = Vec::with_capacity(n);
393 for _ in 0..n {
394 let x = self.gamma_sample(alpha);
396 let y = self.gamma_sample(beta);
397 samples.push(x / (x + y));
398 }
399 Ok(samples)
400 }
401
402 fn gamma_sample(&mut self, shape: f64) -> f64 {
404 if shape < 1.0 {
406 let u = self.rng.gen::<f64>();
407 u.powf(1.0 / shape)
408 } else {
409 let d = shape - 1.0 / 3.0;
411 let c = 1.0 / (9.0 * d).sqrt();
412 loop {
413 let x = self.normal_sample();
414 let v = (1.0 + c * x).powi(3);
415 if v > 0.0 {
416 let u = self.rng.gen::<f64>();
417 if u < 1.0 - 0.0331 * x.powi(4)
418 || u.ln() < 0.5 * x.powi(2) + d * (1.0 - v + v.ln())
419 {
420 return d * v;
421 }
422 }
423 }
424 }
425 }
426
427 fn normal_sample(&mut self) -> f64 {
429 let u1 = self.rng.gen::<f64>();
430 let u2 = self.rng.gen::<f64>();
431 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
432 }
433
434 pub fn gamma(&mut self, shape: f64, scale: f64, n: usize) -> UtilsResult<Vec<f64>> {
436 if shape <= 0.0 || scale <= 0.0 {
437 return Err(UtilsError::InvalidParameter(
438 "Gamma parameters must be positive".to_string(),
439 ));
440 }
441
442 let samples = (0..n).map(|_| self.gamma_sample(shape) * scale).collect();
443 Ok(samples)
444 }
445
446 pub fn multivariate_normal_diag(
448 &mut self,
449 mean: &[f64],
450 variances: &[f64],
451 n: usize,
452 ) -> UtilsResult<Vec<Vec<f64>>> {
453 if mean.len() != variances.len() {
454 return Err(UtilsError::ShapeMismatch {
455 expected: vec![mean.len()],
456 actual: vec![variances.len()],
457 });
458 }
459
460 for &var in variances {
461 if var <= 0.0 {
462 return Err(UtilsError::InvalidParameter(
463 "All variances must be positive".to_string(),
464 ));
465 }
466 }
467
468 let mut samples = Vec::with_capacity(n);
469
470 for _ in 0..n {
471 let mut sample = Vec::with_capacity(mean.len());
472 for (&m, &v) in mean.iter().zip(variances.iter()) {
473 let z = self.normal_sample();
474 sample.push(m + z * v.sqrt());
475 }
476 samples.push(sample);
477 }
478
479 Ok(samples)
480 }
481
482 pub fn truncated_normal(
484 &mut self,
485 mean: f64,
486 std: f64,
487 low: f64,
488 high: f64,
489 n: usize,
490 ) -> UtilsResult<Vec<f64>> {
491 if std <= 0.0 {
492 return Err(UtilsError::InvalidParameter(
493 "Standard deviation must be positive".to_string(),
494 ));
495 }
496
497 if low >= high {
498 return Err(UtilsError::InvalidParameter(
499 "Low bound must be less than high bound".to_string(),
500 ));
501 }
502
503 let mut samples = Vec::with_capacity(n);
504
505 for _ in 0..n {
506 loop {
507 let sample = mean + std * self.normal_sample();
508 if sample >= low && sample <= high {
509 samples.push(sample);
510 break;
511 }
512 }
513 }
514
515 Ok(samples)
516 }
517
518 pub fn mixture_normal(
520 &mut self,
521 components: &[(f64, f64, f64)], n: usize,
523 ) -> UtilsResult<Vec<f64>> {
524 if components.is_empty() {
525 return Err(UtilsError::EmptyInput);
526 }
527
528 let total_weight: f64 = components.iter().map(|(w, _, _)| w).sum();
530 if total_weight <= 0.0 {
531 return Err(UtilsError::InvalidParameter(
532 "Total mixture weight must be positive".to_string(),
533 ));
534 }
535
536 for &(_, _, std) in components {
537 if std <= 0.0 {
538 return Err(UtilsError::InvalidParameter(
539 "All standard deviations must be positive".to_string(),
540 ));
541 }
542 }
543
544 let mut cumulative_weights = Vec::with_capacity(components.len());
545 let mut sum = 0.0;
546 for &(weight, _, _) in components {
547 sum += weight / total_weight;
548 cumulative_weights.push(sum);
549 }
550
551 let mut samples = Vec::with_capacity(n);
552
553 for _ in 0..n {
554 let r: f64 = self.rng.gen::<f64>();
556 let component_idx = cumulative_weights
557 .binary_search_by(|&x| x.partial_cmp(&r).unwrap())
558 .unwrap_or_else(|i| i);
559
560 let (_, mean, std) = components[component_idx];
561 samples.push(mean + std * self.normal_sample());
562 }
563
564 Ok(samples)
565 }
566}
567
568pub struct ThreadSafeRng {
570 rng: Mutex<StdRng>,
571}
572
573impl ThreadSafeRng {
574 pub fn new(seed: Option<u64>) -> Self {
576 Self {
577 rng: Mutex::new(get_rng(seed)),
578 }
579 }
580
581 pub fn gen(&self) -> f64 {
583 let mut rng = self.rng.lock().unwrap();
584 rng.gen::<f64>()
585 }
586
587 pub fn random_range(&self, n: usize) -> usize {
589 let mut rng = self.rng.lock().unwrap();
590 rng.gen_range(0..n)
591 }
592
593 pub fn sample_indices(
595 &self,
596 n_samples: usize,
597 size: usize,
598 replace: bool,
599 ) -> UtilsResult<Vec<usize>> {
600 random_indices(n_samples, size, replace, None)
601 }
602
603 pub fn get_state(&self) -> [u8; 32] {
605 let _rng = self.rng.lock().unwrap();
606 [0u8; 32] }
609
610 pub fn set_state(&self, _state: [u8; 32]) {
612 }
614}
615
616#[allow(non_snake_case)]
617#[cfg(test)]
618mod tests {
619 use super::*;
620
621 #[test]
622 fn test_set_random_state() {
623 set_random_state(42);
624 let indices1 = random_indices(100, 10, false, None).unwrap();
625
626 set_random_state(42);
627 let indices2 = random_indices(100, 10, false, None).unwrap();
628
629 assert_eq!(indices1, indices2);
630 }
631
632 #[test]
633 fn test_random_indices_without_replacement() {
634 let indices = random_indices(10, 5, false, Some(42)).unwrap();
635 assert_eq!(indices.len(), 5);
636
637 let mut sorted = indices.clone();
639 sorted.sort();
640 sorted.dedup();
641 assert_eq!(sorted.len(), 5);
642
643 for &idx in &indices {
645 assert!(idx < 10);
646 }
647 }
648
649 #[test]
650 fn test_random_indices_with_replacement() {
651 let indices = random_indices(5, 10, true, Some(42)).unwrap();
652 assert_eq!(indices.len(), 10);
653
654 for &idx in &indices {
656 assert!(idx < 5);
657 }
658 }
659
660 #[test]
661 fn test_train_test_split_indices() {
662 let (train, test) = train_test_split_indices(100, 0.2, true, Some(42)).unwrap();
663
664 assert_eq!(train.len() + test.len(), 100);
665 assert!((test.len() as f64 / 100.0 - 0.2).abs() < 0.1);
666
667 let mut all_indices = train.clone();
669 all_indices.extend(&test);
670 all_indices.sort();
671 all_indices.dedup();
672 assert_eq!(all_indices.len(), 100);
673 }
674
675 #[test]
676 fn test_random_weights() {
677 let weights = random_weights(5, Some(42));
678 assert_eq!(weights.len(), 5);
679
680 let sum: f64 = weights.iter().sum();
681 assert!((sum - 1.0).abs() < 1e-10);
682
683 for &w in &weights {
684 assert!(w >= 0.0);
685 }
686 }
687
688 #[test]
689 fn test_bootstrap_indices() {
690 let indices = bootstrap_indices(10, Some(42));
691 assert_eq!(indices.len(), 10);
692
693 for &idx in &indices {
694 assert!(idx < 10);
695 }
696 }
697
698 #[test]
699 fn test_stratified_split() {
700 let labels = vec![0, 0, 0, 0, 1, 1, 1, 1, 2, 2];
701 let (train, test) = stratified_split_indices(&labels, 0.3, Some(42)).unwrap();
702
703 assert_eq!(train.len() + test.len(), 10);
704
705 let train_labels: Vec<i32> = train.iter().map(|&i| labels[i]).collect();
707 let test_labels: Vec<i32> = test.iter().map(|&i| labels[i]).collect();
708
709 for &class in &[0, 1, 2] {
710 assert!(train_labels.contains(&class));
711 assert!(test_labels.contains(&class));
712 }
713 }
714
715 #[test]
716 fn test_reservoir_sampling() {
717 let items: Vec<i32> = (0..100).collect();
718 let sample = reservoir_sampling(items.into_iter(), 10, Some(42));
719
720 assert_eq!(sample.len(), 10);
721
722 for &item in &sample {
724 assert!(item < 100);
725 }
726 }
727
728 #[test]
729 fn test_importance_sampling() {
730 let weights = vec![0.1, 0.3, 0.6]; let samples = importance_sampling(&weights, 1000, Some(42)).unwrap();
732
733 assert_eq!(samples.len(), 1000);
734
735 let mut counts = [0; 3];
737 for &idx in &samples {
738 counts[idx] += 1;
739 }
740
741 assert!(counts[2] > counts[1]);
743 assert!(counts[1] > counts[0]);
744 }
745
746 #[test]
747 fn test_weighted_sampling_without_replacement() {
748 let weights = vec![1.0, 2.0, 3.0, 4.0];
749 let sample = weighted_sampling_without_replacement(&weights, 3, Some(42)).unwrap();
750
751 assert_eq!(sample.len(), 3);
752
753 let mut sorted = sample.clone();
755 sorted.sort();
756 sorted.dedup();
757 assert_eq!(sorted.len(), 3);
758 }
759
760 #[test]
761 fn test_distribution_sampler() {
762 let mut sampler = DistributionSampler::new(Some(42));
763
764 let normal_samples = sampler.normal(0.0, 1.0, 100).unwrap();
766 assert_eq!(normal_samples.len(), 100);
767
768 let uniform_samples = sampler.uniform(0.0, 1.0, 100).unwrap();
770 assert_eq!(uniform_samples.len(), 100);
771 for &sample in &uniform_samples {
772 assert!(sample >= 0.0 && sample < 1.0);
773 }
774
775 let beta_samples = sampler.beta(2.0, 3.0, 100).unwrap();
777 assert_eq!(beta_samples.len(), 100);
778 for &sample in &beta_samples {
779 assert!(sample >= 0.0 && sample <= 1.0);
780 }
781
782 let gamma_samples = sampler.gamma(2.0, 1.0, 100).unwrap();
784 assert_eq!(gamma_samples.len(), 100);
785 for &sample in &gamma_samples {
786 assert!(sample >= 0.0);
787 }
788
789 let mean = vec![0.0, 1.0];
791 let variances = vec![1.0, 2.0];
792 let mv_samples = sampler
793 .multivariate_normal_diag(&mean, &variances, 50)
794 .unwrap();
795 assert_eq!(mv_samples.len(), 50);
796 for sample in &mv_samples {
797 assert_eq!(sample.len(), 2);
798 }
799
800 let truncated_samples = sampler.truncated_normal(0.0, 1.0, -1.0, 1.0, 50).unwrap();
802 assert_eq!(truncated_samples.len(), 50);
803 for &sample in &truncated_samples {
804 assert!(sample >= -1.0 && sample <= 1.0);
805 }
806
807 let components = vec![(0.3, -1.0, 0.5), (0.7, 1.0, 0.5)];
809 let mixture_samples = sampler.mixture_normal(&components, 100).unwrap();
810 assert_eq!(mixture_samples.len(), 100);
811 }
812
813 #[test]
814 fn test_thread_safe_rng() {
815 let rng = ThreadSafeRng::new(Some(42));
816
817 let val1 = rng.gen();
818 let val2 = rng.gen();
819 assert_ne!(val1, val2);
820
821 let idx = rng.random_range(10);
822 assert!(idx < 10);
823 }
824}