1use crate::{Model, TrainError, TrainResult};
11use scirs2_core::ndarray::Array2;
12use std::collections::HashMap;
13
14pub trait Ensemble {
16 fn predict(&self, input: &Array2<f64>) -> TrainResult<Array2<f64>>;
24
25 fn num_models(&self) -> usize;
27}
28
29#[derive(Debug, Clone, Copy, PartialEq)]
35pub enum VotingMode {
36 Hard,
38 Soft,
40}
41
42#[derive(Debug)]
44pub struct VotingEnsemble<M: Model> {
45 models: Vec<M>,
47 mode: VotingMode,
49 weights: Option<Vec<f64>>,
51}
52
53impl<M: Model> VotingEnsemble<M> {
54 pub fn new(models: Vec<M>, mode: VotingMode) -> TrainResult<Self> {
60 if models.is_empty() {
61 return Err(TrainError::InvalidParameter(
62 "Ensemble must have at least one model".to_string(),
63 ));
64 }
65 Ok(Self {
66 models,
67 mode,
68 weights: None,
69 })
70 }
71
72 pub fn with_weights(mut self, weights: Vec<f64>) -> TrainResult<Self> {
77 if weights.len() != self.models.len() {
78 return Err(TrainError::InvalidParameter(
79 "Number of weights must match number of models".to_string(),
80 ));
81 }
82
83 let sum: f64 = weights.iter().sum();
84 if (sum - 1.0).abs() > 1e-6 {
85 return Err(TrainError::InvalidParameter(
86 "Weights must sum to 1.0".to_string(),
87 ));
88 }
89
90 self.weights = Some(weights);
91 Ok(self)
92 }
93
94 pub fn mode(&self) -> VotingMode {
96 self.mode
97 }
98}
99
100impl<M: Model> Ensemble for VotingEnsemble<M> {
101 fn predict(&self, input: &Array2<f64>) -> TrainResult<Array2<f64>> {
102 let batch_size = input.nrows();
103
104 let mut all_predictions = Vec::with_capacity(self.models.len());
106 for model in &self.models {
107 let pred = model.forward(&input.view())?;
108 all_predictions.push(pred);
109 }
110
111 let num_classes = all_predictions[0].ncols();
113 let mut ensemble_pred = Array2::zeros((batch_size, num_classes));
114
115 match self.mode {
116 VotingMode::Hard => {
117 for i in 0..batch_size {
119 let mut votes = vec![0.0; num_classes];
120
121 for (model_idx, pred) in all_predictions.iter().enumerate() {
122 let row = pred.row(i);
124 let class_idx = row
125 .iter()
126 .enumerate()
127 .max_by(|(_, a), (_, b)| {
128 a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
129 })
130 .map(|(idx, _)| idx)
131 .unwrap_or(0);
132
133 let weight = self.weights.as_ref().map(|w| w[model_idx]).unwrap_or(1.0);
134 votes[class_idx] += weight;
135 }
136
137 let max_votes = votes.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
139 let winning_class = votes
140 .iter()
141 .position(|&v| (v - max_votes).abs() < 1e-10)
142 .unwrap_or(0);
143
144 ensemble_pred[[i, winning_class]] = 1.0;
145 }
146 }
147 VotingMode::Soft => {
148 for i in 0..batch_size {
150 for j in 0..num_classes {
151 let mut weighted_sum = 0.0;
152
153 for (model_idx, pred) in all_predictions.iter().enumerate() {
154 let weight = self.weights.as_ref().map(|w| w[model_idx]).unwrap_or(1.0);
155 weighted_sum += pred[[i, j]] * weight;
156 }
157
158 let normalizer = if self.weights.is_some() {
159 1.0 } else {
161 self.models.len() as f64
162 };
163
164 ensemble_pred[[i, j]] = weighted_sum / normalizer;
165 }
166 }
167 }
168 }
169
170 Ok(ensemble_pred)
171 }
172
173 fn num_models(&self) -> usize {
174 self.models.len()
175 }
176}
177
178#[derive(Debug)]
182pub struct AveragingEnsemble<M: Model> {
183 models: Vec<M>,
185 weights: Option<Vec<f64>>,
187}
188
189impl<M: Model> AveragingEnsemble<M> {
190 pub fn new(models: Vec<M>) -> TrainResult<Self> {
195 if models.is_empty() {
196 return Err(TrainError::InvalidParameter(
197 "Ensemble must have at least one model".to_string(),
198 ));
199 }
200 Ok(Self {
201 models,
202 weights: None,
203 })
204 }
205
206 pub fn with_weights(mut self, weights: Vec<f64>) -> TrainResult<Self> {
211 if weights.len() != self.models.len() {
212 return Err(TrainError::InvalidParameter(
213 "Number of weights must match number of models".to_string(),
214 ));
215 }
216
217 let sum: f64 = weights.iter().sum();
219 if sum <= 0.0 {
220 return Err(TrainError::InvalidParameter(
221 "Weights must sum to a positive value".to_string(),
222 ));
223 }
224
225 let normalized_weights: Vec<f64> = weights.iter().map(|w| w / sum).collect();
226 self.weights = Some(normalized_weights);
227 Ok(self)
228 }
229}
230
231impl<M: Model> Ensemble for AveragingEnsemble<M> {
232 fn predict(&self, input: &Array2<f64>) -> TrainResult<Array2<f64>> {
233 let mut all_predictions = Vec::with_capacity(self.models.len());
235 for model in &self.models {
236 let pred = model.forward(&input.view())?;
237 all_predictions.push(pred);
238 }
239
240 let shape = all_predictions[0].raw_dim();
242 let mut ensemble_pred = Array2::zeros(shape);
243
244 for (model_idx, pred) in all_predictions.iter().enumerate() {
245 let weight = self.weights.as_ref().map(|w| w[model_idx]).unwrap_or(1.0);
246
247 for i in 0..pred.nrows() {
248 for j in 0..pred.ncols() {
249 ensemble_pred[[i, j]] += pred[[i, j]] * weight;
250 }
251 }
252 }
253
254 if self.weights.is_none() {
256 ensemble_pred /= self.models.len() as f64;
257 }
258
259 Ok(ensemble_pred)
260 }
261
262 fn num_models(&self) -> usize {
263 self.models.len()
264 }
265}
266
267#[derive(Debug)]
271pub struct StackingEnsemble<M: Model, Meta: Model> {
272 base_models: Vec<M>,
274 meta_model: Meta,
276}
277
278impl<M: Model, Meta: Model> StackingEnsemble<M, Meta> {
279 pub fn new(base_models: Vec<M>, meta_model: Meta) -> TrainResult<Self> {
285 if base_models.is_empty() {
286 return Err(TrainError::InvalidParameter(
287 "Ensemble must have at least one base model".to_string(),
288 ));
289 }
290 Ok(Self {
291 base_models,
292 meta_model,
293 })
294 }
295
296 pub fn generate_meta_features(&self, input: &Array2<f64>) -> TrainResult<Array2<f64>> {
304 let batch_size = input.nrows();
305
306 let mut all_predictions = Vec::with_capacity(self.base_models.len());
308 for model in &self.base_models {
309 let pred = model.forward(&input.view())?;
310 all_predictions.push(pred);
311 }
312
313 let num_features_per_model = all_predictions[0].ncols();
315 let total_features = self.base_models.len() * num_features_per_model;
316
317 let mut meta_features = Array2::zeros((batch_size, total_features));
318
319 for (model_idx, pred) in all_predictions.iter().enumerate() {
320 let start_col = model_idx * num_features_per_model;
321
322 for i in 0..batch_size {
323 for j in 0..num_features_per_model {
324 meta_features[[i, start_col + j]] = pred[[i, j]];
325 }
326 }
327 }
328
329 Ok(meta_features)
330 }
331}
332
333impl<M: Model, Meta: Model> Ensemble for StackingEnsemble<M, Meta> {
334 fn predict(&self, input: &Array2<f64>) -> TrainResult<Array2<f64>> {
335 let meta_features = self.generate_meta_features(input)?;
337
338 self.meta_model.forward(&meta_features.view())
340 }
341
342 fn num_models(&self) -> usize {
343 self.base_models.len() + 1 }
345}
346
347#[derive(Debug)]
351pub struct BaggingHelper {
352 pub n_estimators: usize,
354 pub random_seed: u64,
356}
357
358impl BaggingHelper {
359 pub fn new(n_estimators: usize, random_seed: u64) -> TrainResult<Self> {
365 if n_estimators == 0 {
366 return Err(TrainError::InvalidParameter(
367 "n_estimators must be positive".to_string(),
368 ));
369 }
370 Ok(Self {
371 n_estimators,
372 random_seed,
373 })
374 }
375
376 pub fn generate_bootstrap_indices(&self, n_samples: usize, estimator_idx: usize) -> Vec<usize> {
385 #[allow(unused_imports)]
386 use scirs2_core::random::{Rng, SeedableRng, StdRng};
387
388 let seed = self.random_seed.wrapping_add(estimator_idx as u64);
389 let mut rng = StdRng::seed_from_u64(seed);
390
391 (0..n_samples)
392 .map(|_| rng.gen_range(0..n_samples))
393 .collect()
394 }
395
396 pub fn get_oob_indices(&self, n_samples: usize, bootstrap_indices: &[usize]) -> Vec<usize> {
405 let bootstrap_set: std::collections::HashSet<usize> =
406 bootstrap_indices.iter().cloned().collect();
407
408 (0..n_samples)
409 .filter(|idx| !bootstrap_set.contains(idx))
410 .collect()
411 }
412}
413
414#[derive(Debug, Clone)]
441pub struct ModelSoup {
442 weights: HashMap<String, Array2<f64>>,
444 num_models: usize,
446 recipe: SoupRecipe,
448}
449
450#[derive(Debug, Clone, Copy, PartialEq, Eq)]
452pub enum SoupRecipe {
453 Uniform,
455 Greedy,
457 Weighted,
459}
460
461impl ModelSoup {
462 pub fn uniform_soup(model_weights: Vec<HashMap<String, Array2<f64>>>) -> TrainResult<Self> {
486 if model_weights.is_empty() {
487 return Err(TrainError::InvalidParameter(
488 "At least one model required for soup".to_string(),
489 ));
490 }
491
492 let num_models = model_weights.len();
493 let mut averaged_weights = HashMap::new();
494
495 let param_names: Vec<String> = model_weights[0].keys().cloned().collect();
497
498 for param_name in param_names {
500 let shape = model_weights[0][¶m_name].raw_dim();
502 let mut averaged_param = Array2::zeros(shape);
503
504 for model_weight in &model_weights {
506 if let Some(param) = model_weight.get(¶m_name) {
507 averaged_param += param;
508 } else {
509 return Err(TrainError::InvalidParameter(format!(
510 "Parameter '{}' not found in all models",
511 param_name
512 )));
513 }
514 }
515
516 averaged_param /= num_models as f64;
518 averaged_weights.insert(param_name, averaged_param);
519 }
520
521 Ok(Self {
522 weights: averaged_weights,
523 num_models,
524 recipe: SoupRecipe::Uniform,
525 })
526 }
527
528 pub fn greedy_soup(
543 model_weights: Vec<HashMap<String, Array2<f64>>>,
544 val_accuracies: Vec<f64>,
545 ) -> TrainResult<Self> {
546 if model_weights.is_empty() {
547 return Err(TrainError::InvalidParameter(
548 "At least one model required for soup".to_string(),
549 ));
550 }
551
552 if model_weights.len() != val_accuracies.len() {
553 return Err(TrainError::InvalidParameter(
554 "Number of models must match number of validation accuracies".to_string(),
555 ));
556 }
557
558 let best_idx = val_accuracies
560 .iter()
561 .enumerate()
562 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
563 .map(|(idx, _)| idx)
564 .unwrap_or(0);
565
566 let mut soup_indices = vec![best_idx];
567 let mut best_accuracy = val_accuracies[best_idx];
568
569 loop {
571 let mut improved = false;
572 let mut best_addition = None;
573 let mut best_new_accuracy = best_accuracy;
574
575 for (idx, acc) in val_accuracies.iter().enumerate() {
577 if soup_indices.contains(&idx) {
578 continue;
579 }
580
581 let potential_accuracy = (*acc + best_accuracy) / 2.0;
584
585 if potential_accuracy > best_new_accuracy {
586 best_new_accuracy = potential_accuracy;
587 best_addition = Some(idx);
588 improved = true;
589 }
590 }
591
592 if improved {
593 if let Some(idx) = best_addition {
594 soup_indices.push(idx);
595 best_accuracy = best_new_accuracy;
596 } else {
597 break;
598 }
599 } else {
600 break;
601 }
602 }
603
604 let selected_weights: Vec<_> = soup_indices
606 .iter()
607 .map(|&idx| model_weights[idx].clone())
608 .collect();
609
610 let mut soup = Self::uniform_soup(selected_weights)?;
611 soup.recipe = SoupRecipe::Greedy;
612 soup.num_models = soup_indices.len();
613
614 Ok(soup)
615 }
616
617 pub fn weighted_soup(
626 model_weights: Vec<HashMap<String, Array2<f64>>>,
627 weights: Vec<f64>,
628 ) -> TrainResult<Self> {
629 if model_weights.is_empty() {
630 return Err(TrainError::InvalidParameter(
631 "At least one model required for soup".to_string(),
632 ));
633 }
634
635 if model_weights.len() != weights.len() {
636 return Err(TrainError::InvalidParameter(
637 "Number of models must match number of weights".to_string(),
638 ));
639 }
640
641 let sum: f64 = weights.iter().sum();
643 if sum <= 0.0 {
644 return Err(TrainError::InvalidParameter(
645 "Weights must sum to positive value".to_string(),
646 ));
647 }
648
649 let normalized_weights: Vec<f64> = weights.iter().map(|w| w / sum).collect();
650
651 let num_models = model_weights.len();
653 let mut averaged_weights = HashMap::new();
654 let param_names: Vec<String> = model_weights[0].keys().cloned().collect();
655
656 for param_name in param_names {
657 let shape = model_weights[0][¶m_name].raw_dim();
658 let mut averaged_param = Array2::zeros(shape);
659
660 for (model_idx, model_weight) in model_weights.iter().enumerate() {
661 if let Some(param) = model_weight.get(¶m_name) {
662 averaged_param = averaged_param + param * normalized_weights[model_idx];
663 } else {
664 return Err(TrainError::InvalidParameter(format!(
665 "Parameter '{}' not found in all models",
666 param_name
667 )));
668 }
669 }
670
671 averaged_weights.insert(param_name, averaged_param);
672 }
673
674 Ok(Self {
675 weights: averaged_weights,
676 num_models,
677 recipe: SoupRecipe::Weighted,
678 })
679 }
680
681 pub fn weights(&self) -> &HashMap<String, Array2<f64>> {
683 &self.weights
684 }
685
686 pub fn num_models(&self) -> usize {
688 self.num_models
689 }
690
691 pub fn recipe(&self) -> SoupRecipe {
693 self.recipe
694 }
695
696 pub fn get_parameter(&self, name: &str) -> Option<&Array2<f64>> {
698 self.weights.get(name)
699 }
700
701 pub fn into_weights(self) -> HashMap<String, Array2<f64>> {
705 self.weights
706 }
707}
708
709#[cfg(test)]
710mod tests {
711 use super::*;
712 use crate::LinearModel;
713 use scirs2_core::ndarray::array;
714
715 fn create_test_model() -> LinearModel {
716 LinearModel::new(2, 2)
718 }
719
720 #[test]
721 fn test_voting_ensemble_hard() {
722 let model1 = create_test_model();
723 let model2 = create_test_model();
724
725 let ensemble = VotingEnsemble::new(vec![model1, model2], VotingMode::Hard).expect("unwrap");
726
727 assert_eq!(ensemble.num_models(), 2);
728 assert_eq!(ensemble.mode(), VotingMode::Hard);
729
730 let input = array![[1.0, 0.0], [0.0, 1.0]];
731 let pred = ensemble.predict(&input).expect("unwrap");
732
733 assert_eq!(pred.shape(), &[2, 2]);
734 }
735
736 #[test]
737 fn test_voting_ensemble_soft() {
738 let model1 = create_test_model();
739 let model2 = create_test_model();
740
741 let ensemble = VotingEnsemble::new(vec![model1, model2], VotingMode::Soft).expect("unwrap");
742
743 let input = array![[1.0, 0.0]];
744 let pred = ensemble.predict(&input).expect("unwrap");
745
746 assert_eq!(pred.shape(), &[1, 2]);
747 }
748
749 #[test]
750 fn test_voting_ensemble_with_weights() {
751 let model1 = create_test_model();
752 let model2 = create_test_model();
753
754 let ensemble = VotingEnsemble::new(vec![model1, model2], VotingMode::Soft)
755 .expect("unwrap")
756 .with_weights(vec![0.7, 0.3])
757 .expect("unwrap");
758
759 let input = array![[1.0, 0.0]];
760 let pred = ensemble.predict(&input).expect("unwrap");
761
762 assert_eq!(pred.shape(), &[1, 2]);
763 }
764
765 #[test]
766 fn test_voting_ensemble_invalid_weights() {
767 let model1 = create_test_model();
768 let model2 = create_test_model();
769
770 let ensemble = VotingEnsemble::new(vec![model1, model2], VotingMode::Soft).expect("unwrap");
771
772 let result = ensemble.with_weights(vec![0.5]);
774 assert!(result.is_err());
775
776 let model3 = create_test_model();
778 let model4 = create_test_model();
779 let ensemble2 =
780 VotingEnsemble::new(vec![model3, model4], VotingMode::Soft).expect("unwrap");
781 let result = ensemble2.with_weights(vec![0.5, 0.6]);
782 assert!(result.is_err());
783 }
784
785 #[test]
786 fn test_averaging_ensemble() {
787 let model1 = create_test_model();
788 let model2 = create_test_model();
789
790 let ensemble = AveragingEnsemble::new(vec![model1, model2]).expect("unwrap");
791
792 assert_eq!(ensemble.num_models(), 2);
793
794 let input = array![[1.0, 0.0], [0.0, 1.0]];
795 let pred = ensemble.predict(&input).expect("unwrap");
796
797 assert_eq!(pred.shape(), &[2, 2]);
798 }
799
800 #[test]
801 fn test_averaging_ensemble_with_weights() {
802 let model1 = create_test_model();
803 let model2 = create_test_model();
804
805 let ensemble = AveragingEnsemble::new(vec![model1, model2])
806 .expect("unwrap")
807 .with_weights(vec![2.0, 1.0])
808 .expect("unwrap");
809
810 let input = array![[1.0, 0.0]];
811 let pred = ensemble.predict(&input).expect("unwrap");
812
813 assert_eq!(pred.shape(), &[1, 2]);
814 }
815
816 #[test]
817 fn test_stacking_ensemble() {
818 let base1 = create_test_model(); let base2 = create_test_model(); let meta = LinearModel::new(4, 2); let ensemble = StackingEnsemble::new(vec![base1, base2], meta).expect("unwrap");
823
824 assert_eq!(ensemble.num_models(), 3); let input = array![[1.0, 0.0]];
827 let pred = ensemble.predict(&input).expect("unwrap");
828
829 assert_eq!(pred.nrows(), 1);
831 }
832
833 #[test]
834 fn test_stacking_meta_features() {
835 let base1 = create_test_model();
836 let base2 = create_test_model();
837 let meta = create_test_model();
838
839 let ensemble = StackingEnsemble::new(vec![base1, base2], meta).expect("unwrap");
840
841 let input = array![[1.0, 0.0]];
842 let meta_features = ensemble.generate_meta_features(&input).expect("unwrap");
843
844 assert_eq!(meta_features.shape(), &[1, 4]);
847 }
848
849 #[test]
850 fn test_bagging_helper() {
851 let helper = BaggingHelper::new(10, 42).expect("unwrap");
852
853 let indices = helper.generate_bootstrap_indices(100, 0);
854 assert_eq!(indices.len(), 100);
855
856 assert!(indices.iter().all(|&i| i < 100));
858
859 let oob = helper.get_oob_indices(100, &indices);
861 assert!(!oob.is_empty());
862
863 for &idx in &oob {
864 assert!(!indices.contains(&idx));
865 }
866 }
867
868 #[test]
869 fn test_bagging_helper_different_seeds() {
870 let helper = BaggingHelper::new(10, 42).expect("unwrap");
871
872 let indices1 = helper.generate_bootstrap_indices(50, 0);
873 let indices2 = helper.generate_bootstrap_indices(50, 1);
874
875 assert_ne!(indices1, indices2);
877 }
878
879 #[test]
880 fn test_bagging_helper_invalid() {
881 assert!(BaggingHelper::new(0, 42).is_err());
882 }
883
884 #[test]
885 fn test_ensemble_empty_models() {
886 let result = VotingEnsemble::<LinearModel>::new(vec![], VotingMode::Hard);
887 assert!(result.is_err());
888
889 let result = AveragingEnsemble::<LinearModel>::new(vec![]);
890 assert!(result.is_err());
891 }
892
893 #[test]
895 fn test_uniform_soup() {
896 let mut weights1 = HashMap::new();
897 weights1.insert("w".to_string(), array![[1.0, 2.0]]);
898 weights1.insert("b".to_string(), array![[0.5]]);
899
900 let mut weights2 = HashMap::new();
901 weights2.insert("w".to_string(), array![[3.0, 4.0]]);
902 weights2.insert("b".to_string(), array![[1.5]]);
903
904 let soup = ModelSoup::uniform_soup(vec![weights1, weights2]).expect("unwrap");
905
906 assert_eq!(soup.num_models(), 2);
907 assert_eq!(soup.recipe(), SoupRecipe::Uniform);
908
909 let w = soup.get_parameter("w").expect("unwrap");
911 assert_eq!(w[[0, 0]], 2.0); assert_eq!(w[[0, 1]], 3.0); let b = soup.get_parameter("b").expect("unwrap");
915 assert_eq!(b[[0, 0]], 1.0); }
917
918 #[test]
919 fn test_uniform_soup_three_models() {
920 let mut weights1 = HashMap::new();
921 weights1.insert("w".to_string(), array![[1.0]]);
922
923 let mut weights2 = HashMap::new();
924 weights2.insert("w".to_string(), array![[2.0]]);
925
926 let mut weights3 = HashMap::new();
927 weights3.insert("w".to_string(), array![[3.0]]);
928
929 let soup = ModelSoup::uniform_soup(vec![weights1, weights2, weights3]).expect("unwrap");
930
931 let w = soup.get_parameter("w").expect("unwrap");
932 assert_eq!(w[[0, 0]], 2.0); }
934
935 #[test]
936 fn test_greedy_soup() {
937 let mut weights1 = HashMap::new();
938 weights1.insert("w".to_string(), array![[1.0]]);
939
940 let mut weights2 = HashMap::new();
941 weights2.insert("w".to_string(), array![[2.0]]);
942
943 let mut weights3 = HashMap::new();
944 weights3.insert("w".to_string(), array![[3.0]]);
945
946 let accuracies = vec![0.8, 0.9, 0.85]; let soup =
949 ModelSoup::greedy_soup(vec![weights1, weights2, weights3], accuracies).expect("unwrap");
950
951 assert_eq!(soup.recipe(), SoupRecipe::Greedy);
952 assert!(soup.num_models() >= 1); }
954
955 #[test]
956 fn test_weighted_soup() {
957 let mut weights1 = HashMap::new();
958 weights1.insert("w".to_string(), array![[1.0, 2.0]]);
959
960 let mut weights2 = HashMap::new();
961 weights2.insert("w".to_string(), array![[3.0, 4.0]]);
962
963 let soup =
965 ModelSoup::weighted_soup(vec![weights1, weights2], vec![2.0, 1.0]).expect("unwrap");
966
967 assert_eq!(soup.recipe(), SoupRecipe::Weighted);
968
969 let w = soup.get_parameter("w").expect("unwrap");
971 assert!((w[[0, 0]] - 1.6666666).abs() < 1e-5);
972 assert!((w[[0, 1]] - 2.6666666).abs() < 1e-5);
973 }
974
975 #[test]
976 fn test_soup_empty_models() {
977 let result = ModelSoup::uniform_soup(vec![]);
978 assert!(result.is_err());
979 }
980
981 #[test]
982 fn test_soup_mismatched_parameters() {
983 let mut weights1 = HashMap::new();
984 weights1.insert("w".to_string(), array![[1.0]]);
985
986 let mut weights2 = HashMap::new();
987 weights2.insert("b".to_string(), array![[2.0]]); let result = ModelSoup::uniform_soup(vec![weights1, weights2]);
990 assert!(result.is_err());
991 }
992
993 #[test]
994 fn test_greedy_soup_mismatched_lengths() {
995 let mut weights1 = HashMap::new();
996 weights1.insert("w".to_string(), array![[1.0]]);
997
998 let result = ModelSoup::greedy_soup(vec![weights1], vec![0.8, 0.9]);
999 assert!(result.is_err());
1000 }
1001
1002 #[test]
1003 fn test_weighted_soup_invalid_weights() {
1004 let mut weights1 = HashMap::new();
1005 weights1.insert("w".to_string(), array![[1.0]]);
1006
1007 let mut weights2 = HashMap::new();
1008 weights2.insert("w".to_string(), array![[2.0]]);
1009
1010 let result =
1012 ModelSoup::weighted_soup(vec![weights1.clone(), weights2.clone()], vec![-1.0, 1.0]);
1013 assert!(result.is_err());
1014
1015 let result = ModelSoup::weighted_soup(vec![weights1], vec![1.0, 2.0]);
1017 assert!(result.is_err());
1018 }
1019
1020 #[test]
1021 fn test_soup_into_weights() {
1022 let mut weights1 = HashMap::new();
1023 weights1.insert("w".to_string(), array![[1.0]]);
1024
1025 let mut weights2 = HashMap::new();
1026 weights2.insert("w".to_string(), array![[3.0]]);
1027
1028 let soup = ModelSoup::uniform_soup(vec![weights1, weights2]).expect("unwrap");
1029 let final_weights = soup.into_weights();
1030
1031 assert_eq!(final_weights["w"][[0, 0]], 2.0);
1032 }
1033
1034 #[test]
1035 fn test_soup_multidimensional_weights() {
1036 let mut weights1 = HashMap::new();
1037 weights1.insert("conv".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
1038
1039 let mut weights2 = HashMap::new();
1040 weights2.insert("conv".to_string(), array![[5.0, 6.0], [7.0, 8.0]]);
1041
1042 let soup = ModelSoup::uniform_soup(vec![weights1, weights2]).expect("unwrap");
1043 let conv = soup.get_parameter("conv").expect("unwrap");
1044
1045 assert_eq!(conv[[0, 0]], 3.0); assert_eq!(conv[[0, 1]], 4.0); assert_eq!(conv[[1, 0]], 5.0); assert_eq!(conv[[1, 1]], 6.0); }
1050}