scirs2_transform/
impute.rs

1//! Missing value imputation utilities
2//!
3//! This module provides methods for handling missing values in datasets,
4//! which is a crucial preprocessing step for machine learning.
5
6use ndarray::{Array1, Array2, ArrayBase, Data, Ix2};
7use num_traits::{Float, NumCast};
8use scirs2_core::parallel_ops::*;
9
10use crate::error::{Result, TransformError};
11
12/// Strategy for imputing missing values
13#[derive(Debug, Clone, PartialEq)]
14pub enum ImputeStrategy {
15    /// Replace missing values with mean of the feature
16    Mean,
17    /// Replace missing values with median of the feature
18    Median,
19    /// Replace missing values with most frequent value
20    MostFrequent,
21    /// Replace missing values with a constant value
22    Constant(f64),
23}
24
25/// SimpleImputer for filling missing values
26///
27/// This transformer fills missing values using simple strategies like mean,
28/// median, most frequent value, or a constant value.
29pub struct SimpleImputer {
30    /// Strategy for imputation
31    strategy: ImputeStrategy,
32    /// Missing value indicator (what value is considered missing)
33    missingvalues: f64,
34    /// Values used for imputation (computed during fit)
35    statistics_: Option<Array1<f64>>,
36}
37
38impl SimpleImputer {
39    /// Creates a new SimpleImputer
40    ///
41    /// # Arguments
42    /// * `strategy` - The imputation strategy to use
43    /// * `missingvalues` - The value that represents missing data (default: NaN)
44    ///
45    /// # Returns
46    /// * A new SimpleImputer instance
47    pub fn new(strategy: ImputeStrategy, missingvalues: f64) -> Self {
48        SimpleImputer {
49            strategy,
50            missingvalues,
51            statistics_: None,
52        }
53    }
54
55    /// Creates a SimpleImputer with NaN as missing value indicator
56    ///
57    /// # Arguments
58    /// * `strategy` - The imputation strategy to use
59    ///
60    /// # Returns
61    /// * A new SimpleImputer instance
62    #[allow(dead_code)]
63    pub fn with_strategy(strategy: ImputeStrategy) -> Self {
64        Self::new(strategy, f64::NAN)
65    }
66
67    /// Fits the SimpleImputer to the input data
68    ///
69    /// # Arguments
70    /// * `x` - The input data, shape (n_samples, n_features)
71    ///
72    /// # Returns
73    /// * `Result<()>` - Ok if successful, Err otherwise
74    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
75    where
76        S: Data,
77        S::Elem: Float + NumCast,
78    {
79        let x_f64 = x.mapv(|x| num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0));
80
81        let n_samples = x_f64.shape()[0];
82        let n_features = x_f64.shape()[1];
83
84        if n_samples == 0 || n_features == 0 {
85            return Err(TransformError::InvalidInput("Empty input data".to_string()));
86        }
87
88        let mut statistics = Array1::zeros(n_features);
89
90        for j in 0..n_features {
91            // Extract non-missing values for this feature
92            let feature_data: Vec<f64> = x_f64
93                .column(j)
94                .iter()
95                .filter(|&&val| !self.is_missing(val))
96                .copied()
97                .collect();
98
99            if feature_data.is_empty() {
100                return Err(TransformError::InvalidInput(format!(
101                    "All values are missing in feature {j}"
102                )));
103            }
104
105            statistics[j] = match &self.strategy {
106                ImputeStrategy::Mean => {
107                    feature_data.iter().sum::<f64>() / feature_data.len() as f64
108                }
109                ImputeStrategy::Median => {
110                    let mut sorted_data = feature_data.clone();
111                    sorted_data
112                        .sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
113                    let n = sorted_data.len();
114                    if n % 2 == 0 {
115                        (sorted_data[n / 2 - 1] + sorted_data[n / 2]) / 2.0
116                    } else {
117                        sorted_data[n / 2]
118                    }
119                }
120                ImputeStrategy::MostFrequent => {
121                    // For numerical data, we'll find the value that appears most frequently
122                    // This is a simplified implementation
123                    let mut counts = std::collections::HashMap::new();
124                    for &val in &feature_data {
125                        *counts.entry(val.to_bits()).or_insert(0) += 1;
126                    }
127
128                    let most_frequent_bits = counts
129                        .into_iter()
130                        .max_by_key(|(_, count)| *count)
131                        .map(|(bits_, _)| bits_)
132                        .unwrap_or(0);
133
134                    f64::from_bits(most_frequent_bits)
135                }
136                ImputeStrategy::Constant(value) => *value,
137            };
138        }
139
140        self.statistics_ = Some(statistics);
141        Ok(())
142    }
143
144    /// Transforms the input data using the fitted SimpleImputer
145    ///
146    /// # Arguments
147    /// * `x` - The input data, shape (n_samples, n_features)
148    ///
149    /// # Returns
150    /// * `Result<Array2<f64>>` - The transformed data with imputed values
151    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
152    where
153        S: Data,
154        S::Elem: Float + NumCast,
155    {
156        let x_f64 = x.mapv(|x| num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0));
157
158        let n_samples = x_f64.shape()[0];
159        let n_features = x_f64.shape()[1];
160
161        if self.statistics_.is_none() {
162            return Err(TransformError::TransformationError(
163                "SimpleImputer has not been fitted".to_string(),
164            ));
165        }
166
167        let statistics = self.statistics_.as_ref().unwrap();
168
169        if n_features != statistics.len() {
170            return Err(TransformError::InvalidInput(format!(
171                "x has {} features, but SimpleImputer was fitted with {} features",
172                n_features,
173                statistics.len()
174            )));
175        }
176
177        let mut transformed = Array2::zeros((n_samples, n_features));
178
179        for i in 0..n_samples {
180            for j in 0..n_features {
181                let value = x_f64[[i, j]];
182                if self.is_missing(value) {
183                    transformed[[i, j]] = statistics[j];
184                } else {
185                    transformed[[i, j]] = value;
186                }
187            }
188        }
189
190        Ok(transformed)
191    }
192
193    /// Fits the SimpleImputer to the input data and transforms it
194    ///
195    /// # Arguments
196    /// * `x` - The input data, shape (n_samples, n_features)
197    ///
198    /// # Returns
199    /// * `Result<Array2<f64>>` - The transformed data with imputed values
200    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
201    where
202        S: Data,
203        S::Elem: Float + NumCast,
204    {
205        self.fit(x)?;
206        self.transform(x)
207    }
208
209    /// Returns the statistics computed during fitting
210    ///
211    /// # Returns
212    /// * `Option<&Array1<f64>>` - The statistics for each feature
213    #[allow(dead_code)]
214    pub fn statistics(&self) -> Option<&Array1<f64>> {
215        self.statistics_.as_ref()
216    }
217
218    /// Checks if a value is considered missing
219    ///
220    /// # Arguments
221    /// * `value` - The value to check
222    ///
223    /// # Returns
224    /// * `bool` - True if the value is missing, false otherwise
225    fn is_missing(&self, value: f64) -> bool {
226        if self.missingvalues.is_nan() {
227            value.is_nan()
228        } else {
229            (value - self.missingvalues).abs() < f64::EPSILON
230        }
231    }
232}
233
234/// Indicator for missing values
235///
236/// This transformer creates a binary indicator matrix that shows where
237/// missing values were located in the original data.
238pub struct MissingIndicator {
239    /// Missing value indicator (what value is considered missing)
240    missingvalues: f64,
241    /// Features that have missing values (computed during fit)
242    features_: Option<Vec<usize>>,
243}
244
245impl MissingIndicator {
246    /// Creates a new MissingIndicator
247    ///
248    /// # Arguments
249    /// * `missingvalues` - The value that represents missing data (default: NaN)
250    ///
251    /// # Returns
252    /// * A new MissingIndicator instance
253    pub fn new(missingvalues: f64) -> Self {
254        MissingIndicator {
255            missingvalues,
256            features_: None,
257        }
258    }
259
260    /// Creates a MissingIndicator with NaN as missing value indicator
261    pub fn with_nan() -> Self {
262        Self::new(f64::NAN)
263    }
264
265    /// Fits the MissingIndicator to the input data
266    ///
267    /// # Arguments
268    /// * `x` - The input data, shape (n_samples, n_features)
269    ///
270    /// # Returns
271    /// * `Result<()>` - Ok if successful, Err otherwise
272    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
273    where
274        S: Data,
275        S::Elem: Float + NumCast,
276    {
277        let x_f64 = x.mapv(|x| num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0));
278
279        let n_features = x_f64.shape()[1];
280        let mut features_with_missing = Vec::new();
281
282        for j in 0..n_features {
283            let has_missing = x_f64.column(j).iter().any(|&val| self.is_missing(val));
284            if has_missing {
285                features_with_missing.push(j);
286            }
287        }
288
289        self.features_ = Some(features_with_missing);
290        Ok(())
291    }
292
293    /// Transforms the input data to create missing value indicators
294    ///
295    /// # Arguments
296    /// * `x` - The input data, shape (n_samples, n_features)
297    ///
298    /// # Returns
299    /// * `Result<Array2<f64>>` - Binary indicator matrix, shape (n_samples, n_features_with_missing)
300    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
301    where
302        S: Data,
303        S::Elem: Float + NumCast,
304    {
305        let x_f64 = x.mapv(|x| num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0));
306
307        let n_samples = x_f64.shape()[0];
308
309        if self.features_.is_none() {
310            return Err(TransformError::TransformationError(
311                "MissingIndicator has not been fitted".to_string(),
312            ));
313        }
314
315        let features_with_missing = self.features_.as_ref().unwrap();
316        let n_output_features = features_with_missing.len();
317
318        let mut indicators = Array2::zeros((n_samples, n_output_features));
319
320        for i in 0..n_samples {
321            for (out_j, &orig_j) in features_with_missing.iter().enumerate() {
322                if self.is_missing(x_f64[[i, orig_j]]) {
323                    indicators[[i, out_j]] = 1.0;
324                }
325            }
326        }
327
328        Ok(indicators)
329    }
330
331    /// Fits the MissingIndicator to the input data and transforms it
332    ///
333    /// # Arguments
334    /// * `x` - The input data, shape (n_samples, n_features)
335    ///
336    /// # Returns
337    /// * `Result<Array2<f64>>` - Binary indicator matrix
338    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
339    where
340        S: Data,
341        S::Elem: Float + NumCast,
342    {
343        self.fit(x)?;
344        self.transform(x)
345    }
346
347    /// Returns the features that have missing values
348    ///
349    /// # Returns
350    /// * `Option<&Vec<usize>>` - Indices of features with missing values
351    pub fn features(&self) -> Option<&Vec<usize>> {
352        self.features_.as_ref()
353    }
354
355    /// Checks if a value is considered missing
356    ///
357    /// # Arguments
358    /// * `value` - The value to check
359    ///
360    /// # Returns
361    /// * `bool` - True if the value is missing, false otherwise
362    fn is_missing(&self, value: f64) -> bool {
363        if self.missingvalues.is_nan() {
364            value.is_nan()
365        } else {
366            (value - self.missingvalues).abs() < f64::EPSILON
367        }
368    }
369}
370
371/// Distance metric for k-nearest neighbors search
372#[derive(Debug, Clone, PartialEq)]
373pub enum DistanceMetric {
374    /// Euclidean distance (L2 norm)
375    Euclidean,
376    /// Manhattan distance (L1 norm)
377    Manhattan,
378}
379
380/// Weighting scheme for k-nearest neighbors imputation
381#[derive(Debug, Clone, PartialEq)]
382pub enum WeightingScheme {
383    /// All neighbors contribute equally
384    Uniform,
385    /// Weight by inverse distance (closer neighbors have more influence)
386    Distance,
387}
388
389/// K-Nearest Neighbors Imputer for filling missing values
390///
391/// This transformer fills missing values using k-nearest neighbors.
392/// For each sample, the missing features are imputed from the nearest
393/// neighbors that have a value for that feature.
394pub struct KNNImputer {
395    /// Number of nearest neighbors to use
396    _nneighbors: usize,
397    /// Distance metric to use for finding neighbors
398    metric: DistanceMetric,
399    /// Weighting scheme for aggregating neighbor values
400    weights: WeightingScheme,
401    /// Missing value indicator (what value is considered missing)
402    missingvalues: f64,
403    /// Training data (stored to find neighbors during transform)
404    x_train_: Option<Array2<f64>>,
405}
406
407impl KNNImputer {
408    /// Creates a new KNNImputer
409    ///
410    /// # Arguments
411    /// * `_nneighbors` - Number of neighboring samples to use for imputation
412    /// * `metric` - Distance metric for finding neighbors
413    /// * `weights` - Weight function used in imputation
414    /// * `missingvalues` - The value that represents missing data (default: NaN)
415    ///
416    /// # Returns
417    /// * A new KNNImputer instance
418    pub fn new(
419        _nneighbors: usize,
420        metric: DistanceMetric,
421        weights: WeightingScheme,
422        missingvalues: f64,
423    ) -> Self {
424        KNNImputer {
425            _nneighbors,
426            metric,
427            weights,
428            missingvalues,
429            x_train_: None,
430        }
431    }
432
433    /// Creates a KNNImputer with default parameters
434    ///
435    /// Uses 5 neighbors, Euclidean distance, uniform weighting, and NaN as missing values
436    pub fn with_defaults() -> Self {
437        Self::new(
438            5,
439            DistanceMetric::Euclidean,
440            WeightingScheme::Uniform,
441            f64::NAN,
442        )
443    }
444
445    /// Creates a KNNImputer with specified number of neighbors and defaults for other parameters
446    pub fn with_n_neighbors(_nneighbors: usize) -> Self {
447        Self::new(
448            _nneighbors,
449            DistanceMetric::Euclidean,
450            WeightingScheme::Uniform,
451            f64::NAN,
452        )
453    }
454
455    /// Creates a KNNImputer with distance weighting
456    pub fn with_distance_weighting(_nneighbors: usize) -> Self {
457        Self::new(
458            _nneighbors,
459            DistanceMetric::Euclidean,
460            WeightingScheme::Distance,
461            f64::NAN,
462        )
463    }
464
465    /// Fits the KNNImputer to the input data
466    ///
467    /// # Arguments
468    /// * `x` - The input data, shape (n_samples, n_features)
469    ///
470    /// # Returns
471    /// * `Result<()>` - Ok if successful, Err otherwise
472    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
473    where
474        S: Data,
475        S::Elem: Float + NumCast,
476    {
477        let x_f64 = x.mapv(|x| num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0));
478
479        // Validate that we have enough samples for k-nearest neighbors
480        let n_samples = x_f64.shape()[0];
481        if n_samples < self._nneighbors {
482            return Err(TransformError::InvalidInput(format!(
483                "Number of samples ({}) must be >= _nneighbors ({})",
484                n_samples, self._nneighbors
485            )));
486        }
487
488        // Store training data for neighbor search during transform
489        self.x_train_ = Some(x_f64);
490        Ok(())
491    }
492
493    /// Transforms the input data by imputing missing values
494    ///
495    /// # Arguments
496    /// * `x` - The input data, shape (n_samples, n_features)
497    ///
498    /// # Returns
499    /// * `Result<Array2<f64>>` - Transformed data with missing values imputed
500    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
501    where
502        S: Data,
503        S::Elem: Float + NumCast,
504    {
505        let x_f64 = x.mapv(|x| num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0));
506
507        if self.x_train_.is_none() {
508            return Err(TransformError::TransformationError(
509                "KNNImputer must be fitted before transform".to_string(),
510            ));
511        }
512
513        let x_train = self.x_train_.as_ref().unwrap();
514        let (n_samples, n_features) = x_f64.dim();
515
516        if n_features != x_train.shape()[1] {
517            return Err(TransformError::InvalidInput(format!(
518                "Number of features in transform data ({}) doesn't match training data ({})",
519                n_features,
520                x_train.shape()[1]
521            )));
522        }
523
524        let mut result = x_f64.clone();
525
526        // Process each sample
527        for i in 0..n_samples {
528            let sample = x_f64.row(i);
529
530            // Find features that need imputation
531            let missing_features: Vec<usize> = (0..n_features)
532                .filter(|&j| self.is_missing(sample[j]))
533                .collect();
534
535            if missing_features.is_empty() {
536                continue; // No missing values in this sample
537            }
538
539            // Find k-nearest neighbors for this sample (excluding itself)
540            let neighbors =
541                self.find_nearest_neighbors_excluding(&sample.to_owned(), x_train, i)?;
542
543            // Impute each missing feature
544            for &feature_idx in &missing_features {
545                let imputed_value = self.impute_feature(feature_idx, &neighbors, x_train)?;
546                result[[i, feature_idx]] = imputed_value;
547            }
548        }
549
550        Ok(result)
551    }
552
553    /// Fits the imputer and transforms the data in one step
554    ///
555    /// # Arguments
556    /// * `x` - The input data, shape (n_samples, n_features)
557    ///
558    /// # Returns
559    /// * `Result<Array2<f64>>` - Transformed data with missing values imputed
560    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
561    where
562        S: Data,
563        S::Elem: Float + NumCast,
564    {
565        self.fit(x)?;
566        self.transform(x)
567    }
568
569    /// Find k-nearest neighbors for a given sample, excluding a specific index
570    fn find_nearest_neighbors_excluding(
571        &self,
572        sample: &Array1<f64>,
573        x_train: &Array2<f64>,
574        exclude_idx: usize,
575    ) -> Result<Vec<usize>> {
576        let n_train_samples = x_train.shape()[0];
577
578        // Compute distances to all training samples (excluding the specified index)
579        let distances: Vec<(usize, f64)> = (0..n_train_samples)
580            .into_par_iter()
581            .filter(|&i| i != exclude_idx)
582            .map(|i| {
583                let train_sample = x_train.row(i);
584                let distance = self.compute_distance(sample, &train_sample.to_owned());
585                (i, distance)
586            })
587            .collect();
588
589        // Sort by distance and take k nearest
590        let mut sorted_distances = distances;
591        sorted_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
592
593        let neighbors: Vec<usize> = sorted_distances
594            .into_iter()
595            .take(self._nneighbors)
596            .map(|(idx_, _)| idx_)
597            .collect();
598
599        Ok(neighbors)
600    }
601
602    /// Compute distance between two samples, handling missing values
603    fn compute_distance(&self, sample1: &Array1<f64>, sample2: &Array1<f64>) -> f64 {
604        let n_features = sample1.len();
605        let mut distance = 0.0;
606        let mut valid_features = 0;
607
608        for i in 0..n_features {
609            let val1 = sample1[i];
610            let val2 = sample2[i];
611
612            // Skip features where either sample has missing values
613            if self.is_missing(val1) || self.is_missing(val2) {
614                continue;
615            }
616
617            valid_features += 1;
618            let diff = val1 - val2;
619
620            match self.metric {
621                DistanceMetric::Euclidean => {
622                    distance += diff * diff;
623                }
624                DistanceMetric::Manhattan => {
625                    distance += diff.abs();
626                }
627            }
628        }
629
630        // Handle case where no valid features for comparison
631        if valid_features == 0 {
632            return f64::INFINITY;
633        }
634
635        // Normalize by number of valid features to make distances comparable
636        distance /= valid_features as f64;
637
638        match self.metric {
639            DistanceMetric::Euclidean => distance.sqrt(),
640            DistanceMetric::Manhattan => distance,
641        }
642    }
643
644    /// Impute a single feature using the k-nearest neighbors
645    fn impute_feature(
646        &self,
647        feature_idx: usize,
648        neighbors: &[usize],
649        x_train: &Array2<f64>,
650    ) -> Result<f64> {
651        let mut values = Vec::new();
652        let mut weights = Vec::new();
653
654        // Collect non-missing values from neighbors for this feature
655        for &neighbor_idx in neighbors {
656            let neighbor_value = x_train[[neighbor_idx, feature_idx]];
657
658            if !self.is_missing(neighbor_value) {
659                values.push(neighbor_value);
660
661                // Compute weight based on weighting scheme
662                let weight = match self.weights {
663                    WeightingScheme::Uniform => 1.0,
664                    WeightingScheme::Distance => {
665                        // For distance weighting, we need to recompute distance
666                        // This is a simplified version - could be optimized by storing distances
667                        1.0 // Placeholder - in practice, would use inverse distance
668                    }
669                };
670                weights.push(weight);
671            }
672        }
673
674        if values.is_empty() {
675            return Err(TransformError::TransformationError(format!(
676                "No valid neighbors found for feature {feature_idx} imputation"
677            )));
678        }
679
680        // Compute weighted average
681        let total_weight: f64 = weights.iter().sum();
682        if total_weight == 0.0 {
683            return Err(TransformError::TransformationError(
684                "Total weight is zero for imputation".to_string(),
685            ));
686        }
687
688        let weighted_sum: f64 = values
689            .iter()
690            .zip(weights.iter())
691            .map(|(&val, &weight)| val * weight)
692            .sum();
693
694        Ok(weighted_sum / total_weight)
695    }
696
697    /// Checks if a value is considered missing
698    fn is_missing(&self, value: f64) -> bool {
699        if self.missingvalues.is_nan() {
700            value.is_nan()
701        } else {
702            (value - self.missingvalues).abs() < f64::EPSILON
703        }
704    }
705
706    /// Returns the number of neighbors used for imputation
707    pub fn _nneighbors(&self) -> usize {
708        self._nneighbors
709    }
710
711    /// Returns the distance metric used
712    pub fn metric(&self) -> &DistanceMetric {
713        &self.metric
714    }
715
716    /// Returns the weighting scheme used
717    pub fn weights(&self) -> &WeightingScheme {
718        &self.weights
719    }
720}
721
722/// Simple regression model for MICE imputation
723///
724/// This is a basic linear regression implementation for use in the MICE algorithm.
725/// It uses a simple least squares approach with optional regularization.
726#[derive(Debug, Clone)]
727struct SimpleRegressor {
728    /// Regression coefficients (including intercept as first element)
729    coefficients: Option<Array1<f64>>,
730    /// Whether to include an intercept term
731    includeintercept: bool,
732    /// Regularization parameter (ridge regression)
733    alpha: f64,
734}
735
736impl SimpleRegressor {
737    /// Create a new simple regressor
738    fn new(includeintercept: bool, alpha: f64) -> Self {
739        Self {
740            coefficients: None,
741            includeintercept,
742            alpha,
743        }
744    }
745
746    /// Fit the regressor to the data
747    fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<()> {
748        let (n_samples, n_features) = x.dim();
749
750        if n_samples != y.len() {
751            return Err(TransformError::InvalidInput(
752                "X and y must have the same number of samples".to_string(),
753            ));
754        }
755
756        // Add intercept column if needed
757        let x_design = if self.includeintercept {
758            let mut x_with_intercept = Array2::ones((n_samples, n_features + 1));
759            x_with_intercept.slice_mut(ndarray::s![.., 1..]).assign(x);
760            x_with_intercept
761        } else {
762            x.to_owned()
763        };
764
765        // Solve normal equations: (X^T X + alpha*I) * beta = X^T y
766        let xtx = x_design.t().dot(&x_design);
767        let xty = x_design.t().dot(y);
768
769        // Add regularization
770        let mut regularized_xtx = xtx;
771        let n_coeffs = regularized_xtx.shape()[0];
772        for i in 0..n_coeffs {
773            regularized_xtx[[i, i]] += self.alpha;
774        }
775
776        // Solve using simple Gaussian elimination (for small problems)
777        self.coefficients = Some(self.solve_linear_system(&regularized_xtx, &xty)?);
778
779        Ok(())
780    }
781
782    /// Predict using the fitted regressor
783    fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>> {
784        let coeffs = self.coefficients.as_ref().ok_or_else(|| {
785            TransformError::TransformationError(
786                "Regressor must be fitted before prediction".to_string(),
787            )
788        })?;
789
790        let x_design = if self.includeintercept {
791            let (n_samples, n_features) = x.dim();
792            let mut x_with_intercept = Array2::ones((n_samples, n_features + 1));
793            x_with_intercept.slice_mut(ndarray::s![.., 1..]).assign(x);
794            x_with_intercept
795        } else {
796            x.to_owned()
797        };
798
799        Ok(x_design.dot(coeffs))
800    }
801
802    /// Simple linear system solver using Gaussian elimination
803    fn solve_linear_system(&self, a: &Array2<f64>, b: &Array1<f64>) -> Result<Array1<f64>> {
804        let n = a.shape()[0];
805        let mut aug_matrix = Array2::zeros((n, n + 1));
806
807        // Create augmented matrix [A|b]
808        aug_matrix.slice_mut(ndarray::s![.., ..n]).assign(a);
809        aug_matrix.slice_mut(ndarray::s![.., n]).assign(b);
810
811        // Forward elimination
812        for i in 0..n {
813            // Find pivot
814            let mut max_row = i;
815            for k in (i + 1)..n {
816                if aug_matrix[[k, i]].abs() > aug_matrix[[max_row, i]].abs() {
817                    max_row = k;
818                }
819            }
820
821            // Swap rows
822            if max_row != i {
823                for j in 0..=n {
824                    let temp = aug_matrix[[i, j]];
825                    aug_matrix[[i, j]] = aug_matrix[[max_row, j]];
826                    aug_matrix[[max_row, j]] = temp;
827                }
828            }
829
830            // Check for singular matrix
831            if aug_matrix[[i, i]].abs() < 1e-12 {
832                return Err(TransformError::TransformationError(
833                    "Singular matrix in regression".to_string(),
834                ));
835            }
836
837            // Make diagonal element 1
838            let pivot = aug_matrix[[i, i]];
839            for j in i..=n {
840                aug_matrix[[i, j]] /= pivot;
841            }
842
843            // Eliminate column
844            for k in 0..n {
845                if k != i {
846                    let factor = aug_matrix[[k, i]];
847                    for j in i..=n {
848                        aug_matrix[[k, j]] -= factor * aug_matrix[[i, j]];
849                    }
850                }
851            }
852        }
853
854        // Extract solution
855        let mut solution = Array1::zeros(n);
856        for i in 0..n {
857            solution[i] = aug_matrix[[i, n]];
858        }
859
860        Ok(solution)
861    }
862}
863
864/// Iterative Imputer using the MICE (Multiple Imputation by Chained Equations) algorithm
865///
866/// This transformer iteratively models each feature with missing values as a function
867/// of other features. The algorithm performs multiple rounds of imputation where each
868/// feature is predicted using the other features in a round-robin fashion.
869///
870/// MICE is particularly useful when:
871/// - There are multiple features with missing values
872/// - The missing patterns are complex
873/// - You want to model relationships between features
874pub struct IterativeImputer {
875    /// Maximum number of iterations to perform
876    max_iter: usize,
877    /// Convergence tolerance (change in imputed values between iterations)
878    tolerance: f64,
879    /// Initial strategy for first round of imputation
880    initial_strategy: ImputeStrategy,
881    /// Random seed for reproducibility
882    random_seed: Option<u64>,
883    /// Missing value indicator
884    missingvalues: f64,
885    /// Regularization parameter for regression
886    alpha: f64,
887    /// Minimum improvement to continue iterating
888    min_improvement: f64,
889
890    // Internal state
891    /// Training data for fitting predictors
892    x_train_: Option<Array2<f64>>,
893    /// Indices of features that had missing values during fitting
894    missing_features_: Option<Vec<usize>>,
895    /// Initial imputation values for features
896    initial_values_: Option<Array1<f64>>,
897    /// Whether the imputer has been fitted
898    is_fitted_: bool,
899}
900
901impl IterativeImputer {
902    /// Creates a new IterativeImputer
903    ///
904    /// # Arguments
905    /// * `max_iter` - Maximum number of iterations
906    /// * `tolerance` - Convergence tolerance
907    /// * `initial_strategy` - Strategy for initial imputation
908    /// * `missingvalues` - Value representing missing data
909    /// * `alpha` - Regularization parameter for regression
910    ///
911    /// # Returns
912    /// * A new IterativeImputer instance
913    pub fn new(
914        max_iter: usize,
915        tolerance: f64,
916        initial_strategy: ImputeStrategy,
917        missingvalues: f64,
918        alpha: f64,
919    ) -> Self {
920        IterativeImputer {
921            max_iter,
922            tolerance,
923            initial_strategy,
924            random_seed: None,
925            missingvalues,
926            alpha,
927            min_improvement: 1e-6,
928            x_train_: None,
929            missing_features_: None,
930            initial_values_: None,
931            is_fitted_: false,
932        }
933    }
934
935    /// Creates an IterativeImputer with default parameters
936    ///
937    /// Uses 10 iterations, 1e-3 tolerance, mean initial strategy, NaN missing values,
938    /// and 1e-6 regularization.
939    pub fn with_defaults() -> Self {
940        Self::new(10, 1e-3, ImputeStrategy::Mean, f64::NAN, 1e-6)
941    }
942
943    /// Creates an IterativeImputer with specified max iterations and defaults for other parameters
944    pub fn with_max_iter(_maxiter: usize) -> Self {
945        Self::new(_maxiter, 1e-3, ImputeStrategy::Mean, f64::NAN, 1e-6)
946    }
947
948    /// Set the random seed for reproducible results
949    pub fn with_random_seed(mut self, seed: u64) -> Self {
950        self.random_seed = Some(seed);
951        self
952    }
953
954    /// Set the regularization parameter
955    pub fn with_alpha(mut self, alpha: f64) -> Self {
956        self.alpha = alpha;
957        self
958    }
959
960    /// Set the minimum improvement threshold
961    pub fn with_min_improvement(mut self, minimprovement: f64) -> Self {
962        self.min_improvement = minimprovement;
963        self
964    }
965
966    /// Fits the IterativeImputer to the input data
967    ///
968    /// # Arguments
969    /// * `x` - The input data, shape (n_samples, n_features)
970    ///
971    /// # Returns
972    /// * `Result<()>` - Ok if successful, Err otherwise
973    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
974    where
975        S: Data,
976        S::Elem: Float + NumCast,
977    {
978        let x_f64 = x.mapv(|x| num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0));
979        let (n_samples, n_features) = x_f64.dim();
980
981        if n_samples == 0 || n_features == 0 {
982            return Err(TransformError::InvalidInput("Empty input data".to_string()));
983        }
984
985        // Find features that have missing values
986        let missing_features: Vec<usize> = (0..n_features)
987            .filter(|&j| x_f64.column(j).iter().any(|&val| self.is_missing(val)))
988            .collect();
989
990        if missing_features.is_empty() {
991            // No missing values, store data as-is
992            self.x_train_ = Some(x_f64);
993            self.missing_features_ = Some(Vec::new());
994            self.initial_values_ = Some(Array1::zeros(0));
995            self.is_fitted_ = true;
996            return Ok(());
997        }
998
999        // Compute initial imputation values for each feature
1000        let mut initial_values = Array1::zeros(n_features);
1001        for &feature_idx in &missing_features {
1002            let feature_data: Vec<f64> = x_f64
1003                .column(feature_idx)
1004                .iter()
1005                .filter(|&&val| !self.is_missing(val))
1006                .copied()
1007                .collect();
1008
1009            if feature_data.is_empty() {
1010                return Err(TransformError::InvalidInput(format!(
1011                    "All values are missing in feature {feature_idx}"
1012                )));
1013            }
1014
1015            initial_values[feature_idx] = match &self.initial_strategy {
1016                ImputeStrategy::Mean => {
1017                    feature_data.iter().sum::<f64>() / feature_data.len() as f64
1018                }
1019                ImputeStrategy::Median => {
1020                    let mut sorted_data = feature_data;
1021                    sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap());
1022                    let len = sorted_data.len();
1023                    if len % 2 == 0 {
1024                        (sorted_data[len / 2 - 1] + sorted_data[len / 2]) / 2.0
1025                    } else {
1026                        sorted_data[len / 2]
1027                    }
1028                }
1029                ImputeStrategy::MostFrequent => {
1030                    // For continuous data, use mean as approximation
1031                    feature_data.iter().sum::<f64>() / feature_data.len() as f64
1032                }
1033                ImputeStrategy::Constant(value) => *value,
1034            };
1035        }
1036
1037        self.x_train_ = Some(x_f64);
1038        self.missing_features_ = Some(missing_features);
1039        self.initial_values_ = Some(initial_values);
1040        self.is_fitted_ = true;
1041
1042        Ok(())
1043    }
1044
1045    /// Transforms the input data by imputing missing values using MICE
1046    ///
1047    /// # Arguments
1048    /// * `x` - The input data, shape (n_samples, n_features)
1049    ///
1050    /// # Returns
1051    /// * `Result<Array2<f64>>` - Transformed data with missing values imputed
1052    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
1053    where
1054        S: Data,
1055        S::Elem: Float + NumCast,
1056    {
1057        if !self.is_fitted_ {
1058            return Err(TransformError::TransformationError(
1059                "IterativeImputer must be fitted before transform".to_string(),
1060            ));
1061        }
1062
1063        let x_f64 = x.mapv(|x| num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0));
1064        let missing_features = self.missing_features_.as_ref().unwrap();
1065
1066        if missing_features.is_empty() {
1067            // No missing values in training data, return as-is
1068            return Ok(x_f64);
1069        }
1070
1071        let initial_values = self.initial_values_.as_ref().unwrap();
1072        let (n_samples, n_features) = x_f64.dim();
1073
1074        // Start with initial imputation
1075        let mut imputed_data = x_f64.clone();
1076        self.apply_initial_imputation(&mut imputed_data, initial_values)?;
1077
1078        // MICE iterations
1079        for iteration in 0..self.max_iter {
1080            let mut max_change = 0.0;
1081            let old_imputed_data = imputed_data.clone();
1082
1083            // Iterate through each feature with missing values
1084            for &feature_idx in missing_features {
1085                // Find samples with missing values for this feature
1086                let missing_mask: Vec<bool> = (0..n_samples)
1087                    .map(|i| self.is_missing(x_f64[[i, feature_idx]]))
1088                    .collect();
1089
1090                if !missing_mask.iter().any(|&x| x) {
1091                    continue; // No missing values for this feature
1092                }
1093
1094                // Prepare predictors (all other features)
1095                let predictor_indices: Vec<usize> =
1096                    (0..n_features).filter(|&i| i != feature_idx).collect();
1097
1098                // Create training data from samples without missing values for this feature
1099                let (train_x, train_y) = self.prepare_training_data(
1100                    &imputed_data,
1101                    feature_idx,
1102                    &predictor_indices,
1103                    &missing_mask,
1104                )?;
1105
1106                if train_x.is_empty() {
1107                    continue; // Cannot train predictor
1108                }
1109
1110                // Fit predictor
1111                let mut regressor = SimpleRegressor::new(true, self.alpha);
1112                regressor.fit(&train_x, &train_y)?;
1113
1114                // Predict missing values
1115                let test_x =
1116                    self.prepare_test_data(&imputed_data, &predictor_indices, &missing_mask)?;
1117
1118                if !test_x.is_empty() {
1119                    let predictions = regressor.predict(&test_x)?;
1120
1121                    // Update imputed values
1122                    let mut pred_idx = 0;
1123                    for i in 0..n_samples {
1124                        if missing_mask[i] {
1125                            let old_value = imputed_data[[i, feature_idx]];
1126                            let new_value = predictions[pred_idx];
1127                            imputed_data[[i, feature_idx]] = new_value;
1128
1129                            let change = (new_value - old_value).abs();
1130                            max_change = max_change.max(change);
1131                            pred_idx += 1;
1132                        }
1133                    }
1134                }
1135            }
1136
1137            // Check convergence
1138            if max_change < self.tolerance {
1139                break;
1140            }
1141
1142            // Check for minimum improvement
1143            if iteration > 0 {
1144                let total_change = self.compute_total_change(&old_imputed_data, &imputed_data);
1145                if total_change < self.min_improvement {
1146                    break;
1147                }
1148            }
1149        }
1150
1151        Ok(imputed_data)
1152    }
1153
1154    /// Fits the imputer and transforms the data in one step
1155    ///
1156    /// # Arguments
1157    /// * `x` - The input data, shape (n_samples, n_features)
1158    ///
1159    /// # Returns
1160    /// * `Result<Array2<f64>>` - Transformed data with missing values imputed
1161    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
1162    where
1163        S: Data,
1164        S::Elem: Float + NumCast,
1165    {
1166        self.fit(x)?;
1167        self.transform(x)
1168    }
1169
1170    /// Apply initial imputation using the specified strategy
1171    fn apply_initial_imputation(
1172        &self,
1173        data: &mut Array2<f64>,
1174        initial_values: &Array1<f64>,
1175    ) -> Result<()> {
1176        let (n_samples, n_features) = data.dim();
1177
1178        for i in 0..n_samples {
1179            for j in 0..n_features {
1180                if self.is_missing(data[[i, j]]) {
1181                    data[[i, j]] = initial_values[j];
1182                }
1183            }
1184        }
1185
1186        Ok(())
1187    }
1188
1189    /// Prepare training data for a specific feature
1190    fn prepare_training_data(
1191        &self,
1192        data: &Array2<f64>,
1193        target_feature: usize,
1194        predictor_indices: &[usize],
1195        missing_mask: &[bool],
1196    ) -> Result<(Array2<f64>, Array1<f64>)> {
1197        let n_samples = data.shape()[0];
1198        let n_predictors = predictor_indices.len();
1199
1200        // Count non-missing samples
1201        let non_missing_count = missing_mask.iter().filter(|&&x| !x).count();
1202
1203        if non_missing_count == 0 {
1204            return Ok((Array2::zeros((0, n_predictors)), Array1::zeros(0)));
1205        }
1206
1207        let mut train_x = Array2::zeros((non_missing_count, n_predictors));
1208        let mut train_y = Array1::zeros(non_missing_count);
1209
1210        let mut train_idx = 0;
1211        for i in 0..n_samples {
1212            if !missing_mask[i] {
1213                // Copy predictor features
1214                for (pred_j, &orig_j) in predictor_indices.iter().enumerate() {
1215                    train_x[[train_idx, pred_j]] = data[[i, orig_j]];
1216                }
1217                // Copy target _feature
1218                train_y[train_idx] = data[[i, target_feature]];
1219                train_idx += 1;
1220            }
1221        }
1222
1223        Ok((train_x, train_y))
1224    }
1225
1226    /// Prepare test data for prediction
1227    fn prepare_test_data(
1228        &self,
1229        data: &Array2<f64>,
1230        predictor_indices: &[usize],
1231        missing_mask: &[bool],
1232    ) -> Result<Array2<f64>> {
1233        let n_samples = data.shape()[0];
1234        let n_predictors = predictor_indices.len();
1235
1236        // Count missing samples
1237        let missing_count = missing_mask.iter().filter(|&&x| x).count();
1238
1239        if missing_count == 0 {
1240            return Ok(Array2::zeros((0, n_predictors)));
1241        }
1242
1243        let mut test_x = Array2::zeros((missing_count, n_predictors));
1244
1245        let mut test_idx = 0;
1246        for i in 0..n_samples {
1247            if missing_mask[i] {
1248                // Copy predictor features
1249                for (pred_j, &orig_j) in predictor_indices.iter().enumerate() {
1250                    test_x[[test_idx, pred_j]] = data[[i, orig_j]];
1251                }
1252                test_idx += 1;
1253            }
1254        }
1255
1256        Ok(test_x)
1257    }
1258
1259    /// Compute total change between two imputation iterations
1260    fn compute_total_change(&self, old_data: &Array2<f64>, newdata: &Array2<f64>) -> f64 {
1261        let diff = newdata - old_data;
1262        diff.iter().map(|&x| x * x).sum::<f64>().sqrt()
1263    }
1264
1265    /// Check if a value is considered missing
1266    fn is_missing(&self, value: f64) -> bool {
1267        if self.missingvalues.is_nan() {
1268            value.is_nan()
1269        } else {
1270            (value - self.missingvalues).abs() < f64::EPSILON
1271        }
1272    }
1273}
1274
1275#[cfg(test)]
1276mod tests {
1277    use super::*;
1278    use approx::assert_abs_diff_eq;
1279    use ndarray::Array;
1280
1281    #[test]
1282    fn test_simple_imputer_mean() {
1283        // Create test data with NaN values
1284        let data = Array::from_shape_vec(
1285            (4, 3),
1286            vec![
1287                1.0,
1288                2.0,
1289                3.0,
1290                f64::NAN,
1291                5.0,
1292                6.0,
1293                7.0,
1294                f64::NAN,
1295                9.0,
1296                10.0,
1297                11.0,
1298                f64::NAN,
1299            ],
1300        )
1301        .unwrap();
1302
1303        let mut imputer = SimpleImputer::with_strategy(ImputeStrategy::Mean);
1304        let transformed = imputer.fit_transform(&data).unwrap();
1305
1306        // Check shape is preserved
1307        assert_eq!(transformed.shape(), &[4, 3]);
1308
1309        // Check that mean values were used for imputation
1310        // Column 0: mean of [1.0, 7.0, 10.0] = 6.0
1311        // Column 1: mean of [2.0, 5.0, 11.0] = 6.0
1312        // Column 2: mean of [3.0, 6.0, 9.0] = 6.0
1313
1314        assert_abs_diff_eq!(transformed[[0, 0]], 1.0, epsilon = 1e-10);
1315        assert_abs_diff_eq!(transformed[[1, 0]], 6.0, epsilon = 1e-10); // Imputed
1316        assert_abs_diff_eq!(transformed[[2, 0]], 7.0, epsilon = 1e-10);
1317        assert_abs_diff_eq!(transformed[[3, 0]], 10.0, epsilon = 1e-10);
1318
1319        assert_abs_diff_eq!(transformed[[0, 1]], 2.0, epsilon = 1e-10);
1320        assert_abs_diff_eq!(transformed[[1, 1]], 5.0, epsilon = 1e-10);
1321        assert_abs_diff_eq!(transformed[[2, 1]], 6.0, epsilon = 1e-10); // Imputed
1322        assert_abs_diff_eq!(transformed[[3, 1]], 11.0, epsilon = 1e-10);
1323
1324        assert_abs_diff_eq!(transformed[[0, 2]], 3.0, epsilon = 1e-10);
1325        assert_abs_diff_eq!(transformed[[1, 2]], 6.0, epsilon = 1e-10);
1326        assert_abs_diff_eq!(transformed[[2, 2]], 9.0, epsilon = 1e-10);
1327        assert_abs_diff_eq!(transformed[[3, 2]], 6.0, epsilon = 1e-10); // Imputed
1328    }
1329
1330    #[test]
1331    fn test_simple_imputer_median() {
1332        // Create test data with NaN values
1333        let data = Array::from_shape_vec(
1334            (5, 2),
1335            vec![
1336                1.0,
1337                10.0,
1338                f64::NAN,
1339                20.0,
1340                3.0,
1341                f64::NAN,
1342                4.0,
1343                40.0,
1344                5.0,
1345                50.0,
1346            ],
1347        )
1348        .unwrap();
1349
1350        let mut imputer = SimpleImputer::with_strategy(ImputeStrategy::Median);
1351        let transformed = imputer.fit_transform(&data).unwrap();
1352
1353        // Check shape is preserved
1354        assert_eq!(transformed.shape(), &[5, 2]);
1355
1356        // Column 0: median of [1.0, 3.0, 4.0, 5.0] = 3.5
1357        // Column 1: median of [10.0, 20.0, 40.0, 50.0] = 30.0
1358
1359        assert_abs_diff_eq!(transformed[[1, 0]], 3.5, epsilon = 1e-10); // Imputed
1360        assert_abs_diff_eq!(transformed[[2, 1]], 30.0, epsilon = 1e-10); // Imputed
1361    }
1362
1363    #[test]
1364    fn test_simple_imputer_constant() {
1365        // Create test data with NaN values
1366        let data =
1367            Array::from_shape_vec((3, 2), vec![1.0, f64::NAN, f64::NAN, 3.0, 4.0, 5.0]).unwrap();
1368
1369        let mut imputer = SimpleImputer::with_strategy(ImputeStrategy::Constant(99.0));
1370        let transformed = imputer.fit_transform(&data).unwrap();
1371
1372        // Check that constant value was used for imputation
1373        assert_abs_diff_eq!(transformed[[0, 1]], 99.0, epsilon = 1e-10); // Imputed
1374        assert_abs_diff_eq!(transformed[[1, 0]], 99.0, epsilon = 1e-10); // Imputed
1375
1376        // Non-missing values should remain unchanged
1377        assert_abs_diff_eq!(transformed[[0, 0]], 1.0, epsilon = 1e-10);
1378        assert_abs_diff_eq!(transformed[[1, 1]], 3.0, epsilon = 1e-10);
1379        assert_abs_diff_eq!(transformed[[2, 0]], 4.0, epsilon = 1e-10);
1380        assert_abs_diff_eq!(transformed[[2, 1]], 5.0, epsilon = 1e-10);
1381    }
1382
1383    #[test]
1384    fn test_missing_indicator() {
1385        // Create test data with NaN values
1386        let data = Array::from_shape_vec(
1387            (3, 4),
1388            vec![
1389                1.0,
1390                f64::NAN,
1391                3.0,
1392                4.0,
1393                f64::NAN,
1394                6.0,
1395                f64::NAN,
1396                8.0,
1397                9.0,
1398                10.0,
1399                11.0,
1400                f64::NAN,
1401            ],
1402        )
1403        .unwrap();
1404
1405        let mut indicator = MissingIndicator::with_nan();
1406        let indicators = indicator.fit_transform(&data).unwrap();
1407
1408        // All features have missing values, so output shape should be (3, 4)
1409        assert_eq!(indicators.shape(), &[3, 4]);
1410
1411        // Check indicators
1412        assert_abs_diff_eq!(indicators[[0, 0]], 0.0, epsilon = 1e-10); // Not missing
1413        assert_abs_diff_eq!(indicators[[0, 1]], 1.0, epsilon = 1e-10); // Missing
1414        assert_abs_diff_eq!(indicators[[0, 2]], 0.0, epsilon = 1e-10); // Not missing
1415        assert_abs_diff_eq!(indicators[[0, 3]], 0.0, epsilon = 1e-10); // Not missing
1416
1417        assert_abs_diff_eq!(indicators[[1, 0]], 1.0, epsilon = 1e-10); // Missing
1418        assert_abs_diff_eq!(indicators[[1, 1]], 0.0, epsilon = 1e-10); // Not missing
1419        assert_abs_diff_eq!(indicators[[1, 2]], 1.0, epsilon = 1e-10); // Missing
1420        assert_abs_diff_eq!(indicators[[1, 3]], 0.0, epsilon = 1e-10); // Not missing
1421
1422        assert_abs_diff_eq!(indicators[[2, 0]], 0.0, epsilon = 1e-10); // Not missing
1423        assert_abs_diff_eq!(indicators[[2, 1]], 0.0, epsilon = 1e-10); // Not missing
1424        assert_abs_diff_eq!(indicators[[2, 2]], 0.0, epsilon = 1e-10); // Not missing
1425        assert_abs_diff_eq!(indicators[[2, 3]], 1.0, epsilon = 1e-10); // Missing
1426    }
1427
1428    #[test]
1429    fn test_imputer_errors() {
1430        // Test error when all values are missing in a feature
1431        let data = Array::from_shape_vec((2, 2), vec![f64::NAN, 1.0, f64::NAN, 2.0]).unwrap();
1432
1433        let mut imputer = SimpleImputer::with_strategy(ImputeStrategy::Mean);
1434        assert!(imputer.fit(&data).is_err());
1435    }
1436
1437    #[test]
1438    fn test_knn_imputer_basic() {
1439        // Create test data with missing values
1440        // Dataset:
1441        // [1.0, 2.0, 3.0]
1442        // [4.0, NaN, 6.0]
1443        // [7.0, 8.0, NaN]
1444        // [10.0, 11.0, 12.0]
1445        let data = Array::from_shape_vec(
1446            (4, 3),
1447            vec![
1448                1.0,
1449                2.0,
1450                3.0,
1451                4.0,
1452                f64::NAN,
1453                6.0,
1454                7.0,
1455                8.0,
1456                f64::NAN,
1457                10.0,
1458                11.0,
1459                12.0,
1460            ],
1461        )
1462        .unwrap();
1463
1464        let mut imputer = KNNImputer::with_n_neighbors(2);
1465        let transformed = imputer.fit_transform(&data).unwrap();
1466
1467        // Check shape is preserved
1468        assert_eq!(transformed.shape(), &[4, 3]);
1469
1470        // Check that non-missing values are unchanged
1471        assert_abs_diff_eq!(transformed[[0, 0]], 1.0, epsilon = 1e-10);
1472        assert_abs_diff_eq!(transformed[[0, 1]], 2.0, epsilon = 1e-10);
1473        assert_abs_diff_eq!(transformed[[0, 2]], 3.0, epsilon = 1e-10);
1474        assert_abs_diff_eq!(transformed[[3, 0]], 10.0, epsilon = 1e-10);
1475        assert_abs_diff_eq!(transformed[[3, 1]], 11.0, epsilon = 1e-10);
1476        assert_abs_diff_eq!(transformed[[3, 2]], 12.0, epsilon = 1e-10);
1477
1478        // Missing values should have been imputed (values depend on neighbors chosen)
1479        assert!(!transformed[[1, 1]].is_nan()); // Should be imputed
1480        assert!(!transformed[[2, 2]].is_nan()); // Should be imputed
1481    }
1482
1483    #[test]
1484    fn test_knn_imputer_simple_case() {
1485        // Simple test case where neighbors are easy to determine
1486        let data = Array::from_shape_vec((3, 2), vec![1.0, 1.0, f64::NAN, 2.0, 3.0, 3.0]).unwrap();
1487
1488        let mut imputer = KNNImputer::with_n_neighbors(2);
1489        let transformed = imputer.fit_transform(&data).unwrap();
1490
1491        // The missing value [?, 2.0] should be imputed based on nearest neighbors
1492        // Neighbors should be [1.0, 1.0] and [3.0, 3.0]
1493        // Expected imputed value for feature 0 should be close to 2.0 (average of 1.0 and 3.0)
1494        assert_abs_diff_eq!(transformed[[1, 0]], 2.0, epsilon = 1e-1);
1495    }
1496
1497    #[test]
1498    fn test_knn_imputer_manhattan_distance() {
1499        let data =
1500            Array::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, f64::NAN, 2.0, 2.0, 10.0, 10.0])
1501                .unwrap();
1502
1503        let mut imputer = KNNImputer::new(
1504            2,
1505            DistanceMetric::Manhattan,
1506            WeightingScheme::Uniform,
1507            f64::NAN,
1508        );
1509        let transformed = imputer.fit_transform(&data).unwrap();
1510
1511        // With Manhattan distance, the closest neighbors to [1.0, ?] should be
1512        // [0.0, 0.0] and [2.0, 2.0], not [10.0, 10.0]
1513        assert!(!transformed[[1, 1]].is_nan());
1514        // The imputed value should be reasonable (around 1.0)
1515        assert!(transformed[[1, 1]] < 5.0); // Should not be close to 10.0
1516    }
1517
1518    #[test]
1519    fn test_knn_imputer_validation_errors() {
1520        // Test insufficient samples
1521        let small_data = Array::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1522        let mut imputer = KNNImputer::with_n_neighbors(5); // More neighbors than samples
1523        assert!(imputer.fit(&small_data).is_err());
1524
1525        // Test transform without fit
1526        let data =
1527            Array::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1528        let unfitted_imputer = KNNImputer::with_n_neighbors(2);
1529        assert!(unfitted_imputer.transform(&data).is_err());
1530    }
1531
1532    #[test]
1533    fn test_knn_imputer_no_missing_values() {
1534        // Test data with no missing values
1535        let data = Array::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1536
1537        let mut imputer = KNNImputer::with_n_neighbors(2);
1538        let transformed = imputer.fit_transform(&data).unwrap();
1539
1540        // Should be unchanged
1541        assert_eq!(transformed, data);
1542    }
1543
1544    #[test]
1545    fn test_knn_imputer_accessors() {
1546        let imputer = KNNImputer::new(
1547            3,
1548            DistanceMetric::Manhattan,
1549            WeightingScheme::Distance,
1550            -999.0,
1551        );
1552
1553        assert_eq!(imputer._nneighbors(), 3);
1554        assert_eq!(imputer.metric(), &DistanceMetric::Manhattan);
1555        assert_eq!(imputer.weights(), &WeightingScheme::Distance);
1556    }
1557
1558    #[test]
1559    fn test_knn_imputer_multiple_missing_features() {
1560        // Test sample with multiple missing features
1561        let data = Array::from_shape_vec(
1562            (4, 3),
1563            vec![
1564                1.0,
1565                2.0,
1566                3.0,
1567                f64::NAN,
1568                f64::NAN,
1569                6.0,
1570                7.0,
1571                8.0,
1572                9.0,
1573                10.0,
1574                11.0,
1575                12.0,
1576            ],
1577        )
1578        .unwrap();
1579
1580        let mut imputer = KNNImputer::with_n_neighbors(2);
1581        let transformed = imputer.fit_transform(&data).unwrap();
1582
1583        // Both missing values should be imputed
1584        assert!(!transformed[[1, 0]].is_nan());
1585        assert!(!transformed[[1, 1]].is_nan());
1586        // Non-missing value should remain unchanged
1587        assert_abs_diff_eq!(transformed[[1, 2]], 6.0, epsilon = 1e-10);
1588    }
1589
1590    #[test]
1591    fn test_iterative_imputer_basic() {
1592        // Create test data with missing values that have relationships
1593        // Dataset with correlated features:
1594        // Feature 0: [1.0, 2.0, 3.0, NaN]
1595        // Feature 1: [2.0, 4.0, NaN, 8.0] (roughly 2 * feature 0)
1596        let data = Array::from_shape_vec(
1597            (4, 2),
1598            vec![1.0, 2.0, 2.0, 4.0, 3.0, f64::NAN, f64::NAN, 8.0],
1599        )
1600        .unwrap();
1601
1602        let mut imputer = IterativeImputer::with_max_iter(5);
1603        let transformed = imputer.fit_transform(&data).unwrap();
1604
1605        // Check that missing values have been imputed
1606        assert!(!transformed[[2, 1]].is_nan()); // Feature 1 in row 2
1607        assert!(!transformed[[3, 0]].is_nan()); // Feature 0 in row 3
1608
1609        // Non-missing values should remain unchanged
1610        assert_abs_diff_eq!(transformed[[0, 0]], 1.0, epsilon = 1e-10);
1611        assert_abs_diff_eq!(transformed[[0, 1]], 2.0, epsilon = 1e-10);
1612        assert_abs_diff_eq!(transformed[[1, 0]], 2.0, epsilon = 1e-10);
1613        assert_abs_diff_eq!(transformed[[1, 1]], 4.0, epsilon = 1e-10);
1614        assert_abs_diff_eq!(transformed[[2, 0]], 3.0, epsilon = 1e-10);
1615        assert_abs_diff_eq!(transformed[[3, 1]], 8.0, epsilon = 1e-10);
1616
1617        // Check that the imputed values are reasonable given the linear relationship
1618        // Feature 1 should be approximately 2 * feature 0
1619        let imputed_f1_row2 = transformed[[2, 1]];
1620        let expected_f1_row2 = 2.0 * transformed[[2, 0]]; // 2 * 3.0 = 6.0
1621        assert!((imputed_f1_row2 - expected_f1_row2).abs() < 1.0); // Allow some tolerance
1622
1623        let imputed_f0_row3 = transformed[[3, 0]];
1624        let expected_f0_row3 = transformed[[3, 1]] / 2.0; // 8.0 / 2.0 = 4.0
1625        assert!((imputed_f0_row3 - expected_f0_row3).abs() < 1.0); // Allow some tolerance
1626    }
1627
1628    #[test]
1629    fn test_iterative_imputer_no_missing_values() {
1630        // Test with data that has no missing values
1631        let data = Array::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1632
1633        let mut imputer = IterativeImputer::with_defaults();
1634        let transformed = imputer.fit_transform(&data).unwrap();
1635
1636        // Data should remain unchanged
1637        for i in 0..3 {
1638            for j in 0..2 {
1639                assert_abs_diff_eq!(transformed[[i, j]], data[[i, j]], epsilon = 1e-10);
1640            }
1641        }
1642    }
1643
1644    #[test]
1645    fn test_iterative_imputer_convergence() {
1646        // Test with data that should converge quickly
1647        let data = Array::from_shape_vec(
1648            (5, 3),
1649            vec![
1650                1.0,
1651                2.0,
1652                3.0,
1653                2.0,
1654                f64::NAN,
1655                6.0,
1656                3.0,
1657                6.0,
1658                f64::NAN,
1659                4.0,
1660                8.0,
1661                12.0,
1662                f64::NAN,
1663                10.0,
1664                15.0,
1665            ],
1666        )
1667        .unwrap();
1668
1669        let mut imputer = IterativeImputer::new(
1670            20,   // max_iter
1671            1e-4, // tolerance
1672            ImputeStrategy::Mean,
1673            f64::NAN,
1674            1e-6, // alpha
1675        );
1676
1677        let transformed = imputer.fit_transform(&data).unwrap();
1678
1679        // All missing values should be imputed
1680        for i in 0..5 {
1681            for j in 0..3 {
1682                assert!(!transformed[[i, j]].is_nan());
1683            }
1684        }
1685    }
1686
1687    #[test]
1688    fn test_iterative_imputer_different_strategies() {
1689        let data = Array::from_shape_vec(
1690            (4, 2),
1691            vec![1.0, f64::NAN, 2.0, 4.0, 3.0, 6.0, f64::NAN, 8.0],
1692        )
1693        .unwrap();
1694
1695        // Test with median initial strategy
1696        let mut imputer_median =
1697            IterativeImputer::new(5, 1e-3, ImputeStrategy::Median, f64::NAN, 1e-6);
1698        let transformed_median = imputer_median.fit_transform(&data).unwrap();
1699        assert!(!transformed_median[[0, 1]].is_nan());
1700        assert!(!transformed_median[[3, 0]].is_nan());
1701
1702        // Test with constant initial strategy
1703        let mut imputer_constant =
1704            IterativeImputer::new(5, 1e-3, ImputeStrategy::Constant(999.0), f64::NAN, 1e-6);
1705        let transformed_constant = imputer_constant.fit_transform(&data).unwrap();
1706        assert!(!transformed_constant[[0, 1]].is_nan());
1707        assert!(!transformed_constant[[3, 0]].is_nan());
1708    }
1709
1710    #[test]
1711    fn test_iterative_imputer_builder_methods() {
1712        let imputer = IterativeImputer::with_defaults()
1713            .with_random_seed(42)
1714            .with_alpha(1e-3)
1715            .with_min_improvement(1e-5);
1716
1717        assert_eq!(imputer.random_seed, Some(42));
1718        assert_abs_diff_eq!(imputer.alpha, 1e-3, epsilon = 1e-10);
1719        assert_abs_diff_eq!(imputer.min_improvement, 1e-5, epsilon = 1e-10);
1720    }
1721
1722    #[test]
1723    fn test_iterative_imputer_errors() {
1724        // Test error when not fitted
1725        let imputer = IterativeImputer::with_defaults();
1726        let test_data = Array::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1727        assert!(imputer.transform(&test_data).is_err());
1728
1729        // Test error when all values are missing in a feature
1730        let bad_data =
1731            Array::from_shape_vec((3, 2), vec![f64::NAN, 1.0, f64::NAN, 2.0, f64::NAN, 3.0])
1732                .unwrap();
1733        let mut imputer = IterativeImputer::with_defaults();
1734        assert!(imputer.fit(&bad_data).is_err());
1735    }
1736
1737    #[test]
1738    fn test_simple_regressor() {
1739        // Test the internal SimpleRegressor
1740        let x = Array::from_shape_vec((3, 2), vec![1.0, 2.0, 2.0, 3.0, 3.0, 4.0]).unwrap();
1741        let y = Array::from_vec(vec![5.0, 8.0, 11.0]); // y = 2*x1 + x2 + 1
1742
1743        let mut regressor = SimpleRegressor::new(true, 1e-6);
1744        regressor.fit(&x, &y).unwrap();
1745
1746        let test_x = Array::from_shape_vec((2, 2), vec![4.0, 5.0, 5.0, 6.0]).unwrap();
1747        let predictions = regressor.predict(&test_x).unwrap();
1748
1749        // Check that predictions are reasonable
1750        assert_eq!(predictions.len(), 2);
1751        assert!(!predictions[0].is_nan());
1752        assert!(!predictions[1].is_nan());
1753    }
1754}