tensorlogic_train/
ensemble.rs

1//! Model ensembling utilities for combining multiple models.
2//!
3//! This module provides various ensemble strategies:
4//! - Voting ensembles (hard and soft voting)
5//! - Averaging ensembles (simple and weighted)
6//! - Stacking ensembles (meta-learner)
7//! - Bagging utilities
8
9use crate::{Model, TrainError, TrainResult};
10use scirs2_core::ndarray::Array2;
11
12/// Trait for ensemble methods.
13pub trait Ensemble {
14    /// Predict using the ensemble.
15    ///
16    /// # Arguments
17    /// * `input` - Input data [batch_size, features]
18    ///
19    /// # Returns
20    /// Ensemble predictions [batch_size, num_classes]
21    fn predict(&self, input: &Array2<f64>) -> TrainResult<Array2<f64>>;
22
23    /// Get the number of models in the ensemble.
24    fn num_models(&self) -> usize;
25}
26
27/// Voting ensemble for classification.
28///
29/// Combines predictions from multiple models using voting:
30/// - Hard voting: Majority vote (class with most votes wins)
31/// - Soft voting: Average predicted probabilities
32#[derive(Debug, Clone, Copy, PartialEq)]
33pub enum VotingMode {
34    /// Hard voting (majority vote).
35    Hard,
36    /// Soft voting (average probabilities).
37    Soft,
38}
39
40/// Voting ensemble configuration.
41#[derive(Debug)]
42pub struct VotingEnsemble<M: Model> {
43    /// Base models in the ensemble.
44    models: Vec<M>,
45    /// Voting mode (hard or soft).
46    mode: VotingMode,
47    /// Model weights (for weighted voting).
48    weights: Option<Vec<f64>>,
49}
50
51impl<M: Model> VotingEnsemble<M> {
52    /// Create a new voting ensemble.
53    ///
54    /// # Arguments
55    /// * `models` - Base models to ensemble
56    /// * `mode` - Voting mode (hard or soft)
57    pub fn new(models: Vec<M>, mode: VotingMode) -> TrainResult<Self> {
58        if models.is_empty() {
59            return Err(TrainError::InvalidParameter(
60                "Ensemble must have at least one model".to_string(),
61            ));
62        }
63        Ok(Self {
64            models,
65            mode,
66            weights: None,
67        })
68    }
69
70    /// Set model weights for weighted voting.
71    ///
72    /// # Arguments
73    /// * `weights` - Weight for each model (must sum to 1.0)
74    pub fn with_weights(mut self, weights: Vec<f64>) -> TrainResult<Self> {
75        if weights.len() != self.models.len() {
76            return Err(TrainError::InvalidParameter(
77                "Number of weights must match number of models".to_string(),
78            ));
79        }
80
81        let sum: f64 = weights.iter().sum();
82        if (sum - 1.0).abs() > 1e-6 {
83            return Err(TrainError::InvalidParameter(
84                "Weights must sum to 1.0".to_string(),
85            ));
86        }
87
88        self.weights = Some(weights);
89        Ok(self)
90    }
91
92    /// Get voting mode.
93    pub fn mode(&self) -> VotingMode {
94        self.mode
95    }
96}
97
98impl<M: Model> Ensemble for VotingEnsemble<M> {
99    fn predict(&self, input: &Array2<f64>) -> TrainResult<Array2<f64>> {
100        let batch_size = input.nrows();
101
102        // Collect predictions from all models
103        let mut all_predictions = Vec::with_capacity(self.models.len());
104        for model in &self.models {
105            let pred = model.forward(&input.view())?;
106            all_predictions.push(pred);
107        }
108
109        // Get output shape from first prediction
110        let num_classes = all_predictions[0].ncols();
111        let mut ensemble_pred = Array2::zeros((batch_size, num_classes));
112
113        match self.mode {
114            VotingMode::Hard => {
115                // Hard voting: count votes for each class
116                for i in 0..batch_size {
117                    let mut votes = vec![0.0; num_classes];
118
119                    for (model_idx, pred) in all_predictions.iter().enumerate() {
120                        // Get predicted class (argmax)
121                        let row = pred.row(i);
122                        let class_idx = row
123                            .iter()
124                            .enumerate()
125                            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
126                            .map(|(idx, _)| idx)
127                            .unwrap_or(0);
128
129                        let weight = self.weights.as_ref().map(|w| w[model_idx]).unwrap_or(1.0);
130                        votes[class_idx] += weight;
131                    }
132
133                    // Convert votes to one-hot prediction
134                    let max_votes = votes.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
135                    let winning_class = votes
136                        .iter()
137                        .position(|&v| (v - max_votes).abs() < 1e-10)
138                        .unwrap();
139
140                    ensemble_pred[[i, winning_class]] = 1.0;
141                }
142            }
143            VotingMode::Soft => {
144                // Soft voting: average probabilities
145                for i in 0..batch_size {
146                    for j in 0..num_classes {
147                        let mut weighted_sum = 0.0;
148
149                        for (model_idx, pred) in all_predictions.iter().enumerate() {
150                            let weight = self.weights.as_ref().map(|w| w[model_idx]).unwrap_or(1.0);
151                            weighted_sum += pred[[i, j]] * weight;
152                        }
153
154                        let normalizer = if self.weights.is_some() {
155                            1.0 // Weights already sum to 1.0
156                        } else {
157                            self.models.len() as f64
158                        };
159
160                        ensemble_pred[[i, j]] = weighted_sum / normalizer;
161                    }
162                }
163            }
164        }
165
166        Ok(ensemble_pred)
167    }
168
169    fn num_models(&self) -> usize {
170        self.models.len()
171    }
172}
173
174/// Averaging ensemble for regression.
175///
176/// Combines predictions by averaging (simple or weighted).
177#[derive(Debug)]
178pub struct AveragingEnsemble<M: Model> {
179    /// Base models in the ensemble.
180    models: Vec<M>,
181    /// Model weights (for weighted averaging).
182    weights: Option<Vec<f64>>,
183}
184
185impl<M: Model> AveragingEnsemble<M> {
186    /// Create a new averaging ensemble.
187    ///
188    /// # Arguments
189    /// * `models` - Base models to ensemble
190    pub fn new(models: Vec<M>) -> TrainResult<Self> {
191        if models.is_empty() {
192            return Err(TrainError::InvalidParameter(
193                "Ensemble must have at least one model".to_string(),
194            ));
195        }
196        Ok(Self {
197            models,
198            weights: None,
199        })
200    }
201
202    /// Set model weights for weighted averaging.
203    ///
204    /// # Arguments
205    /// * `weights` - Weight for each model
206    pub fn with_weights(mut self, weights: Vec<f64>) -> TrainResult<Self> {
207        if weights.len() != self.models.len() {
208            return Err(TrainError::InvalidParameter(
209                "Number of weights must match number of models".to_string(),
210            ));
211        }
212
213        // Normalize weights
214        let sum: f64 = weights.iter().sum();
215        if sum <= 0.0 {
216            return Err(TrainError::InvalidParameter(
217                "Weights must sum to a positive value".to_string(),
218            ));
219        }
220
221        let normalized_weights: Vec<f64> = weights.iter().map(|w| w / sum).collect();
222        self.weights = Some(normalized_weights);
223        Ok(self)
224    }
225}
226
227impl<M: Model> Ensemble for AveragingEnsemble<M> {
228    fn predict(&self, input: &Array2<f64>) -> TrainResult<Array2<f64>> {
229        // Collect predictions from all models
230        let mut all_predictions = Vec::with_capacity(self.models.len());
231        for model in &self.models {
232            let pred = model.forward(&input.view())?;
233            all_predictions.push(pred);
234        }
235
236        // Average predictions
237        let shape = all_predictions[0].raw_dim();
238        let mut ensemble_pred = Array2::zeros(shape);
239
240        for (model_idx, pred) in all_predictions.iter().enumerate() {
241            let weight = self.weights.as_ref().map(|w| w[model_idx]).unwrap_or(1.0);
242
243            for i in 0..pred.nrows() {
244                for j in 0..pred.ncols() {
245                    ensemble_pred[[i, j]] += pred[[i, j]] * weight;
246                }
247            }
248        }
249
250        // Normalize if using uniform weights
251        if self.weights.is_none() {
252            ensemble_pred /= self.models.len() as f64;
253        }
254
255        Ok(ensemble_pred)
256    }
257
258    fn num_models(&self) -> usize {
259        self.models.len()
260    }
261}
262
263/// Stacking ensemble with a meta-learner.
264///
265/// Uses base models' predictions as features for a meta-model.
266#[derive(Debug)]
267pub struct StackingEnsemble<M: Model, Meta: Model> {
268    /// Base models (first level).
269    base_models: Vec<M>,
270    /// Meta-model (second level).
271    meta_model: Meta,
272}
273
274impl<M: Model, Meta: Model> StackingEnsemble<M, Meta> {
275    /// Create a new stacking ensemble.
276    ///
277    /// # Arguments
278    /// * `base_models` - First-level base models
279    /// * `meta_model` - Second-level meta-learner
280    pub fn new(base_models: Vec<M>, meta_model: Meta) -> TrainResult<Self> {
281        if base_models.is_empty() {
282            return Err(TrainError::InvalidParameter(
283                "Ensemble must have at least one base model".to_string(),
284            ));
285        }
286        Ok(Self {
287            base_models,
288            meta_model,
289        })
290    }
291
292    /// Generate meta-features from base model predictions.
293    ///
294    /// # Arguments
295    /// * `input` - Input data
296    ///
297    /// # Returns
298    /// Meta-features [batch_size, num_base_models * num_classes]
299    pub fn generate_meta_features(&self, input: &Array2<f64>) -> TrainResult<Array2<f64>> {
300        let batch_size = input.nrows();
301
302        // Collect predictions from all base models
303        let mut all_predictions = Vec::with_capacity(self.base_models.len());
304        for model in &self.base_models {
305            let pred = model.forward(&input.view())?;
306            all_predictions.push(pred);
307        }
308
309        // Concatenate predictions horizontally to form meta-features
310        let num_features_per_model = all_predictions[0].ncols();
311        let total_features = self.base_models.len() * num_features_per_model;
312
313        let mut meta_features = Array2::zeros((batch_size, total_features));
314
315        for (model_idx, pred) in all_predictions.iter().enumerate() {
316            let start_col = model_idx * num_features_per_model;
317
318            for i in 0..batch_size {
319                for j in 0..num_features_per_model {
320                    meta_features[[i, start_col + j]] = pred[[i, j]];
321                }
322            }
323        }
324
325        Ok(meta_features)
326    }
327}
328
329impl<M: Model, Meta: Model> Ensemble for StackingEnsemble<M, Meta> {
330    fn predict(&self, input: &Array2<f64>) -> TrainResult<Array2<f64>> {
331        // Generate meta-features from base models
332        let meta_features = self.generate_meta_features(input)?;
333
334        // Make final prediction with meta-model
335        self.meta_model.forward(&meta_features.view())
336    }
337
338    fn num_models(&self) -> usize {
339        self.base_models.len() + 1 // base models + meta model
340    }
341}
342
343/// Bagging (Bootstrap Aggregating) utilities.
344///
345/// Generates bootstrap samples for training ensemble members.
346#[derive(Debug)]
347pub struct BaggingHelper {
348    /// Number of bootstrap samples.
349    pub n_estimators: usize,
350    /// Random seed for reproducibility.
351    pub random_seed: u64,
352}
353
354impl BaggingHelper {
355    /// Create a new bagging helper.
356    ///
357    /// # Arguments
358    /// * `n_estimators` - Number of bootstrap samples
359    /// * `random_seed` - Random seed
360    pub fn new(n_estimators: usize, random_seed: u64) -> TrainResult<Self> {
361        if n_estimators == 0 {
362            return Err(TrainError::InvalidParameter(
363                "n_estimators must be positive".to_string(),
364            ));
365        }
366        Ok(Self {
367            n_estimators,
368            random_seed,
369        })
370    }
371
372    /// Generate bootstrap sample indices.
373    ///
374    /// # Arguments
375    /// * `n_samples` - Total number of samples
376    /// * `estimator_idx` - Index of the estimator (for seeding)
377    ///
378    /// # Returns
379    /// Bootstrap sample indices (with replacement)
380    pub fn generate_bootstrap_indices(&self, n_samples: usize, estimator_idx: usize) -> Vec<usize> {
381        #[allow(unused_imports)]
382        use scirs2_core::random::{Rng, SeedableRng, StdRng};
383
384        let seed = self.random_seed.wrapping_add(estimator_idx as u64);
385        let mut rng = StdRng::seed_from_u64(seed);
386
387        (0..n_samples)
388            .map(|_| rng.gen_range(0..n_samples))
389            .collect()
390    }
391
392    /// Get out-of-bag (OOB) indices for an estimator.
393    ///
394    /// # Arguments
395    /// * `n_samples` - Total number of samples
396    /// * `bootstrap_indices` - Bootstrap sample indices
397    ///
398    /// # Returns
399    /// OOB sample indices (not in bootstrap sample)
400    pub fn get_oob_indices(&self, n_samples: usize, bootstrap_indices: &[usize]) -> Vec<usize> {
401        let bootstrap_set: std::collections::HashSet<usize> =
402            bootstrap_indices.iter().cloned().collect();
403
404        (0..n_samples)
405            .filter(|idx| !bootstrap_set.contains(idx))
406            .collect()
407    }
408}
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413    use crate::LinearModel;
414    use scirs2_core::ndarray::array;
415
416    fn create_test_model() -> LinearModel {
417        // Create a 2-input, 2-output linear model
418        LinearModel::new(2, 2)
419    }
420
421    #[test]
422    fn test_voting_ensemble_hard() {
423        let model1 = create_test_model();
424        let model2 = create_test_model();
425
426        let ensemble = VotingEnsemble::new(vec![model1, model2], VotingMode::Hard).unwrap();
427
428        assert_eq!(ensemble.num_models(), 2);
429        assert_eq!(ensemble.mode(), VotingMode::Hard);
430
431        let input = array![[1.0, 0.0], [0.0, 1.0]];
432        let pred = ensemble.predict(&input).unwrap();
433
434        assert_eq!(pred.shape(), &[2, 2]);
435    }
436
437    #[test]
438    fn test_voting_ensemble_soft() {
439        let model1 = create_test_model();
440        let model2 = create_test_model();
441
442        let ensemble = VotingEnsemble::new(vec![model1, model2], VotingMode::Soft).unwrap();
443
444        let input = array![[1.0, 0.0]];
445        let pred = ensemble.predict(&input).unwrap();
446
447        assert_eq!(pred.shape(), &[1, 2]);
448    }
449
450    #[test]
451    fn test_voting_ensemble_with_weights() {
452        let model1 = create_test_model();
453        let model2 = create_test_model();
454
455        let ensemble = VotingEnsemble::new(vec![model1, model2], VotingMode::Soft)
456            .unwrap()
457            .with_weights(vec![0.7, 0.3])
458            .unwrap();
459
460        let input = array![[1.0, 0.0]];
461        let pred = ensemble.predict(&input).unwrap();
462
463        assert_eq!(pred.shape(), &[1, 2]);
464    }
465
466    #[test]
467    fn test_voting_ensemble_invalid_weights() {
468        let model1 = create_test_model();
469        let model2 = create_test_model();
470
471        let ensemble = VotingEnsemble::new(vec![model1, model2], VotingMode::Soft).unwrap();
472
473        // Wrong number of weights
474        let result = ensemble.with_weights(vec![0.5]);
475        assert!(result.is_err());
476
477        // Weights don't sum to 1.0
478        let model3 = create_test_model();
479        let model4 = create_test_model();
480        let ensemble2 = VotingEnsemble::new(vec![model3, model4], VotingMode::Soft).unwrap();
481        let result = ensemble2.with_weights(vec![0.5, 0.6]);
482        assert!(result.is_err());
483    }
484
485    #[test]
486    fn test_averaging_ensemble() {
487        let model1 = create_test_model();
488        let model2 = create_test_model();
489
490        let ensemble = AveragingEnsemble::new(vec![model1, model2]).unwrap();
491
492        assert_eq!(ensemble.num_models(), 2);
493
494        let input = array![[1.0, 0.0], [0.0, 1.0]];
495        let pred = ensemble.predict(&input).unwrap();
496
497        assert_eq!(pred.shape(), &[2, 2]);
498    }
499
500    #[test]
501    fn test_averaging_ensemble_with_weights() {
502        let model1 = create_test_model();
503        let model2 = create_test_model();
504
505        let ensemble = AveragingEnsemble::new(vec![model1, model2])
506            .unwrap()
507            .with_weights(vec![2.0, 1.0])
508            .unwrap();
509
510        let input = array![[1.0, 0.0]];
511        let pred = ensemble.predict(&input).unwrap();
512
513        assert_eq!(pred.shape(), &[1, 2]);
514    }
515
516    #[test]
517    fn test_stacking_ensemble() {
518        let base1 = create_test_model(); // 2 inputs, 2 outputs
519        let base2 = create_test_model(); // 2 inputs, 2 outputs
520        let meta = LinearModel::new(4, 2); // 4 inputs (2 base models × 2 outputs), 2 outputs
521
522        let ensemble = StackingEnsemble::new(vec![base1, base2], meta).unwrap();
523
524        assert_eq!(ensemble.num_models(), 3); // 2 base + 1 meta
525
526        let input = array![[1.0, 0.0]];
527        let pred = ensemble.predict(&input).unwrap();
528
529        // Meta-model takes concatenated predictions as input
530        assert_eq!(pred.nrows(), 1);
531    }
532
533    #[test]
534    fn test_stacking_meta_features() {
535        let base1 = create_test_model();
536        let base2 = create_test_model();
537        let meta = create_test_model();
538
539        let ensemble = StackingEnsemble::new(vec![base1, base2], meta).unwrap();
540
541        let input = array![[1.0, 0.0]];
542        let meta_features = ensemble.generate_meta_features(&input).unwrap();
543
544        // Should concatenate predictions from 2 base models
545        // Each base model outputs 2 features, so total = 2 * 2 = 4
546        assert_eq!(meta_features.shape(), &[1, 4]);
547    }
548
549    #[test]
550    fn test_bagging_helper() {
551        let helper = BaggingHelper::new(10, 42).unwrap();
552
553        let indices = helper.generate_bootstrap_indices(100, 0);
554        assert_eq!(indices.len(), 100);
555
556        // All indices should be in valid range
557        assert!(indices.iter().all(|&i| i < 100));
558
559        // OOB indices should be the complement
560        let oob = helper.get_oob_indices(100, &indices);
561        assert!(!oob.is_empty());
562
563        for &idx in &oob {
564            assert!(!indices.contains(&idx));
565        }
566    }
567
568    #[test]
569    fn test_bagging_helper_different_seeds() {
570        let helper = BaggingHelper::new(10, 42).unwrap();
571
572        let indices1 = helper.generate_bootstrap_indices(50, 0);
573        let indices2 = helper.generate_bootstrap_indices(50, 1);
574
575        // Different seeds should produce different samples
576        assert_ne!(indices1, indices2);
577    }
578
579    #[test]
580    fn test_bagging_helper_invalid() {
581        assert!(BaggingHelper::new(0, 42).is_err());
582    }
583
584    #[test]
585    fn test_ensemble_empty_models() {
586        let result = VotingEnsemble::<LinearModel>::new(vec![], VotingMode::Hard);
587        assert!(result.is_err());
588
589        let result = AveragingEnsemble::<LinearModel>::new(vec![]);
590        assert!(result.is_err());
591    }
592}