1use crate::{Model, TrainError, TrainResult};
10use scirs2_core::ndarray::Array2;
11
12pub trait Ensemble {
14 fn predict(&self, input: &Array2<f64>) -> TrainResult<Array2<f64>>;
22
23 fn num_models(&self) -> usize;
25}
26
27#[derive(Debug, Clone, Copy, PartialEq)]
33pub enum VotingMode {
34 Hard,
36 Soft,
38}
39
40#[derive(Debug)]
42pub struct VotingEnsemble<M: Model> {
43 models: Vec<M>,
45 mode: VotingMode,
47 weights: Option<Vec<f64>>,
49}
50
51impl<M: Model> VotingEnsemble<M> {
52 pub fn new(models: Vec<M>, mode: VotingMode) -> TrainResult<Self> {
58 if models.is_empty() {
59 return Err(TrainError::InvalidParameter(
60 "Ensemble must have at least one model".to_string(),
61 ));
62 }
63 Ok(Self {
64 models,
65 mode,
66 weights: None,
67 })
68 }
69
70 pub fn with_weights(mut self, weights: Vec<f64>) -> TrainResult<Self> {
75 if weights.len() != self.models.len() {
76 return Err(TrainError::InvalidParameter(
77 "Number of weights must match number of models".to_string(),
78 ));
79 }
80
81 let sum: f64 = weights.iter().sum();
82 if (sum - 1.0).abs() > 1e-6 {
83 return Err(TrainError::InvalidParameter(
84 "Weights must sum to 1.0".to_string(),
85 ));
86 }
87
88 self.weights = Some(weights);
89 Ok(self)
90 }
91
92 pub fn mode(&self) -> VotingMode {
94 self.mode
95 }
96}
97
98impl<M: Model> Ensemble for VotingEnsemble<M> {
99 fn predict(&self, input: &Array2<f64>) -> TrainResult<Array2<f64>> {
100 let batch_size = input.nrows();
101
102 let mut all_predictions = Vec::with_capacity(self.models.len());
104 for model in &self.models {
105 let pred = model.forward(&input.view())?;
106 all_predictions.push(pred);
107 }
108
109 let num_classes = all_predictions[0].ncols();
111 let mut ensemble_pred = Array2::zeros((batch_size, num_classes));
112
113 match self.mode {
114 VotingMode::Hard => {
115 for i in 0..batch_size {
117 let mut votes = vec![0.0; num_classes];
118
119 for (model_idx, pred) in all_predictions.iter().enumerate() {
120 let row = pred.row(i);
122 let class_idx = row
123 .iter()
124 .enumerate()
125 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
126 .map(|(idx, _)| idx)
127 .unwrap_or(0);
128
129 let weight = self.weights.as_ref().map(|w| w[model_idx]).unwrap_or(1.0);
130 votes[class_idx] += weight;
131 }
132
133 let max_votes = votes.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
135 let winning_class = votes
136 .iter()
137 .position(|&v| (v - max_votes).abs() < 1e-10)
138 .unwrap();
139
140 ensemble_pred[[i, winning_class]] = 1.0;
141 }
142 }
143 VotingMode::Soft => {
144 for i in 0..batch_size {
146 for j in 0..num_classes {
147 let mut weighted_sum = 0.0;
148
149 for (model_idx, pred) in all_predictions.iter().enumerate() {
150 let weight = self.weights.as_ref().map(|w| w[model_idx]).unwrap_or(1.0);
151 weighted_sum += pred[[i, j]] * weight;
152 }
153
154 let normalizer = if self.weights.is_some() {
155 1.0 } else {
157 self.models.len() as f64
158 };
159
160 ensemble_pred[[i, j]] = weighted_sum / normalizer;
161 }
162 }
163 }
164 }
165
166 Ok(ensemble_pred)
167 }
168
169 fn num_models(&self) -> usize {
170 self.models.len()
171 }
172}
173
174#[derive(Debug)]
178pub struct AveragingEnsemble<M: Model> {
179 models: Vec<M>,
181 weights: Option<Vec<f64>>,
183}
184
185impl<M: Model> AveragingEnsemble<M> {
186 pub fn new(models: Vec<M>) -> TrainResult<Self> {
191 if models.is_empty() {
192 return Err(TrainError::InvalidParameter(
193 "Ensemble must have at least one model".to_string(),
194 ));
195 }
196 Ok(Self {
197 models,
198 weights: None,
199 })
200 }
201
202 pub fn with_weights(mut self, weights: Vec<f64>) -> TrainResult<Self> {
207 if weights.len() != self.models.len() {
208 return Err(TrainError::InvalidParameter(
209 "Number of weights must match number of models".to_string(),
210 ));
211 }
212
213 let sum: f64 = weights.iter().sum();
215 if sum <= 0.0 {
216 return Err(TrainError::InvalidParameter(
217 "Weights must sum to a positive value".to_string(),
218 ));
219 }
220
221 let normalized_weights: Vec<f64> = weights.iter().map(|w| w / sum).collect();
222 self.weights = Some(normalized_weights);
223 Ok(self)
224 }
225}
226
227impl<M: Model> Ensemble for AveragingEnsemble<M> {
228 fn predict(&self, input: &Array2<f64>) -> TrainResult<Array2<f64>> {
229 let mut all_predictions = Vec::with_capacity(self.models.len());
231 for model in &self.models {
232 let pred = model.forward(&input.view())?;
233 all_predictions.push(pred);
234 }
235
236 let shape = all_predictions[0].raw_dim();
238 let mut ensemble_pred = Array2::zeros(shape);
239
240 for (model_idx, pred) in all_predictions.iter().enumerate() {
241 let weight = self.weights.as_ref().map(|w| w[model_idx]).unwrap_or(1.0);
242
243 for i in 0..pred.nrows() {
244 for j in 0..pred.ncols() {
245 ensemble_pred[[i, j]] += pred[[i, j]] * weight;
246 }
247 }
248 }
249
250 if self.weights.is_none() {
252 ensemble_pred /= self.models.len() as f64;
253 }
254
255 Ok(ensemble_pred)
256 }
257
258 fn num_models(&self) -> usize {
259 self.models.len()
260 }
261}
262
263#[derive(Debug)]
267pub struct StackingEnsemble<M: Model, Meta: Model> {
268 base_models: Vec<M>,
270 meta_model: Meta,
272}
273
274impl<M: Model, Meta: Model> StackingEnsemble<M, Meta> {
275 pub fn new(base_models: Vec<M>, meta_model: Meta) -> TrainResult<Self> {
281 if base_models.is_empty() {
282 return Err(TrainError::InvalidParameter(
283 "Ensemble must have at least one base model".to_string(),
284 ));
285 }
286 Ok(Self {
287 base_models,
288 meta_model,
289 })
290 }
291
292 pub fn generate_meta_features(&self, input: &Array2<f64>) -> TrainResult<Array2<f64>> {
300 let batch_size = input.nrows();
301
302 let mut all_predictions = Vec::with_capacity(self.base_models.len());
304 for model in &self.base_models {
305 let pred = model.forward(&input.view())?;
306 all_predictions.push(pred);
307 }
308
309 let num_features_per_model = all_predictions[0].ncols();
311 let total_features = self.base_models.len() * num_features_per_model;
312
313 let mut meta_features = Array2::zeros((batch_size, total_features));
314
315 for (model_idx, pred) in all_predictions.iter().enumerate() {
316 let start_col = model_idx * num_features_per_model;
317
318 for i in 0..batch_size {
319 for j in 0..num_features_per_model {
320 meta_features[[i, start_col + j]] = pred[[i, j]];
321 }
322 }
323 }
324
325 Ok(meta_features)
326 }
327}
328
329impl<M: Model, Meta: Model> Ensemble for StackingEnsemble<M, Meta> {
330 fn predict(&self, input: &Array2<f64>) -> TrainResult<Array2<f64>> {
331 let meta_features = self.generate_meta_features(input)?;
333
334 self.meta_model.forward(&meta_features.view())
336 }
337
338 fn num_models(&self) -> usize {
339 self.base_models.len() + 1 }
341}
342
343#[derive(Debug)]
347pub struct BaggingHelper {
348 pub n_estimators: usize,
350 pub random_seed: u64,
352}
353
354impl BaggingHelper {
355 pub fn new(n_estimators: usize, random_seed: u64) -> TrainResult<Self> {
361 if n_estimators == 0 {
362 return Err(TrainError::InvalidParameter(
363 "n_estimators must be positive".to_string(),
364 ));
365 }
366 Ok(Self {
367 n_estimators,
368 random_seed,
369 })
370 }
371
372 pub fn generate_bootstrap_indices(&self, n_samples: usize, estimator_idx: usize) -> Vec<usize> {
381 #[allow(unused_imports)]
382 use scirs2_core::random::{Rng, SeedableRng, StdRng};
383
384 let seed = self.random_seed.wrapping_add(estimator_idx as u64);
385 let mut rng = StdRng::seed_from_u64(seed);
386
387 (0..n_samples)
388 .map(|_| rng.gen_range(0..n_samples))
389 .collect()
390 }
391
392 pub fn get_oob_indices(&self, n_samples: usize, bootstrap_indices: &[usize]) -> Vec<usize> {
401 let bootstrap_set: std::collections::HashSet<usize> =
402 bootstrap_indices.iter().cloned().collect();
403
404 (0..n_samples)
405 .filter(|idx| !bootstrap_set.contains(idx))
406 .collect()
407 }
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413 use crate::LinearModel;
414 use scirs2_core::ndarray::array;
415
416 fn create_test_model() -> LinearModel {
417 LinearModel::new(2, 2)
419 }
420
421 #[test]
422 fn test_voting_ensemble_hard() {
423 let model1 = create_test_model();
424 let model2 = create_test_model();
425
426 let ensemble = VotingEnsemble::new(vec![model1, model2], VotingMode::Hard).unwrap();
427
428 assert_eq!(ensemble.num_models(), 2);
429 assert_eq!(ensemble.mode(), VotingMode::Hard);
430
431 let input = array![[1.0, 0.0], [0.0, 1.0]];
432 let pred = ensemble.predict(&input).unwrap();
433
434 assert_eq!(pred.shape(), &[2, 2]);
435 }
436
437 #[test]
438 fn test_voting_ensemble_soft() {
439 let model1 = create_test_model();
440 let model2 = create_test_model();
441
442 let ensemble = VotingEnsemble::new(vec![model1, model2], VotingMode::Soft).unwrap();
443
444 let input = array![[1.0, 0.0]];
445 let pred = ensemble.predict(&input).unwrap();
446
447 assert_eq!(pred.shape(), &[1, 2]);
448 }
449
450 #[test]
451 fn test_voting_ensemble_with_weights() {
452 let model1 = create_test_model();
453 let model2 = create_test_model();
454
455 let ensemble = VotingEnsemble::new(vec![model1, model2], VotingMode::Soft)
456 .unwrap()
457 .with_weights(vec![0.7, 0.3])
458 .unwrap();
459
460 let input = array![[1.0, 0.0]];
461 let pred = ensemble.predict(&input).unwrap();
462
463 assert_eq!(pred.shape(), &[1, 2]);
464 }
465
466 #[test]
467 fn test_voting_ensemble_invalid_weights() {
468 let model1 = create_test_model();
469 let model2 = create_test_model();
470
471 let ensemble = VotingEnsemble::new(vec![model1, model2], VotingMode::Soft).unwrap();
472
473 let result = ensemble.with_weights(vec![0.5]);
475 assert!(result.is_err());
476
477 let model3 = create_test_model();
479 let model4 = create_test_model();
480 let ensemble2 = VotingEnsemble::new(vec![model3, model4], VotingMode::Soft).unwrap();
481 let result = ensemble2.with_weights(vec![0.5, 0.6]);
482 assert!(result.is_err());
483 }
484
485 #[test]
486 fn test_averaging_ensemble() {
487 let model1 = create_test_model();
488 let model2 = create_test_model();
489
490 let ensemble = AveragingEnsemble::new(vec![model1, model2]).unwrap();
491
492 assert_eq!(ensemble.num_models(), 2);
493
494 let input = array![[1.0, 0.0], [0.0, 1.0]];
495 let pred = ensemble.predict(&input).unwrap();
496
497 assert_eq!(pred.shape(), &[2, 2]);
498 }
499
500 #[test]
501 fn test_averaging_ensemble_with_weights() {
502 let model1 = create_test_model();
503 let model2 = create_test_model();
504
505 let ensemble = AveragingEnsemble::new(vec![model1, model2])
506 .unwrap()
507 .with_weights(vec![2.0, 1.0])
508 .unwrap();
509
510 let input = array![[1.0, 0.0]];
511 let pred = ensemble.predict(&input).unwrap();
512
513 assert_eq!(pred.shape(), &[1, 2]);
514 }
515
516 #[test]
517 fn test_stacking_ensemble() {
518 let base1 = create_test_model(); let base2 = create_test_model(); let meta = LinearModel::new(4, 2); let ensemble = StackingEnsemble::new(vec![base1, base2], meta).unwrap();
523
524 assert_eq!(ensemble.num_models(), 3); let input = array![[1.0, 0.0]];
527 let pred = ensemble.predict(&input).unwrap();
528
529 assert_eq!(pred.nrows(), 1);
531 }
532
533 #[test]
534 fn test_stacking_meta_features() {
535 let base1 = create_test_model();
536 let base2 = create_test_model();
537 let meta = create_test_model();
538
539 let ensemble = StackingEnsemble::new(vec![base1, base2], meta).unwrap();
540
541 let input = array![[1.0, 0.0]];
542 let meta_features = ensemble.generate_meta_features(&input).unwrap();
543
544 assert_eq!(meta_features.shape(), &[1, 4]);
547 }
548
549 #[test]
550 fn test_bagging_helper() {
551 let helper = BaggingHelper::new(10, 42).unwrap();
552
553 let indices = helper.generate_bootstrap_indices(100, 0);
554 assert_eq!(indices.len(), 100);
555
556 assert!(indices.iter().all(|&i| i < 100));
558
559 let oob = helper.get_oob_indices(100, &indices);
561 assert!(!oob.is_empty());
562
563 for &idx in &oob {
564 assert!(!indices.contains(&idx));
565 }
566 }
567
568 #[test]
569 fn test_bagging_helper_different_seeds() {
570 let helper = BaggingHelper::new(10, 42).unwrap();
571
572 let indices1 = helper.generate_bootstrap_indices(50, 0);
573 let indices2 = helper.generate_bootstrap_indices(50, 1);
574
575 assert_ne!(indices1, indices2);
577 }
578
579 #[test]
580 fn test_bagging_helper_invalid() {
581 assert!(BaggingHelper::new(0, 42).is_err());
582 }
583
584 #[test]
585 fn test_ensemble_empty_models() {
586 let result = VotingEnsemble::<LinearModel>::new(vec![], VotingMode::Hard);
587 assert!(result.is_err());
588
589 let result = AveragingEnsemble::<LinearModel>::new(vec![]);
590 assert!(result.is_err());
591 }
592}