1use crate::error::{DatasetsError, Result};
9use ndarray::Array1;
10use rand::prelude::*;
11use rand::rng;
12use rand::rngs::StdRng;
13use std::collections::HashMap;
14
15pub fn random_sample(
46 n_samples: usize,
47 sample_size: usize,
48 replace: bool,
49 random_seed: Option<u64>,
50) -> Result<Vec<usize>> {
51 if n_samples == 0 {
52 return Err(DatasetsError::InvalidFormat(
53 "Number of samples must be > 0".to_string(),
54 ));
55 }
56
57 if sample_size == 0 {
58 return Err(DatasetsError::InvalidFormat(
59 "Sample size must be > 0".to_string(),
60 ));
61 }
62
63 if !replace && sample_size > n_samples {
64 return Err(DatasetsError::InvalidFormat(format!(
65 "Cannot sample {} items from {} without replacement",
66 sample_size, n_samples
67 )));
68 }
69
70 let mut rng = match random_seed {
71 Some(seed) => StdRng::seed_from_u64(seed),
72 None => {
73 let mut r = rng();
74 StdRng::seed_from_u64(r.next_u64())
75 }
76 };
77
78 let mut indices = Vec::with_capacity(sample_size);
79
80 if replace {
81 for _ in 0..sample_size {
83 indices.push(rng.random_range(0..n_samples));
84 }
85 } else {
86 let mut available: Vec<usize> = (0..n_samples).collect();
88 available.shuffle(&mut rng);
89 indices.extend_from_slice(&available[0..sample_size]);
90 }
91
92 Ok(indices)
93}
94
95pub fn stratified_sample(
129 targets: &Array1<f64>,
130 sample_size: usize,
131 random_seed: Option<u64>,
132) -> Result<Vec<usize>> {
133 if targets.is_empty() {
134 return Err(DatasetsError::InvalidFormat(
135 "Targets array cannot be empty".to_string(),
136 ));
137 }
138
139 if sample_size == 0 {
140 return Err(DatasetsError::InvalidFormat(
141 "Sample size must be > 0".to_string(),
142 ));
143 }
144
145 if sample_size > targets.len() {
146 return Err(DatasetsError::InvalidFormat(format!(
147 "Cannot sample {} items from {} total samples",
148 sample_size,
149 targets.len()
150 )));
151 }
152
153 let mut class_indices: HashMap<i64, Vec<usize>> = HashMap::new();
155 for (i, &target) in targets.iter().enumerate() {
156 let class = target.round() as i64;
157 class_indices.entry(class).or_default().push(i);
158 }
159
160 let mut rng = match random_seed {
161 Some(seed) => StdRng::seed_from_u64(seed),
162 None => {
163 let mut r = rng();
164 StdRng::seed_from_u64(r.next_u64())
165 }
166 };
167
168 let mut stratified_indices = Vec::new();
169 let n_classes = class_indices.len();
170 let base_samples_per_class = sample_size / n_classes;
171 let remainder = sample_size % n_classes;
172
173 let mut class_list: Vec<_> = class_indices.keys().cloned().collect();
174 class_list.sort();
175
176 for (i, &class) in class_list.iter().enumerate() {
177 let class_samples = class_indices.get(&class).unwrap();
178 let samples_for_this_class = if i < remainder {
179 base_samples_per_class + 1
180 } else {
181 base_samples_per_class
182 };
183
184 if samples_for_this_class > class_samples.len() {
185 return Err(DatasetsError::InvalidFormat(format!(
186 "Class {} has only {} samples but needs {} for stratified sampling",
187 class,
188 class_samples.len(),
189 samples_for_this_class
190 )));
191 }
192
193 let sampled_indices = random_sample(
195 class_samples.len(),
196 samples_for_this_class,
197 false,
198 Some(rng.next_u64()),
199 )?;
200
201 for &idx in &sampled_indices {
202 stratified_indices.push(class_samples[idx]);
203 }
204 }
205
206 stratified_indices.shuffle(&mut rng);
207 Ok(stratified_indices)
208}
209
210pub fn importance_sample(
249 weights: &Array1<f64>,
250 sample_size: usize,
251 replace: bool,
252 random_seed: Option<u64>,
253) -> Result<Vec<usize>> {
254 if weights.is_empty() {
255 return Err(DatasetsError::InvalidFormat(
256 "Weights array cannot be empty".to_string(),
257 ));
258 }
259
260 if sample_size == 0 {
261 return Err(DatasetsError::InvalidFormat(
262 "Sample size must be > 0".to_string(),
263 ));
264 }
265
266 if !replace && sample_size > weights.len() {
267 return Err(DatasetsError::InvalidFormat(format!(
268 "Cannot sample {} items from {} without replacement",
269 sample_size,
270 weights.len()
271 )));
272 }
273
274 for &weight in weights.iter() {
276 if weight < 0.0 {
277 return Err(DatasetsError::InvalidFormat(
278 "All weights must be non-negative".to_string(),
279 ));
280 }
281 }
282
283 let weight_sum: f64 = weights.sum();
284 if weight_sum <= 0.0 {
285 return Err(DatasetsError::InvalidFormat(
286 "Sum of weights must be positive".to_string(),
287 ));
288 }
289
290 let mut rng = match random_seed {
291 Some(seed) => StdRng::seed_from_u64(seed),
292 None => {
293 let mut r = rng();
294 StdRng::seed_from_u64(r.next_u64())
295 }
296 };
297
298 let mut indices = Vec::with_capacity(sample_size);
299 let mut available_weights = weights.clone();
300 let mut available_indices: Vec<usize> = (0..weights.len()).collect();
301
302 for _ in 0..sample_size {
303 let current_sum = available_weights.sum();
304 if current_sum <= 0.0 {
305 break;
306 }
307
308 let random_value = rng.random_range(0.0..current_sum);
310
311 let mut cumulative_weight = 0.0;
313 let mut selected_idx = 0;
314
315 for (i, &weight) in available_weights.iter().enumerate() {
316 cumulative_weight += weight;
317 if random_value <= cumulative_weight {
318 selected_idx = i;
319 break;
320 }
321 }
322
323 let original_idx = available_indices[selected_idx];
324 indices.push(original_idx);
325
326 if !replace {
327 available_weights = Array1::from_iter(
329 available_weights
330 .iter()
331 .enumerate()
332 .filter(|(i, _)| *i != selected_idx)
333 .map(|(_, &w)| w),
334 );
335 available_indices.remove(selected_idx);
336 }
337 }
338
339 Ok(indices)
340}
341
342pub fn bootstrap_sample(
373 n_samples: usize,
374 n_bootstrap_samples: usize,
375 random_seed: Option<u64>,
376) -> Result<Vec<usize>> {
377 random_sample(n_samples, n_bootstrap_samples, true, random_seed)
378}
379
380pub fn multiple_bootstrap_samples(
406 n_samples: usize,
407 sample_size: usize,
408 n_bootstrap_rounds: usize,
409 random_seed: Option<u64>,
410) -> Result<Vec<Vec<usize>>> {
411 if n_bootstrap_rounds == 0 {
412 return Err(DatasetsError::InvalidFormat(
413 "Number of bootstrap rounds must be > 0".to_string(),
414 ));
415 }
416
417 let mut rng = match random_seed {
418 Some(seed) => StdRng::seed_from_u64(seed),
419 None => {
420 let mut r = rng();
421 StdRng::seed_from_u64(r.next_u64())
422 }
423 };
424
425 let mut bootstrap_samples = Vec::with_capacity(n_bootstrap_rounds);
426
427 for _ in 0..n_bootstrap_rounds {
428 let sample = random_sample(n_samples, sample_size, true, Some(rng.next_u64()))?;
429 bootstrap_samples.push(sample);
430 }
431
432 Ok(bootstrap_samples)
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438 use ndarray::array;
439 use std::collections::HashSet;
440
441 #[test]
442 fn test_random_sample_without_replacement() {
443 let indices = random_sample(10, 5, false, Some(42)).unwrap();
444
445 assert_eq!(indices.len(), 5);
446 assert!(indices.iter().all(|&i| i < 10));
447
448 let unique_indices: HashSet<_> = indices.iter().cloned().collect();
450 assert_eq!(unique_indices.len(), 5);
451 }
452
453 #[test]
454 fn test_random_sample_with_replacement() {
455 let indices = random_sample(5, 10, true, Some(42)).unwrap();
456
457 assert_eq!(indices.len(), 10);
458 assert!(indices.iter().all(|&i| i < 5));
459
460 let unique_indices: HashSet<_> = indices.iter().cloned().collect();
462 assert!(unique_indices.len() <= 10);
463 }
464
465 #[test]
466 fn test_random_sample_invalid_params() {
467 assert!(random_sample(0, 5, false, None).is_err());
469
470 assert!(random_sample(10, 0, false, None).is_err());
472
473 assert!(random_sample(5, 10, false, None).is_err());
475 }
476
477 #[test]
478 fn test_stratified_sample() {
479 let targets = array![0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0]; let indices = stratified_sample(&targets, 6, Some(42)).unwrap();
481
482 assert_eq!(indices.len(), 6);
483
484 let mut class_counts = HashMap::new();
486 for &idx in &indices {
487 let class = targets[idx] as i32;
488 *class_counts.entry(class).or_insert(0) += 1;
489 }
490
491 assert!(class_counts.len() <= 3); }
494
495 #[test]
496 fn test_stratified_sample_insufficient_samples() {
497 let targets = array![0.0, 1.0]; assert!(stratified_sample(&targets, 4, Some(42)).is_err());
500 }
501
502 #[test]
503 fn test_importance_sample() {
504 let weights = array![0.1, 0.1, 0.1, 0.8, 0.9, 1.0]; let indices = importance_sample(&weights, 3, false, Some(42)).unwrap();
506
507 assert_eq!(indices.len(), 3);
508 assert!(indices.iter().all(|&i| i < 6));
509
510 let unique_indices: HashSet<_> = indices.iter().cloned().collect();
512 assert_eq!(unique_indices.len(), 3);
513 }
514
515 #[test]
516 fn test_importance_sample_negative_weights() {
517 let weights = array![0.5, -0.1, 0.3]; assert!(importance_sample(&weights, 2, false, None).is_err());
519 }
520
521 #[test]
522 fn test_importance_sample_zero_weights() {
523 let weights = array![0.0, 0.0, 0.0]; assert!(importance_sample(&weights, 2, false, None).is_err());
525 }
526
527 #[test]
528 fn test_bootstrap_sample() {
529 let indices = bootstrap_sample(20, 20, Some(42)).unwrap();
530
531 assert_eq!(indices.len(), 20);
532 assert!(indices.iter().all(|&i| i < 20));
533
534 let unique_indices: HashSet<_> = indices.iter().cloned().collect();
536 assert!(unique_indices.len() < 20); }
538
539 #[test]
540 fn test_multiple_bootstrap_samples() {
541 let samples = multiple_bootstrap_samples(10, 8, 5, Some(42)).unwrap();
542
543 assert_eq!(samples.len(), 5);
544 assert!(samples.iter().all(|sample| sample.len() == 8));
545 assert!(samples.iter().all(|sample| sample.iter().all(|&i| i < 10)));
546
547 assert_ne!(samples[0], samples[1]); }
550
551 #[test]
552 fn test_multiple_bootstrap_samples_invalid_params() {
553 assert!(multiple_bootstrap_samples(10, 10, 0, None).is_err());
554 }
555}