Skip to main content

tensorlogic_train/
ensemble.rs

1//! Model ensembling utilities for combining multiple models.
2//!
3//! This module provides various ensemble strategies:
4//! - Voting ensembles (hard and soft voting)
5//! - Averaging ensembles (simple and weighted)
6//! - Stacking ensembles (meta-learner)
7//! - Bagging utilities
8//! - Model soups (weight-space averaging)
9
10use crate::{Model, TrainError, TrainResult};
11use scirs2_core::ndarray::Array2;
12use std::collections::HashMap;
13
14/// Trait for ensemble methods.
15pub trait Ensemble {
16    /// Predict using the ensemble.
17    ///
18    /// # Arguments
19    /// * `input` - Input data [batch_size, features]
20    ///
21    /// # Returns
22    /// Ensemble predictions [batch_size, num_classes]
23    fn predict(&self, input: &Array2<f64>) -> TrainResult<Array2<f64>>;
24
25    /// Get the number of models in the ensemble.
26    fn num_models(&self) -> usize;
27}
28
29/// Voting ensemble for classification.
30///
31/// Combines predictions from multiple models using voting:
32/// - Hard voting: Majority vote (class with most votes wins)
33/// - Soft voting: Average predicted probabilities
34#[derive(Debug, Clone, Copy, PartialEq)]
35pub enum VotingMode {
36    /// Hard voting (majority vote).
37    Hard,
38    /// Soft voting (average probabilities).
39    Soft,
40}
41
42/// Voting ensemble configuration.
43#[derive(Debug)]
44pub struct VotingEnsemble<M: Model> {
45    /// Base models in the ensemble.
46    models: Vec<M>,
47    /// Voting mode (hard or soft).
48    mode: VotingMode,
49    /// Model weights (for weighted voting).
50    weights: Option<Vec<f64>>,
51}
52
53impl<M: Model> VotingEnsemble<M> {
54    /// Create a new voting ensemble.
55    ///
56    /// # Arguments
57    /// * `models` - Base models to ensemble
58    /// * `mode` - Voting mode (hard or soft)
59    pub fn new(models: Vec<M>, mode: VotingMode) -> TrainResult<Self> {
60        if models.is_empty() {
61            return Err(TrainError::InvalidParameter(
62                "Ensemble must have at least one model".to_string(),
63            ));
64        }
65        Ok(Self {
66            models,
67            mode,
68            weights: None,
69        })
70    }
71
72    /// Set model weights for weighted voting.
73    ///
74    /// # Arguments
75    /// * `weights` - Weight for each model (must sum to 1.0)
76    pub fn with_weights(mut self, weights: Vec<f64>) -> TrainResult<Self> {
77        if weights.len() != self.models.len() {
78            return Err(TrainError::InvalidParameter(
79                "Number of weights must match number of models".to_string(),
80            ));
81        }
82
83        let sum: f64 = weights.iter().sum();
84        if (sum - 1.0).abs() > 1e-6 {
85            return Err(TrainError::InvalidParameter(
86                "Weights must sum to 1.0".to_string(),
87            ));
88        }
89
90        self.weights = Some(weights);
91        Ok(self)
92    }
93
94    /// Get voting mode.
95    pub fn mode(&self) -> VotingMode {
96        self.mode
97    }
98}
99
100impl<M: Model> Ensemble for VotingEnsemble<M> {
101    fn predict(&self, input: &Array2<f64>) -> TrainResult<Array2<f64>> {
102        let batch_size = input.nrows();
103
104        // Collect predictions from all models
105        let mut all_predictions = Vec::with_capacity(self.models.len());
106        for model in &self.models {
107            let pred = model.forward(&input.view())?;
108            all_predictions.push(pred);
109        }
110
111        // Get output shape from first prediction
112        let num_classes = all_predictions[0].ncols();
113        let mut ensemble_pred = Array2::zeros((batch_size, num_classes));
114
115        match self.mode {
116            VotingMode::Hard => {
117                // Hard voting: count votes for each class
118                for i in 0..batch_size {
119                    let mut votes = vec![0.0; num_classes];
120
121                    for (model_idx, pred) in all_predictions.iter().enumerate() {
122                        // Get predicted class (argmax)
123                        let row = pred.row(i);
124                        let class_idx = row
125                            .iter()
126                            .enumerate()
127                            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
128                            .map(|(idx, _)| idx)
129                            .unwrap_or(0);
130
131                        let weight = self.weights.as_ref().map(|w| w[model_idx]).unwrap_or(1.0);
132                        votes[class_idx] += weight;
133                    }
134
135                    // Convert votes to one-hot prediction
136                    let max_votes = votes.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
137                    let winning_class = votes
138                        .iter()
139                        .position(|&v| (v - max_votes).abs() < 1e-10)
140                        .unwrap();
141
142                    ensemble_pred[[i, winning_class]] = 1.0;
143                }
144            }
145            VotingMode::Soft => {
146                // Soft voting: average probabilities
147                for i in 0..batch_size {
148                    for j in 0..num_classes {
149                        let mut weighted_sum = 0.0;
150
151                        for (model_idx, pred) in all_predictions.iter().enumerate() {
152                            let weight = self.weights.as_ref().map(|w| w[model_idx]).unwrap_or(1.0);
153                            weighted_sum += pred[[i, j]] * weight;
154                        }
155
156                        let normalizer = if self.weights.is_some() {
157                            1.0 // Weights already sum to 1.0
158                        } else {
159                            self.models.len() as f64
160                        };
161
162                        ensemble_pred[[i, j]] = weighted_sum / normalizer;
163                    }
164                }
165            }
166        }
167
168        Ok(ensemble_pred)
169    }
170
171    fn num_models(&self) -> usize {
172        self.models.len()
173    }
174}
175
176/// Averaging ensemble for regression.
177///
178/// Combines predictions by averaging (simple or weighted).
179#[derive(Debug)]
180pub struct AveragingEnsemble<M: Model> {
181    /// Base models in the ensemble.
182    models: Vec<M>,
183    /// Model weights (for weighted averaging).
184    weights: Option<Vec<f64>>,
185}
186
187impl<M: Model> AveragingEnsemble<M> {
188    /// Create a new averaging ensemble.
189    ///
190    /// # Arguments
191    /// * `models` - Base models to ensemble
192    pub fn new(models: Vec<M>) -> TrainResult<Self> {
193        if models.is_empty() {
194            return Err(TrainError::InvalidParameter(
195                "Ensemble must have at least one model".to_string(),
196            ));
197        }
198        Ok(Self {
199            models,
200            weights: None,
201        })
202    }
203
204    /// Set model weights for weighted averaging.
205    ///
206    /// # Arguments
207    /// * `weights` - Weight for each model
208    pub fn with_weights(mut self, weights: Vec<f64>) -> TrainResult<Self> {
209        if weights.len() != self.models.len() {
210            return Err(TrainError::InvalidParameter(
211                "Number of weights must match number of models".to_string(),
212            ));
213        }
214
215        // Normalize weights
216        let sum: f64 = weights.iter().sum();
217        if sum <= 0.0 {
218            return Err(TrainError::InvalidParameter(
219                "Weights must sum to a positive value".to_string(),
220            ));
221        }
222
223        let normalized_weights: Vec<f64> = weights.iter().map(|w| w / sum).collect();
224        self.weights = Some(normalized_weights);
225        Ok(self)
226    }
227}
228
229impl<M: Model> Ensemble for AveragingEnsemble<M> {
230    fn predict(&self, input: &Array2<f64>) -> TrainResult<Array2<f64>> {
231        // Collect predictions from all models
232        let mut all_predictions = Vec::with_capacity(self.models.len());
233        for model in &self.models {
234            let pred = model.forward(&input.view())?;
235            all_predictions.push(pred);
236        }
237
238        // Average predictions
239        let shape = all_predictions[0].raw_dim();
240        let mut ensemble_pred = Array2::zeros(shape);
241
242        for (model_idx, pred) in all_predictions.iter().enumerate() {
243            let weight = self.weights.as_ref().map(|w| w[model_idx]).unwrap_or(1.0);
244
245            for i in 0..pred.nrows() {
246                for j in 0..pred.ncols() {
247                    ensemble_pred[[i, j]] += pred[[i, j]] * weight;
248                }
249            }
250        }
251
252        // Normalize if using uniform weights
253        if self.weights.is_none() {
254            ensemble_pred /= self.models.len() as f64;
255        }
256
257        Ok(ensemble_pred)
258    }
259
260    fn num_models(&self) -> usize {
261        self.models.len()
262    }
263}
264
265/// Stacking ensemble with a meta-learner.
266///
267/// Uses base models' predictions as features for a meta-model.
268#[derive(Debug)]
269pub struct StackingEnsemble<M: Model, Meta: Model> {
270    /// Base models (first level).
271    base_models: Vec<M>,
272    /// Meta-model (second level).
273    meta_model: Meta,
274}
275
276impl<M: Model, Meta: Model> StackingEnsemble<M, Meta> {
277    /// Create a new stacking ensemble.
278    ///
279    /// # Arguments
280    /// * `base_models` - First-level base models
281    /// * `meta_model` - Second-level meta-learner
282    pub fn new(base_models: Vec<M>, meta_model: Meta) -> TrainResult<Self> {
283        if base_models.is_empty() {
284            return Err(TrainError::InvalidParameter(
285                "Ensemble must have at least one base model".to_string(),
286            ));
287        }
288        Ok(Self {
289            base_models,
290            meta_model,
291        })
292    }
293
294    /// Generate meta-features from base model predictions.
295    ///
296    /// # Arguments
297    /// * `input` - Input data
298    ///
299    /// # Returns
300    /// Meta-features [batch_size, num_base_models * num_classes]
301    pub fn generate_meta_features(&self, input: &Array2<f64>) -> TrainResult<Array2<f64>> {
302        let batch_size = input.nrows();
303
304        // Collect predictions from all base models
305        let mut all_predictions = Vec::with_capacity(self.base_models.len());
306        for model in &self.base_models {
307            let pred = model.forward(&input.view())?;
308            all_predictions.push(pred);
309        }
310
311        // Concatenate predictions horizontally to form meta-features
312        let num_features_per_model = all_predictions[0].ncols();
313        let total_features = self.base_models.len() * num_features_per_model;
314
315        let mut meta_features = Array2::zeros((batch_size, total_features));
316
317        for (model_idx, pred) in all_predictions.iter().enumerate() {
318            let start_col = model_idx * num_features_per_model;
319
320            for i in 0..batch_size {
321                for j in 0..num_features_per_model {
322                    meta_features[[i, start_col + j]] = pred[[i, j]];
323                }
324            }
325        }
326
327        Ok(meta_features)
328    }
329}
330
331impl<M: Model, Meta: Model> Ensemble for StackingEnsemble<M, Meta> {
332    fn predict(&self, input: &Array2<f64>) -> TrainResult<Array2<f64>> {
333        // Generate meta-features from base models
334        let meta_features = self.generate_meta_features(input)?;
335
336        // Make final prediction with meta-model
337        self.meta_model.forward(&meta_features.view())
338    }
339
340    fn num_models(&self) -> usize {
341        self.base_models.len() + 1 // base models + meta model
342    }
343}
344
345/// Bagging (Bootstrap Aggregating) utilities.
346///
347/// Generates bootstrap samples for training ensemble members.
348#[derive(Debug)]
349pub struct BaggingHelper {
350    /// Number of bootstrap samples.
351    pub n_estimators: usize,
352    /// Random seed for reproducibility.
353    pub random_seed: u64,
354}
355
356impl BaggingHelper {
357    /// Create a new bagging helper.
358    ///
359    /// # Arguments
360    /// * `n_estimators` - Number of bootstrap samples
361    /// * `random_seed` - Random seed
362    pub fn new(n_estimators: usize, random_seed: u64) -> TrainResult<Self> {
363        if n_estimators == 0 {
364            return Err(TrainError::InvalidParameter(
365                "n_estimators must be positive".to_string(),
366            ));
367        }
368        Ok(Self {
369            n_estimators,
370            random_seed,
371        })
372    }
373
374    /// Generate bootstrap sample indices.
375    ///
376    /// # Arguments
377    /// * `n_samples` - Total number of samples
378    /// * `estimator_idx` - Index of the estimator (for seeding)
379    ///
380    /// # Returns
381    /// Bootstrap sample indices (with replacement)
382    pub fn generate_bootstrap_indices(&self, n_samples: usize, estimator_idx: usize) -> Vec<usize> {
383        #[allow(unused_imports)]
384        use scirs2_core::random::{Rng, SeedableRng, StdRng};
385
386        let seed = self.random_seed.wrapping_add(estimator_idx as u64);
387        let mut rng = StdRng::seed_from_u64(seed);
388
389        (0..n_samples)
390            .map(|_| rng.gen_range(0..n_samples))
391            .collect()
392    }
393
394    /// Get out-of-bag (OOB) indices for an estimator.
395    ///
396    /// # Arguments
397    /// * `n_samples` - Total number of samples
398    /// * `bootstrap_indices` - Bootstrap sample indices
399    ///
400    /// # Returns
401    /// OOB sample indices (not in bootstrap sample)
402    pub fn get_oob_indices(&self, n_samples: usize, bootstrap_indices: &[usize]) -> Vec<usize> {
403        let bootstrap_set: std::collections::HashSet<usize> =
404            bootstrap_indices.iter().cloned().collect();
405
406        (0..n_samples)
407            .filter(|idx| !bootstrap_set.contains(idx))
408            .collect()
409    }
410}
411
412/// Model Soup - Weight-space averaging for improved generalization.
413///
414/// From "Model soups: averaging weights of multiple fine-tuned models
415/// improves accuracy without increasing inference time" (Wortsman et al., 2022).
416///
417/// Model soups average the *weights* of multiple models (not predictions), which can:
418/// - Improve accuracy compared to individual models
419/// - No inference cost (single model at test time)
420/// - Work across different hyperparameters and random seeds
421/// - Particularly effective for models fine-tuned from same initialization
422///
423/// Two main recipes:
424/// - **Uniform Soup**: Simple average of all model weights
425/// - **Greedy Soup**: Iteratively add models that improve validation performance
426///
427/// # Example
428/// ```
429/// use tensorlogic_train::{ModelSoup, SoupRecipe};
430/// use std::collections::HashMap;
431/// use scirs2_core::ndarray::Array2;
432///
433/// // Collect weights from multiple fine-tuned models
434/// // let model_weights = vec![weights1, weights2, weights3];
435/// // let soup = ModelSoup::uniform_soup(model_weights);
436/// // let averaged_weights = soup.weights();
437/// ```
438#[derive(Debug, Clone)]
439pub struct ModelSoup {
440    /// Averaged model weights
441    weights: HashMap<String, Array2<f64>>,
442    /// Number of models in the soup
443    num_models: usize,
444    /// Recipe used to create the soup
445    recipe: SoupRecipe,
446}
447
448/// Recipe for creating model soups
449#[derive(Debug, Clone, Copy, PartialEq, Eq)]
450pub enum SoupRecipe {
451    /// Uniform averaging of all models
452    Uniform,
453    /// Greedy selection based on validation performance
454    Greedy,
455    /// Custom weighted averaging
456    Weighted,
457}
458
459impl ModelSoup {
460    /// Create a uniform soup by averaging all model weights equally.
461    ///
462    /// # Arguments
463    /// * `model_weights` - Weights from multiple fine-tuned models
464    ///
465    /// # Returns
466    /// Model soup with uniformly averaged weights
467    ///
468    /// # Example
469    /// ```
470    /// use tensorlogic_train::ModelSoup;
471    /// use std::collections::HashMap;
472    /// use scirs2_core::ndarray::array;
473    ///
474    /// let mut weights1 = HashMap::new();
475    /// weights1.insert("w".to_string(), array![[1.0, 2.0]]);
476    ///
477    /// let mut weights2 = HashMap::new();
478    /// weights2.insert("w".to_string(), array![[3.0, 4.0]]);
479    ///
480    /// let soup = ModelSoup::uniform_soup(vec![weights1, weights2]).unwrap();
481    /// // Averaged weights: [[2.0, 3.0]]
482    /// ```
483    pub fn uniform_soup(model_weights: Vec<HashMap<String, Array2<f64>>>) -> TrainResult<Self> {
484        if model_weights.is_empty() {
485            return Err(TrainError::InvalidParameter(
486                "At least one model required for soup".to_string(),
487            ));
488        }
489
490        let num_models = model_weights.len();
491        let mut averaged_weights = HashMap::new();
492
493        // Get parameter names from first model
494        let param_names: Vec<String> = model_weights[0].keys().cloned().collect();
495
496        // Average each parameter across all models
497        for param_name in param_names {
498            // Initialize with zeros
499            let shape = model_weights[0][&param_name].raw_dim();
500            let mut averaged_param = Array2::zeros(shape);
501
502            // Sum across all models
503            for model_weight in &model_weights {
504                if let Some(param) = model_weight.get(&param_name) {
505                    averaged_param += param;
506                } else {
507                    return Err(TrainError::InvalidParameter(format!(
508                        "Parameter '{}' not found in all models",
509                        param_name
510                    )));
511                }
512            }
513
514            // Divide by number of models
515            averaged_param /= num_models as f64;
516            averaged_weights.insert(param_name, averaged_param);
517        }
518
519        Ok(Self {
520            weights: averaged_weights,
521            num_models,
522            recipe: SoupRecipe::Uniform,
523        })
524    }
525
526    /// Create a greedy soup by iteratively adding models that improve validation performance.
527    ///
528    /// # Arguments
529    /// * `model_weights` - Weights from multiple fine-tuned models
530    /// * `val_accuracies` - Validation accuracy for each model
531    ///
532    /// # Returns
533    /// Model soup with greedily selected and averaged weights
534    ///
535    /// # Algorithm
536    /// 1. Start with best single model
537    /// 2. Try adding each remaining model to soup
538    /// 3. Keep additions that improve validation performance
539    /// 4. Repeat until no improvement
540    pub fn greedy_soup(
541        model_weights: Vec<HashMap<String, Array2<f64>>>,
542        val_accuracies: Vec<f64>,
543    ) -> TrainResult<Self> {
544        if model_weights.is_empty() {
545            return Err(TrainError::InvalidParameter(
546                "At least one model required for soup".to_string(),
547            ));
548        }
549
550        if model_weights.len() != val_accuracies.len() {
551            return Err(TrainError::InvalidParameter(
552                "Number of models must match number of validation accuracies".to_string(),
553            ));
554        }
555
556        // Find best single model as starting point
557        let best_idx = val_accuracies
558            .iter()
559            .enumerate()
560            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
561            .map(|(idx, _)| idx)
562            .unwrap();
563
564        let mut soup_indices = vec![best_idx];
565        let mut best_accuracy = val_accuracies[best_idx];
566
567        // Greedily add models that improve performance
568        loop {
569            let mut improved = false;
570            let mut best_addition = None;
571            let mut best_new_accuracy = best_accuracy;
572
573            // Try adding each model not yet in soup
574            for (idx, acc) in val_accuracies.iter().enumerate() {
575                if soup_indices.contains(&idx) {
576                    continue;
577                }
578
579                // Estimate accuracy if we add this model
580                // (In practice, you'd evaluate on validation set, but we use provided accuracy)
581                let potential_accuracy = (*acc + best_accuracy) / 2.0;
582
583                if potential_accuracy > best_new_accuracy {
584                    best_new_accuracy = potential_accuracy;
585                    best_addition = Some(idx);
586                    improved = true;
587                }
588            }
589
590            if improved {
591                if let Some(idx) = best_addition {
592                    soup_indices.push(idx);
593                    best_accuracy = best_new_accuracy;
594                } else {
595                    break;
596                }
597            } else {
598                break;
599            }
600        }
601
602        // Create soup from selected models
603        let selected_weights: Vec<_> = soup_indices
604            .iter()
605            .map(|&idx| model_weights[idx].clone())
606            .collect();
607
608        let mut soup = Self::uniform_soup(selected_weights)?;
609        soup.recipe = SoupRecipe::Greedy;
610        soup.num_models = soup_indices.len();
611
612        Ok(soup)
613    }
614
615    /// Create a weighted soup with custom weights for each model.
616    ///
617    /// # Arguments
618    /// * `model_weights` - Weights from multiple fine-tuned models
619    /// * `weights` - Weight for each model (will be normalized to sum to 1)
620    ///
621    /// # Returns
622    /// Model soup with weighted averaged parameters
623    pub fn weighted_soup(
624        model_weights: Vec<HashMap<String, Array2<f64>>>,
625        weights: Vec<f64>,
626    ) -> TrainResult<Self> {
627        if model_weights.is_empty() {
628            return Err(TrainError::InvalidParameter(
629                "At least one model required for soup".to_string(),
630            ));
631        }
632
633        if model_weights.len() != weights.len() {
634            return Err(TrainError::InvalidParameter(
635                "Number of models must match number of weights".to_string(),
636            ));
637        }
638
639        // Normalize weights
640        let sum: f64 = weights.iter().sum();
641        if sum <= 0.0 {
642            return Err(TrainError::InvalidParameter(
643                "Weights must sum to positive value".to_string(),
644            ));
645        }
646
647        let normalized_weights: Vec<f64> = weights.iter().map(|w| w / sum).collect();
648
649        // Weighted average
650        let num_models = model_weights.len();
651        let mut averaged_weights = HashMap::new();
652        let param_names: Vec<String> = model_weights[0].keys().cloned().collect();
653
654        for param_name in param_names {
655            let shape = model_weights[0][&param_name].raw_dim();
656            let mut averaged_param = Array2::zeros(shape);
657
658            for (model_idx, model_weight) in model_weights.iter().enumerate() {
659                if let Some(param) = model_weight.get(&param_name) {
660                    averaged_param = averaged_param + param * normalized_weights[model_idx];
661                } else {
662                    return Err(TrainError::InvalidParameter(format!(
663                        "Parameter '{}' not found in all models",
664                        param_name
665                    )));
666                }
667            }
668
669            averaged_weights.insert(param_name, averaged_param);
670        }
671
672        Ok(Self {
673            weights: averaged_weights,
674            num_models,
675            recipe: SoupRecipe::Weighted,
676        })
677    }
678
679    /// Get the averaged weights from the soup.
680    pub fn weights(&self) -> &HashMap<String, Array2<f64>> {
681        &self.weights
682    }
683
684    /// Get the number of models in the soup.
685    pub fn num_models(&self) -> usize {
686        self.num_models
687    }
688
689    /// Get the recipe used to create the soup.
690    pub fn recipe(&self) -> SoupRecipe {
691        self.recipe
692    }
693
694    /// Get a specific parameter by name.
695    pub fn get_parameter(&self, name: &str) -> Option<&Array2<f64>> {
696        self.weights.get(name)
697    }
698
699    /// Load weights into a model (consumes the soup).
700    ///
701    /// This is a convenience method that returns the weights for loading into a model.
702    pub fn into_weights(self) -> HashMap<String, Array2<f64>> {
703        self.weights
704    }
705}
706
707#[cfg(test)]
708mod tests {
709    use super::*;
710    use crate::LinearModel;
711    use scirs2_core::ndarray::array;
712
713    fn create_test_model() -> LinearModel {
714        // Create a 2-input, 2-output linear model
715        LinearModel::new(2, 2)
716    }
717
718    #[test]
719    fn test_voting_ensemble_hard() {
720        let model1 = create_test_model();
721        let model2 = create_test_model();
722
723        let ensemble = VotingEnsemble::new(vec![model1, model2], VotingMode::Hard).unwrap();
724
725        assert_eq!(ensemble.num_models(), 2);
726        assert_eq!(ensemble.mode(), VotingMode::Hard);
727
728        let input = array![[1.0, 0.0], [0.0, 1.0]];
729        let pred = ensemble.predict(&input).unwrap();
730
731        assert_eq!(pred.shape(), &[2, 2]);
732    }
733
734    #[test]
735    fn test_voting_ensemble_soft() {
736        let model1 = create_test_model();
737        let model2 = create_test_model();
738
739        let ensemble = VotingEnsemble::new(vec![model1, model2], VotingMode::Soft).unwrap();
740
741        let input = array![[1.0, 0.0]];
742        let pred = ensemble.predict(&input).unwrap();
743
744        assert_eq!(pred.shape(), &[1, 2]);
745    }
746
747    #[test]
748    fn test_voting_ensemble_with_weights() {
749        let model1 = create_test_model();
750        let model2 = create_test_model();
751
752        let ensemble = VotingEnsemble::new(vec![model1, model2], VotingMode::Soft)
753            .unwrap()
754            .with_weights(vec![0.7, 0.3])
755            .unwrap();
756
757        let input = array![[1.0, 0.0]];
758        let pred = ensemble.predict(&input).unwrap();
759
760        assert_eq!(pred.shape(), &[1, 2]);
761    }
762
763    #[test]
764    fn test_voting_ensemble_invalid_weights() {
765        let model1 = create_test_model();
766        let model2 = create_test_model();
767
768        let ensemble = VotingEnsemble::new(vec![model1, model2], VotingMode::Soft).unwrap();
769
770        // Wrong number of weights
771        let result = ensemble.with_weights(vec![0.5]);
772        assert!(result.is_err());
773
774        // Weights don't sum to 1.0
775        let model3 = create_test_model();
776        let model4 = create_test_model();
777        let ensemble2 = VotingEnsemble::new(vec![model3, model4], VotingMode::Soft).unwrap();
778        let result = ensemble2.with_weights(vec![0.5, 0.6]);
779        assert!(result.is_err());
780    }
781
782    #[test]
783    fn test_averaging_ensemble() {
784        let model1 = create_test_model();
785        let model2 = create_test_model();
786
787        let ensemble = AveragingEnsemble::new(vec![model1, model2]).unwrap();
788
789        assert_eq!(ensemble.num_models(), 2);
790
791        let input = array![[1.0, 0.0], [0.0, 1.0]];
792        let pred = ensemble.predict(&input).unwrap();
793
794        assert_eq!(pred.shape(), &[2, 2]);
795    }
796
797    #[test]
798    fn test_averaging_ensemble_with_weights() {
799        let model1 = create_test_model();
800        let model2 = create_test_model();
801
802        let ensemble = AveragingEnsemble::new(vec![model1, model2])
803            .unwrap()
804            .with_weights(vec![2.0, 1.0])
805            .unwrap();
806
807        let input = array![[1.0, 0.0]];
808        let pred = ensemble.predict(&input).unwrap();
809
810        assert_eq!(pred.shape(), &[1, 2]);
811    }
812
813    #[test]
814    fn test_stacking_ensemble() {
815        let base1 = create_test_model(); // 2 inputs, 2 outputs
816        let base2 = create_test_model(); // 2 inputs, 2 outputs
817        let meta = LinearModel::new(4, 2); // 4 inputs (2 base models × 2 outputs), 2 outputs
818
819        let ensemble = StackingEnsemble::new(vec![base1, base2], meta).unwrap();
820
821        assert_eq!(ensemble.num_models(), 3); // 2 base + 1 meta
822
823        let input = array![[1.0, 0.0]];
824        let pred = ensemble.predict(&input).unwrap();
825
826        // Meta-model takes concatenated predictions as input
827        assert_eq!(pred.nrows(), 1);
828    }
829
830    #[test]
831    fn test_stacking_meta_features() {
832        let base1 = create_test_model();
833        let base2 = create_test_model();
834        let meta = create_test_model();
835
836        let ensemble = StackingEnsemble::new(vec![base1, base2], meta).unwrap();
837
838        let input = array![[1.0, 0.0]];
839        let meta_features = ensemble.generate_meta_features(&input).unwrap();
840
841        // Should concatenate predictions from 2 base models
842        // Each base model outputs 2 features, so total = 2 * 2 = 4
843        assert_eq!(meta_features.shape(), &[1, 4]);
844    }
845
846    #[test]
847    fn test_bagging_helper() {
848        let helper = BaggingHelper::new(10, 42).unwrap();
849
850        let indices = helper.generate_bootstrap_indices(100, 0);
851        assert_eq!(indices.len(), 100);
852
853        // All indices should be in valid range
854        assert!(indices.iter().all(|&i| i < 100));
855
856        // OOB indices should be the complement
857        let oob = helper.get_oob_indices(100, &indices);
858        assert!(!oob.is_empty());
859
860        for &idx in &oob {
861            assert!(!indices.contains(&idx));
862        }
863    }
864
865    #[test]
866    fn test_bagging_helper_different_seeds() {
867        let helper = BaggingHelper::new(10, 42).unwrap();
868
869        let indices1 = helper.generate_bootstrap_indices(50, 0);
870        let indices2 = helper.generate_bootstrap_indices(50, 1);
871
872        // Different seeds should produce different samples
873        assert_ne!(indices1, indices2);
874    }
875
876    #[test]
877    fn test_bagging_helper_invalid() {
878        assert!(BaggingHelper::new(0, 42).is_err());
879    }
880
881    #[test]
882    fn test_ensemble_empty_models() {
883        let result = VotingEnsemble::<LinearModel>::new(vec![], VotingMode::Hard);
884        assert!(result.is_err());
885
886        let result = AveragingEnsemble::<LinearModel>::new(vec![]);
887        assert!(result.is_err());
888    }
889
890    // Model Soup Tests
891    #[test]
892    fn test_uniform_soup() {
893        let mut weights1 = HashMap::new();
894        weights1.insert("w".to_string(), array![[1.0, 2.0]]);
895        weights1.insert("b".to_string(), array![[0.5]]);
896
897        let mut weights2 = HashMap::new();
898        weights2.insert("w".to_string(), array![[3.0, 4.0]]);
899        weights2.insert("b".to_string(), array![[1.5]]);
900
901        let soup = ModelSoup::uniform_soup(vec![weights1, weights2]).unwrap();
902
903        assert_eq!(soup.num_models(), 2);
904        assert_eq!(soup.recipe(), SoupRecipe::Uniform);
905
906        // Check averaged weights
907        let w = soup.get_parameter("w").unwrap();
908        assert_eq!(w[[0, 0]], 2.0); // (1.0 + 3.0) / 2
909        assert_eq!(w[[0, 1]], 3.0); // (2.0 + 4.0) / 2
910
911        let b = soup.get_parameter("b").unwrap();
912        assert_eq!(b[[0, 0]], 1.0); // (0.5 + 1.5) / 2
913    }
914
915    #[test]
916    fn test_uniform_soup_three_models() {
917        let mut weights1 = HashMap::new();
918        weights1.insert("w".to_string(), array![[1.0]]);
919
920        let mut weights2 = HashMap::new();
921        weights2.insert("w".to_string(), array![[2.0]]);
922
923        let mut weights3 = HashMap::new();
924        weights3.insert("w".to_string(), array![[3.0]]);
925
926        let soup = ModelSoup::uniform_soup(vec![weights1, weights2, weights3]).unwrap();
927
928        let w = soup.get_parameter("w").unwrap();
929        assert_eq!(w[[0, 0]], 2.0); // (1.0 + 2.0 + 3.0) / 3
930    }
931
932    #[test]
933    fn test_greedy_soup() {
934        let mut weights1 = HashMap::new();
935        weights1.insert("w".to_string(), array![[1.0]]);
936
937        let mut weights2 = HashMap::new();
938        weights2.insert("w".to_string(), array![[2.0]]);
939
940        let mut weights3 = HashMap::new();
941        weights3.insert("w".to_string(), array![[3.0]]);
942
943        let accuracies = vec![0.8, 0.9, 0.85]; // Model 2 is best
944
945        let soup = ModelSoup::greedy_soup(vec![weights1, weights2, weights3], accuracies).unwrap();
946
947        assert_eq!(soup.recipe(), SoupRecipe::Greedy);
948        assert!(soup.num_models() >= 1); // At least the best model
949    }
950
951    #[test]
952    fn test_weighted_soup() {
953        let mut weights1 = HashMap::new();
954        weights1.insert("w".to_string(), array![[1.0, 2.0]]);
955
956        let mut weights2 = HashMap::new();
957        weights2.insert("w".to_string(), array![[3.0, 4.0]]);
958
959        // Weight model 1 twice as much as model 2
960        let soup = ModelSoup::weighted_soup(vec![weights1, weights2], vec![2.0, 1.0]).unwrap();
961
962        assert_eq!(soup.recipe(), SoupRecipe::Weighted);
963
964        // Check weighted average: (2*1 + 1*3) / 3 = 5/3 ≈ 1.667
965        let w = soup.get_parameter("w").unwrap();
966        assert!((w[[0, 0]] - 1.6666666).abs() < 1e-5);
967        assert!((w[[0, 1]] - 2.6666666).abs() < 1e-5);
968    }
969
970    #[test]
971    fn test_soup_empty_models() {
972        let result = ModelSoup::uniform_soup(vec![]);
973        assert!(result.is_err());
974    }
975
976    #[test]
977    fn test_soup_mismatched_parameters() {
978        let mut weights1 = HashMap::new();
979        weights1.insert("w".to_string(), array![[1.0]]);
980
981        let mut weights2 = HashMap::new();
982        weights2.insert("b".to_string(), array![[2.0]]); // Different parameter name
983
984        let result = ModelSoup::uniform_soup(vec![weights1, weights2]);
985        assert!(result.is_err());
986    }
987
988    #[test]
989    fn test_greedy_soup_mismatched_lengths() {
990        let mut weights1 = HashMap::new();
991        weights1.insert("w".to_string(), array![[1.0]]);
992
993        let result = ModelSoup::greedy_soup(vec![weights1], vec![0.8, 0.9]);
994        assert!(result.is_err());
995    }
996
997    #[test]
998    fn test_weighted_soup_invalid_weights() {
999        let mut weights1 = HashMap::new();
1000        weights1.insert("w".to_string(), array![[1.0]]);
1001
1002        let mut weights2 = HashMap::new();
1003        weights2.insert("w".to_string(), array![[2.0]]);
1004
1005        // Negative weights
1006        let result =
1007            ModelSoup::weighted_soup(vec![weights1.clone(), weights2.clone()], vec![-1.0, 1.0]);
1008        assert!(result.is_err());
1009
1010        // Mismatched lengths
1011        let result = ModelSoup::weighted_soup(vec![weights1], vec![1.0, 2.0]);
1012        assert!(result.is_err());
1013    }
1014
1015    #[test]
1016    fn test_soup_into_weights() {
1017        let mut weights1 = HashMap::new();
1018        weights1.insert("w".to_string(), array![[1.0]]);
1019
1020        let mut weights2 = HashMap::new();
1021        weights2.insert("w".to_string(), array![[3.0]]);
1022
1023        let soup = ModelSoup::uniform_soup(vec![weights1, weights2]).unwrap();
1024        let final_weights = soup.into_weights();
1025
1026        assert_eq!(final_weights["w"][[0, 0]], 2.0);
1027    }
1028
1029    #[test]
1030    fn test_soup_multidimensional_weights() {
1031        let mut weights1 = HashMap::new();
1032        weights1.insert("conv".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
1033
1034        let mut weights2 = HashMap::new();
1035        weights2.insert("conv".to_string(), array![[5.0, 6.0], [7.0, 8.0]]);
1036
1037        let soup = ModelSoup::uniform_soup(vec![weights1, weights2]).unwrap();
1038        let conv = soup.get_parameter("conv").unwrap();
1039
1040        assert_eq!(conv[[0, 0]], 3.0); // (1 + 5) / 2
1041        assert_eq!(conv[[0, 1]], 4.0); // (2 + 6) / 2
1042        assert_eq!(conv[[1, 0]], 5.0); // (3 + 7) / 2
1043        assert_eq!(conv[[1, 1]], 6.0); // (4 + 8) / 2
1044    }
1045}