1use crate::error::{DatasetsError, Result};
9use scirs2_core::ndarray::Array1;
10use scirs2_core::random::prelude::*;
11use scirs2_core::random::prelude::*;
12use scirs2_core::random::rngs::StdRng;
13use scirs2_core::random::seq::SliceRandom;
14use scirs2_core::random::Uniform;
15use std::collections::HashMap;
16
17#[allow(dead_code)]
48pub fn random_sample(
49 n_samples: usize,
50 sample_size: usize,
51 replace: bool,
52 random_seed: Option<u64>,
53) -> Result<Vec<usize>> {
54 if n_samples == 0 {
55 return Err(DatasetsError::InvalidFormat(
56 "Number of _samples must be > 0".to_string(),
57 ));
58 }
59
60 if sample_size == 0 {
61 return Err(DatasetsError::InvalidFormat(
62 "Sample _size must be > 0".to_string(),
63 ));
64 }
65
66 if !replace && sample_size > n_samples {
67 return Err(DatasetsError::InvalidFormat(format!(
68 "Cannot sample {sample_size} items from {n_samples} without replacement"
69 )));
70 }
71
72 let mut rng = match random_seed {
73 Some(_seed) => StdRng::seed_from_u64(_seed),
74 None => {
75 let mut r = thread_rng();
76 StdRng::seed_from_u64(r.next_u64())
77 }
78 };
79
80 let mut indices = Vec::with_capacity(sample_size);
81
82 if replace {
83 for _ in 0..sample_size {
85 indices.push(rng.sample(Uniform::new(0, n_samples).unwrap()));
86 }
87 } else {
88 let mut available: Vec<usize> = (0..n_samples).collect();
90 available.shuffle(&mut rng);
91 indices.extend_from_slice(&available[0..sample_size]);
92 }
93
94 Ok(indices)
95}
96
97#[allow(dead_code)]
131pub fn stratified_sample(
132 targets: &Array1<f64>,
133 sample_size: usize,
134 random_seed: Option<u64>,
135) -> Result<Vec<usize>> {
136 if targets.is_empty() {
137 return Err(DatasetsError::InvalidFormat(
138 "Targets array cannot be empty".to_string(),
139 ));
140 }
141
142 if sample_size == 0 {
143 return Err(DatasetsError::InvalidFormat(
144 "Sample _size must be > 0".to_string(),
145 ));
146 }
147
148 if sample_size > targets.len() {
149 return Err(DatasetsError::InvalidFormat(format!(
150 "Cannot sample {} items from {} total samples",
151 sample_size,
152 targets.len()
153 )));
154 }
155
156 let mut class_indices: HashMap<i64, Vec<usize>> = HashMap::new();
158 for (i, &target) in targets.iter().enumerate() {
159 let class = target.round() as i64;
160 class_indices.entry(class).or_default().push(i);
161 }
162
163 let mut rng = match random_seed {
164 Some(_seed) => StdRng::seed_from_u64(_seed),
165 None => {
166 let mut r = thread_rng();
167 StdRng::seed_from_u64(r.next_u64())
168 }
169 };
170
171 let mut stratified_indices = Vec::new();
172 let n_classes = class_indices.len();
173 let base_samples_per_class = sample_size / n_classes;
174 let remainder = sample_size % n_classes;
175
176 let mut class_list: Vec<_> = class_indices.keys().cloned().collect();
177 class_list.sort();
178
179 for (i, &class) in class_list.iter().enumerate() {
180 let class_samples = class_indices.get(&class).unwrap();
181 let samples_for_this_class = if i < remainder {
182 base_samples_per_class + 1
183 } else {
184 base_samples_per_class
185 };
186
187 if samples_for_this_class > class_samples.len() {
188 return Err(DatasetsError::InvalidFormat(format!(
189 "Class {} has only {} samples but needs {} for stratified sampling",
190 class,
191 class_samples.len(),
192 samples_for_this_class
193 )));
194 }
195
196 let sampled_indices = random_sample(
198 class_samples.len(),
199 samples_for_this_class,
200 false,
201 Some(rng.next_u64()),
202 )?;
203
204 for &idx in &sampled_indices {
205 stratified_indices.push(class_samples[idx]);
206 }
207 }
208
209 stratified_indices.shuffle(&mut rng);
210 Ok(stratified_indices)
211}
212
213#[allow(dead_code)]
252pub fn importance_sample(
253 weights: &Array1<f64>,
254 sample_size: usize,
255 replace: bool,
256 random_seed: Option<u64>,
257) -> Result<Vec<usize>> {
258 if weights.is_empty() {
259 return Err(DatasetsError::InvalidFormat(
260 "Weights array cannot be empty".to_string(),
261 ));
262 }
263
264 if sample_size == 0 {
265 return Err(DatasetsError::InvalidFormat(
266 "Sample _size must be > 0".to_string(),
267 ));
268 }
269
270 if !replace && sample_size > weights.len() {
271 return Err(DatasetsError::InvalidFormat(format!(
272 "Cannot sample {} items from {} without replacement",
273 sample_size,
274 weights.len()
275 )));
276 }
277
278 for &weight in weights.iter() {
280 if weight < 0.0 {
281 return Err(DatasetsError::InvalidFormat(
282 "All weights must be non-negative".to_string(),
283 ));
284 }
285 }
286
287 let weight_sum: f64 = weights.sum();
288 if weight_sum <= 0.0 {
289 return Err(DatasetsError::InvalidFormat(
290 "Sum of weights must be positive".to_string(),
291 ));
292 }
293
294 let mut rng = match random_seed {
295 Some(_seed) => StdRng::seed_from_u64(_seed),
296 None => {
297 let mut r = thread_rng();
298 StdRng::seed_from_u64(r.next_u64())
299 }
300 };
301
302 let mut indices = Vec::with_capacity(sample_size);
303 let mut available_weights = weights.clone();
304 let mut available_indices: Vec<usize> = (0..weights.len()).collect();
305
306 for _ in 0..sample_size {
307 let current_sum = available_weights.sum();
308 if current_sum <= 0.0 {
309 break;
310 }
311
312 let random_value = rng.gen_range(0.0..current_sum);
314
315 let mut cumulative_weight = 0.0;
317 let mut selected_idx = 0;
318
319 for (i, &weight) in available_weights.iter().enumerate() {
320 cumulative_weight += weight;
321 if random_value <= cumulative_weight {
322 selected_idx = i;
323 break;
324 }
325 }
326
327 let original_idx = available_indices[selected_idx];
328 indices.push(original_idx);
329
330 if !replace {
331 available_weights = Array1::from_iter(
333 available_weights
334 .iter()
335 .enumerate()
336 .filter(|(i_, _)| *i_ != selected_idx)
337 .map(|(_, &w)| w),
338 );
339 available_indices.remove(selected_idx);
340 }
341 }
342
343 Ok(indices)
344}
345
346#[allow(dead_code)]
377pub fn bootstrap_sample(
378 n_samples: usize,
379 n_bootstrap_samples: usize,
380 random_seed: Option<u64>,
381) -> Result<Vec<usize>> {
382 random_sample(n_samples, n_bootstrap_samples, true, random_seed)
383}
384
385#[allow(dead_code)]
411pub fn multiple_bootstrap_samples(
412 n_samples: usize,
413 sample_size: usize,
414 n_bootstrap_rounds: usize,
415 random_seed: Option<u64>,
416) -> Result<Vec<Vec<usize>>> {
417 if n_bootstrap_rounds == 0 {
418 return Err(DatasetsError::InvalidFormat(
419 "Number of bootstrap _rounds must be > 0".to_string(),
420 ));
421 }
422
423 let mut rng = match random_seed {
424 Some(_seed) => StdRng::seed_from_u64(_seed),
425 None => {
426 let mut r = thread_rng();
427 StdRng::seed_from_u64(r.next_u64())
428 }
429 };
430
431 let mut bootstrap_samples = Vec::with_capacity(n_bootstrap_rounds);
432
433 for _ in 0..n_bootstrap_rounds {
434 let sample = random_sample(n_samples, sample_size, true, Some(rng.next_u64()))?;
435 bootstrap_samples.push(sample);
436 }
437
438 Ok(bootstrap_samples)
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444 use scirs2_core::ndarray::array;
445 use scirs2_core::random::Uniform;
446 use std::collections::HashSet;
447
448 #[test]
449 fn test_random_sample_without_replacement() {
450 let indices = random_sample(10, 5, false, Some(42)).unwrap();
451
452 assert_eq!(indices.len(), 5);
453 assert!(indices.iter().all(|&i| i < 10));
454
455 let unique_indices: HashSet<_> = indices.iter().cloned().collect();
457 assert_eq!(unique_indices.len(), 5);
458 }
459
460 #[test]
461 fn test_random_sample_with_replacement() {
462 let indices = random_sample(5, 10, true, Some(42)).unwrap();
463
464 assert_eq!(indices.len(), 10);
465 assert!(indices.iter().all(|&i| i < 5));
466
467 let unique_indices: HashSet<_> = indices.iter().cloned().collect();
469 assert!(unique_indices.len() <= 10);
470 }
471
472 #[test]
473 fn test_random_sample_invalid_params() {
474 assert!(random_sample(0, 5, false, None).is_err());
476
477 assert!(random_sample(10, 0, false, None).is_err());
479
480 assert!(random_sample(5, 10, false, None).is_err());
482 }
483
484 #[test]
485 fn test_stratified_sample() {
486 let targets = array![0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0]; let indices = stratified_sample(&targets, 6, Some(42)).unwrap();
488
489 assert_eq!(indices.len(), 6);
490
491 let mut class_counts = HashMap::new();
493 for &idx in &indices {
494 let class = targets[idx] as i32;
495 *class_counts.entry(class).or_insert(0) += 1;
496 }
497
498 assert!(class_counts.len() <= 3); }
501
502 #[test]
503 fn test_stratified_sample_insufficient_samples() {
504 let targets = array![0.0, 1.0]; assert!(stratified_sample(&targets, 4, Some(42)).is_err());
507 }
508
509 #[test]
510 fn test_importance_sample() {
511 let weights = array![0.1, 0.1, 0.1, 0.8, 0.9, 1.0]; let indices = importance_sample(&weights, 3, false, Some(42)).unwrap();
513
514 assert_eq!(indices.len(), 3);
515 assert!(indices.iter().all(|&i| i < 6));
516
517 let unique_indices: HashSet<_> = indices.iter().cloned().collect();
519 assert_eq!(unique_indices.len(), 3);
520 }
521
522 #[test]
523 fn test_importance_sample_negative_weights() {
524 let weights = array![0.5, -0.1, 0.3]; assert!(importance_sample(&weights, 2, false, None).is_err());
526 }
527
528 #[test]
529 fn test_importance_sample_zero_weights() {
530 let weights = array![0.0, 0.0, 0.0]; assert!(importance_sample(&weights, 2, false, None).is_err());
532 }
533
534 #[test]
535 fn test_bootstrap_sample() {
536 let indices = bootstrap_sample(20, 20, Some(42)).unwrap();
537
538 assert_eq!(indices.len(), 20);
539 assert!(indices.iter().all(|&i| i < 20));
540
541 let unique_indices: HashSet<_> = indices.iter().cloned().collect();
543 assert!(unique_indices.len() < 20); }
545
546 #[test]
547 fn test_multiple_bootstrap_samples() {
548 let samples = multiple_bootstrap_samples(10, 8, 5, Some(42)).unwrap();
549
550 assert_eq!(samples.len(), 5);
551 assert!(samples.iter().all(|sample| sample.len() == 8));
552 assert!(samples.iter().all(|sample| sample.iter().all(|&i| i < 10)));
553
554 assert_ne!(samples[0], samples[1]); }
557
558 #[test]
559 fn test_multiple_bootstrap_samples_invalid_params() {
560 assert!(multiple_bootstrap_samples(10, 10, 0, None).is_err());
561 }
562}