1use crate::{UtilsError, UtilsResult};
7use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::random::rngs::StdRng;
9use scirs2_core::random::{Rng, SeedableRng};
10use std::collections::HashMap;
11
12#[derive(Clone, Debug)]
16pub struct Bootstrap {
17 n_samples: Option<usize>,
18 random_state: Option<u64>,
19}
20
21impl Bootstrap {
22 pub fn new(n_samples: Option<usize>, random_state: Option<u64>) -> Self {
28 Self {
29 n_samples,
30 random_state,
31 }
32 }
33
34 pub fn sample(&self, n_population: usize) -> UtilsResult<(Vec<usize>, Vec<usize>)> {
42 if n_population == 0 {
43 return Err(UtilsError::InvalidParameter(
44 "Population size must be positive".to_string(),
45 ));
46 }
47
48 let n_samples = self.n_samples.unwrap_or(n_population);
49 let mut rng = self
50 .random_state
51 .map(StdRng::seed_from_u64)
52 .unwrap_or_else(|| StdRng::seed_from_u64(42));
53
54 let mut in_bag = Vec::with_capacity(n_samples);
56 let mut in_bag_set = vec![false; n_population];
57
58 for _ in 0..n_samples {
59 let idx = rng.gen_range(0..n_population);
60 in_bag.push(idx);
61 in_bag_set[idx] = true;
62 }
63
64 let out_of_bag: Vec<usize> = (0..n_population).filter(|&i| !in_bag_set[i]).collect();
66
67 Ok((in_bag, out_of_bag))
68 }
69
70 pub fn sample_multiple(
79 &self,
80 n_population: usize,
81 n_bootstraps: usize,
82 ) -> UtilsResult<Vec<(Vec<usize>, Vec<usize>)>> {
83 let mut samples = Vec::with_capacity(n_bootstraps);
84
85 for i in 0..n_bootstraps {
86 let seed = self.random_state.map(|s| s + i as u64);
88 let sampler = Bootstrap::new(self.n_samples, seed);
89 samples.push(sampler.sample(n_population)?);
90 }
91
92 Ok(samples)
93 }
94}
95
96impl Default for Bootstrap {
97 fn default() -> Self {
98 Self::new(None, Some(42))
99 }
100}
101
102#[derive(Clone, Debug)]
104pub struct BaggingPredictor {
105 aggregation: AggregationStrategy,
106}
107
108#[derive(Clone, Debug, PartialEq)]
110pub enum AggregationStrategy {
111 Mean,
113 Median,
115 MajorityVote,
117 WeightedMean,
119}
120
121impl BaggingPredictor {
122 pub fn new(aggregation: AggregationStrategy) -> Self {
124 Self { aggregation }
125 }
126
127 pub fn aggregate_regression(
136 &self,
137 predictions: &Array2<f64>,
138 weights: Option<&Array1<f64>>,
139 ) -> UtilsResult<Array1<f64>> {
140 if predictions.nrows() == 0 || predictions.ncols() == 0 {
141 return Err(UtilsError::InvalidParameter(
142 "Predictions array cannot be empty".to_string(),
143 ));
144 }
145
146 match &self.aggregation {
147 AggregationStrategy::Mean => Ok(predictions
148 .mean_axis(scirs2_core::ndarray::Axis(1))
149 .unwrap()),
150 AggregationStrategy::Median => {
151 let mut result = Array1::zeros(predictions.nrows());
152 for (i, row) in predictions.outer_iter().enumerate() {
153 let mut sorted: Vec<f64> = row.to_vec();
154 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
155 let mid = sorted.len() / 2;
156 result[i] = if sorted.len() % 2 == 0 {
157 (sorted[mid - 1] + sorted[mid]) / 2.0
158 } else {
159 sorted[mid]
160 };
161 }
162 Ok(result)
163 }
164 AggregationStrategy::WeightedMean => {
165 let weights = weights.ok_or_else(|| {
166 UtilsError::InvalidParameter("Weights required for weighted mean".to_string())
167 })?;
168
169 if weights.len() != predictions.ncols() {
170 return Err(UtilsError::InvalidParameter(
171 "Number of weights must match number of estimators".to_string(),
172 ));
173 }
174
175 let weight_sum: f64 = weights.sum();
176 if weight_sum <= 0.0 {
177 return Err(UtilsError::InvalidParameter(
178 "Weight sum must be positive".to_string(),
179 ));
180 }
181
182 let normalized_weights = weights / weight_sum;
183 Ok(predictions.dot(&normalized_weights))
184 }
185 AggregationStrategy::MajorityVote => Err(UtilsError::InvalidParameter(
186 "Use aggregate_classification for majority voting".to_string(),
187 )),
188 }
189 }
190
191 pub fn aggregate_classification(
199 &self,
200 predictions: &Array2<usize>,
201 ) -> UtilsResult<Array1<usize>> {
202 if predictions.nrows() == 0 || predictions.ncols() == 0 {
203 return Err(UtilsError::InvalidParameter(
204 "Predictions array cannot be empty".to_string(),
205 ));
206 }
207
208 let mut result = Array1::zeros(predictions.nrows());
209
210 for (i, row) in predictions.outer_iter().enumerate() {
211 let mut vote_counts: HashMap<usize, usize> = HashMap::new();
213 for &pred in row.iter() {
214 *vote_counts.entry(pred).or_insert(0) += 1;
215 }
216
217 let (predicted_class, _) = vote_counts
219 .iter()
220 .max_by_key(|(_, &count)| count)
221 .ok_or_else(|| UtilsError::InvalidParameter("No votes found".to_string()))?;
222
223 result[i] = *predicted_class;
224 }
225
226 Ok(result)
227 }
228
229 pub fn aggregate_probabilities(
238 &self,
239 probabilities: &[Array2<f64>],
240 ) -> UtilsResult<Array2<f64>> {
241 if probabilities.is_empty() {
242 return Err(UtilsError::InvalidParameter(
243 "Probabilities array cannot be empty".to_string(),
244 ));
245 }
246
247 let (n_samples, n_classes) = probabilities[0].dim();
248
249 for probs in probabilities.iter() {
251 if probs.dim() != (n_samples, n_classes) {
252 return Err(UtilsError::InvalidParameter(
253 "All probability matrices must have the same shape".to_string(),
254 ));
255 }
256 }
257
258 let mut result = Array2::zeros((n_samples, n_classes));
260 for probs in probabilities {
261 result += probs;
262 }
263 result /= probabilities.len() as f64;
264
265 Ok(result)
266 }
267}
268
269impl Default for BaggingPredictor {
270 fn default() -> Self {
271 Self::new(AggregationStrategy::Mean)
272 }
273}
274
275#[derive(Clone, Debug)]
279pub struct OOBScoreEstimator;
280
281impl OOBScoreEstimator {
282 pub fn oob_score_regression(
291 y_true: &Array1<f64>,
292 oob_predictions: &Array1<f64>,
293 ) -> UtilsResult<f64> {
294 if y_true.len() != oob_predictions.len() {
295 return Err(UtilsError::InvalidParameter(
296 "y_true and predictions must have same length".to_string(),
297 ));
298 }
299
300 if y_true.is_empty() {
301 return Err(UtilsError::InvalidParameter(
302 "Cannot compute score on empty array".to_string(),
303 ));
304 }
305
306 let y_mean = y_true.mean().unwrap();
308 let ss_tot: f64 = y_true.iter().map(|&y| (y - y_mean).powi(2)).sum();
309 let ss_res: f64 = y_true
310 .iter()
311 .zip(oob_predictions.iter())
312 .map(|(&y, &pred)| (y - pred).powi(2))
313 .sum();
314
315 if ss_tot <= 0.0 {
316 Ok(0.0)
317 } else {
318 Ok(1.0 - ss_res / ss_tot)
319 }
320 }
321
322 pub fn oob_accuracy(
331 y_true: &Array1<usize>,
332 oob_predictions: &Array1<usize>,
333 ) -> UtilsResult<f64> {
334 if y_true.len() != oob_predictions.len() {
335 return Err(UtilsError::InvalidParameter(
336 "y_true and predictions must have same length".to_string(),
337 ));
338 }
339
340 if y_true.is_empty() {
341 return Err(UtilsError::InvalidParameter(
342 "Cannot compute score on empty array".to_string(),
343 ));
344 }
345
346 let correct: usize = y_true
347 .iter()
348 .zip(oob_predictions.iter())
349 .filter(|(&y, &pred)| y == pred)
350 .count();
351
352 Ok(correct as f64 / y_true.len() as f64)
353 }
354}
355
356#[derive(Clone, Debug)]
358pub struct StackingHelper;
359
360impl StackingHelper {
361 pub fn generate_cv_folds(
371 n_samples: usize,
372 n_folds: usize,
373 random_state: Option<u64>,
374 ) -> UtilsResult<Vec<(Vec<usize>, Vec<usize>)>> {
375 if n_folds < 2 {
376 return Err(UtilsError::InvalidParameter(
377 "n_folds must be at least 2".to_string(),
378 ));
379 }
380
381 if n_samples < n_folds {
382 return Err(UtilsError::InvalidParameter(
383 "n_samples must be >= n_folds".to_string(),
384 ));
385 }
386
387 let mut indices: Vec<usize> = (0..n_samples).collect();
389 let mut rng = random_state
390 .map(StdRng::seed_from_u64)
391 .unwrap_or_else(|| StdRng::seed_from_u64(42));
392
393 for i in (1..indices.len()).rev() {
395 let j = rng.gen_range(0..=i);
396 indices.swap(i, j);
397 }
398
399 let fold_sizes = Self::compute_fold_sizes(n_samples, n_folds);
401 let mut folds = Vec::with_capacity(n_folds);
402 let mut start = 0;
403
404 for size in fold_sizes {
405 let test_indices = indices[start..start + size].to_vec();
406 let train_indices: Vec<usize> = indices
407 .iter()
408 .enumerate()
409 .filter(|(i, _)| *i < start || *i >= start + size)
410 .map(|(_, &idx)| idx)
411 .collect();
412
413 folds.push((train_indices, test_indices));
414 start += size;
415 }
416
417 Ok(folds)
418 }
419
420 fn compute_fold_sizes(n_samples: usize, n_folds: usize) -> Vec<usize> {
421 let base_size = n_samples / n_folds;
422 let remainder = n_samples % n_folds;
423
424 (0..n_folds)
425 .map(|i| {
426 if i < remainder {
427 base_size + 1
428 } else {
429 base_size
430 }
431 })
432 .collect()
433 }
434}
435
436#[cfg(test)]
437mod tests {
438 use super::*;
439 use approx::assert_abs_diff_eq;
440 use scirs2_core::ndarray::array;
441
442 #[test]
443 fn test_bootstrap_sample() {
444 let bootstrap = Bootstrap::new(Some(10), Some(42));
445 let (in_bag, out_of_bag) = bootstrap.sample(10).unwrap();
446
447 assert_eq!(in_bag.len(), 10);
448 assert!(out_of_bag.len() > 0); assert!(out_of_bag.len() < 10);
450
451 for &idx in &in_bag {
453 assert!(idx < 10);
454 }
455 for &idx in &out_of_bag {
456 assert!(idx < 10);
457 }
458 }
459
460 #[test]
461 fn test_bootstrap_multiple() {
462 let bootstrap = Bootstrap::new(None, Some(42));
463 let samples = bootstrap.sample_multiple(10, 5).unwrap();
464
465 assert_eq!(samples.len(), 5);
466
467 for (in_bag, out_of_bag) in &samples {
469 assert_eq!(in_bag.len(), 10);
470 assert!(out_of_bag.len() <= 10);
471 }
472 }
473
474 #[test]
475 fn test_bagging_mean_aggregation() {
476 let predictor = BaggingPredictor::new(AggregationStrategy::Mean);
477
478 let predictions = array![
480 [1.0, 2.0, 3.0, 4.0],
481 [2.0, 2.0, 2.0, 2.0],
482 [1.0, 3.0, 2.0, 4.0]
483 ];
484
485 let result = predictor.aggregate_regression(&predictions, None).unwrap();
486
487 assert_abs_diff_eq!(result[0], 2.5, epsilon = 1e-10);
488 assert_abs_diff_eq!(result[1], 2.0, epsilon = 1e-10);
489 assert_abs_diff_eq!(result[2], 2.5, epsilon = 1e-10);
490 }
491
492 #[test]
493 fn test_bagging_median_aggregation() {
494 let predictor = BaggingPredictor::new(AggregationStrategy::Median);
495
496 let predictions = array![
498 [1.0, 2.0, 100.0, 3.0], [1.0, 2.0, 3.0, 4.0] ];
501
502 let result = predictor.aggregate_regression(&predictions, None).unwrap();
503
504 assert_abs_diff_eq!(result[0], 2.5, epsilon = 1e-10);
505 assert_abs_diff_eq!(result[1], 2.5, epsilon = 1e-10);
506
507 let predictions2 = array![
509 [1.0, 2.0, 3.0, 4.0, 5.0], [10.0, 1.0, 2.0, 3.0, 100.0] ];
512
513 let result2 = predictor.aggregate_regression(&predictions2, None).unwrap();
514 assert_abs_diff_eq!(result2[0], 3.0, epsilon = 1e-10);
515 assert_abs_diff_eq!(result2[1], 3.0, epsilon = 1e-10);
516 }
517
518 #[test]
519 fn test_bagging_weighted_mean() {
520 let predictor = BaggingPredictor::new(AggregationStrategy::WeightedMean);
521
522 let predictions = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
523 let weights = array![0.5, 0.3, 0.2]; let result = predictor
526 .aggregate_regression(&predictions, Some(&weights))
527 .unwrap();
528
529 assert_abs_diff_eq!(result[0], 1.7, epsilon = 1e-10);
531 assert_abs_diff_eq!(result[1], 4.7, epsilon = 1e-10);
533 }
534
535 #[test]
536 fn test_majority_vote() {
537 let predictor = BaggingPredictor::new(AggregationStrategy::MajorityVote);
538
539 let predictions = array![
540 [0, 0, 1, 0, 0], [1, 1, 0, 1, 1], [2, 2, 2, 0, 1] ];
544
545 let result = predictor.aggregate_classification(&predictions).unwrap();
546
547 assert_eq!(result[0], 0);
548 assert_eq!(result[1], 1);
549 assert_eq!(result[2], 2);
550 }
551
552 #[test]
553 fn test_aggregate_probabilities() {
554 let predictor = BaggingPredictor::default();
555
556 let probs1 = array![[0.8, 0.2], [0.3, 0.7]];
557 let probs2 = array![[0.6, 0.4], [0.4, 0.6]];
558
559 let result = predictor
560 .aggregate_probabilities(&[probs1, probs2])
561 .unwrap();
562
563 assert_abs_diff_eq!(result[[0, 0]], 0.7, epsilon = 1e-10);
564 assert_abs_diff_eq!(result[[0, 1]], 0.3, epsilon = 1e-10);
565 assert_abs_diff_eq!(result[[1, 0]], 0.35, epsilon = 1e-10);
566 assert_abs_diff_eq!(result[[1, 1]], 0.65, epsilon = 1e-10);
567 }
568
569 #[test]
570 fn test_oob_score_regression() {
571 let y_true = array![1.0, 2.0, 3.0, 4.0, 5.0];
572 let y_pred = array![1.1, 1.9, 3.1, 3.9, 5.1];
573
574 let score = OOBScoreEstimator::oob_score_regression(&y_true, &y_pred).unwrap();
575
576 assert!(score > 0.95);
578 }
579
580 #[test]
581 fn test_oob_accuracy() {
582 let y_true = array![0, 1, 2, 0, 1];
583 let y_pred = array![0, 1, 2, 0, 2]; let accuracy = OOBScoreEstimator::oob_accuracy(&y_true, &y_pred).unwrap();
586
587 assert_abs_diff_eq!(accuracy, 0.8, epsilon = 1e-10);
588 }
589
590 #[test]
591 fn test_stacking_cv_folds() {
592 let folds = StackingHelper::generate_cv_folds(10, 3, Some(42)).unwrap();
593
594 assert_eq!(folds.len(), 3);
595
596 let mut all_test_indices: Vec<usize> = Vec::new();
598 for (train, test) in &folds {
599 assert!(train.len() > 0);
600 assert!(test.len() > 0);
601 assert_eq!(train.len() + test.len(), 10);
602 all_test_indices.extend(test);
603 }
604
605 all_test_indices.sort_unstable();
606 assert_eq!(all_test_indices, vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
607 }
608}