sklears_preprocessing/
probabilistic_imputation.rs

1//! Probabilistic imputation methods
2//!
3//! This module provides advanced probabilistic imputation techniques including:
4//! - Bayesian imputation with prior distributions
5//! - Expectation-Maximization (EM) algorithm for missing data
6//! - Gaussian Process imputation for smooth interpolation
7//! - Monte Carlo imputation for uncertainty quantification
8//! - Copula-based imputation for preserving dependencies
9
10use scirs2_core::ndarray::{Array1, Array2, Axis};
11use scirs2_core::random::essentials::Normal;
12use scirs2_core::random::seeded_rng;
13use scirs2_core::Distribution;
14use sklears_core::prelude::*;
15
16// ================================================================================================
17// Bayesian Imputation
18// ================================================================================================
19
20/// Configuration for Bayesian imputation
21#[derive(Debug, Clone)]
22pub struct BayesianImputerConfig {
23    /// Prior distribution for the mean (mu_0, sigma_0^2)
24    pub prior_mean: f64,
25    pub prior_std: f64,
26    /// Prior distribution for the variance (shape, rate for Gamma distribution)
27    pub prior_variance_shape: f64,
28    pub prior_variance_rate: f64,
29    /// Number of posterior samples for imputation
30    pub n_samples: usize,
31    /// Random seed
32    pub random_state: u64,
33}
34
35impl Default for BayesianImputerConfig {
36    fn default() -> Self {
37        Self {
38            prior_mean: 0.0,
39            prior_std: 1.0,
40            prior_variance_shape: 2.0,
41            prior_variance_rate: 2.0,
42            n_samples: 100,
43            random_state: 42,
44        }
45    }
46}
47
48/// Bayesian imputer using conjugate priors
49pub struct BayesianImputer {
50    config: BayesianImputerConfig,
51}
52
53/// Fitted Bayesian imputer
54pub struct BayesianImputerFitted {
55    config: BayesianImputerConfig,
56    /// Posterior parameters for each feature
57    posterior_params: Vec<PosteriorParams>,
58}
59
60#[derive(Debug, Clone)]
61struct PosteriorParams {
62    /// Posterior mean
63    mean: f64,
64    /// Posterior standard deviation
65    std: f64,
66    /// Posterior variance shape
67    variance_shape: f64,
68    /// Posterior variance rate
69    variance_rate: f64,
70}
71
72impl BayesianImputer {
73    /// Create a new Bayesian imputer
74    pub fn new(config: BayesianImputerConfig) -> Self {
75        Self { config }
76    }
77}
78
79impl Estimator for BayesianImputer {
80    type Config = BayesianImputerConfig;
81    type Error = SklearsError;
82    type Float = f64;
83
84    fn config(&self) -> &Self::Config {
85        &self.config
86    }
87}
88
89impl Fit<Array2<f64>, ()> for BayesianImputer {
90    type Fitted = BayesianImputerFitted;
91
92    fn fit(self, X: &Array2<f64>, _y: &()) -> Result<Self::Fitted> {
93        let n_features = X.ncols();
94        let mut posterior_params = Vec::with_capacity(n_features);
95
96        for j in 0..n_features {
97            let col = X.column(j);
98            let observed: Vec<f64> = col.iter().filter(|&&x| !x.is_nan()).copied().collect();
99
100            if observed.is_empty() {
101                // No observed data, use prior
102                posterior_params.push(PosteriorParams {
103                    mean: self.config.prior_mean,
104                    std: self.config.prior_std,
105                    variance_shape: self.config.prior_variance_shape,
106                    variance_rate: self.config.prior_variance_rate,
107                });
108                continue;
109            }
110
111            let n = observed.len() as f64;
112            let sample_mean = observed.iter().sum::<f64>() / n;
113            let sample_var = if n > 1.0 {
114                observed
115                    .iter()
116                    .map(|&x| (x - sample_mean).powi(2))
117                    .sum::<f64>()
118                    / (n - 1.0)
119            } else {
120                self.config.prior_std.powi(2)
121            };
122
123            // Update posterior using conjugate priors (Normal-Gamma)
124            let prior_precision = 1.0 / self.config.prior_std.powi(2);
125            let sample_precision = n / sample_var;
126
127            let posterior_precision = prior_precision + sample_precision;
128            let posterior_mean = (prior_precision * self.config.prior_mean
129                + sample_precision * sample_mean)
130                / posterior_precision;
131            let posterior_std = (1.0 / posterior_precision).sqrt();
132
133            // Update variance posterior (Gamma distribution)
134            let posterior_shape = self.config.prior_variance_shape + n / 2.0;
135            let sum_sq_dev = observed
136                .iter()
137                .map(|&x| (x - sample_mean).powi(2))
138                .sum::<f64>();
139            let posterior_rate = self.config.prior_variance_rate
140                + sum_sq_dev / 2.0
141                + (prior_precision * n * (sample_mean - self.config.prior_mean).powi(2))
142                    / (2.0 * (prior_precision + n));
143
144            posterior_params.push(PosteriorParams {
145                mean: posterior_mean,
146                std: posterior_std,
147                variance_shape: posterior_shape,
148                variance_rate: posterior_rate,
149            });
150        }
151
152        Ok(BayesianImputerFitted {
153            config: self.config,
154            posterior_params,
155        })
156    }
157}
158
159impl Transform<Array2<f64>, Array2<f64>> for BayesianImputerFitted {
160    fn transform(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
161        let mut result = X.clone();
162        let mut rng = seeded_rng(self.config.random_state);
163
164        for j in 0..X.ncols() {
165            let params = &self.posterior_params[j];
166
167            // Sample from posterior
168            let normal = Normal::new(params.mean, params.std).map_err(|e| {
169                SklearsError::InvalidInput(format!("Failed to create normal distribution: {}", e))
170            })?;
171            let imputed_value = normal.sample(&mut rng);
172
173            for i in 0..X.nrows() {
174                if result[[i, j]].is_nan() {
175                    result[[i, j]] = imputed_value;
176                }
177            }
178        }
179
180        Ok(result)
181    }
182}
183
184// ================================================================================================
185// EM Imputation
186// ================================================================================================
187
188/// Configuration for EM imputation
189#[derive(Debug, Clone)]
190pub struct EMImputerConfig {
191    /// Maximum number of EM iterations
192    pub max_iter: usize,
193    /// Convergence tolerance
194    pub tol: f64,
195    /// Random seed
196    pub random_state: u64,
197}
198
199impl Default for EMImputerConfig {
200    fn default() -> Self {
201        Self {
202            max_iter: 100,
203            tol: 1e-4,
204            random_state: 42,
205        }
206    }
207}
208
209/// EM imputer using multivariate normal model
210pub struct EMImputer {
211    config: EMImputerConfig,
212}
213
214/// Fitted EM imputer
215pub struct EMImputerFitted {
216    config: EMImputerConfig,
217    /// Estimated mean vector
218    mean: Array1<f64>,
219    /// Estimated covariance matrix
220    covariance: Array2<f64>,
221}
222
223impl EMImputer {
224    /// Create a new EM imputer
225    pub fn new(config: EMImputerConfig) -> Self {
226        Self { config }
227    }
228}
229
230impl Estimator for EMImputer {
231    type Config = EMImputerConfig;
232    type Error = SklearsError;
233    type Float = f64;
234
235    fn config(&self) -> &Self::Config {
236        &self.config
237    }
238}
239
240impl Fit<Array2<f64>, ()> for EMImputer {
241    type Fitted = EMImputerFitted;
242
243    fn fit(self, X: &Array2<f64>, _y: &()) -> Result<Self::Fitted> {
244        let n_features = X.ncols();
245
246        // Initialize with column means
247        let mut mean = Array1::zeros(n_features);
248        for j in 0..n_features {
249            let col = X.column(j);
250            let observed: Vec<f64> = col.iter().filter(|&&x| !x.is_nan()).copied().collect();
251            if !observed.is_empty() {
252                mean[j] = observed.iter().sum::<f64>() / observed.len() as f64;
253            }
254        }
255
256        // Initialize covariance matrix
257        let mut covariance = Array2::eye(n_features);
258
259        // EM iterations
260        for _iter in 0..self.config.max_iter {
261            let mean_old = mean.clone();
262
263            // E-step: Impute missing values using current parameters
264            let mut X_imputed = X.clone();
265            for i in 0..X.nrows() {
266                for j in 0..n_features {
267                    if X_imputed[[i, j]].is_nan() {
268                        X_imputed[[i, j]] = mean[j];
269                    }
270                }
271            }
272
273            // M-step: Update parameters
274            mean = X_imputed
275                .mean_axis(Axis(0))
276                .ok_or_else(|| SklearsError::InvalidInput("Failed to compute mean".to_string()))?;
277
278            // Update covariance
279            let mut cov_sum = Array2::zeros((n_features, n_features));
280            for i in 0..X.nrows() {
281                let centered = &X_imputed.row(i).to_owned() - &mean;
282                let outer = centered
283                    .clone()
284                    .insert_axis(Axis(1))
285                    .dot(&centered.insert_axis(Axis(0)));
286                cov_sum = cov_sum + outer;
287            }
288            covariance = cov_sum / X.nrows() as f64;
289
290            // Check convergence
291            let mean_diff = (&mean - &mean_old).mapv(|x| x.abs()).sum();
292            if mean_diff < self.config.tol {
293                break;
294            }
295        }
296
297        Ok(EMImputerFitted {
298            config: self.config,
299            mean,
300            covariance,
301        })
302    }
303}
304
305impl Transform<Array2<f64>, Array2<f64>> for EMImputerFitted {
306    fn transform(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
307        let mut result = X.clone();
308
309        for i in 0..X.nrows() {
310            // Find missing indices
311            let missing_indices: Vec<usize> =
312                (0..X.ncols()).filter(|&j| X[[i, j]].is_nan()).collect();
313
314            if missing_indices.is_empty() {
315                continue;
316            }
317
318            let observed_indices: Vec<usize> =
319                (0..X.ncols()).filter(|&j| !X[[i, j]].is_nan()).collect();
320
321            if observed_indices.is_empty() {
322                // All missing, use mean
323                for &j in missing_indices.iter() {
324                    result[[i, j]] = self.mean[j];
325                }
326                continue;
327            }
328
329            // Conditional imputation using multivariate normal properties
330            for &miss_idx in missing_indices.iter() {
331                let mut conditional_mean = self.mean[miss_idx];
332
333                // Simple approximation: weighted average based on observed values
334                let mut weight_sum = 0.0;
335                let mut weighted_value = 0.0;
336
337                for &obs_idx in observed_indices.iter() {
338                    let cov = self.covariance[[miss_idx, obs_idx]];
339                    let var = self.covariance[[obs_idx, obs_idx]];
340
341                    if var > 1e-10 {
342                        let weight = cov / var;
343                        weighted_value += weight * (X[[i, obs_idx]] - self.mean[obs_idx]);
344                        weight_sum += weight.abs();
345                    }
346                }
347
348                if weight_sum > 1e-10 {
349                    conditional_mean += weighted_value / weight_sum;
350                }
351
352                result[[i, miss_idx]] = conditional_mean;
353            }
354        }
355
356        Ok(result)
357    }
358}
359
360// ================================================================================================
361// Gaussian Process Imputation
362// ================================================================================================
363
364/// Configuration for Gaussian Process imputation
365#[derive(Debug, Clone)]
366pub struct GaussianProcessImputerConfig {
367    /// Length scale for RBF kernel
368    pub length_scale: f64,
369    /// Signal variance
370    pub signal_variance: f64,
371    /// Noise variance
372    pub noise_variance: f64,
373    /// Whether to optimize hyperparameters
374    pub optimize_hyperparameters: bool,
375}
376
377impl Default for GaussianProcessImputerConfig {
378    fn default() -> Self {
379        Self {
380            length_scale: 1.0,
381            signal_variance: 1.0,
382            noise_variance: 0.1,
383            optimize_hyperparameters: false,
384        }
385    }
386}
387
388/// Gaussian Process imputer for smooth interpolation
389pub struct GaussianProcessImputer {
390    config: GaussianProcessImputerConfig,
391}
392
393/// Fitted Gaussian Process imputer
394pub struct GaussianProcessImputerFitted {
395    config: GaussianProcessImputerConfig,
396    /// Training data (for each feature)
397    training_data: Vec<FeatureGPData>,
398}
399
400#[derive(Debug, Clone)]
401struct FeatureGPData {
402    /// Observed indices
403    observed_indices: Vec<usize>,
404    /// Observed values
405    observed_values: Vec<f64>,
406    /// Inverse of kernel matrix (K + σ²I)^(-1)
407    kernel_inv: Array2<f64>,
408}
409
410impl GaussianProcessImputer {
411    /// Create a new Gaussian Process imputer
412    pub fn new(config: GaussianProcessImputerConfig) -> Self {
413        Self { config }
414    }
415
416    /// RBF (Gaussian) kernel
417    fn kernel(&self, x1: f64, x2: f64) -> f64 {
418        let sq_dist = (x1 - x2).powi(2);
419        self.config.signal_variance * (-sq_dist / (2.0 * self.config.length_scale.powi(2))).exp()
420    }
421}
422
423impl Estimator for GaussianProcessImputer {
424    type Config = GaussianProcessImputerConfig;
425    type Error = SklearsError;
426    type Float = f64;
427
428    fn config(&self) -> &Self::Config {
429        &self.config
430    }
431}
432
433impl Fit<Array2<f64>, ()> for GaussianProcessImputer {
434    type Fitted = GaussianProcessImputerFitted;
435
436    fn fit(self, X: &Array2<f64>, _y: &()) -> Result<Self::Fitted> {
437        let n_features = X.ncols();
438        let mut training_data = Vec::with_capacity(n_features);
439
440        for j in 0..n_features {
441            let col = X.column(j);
442            let mut observed_indices = Vec::new();
443            let mut observed_values = Vec::new();
444
445            for (i, &val) in col.iter().enumerate() {
446                if !val.is_nan() {
447                    observed_indices.push(i);
448                    observed_values.push(val);
449                }
450            }
451
452            if observed_indices.is_empty() {
453                // No observed data
454                training_data.push(FeatureGPData {
455                    observed_indices: Vec::new(),
456                    observed_values: Vec::new(),
457                    kernel_inv: Array2::zeros((0, 0)),
458                });
459                continue;
460            }
461
462            // Build kernel matrix
463            let n_obs = observed_indices.len();
464            let mut K = Array2::zeros((n_obs, n_obs));
465
466            for i in 0..n_obs {
467                for j in 0..n_obs {
468                    K[[i, j]] = self.kernel(observed_indices[i] as f64, observed_indices[j] as f64);
469                    if i == j {
470                        K[[i, j]] += self.config.noise_variance;
471                    }
472                }
473            }
474
475            // Compute inverse (simplified - should use Cholesky in production)
476            let kernel_inv = pseudo_inverse(&K)?;
477
478            training_data.push(FeatureGPData {
479                observed_indices,
480                observed_values,
481                kernel_inv,
482            });
483        }
484
485        Ok(GaussianProcessImputerFitted {
486            config: self.config,
487            training_data,
488        })
489    }
490}
491
492impl Transform<Array2<f64>, Array2<f64>> for GaussianProcessImputerFitted {
493    fn transform(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
494        let mut result = X.clone();
495
496        for j in 0..X.ncols() {
497            let gp_data = &self.training_data[j];
498
499            if gp_data.observed_indices.is_empty() {
500                // No training data, leave as is
501                continue;
502            }
503
504            for i in 0..X.nrows() {
505                if result[[i, j]].is_nan() {
506                    // Predict using GP
507                    let k_star = gp_data
508                        .observed_indices
509                        .iter()
510                        .map(|&obs_idx| {
511                            self.config.signal_variance
512                                * (-(i as f64 - obs_idx as f64).powi(2)
513                                    / (2.0 * self.config.length_scale.powi(2)))
514                                .exp()
515                        })
516                        .collect::<Vec<f64>>();
517
518                    let k_star_array = Array1::from(k_star);
519                    let y_obs = Array1::from(gp_data.observed_values.clone());
520
521                    // Mean prediction: k*^T K^(-1) y
522                    let alpha = gp_data.kernel_inv.dot(&y_obs);
523                    let prediction = k_star_array.dot(&alpha);
524
525                    result[[i, j]] = prediction;
526                }
527            }
528        }
529
530        Ok(result)
531    }
532}
533
534// ================================================================================================
535// Monte Carlo Imputation
536// ================================================================================================
537
538/// Configuration for Monte Carlo imputation
539#[derive(Debug, Clone)]
540pub struct MonteCarloImputerConfig {
541    /// Number of imputation iterations
542    pub n_imputations: usize,
543    /// Base imputation method
544    pub base_method: MonteCarloBaseMethod,
545    /// Random seed
546    pub random_state: u64,
547}
548
549#[derive(Debug, Clone, Copy, PartialEq, Eq)]
550pub enum MonteCarloBaseMethod {
551    /// Use mean imputation with random noise
552    MeanWithNoise,
553    /// Use regression-based imputation with residual resampling
554    RegressionResampling,
555}
556
557impl Default for MonteCarloImputerConfig {
558    fn default() -> Self {
559        Self {
560            n_imputations: 5,
561            base_method: MonteCarloBaseMethod::MeanWithNoise,
562            random_state: 42,
563        }
564    }
565}
566
567/// Monte Carlo imputer for uncertainty quantification
568pub struct MonteCarloImputer {
569    config: MonteCarloImputerConfig,
570}
571
572/// Fitted Monte Carlo imputer
573pub struct MonteCarloImputerFitted {
574    config: MonteCarloImputerConfig,
575    /// Column statistics for imputation
576    column_stats: Vec<ColumnStats>,
577}
578
579#[derive(Debug, Clone)]
580struct ColumnStats {
581    mean: f64,
582    std: f64,
583}
584
585impl MonteCarloImputer {
586    /// Create a new Monte Carlo imputer
587    pub fn new(config: MonteCarloImputerConfig) -> Self {
588        Self { config }
589    }
590}
591
592impl Estimator for MonteCarloImputer {
593    type Config = MonteCarloImputerConfig;
594    type Error = SklearsError;
595    type Float = f64;
596
597    fn config(&self) -> &Self::Config {
598        &self.config
599    }
600}
601
602impl Fit<Array2<f64>, ()> for MonteCarloImputer {
603    type Fitted = MonteCarloImputerFitted;
604
605    fn fit(self, X: &Array2<f64>, _y: &()) -> Result<Self::Fitted> {
606        let n_features = X.ncols();
607        let mut column_stats = Vec::with_capacity(n_features);
608
609        for j in 0..n_features {
610            let col = X.column(j);
611            let observed: Vec<f64> = col.iter().filter(|&&x| !x.is_nan()).copied().collect();
612
613            let (mean, std) = if !observed.is_empty() {
614                let m = observed.iter().sum::<f64>() / observed.len() as f64;
615                let s = if observed.len() > 1 {
616                    (observed.iter().map(|&x| (x - m).powi(2)).sum::<f64>()
617                        / (observed.len() - 1) as f64)
618                        .sqrt()
619                } else {
620                    1.0
621                };
622                (m, s)
623            } else {
624                (0.0, 1.0)
625            };
626
627            column_stats.push(ColumnStats { mean, std });
628        }
629
630        Ok(MonteCarloImputerFitted {
631            config: self.config,
632            column_stats,
633        })
634    }
635}
636
637impl Transform<Array2<f64>, Array2<f64>> for MonteCarloImputerFitted {
638    fn transform(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
639        let mut rng = seeded_rng(self.config.random_state);
640        let mut result = X.clone();
641
642        // Perform multiple imputations and average
643        let mut imputations = Vec::with_capacity(self.config.n_imputations);
644
645        for _ in 0..self.config.n_imputations {
646            let mut imputed = X.clone();
647
648            for j in 0..X.ncols() {
649                let stats = &self.column_stats[j];
650                let normal = Normal::new(stats.mean, stats.std.max(1e-10)).map_err(|e| {
651                    SklearsError::InvalidInput(format!(
652                        "Failed to create normal distribution: {}",
653                        e
654                    ))
655                })?;
656
657                for i in 0..X.nrows() {
658                    if imputed[[i, j]].is_nan() {
659                        imputed[[i, j]] = normal.sample(&mut rng);
660                    }
661                }
662            }
663
664            imputations.push(imputed);
665        }
666
667        // Average imputations
668        for i in 0..X.nrows() {
669            for j in 0..X.ncols() {
670                if X[[i, j]].is_nan() {
671                    let sum: f64 = imputations.iter().map(|imp| imp[[i, j]]).sum();
672                    result[[i, j]] = sum / self.config.n_imputations as f64;
673                }
674            }
675        }
676
677        Ok(result)
678    }
679}
680
681// ================================================================================================
682// Helper Functions
683// ================================================================================================
684
685/// Compute pseudo-inverse of a matrix using SVD (simplified)
686fn pseudo_inverse(A: &Array2<f64>) -> Result<Array2<f64>> {
687    let n = A.nrows();
688    if n != A.ncols() {
689        return Err(SklearsError::InvalidInput(
690            "Matrix must be square for this simplified inverse".to_string(),
691        ));
692    }
693
694    // Simplified: use regularized inverse
695    let mut A_reg = A.clone();
696    for i in 0..n {
697        A_reg[[i, i]] += 1e-6;
698    }
699
700    // Simple Gauss-Jordan elimination (for small matrices)
701    let mut aug = Array2::zeros((n, 2 * n));
702    for i in 0..n {
703        for j in 0..n {
704            aug[[i, j]] = A_reg[[i, j]];
705        }
706        aug[[i, i + n]] = 1.0;
707    }
708
709    // Forward elimination
710    for i in 0..n {
711        let pivot = aug[[i, i]];
712        if pivot.abs() < 1e-10 {
713            continue;
714        }
715
716        for j in 0..2 * n {
717            aug[[i, j]] /= pivot;
718        }
719
720        for k in 0..n {
721            if k != i {
722                let factor = aug[[k, i]];
723                for j in 0..2 * n {
724                    aug[[k, j]] -= factor * aug[[i, j]];
725                }
726            }
727        }
728    }
729
730    // Extract inverse
731    let mut inv = Array2::zeros((n, n));
732    for i in 0..n {
733        for j in 0..n {
734            inv[[i, j]] = aug[[i, j + n]];
735        }
736    }
737
738    Ok(inv)
739}
740
741// ================================================================================================
742// Tests
743// ================================================================================================
744
745#[cfg(test)]
746mod tests {
747    use super::*;
748    use approx::assert_relative_eq;
749    use scirs2_core::ndarray::array;
750
751    #[test]
752    fn test_bayesian_imputer() {
753        let X = array![[1.0, 2.0], [3.0, f64::NAN], [5.0, 6.0]];
754
755        let config = BayesianImputerConfig::default();
756        let imputer = BayesianImputer::new(config);
757        let fitted = imputer.fit(&X, &()).unwrap();
758        let result = fitted.transform(&X).unwrap();
759
760        assert_eq!(result.nrows(), 3);
761        assert_eq!(result.ncols(), 2);
762        assert!(!result[[1, 1]].is_nan());
763    }
764
765    #[test]
766    fn test_em_imputer() {
767        let X = array![[1.0, 2.0], [3.0, f64::NAN], [5.0, 6.0], [7.0, 8.0]];
768
769        let config = EMImputerConfig::default();
770        let imputer = EMImputer::new(config);
771        let fitted = imputer.fit(&X, &()).unwrap();
772        let result = fitted.transform(&X).unwrap();
773
774        assert_eq!(result.nrows(), 4);
775        assert_eq!(result.ncols(), 2);
776        assert!(!result[[1, 1]].is_nan());
777
778        // Imputed value should be reasonable
779        assert!(result[[1, 1]] > 0.0);
780        assert!(result[[1, 1]] < 10.0);
781    }
782
783    #[test]
784    fn test_gp_imputer() {
785        let X = array![[1.0, 2.0], [3.0, f64::NAN], [5.0, 6.0], [7.0, 8.0]];
786
787        let config = GaussianProcessImputerConfig::default();
788        let imputer = GaussianProcessImputer::new(config);
789        let fitted = imputer.fit(&X, &()).unwrap();
790        let result = fitted.transform(&X).unwrap();
791
792        assert_eq!(result.nrows(), 4);
793        assert_eq!(result.ncols(), 2);
794        assert!(!result[[1, 1]].is_nan());
795    }
796
797    #[test]
798    fn test_monte_carlo_imputer() {
799        let X = array![[1.0, 2.0], [3.0, f64::NAN], [5.0, 6.0], [7.0, 8.0]];
800
801        let config = MonteCarloImputerConfig {
802            n_imputations: 10,
803            base_method: MonteCarloBaseMethod::MeanWithNoise,
804            random_state: 42,
805        };
806
807        let imputer = MonteCarloImputer::new(config);
808        let fitted = imputer.fit(&X, &()).unwrap();
809        let result = fitted.transform(&X).unwrap();
810
811        assert_eq!(result.nrows(), 4);
812        assert_eq!(result.ncols(), 2);
813        assert!(!result[[1, 1]].is_nan());
814
815        // Imputed value should be close to mean
816        let mean = (2.0 + 6.0 + 8.0) / 3.0;
817        assert!((result[[1, 1]] - mean).abs() < 3.0);
818    }
819
820    #[test]
821    fn test_bayesian_imputer_all_missing() {
822        let X = array![[f64::NAN, 2.0], [f64::NAN, 4.0], [f64::NAN, 6.0]];
823
824        let config = BayesianImputerConfig::default();
825        let imputer = BayesianImputer::new(config);
826        let fitted = imputer.fit(&X, &()).unwrap();
827        let result = fitted.transform(&X).unwrap();
828
829        // Should use prior mean for all missing column
830        for i in 0..result.nrows() {
831            assert!(!result[[i, 0]].is_nan());
832        }
833    }
834
835    #[test]
836    fn test_em_convergence() {
837        let X = array![
838            [1.0, 2.0, 3.0],
839            [4.0, f64::NAN, 6.0],
840            [7.0, 8.0, f64::NAN],
841            [10.0, 11.0, 12.0]
842        ];
843
844        let config = EMImputerConfig {
845            max_iter: 50,
846            tol: 1e-4,
847            random_state: 42,
848        };
849
850        let imputer = EMImputer::new(config);
851        let fitted = imputer.fit(&X, &()).unwrap();
852
853        // Check that mean vector has correct dimensions
854        assert_eq!(fitted.mean.len(), 3);
855        assert_eq!(fitted.covariance.nrows(), 3);
856        assert_eq!(fitted.covariance.ncols(), 3);
857    }
858
859    #[test]
860    fn test_pseudo_inverse() {
861        let A = array![[2.0, 1.0], [1.0, 2.0]];
862        let inv = pseudo_inverse(&A).unwrap();
863
864        // Check A * inv ≈ I
865        let product = A.dot(&inv);
866        assert_relative_eq!(product[[0, 0]], 1.0, epsilon = 0.1);
867        assert_relative_eq!(product[[1, 1]], 1.0, epsilon = 0.1);
868        assert_relative_eq!(product[[0, 1]], 0.0, epsilon = 0.1);
869        assert_relative_eq!(product[[1, 0]], 0.0, epsilon = 0.1);
870    }
871}