Skip to main content

tensorlogic_train/
ensemble.rs

1//! Model ensembling utilities for combining multiple models.
2//!
3//! This module provides various ensemble strategies:
4//! - Voting ensembles (hard and soft voting)
5//! - Averaging ensembles (simple and weighted)
6//! - Stacking ensembles (meta-learner)
7//! - Bagging utilities
8//! - Model soups (weight-space averaging)
9
10use crate::{Model, TrainError, TrainResult};
11use scirs2_core::ndarray::Array2;
12use std::collections::HashMap;
13
14/// Trait for ensemble methods.
15pub trait Ensemble {
16    /// Predict using the ensemble.
17    ///
18    /// # Arguments
19    /// * `input` - Input data [batch_size, features]
20    ///
21    /// # Returns
22    /// Ensemble predictions [batch_size, num_classes]
23    fn predict(&self, input: &Array2<f64>) -> TrainResult<Array2<f64>>;
24
25    /// Get the number of models in the ensemble.
26    fn num_models(&self) -> usize;
27}
28
29/// Voting ensemble for classification.
30///
31/// Combines predictions from multiple models using voting:
32/// - Hard voting: Majority vote (class with most votes wins)
33/// - Soft voting: Average predicted probabilities
34#[derive(Debug, Clone, Copy, PartialEq)]
35pub enum VotingMode {
36    /// Hard voting (majority vote).
37    Hard,
38    /// Soft voting (average probabilities).
39    Soft,
40}
41
42/// Voting ensemble configuration.
43#[derive(Debug)]
44pub struct VotingEnsemble<M: Model> {
45    /// Base models in the ensemble.
46    models: Vec<M>,
47    /// Voting mode (hard or soft).
48    mode: VotingMode,
49    /// Model weights (for weighted voting).
50    weights: Option<Vec<f64>>,
51}
52
53impl<M: Model> VotingEnsemble<M> {
54    /// Create a new voting ensemble.
55    ///
56    /// # Arguments
57    /// * `models` - Base models to ensemble
58    /// * `mode` - Voting mode (hard or soft)
59    pub fn new(models: Vec<M>, mode: VotingMode) -> TrainResult<Self> {
60        if models.is_empty() {
61            return Err(TrainError::InvalidParameter(
62                "Ensemble must have at least one model".to_string(),
63            ));
64        }
65        Ok(Self {
66            models,
67            mode,
68            weights: None,
69        })
70    }
71
72    /// Set model weights for weighted voting.
73    ///
74    /// # Arguments
75    /// * `weights` - Weight for each model (must sum to 1.0)
76    pub fn with_weights(mut self, weights: Vec<f64>) -> TrainResult<Self> {
77        if weights.len() != self.models.len() {
78            return Err(TrainError::InvalidParameter(
79                "Number of weights must match number of models".to_string(),
80            ));
81        }
82
83        let sum: f64 = weights.iter().sum();
84        if (sum - 1.0).abs() > 1e-6 {
85            return Err(TrainError::InvalidParameter(
86                "Weights must sum to 1.0".to_string(),
87            ));
88        }
89
90        self.weights = Some(weights);
91        Ok(self)
92    }
93
94    /// Get voting mode.
95    pub fn mode(&self) -> VotingMode {
96        self.mode
97    }
98}
99
100impl<M: Model> Ensemble for VotingEnsemble<M> {
101    fn predict(&self, input: &Array2<f64>) -> TrainResult<Array2<f64>> {
102        let batch_size = input.nrows();
103
104        // Collect predictions from all models
105        let mut all_predictions = Vec::with_capacity(self.models.len());
106        for model in &self.models {
107            let pred = model.forward(&input.view())?;
108            all_predictions.push(pred);
109        }
110
111        // Get output shape from first prediction
112        let num_classes = all_predictions[0].ncols();
113        let mut ensemble_pred = Array2::zeros((batch_size, num_classes));
114
115        match self.mode {
116            VotingMode::Hard => {
117                // Hard voting: count votes for each class
118                for i in 0..batch_size {
119                    let mut votes = vec![0.0; num_classes];
120
121                    for (model_idx, pred) in all_predictions.iter().enumerate() {
122                        // Get predicted class (argmax)
123                        let row = pred.row(i);
124                        let class_idx = row
125                            .iter()
126                            .enumerate()
127                            .max_by(|(_, a), (_, b)| {
128                                a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
129                            })
130                            .map(|(idx, _)| idx)
131                            .unwrap_or(0);
132
133                        let weight = self.weights.as_ref().map(|w| w[model_idx]).unwrap_or(1.0);
134                        votes[class_idx] += weight;
135                    }
136
137                    // Convert votes to one-hot prediction
138                    let max_votes = votes.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
139                    let winning_class = votes
140                        .iter()
141                        .position(|&v| (v - max_votes).abs() < 1e-10)
142                        .unwrap_or(0);
143
144                    ensemble_pred[[i, winning_class]] = 1.0;
145                }
146            }
147            VotingMode::Soft => {
148                // Soft voting: average probabilities
149                for i in 0..batch_size {
150                    for j in 0..num_classes {
151                        let mut weighted_sum = 0.0;
152
153                        for (model_idx, pred) in all_predictions.iter().enumerate() {
154                            let weight = self.weights.as_ref().map(|w| w[model_idx]).unwrap_or(1.0);
155                            weighted_sum += pred[[i, j]] * weight;
156                        }
157
158                        let normalizer = if self.weights.is_some() {
159                            1.0 // Weights already sum to 1.0
160                        } else {
161                            self.models.len() as f64
162                        };
163
164                        ensemble_pred[[i, j]] = weighted_sum / normalizer;
165                    }
166                }
167            }
168        }
169
170        Ok(ensemble_pred)
171    }
172
173    fn num_models(&self) -> usize {
174        self.models.len()
175    }
176}
177
178/// Averaging ensemble for regression.
179///
180/// Combines predictions by averaging (simple or weighted).
181#[derive(Debug)]
182pub struct AveragingEnsemble<M: Model> {
183    /// Base models in the ensemble.
184    models: Vec<M>,
185    /// Model weights (for weighted averaging).
186    weights: Option<Vec<f64>>,
187}
188
189impl<M: Model> AveragingEnsemble<M> {
190    /// Create a new averaging ensemble.
191    ///
192    /// # Arguments
193    /// * `models` - Base models to ensemble
194    pub fn new(models: Vec<M>) -> TrainResult<Self> {
195        if models.is_empty() {
196            return Err(TrainError::InvalidParameter(
197                "Ensemble must have at least one model".to_string(),
198            ));
199        }
200        Ok(Self {
201            models,
202            weights: None,
203        })
204    }
205
206    /// Set model weights for weighted averaging.
207    ///
208    /// # Arguments
209    /// * `weights` - Weight for each model
210    pub fn with_weights(mut self, weights: Vec<f64>) -> TrainResult<Self> {
211        if weights.len() != self.models.len() {
212            return Err(TrainError::InvalidParameter(
213                "Number of weights must match number of models".to_string(),
214            ));
215        }
216
217        // Normalize weights
218        let sum: f64 = weights.iter().sum();
219        if sum <= 0.0 {
220            return Err(TrainError::InvalidParameter(
221                "Weights must sum to a positive value".to_string(),
222            ));
223        }
224
225        let normalized_weights: Vec<f64> = weights.iter().map(|w| w / sum).collect();
226        self.weights = Some(normalized_weights);
227        Ok(self)
228    }
229}
230
231impl<M: Model> Ensemble for AveragingEnsemble<M> {
232    fn predict(&self, input: &Array2<f64>) -> TrainResult<Array2<f64>> {
233        // Collect predictions from all models
234        let mut all_predictions = Vec::with_capacity(self.models.len());
235        for model in &self.models {
236            let pred = model.forward(&input.view())?;
237            all_predictions.push(pred);
238        }
239
240        // Average predictions
241        let shape = all_predictions[0].raw_dim();
242        let mut ensemble_pred = Array2::zeros(shape);
243
244        for (model_idx, pred) in all_predictions.iter().enumerate() {
245            let weight = self.weights.as_ref().map(|w| w[model_idx]).unwrap_or(1.0);
246
247            for i in 0..pred.nrows() {
248                for j in 0..pred.ncols() {
249                    ensemble_pred[[i, j]] += pred[[i, j]] * weight;
250                }
251            }
252        }
253
254        // Normalize if using uniform weights
255        if self.weights.is_none() {
256            ensemble_pred /= self.models.len() as f64;
257        }
258
259        Ok(ensemble_pred)
260    }
261
262    fn num_models(&self) -> usize {
263        self.models.len()
264    }
265}
266
267/// Stacking ensemble with a meta-learner.
268///
269/// Uses base models' predictions as features for a meta-model.
270#[derive(Debug)]
271pub struct StackingEnsemble<M: Model, Meta: Model> {
272    /// Base models (first level).
273    base_models: Vec<M>,
274    /// Meta-model (second level).
275    meta_model: Meta,
276}
277
278impl<M: Model, Meta: Model> StackingEnsemble<M, Meta> {
279    /// Create a new stacking ensemble.
280    ///
281    /// # Arguments
282    /// * `base_models` - First-level base models
283    /// * `meta_model` - Second-level meta-learner
284    pub fn new(base_models: Vec<M>, meta_model: Meta) -> TrainResult<Self> {
285        if base_models.is_empty() {
286            return Err(TrainError::InvalidParameter(
287                "Ensemble must have at least one base model".to_string(),
288            ));
289        }
290        Ok(Self {
291            base_models,
292            meta_model,
293        })
294    }
295
296    /// Generate meta-features from base model predictions.
297    ///
298    /// # Arguments
299    /// * `input` - Input data
300    ///
301    /// # Returns
302    /// Meta-features [batch_size, num_base_models * num_classes]
303    pub fn generate_meta_features(&self, input: &Array2<f64>) -> TrainResult<Array2<f64>> {
304        let batch_size = input.nrows();
305
306        // Collect predictions from all base models
307        let mut all_predictions = Vec::with_capacity(self.base_models.len());
308        for model in &self.base_models {
309            let pred = model.forward(&input.view())?;
310            all_predictions.push(pred);
311        }
312
313        // Concatenate predictions horizontally to form meta-features
314        let num_features_per_model = all_predictions[0].ncols();
315        let total_features = self.base_models.len() * num_features_per_model;
316
317        let mut meta_features = Array2::zeros((batch_size, total_features));
318
319        for (model_idx, pred) in all_predictions.iter().enumerate() {
320            let start_col = model_idx * num_features_per_model;
321
322            for i in 0..batch_size {
323                for j in 0..num_features_per_model {
324                    meta_features[[i, start_col + j]] = pred[[i, j]];
325                }
326            }
327        }
328
329        Ok(meta_features)
330    }
331}
332
333impl<M: Model, Meta: Model> Ensemble for StackingEnsemble<M, Meta> {
334    fn predict(&self, input: &Array2<f64>) -> TrainResult<Array2<f64>> {
335        // Generate meta-features from base models
336        let meta_features = self.generate_meta_features(input)?;
337
338        // Make final prediction with meta-model
339        self.meta_model.forward(&meta_features.view())
340    }
341
342    fn num_models(&self) -> usize {
343        self.base_models.len() + 1 // base models + meta model
344    }
345}
346
347/// Bagging (Bootstrap Aggregating) utilities.
348///
349/// Generates bootstrap samples for training ensemble members.
350#[derive(Debug)]
351pub struct BaggingHelper {
352    /// Number of bootstrap samples.
353    pub n_estimators: usize,
354    /// Random seed for reproducibility.
355    pub random_seed: u64,
356}
357
358impl BaggingHelper {
359    /// Create a new bagging helper.
360    ///
361    /// # Arguments
362    /// * `n_estimators` - Number of bootstrap samples
363    /// * `random_seed` - Random seed
364    pub fn new(n_estimators: usize, random_seed: u64) -> TrainResult<Self> {
365        if n_estimators == 0 {
366            return Err(TrainError::InvalidParameter(
367                "n_estimators must be positive".to_string(),
368            ));
369        }
370        Ok(Self {
371            n_estimators,
372            random_seed,
373        })
374    }
375
376    /// Generate bootstrap sample indices.
377    ///
378    /// # Arguments
379    /// * `n_samples` - Total number of samples
380    /// * `estimator_idx` - Index of the estimator (for seeding)
381    ///
382    /// # Returns
383    /// Bootstrap sample indices (with replacement)
384    pub fn generate_bootstrap_indices(&self, n_samples: usize, estimator_idx: usize) -> Vec<usize> {
385        #[allow(unused_imports)]
386        use scirs2_core::random::{Rng, SeedableRng, StdRng};
387
388        let seed = self.random_seed.wrapping_add(estimator_idx as u64);
389        let mut rng = StdRng::seed_from_u64(seed);
390
391        (0..n_samples)
392            .map(|_| rng.gen_range(0..n_samples))
393            .collect()
394    }
395
396    /// Get out-of-bag (OOB) indices for an estimator.
397    ///
398    /// # Arguments
399    /// * `n_samples` - Total number of samples
400    /// * `bootstrap_indices` - Bootstrap sample indices
401    ///
402    /// # Returns
403    /// OOB sample indices (not in bootstrap sample)
404    pub fn get_oob_indices(&self, n_samples: usize, bootstrap_indices: &[usize]) -> Vec<usize> {
405        let bootstrap_set: std::collections::HashSet<usize> =
406            bootstrap_indices.iter().cloned().collect();
407
408        (0..n_samples)
409            .filter(|idx| !bootstrap_set.contains(idx))
410            .collect()
411    }
412}
413
414/// Model Soup - Weight-space averaging for improved generalization.
415///
416/// From "Model soups: averaging weights of multiple fine-tuned models
417/// improves accuracy without increasing inference time" (Wortsman et al., 2022).
418///
419/// Model soups average the *weights* of multiple models (not predictions), which can:
420/// - Improve accuracy compared to individual models
421/// - No inference cost (single model at test time)
422/// - Work across different hyperparameters and random seeds
423/// - Particularly effective for models fine-tuned from same initialization
424///
425/// Two main recipes:
426/// - **Uniform Soup**: Simple average of all model weights
427/// - **Greedy Soup**: Iteratively add models that improve validation performance
428///
429/// # Example
430/// ```
431/// use tensorlogic_train::{ModelSoup, SoupRecipe};
432/// use std::collections::HashMap;
433/// use scirs2_core::ndarray::Array2;
434///
435/// // Collect weights from multiple fine-tuned models
436/// // let model_weights = vec![weights1, weights2, weights3];
437/// // let soup = ModelSoup::uniform_soup(model_weights);
438/// // let averaged_weights = soup.weights();
439/// ```
440#[derive(Debug, Clone)]
441pub struct ModelSoup {
442    /// Averaged model weights
443    weights: HashMap<String, Array2<f64>>,
444    /// Number of models in the soup
445    num_models: usize,
446    /// Recipe used to create the soup
447    recipe: SoupRecipe,
448}
449
450/// Recipe for creating model soups
451#[derive(Debug, Clone, Copy, PartialEq, Eq)]
452pub enum SoupRecipe {
453    /// Uniform averaging of all models
454    Uniform,
455    /// Greedy selection based on validation performance
456    Greedy,
457    /// Custom weighted averaging
458    Weighted,
459}
460
461impl ModelSoup {
462    /// Create a uniform soup by averaging all model weights equally.
463    ///
464    /// # Arguments
465    /// * `model_weights` - Weights from multiple fine-tuned models
466    ///
467    /// # Returns
468    /// Model soup with uniformly averaged weights
469    ///
470    /// # Example
471    /// ```
472    /// use tensorlogic_train::ModelSoup;
473    /// use std::collections::HashMap;
474    /// use scirs2_core::ndarray::array;
475    ///
476    /// let mut weights1 = HashMap::new();
477    /// weights1.insert("w".to_string(), array![[1.0, 2.0]]);
478    ///
479    /// let mut weights2 = HashMap::new();
480    /// weights2.insert("w".to_string(), array![[3.0, 4.0]]);
481    ///
482    /// let soup = ModelSoup::uniform_soup(vec![weights1, weights2]).expect("unwrap");
483    /// // Averaged weights: [[2.0, 3.0]]
484    /// ```
485    pub fn uniform_soup(model_weights: Vec<HashMap<String, Array2<f64>>>) -> TrainResult<Self> {
486        if model_weights.is_empty() {
487            return Err(TrainError::InvalidParameter(
488                "At least one model required for soup".to_string(),
489            ));
490        }
491
492        let num_models = model_weights.len();
493        let mut averaged_weights = HashMap::new();
494
495        // Get parameter names from first model
496        let param_names: Vec<String> = model_weights[0].keys().cloned().collect();
497
498        // Average each parameter across all models
499        for param_name in param_names {
500            // Initialize with zeros
501            let shape = model_weights[0][&param_name].raw_dim();
502            let mut averaged_param = Array2::zeros(shape);
503
504            // Sum across all models
505            for model_weight in &model_weights {
506                if let Some(param) = model_weight.get(&param_name) {
507                    averaged_param += param;
508                } else {
509                    return Err(TrainError::InvalidParameter(format!(
510                        "Parameter '{}' not found in all models",
511                        param_name
512                    )));
513                }
514            }
515
516            // Divide by number of models
517            averaged_param /= num_models as f64;
518            averaged_weights.insert(param_name, averaged_param);
519        }
520
521        Ok(Self {
522            weights: averaged_weights,
523            num_models,
524            recipe: SoupRecipe::Uniform,
525        })
526    }
527
528    /// Create a greedy soup by iteratively adding models that improve validation performance.
529    ///
530    /// # Arguments
531    /// * `model_weights` - Weights from multiple fine-tuned models
532    /// * `val_accuracies` - Validation accuracy for each model
533    ///
534    /// # Returns
535    /// Model soup with greedily selected and averaged weights
536    ///
537    /// # Algorithm
538    /// 1. Start with best single model
539    /// 2. Try adding each remaining model to soup
540    /// 3. Keep additions that improve validation performance
541    /// 4. Repeat until no improvement
542    pub fn greedy_soup(
543        model_weights: Vec<HashMap<String, Array2<f64>>>,
544        val_accuracies: Vec<f64>,
545    ) -> TrainResult<Self> {
546        if model_weights.is_empty() {
547            return Err(TrainError::InvalidParameter(
548                "At least one model required for soup".to_string(),
549            ));
550        }
551
552        if model_weights.len() != val_accuracies.len() {
553            return Err(TrainError::InvalidParameter(
554                "Number of models must match number of validation accuracies".to_string(),
555            ));
556        }
557
558        // Find best single model as starting point
559        let best_idx = val_accuracies
560            .iter()
561            .enumerate()
562            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
563            .map(|(idx, _)| idx)
564            .unwrap_or(0);
565
566        let mut soup_indices = vec![best_idx];
567        let mut best_accuracy = val_accuracies[best_idx];
568
569        // Greedily add models that improve performance
570        loop {
571            let mut improved = false;
572            let mut best_addition = None;
573            let mut best_new_accuracy = best_accuracy;
574
575            // Try adding each model not yet in soup
576            for (idx, acc) in val_accuracies.iter().enumerate() {
577                if soup_indices.contains(&idx) {
578                    continue;
579                }
580
581                // Estimate accuracy if we add this model
582                // (In practice, you'd evaluate on validation set, but we use provided accuracy)
583                let potential_accuracy = (*acc + best_accuracy) / 2.0;
584
585                if potential_accuracy > best_new_accuracy {
586                    best_new_accuracy = potential_accuracy;
587                    best_addition = Some(idx);
588                    improved = true;
589                }
590            }
591
592            if improved {
593                if let Some(idx) = best_addition {
594                    soup_indices.push(idx);
595                    best_accuracy = best_new_accuracy;
596                } else {
597                    break;
598                }
599            } else {
600                break;
601            }
602        }
603
604        // Create soup from selected models
605        let selected_weights: Vec<_> = soup_indices
606            .iter()
607            .map(|&idx| model_weights[idx].clone())
608            .collect();
609
610        let mut soup = Self::uniform_soup(selected_weights)?;
611        soup.recipe = SoupRecipe::Greedy;
612        soup.num_models = soup_indices.len();
613
614        Ok(soup)
615    }
616
617    /// Create a weighted soup with custom weights for each model.
618    ///
619    /// # Arguments
620    /// * `model_weights` - Weights from multiple fine-tuned models
621    /// * `weights` - Weight for each model (will be normalized to sum to 1)
622    ///
623    /// # Returns
624    /// Model soup with weighted averaged parameters
625    pub fn weighted_soup(
626        model_weights: Vec<HashMap<String, Array2<f64>>>,
627        weights: Vec<f64>,
628    ) -> TrainResult<Self> {
629        if model_weights.is_empty() {
630            return Err(TrainError::InvalidParameter(
631                "At least one model required for soup".to_string(),
632            ));
633        }
634
635        if model_weights.len() != weights.len() {
636            return Err(TrainError::InvalidParameter(
637                "Number of models must match number of weights".to_string(),
638            ));
639        }
640
641        // Normalize weights
642        let sum: f64 = weights.iter().sum();
643        if sum <= 0.0 {
644            return Err(TrainError::InvalidParameter(
645                "Weights must sum to positive value".to_string(),
646            ));
647        }
648
649        let normalized_weights: Vec<f64> = weights.iter().map(|w| w / sum).collect();
650
651        // Weighted average
652        let num_models = model_weights.len();
653        let mut averaged_weights = HashMap::new();
654        let param_names: Vec<String> = model_weights[0].keys().cloned().collect();
655
656        for param_name in param_names {
657            let shape = model_weights[0][&param_name].raw_dim();
658            let mut averaged_param = Array2::zeros(shape);
659
660            for (model_idx, model_weight) in model_weights.iter().enumerate() {
661                if let Some(param) = model_weight.get(&param_name) {
662                    averaged_param = averaged_param + param * normalized_weights[model_idx];
663                } else {
664                    return Err(TrainError::InvalidParameter(format!(
665                        "Parameter '{}' not found in all models",
666                        param_name
667                    )));
668                }
669            }
670
671            averaged_weights.insert(param_name, averaged_param);
672        }
673
674        Ok(Self {
675            weights: averaged_weights,
676            num_models,
677            recipe: SoupRecipe::Weighted,
678        })
679    }
680
681    /// Get the averaged weights from the soup.
682    pub fn weights(&self) -> &HashMap<String, Array2<f64>> {
683        &self.weights
684    }
685
686    /// Get the number of models in the soup.
687    pub fn num_models(&self) -> usize {
688        self.num_models
689    }
690
691    /// Get the recipe used to create the soup.
692    pub fn recipe(&self) -> SoupRecipe {
693        self.recipe
694    }
695
696    /// Get a specific parameter by name.
697    pub fn get_parameter(&self, name: &str) -> Option<&Array2<f64>> {
698        self.weights.get(name)
699    }
700
701    /// Load weights into a model (consumes the soup).
702    ///
703    /// This is a convenience method that returns the weights for loading into a model.
704    pub fn into_weights(self) -> HashMap<String, Array2<f64>> {
705        self.weights
706    }
707}
708
709#[cfg(test)]
710mod tests {
711    use super::*;
712    use crate::LinearModel;
713    use scirs2_core::ndarray::array;
714
715    fn create_test_model() -> LinearModel {
716        // Create a 2-input, 2-output linear model
717        LinearModel::new(2, 2)
718    }
719
720    #[test]
721    fn test_voting_ensemble_hard() {
722        let model1 = create_test_model();
723        let model2 = create_test_model();
724
725        let ensemble = VotingEnsemble::new(vec![model1, model2], VotingMode::Hard).expect("unwrap");
726
727        assert_eq!(ensemble.num_models(), 2);
728        assert_eq!(ensemble.mode(), VotingMode::Hard);
729
730        let input = array![[1.0, 0.0], [0.0, 1.0]];
731        let pred = ensemble.predict(&input).expect("unwrap");
732
733        assert_eq!(pred.shape(), &[2, 2]);
734    }
735
736    #[test]
737    fn test_voting_ensemble_soft() {
738        let model1 = create_test_model();
739        let model2 = create_test_model();
740
741        let ensemble = VotingEnsemble::new(vec![model1, model2], VotingMode::Soft).expect("unwrap");
742
743        let input = array![[1.0, 0.0]];
744        let pred = ensemble.predict(&input).expect("unwrap");
745
746        assert_eq!(pred.shape(), &[1, 2]);
747    }
748
749    #[test]
750    fn test_voting_ensemble_with_weights() {
751        let model1 = create_test_model();
752        let model2 = create_test_model();
753
754        let ensemble = VotingEnsemble::new(vec![model1, model2], VotingMode::Soft)
755            .expect("unwrap")
756            .with_weights(vec![0.7, 0.3])
757            .expect("unwrap");
758
759        let input = array![[1.0, 0.0]];
760        let pred = ensemble.predict(&input).expect("unwrap");
761
762        assert_eq!(pred.shape(), &[1, 2]);
763    }
764
765    #[test]
766    fn test_voting_ensemble_invalid_weights() {
767        let model1 = create_test_model();
768        let model2 = create_test_model();
769
770        let ensemble = VotingEnsemble::new(vec![model1, model2], VotingMode::Soft).expect("unwrap");
771
772        // Wrong number of weights
773        let result = ensemble.with_weights(vec![0.5]);
774        assert!(result.is_err());
775
776        // Weights don't sum to 1.0
777        let model3 = create_test_model();
778        let model4 = create_test_model();
779        let ensemble2 =
780            VotingEnsemble::new(vec![model3, model4], VotingMode::Soft).expect("unwrap");
781        let result = ensemble2.with_weights(vec![0.5, 0.6]);
782        assert!(result.is_err());
783    }
784
785    #[test]
786    fn test_averaging_ensemble() {
787        let model1 = create_test_model();
788        let model2 = create_test_model();
789
790        let ensemble = AveragingEnsemble::new(vec![model1, model2]).expect("unwrap");
791
792        assert_eq!(ensemble.num_models(), 2);
793
794        let input = array![[1.0, 0.0], [0.0, 1.0]];
795        let pred = ensemble.predict(&input).expect("unwrap");
796
797        assert_eq!(pred.shape(), &[2, 2]);
798    }
799
800    #[test]
801    fn test_averaging_ensemble_with_weights() {
802        let model1 = create_test_model();
803        let model2 = create_test_model();
804
805        let ensemble = AveragingEnsemble::new(vec![model1, model2])
806            .expect("unwrap")
807            .with_weights(vec![2.0, 1.0])
808            .expect("unwrap");
809
810        let input = array![[1.0, 0.0]];
811        let pred = ensemble.predict(&input).expect("unwrap");
812
813        assert_eq!(pred.shape(), &[1, 2]);
814    }
815
816    #[test]
817    fn test_stacking_ensemble() {
818        let base1 = create_test_model(); // 2 inputs, 2 outputs
819        let base2 = create_test_model(); // 2 inputs, 2 outputs
820        let meta = LinearModel::new(4, 2); // 4 inputs (2 base models × 2 outputs), 2 outputs
821
822        let ensemble = StackingEnsemble::new(vec![base1, base2], meta).expect("unwrap");
823
824        assert_eq!(ensemble.num_models(), 3); // 2 base + 1 meta
825
826        let input = array![[1.0, 0.0]];
827        let pred = ensemble.predict(&input).expect("unwrap");
828
829        // Meta-model takes concatenated predictions as input
830        assert_eq!(pred.nrows(), 1);
831    }
832
833    #[test]
834    fn test_stacking_meta_features() {
835        let base1 = create_test_model();
836        let base2 = create_test_model();
837        let meta = create_test_model();
838
839        let ensemble = StackingEnsemble::new(vec![base1, base2], meta).expect("unwrap");
840
841        let input = array![[1.0, 0.0]];
842        let meta_features = ensemble.generate_meta_features(&input).expect("unwrap");
843
844        // Should concatenate predictions from 2 base models
845        // Each base model outputs 2 features, so total = 2 * 2 = 4
846        assert_eq!(meta_features.shape(), &[1, 4]);
847    }
848
849    #[test]
850    fn test_bagging_helper() {
851        let helper = BaggingHelper::new(10, 42).expect("unwrap");
852
853        let indices = helper.generate_bootstrap_indices(100, 0);
854        assert_eq!(indices.len(), 100);
855
856        // All indices should be in valid range
857        assert!(indices.iter().all(|&i| i < 100));
858
859        // OOB indices should be the complement
860        let oob = helper.get_oob_indices(100, &indices);
861        assert!(!oob.is_empty());
862
863        for &idx in &oob {
864            assert!(!indices.contains(&idx));
865        }
866    }
867
868    #[test]
869    fn test_bagging_helper_different_seeds() {
870        let helper = BaggingHelper::new(10, 42).expect("unwrap");
871
872        let indices1 = helper.generate_bootstrap_indices(50, 0);
873        let indices2 = helper.generate_bootstrap_indices(50, 1);
874
875        // Different seeds should produce different samples
876        assert_ne!(indices1, indices2);
877    }
878
879    #[test]
880    fn test_bagging_helper_invalid() {
881        assert!(BaggingHelper::new(0, 42).is_err());
882    }
883
884    #[test]
885    fn test_ensemble_empty_models() {
886        let result = VotingEnsemble::<LinearModel>::new(vec![], VotingMode::Hard);
887        assert!(result.is_err());
888
889        let result = AveragingEnsemble::<LinearModel>::new(vec![]);
890        assert!(result.is_err());
891    }
892
893    // Model Soup Tests
894    #[test]
895    fn test_uniform_soup() {
896        let mut weights1 = HashMap::new();
897        weights1.insert("w".to_string(), array![[1.0, 2.0]]);
898        weights1.insert("b".to_string(), array![[0.5]]);
899
900        let mut weights2 = HashMap::new();
901        weights2.insert("w".to_string(), array![[3.0, 4.0]]);
902        weights2.insert("b".to_string(), array![[1.5]]);
903
904        let soup = ModelSoup::uniform_soup(vec![weights1, weights2]).expect("unwrap");
905
906        assert_eq!(soup.num_models(), 2);
907        assert_eq!(soup.recipe(), SoupRecipe::Uniform);
908
909        // Check averaged weights
910        let w = soup.get_parameter("w").expect("unwrap");
911        assert_eq!(w[[0, 0]], 2.0); // (1.0 + 3.0) / 2
912        assert_eq!(w[[0, 1]], 3.0); // (2.0 + 4.0) / 2
913
914        let b = soup.get_parameter("b").expect("unwrap");
915        assert_eq!(b[[0, 0]], 1.0); // (0.5 + 1.5) / 2
916    }
917
918    #[test]
919    fn test_uniform_soup_three_models() {
920        let mut weights1 = HashMap::new();
921        weights1.insert("w".to_string(), array![[1.0]]);
922
923        let mut weights2 = HashMap::new();
924        weights2.insert("w".to_string(), array![[2.0]]);
925
926        let mut weights3 = HashMap::new();
927        weights3.insert("w".to_string(), array![[3.0]]);
928
929        let soup = ModelSoup::uniform_soup(vec![weights1, weights2, weights3]).expect("unwrap");
930
931        let w = soup.get_parameter("w").expect("unwrap");
932        assert_eq!(w[[0, 0]], 2.0); // (1.0 + 2.0 + 3.0) / 3
933    }
934
935    #[test]
936    fn test_greedy_soup() {
937        let mut weights1 = HashMap::new();
938        weights1.insert("w".to_string(), array![[1.0]]);
939
940        let mut weights2 = HashMap::new();
941        weights2.insert("w".to_string(), array![[2.0]]);
942
943        let mut weights3 = HashMap::new();
944        weights3.insert("w".to_string(), array![[3.0]]);
945
946        let accuracies = vec![0.8, 0.9, 0.85]; // Model 2 is best
947
948        let soup =
949            ModelSoup::greedy_soup(vec![weights1, weights2, weights3], accuracies).expect("unwrap");
950
951        assert_eq!(soup.recipe(), SoupRecipe::Greedy);
952        assert!(soup.num_models() >= 1); // At least the best model
953    }
954
955    #[test]
956    fn test_weighted_soup() {
957        let mut weights1 = HashMap::new();
958        weights1.insert("w".to_string(), array![[1.0, 2.0]]);
959
960        let mut weights2 = HashMap::new();
961        weights2.insert("w".to_string(), array![[3.0, 4.0]]);
962
963        // Weight model 1 twice as much as model 2
964        let soup =
965            ModelSoup::weighted_soup(vec![weights1, weights2], vec![2.0, 1.0]).expect("unwrap");
966
967        assert_eq!(soup.recipe(), SoupRecipe::Weighted);
968
969        // Check weighted average: (2*1 + 1*3) / 3 = 5/3 ≈ 1.667
970        let w = soup.get_parameter("w").expect("unwrap");
971        assert!((w[[0, 0]] - 1.6666666).abs() < 1e-5);
972        assert!((w[[0, 1]] - 2.6666666).abs() < 1e-5);
973    }
974
975    #[test]
976    fn test_soup_empty_models() {
977        let result = ModelSoup::uniform_soup(vec![]);
978        assert!(result.is_err());
979    }
980
981    #[test]
982    fn test_soup_mismatched_parameters() {
983        let mut weights1 = HashMap::new();
984        weights1.insert("w".to_string(), array![[1.0]]);
985
986        let mut weights2 = HashMap::new();
987        weights2.insert("b".to_string(), array![[2.0]]); // Different parameter name
988
989        let result = ModelSoup::uniform_soup(vec![weights1, weights2]);
990        assert!(result.is_err());
991    }
992
993    #[test]
994    fn test_greedy_soup_mismatched_lengths() {
995        let mut weights1 = HashMap::new();
996        weights1.insert("w".to_string(), array![[1.0]]);
997
998        let result = ModelSoup::greedy_soup(vec![weights1], vec![0.8, 0.9]);
999        assert!(result.is_err());
1000    }
1001
1002    #[test]
1003    fn test_weighted_soup_invalid_weights() {
1004        let mut weights1 = HashMap::new();
1005        weights1.insert("w".to_string(), array![[1.0]]);
1006
1007        let mut weights2 = HashMap::new();
1008        weights2.insert("w".to_string(), array![[2.0]]);
1009
1010        // Negative weights
1011        let result =
1012            ModelSoup::weighted_soup(vec![weights1.clone(), weights2.clone()], vec![-1.0, 1.0]);
1013        assert!(result.is_err());
1014
1015        // Mismatched lengths
1016        let result = ModelSoup::weighted_soup(vec![weights1], vec![1.0, 2.0]);
1017        assert!(result.is_err());
1018    }
1019
1020    #[test]
1021    fn test_soup_into_weights() {
1022        let mut weights1 = HashMap::new();
1023        weights1.insert("w".to_string(), array![[1.0]]);
1024
1025        let mut weights2 = HashMap::new();
1026        weights2.insert("w".to_string(), array![[3.0]]);
1027
1028        let soup = ModelSoup::uniform_soup(vec![weights1, weights2]).expect("unwrap");
1029        let final_weights = soup.into_weights();
1030
1031        assert_eq!(final_weights["w"][[0, 0]], 2.0);
1032    }
1033
1034    #[test]
1035    fn test_soup_multidimensional_weights() {
1036        let mut weights1 = HashMap::new();
1037        weights1.insert("conv".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
1038
1039        let mut weights2 = HashMap::new();
1040        weights2.insert("conv".to_string(), array![[5.0, 6.0], [7.0, 8.0]]);
1041
1042        let soup = ModelSoup::uniform_soup(vec![weights1, weights2]).expect("unwrap");
1043        let conv = soup.get_parameter("conv").expect("unwrap");
1044
1045        assert_eq!(conv[[0, 0]], 3.0); // (1 + 5) / 2
1046        assert_eq!(conv[[0, 1]], 4.0); // (2 + 6) / 2
1047        assert_eq!(conv[[1, 0]], 5.0); // (3 + 7) / 2
1048        assert_eq!(conv[[1, 1]], 6.0); // (4 + 8) / 2
1049    }
1050}