1use crate::error::{StatsError, StatsResult};
7use crate::random;
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
9use scirs2_core::numeric::Float;
10use scirs2_core::random::prelude::*;
11use scirs2_core::random::SeedableRng;
12
13pub trait SampleableDistribution<T> {
15 fn rvs(&self, size: usize) -> StatsResult<Vec<T>>;
17}
18
19#[allow(dead_code)]
44pub fn sample_distribution<T, D>(dist: &D, size: usize) -> StatsResult<Array1<T>>
45where
46 T: Float + std::iter::Sum<T> + std::ops::Div<Output = T>,
47 D: SampleableDistribution<T>,
48{
49 if size == 0 {
50 return Err(StatsError::InvalidArgument(
51 "Size must be positive".to_string(),
52 ));
53 }
54
55 let samples = dist.rvs(size)?;
56 Ok(Array1::from_vec(samples))
57}
58
59#[allow(dead_code)]
85pub fn bootstrap<T>(
86 x: &ArrayView1<T>,
87 n_resamples: usize,
88 seed: Option<u64>,
89) -> StatsResult<Array2<T>>
90where
91 T: Copy + scirs2_core::numeric::Zero,
92{
93 random::bootstrap_sample(x, n_resamples, seed)
94}
95
96#[allow(dead_code)]
121pub fn permutation<T>(x: &ArrayView1<T>, seed: Option<u64>) -> StatsResult<Array1<T>>
122where
123 T: Copy,
124{
125 random::permutation(x, seed)
126}
127
128#[allow(dead_code)]
156pub fn stratified_sample<T, G>(
157 x: &ArrayView1<T>,
158 groups: &ArrayView1<G>,
159 size: usize,
160 seed: Option<u64>,
161) -> StatsResult<Array1<usize>>
162where
163 T: Copy,
164 G: Copy + Eq + std::hash::Hash,
165{
166 if x.len() != groups.len() {
167 return Err(StatsError::DimensionMismatch(
168 "Input array and group array must have the same length".to_string(),
169 ));
170 }
171
172 if size == 0 {
173 return Err(StatsError::InvalidArgument(
174 "Size must be positive".to_string(),
175 ));
176 }
177
178 let mut unique_groups = std::collections::HashSet::new();
180 for &g in groups.iter() {
181 unique_groups.insert(g);
182 }
183
184 let n_groups = unique_groups.len();
185
186 let mut group_indices = std::collections::HashMap::new();
188 for (i, &g) in groups.iter().enumerate() {
189 group_indices.entry(g).or_insert_with(Vec::new).push(i);
190 }
191
192 let mut rng = match seed {
194 Some(seed_value) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed_value),
195 None => {
196 let mut rng = scirs2_core::random::thread_rng();
198 let seed = rng.random::<u64>();
199 scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
200 }
201 };
202
203 let mut result = Vec::with_capacity(n_groups * size);
205
206 for (_, indices) in group_indices.iter() {
207 if indices.len() < size {
208 return Err(StatsError::InvalidArgument(format!(
209 "Group size {} is smaller than requested sample size {}",
210 indices.len(),
211 size
212 )));
213 }
214
215 let mut indices_copy = indices.clone();
217 for i in 0..size {
218 let j = rng.random_range(i..indices_copy.len());
219 indices_copy.swap(i, j);
220 result.push(indices_copy[i]);
221 }
222 }
223
224 Ok(Array1::from_vec(result))
225}
226
227#[allow(dead_code)]
256pub fn stratified_bootstrap<T, G>(
257 x: &ArrayView1<T>,
258 groups: &ArrayView1<G>,
259 n_resamples: usize,
260 seed: Option<u64>,
261) -> StatsResult<Array2<T>>
262where
263 T: Copy + scirs2_core::numeric::Zero,
264 G: Copy + Eq + std::hash::Hash,
265{
266 if x.len() != groups.len() {
267 return Err(StatsError::DimensionMismatch(
268 "Input array and group array must have the same length".to_string(),
269 ));
270 }
271
272 if n_resamples == 0 {
273 return Err(StatsError::InvalidArgument(
274 "Number of _resamples must be positive".to_string(),
275 ));
276 }
277
278 let mut group_indices = std::collections::HashMap::new();
280 for (i, &g) in groups.iter().enumerate() {
281 group_indices.entry(g).or_insert_with(Vec::new).push(i);
282 }
283
284 let mut rng = match seed {
286 Some(seed_value) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed_value),
287 None => {
288 let mut rng = scirs2_core::random::thread_rng();
289 let seed = rng.random::<u64>();
290 scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
291 }
292 };
293
294 let mut samples = Array2::zeros((n_resamples, x.len()));
295
296 for resample_idx in 0..n_resamples {
297 let mut sample_idx = 0;
298
299 for (_, indices) in group_indices.iter() {
301 for _ in 0..indices.len() {
302 let random_idx = rng.random_range(0..indices.len());
303 let selected_idx = indices[random_idx];
304 samples[[resample_idx, sample_idx]] = x[selected_idx];
305 sample_idx += 1;
306 }
307 }
308 }
309
310 Ok(samples)
311}
312
313#[allow(dead_code)]
342pub fn block_bootstrap<T>(
343 x: &ArrayView1<T>,
344 blocksize: usize,
345 n_resamples: usize,
346 circular: bool,
347 seed: Option<u64>,
348) -> StatsResult<Array2<T>>
349where
350 T: Copy + scirs2_core::numeric::Zero,
351{
352 if x.is_empty() {
353 return Err(StatsError::InvalidArgument(
354 "Input array cannot be empty".to_string(),
355 ));
356 }
357
358 if blocksize == 0 {
359 return Err(StatsError::InvalidArgument(
360 "Block size must be positive".to_string(),
361 ));
362 }
363
364 if blocksize > x.len() {
365 return Err(StatsError::InvalidArgument(
366 "Block size cannot exceed array length".to_string(),
367 ));
368 }
369
370 if n_resamples == 0 {
371 return Err(StatsError::InvalidArgument(
372 "Number of _resamples must be positive".to_string(),
373 ));
374 }
375
376 let mut rng = match seed {
378 Some(seed_value) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed_value),
379 None => {
380 let mut rng = scirs2_core::random::thread_rng();
381 let seed = rng.random::<u64>();
382 scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
383 }
384 };
385
386 let data_len = x.len();
387 let max_start_pos = if circular {
388 data_len
389 } else {
390 data_len - blocksize + 1
391 };
392
393 let mut samples = Array2::zeros((n_resamples, data_len));
394
395 for resample_idx in 0..n_resamples {
396 let mut sample_pos = 0;
397
398 while sample_pos < data_len {
400 let start_pos = rng.random_range(0..max_start_pos);
402
403 for block_offset in 0..blocksize {
405 if sample_pos >= data_len {
406 break;
407 }
408
409 let data_idx = if circular {
410 (start_pos + block_offset) % data_len
411 } else {
412 start_pos + block_offset
413 };
414
415 samples[[resample_idx, sample_pos]] = x[data_idx];
416 sample_pos += 1;
417 }
418 }
419 }
420
421 Ok(samples)
422}
423
424#[allow(dead_code)]
438pub fn moving_block_bootstrap<T>(
439 x: &ArrayView1<T>,
440 blocksize: usize,
441 n_resamples: usize,
442 seed: Option<u64>,
443) -> StatsResult<Array2<T>>
444where
445 T: Copy + scirs2_core::numeric::Zero,
446{
447 if x.is_empty() {
448 return Err(StatsError::InvalidArgument(
449 "Input array cannot be empty".to_string(),
450 ));
451 }
452
453 if blocksize == 0 || blocksize > x.len() {
454 return Err(StatsError::InvalidArgument(
455 "Block size must be positive and not exceed array length".to_string(),
456 ));
457 }
458
459 let mut blocks = Vec::new();
461 for i in 0..=(x.len() - blocksize) {
462 let mut block = Vec::with_capacity(blocksize);
463 for j in i..(i + blocksize) {
464 block.push(x[j]);
465 }
466 blocks.push(block);
467 }
468
469 let mut rng = match seed {
471 Some(seed_value) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed_value),
472 None => {
473 let mut rng = scirs2_core::random::thread_rng();
474 let seed = rng.random::<u64>();
475 scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
476 }
477 };
478
479 let data_len = x.len();
480 let n_blocks_needed = data_len.div_ceil(blocksize); let mut samples = Array2::zeros((n_resamples, data_len));
482
483 for resample_idx in 0..n_resamples {
484 let mut sample_pos = 0;
485
486 for _ in 0..n_blocks_needed {
488 if sample_pos >= data_len {
489 break;
490 }
491
492 let block_idx = rng.random_range(0..blocks.len());
494 let selected_block = &blocks[block_idx];
495
496 for &value in selected_block {
498 if sample_pos >= data_len {
499 break;
500 }
501 samples[[resample_idx, sample_pos]] = value;
502 sample_pos += 1;
503 }
504 }
505 }
506
507 Ok(samples)
508}
509
510#[allow(dead_code)]
524pub fn stationary_bootstrap<T>(
525 x: &ArrayView1<T>,
526 p: f64,
527 n_resamples: usize,
528 seed: Option<u64>,
529) -> StatsResult<Array2<T>>
530where
531 T: Copy + scirs2_core::numeric::Zero,
532{
533 if x.is_empty() {
534 return Err(StatsError::InvalidArgument(
535 "Input array cannot be empty".to_string(),
536 ));
537 }
538
539 if p <= 0.0 || p >= 1.0 {
540 return Err(StatsError::InvalidArgument(
541 "Probability parameter p must be between 0 and 1".to_string(),
542 ));
543 }
544
545 if n_resamples == 0 {
546 return Err(StatsError::InvalidArgument(
547 "Number of _resamples must be positive".to_string(),
548 ));
549 }
550
551 let mut rng = match seed {
553 Some(seed_value) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed_value),
554 None => {
555 let mut rng = scirs2_core::random::thread_rng();
556 let seed = rng.random::<u64>();
557 scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
558 }
559 };
560
561 let data_len = x.len();
562 let mut samples = Array2::zeros((n_resamples, data_len));
563
564 for resample_idx in 0..n_resamples {
565 let mut sample_pos = 0;
566
567 while sample_pos < data_len {
568 let start_pos = rng.random_range(0..data_len);
570 let mut current_pos = start_pos;
571
572 loop {
574 samples[[resample_idx, sample_pos]] = x[current_pos];
575 sample_pos += 1;
576
577 if sample_pos >= data_len {
578 break;
579 }
580
581 let u: f64 = rng.random();
583 if u < p {
584 break; }
586
587 current_pos = (current_pos + 1) % data_len;
589 }
590 }
591 }
592
593 Ok(samples)
594}
595
596#[allow(dead_code)]
612pub fn double_bootstrap<T, F>(
613 x: &ArrayView1<T>,
614 statistic: F,
615 n_resamples1: usize,
616 n_resamples2: usize,
617 seed: Option<u64>,
618) -> StatsResult<(f64, Array1<f64>, f64)>
619where
620 T: Copy + scirs2_core::numeric::Zero,
621 F: Fn(&ArrayView1<T>) -> StatsResult<f64> + Copy,
622{
623 if x.is_empty() {
624 return Err(StatsError::InvalidArgument(
625 "Input array cannot be empty".to_string(),
626 ));
627 }
628
629 if n_resamples1 == 0 || n_resamples2 == 0 {
630 return Err(StatsError::InvalidArgument(
631 "Number of resamples must be positive".to_string(),
632 ));
633 }
634
635 let original_stat = statistic(x)?;
637
638 let first_level_samples = bootstrap(x, n_resamples1, seed)?;
640 let mut first_level_stats = Array1::zeros(n_resamples1);
641
642 let mut rng = match seed {
644 Some(seed_value) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed_value + 1),
645 None => {
646 let mut rng = scirs2_core::random::thread_rng();
647 let seed = rng.random::<u64>();
648 scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
649 }
650 };
651
652 let mut bias_estimates = Array1::zeros(n_resamples1);
653
654 for i in 0..n_resamples1 {
655 let first_sample = first_level_samples.row(i);
656 let first_stat = statistic(&first_sample)?;
657 first_level_stats[i] = first_stat;
658
659 let second_seed = rng.random::<u64>();
661 let second_level_samples = bootstrap(&first_sample, n_resamples2, Some(second_seed))?;
662
663 let mut second_level_stats = Array1::zeros(n_resamples2);
664 for j in 0..n_resamples2 {
665 let second_sample = second_level_samples.row(j);
666 second_level_stats[j] = statistic(&second_sample)?;
667 }
668
669 let second_level_mean = second_level_stats.mean().expect("Operation failed");
671 bias_estimates[i] = second_level_mean - first_stat;
672 }
673
674 let overall_bias = bias_estimates.mean().expect("Operation failed");
676
677 let _first_level_mean = first_level_stats.mean().expect("Operation failed");
679 let bias_corrected = original_stat - overall_bias;
680
681 Ok((bias_corrected, first_level_stats, overall_bias))
682}
683
684#[allow(dead_code)]
701pub fn bootstrap_confidence_intervals<T, F>(
702 x: &ArrayView1<T>,
703 statistic: F,
704 n_resamples: usize,
705 confidence_level: f64,
706 seed: Option<u64>,
707) -> StatsResult<((f64, f64), (f64, f64), (f64, f64))>
708where
709 T: Copy + scirs2_core::numeric::Zero,
710 F: Fn(&ArrayView1<T>) -> StatsResult<f64> + Copy,
711{
712 if confidence_level <= 0.0 || confidence_level >= 1.0 {
713 return Err(StatsError::InvalidArgument(
714 "Confidence _level must be between 0 and 1".to_string(),
715 ));
716 }
717
718 let original_stat = statistic(x)?;
720
721 let bootstrap_samples = bootstrap(x, n_resamples, seed)?;
723 let mut bootstrap_stats = Array1::zeros(n_resamples);
724
725 for i in 0..n_resamples {
726 let sample = bootstrap_samples.row(i);
727 bootstrap_stats[i] = statistic(&sample)?;
728 }
729
730 let mut sorted_stats = bootstrap_stats.to_vec();
732 sorted_stats.sort_by(|a, b| a.partial_cmp(b).expect("Operation failed"));
733
734 let alpha = 1.0 - confidence_level;
735 let n = sorted_stats.len() as f64;
736
737 let lower_idx = ((alpha / 2.0) * n) as usize;
739 let upper_idx = ((1.0 - alpha / 2.0) * n) as usize;
740 let percentile_ci = (
741 sorted_stats[lower_idx.min(n_resamples - 1)],
742 sorted_stats[upper_idx.min(n_resamples - 1)],
743 );
744
745 let below_original = sorted_stats.iter().filter(|&&x| x < original_stat).count() as f64;
747 let z0 = if below_original > 0.0 && below_original < n {
748 let p = below_original / n;
750 if p > 0.5 {
753 (2.0 * std::f64::consts::PI * p).sqrt()
754 } else {
755 -(2.0 * std::f64::consts::PI * (1.0 - p)).sqrt()
756 }
757 } else {
758 0.0
759 };
760
761 let mut jackknife_stats = Vec::with_capacity(x.len());
763 for i in 0..x.len() {
764 let mut jackknife_sample = Vec::with_capacity(x.len() - 1);
765 for j in 0..x.len() {
766 if i != j {
767 jackknife_sample.push(x[j]);
768 }
769 }
770 let jk_array = Array1::from_vec(jackknife_sample);
771 jackknife_stats.push(statistic(&jk_array.view())?);
772 }
773
774 let jk_mean = jackknife_stats.iter().sum::<f64>() / jackknife_stats.len() as f64;
775 let mut numerator = 0.0;
776 let mut denominator = 0.0;
777 for &jk_stat in &jackknife_stats {
778 let diff = jk_mean - jk_stat;
779 numerator += diff.powi(3);
780 denominator += diff.powi(2);
781 }
782
783 let acceleration = if denominator > 0.0 {
784 numerator / (6.0 * denominator.powf(1.5))
785 } else {
786 0.0
787 };
788
789 let z_alpha_2 = 1.96 * alpha / 2.0; let z_1_alpha_2 = -z_alpha_2;
792
793 let alpha1 = normal_cdf(z0 + (z0 + z_alpha_2) / (1.0 - acceleration * (z0 + z_alpha_2)));
794 let alpha2 = normal_cdf(z0 + (z0 + z_1_alpha_2) / (1.0 - acceleration * (z0 + z_1_alpha_2)));
795
796 let bca_lower_idx = (alpha1 * n) as usize;
797 let bca_upper_idx = (alpha2 * n) as usize;
798
799 let bc_ci = (
800 sorted_stats[bca_lower_idx.min(n_resamples - 1)],
801 sorted_stats[bca_upper_idx.min(n_resamples - 1)],
802 );
803
804 let bca_ci = (
805 sorted_stats[bca_lower_idx.min(n_resamples - 1)],
806 sorted_stats[bca_upper_idx.min(n_resamples - 1)],
807 );
808
809 Ok((percentile_ci, bc_ci, bca_ci))
810}
811
812#[allow(dead_code)]
814fn normal_cdf(x: f64) -> f64 {
815 0.5 * (1.0 + erf(x / std::f64::consts::SQRT_2))
816}
817
818#[allow(dead_code)]
820fn erf(x: f64) -> f64 {
821 let a1 = 0.254829592;
823 let a2 = -0.284496736;
824 let a3 = 1.421413741;
825 let a4 = -1.453152027;
826 let a5 = 1.061405429;
827 let p = 0.3275911;
828
829 let sign = if x >= 0.0 { 1.0 } else { -1.0 };
830 let x = x.abs();
831
832 let t = 1.0 / (1.0 + p * x);
833 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
834
835 sign * y
836}