sklears_core/
mock_objects.rs

1/// Mock objects for testing complex machine learning interactions
2///
3/// This module provides comprehensive mock implementations of machine learning
4/// components to enable sophisticated testing scenarios, particularly for:
5///
6/// - Integration testing between multiple ML components
7/// - Behavior verification in ensemble methods
8/// - Error condition simulation and recovery testing
9/// - Performance benchmarking with controlled behavior
10/// - Pipeline testing with predictable components
11///
12/// # Key Features
13///
14/// - Configurable mock estimators with predictable behavior
15/// - Controllable failure modes for error testing
16/// - Performance simulation for benchmarking
17/// - State tracking for behavior verification
18/// - Builder pattern for easy mock configuration
19///
20/// # Examples
21///
22/// ```rust
23/// use sklears_core::mock_objects::{MockEstimator, MockBehavior, MockConfig};
24/// use sklears_core::traits::{Predict, Fit};
25///
26/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
27/// // Create a mock classifier that always predicts class 1
28/// let mock = MockEstimator::builder()
29///     .with_behavior(MockBehavior::ConstantPrediction(1.0))
30///     .with_fit_delay(std::time::Duration::from_millis(10))
31///     .build();
32///
33/// // Use it like any other estimator
34/// let features = ndarray::Array2::zeros((100, 10));
35/// let targets = ndarray::Array1::zeros(100);
36///
37/// let trained = mock.fit(&features.view(), &targets.view())?;
38/// let predictions = trained.predict(&features.view())?;
39///
40/// // All predictions should be 1.0
41/// assert!(predictions.iter().all(|&p| p == 1.0));
42/// # Ok(())
43/// # }
44/// ```
45use crate::error::{Result, SklearsError};
46use crate::traits::{Estimator, Fit, Predict, PredictProba, Score, Transform};
47// SciRS2 Policy: Using scirs2_core::ndarray and scirs2_core::random (COMPLIANT)
48use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
49use scirs2_core::random::Random;
50use scirs2_core::Rng;
51use serde::{Deserialize, Serialize};
52use std::collections::HashMap;
53use std::sync::{Arc, Mutex};
54use std::time::{Duration, Instant};
55
56/// Mock estimator with configurable behavior for testing
57#[derive(Debug, Clone)]
58pub struct MockEstimator {
59    config: MockConfig,
60    state: Arc<Mutex<MockState>>,
61}
62
63/// Configuration for mock estimator behavior
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct MockConfig {
66    /// Behavior pattern for predictions
67    pub behavior: MockBehavior,
68    /// Artificial delay during fit operation
69    pub fit_delay: Duration,
70    /// Artificial delay during predict operation
71    pub predict_delay: Duration,
72    /// Whether to simulate fit failures
73    pub fit_failure_probability: f64,
74    /// Whether to simulate predict failures
75    pub predict_failure_probability: f64,
76    /// Maximum number of fit calls before failure
77    pub max_fit_calls: Option<usize>,
78    /// Random seed for reproducible behavior
79    pub random_seed: u64,
80}
81
82impl Default for MockConfig {
83    fn default() -> Self {
84        Self {
85            behavior: MockBehavior::ConstantPrediction(0.0),
86            fit_delay: Duration::from_millis(0),
87            predict_delay: Duration::from_millis(0),
88            fit_failure_probability: 0.0,
89            predict_failure_probability: 0.0,
90            max_fit_calls: None,
91            random_seed: 42,
92        }
93    }
94}
95
96/// Different mock behavior patterns
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub enum MockBehavior {
99    /// Always return the same prediction value
100    ConstantPrediction(f64),
101    /// Return predictions based on feature sum
102    FeatureSum,
103    /// Return random predictions (using seed)
104    Random { min: f64, max: f64 },
105    /// Return predictions based on a simple linear model
106    LinearModel { weights: Vec<f64>, bias: f64 },
107    /// Return values from a predefined sequence
108    Sequence(Vec<f64>),
109    /// Mirror the target values during training
110    MirrorTargets,
111    /// Always predict the class with highest frequency in training
112    MajorityClass,
113    /// Simulate overfitting by perfect training accuracy, poor test accuracy
114    Overfitting {
115        train_accuracy: f64,
116        test_accuracy: f64,
117    },
118}
119
120/// Internal state tracking for mock estimator
121#[derive(Debug, Default)]
122struct MockState {
123    fit_count: usize,
124    predict_count: usize,
125    last_fit_time: Option<Instant>,
126    last_predict_time: Option<Instant>,
127    training_targets: Option<Array1<f64>>,
128    fitted: bool,
129    fit_call_history: Vec<Instant>,
130    predict_call_history: Vec<Instant>,
131    performance_metrics: HashMap<String, f64>,
132}
133
134impl MockEstimator {
135    /// Create a new mock estimator with default configuration
136    pub fn new() -> Self {
137        Self::with_config(MockConfig::default())
138    }
139
140    /// Create a mock estimator with custom configuration
141    pub fn with_config(config: MockConfig) -> Self {
142        Self {
143            config,
144            state: Arc::new(Mutex::new(MockState::default())),
145        }
146    }
147
148    /// Create a builder for configuring mock estimator
149    pub fn builder() -> MockEstimatorBuilder {
150        MockEstimatorBuilder::new()
151    }
152
153    /// Get the current configuration
154    pub fn config(&self) -> &MockConfig {
155        &self.config
156    }
157
158    /// Get mock state information for testing
159    pub fn mock_state(&self) -> MockStateSnapshot {
160        let state = self.state.lock().unwrap();
161        MockStateSnapshot {
162            fit_count: state.fit_count,
163            predict_count: state.predict_count,
164            fitted: state.fitted,
165            fit_call_history: state.fit_call_history.clone(),
166            predict_call_history: state.predict_call_history.clone(),
167            performance_metrics: state.performance_metrics.clone(),
168        }
169    }
170
171    /// Reset the mock state (useful for test setup)
172    pub fn reset_state(&self) {
173        let mut state = self.state.lock().unwrap();
174        *state = MockState::default();
175    }
176
177    /// Simulate a specific error condition
178    pub fn simulate_error(&self, error_type: MockErrorType) -> Result<()> {
179        match error_type {
180            MockErrorType::FitFailure => {
181                Err(SklearsError::FitError("Simulated fit failure".to_string()))
182            }
183            MockErrorType::PredictFailure => Err(SklearsError::PredictError(
184                "Simulated predict failure".to_string(),
185            )),
186            MockErrorType::InvalidInput => Err(SklearsError::InvalidInput(
187                "Simulated invalid input".to_string(),
188            )),
189            MockErrorType::NotFitted => Err(SklearsError::NotFitted {
190                operation: "predict".to_string(),
191            }),
192        }
193    }
194}
195
196impl Default for MockEstimator {
197    fn default() -> Self {
198        Self::new()
199    }
200}
201
202/// Builder for configuring mock estimators
203#[derive(Debug)]
204pub struct MockEstimatorBuilder {
205    config: MockConfig,
206}
207
208impl MockEstimatorBuilder {
209    /// Create a new builder
210    pub fn new() -> Self {
211        Self {
212            config: MockConfig::default(),
213        }
214    }
215
216    /// Set the behavior pattern
217    pub fn with_behavior(mut self, behavior: MockBehavior) -> Self {
218        self.config.behavior = behavior;
219        self
220    }
221
222    /// Set the fit delay
223    pub fn with_fit_delay(mut self, delay: Duration) -> Self {
224        self.config.fit_delay = delay;
225        self
226    }
227
228    /// Set the predict delay
229    pub fn with_predict_delay(mut self, delay: Duration) -> Self {
230        self.config.predict_delay = delay;
231        self
232    }
233
234    /// Set fit failure probability
235    pub fn with_fit_failure_probability(mut self, probability: f64) -> Self {
236        self.config.fit_failure_probability = probability.clamp(0.0, 1.0);
237        self
238    }
239
240    /// Set predict failure probability
241    pub fn with_predict_failure_probability(mut self, probability: f64) -> Self {
242        self.config.predict_failure_probability = probability.clamp(0.0, 1.0);
243        self
244    }
245
246    /// Set maximum number of fit calls before failure
247    pub fn with_max_fit_calls(mut self, max_calls: usize) -> Self {
248        self.config.max_fit_calls = Some(max_calls);
249        self
250    }
251
252    /// Set random seed for reproducible behavior
253    pub fn with_random_seed(mut self, seed: u64) -> Self {
254        self.config.random_seed = seed;
255        self
256    }
257
258    /// Build the mock estimator
259    pub fn build(self) -> MockEstimator {
260        MockEstimator::with_config(self.config)
261    }
262}
263
264impl Default for MockEstimatorBuilder {
265    fn default() -> Self {
266        Self::new()
267    }
268}
269
270/// Snapshot of mock state for testing
271#[derive(Debug, Clone)]
272pub struct MockStateSnapshot {
273    pub fit_count: usize,
274    pub predict_count: usize,
275    pub fitted: bool,
276    pub fit_call_history: Vec<Instant>,
277    pub predict_call_history: Vec<Instant>,
278    pub performance_metrics: HashMap<String, f64>,
279}
280
281/// Types of errors that can be simulated
282#[derive(Debug, Clone, Copy)]
283pub enum MockErrorType {
284    FitFailure,
285    PredictFailure,
286    InvalidInput,
287    NotFitted,
288}
289
290/// Trained mock estimator
291#[derive(Debug, Clone)]
292pub struct TrainedMockEstimator {
293    estimator: MockEstimator,
294    training_data_shape: (usize, usize),
295}
296
297impl Estimator for MockEstimator {
298    type Config = MockConfig;
299    type Error = crate::error::SklearsError;
300    type Float = f64;
301
302    fn config(&self) -> &Self::Config {
303        &self.config
304    }
305}
306
307impl<'a> Fit<ArrayView2<'a, f64>, ArrayView1<'a, f64>> for MockEstimator {
308    type Fitted = TrainedMockEstimator;
309
310    fn fit(self, x: &ArrayView2<'a, f64>, y: &ArrayView1<'a, f64>) -> Result<Self::Fitted> {
311        let mut state = self.state.lock().unwrap();
312
313        // Track fit call
314        state.fit_count += 1;
315        state.last_fit_time = Some(Instant::now());
316        state.fit_call_history.push(Instant::now());
317
318        // Check for max fit calls limit
319        if let Some(max_calls) = self.config.max_fit_calls {
320            if state.fit_count > max_calls {
321                return Err(SklearsError::FitError(format!(
322                    "Maximum fit calls ({max_calls}) exceeded"
323                )));
324            }
325        }
326
327        // Simulate fit failure probability
328        if self.config.fit_failure_probability > 0.0 {
329            let mut rng = Random::seed(self.config.random_seed + state.fit_count as u64);
330            if rng.gen::<f64>() < self.config.fit_failure_probability {
331                return Err(SklearsError::FitError(
332                    "Simulated random fit failure".to_string(),
333                ));
334            }
335        }
336
337        // Validate input dimensions
338        if x.nrows() != y.len() {
339            return Err(SklearsError::ShapeMismatch {
340                expected: format!("({}, n_features)", y.len()),
341                actual: format!("({}, {})", x.nrows(), x.ncols()),
342            });
343        }
344
345        // Store training targets for certain behaviors
346        match self.config.behavior {
347            MockBehavior::MirrorTargets | MockBehavior::MajorityClass => {
348                state.training_targets = Some(y.to_owned());
349            }
350            _ => {}
351        }
352
353        // Simulate fit delay
354        if !self.config.fit_delay.is_zero() {
355            std::thread::sleep(self.config.fit_delay);
356        }
357
358        state.fitted = true;
359        drop(state); // Release lock before creating output
360
361        Ok(TrainedMockEstimator {
362            estimator: self.clone(),
363            training_data_shape: (x.nrows(), x.ncols()),
364        })
365    }
366}
367
368impl<'a> Predict<ArrayView2<'a, f64>, Array1<f64>> for TrainedMockEstimator {
369    fn predict(&self, x: &ArrayView2<'a, f64>) -> Result<Array1<f64>> {
370        let mut state = self.estimator.state.lock().unwrap();
371
372        // Track predict call
373        state.predict_count += 1;
374        state.last_predict_time = Some(Instant::now());
375        state.predict_call_history.push(Instant::now());
376
377        // Simulate predict failure probability
378        if self.estimator.config.predict_failure_probability > 0.0 {
379            let mut rng =
380                Random::seed(self.estimator.config.random_seed + state.predict_count as u64);
381            if rng.gen::<f64>() < self.estimator.config.predict_failure_probability {
382                return Err(SklearsError::PredictError(
383                    "Simulated random predict failure".to_string(),
384                ));
385            }
386        }
387
388        // Validate input dimensions
389        if x.ncols() != self.training_data_shape.1 {
390            return Err(SklearsError::FeatureMismatch {
391                expected: self.training_data_shape.1,
392                actual: x.ncols(),
393            });
394        }
395
396        // Simulate predict delay
397        if !self.estimator.config.predict_delay.is_zero() {
398            std::thread::sleep(self.estimator.config.predict_delay);
399        }
400
401        // Generate predictions based on behavior
402        let predictions = match &self.estimator.config.behavior {
403            MockBehavior::ConstantPrediction(value) => Array1::from_elem(x.nrows(), *value),
404            MockBehavior::FeatureSum => {
405                Array1::from_iter(x.rows().into_iter().map(|row| row.sum()))
406            }
407            MockBehavior::Random { min, max } => {
408                let mut rng = Random::seed(self.estimator.config.random_seed);
409                Array1::from_iter((0..x.nrows()).map(|_| rng.gen_range(*min..*max)))
410            }
411            MockBehavior::LinearModel { weights, bias } => {
412                if weights.len() != x.ncols() {
413                    return Err(SklearsError::InvalidInput(
414                        "Weight dimension mismatch".to_string(),
415                    ));
416                }
417                Array1::from_iter(x.rows().into_iter().map(|row| {
418                    let dot_product: f64 = row.iter().zip(weights.iter()).map(|(x, w)| x * w).sum();
419                    dot_product + bias
420                }))
421            }
422            MockBehavior::Sequence(values) => {
423                Array1::from_iter((0..x.nrows()).map(|i| values[i % values.len()]))
424            }
425            MockBehavior::MirrorTargets => {
426                if let Some(ref targets) = state.training_targets {
427                    // Return targets corresponding to input indices (simplified)
428                    Array1::from_iter((0..x.nrows()).map(|i| targets[i % targets.len()]))
429                } else {
430                    Array1::zeros(x.nrows())
431                }
432            }
433            MockBehavior::MajorityClass => {
434                if let Some(ref targets) = state.training_targets {
435                    // Find most common class
436                    let mut counts = HashMap::new();
437                    for &target in targets {
438                        *counts.entry(target as i32).or_insert(0) += 1;
439                    }
440                    let majority_class = counts
441                        .into_iter()
442                        .max_by_key(|(_, count)| *count)
443                        .map(|(class, _)| class as f64)
444                        .unwrap_or(0.0);
445                    Array1::from_elem(x.nrows(), majority_class)
446                } else {
447                    Array1::zeros(x.nrows())
448                }
449            }
450            MockBehavior::Overfitting {
451                train_accuracy: _,
452                test_accuracy,
453            } => {
454                // Simulate poor generalization
455                let mut rng = Random::seed(self.estimator.config.random_seed);
456                Array1::from_iter((0..x.nrows()).map(|_| {
457                    if rng.gen::<f64>() < *test_accuracy {
458                        1.0 // Correct prediction
459                    } else {
460                        0.0 // Incorrect prediction
461                    }
462                }))
463            }
464        };
465
466        Ok(predictions)
467    }
468}
469
470impl<'a> PredictProba<ArrayView2<'a, f64>, Array2<f64>> for TrainedMockEstimator {
471    fn predict_proba(&self, x: &ArrayView2<'a, f64>) -> Result<Array2<f64>> {
472        // Convert predictions to probabilities (simplified for binary classification)
473        let predictions = self.predict(x)?;
474        let mut probabilities = Array2::zeros((x.nrows(), 2));
475
476        for (i, &pred) in predictions.iter().enumerate() {
477            let prob_positive = (pred.tanh() + 1.0) / 2.0; // Map to [0, 1]
478            probabilities[[i, 0]] = 1.0 - prob_positive;
479            probabilities[[i, 1]] = prob_positive;
480        }
481
482        Ok(probabilities)
483    }
484}
485
486impl<'a> Score<ArrayView2<'a, f64>, ArrayView1<'a, f64>> for TrainedMockEstimator {
487    type Float = f64;
488    fn score(&self, x: &ArrayView2<'a, f64>, y: &ArrayView1<'a, f64>) -> Result<f64> {
489        let predictions = self.predict(x)?;
490
491        // Calculate R² score for regression or accuracy for classification
492        match &self.estimator.config.behavior {
493            MockBehavior::Overfitting {
494                train_accuracy,
495                test_accuracy: _,
496            } => {
497                // Return perfect score for training data simulation
498                Ok(*train_accuracy)
499            }
500            _ => {
501                // Simple accuracy calculation (assuming classification)
502                let correct = predictions
503                    .iter()
504                    .zip(y.iter())
505                    .map(|(pred, actual)| {
506                        if (pred - actual).abs() < 0.5 {
507                            1.0
508                        } else {
509                            0.0
510                        }
511                    })
512                    .sum::<f64>();
513                Ok(correct / y.len() as f64)
514            }
515        }
516    }
517}
518
519/// Mock transformer for testing transformation pipelines
520#[derive(Debug, Clone)]
521pub struct MockTransformer {
522    config: MockTransformConfig,
523    fitted: bool,
524    input_shape: Option<(usize, usize)>,
525}
526
527/// Configuration for mock transformer
528#[derive(Debug, Clone)]
529pub struct MockTransformConfig {
530    pub transform_type: MockTransformType,
531    pub output_features: Option<usize>,
532    pub transform_delay: Duration,
533}
534
535/// Types of transformations to simulate
536#[derive(Debug, Clone)]
537pub enum MockTransformType {
538    /// Identity transformation (no change)
539    Identity,
540    /// Scale all values by a constant
541    Scale(f64),
542    /// Add constant to all values
543    Shift(f64),
544    /// Reduce feature dimensions
545    FeatureReduction { keep_ratio: f64 },
546    /// Expand feature dimensions
547    FeatureExpansion { expansion_factor: usize },
548    /// Simulate standardization (mean=0, std=1)
549    Standardization,
550}
551
552impl MockTransformer {
553    /// Create a new mock transformer
554    pub fn new(transform_type: MockTransformType) -> Self {
555        Self {
556            config: MockTransformConfig {
557                transform_type,
558                output_features: None,
559                transform_delay: Duration::from_millis(0),
560            },
561            fitted: false,
562            input_shape: None,
563        }
564    }
565
566    /// Create a mock transformer builder
567    pub fn builder() -> MockTransformerBuilder {
568        MockTransformerBuilder::new()
569    }
570}
571
572/// Builder for mock transformers
573#[derive(Debug)]
574pub struct MockTransformerBuilder {
575    transform_type: MockTransformType,
576    output_features: Option<usize>,
577    transform_delay: Duration,
578}
579
580impl MockTransformerBuilder {
581    pub fn new() -> Self {
582        Self {
583            transform_type: MockTransformType::Identity,
584            output_features: None,
585            transform_delay: Duration::from_millis(0),
586        }
587    }
588
589    pub fn with_transform_type(mut self, transform_type: MockTransformType) -> Self {
590        self.transform_type = transform_type;
591        self
592    }
593
594    pub fn with_output_features(mut self, features: usize) -> Self {
595        self.output_features = Some(features);
596        self
597    }
598
599    pub fn with_transform_delay(mut self, delay: Duration) -> Self {
600        self.transform_delay = delay;
601        self
602    }
603
604    pub fn build(self) -> MockTransformer {
605        MockTransformer {
606            config: MockTransformConfig {
607                transform_type: self.transform_type,
608                output_features: self.output_features,
609                transform_delay: self.transform_delay,
610            },
611            fitted: false,
612            input_shape: None,
613        }
614    }
615}
616
617impl Default for MockTransformerBuilder {
618    fn default() -> Self {
619        Self::new()
620    }
621}
622
623impl<'a> Fit<ArrayView2<'a, f64>, ArrayView1<'a, f64>> for MockTransformer {
624    type Fitted = MockTransformer;
625
626    fn fit(self, x: &ArrayView2<'a, f64>, _y: &ArrayView1<'a, f64>) -> Result<Self::Fitted> {
627        let mut fitted = self.clone();
628        fitted.fitted = true;
629        fitted.input_shape = Some((x.nrows(), x.ncols()));
630        Ok(fitted)
631    }
632}
633
634impl<'a> Transform<ArrayView2<'a, f64>, Array2<f64>> for MockTransformer {
635    fn transform(&self, x: &ArrayView2<'a, f64>) -> Result<Array2<f64>> {
636        if !self.fitted {
637            return Err(SklearsError::NotFitted {
638                operation: "transform".to_string(),
639            });
640        }
641
642        // Simulate transform delay
643        if !self.config.transform_delay.is_zero() {
644            std::thread::sleep(self.config.transform_delay);
645        }
646
647        match &self.config.transform_type {
648            MockTransformType::Identity => Ok(x.to_owned()),
649            MockTransformType::Scale(factor) => Ok(x * *factor),
650            MockTransformType::Shift(offset) => Ok(x + *offset),
651            MockTransformType::FeatureReduction { keep_ratio } => {
652                let keep_features = ((x.ncols() as f64) * keep_ratio).ceil() as usize;
653                let keep_features = keep_features.max(1).min(x.ncols());
654                Ok(x.slice(s![.., 0..keep_features]).to_owned())
655            }
656            MockTransformType::FeatureExpansion { expansion_factor } => {
657                let new_features = x.ncols() * expansion_factor;
658                let mut expanded = Array2::zeros((x.nrows(), new_features));
659
660                // Tile the original features
661                for i in 0..*expansion_factor {
662                    let start_col = i * x.ncols();
663                    let end_col = start_col + x.ncols();
664                    expanded.slice_mut(s![.., start_col..end_col]).assign(x);
665                }
666                Ok(expanded)
667            }
668            MockTransformType::Standardization => {
669                // Simple standardization simulation
670                let mean = x.mean().unwrap_or(0.0);
671                let std = x.std(0.0);
672                if std == 0.0 {
673                    Ok(x - mean)
674                } else {
675                    Ok((x - mean) / std)
676                }
677            }
678        }
679    }
680}
681
682/// Mock ensemble for testing ensemble methods
683#[derive(Debug)]
684#[allow(dead_code)]
685pub struct MockEnsemble {
686    estimators: Vec<MockEstimator>,
687    voting_strategy: VotingStrategy,
688    fitted: bool,
689}
690
691/// Voting strategies for mock ensemble
692#[derive(Debug, Clone)]
693pub enum VotingStrategy {
694    MajorityVote,
695    AverageVote,
696    WeightedVote(Vec<f64>),
697}
698
699impl MockEnsemble {
700    /// Create a new mock ensemble
701    pub fn new(estimators: Vec<MockEstimator>, voting_strategy: VotingStrategy) -> Self {
702        Self {
703            estimators,
704            voting_strategy,
705            fitted: false,
706        }
707    }
708
709    /// Get the number of base estimators
710    pub fn n_estimators(&self) -> usize {
711        self.estimators.len()
712    }
713
714    /// Get voting strategy
715    pub fn voting_strategy(&self) -> &VotingStrategy {
716        &self.voting_strategy
717    }
718}
719
720#[allow(non_snake_case)]
721#[cfg(test)]
722mod tests {
723    use super::*;
724    use scirs2_core::ndarray::Array2;
725
726    #[test]
727    fn test_mock_estimator_constant_prediction() {
728        let mock = MockEstimator::builder()
729            .with_behavior(MockBehavior::ConstantPrediction(5.0))
730            .build();
731
732        let features = Array2::zeros((10, 3));
733        let targets = Array1::zeros(10);
734
735        let trained = mock.clone().fit(&features.view(), &targets.view()).unwrap();
736        let predictions = trained.predict(&features.view()).unwrap();
737
738        assert_eq!(predictions.len(), 10);
739        assert!(predictions.iter().all(|&p| p == 5.0));
740    }
741
742    #[test]
743    fn test_mock_estimator_state_tracking() {
744        let mock = MockEstimator::new();
745        let features = Array2::zeros((5, 2));
746        let targets = Array1::zeros(5);
747
748        // Initial state
749        let state = mock.mock_state();
750        assert_eq!(state.fit_count, 0);
751        assert_eq!(state.predict_count, 0);
752        assert!(!state.fitted);
753
754        // After fitting
755        let trained = mock.clone().fit(&features.view(), &targets.view()).unwrap();
756        let state = mock.mock_state();
757        assert_eq!(state.fit_count, 1);
758        assert!(state.fitted);
759
760        // After predicting
761        let _ = trained.predict(&features.view()).unwrap();
762        let state = mock.mock_state();
763        assert_eq!(state.predict_count, 1);
764    }
765
766    #[test]
767    fn test_mock_estimator_feature_sum() {
768        let mock = MockEstimator::builder()
769            .with_behavior(MockBehavior::FeatureSum)
770            .build();
771
772        let features = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
773        let targets = Array1::zeros(2);
774
775        let trained = mock.clone().fit(&features.view(), &targets.view()).unwrap();
776        let predictions = trained.predict(&features.view()).unwrap();
777
778        assert_eq!(predictions[0], 6.0); // 1 + 2 + 3
779        assert_eq!(predictions[1], 15.0); // 4 + 5 + 6
780    }
781
782    #[test]
783    fn test_mock_estimator_linear_model() {
784        let weights = vec![1.0, 2.0, 3.0];
785        let bias = 1.0;
786
787        let mock = MockEstimator::builder()
788            .with_behavior(MockBehavior::LinearModel { weights, bias })
789            .build();
790
791        let features = Array2::from_shape_vec((1, 3), vec![1.0, 1.0, 1.0]).unwrap();
792        let targets = Array1::zeros(1);
793
794        let trained = mock.fit(&features.view(), &targets.view()).unwrap();
795        let predictions = trained.predict(&features.view()).unwrap();
796
797        assert_eq!(predictions[0], 7.0); // 1*1 + 2*1 + 3*1 + 1
798    }
799
800    #[test]
801    fn test_mock_transformer_identity() {
802        let transformer = MockTransformer::new(MockTransformType::Identity);
803        let data = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
804        let targets = Array1::zeros(2);
805
806        let fitted = transformer
807            .clone()
808            .fit(&data.view(), &targets.view())
809            .unwrap();
810        let transformed = fitted.transform(&data.view()).unwrap();
811
812        assert_eq!(transformed, data);
813    }
814
815    #[test]
816    fn test_mock_transformer_scale() {
817        let transformer = MockTransformer::new(MockTransformType::Scale(2.0));
818        let data = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
819        let targets = Array1::zeros(2);
820
821        let fitted = transformer
822            .clone()
823            .fit(&data.view(), &targets.view())
824            .unwrap();
825        let transformed = fitted.transform(&data.view()).unwrap();
826
827        let expected = Array2::from_shape_vec((2, 2), vec![2.0, 4.0, 6.0, 8.0]).unwrap();
828        assert_eq!(transformed, expected);
829    }
830
831    #[test]
832    fn test_mock_estimator_failure_simulation() {
833        let mock = MockEstimator::builder()
834            .with_fit_failure_probability(1.0) // Always fail
835            .build();
836
837        let features = Array2::zeros((5, 2));
838        let targets = Array1::zeros(5);
839
840        let result = mock.clone().fit(&features.view(), &targets.view());
841        assert!(result.is_err());
842    }
843
844    #[test]
845    fn test_mock_estimator_max_fit_calls() {
846        let mock = MockEstimator::builder().with_max_fit_calls(2).build();
847
848        let features = Array2::zeros((5, 2));
849        let targets = Array1::zeros(5);
850
851        // First two fits should succeed
852        assert!(mock.clone().fit(&features.view(), &targets.view()).is_ok());
853        assert!(mock.clone().fit(&features.view(), &targets.view()).is_ok());
854
855        // Third fit should fail
856        assert!(mock.clone().fit(&features.view(), &targets.view()).is_err());
857    }
858
859    #[test]
860    fn test_mock_estimator_predict_proba() {
861        let mock = MockEstimator::builder()
862            .with_behavior(MockBehavior::ConstantPrediction(0.0))
863            .build();
864
865        let features = Array2::zeros((3, 2));
866        let targets = Array1::zeros(3);
867
868        let trained = mock.clone().fit(&features.view(), &targets.view()).unwrap();
869        let probabilities = trained.predict_proba(&features.view()).unwrap();
870
871        assert_eq!(probabilities.shape(), &[3, 2]);
872        // All predictions should have probabilities that sum to 1
873        for row in probabilities.rows() {
874            let sum: f64 = row.sum();
875            assert!((sum - 1.0).abs() < 1e-10);
876        }
877    }
878
879    #[test]
880    fn test_mock_ensemble_creation() {
881        let est1 = MockEstimator::builder()
882            .with_behavior(MockBehavior::ConstantPrediction(1.0))
883            .build();
884        let est2 = MockEstimator::builder()
885            .with_behavior(MockBehavior::ConstantPrediction(2.0))
886            .build();
887
888        let ensemble = MockEnsemble::new(vec![est1, est2], VotingStrategy::AverageVote);
889
890        assert_eq!(ensemble.n_estimators(), 2);
891        assert!(matches!(
892            ensemble.voting_strategy(),
893            VotingStrategy::AverageVote
894        ));
895    }
896}