1use crate::error::{StatsError, StatsResult};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
8use scirs2_core::numeric::{Float, NumCast, Zero};
9use scirs2_core::random::prelude::*;
10use scirs2_core::random::uniform::SampleUniform;
11use scirs2_core::random::{rngs::StdRng, SeedableRng};
12use scirs2_core::random::{Distribution, StandardNormal};
13
14#[allow(dead_code)]
38pub fn random_sample<T, D>(
39 size: usize,
40 distribution: &D,
41 seed: Option<u64>,
42) -> StatsResult<Array1<T>>
43where
44 T: Copy + Zero,
45 D: Distribution<T>,
46{
47 if size == 0 {
48 return Err(StatsError::InvalidArgument(
49 "Size must be positive".to_string(),
50 ));
51 }
52
53 let mut rng: StdRng = match seed {
54 Some(seed_value) => {
55 let mut seed_bytes = [0u8; 32];
57 seed_bytes[..8].copy_from_slice(&seed_value.to_le_bytes());
58 SeedableRng::from_seed(seed_bytes)
59 }
60 None => {
61 {
63 let mut system_rng = scirs2_core::random::thread_rng();
64 let seed: [u8; 32] = system_rng.random();
65 SeedableRng::from_seed(seed)
66 }
67 }
68 };
69
70 let mut result = Array1::zeros(size);
71 for i in 0..size {
72 result[i] = distribution.sample(&mut rng);
73 }
74
75 Ok(result)
76}
77
78#[allow(dead_code)]
106pub fn uniform<F>(low: F, high: F, size: usize, seed: Option<u64>) -> StatsResult<Array1<F>>
107where
108 F: Float + NumCast + Zero + SampleUniform + std::fmt::Display,
109{
110 if size == 0 {
111 return Err(StatsError::InvalidArgument(
112 "Size must be positive".to_string(),
113 ));
114 }
115
116 if low >= high {
117 return Err(StatsError::InvalidArgument(
118 "Upper bound must be greater than lower bound".to_string(),
119 ));
120 }
121
122 let distribution = scirs2_core::random::Uniform::new(low, high).map_err(|e| {
123 StatsError::ComputationError(format!("Failed to create uniform distribution: {}", e))
124 })?;
125 random_sample(size, &distribution, seed)
126}
127
128#[allow(dead_code)]
156pub fn randint(low: i64, high: i64, size: usize, seed: Option<u64>) -> StatsResult<Array1<i64>> {
157 if size == 0 {
158 return Err(StatsError::InvalidArgument(
159 "Size must be positive".to_string(),
160 ));
161 }
162
163 if low >= high {
164 return Err(StatsError::InvalidArgument(
165 "Upper bound must be greater than lower bound".to_string(),
166 ));
167 }
168
169 let distribution = scirs2_core::random::Uniform::new_inclusive(low, high - 1).map_err(|e| {
170 StatsError::ComputationError(format!("Failed to create uniform distribution: {}", e))
171 })?;
172 random_sample(size, &distribution, seed)
173}
174
175#[allow(dead_code)]
203pub fn randn(size: usize, seed: Option<u64>) -> StatsResult<Array1<f64>> {
204 if size == 0 {
205 return Err(StatsError::InvalidArgument(
206 "Size must be positive".to_string(),
207 ));
208 }
209
210 let mut rng: StdRng = match seed {
211 Some(seed_value) => {
212 let mut seed_bytes = [0u8; 32];
214 seed_bytes[..8].copy_from_slice(&seed_value.to_le_bytes());
215 SeedableRng::from_seed(seed_bytes)
216 }
217 None => {
218 {
220 let mut system_rng = scirs2_core::random::thread_rng();
221 let seed: [u8; 32] = system_rng.random();
222 SeedableRng::from_seed(seed)
223 }
224 }
225 };
226
227 let distribution = StandardNormal;
228 let mut result = Array1::zeros(size);
229 for i in 0..size {
230 result[i] = distribution.sample(&mut rng);
231 }
232
233 Ok(result)
234}
235
236#[allow(dead_code)]
268pub fn choice<T>(
269 a: &ArrayView1<T>,
270 size: usize,
271 replace: bool,
272 p: Option<&ArrayView1<f64>>,
273 seed: Option<u64>,
274) -> StatsResult<Array1<T>>
275where
276 T: Copy,
277{
278 let n = a.len();
279
280 if n == 0 {
281 return Err(StatsError::InvalidArgument(
282 "Input array cannot be empty".to_string(),
283 ));
284 }
285
286 if size == 0 {
287 return Err(StatsError::InvalidArgument(
288 "Size must be positive".to_string(),
289 ));
290 }
291
292 if !replace && size > n {
293 return Err(StatsError::InvalidArgument(
294 "Cannot take a larger sample than population when 'replace=false'".to_string(),
295 ));
296 }
297
298 let mut rng: StdRng = match seed {
299 Some(seed_value) => {
300 let mut seed_bytes = [0u8; 32];
302 seed_bytes[..8].copy_from_slice(&seed_value.to_le_bytes());
303 SeedableRng::from_seed(seed_bytes)
304 }
305 None => {
306 {
308 let mut system_rng = scirs2_core::random::thread_rng();
309 let seed: [u8; 32] = system_rng.random();
310 SeedableRng::from_seed(seed)
311 }
312 }
313 };
314
315 let mut result = Vec::with_capacity(size);
316
317 if let Some(weights) = p {
318 if weights.len() != n {
320 return Err(StatsError::DimensionMismatch(
321 "Length of weights must match length of array".to_string(),
322 ));
323 }
324
325 let sum: f64 = weights.iter().sum();
327 if (sum - 1.0).abs() > 1e-10 {
328 return Err(StatsError::InvalidArgument(
329 "Weights must sum to 1.0".to_string(),
330 ));
331 }
332
333 let mut cumulative = Vec::with_capacity(n);
335 let mut cum_sum = 0.0;
336
337 for &w in weights.iter() {
338 if w < 0.0 {
339 return Err(StatsError::InvalidArgument(
340 "Weights must be non-negative".to_string(),
341 ));
342 }
343 cum_sum += w;
344 cumulative.push(cum_sum);
345 }
346
347 if replace {
348 for _ in 0..size {
350 let r: f64 = rng.random();
351
352 let mut low = 0;
354 let mut high = n - 1;
355
356 while low < high {
357 let mid = (low + high) / 2;
358 if r > cumulative[mid] {
359 low = mid + 1;
360 } else {
361 high = mid;
362 }
363 }
364
365 result.push(a[low]);
366 }
367 } else {
368 let mut indices: Vec<usize> = (0..n).collect();
370
371 for i in 0..size {
372 let mut remaining_weights = vec![0.0; n - i];
374 let mut total_weight = 0.0;
375
376 for j in 0..n - i {
377 remaining_weights[j] = weights[indices[j]];
378 total_weight += remaining_weights[j];
379 }
380
381 for w in remaining_weights.iter_mut() {
383 *w /= total_weight;
384 }
385
386 let mut cum_weights = vec![0.0; n - i];
388 let mut cum_sum = 0.0;
389
390 for j in 0..n - i {
391 cum_sum += remaining_weights[j];
392 cum_weights[j] = cum_sum;
393 }
394
395 let r: f64 = rng.random();
397 let mut selected = 0;
398
399 for (j, &weight) in cum_weights.iter().enumerate().take(n - i) {
400 if r <= weight {
401 selected = j;
402 break;
403 }
404 }
405
406 result.push(a[indices[selected]]);
407
408 indices.swap(selected, n - i - 1);
410 }
411 }
412 } else {
413 if replace {
415 let uniform = scirs2_core::random::Uniform::new(0, n).expect("Operation failed");
417
418 for _ in 0..size {
419 let idx = uniform.sample(&mut rng);
420 result.push(a[idx]);
421 }
422 } else {
423 let mut indices: Vec<usize> = (0..n).collect();
425
426 for i in 0..size {
427 let j = rng.random_range(i..n);
428 indices.swap(i, j);
429 result.push(a[indices[i]]);
430 }
431 }
432 }
433
434 Ok(Array1::from(result))
435}
436
437#[allow(dead_code)]
467pub fn permutation<T>(x: &ArrayView1<T>, seed: Option<u64>) -> StatsResult<Array1<T>>
468where
469 T: Copy,
470{
471 let n = x.len();
472
473 if n == 0 {
474 return Err(StatsError::InvalidArgument(
475 "Input array cannot be empty".to_string(),
476 ));
477 }
478
479 let mut rng: StdRng = match seed {
480 Some(seed_value) => {
481 let mut seed_bytes = [0u8; 32];
483 seed_bytes[..8].copy_from_slice(&seed_value.to_le_bytes());
484 SeedableRng::from_seed(seed_bytes)
485 }
486 None => {
487 {
489 let mut system_rng = scirs2_core::random::thread_rng();
490 let seed: [u8; 32] = system_rng.random();
491 SeedableRng::from_seed(seed)
492 }
493 }
494 };
495
496 let mut result = Array1::from_iter(x.iter().cloned());
498
499 for i in (1..n).rev() {
501 let j = rng.random_range(0..i);
502 result.swap(i, j);
503 }
504
505 Ok(result)
506}
507
508#[allow(dead_code)]
536pub fn permutation_int(n: usize, seed: Option<u64>) -> StatsResult<Array1<usize>> {
537 if n == 0 {
538 return Err(StatsError::InvalidArgument(
539 "Length must be positive".to_string(),
540 ));
541 }
542
543 let mut rng: StdRng = match seed {
544 Some(seed_value) => {
545 let mut seed_bytes = [0u8; 32];
547 seed_bytes[..8].copy_from_slice(&seed_value.to_le_bytes());
548 SeedableRng::from_seed(seed_bytes)
549 }
550 None => {
551 {
553 let mut system_rng = scirs2_core::random::thread_rng();
554 let seed: [u8; 32] = system_rng.random();
555 SeedableRng::from_seed(seed)
556 }
557 }
558 };
559
560 let mut result = Array1::from_iter(0..n);
562
563 for i in (1..n).rev() {
565 let j = rng.random_range(0..i);
566 result.swap(i, j);
567 }
568
569 Ok(result)
570}
571
572#[allow(dead_code)]
602pub fn random_binary_matrix(
603 n_rows: usize,
604 n_cols: usize,
605 density: f64,
606 seed: Option<u64>,
607) -> StatsResult<Array2<i32>> {
608 if n_rows == 0 || n_cols == 0 {
609 return Err(StatsError::InvalidArgument(
610 "Dimensions must be positive".to_string(),
611 ));
612 }
613
614 if !(0.0..=1.0).contains(&density) {
615 return Err(StatsError::InvalidArgument(
616 "Density must be between 0 and 1".to_string(),
617 ));
618 }
619
620 let mut rng: StdRng = match seed {
621 Some(seed_value) => {
622 let mut seed_bytes = [0u8; 32];
624 seed_bytes[..8].copy_from_slice(&seed_value.to_le_bytes());
625 SeedableRng::from_seed(seed_bytes)
626 }
627 None => {
628 {
630 let mut system_rng = scirs2_core::random::thread_rng();
631 let seed: [u8; 32] = system_rng.random();
632 SeedableRng::from_seed(seed)
633 }
634 }
635 };
636
637 let mut result = Array2::zeros((n_rows, n_cols));
638
639 for i in 0..n_rows {
640 for j in 0..n_cols {
641 if rng.random::<f64>() < density {
642 result[[i, j]] = 1;
643 }
644 }
645 }
646
647 Ok(result)
648}
649
650#[allow(dead_code)]
678pub fn bootstrap_sample<T>(
679 x: &ArrayView1<T>,
680 n_samples_: usize,
681 seed: Option<u64>,
682) -> StatsResult<Array2<T>>
683where
684 T: Copy + scirs2_core::numeric::Zero,
685{
686 let n = x.len();
687
688 if n == 0 {
689 return Err(StatsError::InvalidArgument(
690 "Input array cannot be empty".to_string(),
691 ));
692 }
693
694 if n_samples_ == 0 {
695 return Err(StatsError::InvalidArgument(
696 "Number of _samples must be positive".to_string(),
697 ));
698 }
699
700 let mut rng: StdRng = match seed {
701 Some(seed_value) => {
702 let mut seed_bytes = [0u8; 32];
704 seed_bytes[..8].copy_from_slice(&seed_value.to_le_bytes());
705 SeedableRng::from_seed(seed_bytes)
706 }
707 None => {
708 {
710 let mut system_rng = scirs2_core::random::thread_rng();
711 let seed: [u8; 32] = system_rng.random();
712 SeedableRng::from_seed(seed)
713 }
714 }
715 };
716
717 let uniform = scirs2_core::random::Uniform::new(0, n).expect("Operation failed");
718
719 let mut result = Array2::zeros((n_samples_, n));
720
721 for i in 0..n_samples_ {
722 for j in 0..n {
723 let idx = uniform.sample(&mut rng);
724 result[[i, j]] = x[idx];
725 }
726 }
727
728 Ok(result)
729}
730
731#[cfg(test)]
732mod tests {
733 use super::*;
734 use approx::assert_relative_eq;
735 use scirs2_core::ndarray::array;
736
737 #[test]
738 fn test_random_sample() {
739 let uniform_dist = scirs2_core::random::Uniform::new(0.0, 1.0).expect("Operation failed");
741 let samples = random_sample(100, &uniform_dist, Some(42)).expect("Operation failed");
742
743 assert_eq!(samples.len(), 100);
744 for &s in samples.iter() {
745 assert!((0.0..1.0).contains(&s));
746 }
747
748 assert!(
750 random_sample::<f64, scirs2_core::random::Uniform<f64>>(0, &uniform_dist, None)
751 .is_err()
752 );
753 }
754
755 #[test]
756 fn test_uniform() {
757 let samples = uniform(10.0, 20.0, 50, Some(42)).expect("Operation failed");
759
760 assert_eq!(samples.len(), 50);
761 for &s in samples.iter() {
762 assert!((10.0..20.0).contains(&s));
763 }
764
765 assert!(uniform(5.0, 5.0, 10, None).is_err());
767 assert!(uniform(10.0, 5.0, 10, None).is_err());
768 assert!(uniform(0.0, 1.0, 0, None).is_err());
769 }
770
771 #[test]
772 fn test_randint() {
773 let samples = randint(1, 101, 100, Some(42)).expect("Operation failed");
775
776 assert_eq!(samples.len(), 100);
777 for &s in samples.iter() {
778 assert!((1..=100).contains(&s));
779 }
780
781 assert!(randint(5, 5, 10, None).is_err());
783 assert!(randint(10, 5, 10, None).is_err());
784 assert!(randint(0, 10, 0, None).is_err());
785 }
786
787 #[test]
788 fn test_randn() {
789 let samples = randn(1000, Some(42)).expect("Operation failed");
791
792 assert_eq!(samples.len(), 1000);
793
794 let sum: f64 = samples.iter().sum();
796 let mean = sum / 1000.0;
797
798 let sum_sq: f64 = samples.iter().map(|&x| (x - mean).powi(2)).sum();
800 let variance = sum_sq / 1000.0;
801
802 assert!(mean.abs() < 0.1);
804
805 assert_relative_eq!(variance, 1.0, epsilon = 0.2);
807
808 assert!(randn(0, None).is_err());
810 }
811
812 #[test]
813 fn test_choice() {
814 let options = array![10, 20, 30, 40, 50];
815
816 let choices = choice(&options.view(), 10, true, None, Some(42)).expect("Operation failed");
818 assert_eq!(choices.len(), 10);
819
820 for &c in choices.iter() {
822 assert!(options.iter().any(|&x| x == c));
823 }
824
825 let choices_no_replace =
827 choice(&options.view(), 3, false, None, Some(123)).expect("Operation failed");
828 assert_eq!(choices_no_replace.len(), 3);
829
830 for i in 0..choices_no_replace.len() {
832 for j in i + 1..choices_no_replace.len() {
833 assert_ne!(choices_no_replace[i], choices_no_replace[j]);
834 }
835 }
836
837 let weights = array![0.1, 0.2, 0.3, 0.2, 0.2];
839 let weighted_choices = choice(&options.view(), 5, true, Some(&weights.view()), Some(42))
840 .expect("Operation failed");
841 assert_eq!(weighted_choices.len(), 5);
842
843 assert!(choice(&options.view(), 0, true, None, None).is_err());
845 assert!(choice(&options.view(), 10, false, None, None).is_err());
846
847 let wrong_weights = array![0.5, 0.5];
849 assert!(choice(&options.view(), 2, true, Some(&wrong_weights.view()), None).is_err());
850
851 let neg_weights = array![-0.1, 0.2, 0.3, 0.3, 0.3];
853 assert!(choice(&options.view(), 2, true, Some(&neg_weights.view()), None).is_err());
854
855 let empty: Array1<i32> = array![];
857 assert!(choice(&empty.view(), 1, true, None, None).is_err());
858 }
859
860 #[test]
861 fn test_permutation() {
862 let arr = array![1, 2, 3, 4, 5];
863
864 let perm = permutation(&arr.view(), Some(42)).expect("Operation failed");
866
867 assert_eq!(perm.len(), arr.len());
869
870 for &val in arr.iter() {
872 assert!(perm.iter().any(|&x| x == val));
873 }
874
875 let empty: Array1<i32> = array![];
877 assert!(permutation(&empty.view(), None).is_err());
878 }
879
880 #[test]
881 fn test_permutation_int() {
882 let perm = permutation_int(10, Some(42)).expect("Operation failed");
884
885 assert_eq!(perm.len(), 10);
887
888 for i in 0..10 {
890 assert!(perm.iter().any(|&x| x == i));
891 }
892
893 assert!(permutation_int(0, None).is_err());
895 }
896
897 #[test]
898 fn test_random_binary_matrix() {
899 let matrix = random_binary_matrix(5, 5, 0.5, Some(42)).expect("Operation failed");
901
902 assert_eq!(matrix.shape(), &[5, 5]);
904
905 for &val in matrix.iter() {
907 assert!(val == 0 || val == 1);
908 }
909
910 let ones_count = matrix.iter().filter(|&&x| x == 1).count();
912 let density = ones_count as f64 / 25.0;
913
914 assert!(density > 0.2 && density < 0.8);
916
917 assert!(random_binary_matrix(0, 5, 0.5, None).is_err());
919 assert!(random_binary_matrix(5, 0, 0.5, None).is_err());
920 assert!(random_binary_matrix(5, 5, -0.1, None).is_err());
921 assert!(random_binary_matrix(5, 5, 1.1, None).is_err());
922 }
923
924 #[test]
925 fn test_bootstrap_sample() {
926 let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
927
928 let samples = bootstrap_sample(&data.view(), 10, Some(42)).expect("Operation failed");
930
931 assert_eq!(samples.shape(), &[10, 5]);
933
934 assert!(bootstrap_sample(&data.view(), 0, None).is_err());
936
937 let empty: Array1<f64> = array![];
939 assert!(bootstrap_sample(&empty.view(), 10, None).is_err());
940 }
941}