1use crate::{Model, TrainError, TrainResult};
11use scirs2_core::ndarray::Array2;
12use std::collections::HashMap;
13
14pub trait Ensemble {
16 fn predict(&self, input: &Array2<f64>) -> TrainResult<Array2<f64>>;
24
25 fn num_models(&self) -> usize;
27}
28
29#[derive(Debug, Clone, Copy, PartialEq)]
35pub enum VotingMode {
36 Hard,
38 Soft,
40}
41
42#[derive(Debug)]
44pub struct VotingEnsemble<M: Model> {
45 models: Vec<M>,
47 mode: VotingMode,
49 weights: Option<Vec<f64>>,
51}
52
53impl<M: Model> VotingEnsemble<M> {
54 pub fn new(models: Vec<M>, mode: VotingMode) -> TrainResult<Self> {
60 if models.is_empty() {
61 return Err(TrainError::InvalidParameter(
62 "Ensemble must have at least one model".to_string(),
63 ));
64 }
65 Ok(Self {
66 models,
67 mode,
68 weights: None,
69 })
70 }
71
72 pub fn with_weights(mut self, weights: Vec<f64>) -> TrainResult<Self> {
77 if weights.len() != self.models.len() {
78 return Err(TrainError::InvalidParameter(
79 "Number of weights must match number of models".to_string(),
80 ));
81 }
82
83 let sum: f64 = weights.iter().sum();
84 if (sum - 1.0).abs() > 1e-6 {
85 return Err(TrainError::InvalidParameter(
86 "Weights must sum to 1.0".to_string(),
87 ));
88 }
89
90 self.weights = Some(weights);
91 Ok(self)
92 }
93
94 pub fn mode(&self) -> VotingMode {
96 self.mode
97 }
98}
99
100impl<M: Model> Ensemble for VotingEnsemble<M> {
101 fn predict(&self, input: &Array2<f64>) -> TrainResult<Array2<f64>> {
102 let batch_size = input.nrows();
103
104 let mut all_predictions = Vec::with_capacity(self.models.len());
106 for model in &self.models {
107 let pred = model.forward(&input.view())?;
108 all_predictions.push(pred);
109 }
110
111 let num_classes = all_predictions[0].ncols();
113 let mut ensemble_pred = Array2::zeros((batch_size, num_classes));
114
115 match self.mode {
116 VotingMode::Hard => {
117 for i in 0..batch_size {
119 let mut votes = vec![0.0; num_classes];
120
121 for (model_idx, pred) in all_predictions.iter().enumerate() {
122 let row = pred.row(i);
124 let class_idx = row
125 .iter()
126 .enumerate()
127 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
128 .map(|(idx, _)| idx)
129 .unwrap_or(0);
130
131 let weight = self.weights.as_ref().map(|w| w[model_idx]).unwrap_or(1.0);
132 votes[class_idx] += weight;
133 }
134
135 let max_votes = votes.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
137 let winning_class = votes
138 .iter()
139 .position(|&v| (v - max_votes).abs() < 1e-10)
140 .unwrap();
141
142 ensemble_pred[[i, winning_class]] = 1.0;
143 }
144 }
145 VotingMode::Soft => {
146 for i in 0..batch_size {
148 for j in 0..num_classes {
149 let mut weighted_sum = 0.0;
150
151 for (model_idx, pred) in all_predictions.iter().enumerate() {
152 let weight = self.weights.as_ref().map(|w| w[model_idx]).unwrap_or(1.0);
153 weighted_sum += pred[[i, j]] * weight;
154 }
155
156 let normalizer = if self.weights.is_some() {
157 1.0 } else {
159 self.models.len() as f64
160 };
161
162 ensemble_pred[[i, j]] = weighted_sum / normalizer;
163 }
164 }
165 }
166 }
167
168 Ok(ensemble_pred)
169 }
170
171 fn num_models(&self) -> usize {
172 self.models.len()
173 }
174}
175
176#[derive(Debug)]
180pub struct AveragingEnsemble<M: Model> {
181 models: Vec<M>,
183 weights: Option<Vec<f64>>,
185}
186
187impl<M: Model> AveragingEnsemble<M> {
188 pub fn new(models: Vec<M>) -> TrainResult<Self> {
193 if models.is_empty() {
194 return Err(TrainError::InvalidParameter(
195 "Ensemble must have at least one model".to_string(),
196 ));
197 }
198 Ok(Self {
199 models,
200 weights: None,
201 })
202 }
203
204 pub fn with_weights(mut self, weights: Vec<f64>) -> TrainResult<Self> {
209 if weights.len() != self.models.len() {
210 return Err(TrainError::InvalidParameter(
211 "Number of weights must match number of models".to_string(),
212 ));
213 }
214
215 let sum: f64 = weights.iter().sum();
217 if sum <= 0.0 {
218 return Err(TrainError::InvalidParameter(
219 "Weights must sum to a positive value".to_string(),
220 ));
221 }
222
223 let normalized_weights: Vec<f64> = weights.iter().map(|w| w / sum).collect();
224 self.weights = Some(normalized_weights);
225 Ok(self)
226 }
227}
228
229impl<M: Model> Ensemble for AveragingEnsemble<M> {
230 fn predict(&self, input: &Array2<f64>) -> TrainResult<Array2<f64>> {
231 let mut all_predictions = Vec::with_capacity(self.models.len());
233 for model in &self.models {
234 let pred = model.forward(&input.view())?;
235 all_predictions.push(pred);
236 }
237
238 let shape = all_predictions[0].raw_dim();
240 let mut ensemble_pred = Array2::zeros(shape);
241
242 for (model_idx, pred) in all_predictions.iter().enumerate() {
243 let weight = self.weights.as_ref().map(|w| w[model_idx]).unwrap_or(1.0);
244
245 for i in 0..pred.nrows() {
246 for j in 0..pred.ncols() {
247 ensemble_pred[[i, j]] += pred[[i, j]] * weight;
248 }
249 }
250 }
251
252 if self.weights.is_none() {
254 ensemble_pred /= self.models.len() as f64;
255 }
256
257 Ok(ensemble_pred)
258 }
259
260 fn num_models(&self) -> usize {
261 self.models.len()
262 }
263}
264
265#[derive(Debug)]
269pub struct StackingEnsemble<M: Model, Meta: Model> {
270 base_models: Vec<M>,
272 meta_model: Meta,
274}
275
276impl<M: Model, Meta: Model> StackingEnsemble<M, Meta> {
277 pub fn new(base_models: Vec<M>, meta_model: Meta) -> TrainResult<Self> {
283 if base_models.is_empty() {
284 return Err(TrainError::InvalidParameter(
285 "Ensemble must have at least one base model".to_string(),
286 ));
287 }
288 Ok(Self {
289 base_models,
290 meta_model,
291 })
292 }
293
294 pub fn generate_meta_features(&self, input: &Array2<f64>) -> TrainResult<Array2<f64>> {
302 let batch_size = input.nrows();
303
304 let mut all_predictions = Vec::with_capacity(self.base_models.len());
306 for model in &self.base_models {
307 let pred = model.forward(&input.view())?;
308 all_predictions.push(pred);
309 }
310
311 let num_features_per_model = all_predictions[0].ncols();
313 let total_features = self.base_models.len() * num_features_per_model;
314
315 let mut meta_features = Array2::zeros((batch_size, total_features));
316
317 for (model_idx, pred) in all_predictions.iter().enumerate() {
318 let start_col = model_idx * num_features_per_model;
319
320 for i in 0..batch_size {
321 for j in 0..num_features_per_model {
322 meta_features[[i, start_col + j]] = pred[[i, j]];
323 }
324 }
325 }
326
327 Ok(meta_features)
328 }
329}
330
331impl<M: Model, Meta: Model> Ensemble for StackingEnsemble<M, Meta> {
332 fn predict(&self, input: &Array2<f64>) -> TrainResult<Array2<f64>> {
333 let meta_features = self.generate_meta_features(input)?;
335
336 self.meta_model.forward(&meta_features.view())
338 }
339
340 fn num_models(&self) -> usize {
341 self.base_models.len() + 1 }
343}
344
345#[derive(Debug)]
349pub struct BaggingHelper {
350 pub n_estimators: usize,
352 pub random_seed: u64,
354}
355
356impl BaggingHelper {
357 pub fn new(n_estimators: usize, random_seed: u64) -> TrainResult<Self> {
363 if n_estimators == 0 {
364 return Err(TrainError::InvalidParameter(
365 "n_estimators must be positive".to_string(),
366 ));
367 }
368 Ok(Self {
369 n_estimators,
370 random_seed,
371 })
372 }
373
374 pub fn generate_bootstrap_indices(&self, n_samples: usize, estimator_idx: usize) -> Vec<usize> {
383 #[allow(unused_imports)]
384 use scirs2_core::random::{Rng, SeedableRng, StdRng};
385
386 let seed = self.random_seed.wrapping_add(estimator_idx as u64);
387 let mut rng = StdRng::seed_from_u64(seed);
388
389 (0..n_samples)
390 .map(|_| rng.gen_range(0..n_samples))
391 .collect()
392 }
393
394 pub fn get_oob_indices(&self, n_samples: usize, bootstrap_indices: &[usize]) -> Vec<usize> {
403 let bootstrap_set: std::collections::HashSet<usize> =
404 bootstrap_indices.iter().cloned().collect();
405
406 (0..n_samples)
407 .filter(|idx| !bootstrap_set.contains(idx))
408 .collect()
409 }
410}
411
412#[derive(Debug, Clone)]
439pub struct ModelSoup {
440 weights: HashMap<String, Array2<f64>>,
442 num_models: usize,
444 recipe: SoupRecipe,
446}
447
448#[derive(Debug, Clone, Copy, PartialEq, Eq)]
450pub enum SoupRecipe {
451 Uniform,
453 Greedy,
455 Weighted,
457}
458
459impl ModelSoup {
460 pub fn uniform_soup(model_weights: Vec<HashMap<String, Array2<f64>>>) -> TrainResult<Self> {
484 if model_weights.is_empty() {
485 return Err(TrainError::InvalidParameter(
486 "At least one model required for soup".to_string(),
487 ));
488 }
489
490 let num_models = model_weights.len();
491 let mut averaged_weights = HashMap::new();
492
493 let param_names: Vec<String> = model_weights[0].keys().cloned().collect();
495
496 for param_name in param_names {
498 let shape = model_weights[0][¶m_name].raw_dim();
500 let mut averaged_param = Array2::zeros(shape);
501
502 for model_weight in &model_weights {
504 if let Some(param) = model_weight.get(¶m_name) {
505 averaged_param += param;
506 } else {
507 return Err(TrainError::InvalidParameter(format!(
508 "Parameter '{}' not found in all models",
509 param_name
510 )));
511 }
512 }
513
514 averaged_param /= num_models as f64;
516 averaged_weights.insert(param_name, averaged_param);
517 }
518
519 Ok(Self {
520 weights: averaged_weights,
521 num_models,
522 recipe: SoupRecipe::Uniform,
523 })
524 }
525
526 pub fn greedy_soup(
541 model_weights: Vec<HashMap<String, Array2<f64>>>,
542 val_accuracies: Vec<f64>,
543 ) -> TrainResult<Self> {
544 if model_weights.is_empty() {
545 return Err(TrainError::InvalidParameter(
546 "At least one model required for soup".to_string(),
547 ));
548 }
549
550 if model_weights.len() != val_accuracies.len() {
551 return Err(TrainError::InvalidParameter(
552 "Number of models must match number of validation accuracies".to_string(),
553 ));
554 }
555
556 let best_idx = val_accuracies
558 .iter()
559 .enumerate()
560 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
561 .map(|(idx, _)| idx)
562 .unwrap();
563
564 let mut soup_indices = vec![best_idx];
565 let mut best_accuracy = val_accuracies[best_idx];
566
567 loop {
569 let mut improved = false;
570 let mut best_addition = None;
571 let mut best_new_accuracy = best_accuracy;
572
573 for (idx, acc) in val_accuracies.iter().enumerate() {
575 if soup_indices.contains(&idx) {
576 continue;
577 }
578
579 let potential_accuracy = (*acc + best_accuracy) / 2.0;
582
583 if potential_accuracy > best_new_accuracy {
584 best_new_accuracy = potential_accuracy;
585 best_addition = Some(idx);
586 improved = true;
587 }
588 }
589
590 if improved {
591 if let Some(idx) = best_addition {
592 soup_indices.push(idx);
593 best_accuracy = best_new_accuracy;
594 } else {
595 break;
596 }
597 } else {
598 break;
599 }
600 }
601
602 let selected_weights: Vec<_> = soup_indices
604 .iter()
605 .map(|&idx| model_weights[idx].clone())
606 .collect();
607
608 let mut soup = Self::uniform_soup(selected_weights)?;
609 soup.recipe = SoupRecipe::Greedy;
610 soup.num_models = soup_indices.len();
611
612 Ok(soup)
613 }
614
615 pub fn weighted_soup(
624 model_weights: Vec<HashMap<String, Array2<f64>>>,
625 weights: Vec<f64>,
626 ) -> TrainResult<Self> {
627 if model_weights.is_empty() {
628 return Err(TrainError::InvalidParameter(
629 "At least one model required for soup".to_string(),
630 ));
631 }
632
633 if model_weights.len() != weights.len() {
634 return Err(TrainError::InvalidParameter(
635 "Number of models must match number of weights".to_string(),
636 ));
637 }
638
639 let sum: f64 = weights.iter().sum();
641 if sum <= 0.0 {
642 return Err(TrainError::InvalidParameter(
643 "Weights must sum to positive value".to_string(),
644 ));
645 }
646
647 let normalized_weights: Vec<f64> = weights.iter().map(|w| w / sum).collect();
648
649 let num_models = model_weights.len();
651 let mut averaged_weights = HashMap::new();
652 let param_names: Vec<String> = model_weights[0].keys().cloned().collect();
653
654 for param_name in param_names {
655 let shape = model_weights[0][¶m_name].raw_dim();
656 let mut averaged_param = Array2::zeros(shape);
657
658 for (model_idx, model_weight) in model_weights.iter().enumerate() {
659 if let Some(param) = model_weight.get(¶m_name) {
660 averaged_param = averaged_param + param * normalized_weights[model_idx];
661 } else {
662 return Err(TrainError::InvalidParameter(format!(
663 "Parameter '{}' not found in all models",
664 param_name
665 )));
666 }
667 }
668
669 averaged_weights.insert(param_name, averaged_param);
670 }
671
672 Ok(Self {
673 weights: averaged_weights,
674 num_models,
675 recipe: SoupRecipe::Weighted,
676 })
677 }
678
679 pub fn weights(&self) -> &HashMap<String, Array2<f64>> {
681 &self.weights
682 }
683
684 pub fn num_models(&self) -> usize {
686 self.num_models
687 }
688
689 pub fn recipe(&self) -> SoupRecipe {
691 self.recipe
692 }
693
694 pub fn get_parameter(&self, name: &str) -> Option<&Array2<f64>> {
696 self.weights.get(name)
697 }
698
699 pub fn into_weights(self) -> HashMap<String, Array2<f64>> {
703 self.weights
704 }
705}
706
707#[cfg(test)]
708mod tests {
709 use super::*;
710 use crate::LinearModel;
711 use scirs2_core::ndarray::array;
712
713 fn create_test_model() -> LinearModel {
714 LinearModel::new(2, 2)
716 }
717
718 #[test]
719 fn test_voting_ensemble_hard() {
720 let model1 = create_test_model();
721 let model2 = create_test_model();
722
723 let ensemble = VotingEnsemble::new(vec![model1, model2], VotingMode::Hard).unwrap();
724
725 assert_eq!(ensemble.num_models(), 2);
726 assert_eq!(ensemble.mode(), VotingMode::Hard);
727
728 let input = array![[1.0, 0.0], [0.0, 1.0]];
729 let pred = ensemble.predict(&input).unwrap();
730
731 assert_eq!(pred.shape(), &[2, 2]);
732 }
733
734 #[test]
735 fn test_voting_ensemble_soft() {
736 let model1 = create_test_model();
737 let model2 = create_test_model();
738
739 let ensemble = VotingEnsemble::new(vec![model1, model2], VotingMode::Soft).unwrap();
740
741 let input = array![[1.0, 0.0]];
742 let pred = ensemble.predict(&input).unwrap();
743
744 assert_eq!(pred.shape(), &[1, 2]);
745 }
746
747 #[test]
748 fn test_voting_ensemble_with_weights() {
749 let model1 = create_test_model();
750 let model2 = create_test_model();
751
752 let ensemble = VotingEnsemble::new(vec![model1, model2], VotingMode::Soft)
753 .unwrap()
754 .with_weights(vec![0.7, 0.3])
755 .unwrap();
756
757 let input = array![[1.0, 0.0]];
758 let pred = ensemble.predict(&input).unwrap();
759
760 assert_eq!(pred.shape(), &[1, 2]);
761 }
762
763 #[test]
764 fn test_voting_ensemble_invalid_weights() {
765 let model1 = create_test_model();
766 let model2 = create_test_model();
767
768 let ensemble = VotingEnsemble::new(vec![model1, model2], VotingMode::Soft).unwrap();
769
770 let result = ensemble.with_weights(vec![0.5]);
772 assert!(result.is_err());
773
774 let model3 = create_test_model();
776 let model4 = create_test_model();
777 let ensemble2 = VotingEnsemble::new(vec![model3, model4], VotingMode::Soft).unwrap();
778 let result = ensemble2.with_weights(vec![0.5, 0.6]);
779 assert!(result.is_err());
780 }
781
782 #[test]
783 fn test_averaging_ensemble() {
784 let model1 = create_test_model();
785 let model2 = create_test_model();
786
787 let ensemble = AveragingEnsemble::new(vec![model1, model2]).unwrap();
788
789 assert_eq!(ensemble.num_models(), 2);
790
791 let input = array![[1.0, 0.0], [0.0, 1.0]];
792 let pred = ensemble.predict(&input).unwrap();
793
794 assert_eq!(pred.shape(), &[2, 2]);
795 }
796
797 #[test]
798 fn test_averaging_ensemble_with_weights() {
799 let model1 = create_test_model();
800 let model2 = create_test_model();
801
802 let ensemble = AveragingEnsemble::new(vec![model1, model2])
803 .unwrap()
804 .with_weights(vec![2.0, 1.0])
805 .unwrap();
806
807 let input = array![[1.0, 0.0]];
808 let pred = ensemble.predict(&input).unwrap();
809
810 assert_eq!(pred.shape(), &[1, 2]);
811 }
812
813 #[test]
814 fn test_stacking_ensemble() {
815 let base1 = create_test_model(); let base2 = create_test_model(); let meta = LinearModel::new(4, 2); let ensemble = StackingEnsemble::new(vec![base1, base2], meta).unwrap();
820
821 assert_eq!(ensemble.num_models(), 3); let input = array![[1.0, 0.0]];
824 let pred = ensemble.predict(&input).unwrap();
825
826 assert_eq!(pred.nrows(), 1);
828 }
829
830 #[test]
831 fn test_stacking_meta_features() {
832 let base1 = create_test_model();
833 let base2 = create_test_model();
834 let meta = create_test_model();
835
836 let ensemble = StackingEnsemble::new(vec![base1, base2], meta).unwrap();
837
838 let input = array![[1.0, 0.0]];
839 let meta_features = ensemble.generate_meta_features(&input).unwrap();
840
841 assert_eq!(meta_features.shape(), &[1, 4]);
844 }
845
846 #[test]
847 fn test_bagging_helper() {
848 let helper = BaggingHelper::new(10, 42).unwrap();
849
850 let indices = helper.generate_bootstrap_indices(100, 0);
851 assert_eq!(indices.len(), 100);
852
853 assert!(indices.iter().all(|&i| i < 100));
855
856 let oob = helper.get_oob_indices(100, &indices);
858 assert!(!oob.is_empty());
859
860 for &idx in &oob {
861 assert!(!indices.contains(&idx));
862 }
863 }
864
865 #[test]
866 fn test_bagging_helper_different_seeds() {
867 let helper = BaggingHelper::new(10, 42).unwrap();
868
869 let indices1 = helper.generate_bootstrap_indices(50, 0);
870 let indices2 = helper.generate_bootstrap_indices(50, 1);
871
872 assert_ne!(indices1, indices2);
874 }
875
876 #[test]
877 fn test_bagging_helper_invalid() {
878 assert!(BaggingHelper::new(0, 42).is_err());
879 }
880
881 #[test]
882 fn test_ensemble_empty_models() {
883 let result = VotingEnsemble::<LinearModel>::new(vec![], VotingMode::Hard);
884 assert!(result.is_err());
885
886 let result = AveragingEnsemble::<LinearModel>::new(vec![]);
887 assert!(result.is_err());
888 }
889
890 #[test]
892 fn test_uniform_soup() {
893 let mut weights1 = HashMap::new();
894 weights1.insert("w".to_string(), array![[1.0, 2.0]]);
895 weights1.insert("b".to_string(), array![[0.5]]);
896
897 let mut weights2 = HashMap::new();
898 weights2.insert("w".to_string(), array![[3.0, 4.0]]);
899 weights2.insert("b".to_string(), array![[1.5]]);
900
901 let soup = ModelSoup::uniform_soup(vec![weights1, weights2]).unwrap();
902
903 assert_eq!(soup.num_models(), 2);
904 assert_eq!(soup.recipe(), SoupRecipe::Uniform);
905
906 let w = soup.get_parameter("w").unwrap();
908 assert_eq!(w[[0, 0]], 2.0); assert_eq!(w[[0, 1]], 3.0); let b = soup.get_parameter("b").unwrap();
912 assert_eq!(b[[0, 0]], 1.0); }
914
915 #[test]
916 fn test_uniform_soup_three_models() {
917 let mut weights1 = HashMap::new();
918 weights1.insert("w".to_string(), array![[1.0]]);
919
920 let mut weights2 = HashMap::new();
921 weights2.insert("w".to_string(), array![[2.0]]);
922
923 let mut weights3 = HashMap::new();
924 weights3.insert("w".to_string(), array![[3.0]]);
925
926 let soup = ModelSoup::uniform_soup(vec![weights1, weights2, weights3]).unwrap();
927
928 let w = soup.get_parameter("w").unwrap();
929 assert_eq!(w[[0, 0]], 2.0); }
931
932 #[test]
933 fn test_greedy_soup() {
934 let mut weights1 = HashMap::new();
935 weights1.insert("w".to_string(), array![[1.0]]);
936
937 let mut weights2 = HashMap::new();
938 weights2.insert("w".to_string(), array![[2.0]]);
939
940 let mut weights3 = HashMap::new();
941 weights3.insert("w".to_string(), array![[3.0]]);
942
943 let accuracies = vec![0.8, 0.9, 0.85]; let soup = ModelSoup::greedy_soup(vec![weights1, weights2, weights3], accuracies).unwrap();
946
947 assert_eq!(soup.recipe(), SoupRecipe::Greedy);
948 assert!(soup.num_models() >= 1); }
950
951 #[test]
952 fn test_weighted_soup() {
953 let mut weights1 = HashMap::new();
954 weights1.insert("w".to_string(), array![[1.0, 2.0]]);
955
956 let mut weights2 = HashMap::new();
957 weights2.insert("w".to_string(), array![[3.0, 4.0]]);
958
959 let soup = ModelSoup::weighted_soup(vec![weights1, weights2], vec![2.0, 1.0]).unwrap();
961
962 assert_eq!(soup.recipe(), SoupRecipe::Weighted);
963
964 let w = soup.get_parameter("w").unwrap();
966 assert!((w[[0, 0]] - 1.6666666).abs() < 1e-5);
967 assert!((w[[0, 1]] - 2.6666666).abs() < 1e-5);
968 }
969
970 #[test]
971 fn test_soup_empty_models() {
972 let result = ModelSoup::uniform_soup(vec![]);
973 assert!(result.is_err());
974 }
975
976 #[test]
977 fn test_soup_mismatched_parameters() {
978 let mut weights1 = HashMap::new();
979 weights1.insert("w".to_string(), array![[1.0]]);
980
981 let mut weights2 = HashMap::new();
982 weights2.insert("b".to_string(), array![[2.0]]); let result = ModelSoup::uniform_soup(vec![weights1, weights2]);
985 assert!(result.is_err());
986 }
987
988 #[test]
989 fn test_greedy_soup_mismatched_lengths() {
990 let mut weights1 = HashMap::new();
991 weights1.insert("w".to_string(), array![[1.0]]);
992
993 let result = ModelSoup::greedy_soup(vec![weights1], vec![0.8, 0.9]);
994 assert!(result.is_err());
995 }
996
997 #[test]
998 fn test_weighted_soup_invalid_weights() {
999 let mut weights1 = HashMap::new();
1000 weights1.insert("w".to_string(), array![[1.0]]);
1001
1002 let mut weights2 = HashMap::new();
1003 weights2.insert("w".to_string(), array![[2.0]]);
1004
1005 let result =
1007 ModelSoup::weighted_soup(vec![weights1.clone(), weights2.clone()], vec![-1.0, 1.0]);
1008 assert!(result.is_err());
1009
1010 let result = ModelSoup::weighted_soup(vec![weights1], vec![1.0, 2.0]);
1012 assert!(result.is_err());
1013 }
1014
1015 #[test]
1016 fn test_soup_into_weights() {
1017 let mut weights1 = HashMap::new();
1018 weights1.insert("w".to_string(), array![[1.0]]);
1019
1020 let mut weights2 = HashMap::new();
1021 weights2.insert("w".to_string(), array![[3.0]]);
1022
1023 let soup = ModelSoup::uniform_soup(vec![weights1, weights2]).unwrap();
1024 let final_weights = soup.into_weights();
1025
1026 assert_eq!(final_weights["w"][[0, 0]], 2.0);
1027 }
1028
1029 #[test]
1030 fn test_soup_multidimensional_weights() {
1031 let mut weights1 = HashMap::new();
1032 weights1.insert("conv".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
1033
1034 let mut weights2 = HashMap::new();
1035 weights2.insert("conv".to_string(), array![[5.0, 6.0], [7.0, 8.0]]);
1036
1037 let soup = ModelSoup::uniform_soup(vec![weights1, weights2]).unwrap();
1038 let conv = soup.get_parameter("conv").unwrap();
1039
1040 assert_eq!(conv[[0, 0]], 3.0); assert_eq!(conv[[0, 1]], 4.0); assert_eq!(conv[[1, 0]], 5.0); assert_eq!(conv[[1, 1]], 6.0); }
1045}