1use crate::error::{Result, SklearsError};
47use crate::traits::{Estimator, Fit, Predict, PredictProba, Score, Transform};
48use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
50use scirs2_core::random::Random;
51
52use serde::{Deserialize, Serialize};
53use std::collections::HashMap;
54use std::sync::{Arc, Mutex};
55use std::time::{Duration, Instant};
56
57#[derive(Debug, Clone)]
59pub struct MockEstimator {
60 config: MockConfig,
61 state: Arc<Mutex<MockState>>,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct MockConfig {
67 pub behavior: MockBehavior,
69 pub fit_delay: Duration,
71 pub predict_delay: Duration,
73 pub fit_failure_probability: f64,
75 pub predict_failure_probability: f64,
77 pub max_fit_calls: Option<usize>,
79 pub random_seed: u64,
81}
82
83impl Default for MockConfig {
84 fn default() -> Self {
85 Self {
86 behavior: MockBehavior::ConstantPrediction(0.0),
87 fit_delay: Duration::from_millis(0),
88 predict_delay: Duration::from_millis(0),
89 fit_failure_probability: 0.0,
90 predict_failure_probability: 0.0,
91 max_fit_calls: None,
92 random_seed: 42,
93 }
94 }
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
99pub enum MockBehavior {
100 ConstantPrediction(f64),
102 FeatureSum,
104 Random { min: f64, max: f64 },
106 LinearModel { weights: Vec<f64>, bias: f64 },
108 Sequence(Vec<f64>),
110 MirrorTargets,
112 MajorityClass,
114 Overfitting {
116 train_accuracy: f64,
117 test_accuracy: f64,
118 },
119}
120
121#[derive(Debug, Default)]
123struct MockState {
124 fit_count: usize,
125 predict_count: usize,
126 last_fit_time: Option<Instant>,
127 last_predict_time: Option<Instant>,
128 training_targets: Option<Array1<f64>>,
129 fitted: bool,
130 fit_call_history: Vec<Instant>,
131 predict_call_history: Vec<Instant>,
132 performance_metrics: HashMap<String, f64>,
133}
134
135impl MockEstimator {
136 pub fn new() -> Self {
138 Self::with_config(MockConfig::default())
139 }
140
141 pub fn with_config(config: MockConfig) -> Self {
143 Self {
144 config,
145 state: Arc::new(Mutex::new(MockState::default())),
146 }
147 }
148
149 pub fn builder() -> MockEstimatorBuilder {
151 MockEstimatorBuilder::new()
152 }
153
154 pub fn config(&self) -> &MockConfig {
156 &self.config
157 }
158
159 pub fn mock_state(&self) -> MockStateSnapshot {
161 let state = self.state.lock().unwrap_or_else(|e| e.into_inner());
162 MockStateSnapshot {
163 fit_count: state.fit_count,
164 predict_count: state.predict_count,
165 fitted: state.fitted,
166 fit_call_history: state.fit_call_history.clone(),
167 predict_call_history: state.predict_call_history.clone(),
168 performance_metrics: state.performance_metrics.clone(),
169 }
170 }
171
172 pub fn reset_state(&self) {
174 let mut state = self.state.lock().unwrap_or_else(|e| e.into_inner());
175 *state = MockState::default();
176 }
177
178 pub fn simulate_error(&self, error_type: MockErrorType) -> Result<()> {
180 match error_type {
181 MockErrorType::FitFailure => {
182 Err(SklearsError::FitError("Simulated fit failure".to_string()))
183 }
184 MockErrorType::PredictFailure => Err(SklearsError::PredictError(
185 "Simulated predict failure".to_string(),
186 )),
187 MockErrorType::InvalidInput => Err(SklearsError::InvalidInput(
188 "Simulated invalid input".to_string(),
189 )),
190 MockErrorType::NotFitted => Err(SklearsError::NotFitted {
191 operation: "predict".to_string(),
192 }),
193 }
194 }
195}
196
197impl Default for MockEstimator {
198 fn default() -> Self {
199 Self::new()
200 }
201}
202
203#[derive(Debug)]
205pub struct MockEstimatorBuilder {
206 config: MockConfig,
207}
208
209impl MockEstimatorBuilder {
210 pub fn new() -> Self {
212 Self {
213 config: MockConfig::default(),
214 }
215 }
216
217 pub fn with_behavior(mut self, behavior: MockBehavior) -> Self {
219 self.config.behavior = behavior;
220 self
221 }
222
223 pub fn with_fit_delay(mut self, delay: Duration) -> Self {
225 self.config.fit_delay = delay;
226 self
227 }
228
229 pub fn with_predict_delay(mut self, delay: Duration) -> Self {
231 self.config.predict_delay = delay;
232 self
233 }
234
235 pub fn with_fit_failure_probability(mut self, probability: f64) -> Self {
237 self.config.fit_failure_probability = probability.clamp(0.0, 1.0);
238 self
239 }
240
241 pub fn with_predict_failure_probability(mut self, probability: f64) -> Self {
243 self.config.predict_failure_probability = probability.clamp(0.0, 1.0);
244 self
245 }
246
247 pub fn with_max_fit_calls(mut self, max_calls: usize) -> Self {
249 self.config.max_fit_calls = Some(max_calls);
250 self
251 }
252
253 pub fn with_random_seed(mut self, seed: u64) -> Self {
255 self.config.random_seed = seed;
256 self
257 }
258
259 pub fn build(self) -> MockEstimator {
261 MockEstimator::with_config(self.config)
262 }
263}
264
265impl Default for MockEstimatorBuilder {
266 fn default() -> Self {
267 Self::new()
268 }
269}
270
271#[derive(Debug, Clone)]
273pub struct MockStateSnapshot {
274 pub fit_count: usize,
275 pub predict_count: usize,
276 pub fitted: bool,
277 pub fit_call_history: Vec<Instant>,
278 pub predict_call_history: Vec<Instant>,
279 pub performance_metrics: HashMap<String, f64>,
280}
281
282#[derive(Debug, Clone, Copy)]
284pub enum MockErrorType {
285 FitFailure,
286 PredictFailure,
287 InvalidInput,
288 NotFitted,
289}
290
291#[derive(Debug, Clone)]
293pub struct TrainedMockEstimator {
294 estimator: MockEstimator,
295 training_data_shape: (usize, usize),
296}
297
298impl Estimator for MockEstimator {
299 type Config = MockConfig;
300 type Error = crate::error::SklearsError;
301 type Float = f64;
302
303 fn config(&self) -> &Self::Config {
304 &self.config
305 }
306}
307
308impl<'a> Fit<ArrayView2<'a, f64>, ArrayView1<'a, f64>> for MockEstimator {
309 type Fitted = TrainedMockEstimator;
310
311 fn fit(self, x: &ArrayView2<'a, f64>, y: &ArrayView1<'a, f64>) -> Result<Self::Fitted> {
312 let mut state = self.state.lock().unwrap_or_else(|e| e.into_inner());
313
314 state.fit_count += 1;
316 state.last_fit_time = Some(Instant::now());
317 state.fit_call_history.push(Instant::now());
318
319 if let Some(max_calls) = self.config.max_fit_calls {
321 if state.fit_count > max_calls {
322 return Err(SklearsError::FitError(format!(
323 "Maximum fit calls ({max_calls}) exceeded"
324 )));
325 }
326 }
327
328 if self.config.fit_failure_probability > 0.0 {
330 let mut rng = Random::seed(self.config.random_seed + state.fit_count as u64);
331 if rng.gen_range(0.0..1.0) < self.config.fit_failure_probability {
332 return Err(SklearsError::FitError(
333 "Simulated random fit failure".to_string(),
334 ));
335 }
336 }
337
338 if x.nrows() != y.len() {
340 return Err(SklearsError::ShapeMismatch {
341 expected: format!("({}, n_features)", y.len()),
342 actual: format!("({}, {})", x.nrows(), x.ncols()),
343 });
344 }
345
346 match self.config.behavior {
348 MockBehavior::MirrorTargets | MockBehavior::MajorityClass => {
349 state.training_targets = Some(y.to_owned());
350 }
351 _ => {}
352 }
353
354 if !self.config.fit_delay.is_zero() {
356 std::thread::sleep(self.config.fit_delay);
357 }
358
359 state.fitted = true;
360 drop(state); Ok(TrainedMockEstimator {
363 estimator: self.clone(),
364 training_data_shape: (x.nrows(), x.ncols()),
365 })
366 }
367}
368
369impl<'a> Predict<ArrayView2<'a, f64>, Array1<f64>> for TrainedMockEstimator {
370 fn predict(&self, x: &ArrayView2<'a, f64>) -> Result<Array1<f64>> {
371 let mut state = self
372 .estimator
373 .state
374 .lock()
375 .unwrap_or_else(|e| e.into_inner());
376
377 state.predict_count += 1;
379 state.last_predict_time = Some(Instant::now());
380 state.predict_call_history.push(Instant::now());
381
382 if self.estimator.config.predict_failure_probability > 0.0 {
384 let mut rng =
385 Random::seed(self.estimator.config.random_seed + state.predict_count as u64);
386 if rng.gen_range(0.0..1.0) < self.estimator.config.predict_failure_probability {
387 return Err(SklearsError::PredictError(
388 "Simulated random predict failure".to_string(),
389 ));
390 }
391 }
392
393 if x.ncols() != self.training_data_shape.1 {
395 return Err(SklearsError::FeatureMismatch {
396 expected: self.training_data_shape.1,
397 actual: x.ncols(),
398 });
399 }
400
401 if !self.estimator.config.predict_delay.is_zero() {
403 std::thread::sleep(self.estimator.config.predict_delay);
404 }
405
406 let predictions = match &self.estimator.config.behavior {
408 MockBehavior::ConstantPrediction(value) => Array1::from_elem(x.nrows(), *value),
409 MockBehavior::FeatureSum => {
410 Array1::from_iter(x.rows().into_iter().map(|row| row.sum()))
411 }
412 MockBehavior::Random { min, max } => {
413 let mut rng = Random::seed(self.estimator.config.random_seed);
414 Array1::from_iter((0..x.nrows()).map(|_| rng.gen_range(*min..*max)))
415 }
416 MockBehavior::LinearModel { weights, bias } => {
417 if weights.len() != x.ncols() {
418 return Err(SklearsError::InvalidInput(
419 "Weight dimension mismatch".to_string(),
420 ));
421 }
422 Array1::from_iter(x.rows().into_iter().map(|row| {
423 let dot_product: f64 = row.iter().zip(weights.iter()).map(|(x, w)| x * w).sum();
424 dot_product + bias
425 }))
426 }
427 MockBehavior::Sequence(values) => {
428 Array1::from_iter((0..x.nrows()).map(|i| values[i % values.len()]))
429 }
430 MockBehavior::MirrorTargets => {
431 if let Some(ref targets) = state.training_targets {
432 Array1::from_iter((0..x.nrows()).map(|i| targets[i % targets.len()]))
434 } else {
435 Array1::zeros(x.nrows())
436 }
437 }
438 MockBehavior::MajorityClass => {
439 if let Some(ref targets) = state.training_targets {
440 let mut counts = HashMap::new();
442 for &target in targets {
443 *counts.entry(target as i32).or_insert(0) += 1;
444 }
445 let majority_class = counts
446 .into_iter()
447 .max_by_key(|(_, count)| *count)
448 .map(|(class, _)| class as f64)
449 .unwrap_or(0.0);
450 Array1::from_elem(x.nrows(), majority_class)
451 } else {
452 Array1::zeros(x.nrows())
453 }
454 }
455 MockBehavior::Overfitting {
456 train_accuracy: _,
457 test_accuracy,
458 } => {
459 let mut rng = Random::seed(self.estimator.config.random_seed);
461 Array1::from_iter((0..x.nrows()).map(|_| {
462 if rng.gen_range(0.0..1.0) < *test_accuracy {
463 1.0 } else {
465 0.0 }
467 }))
468 }
469 };
470
471 Ok(predictions)
472 }
473}
474
475impl<'a> PredictProba<ArrayView2<'a, f64>, Array2<f64>> for TrainedMockEstimator {
476 fn predict_proba(&self, x: &ArrayView2<'a, f64>) -> Result<Array2<f64>> {
477 let predictions = self.predict(x)?;
479 let mut probabilities = Array2::zeros((x.nrows(), 2));
480
481 for (i, &pred) in predictions.iter().enumerate() {
482 let prob_positive = (pred.tanh() + 1.0) / 2.0; probabilities[[i, 0]] = 1.0 - prob_positive;
484 probabilities[[i, 1]] = prob_positive;
485 }
486
487 Ok(probabilities)
488 }
489}
490
491impl<'a> Score<ArrayView2<'a, f64>, ArrayView1<'a, f64>> for TrainedMockEstimator {
492 type Float = f64;
493 fn score(&self, x: &ArrayView2<'a, f64>, y: &ArrayView1<'a, f64>) -> Result<f64> {
494 let predictions = self.predict(x)?;
495
496 match &self.estimator.config.behavior {
498 MockBehavior::Overfitting {
499 train_accuracy,
500 test_accuracy: _,
501 } => {
502 Ok(*train_accuracy)
504 }
505 _ => {
506 let correct = predictions
508 .iter()
509 .zip(y.iter())
510 .map(|(pred, actual)| {
511 if (pred - actual).abs() < 0.5 {
512 1.0
513 } else {
514 0.0
515 }
516 })
517 .sum::<f64>();
518 Ok(correct / y.len() as f64)
519 }
520 }
521 }
522}
523
524#[derive(Debug, Clone)]
526pub struct MockTransformer {
527 config: MockTransformConfig,
528 fitted: bool,
529 input_shape: Option<(usize, usize)>,
530}
531
532#[derive(Debug, Clone)]
534pub struct MockTransformConfig {
535 pub transform_type: MockTransformType,
536 pub output_features: Option<usize>,
537 pub transform_delay: Duration,
538}
539
540#[derive(Debug, Clone)]
542pub enum MockTransformType {
543 Identity,
545 Scale(f64),
547 Shift(f64),
549 FeatureReduction { keep_ratio: f64 },
551 FeatureExpansion { expansion_factor: usize },
553 Standardization,
555}
556
557impl MockTransformer {
558 pub fn new(transform_type: MockTransformType) -> Self {
560 Self {
561 config: MockTransformConfig {
562 transform_type,
563 output_features: None,
564 transform_delay: Duration::from_millis(0),
565 },
566 fitted: false,
567 input_shape: None,
568 }
569 }
570
571 pub fn builder() -> MockTransformerBuilder {
573 MockTransformerBuilder::new()
574 }
575}
576
577#[derive(Debug)]
579pub struct MockTransformerBuilder {
580 transform_type: MockTransformType,
581 output_features: Option<usize>,
582 transform_delay: Duration,
583}
584
585impl MockTransformerBuilder {
586 pub fn new() -> Self {
587 Self {
588 transform_type: MockTransformType::Identity,
589 output_features: None,
590 transform_delay: Duration::from_millis(0),
591 }
592 }
593
594 pub fn with_transform_type(mut self, transform_type: MockTransformType) -> Self {
595 self.transform_type = transform_type;
596 self
597 }
598
599 pub fn with_output_features(mut self, features: usize) -> Self {
600 self.output_features = Some(features);
601 self
602 }
603
604 pub fn with_transform_delay(mut self, delay: Duration) -> Self {
605 self.transform_delay = delay;
606 self
607 }
608
609 pub fn build(self) -> MockTransformer {
610 MockTransformer {
611 config: MockTransformConfig {
612 transform_type: self.transform_type,
613 output_features: self.output_features,
614 transform_delay: self.transform_delay,
615 },
616 fitted: false,
617 input_shape: None,
618 }
619 }
620}
621
622impl Default for MockTransformerBuilder {
623 fn default() -> Self {
624 Self::new()
625 }
626}
627
628impl<'a> Fit<ArrayView2<'a, f64>, ArrayView1<'a, f64>> for MockTransformer {
629 type Fitted = MockTransformer;
630
631 fn fit(self, x: &ArrayView2<'a, f64>, _y: &ArrayView1<'a, f64>) -> Result<Self::Fitted> {
632 let mut fitted = self.clone();
633 fitted.fitted = true;
634 fitted.input_shape = Some((x.nrows(), x.ncols()));
635 Ok(fitted)
636 }
637}
638
639impl<'a> Transform<ArrayView2<'a, f64>, Array2<f64>> for MockTransformer {
640 fn transform(&self, x: &ArrayView2<'a, f64>) -> Result<Array2<f64>> {
641 if !self.fitted {
642 return Err(SklearsError::NotFitted {
643 operation: "transform".to_string(),
644 });
645 }
646
647 if !self.config.transform_delay.is_zero() {
649 std::thread::sleep(self.config.transform_delay);
650 }
651
652 match &self.config.transform_type {
653 MockTransformType::Identity => Ok(x.to_owned()),
654 MockTransformType::Scale(factor) => Ok(x * *factor),
655 MockTransformType::Shift(offset) => Ok(x + *offset),
656 MockTransformType::FeatureReduction { keep_ratio } => {
657 let keep_features = ((x.ncols() as f64) * keep_ratio).ceil() as usize;
658 let keep_features = keep_features.max(1).min(x.ncols());
659 Ok(x.slice(s![.., 0..keep_features]).to_owned())
660 }
661 MockTransformType::FeatureExpansion { expansion_factor } => {
662 let new_features = x.ncols() * expansion_factor;
663 let mut expanded = Array2::zeros((x.nrows(), new_features));
664
665 for i in 0..*expansion_factor {
667 let start_col = i * x.ncols();
668 let end_col = start_col + x.ncols();
669 expanded.slice_mut(s![.., start_col..end_col]).assign(x);
670 }
671 Ok(expanded)
672 }
673 MockTransformType::Standardization => {
674 let mean = x.mean().unwrap_or(0.0);
676 let std = x.std(0.0);
677 if std == 0.0 {
678 Ok(x - mean)
679 } else {
680 Ok((x - mean) / std)
681 }
682 }
683 }
684 }
685}
686
687#[derive(Debug)]
689#[allow(dead_code)]
690pub struct MockEnsemble {
691 estimators: Vec<MockEstimator>,
692 voting_strategy: VotingStrategy,
693 fitted: bool,
694}
695
696#[derive(Debug, Clone)]
698pub enum VotingStrategy {
699 MajorityVote,
700 AverageVote,
701 WeightedVote(Vec<f64>),
702}
703
704impl MockEnsemble {
705 pub fn new(estimators: Vec<MockEstimator>, voting_strategy: VotingStrategy) -> Self {
707 Self {
708 estimators,
709 voting_strategy,
710 fitted: false,
711 }
712 }
713
714 pub fn n_estimators(&self) -> usize {
716 self.estimators.len()
717 }
718
719 pub fn voting_strategy(&self) -> &VotingStrategy {
721 &self.voting_strategy
722 }
723}
724
725#[allow(non_snake_case)]
726#[cfg(test)]
727mod tests {
728 use super::*;
729 use scirs2_core::ndarray::Array2;
730
731 #[test]
732 fn test_mock_estimator_constant_prediction() {
733 let mock = MockEstimator::builder()
734 .with_behavior(MockBehavior::ConstantPrediction(5.0))
735 .build();
736
737 let features = Array2::zeros((10, 3));
738 let targets = Array1::zeros(10);
739
740 let trained = mock
741 .clone()
742 .fit(&features.view(), &targets.view())
743 .expect("model fitting should succeed");
744 let predictions = trained
745 .predict(&features.view())
746 .expect("prediction should succeed");
747
748 assert_eq!(predictions.len(), 10);
749 assert!(predictions.iter().all(|&p| p == 5.0));
750 }
751
752 #[test]
753 fn test_mock_estimator_state_tracking() {
754 let mock = MockEstimator::new();
755 let features = Array2::zeros((5, 2));
756 let targets = Array1::zeros(5);
757
758 let state = mock.mock_state();
760 assert_eq!(state.fit_count, 0);
761 assert_eq!(state.predict_count, 0);
762 assert!(!state.fitted);
763
764 let trained = mock
766 .clone()
767 .fit(&features.view(), &targets.view())
768 .expect("model fitting should succeed");
769 let state = mock.mock_state();
770 assert_eq!(state.fit_count, 1);
771 assert!(state.fitted);
772
773 let _ = trained
775 .predict(&features.view())
776 .expect("prediction should succeed");
777 let state = mock.mock_state();
778 assert_eq!(state.predict_count, 1);
779 }
780
781 #[test]
782 fn test_mock_estimator_feature_sum() {
783 let mock = MockEstimator::builder()
784 .with_behavior(MockBehavior::FeatureSum)
785 .build();
786
787 let features = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
788 .expect("valid array shape");
789 let targets = Array1::zeros(2);
790
791 let trained = mock
792 .clone()
793 .fit(&features.view(), &targets.view())
794 .expect("model fitting should succeed");
795 let predictions = trained
796 .predict(&features.view())
797 .expect("prediction should succeed");
798
799 assert_eq!(predictions[0], 6.0); assert_eq!(predictions[1], 15.0); }
802
803 #[test]
804 fn test_mock_estimator_linear_model() {
805 let weights = vec![1.0, 2.0, 3.0];
806 let bias = 1.0;
807
808 let mock = MockEstimator::builder()
809 .with_behavior(MockBehavior::LinearModel { weights, bias })
810 .build();
811
812 let features =
813 Array2::from_shape_vec((1, 3), vec![1.0, 1.0, 1.0]).expect("valid array shape");
814 let targets = Array1::zeros(1);
815
816 let trained = mock
817 .fit(&features.view(), &targets.view())
818 .expect("model fitting should succeed");
819 let predictions = trained
820 .predict(&features.view())
821 .expect("prediction should succeed");
822
823 assert_eq!(predictions[0], 7.0); }
825
826 #[test]
827 fn test_mock_transformer_identity() {
828 let transformer = MockTransformer::new(MockTransformType::Identity);
829 let data =
830 Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).expect("valid array shape");
831 let targets = Array1::zeros(2);
832
833 let fitted = transformer
834 .clone()
835 .fit(&data.view(), &targets.view())
836 .expect("expected valid value");
837 let transformed = fitted
838 .transform(&data.view())
839 .expect("transform should succeed");
840
841 assert_eq!(transformed, data);
842 }
843
844 #[test]
845 fn test_mock_transformer_scale() {
846 let transformer = MockTransformer::new(MockTransformType::Scale(2.0));
847 let data =
848 Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).expect("valid array shape");
849 let targets = Array1::zeros(2);
850
851 let fitted = transformer
852 .clone()
853 .fit(&data.view(), &targets.view())
854 .expect("expected valid value");
855 let transformed = fitted
856 .transform(&data.view())
857 .expect("transform should succeed");
858
859 let expected =
860 Array2::from_shape_vec((2, 2), vec![2.0, 4.0, 6.0, 8.0]).expect("valid array shape");
861 assert_eq!(transformed, expected);
862 }
863
864 #[test]
865 fn test_mock_estimator_failure_simulation() {
866 let mock = MockEstimator::builder()
867 .with_fit_failure_probability(1.0) .build();
869
870 let features = Array2::zeros((5, 2));
871 let targets = Array1::zeros(5);
872
873 let result = mock.clone().fit(&features.view(), &targets.view());
874 assert!(result.is_err());
875 }
876
877 #[test]
878 fn test_mock_estimator_max_fit_calls() {
879 let mock = MockEstimator::builder().with_max_fit_calls(2).build();
880
881 let features = Array2::zeros((5, 2));
882 let targets = Array1::zeros(5);
883
884 assert!(mock.clone().fit(&features.view(), &targets.view()).is_ok());
886 assert!(mock.clone().fit(&features.view(), &targets.view()).is_ok());
887
888 assert!(mock.clone().fit(&features.view(), &targets.view()).is_err());
890 }
891
892 #[test]
893 fn test_mock_estimator_predict_proba() {
894 let mock = MockEstimator::builder()
895 .with_behavior(MockBehavior::ConstantPrediction(0.0))
896 .build();
897
898 let features = Array2::zeros((3, 2));
899 let targets = Array1::zeros(3);
900
901 let trained = mock
902 .clone()
903 .fit(&features.view(), &targets.view())
904 .expect("model fitting should succeed");
905 let probabilities = trained
906 .predict_proba(&features.view())
907 .expect("expected valid value");
908
909 assert_eq!(probabilities.shape(), &[3, 2]);
910 for row in probabilities.rows() {
912 let sum: f64 = row.sum();
913 assert!((sum - 1.0).abs() < 1e-10);
914 }
915 }
916
917 #[test]
918 fn test_mock_ensemble_creation() {
919 let est1 = MockEstimator::builder()
920 .with_behavior(MockBehavior::ConstantPrediction(1.0))
921 .build();
922 let est2 = MockEstimator::builder()
923 .with_behavior(MockBehavior::ConstantPrediction(2.0))
924 .build();
925
926 let ensemble = MockEnsemble::new(vec![est1, est2], VotingStrategy::AverageVote);
927
928 assert_eq!(ensemble.n_estimators(), 2);
929 assert!(matches!(
930 ensemble.voting_strategy(),
931 VotingStrategy::AverageVote
932 ));
933 }
934}