Skip to main content

sklears_core/
mock_objects.rs

1/// Mock objects for testing complex machine learning interactions
2///
3/// This module provides comprehensive mock implementations of machine learning
4/// components to enable sophisticated testing scenarios, particularly for:
5///
6/// - Integration testing between multiple ML components
7/// - Behavior verification in ensemble methods
8/// - Error condition simulation and recovery testing
9/// - Performance benchmarking with controlled behavior
10/// - Pipeline testing with predictable components
11///
12/// # Key Features
13///
14/// - Configurable mock estimators with predictable behavior
15/// - Controllable failure modes for error testing
16/// - Performance simulation for benchmarking
17/// - State tracking for behavior verification
18/// - Builder pattern for easy mock configuration
19///
20/// # Examples
21///
22/// ```rust,no_run
23/// use sklears_core::mock_objects::{MockEstimator, MockBehavior, MockConfig};
24/// use sklears_core::traits::{Predict, Fit};
25/// use scirs2_core::ndarray::{Array1, Array2};
26///
27/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
28/// // Create a mock classifier that always predicts class 1
29/// let mock = MockEstimator::builder()
30///     .with_behavior(MockBehavior::ConstantPrediction(1.0))
31///     .with_fit_delay(std::time::Duration::from_millis(10))
32///     .build();
33///
34/// // Use it like any other estimator
35/// let features = Array2::zeros((100, 10));
36/// let targets = Array1::zeros(100);
37///
38/// let trained = mock.fit(&features.view(), &targets.view())?;
39/// let predictions = trained.predict(&features.view())?;
40///
41/// // All predictions should be 1.0
42/// assert!(predictions.iter().all(|&p| p == 1.0));
43/// # Ok(())
44/// # }
45/// ```
46use crate::error::{Result, SklearsError};
47use crate::traits::{Estimator, Fit, Predict, PredictProba, Score, Transform};
48// SciRS2 Policy: Using scirs2_core::ndarray and scirs2_core::random (COMPLIANT)
49use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
50use scirs2_core::random::Random;
51
52use serde::{Deserialize, Serialize};
53use std::collections::HashMap;
54use std::sync::{Arc, Mutex};
55use std::time::{Duration, Instant};
56
57/// Mock estimator with configurable behavior for testing
58#[derive(Debug, Clone)]
59pub struct MockEstimator {
60    config: MockConfig,
61    state: Arc<Mutex<MockState>>,
62}
63
64/// Configuration for mock estimator behavior
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct MockConfig {
67    /// Behavior pattern for predictions
68    pub behavior: MockBehavior,
69    /// Artificial delay during fit operation
70    pub fit_delay: Duration,
71    /// Artificial delay during predict operation
72    pub predict_delay: Duration,
73    /// Whether to simulate fit failures
74    pub fit_failure_probability: f64,
75    /// Whether to simulate predict failures
76    pub predict_failure_probability: f64,
77    /// Maximum number of fit calls before failure
78    pub max_fit_calls: Option<usize>,
79    /// Random seed for reproducible behavior
80    pub random_seed: u64,
81}
82
83impl Default for MockConfig {
84    fn default() -> Self {
85        Self {
86            behavior: MockBehavior::ConstantPrediction(0.0),
87            fit_delay: Duration::from_millis(0),
88            predict_delay: Duration::from_millis(0),
89            fit_failure_probability: 0.0,
90            predict_failure_probability: 0.0,
91            max_fit_calls: None,
92            random_seed: 42,
93        }
94    }
95}
96
97/// Different mock behavior patterns
98#[derive(Debug, Clone, Serialize, Deserialize)]
99pub enum MockBehavior {
100    /// Always return the same prediction value
101    ConstantPrediction(f64),
102    /// Return predictions based on feature sum
103    FeatureSum,
104    /// Return random predictions (using seed)
105    Random { min: f64, max: f64 },
106    /// Return predictions based on a simple linear model
107    LinearModel { weights: Vec<f64>, bias: f64 },
108    /// Return values from a predefined sequence
109    Sequence(Vec<f64>),
110    /// Mirror the target values during training
111    MirrorTargets,
112    /// Always predict the class with highest frequency in training
113    MajorityClass,
114    /// Simulate overfitting by perfect training accuracy, poor test accuracy
115    Overfitting {
116        train_accuracy: f64,
117        test_accuracy: f64,
118    },
119}
120
121/// Internal state tracking for mock estimator
122#[derive(Debug, Default)]
123struct MockState {
124    fit_count: usize,
125    predict_count: usize,
126    last_fit_time: Option<Instant>,
127    last_predict_time: Option<Instant>,
128    training_targets: Option<Array1<f64>>,
129    fitted: bool,
130    fit_call_history: Vec<Instant>,
131    predict_call_history: Vec<Instant>,
132    performance_metrics: HashMap<String, f64>,
133}
134
135impl MockEstimator {
136    /// Create a new mock estimator with default configuration
137    pub fn new() -> Self {
138        Self::with_config(MockConfig::default())
139    }
140
141    /// Create a mock estimator with custom configuration
142    pub fn with_config(config: MockConfig) -> Self {
143        Self {
144            config,
145            state: Arc::new(Mutex::new(MockState::default())),
146        }
147    }
148
149    /// Create a builder for configuring mock estimator
150    pub fn builder() -> MockEstimatorBuilder {
151        MockEstimatorBuilder::new()
152    }
153
154    /// Get the current configuration
155    pub fn config(&self) -> &MockConfig {
156        &self.config
157    }
158
159    /// Get mock state information for testing
160    pub fn mock_state(&self) -> MockStateSnapshot {
161        let state = self.state.lock().unwrap_or_else(|e| e.into_inner());
162        MockStateSnapshot {
163            fit_count: state.fit_count,
164            predict_count: state.predict_count,
165            fitted: state.fitted,
166            fit_call_history: state.fit_call_history.clone(),
167            predict_call_history: state.predict_call_history.clone(),
168            performance_metrics: state.performance_metrics.clone(),
169        }
170    }
171
172    /// Reset the mock state (useful for test setup)
173    pub fn reset_state(&self) {
174        let mut state = self.state.lock().unwrap_or_else(|e| e.into_inner());
175        *state = MockState::default();
176    }
177
178    /// Simulate a specific error condition
179    pub fn simulate_error(&self, error_type: MockErrorType) -> Result<()> {
180        match error_type {
181            MockErrorType::FitFailure => {
182                Err(SklearsError::FitError("Simulated fit failure".to_string()))
183            }
184            MockErrorType::PredictFailure => Err(SklearsError::PredictError(
185                "Simulated predict failure".to_string(),
186            )),
187            MockErrorType::InvalidInput => Err(SklearsError::InvalidInput(
188                "Simulated invalid input".to_string(),
189            )),
190            MockErrorType::NotFitted => Err(SklearsError::NotFitted {
191                operation: "predict".to_string(),
192            }),
193        }
194    }
195}
196
197impl Default for MockEstimator {
198    fn default() -> Self {
199        Self::new()
200    }
201}
202
203/// Builder for configuring mock estimators
204#[derive(Debug)]
205pub struct MockEstimatorBuilder {
206    config: MockConfig,
207}
208
209impl MockEstimatorBuilder {
210    /// Create a new builder
211    pub fn new() -> Self {
212        Self {
213            config: MockConfig::default(),
214        }
215    }
216
217    /// Set the behavior pattern
218    pub fn with_behavior(mut self, behavior: MockBehavior) -> Self {
219        self.config.behavior = behavior;
220        self
221    }
222
223    /// Set the fit delay
224    pub fn with_fit_delay(mut self, delay: Duration) -> Self {
225        self.config.fit_delay = delay;
226        self
227    }
228
229    /// Set the predict delay
230    pub fn with_predict_delay(mut self, delay: Duration) -> Self {
231        self.config.predict_delay = delay;
232        self
233    }
234
235    /// Set fit failure probability
236    pub fn with_fit_failure_probability(mut self, probability: f64) -> Self {
237        self.config.fit_failure_probability = probability.clamp(0.0, 1.0);
238        self
239    }
240
241    /// Set predict failure probability
242    pub fn with_predict_failure_probability(mut self, probability: f64) -> Self {
243        self.config.predict_failure_probability = probability.clamp(0.0, 1.0);
244        self
245    }
246
247    /// Set maximum number of fit calls before failure
248    pub fn with_max_fit_calls(mut self, max_calls: usize) -> Self {
249        self.config.max_fit_calls = Some(max_calls);
250        self
251    }
252
253    /// Set random seed for reproducible behavior
254    pub fn with_random_seed(mut self, seed: u64) -> Self {
255        self.config.random_seed = seed;
256        self
257    }
258
259    /// Build the mock estimator
260    pub fn build(self) -> MockEstimator {
261        MockEstimator::with_config(self.config)
262    }
263}
264
265impl Default for MockEstimatorBuilder {
266    fn default() -> Self {
267        Self::new()
268    }
269}
270
271/// Snapshot of mock state for testing
272#[derive(Debug, Clone)]
273pub struct MockStateSnapshot {
274    pub fit_count: usize,
275    pub predict_count: usize,
276    pub fitted: bool,
277    pub fit_call_history: Vec<Instant>,
278    pub predict_call_history: Vec<Instant>,
279    pub performance_metrics: HashMap<String, f64>,
280}
281
282/// Types of errors that can be simulated
283#[derive(Debug, Clone, Copy)]
284pub enum MockErrorType {
285    FitFailure,
286    PredictFailure,
287    InvalidInput,
288    NotFitted,
289}
290
291/// Trained mock estimator
292#[derive(Debug, Clone)]
293pub struct TrainedMockEstimator {
294    estimator: MockEstimator,
295    training_data_shape: (usize, usize),
296}
297
298impl Estimator for MockEstimator {
299    type Config = MockConfig;
300    type Error = crate::error::SklearsError;
301    type Float = f64;
302
303    fn config(&self) -> &Self::Config {
304        &self.config
305    }
306}
307
308impl<'a> Fit<ArrayView2<'a, f64>, ArrayView1<'a, f64>> for MockEstimator {
309    type Fitted = TrainedMockEstimator;
310
311    fn fit(self, x: &ArrayView2<'a, f64>, y: &ArrayView1<'a, f64>) -> Result<Self::Fitted> {
312        let mut state = self.state.lock().unwrap_or_else(|e| e.into_inner());
313
314        // Track fit call
315        state.fit_count += 1;
316        state.last_fit_time = Some(Instant::now());
317        state.fit_call_history.push(Instant::now());
318
319        // Check for max fit calls limit
320        if let Some(max_calls) = self.config.max_fit_calls {
321            if state.fit_count > max_calls {
322                return Err(SklearsError::FitError(format!(
323                    "Maximum fit calls ({max_calls}) exceeded"
324                )));
325            }
326        }
327
328        // Simulate fit failure probability
329        if self.config.fit_failure_probability > 0.0 {
330            let mut rng = Random::seed(self.config.random_seed + state.fit_count as u64);
331            if rng.gen_range(0.0..1.0) < self.config.fit_failure_probability {
332                return Err(SklearsError::FitError(
333                    "Simulated random fit failure".to_string(),
334                ));
335            }
336        }
337
338        // Validate input dimensions
339        if x.nrows() != y.len() {
340            return Err(SklearsError::ShapeMismatch {
341                expected: format!("({}, n_features)", y.len()),
342                actual: format!("({}, {})", x.nrows(), x.ncols()),
343            });
344        }
345
346        // Store training targets for certain behaviors
347        match self.config.behavior {
348            MockBehavior::MirrorTargets | MockBehavior::MajorityClass => {
349                state.training_targets = Some(y.to_owned());
350            }
351            _ => {}
352        }
353
354        // Simulate fit delay
355        if !self.config.fit_delay.is_zero() {
356            std::thread::sleep(self.config.fit_delay);
357        }
358
359        state.fitted = true;
360        drop(state); // Release lock before creating output
361
362        Ok(TrainedMockEstimator {
363            estimator: self.clone(),
364            training_data_shape: (x.nrows(), x.ncols()),
365        })
366    }
367}
368
369impl<'a> Predict<ArrayView2<'a, f64>, Array1<f64>> for TrainedMockEstimator {
370    fn predict(&self, x: &ArrayView2<'a, f64>) -> Result<Array1<f64>> {
371        let mut state = self
372            .estimator
373            .state
374            .lock()
375            .unwrap_or_else(|e| e.into_inner());
376
377        // Track predict call
378        state.predict_count += 1;
379        state.last_predict_time = Some(Instant::now());
380        state.predict_call_history.push(Instant::now());
381
382        // Simulate predict failure probability
383        if self.estimator.config.predict_failure_probability > 0.0 {
384            let mut rng =
385                Random::seed(self.estimator.config.random_seed + state.predict_count as u64);
386            if rng.gen_range(0.0..1.0) < self.estimator.config.predict_failure_probability {
387                return Err(SklearsError::PredictError(
388                    "Simulated random predict failure".to_string(),
389                ));
390            }
391        }
392
393        // Validate input dimensions
394        if x.ncols() != self.training_data_shape.1 {
395            return Err(SklearsError::FeatureMismatch {
396                expected: self.training_data_shape.1,
397                actual: x.ncols(),
398            });
399        }
400
401        // Simulate predict delay
402        if !self.estimator.config.predict_delay.is_zero() {
403            std::thread::sleep(self.estimator.config.predict_delay);
404        }
405
406        // Generate predictions based on behavior
407        let predictions = match &self.estimator.config.behavior {
408            MockBehavior::ConstantPrediction(value) => Array1::from_elem(x.nrows(), *value),
409            MockBehavior::FeatureSum => {
410                Array1::from_iter(x.rows().into_iter().map(|row| row.sum()))
411            }
412            MockBehavior::Random { min, max } => {
413                let mut rng = Random::seed(self.estimator.config.random_seed);
414                Array1::from_iter((0..x.nrows()).map(|_| rng.gen_range(*min..*max)))
415            }
416            MockBehavior::LinearModel { weights, bias } => {
417                if weights.len() != x.ncols() {
418                    return Err(SklearsError::InvalidInput(
419                        "Weight dimension mismatch".to_string(),
420                    ));
421                }
422                Array1::from_iter(x.rows().into_iter().map(|row| {
423                    let dot_product: f64 = row.iter().zip(weights.iter()).map(|(x, w)| x * w).sum();
424                    dot_product + bias
425                }))
426            }
427            MockBehavior::Sequence(values) => {
428                Array1::from_iter((0..x.nrows()).map(|i| values[i % values.len()]))
429            }
430            MockBehavior::MirrorTargets => {
431                if let Some(ref targets) = state.training_targets {
432                    // Return targets corresponding to input indices (simplified)
433                    Array1::from_iter((0..x.nrows()).map(|i| targets[i % targets.len()]))
434                } else {
435                    Array1::zeros(x.nrows())
436                }
437            }
438            MockBehavior::MajorityClass => {
439                if let Some(ref targets) = state.training_targets {
440                    // Find most common class
441                    let mut counts = HashMap::new();
442                    for &target in targets {
443                        *counts.entry(target as i32).or_insert(0) += 1;
444                    }
445                    let majority_class = counts
446                        .into_iter()
447                        .max_by_key(|(_, count)| *count)
448                        .map(|(class, _)| class as f64)
449                        .unwrap_or(0.0);
450                    Array1::from_elem(x.nrows(), majority_class)
451                } else {
452                    Array1::zeros(x.nrows())
453                }
454            }
455            MockBehavior::Overfitting {
456                train_accuracy: _,
457                test_accuracy,
458            } => {
459                // Simulate poor generalization
460                let mut rng = Random::seed(self.estimator.config.random_seed);
461                Array1::from_iter((0..x.nrows()).map(|_| {
462                    if rng.gen_range(0.0..1.0) < *test_accuracy {
463                        1.0 // Correct prediction
464                    } else {
465                        0.0 // Incorrect prediction
466                    }
467                }))
468            }
469        };
470
471        Ok(predictions)
472    }
473}
474
475impl<'a> PredictProba<ArrayView2<'a, f64>, Array2<f64>> for TrainedMockEstimator {
476    fn predict_proba(&self, x: &ArrayView2<'a, f64>) -> Result<Array2<f64>> {
477        // Convert predictions to probabilities (simplified for binary classification)
478        let predictions = self.predict(x)?;
479        let mut probabilities = Array2::zeros((x.nrows(), 2));
480
481        for (i, &pred) in predictions.iter().enumerate() {
482            let prob_positive = (pred.tanh() + 1.0) / 2.0; // Map to [0, 1]
483            probabilities[[i, 0]] = 1.0 - prob_positive;
484            probabilities[[i, 1]] = prob_positive;
485        }
486
487        Ok(probabilities)
488    }
489}
490
491impl<'a> Score<ArrayView2<'a, f64>, ArrayView1<'a, f64>> for TrainedMockEstimator {
492    type Float = f64;
493    fn score(&self, x: &ArrayView2<'a, f64>, y: &ArrayView1<'a, f64>) -> Result<f64> {
494        let predictions = self.predict(x)?;
495
496        // Calculate R² score for regression or accuracy for classification
497        match &self.estimator.config.behavior {
498            MockBehavior::Overfitting {
499                train_accuracy,
500                test_accuracy: _,
501            } => {
502                // Return perfect score for training data simulation
503                Ok(*train_accuracy)
504            }
505            _ => {
506                // Simple accuracy calculation (assuming classification)
507                let correct = predictions
508                    .iter()
509                    .zip(y.iter())
510                    .map(|(pred, actual)| {
511                        if (pred - actual).abs() < 0.5 {
512                            1.0
513                        } else {
514                            0.0
515                        }
516                    })
517                    .sum::<f64>();
518                Ok(correct / y.len() as f64)
519            }
520        }
521    }
522}
523
524/// Mock transformer for testing transformation pipelines
525#[derive(Debug, Clone)]
526pub struct MockTransformer {
527    config: MockTransformConfig,
528    fitted: bool,
529    input_shape: Option<(usize, usize)>,
530}
531
532/// Configuration for mock transformer
533#[derive(Debug, Clone)]
534pub struct MockTransformConfig {
535    pub transform_type: MockTransformType,
536    pub output_features: Option<usize>,
537    pub transform_delay: Duration,
538}
539
540/// Types of transformations to simulate
541#[derive(Debug, Clone)]
542pub enum MockTransformType {
543    /// Identity transformation (no change)
544    Identity,
545    /// Scale all values by a constant
546    Scale(f64),
547    /// Add constant to all values
548    Shift(f64),
549    /// Reduce feature dimensions
550    FeatureReduction { keep_ratio: f64 },
551    /// Expand feature dimensions
552    FeatureExpansion { expansion_factor: usize },
553    /// Simulate standardization (mean=0, std=1)
554    Standardization,
555}
556
557impl MockTransformer {
558    /// Create a new mock transformer
559    pub fn new(transform_type: MockTransformType) -> Self {
560        Self {
561            config: MockTransformConfig {
562                transform_type,
563                output_features: None,
564                transform_delay: Duration::from_millis(0),
565            },
566            fitted: false,
567            input_shape: None,
568        }
569    }
570
571    /// Create a mock transformer builder
572    pub fn builder() -> MockTransformerBuilder {
573        MockTransformerBuilder::new()
574    }
575}
576
577/// Builder for mock transformers
578#[derive(Debug)]
579pub struct MockTransformerBuilder {
580    transform_type: MockTransformType,
581    output_features: Option<usize>,
582    transform_delay: Duration,
583}
584
585impl MockTransformerBuilder {
586    pub fn new() -> Self {
587        Self {
588            transform_type: MockTransformType::Identity,
589            output_features: None,
590            transform_delay: Duration::from_millis(0),
591        }
592    }
593
594    pub fn with_transform_type(mut self, transform_type: MockTransformType) -> Self {
595        self.transform_type = transform_type;
596        self
597    }
598
599    pub fn with_output_features(mut self, features: usize) -> Self {
600        self.output_features = Some(features);
601        self
602    }
603
604    pub fn with_transform_delay(mut self, delay: Duration) -> Self {
605        self.transform_delay = delay;
606        self
607    }
608
609    pub fn build(self) -> MockTransformer {
610        MockTransformer {
611            config: MockTransformConfig {
612                transform_type: self.transform_type,
613                output_features: self.output_features,
614                transform_delay: self.transform_delay,
615            },
616            fitted: false,
617            input_shape: None,
618        }
619    }
620}
621
622impl Default for MockTransformerBuilder {
623    fn default() -> Self {
624        Self::new()
625    }
626}
627
628impl<'a> Fit<ArrayView2<'a, f64>, ArrayView1<'a, f64>> for MockTransformer {
629    type Fitted = MockTransformer;
630
631    fn fit(self, x: &ArrayView2<'a, f64>, _y: &ArrayView1<'a, f64>) -> Result<Self::Fitted> {
632        let mut fitted = self.clone();
633        fitted.fitted = true;
634        fitted.input_shape = Some((x.nrows(), x.ncols()));
635        Ok(fitted)
636    }
637}
638
639impl<'a> Transform<ArrayView2<'a, f64>, Array2<f64>> for MockTransformer {
640    fn transform(&self, x: &ArrayView2<'a, f64>) -> Result<Array2<f64>> {
641        if !self.fitted {
642            return Err(SklearsError::NotFitted {
643                operation: "transform".to_string(),
644            });
645        }
646
647        // Simulate transform delay
648        if !self.config.transform_delay.is_zero() {
649            std::thread::sleep(self.config.transform_delay);
650        }
651
652        match &self.config.transform_type {
653            MockTransformType::Identity => Ok(x.to_owned()),
654            MockTransformType::Scale(factor) => Ok(x * *factor),
655            MockTransformType::Shift(offset) => Ok(x + *offset),
656            MockTransformType::FeatureReduction { keep_ratio } => {
657                let keep_features = ((x.ncols() as f64) * keep_ratio).ceil() as usize;
658                let keep_features = keep_features.max(1).min(x.ncols());
659                Ok(x.slice(s![.., 0..keep_features]).to_owned())
660            }
661            MockTransformType::FeatureExpansion { expansion_factor } => {
662                let new_features = x.ncols() * expansion_factor;
663                let mut expanded = Array2::zeros((x.nrows(), new_features));
664
665                // Tile the original features
666                for i in 0..*expansion_factor {
667                    let start_col = i * x.ncols();
668                    let end_col = start_col + x.ncols();
669                    expanded.slice_mut(s![.., start_col..end_col]).assign(x);
670                }
671                Ok(expanded)
672            }
673            MockTransformType::Standardization => {
674                // Simple standardization simulation
675                let mean = x.mean().unwrap_or(0.0);
676                let std = x.std(0.0);
677                if std == 0.0 {
678                    Ok(x - mean)
679                } else {
680                    Ok((x - mean) / std)
681                }
682            }
683        }
684    }
685}
686
687/// Mock ensemble for testing ensemble methods
688#[derive(Debug)]
689#[allow(dead_code)]
690pub struct MockEnsemble {
691    estimators: Vec<MockEstimator>,
692    voting_strategy: VotingStrategy,
693    fitted: bool,
694}
695
696/// Voting strategies for mock ensemble
697#[derive(Debug, Clone)]
698pub enum VotingStrategy {
699    MajorityVote,
700    AverageVote,
701    WeightedVote(Vec<f64>),
702}
703
704impl MockEnsemble {
705    /// Create a new mock ensemble
706    pub fn new(estimators: Vec<MockEstimator>, voting_strategy: VotingStrategy) -> Self {
707        Self {
708            estimators,
709            voting_strategy,
710            fitted: false,
711        }
712    }
713
714    /// Get the number of base estimators
715    pub fn n_estimators(&self) -> usize {
716        self.estimators.len()
717    }
718
719    /// Get voting strategy
720    pub fn voting_strategy(&self) -> &VotingStrategy {
721        &self.voting_strategy
722    }
723}
724
725#[allow(non_snake_case)]
726#[cfg(test)]
727mod tests {
728    use super::*;
729    use scirs2_core::ndarray::Array2;
730
731    #[test]
732    fn test_mock_estimator_constant_prediction() {
733        let mock = MockEstimator::builder()
734            .with_behavior(MockBehavior::ConstantPrediction(5.0))
735            .build();
736
737        let features = Array2::zeros((10, 3));
738        let targets = Array1::zeros(10);
739
740        let trained = mock
741            .clone()
742            .fit(&features.view(), &targets.view())
743            .expect("model fitting should succeed");
744        let predictions = trained
745            .predict(&features.view())
746            .expect("prediction should succeed");
747
748        assert_eq!(predictions.len(), 10);
749        assert!(predictions.iter().all(|&p| p == 5.0));
750    }
751
752    #[test]
753    fn test_mock_estimator_state_tracking() {
754        let mock = MockEstimator::new();
755        let features = Array2::zeros((5, 2));
756        let targets = Array1::zeros(5);
757
758        // Initial state
759        let state = mock.mock_state();
760        assert_eq!(state.fit_count, 0);
761        assert_eq!(state.predict_count, 0);
762        assert!(!state.fitted);
763
764        // After fitting
765        let trained = mock
766            .clone()
767            .fit(&features.view(), &targets.view())
768            .expect("model fitting should succeed");
769        let state = mock.mock_state();
770        assert_eq!(state.fit_count, 1);
771        assert!(state.fitted);
772
773        // After predicting
774        let _ = trained
775            .predict(&features.view())
776            .expect("prediction should succeed");
777        let state = mock.mock_state();
778        assert_eq!(state.predict_count, 1);
779    }
780
781    #[test]
782    fn test_mock_estimator_feature_sum() {
783        let mock = MockEstimator::builder()
784            .with_behavior(MockBehavior::FeatureSum)
785            .build();
786
787        let features = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
788            .expect("valid array shape");
789        let targets = Array1::zeros(2);
790
791        let trained = mock
792            .clone()
793            .fit(&features.view(), &targets.view())
794            .expect("model fitting should succeed");
795        let predictions = trained
796            .predict(&features.view())
797            .expect("prediction should succeed");
798
799        assert_eq!(predictions[0], 6.0); // 1 + 2 + 3
800        assert_eq!(predictions[1], 15.0); // 4 + 5 + 6
801    }
802
803    #[test]
804    fn test_mock_estimator_linear_model() {
805        let weights = vec![1.0, 2.0, 3.0];
806        let bias = 1.0;
807
808        let mock = MockEstimator::builder()
809            .with_behavior(MockBehavior::LinearModel { weights, bias })
810            .build();
811
812        let features =
813            Array2::from_shape_vec((1, 3), vec![1.0, 1.0, 1.0]).expect("valid array shape");
814        let targets = Array1::zeros(1);
815
816        let trained = mock
817            .fit(&features.view(), &targets.view())
818            .expect("model fitting should succeed");
819        let predictions = trained
820            .predict(&features.view())
821            .expect("prediction should succeed");
822
823        assert_eq!(predictions[0], 7.0); // 1*1 + 2*1 + 3*1 + 1
824    }
825
826    #[test]
827    fn test_mock_transformer_identity() {
828        let transformer = MockTransformer::new(MockTransformType::Identity);
829        let data =
830            Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).expect("valid array shape");
831        let targets = Array1::zeros(2);
832
833        let fitted = transformer
834            .clone()
835            .fit(&data.view(), &targets.view())
836            .expect("expected valid value");
837        let transformed = fitted
838            .transform(&data.view())
839            .expect("transform should succeed");
840
841        assert_eq!(transformed, data);
842    }
843
844    #[test]
845    fn test_mock_transformer_scale() {
846        let transformer = MockTransformer::new(MockTransformType::Scale(2.0));
847        let data =
848            Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).expect("valid array shape");
849        let targets = Array1::zeros(2);
850
851        let fitted = transformer
852            .clone()
853            .fit(&data.view(), &targets.view())
854            .expect("expected valid value");
855        let transformed = fitted
856            .transform(&data.view())
857            .expect("transform should succeed");
858
859        let expected =
860            Array2::from_shape_vec((2, 2), vec![2.0, 4.0, 6.0, 8.0]).expect("valid array shape");
861        assert_eq!(transformed, expected);
862    }
863
864    #[test]
865    fn test_mock_estimator_failure_simulation() {
866        let mock = MockEstimator::builder()
867            .with_fit_failure_probability(1.0) // Always fail
868            .build();
869
870        let features = Array2::zeros((5, 2));
871        let targets = Array1::zeros(5);
872
873        let result = mock.clone().fit(&features.view(), &targets.view());
874        assert!(result.is_err());
875    }
876
877    #[test]
878    fn test_mock_estimator_max_fit_calls() {
879        let mock = MockEstimator::builder().with_max_fit_calls(2).build();
880
881        let features = Array2::zeros((5, 2));
882        let targets = Array1::zeros(5);
883
884        // First two fits should succeed
885        assert!(mock.clone().fit(&features.view(), &targets.view()).is_ok());
886        assert!(mock.clone().fit(&features.view(), &targets.view()).is_ok());
887
888        // Third fit should fail
889        assert!(mock.clone().fit(&features.view(), &targets.view()).is_err());
890    }
891
892    #[test]
893    fn test_mock_estimator_predict_proba() {
894        let mock = MockEstimator::builder()
895            .with_behavior(MockBehavior::ConstantPrediction(0.0))
896            .build();
897
898        let features = Array2::zeros((3, 2));
899        let targets = Array1::zeros(3);
900
901        let trained = mock
902            .clone()
903            .fit(&features.view(), &targets.view())
904            .expect("model fitting should succeed");
905        let probabilities = trained
906            .predict_proba(&features.view())
907            .expect("expected valid value");
908
909        assert_eq!(probabilities.shape(), &[3, 2]);
910        // All predictions should have probabilities that sum to 1
911        for row in probabilities.rows() {
912            let sum: f64 = row.sum();
913            assert!((sum - 1.0).abs() < 1e-10);
914        }
915    }
916
917    #[test]
918    fn test_mock_ensemble_creation() {
919        let est1 = MockEstimator::builder()
920            .with_behavior(MockBehavior::ConstantPrediction(1.0))
921            .build();
922        let est2 = MockEstimator::builder()
923            .with_behavior(MockBehavior::ConstantPrediction(2.0))
924            .build();
925
926        let ensemble = MockEnsemble::new(vec![est1, est2], VotingStrategy::AverageVote);
927
928        assert_eq!(ensemble.n_estimators(), 2);
929        assert!(matches!(
930            ensemble.voting_strategy(),
931            VotingStrategy::AverageVote
932        ));
933    }
934}