scirs2_transform/
impute.rs

1//! Missing value imputation utilities
2//!
3//! This module provides methods for handling missing values in datasets,
4//! which is a crucial preprocessing step for machine learning.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Ix2};
7use scirs2_core::numeric::{Float, NumCast};
8use scirs2_core::parallel_ops::*;
9
10use crate::error::{Result, TransformError};
11
12/// Strategy for imputing missing values
13#[derive(Debug, Clone, PartialEq)]
14pub enum ImputeStrategy {
15    /// Replace missing values with mean of the feature
16    Mean,
17    /// Replace missing values with median of the feature
18    Median,
19    /// Replace missing values with most frequent value
20    MostFrequent,
21    /// Replace missing values with a constant value
22    Constant(f64),
23}
24
25/// SimpleImputer for filling missing values
26///
27/// This transformer fills missing values using simple strategies like mean,
28/// median, most frequent value, or a constant value.
29pub struct SimpleImputer {
30    /// Strategy for imputation
31    strategy: ImputeStrategy,
32    /// Missing value indicator (what value is considered missing)
33    missingvalues: f64,
34    /// Values used for imputation (computed during fit)
35    statistics_: Option<Array1<f64>>,
36}
37
38impl SimpleImputer {
39    /// Creates a new SimpleImputer
40    ///
41    /// # Arguments
42    /// * `strategy` - The imputation strategy to use
43    /// * `missingvalues` - The value that represents missing data (default: NaN)
44    ///
45    /// # Returns
46    /// * A new SimpleImputer instance
47    pub fn new(strategy: ImputeStrategy, missingvalues: f64) -> Self {
48        SimpleImputer {
49            strategy,
50            missingvalues,
51            statistics_: None,
52        }
53    }
54
55    /// Creates a SimpleImputer with NaN as missing value indicator
56    ///
57    /// # Arguments
58    /// * `strategy` - The imputation strategy to use
59    ///
60    /// # Returns
61    /// * A new SimpleImputer instance
62    #[allow(dead_code)]
63    pub fn with_strategy(strategy: ImputeStrategy) -> Self {
64        Self::new(strategy, f64::NAN)
65    }
66
67    /// Fits the SimpleImputer to the input data
68    ///
69    /// # Arguments
70    /// * `x` - The input data, shape (n_samples, n_features)
71    ///
72    /// # Returns
73    /// * `Result<()>` - Ok if successful, Err otherwise
74    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
75    where
76        S: Data,
77        S::Elem: Float + NumCast,
78    {
79        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
80
81        let n_samples = x_f64.shape()[0];
82        let n_features = x_f64.shape()[1];
83
84        if n_samples == 0 || n_features == 0 {
85            return Err(TransformError::InvalidInput("Empty input data".to_string()));
86        }
87
88        let mut statistics = Array1::zeros(n_features);
89
90        for j in 0..n_features {
91            // Extract non-missing values for this feature
92            let feature_data: Vec<f64> = x_f64
93                .column(j)
94                .iter()
95                .filter(|&&val| !self.is_missing(val))
96                .copied()
97                .collect();
98
99            if feature_data.is_empty() {
100                return Err(TransformError::InvalidInput(format!(
101                    "All values are missing in feature {j}"
102                )));
103            }
104
105            statistics[j] = match &self.strategy {
106                ImputeStrategy::Mean => {
107                    feature_data.iter().sum::<f64>() / feature_data.len() as f64
108                }
109                ImputeStrategy::Median => {
110                    let mut sorted_data = feature_data.clone();
111                    sorted_data
112                        .sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
113                    let n = sorted_data.len();
114                    if n.is_multiple_of(2) {
115                        (sorted_data[n / 2 - 1] + sorted_data[n / 2]) / 2.0
116                    } else {
117                        sorted_data[n / 2]
118                    }
119                }
120                ImputeStrategy::MostFrequent => {
121                    // For numerical data, we'll find the value that appears most frequently
122                    // This is a simplified implementation
123                    let mut counts = std::collections::HashMap::new();
124                    for &val in &feature_data {
125                        *counts.entry(val.to_bits()).or_insert(0) += 1;
126                    }
127
128                    let most_frequent_bits = counts
129                        .into_iter()
130                        .max_by_key(|(_, count)| *count)
131                        .map(|(bits_, _)| bits_)
132                        .unwrap_or(0);
133
134                    f64::from_bits(most_frequent_bits)
135                }
136                ImputeStrategy::Constant(value) => *value,
137            };
138        }
139
140        self.statistics_ = Some(statistics);
141        Ok(())
142    }
143
144    /// Transforms the input data using the fitted SimpleImputer
145    ///
146    /// # Arguments
147    /// * `x` - The input data, shape (n_samples, n_features)
148    ///
149    /// # Returns
150    /// * `Result<Array2<f64>>` - The transformed data with imputed values
151    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
152    where
153        S: Data,
154        S::Elem: Float + NumCast,
155    {
156        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
157
158        let n_samples = x_f64.shape()[0];
159        let n_features = x_f64.shape()[1];
160
161        if self.statistics_.is_none() {
162            return Err(TransformError::TransformationError(
163                "SimpleImputer has not been fitted".to_string(),
164            ));
165        }
166
167        let statistics = self.statistics_.as_ref().unwrap();
168
169        if n_features != statistics.len() {
170            return Err(TransformError::InvalidInput(format!(
171                "x has {} features, but SimpleImputer was fitted with {} features",
172                n_features,
173                statistics.len()
174            )));
175        }
176
177        let mut transformed = Array2::zeros((n_samples, n_features));
178
179        for i in 0..n_samples {
180            for j in 0..n_features {
181                let value = x_f64[[i, j]];
182                if self.is_missing(value) {
183                    transformed[[i, j]] = statistics[j];
184                } else {
185                    transformed[[i, j]] = value;
186                }
187            }
188        }
189
190        Ok(transformed)
191    }
192
193    /// Fits the SimpleImputer to the input data and transforms it
194    ///
195    /// # Arguments
196    /// * `x` - The input data, shape (n_samples, n_features)
197    ///
198    /// # Returns
199    /// * `Result<Array2<f64>>` - The transformed data with imputed values
200    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
201    where
202        S: Data,
203        S::Elem: Float + NumCast,
204    {
205        self.fit(x)?;
206        self.transform(x)
207    }
208
209    /// Returns the statistics computed during fitting
210    ///
211    /// # Returns
212    /// * `Option<&Array1<f64>>` - The statistics for each feature
213    #[allow(dead_code)]
214    pub fn statistics(&self) -> Option<&Array1<f64>> {
215        self.statistics_.as_ref()
216    }
217
218    /// Checks if a value is considered missing
219    ///
220    /// # Arguments
221    /// * `value` - The value to check
222    ///
223    /// # Returns
224    /// * `bool` - True if the value is missing, false otherwise
225    fn is_missing(&self, value: f64) -> bool {
226        if self.missingvalues.is_nan() {
227            value.is_nan()
228        } else {
229            (value - self.missingvalues).abs() < f64::EPSILON
230        }
231    }
232}
233
234/// Indicator for missing values
235///
236/// This transformer creates a binary indicator matrix that shows where
237/// missing values were located in the original data.
238pub struct MissingIndicator {
239    /// Missing value indicator (what value is considered missing)
240    missingvalues: f64,
241    /// Features that have missing values (computed during fit)
242    features_: Option<Vec<usize>>,
243}
244
245impl MissingIndicator {
246    /// Creates a new MissingIndicator
247    ///
248    /// # Arguments
249    /// * `missingvalues` - The value that represents missing data (default: NaN)
250    ///
251    /// # Returns
252    /// * A new MissingIndicator instance
253    pub fn new(missingvalues: f64) -> Self {
254        MissingIndicator {
255            missingvalues,
256            features_: None,
257        }
258    }
259
260    /// Creates a MissingIndicator with NaN as missing value indicator
261    pub fn with_nan() -> Self {
262        Self::new(f64::NAN)
263    }
264
265    /// Fits the MissingIndicator to the input data
266    ///
267    /// # Arguments
268    /// * `x` - The input data, shape (n_samples, n_features)
269    ///
270    /// # Returns
271    /// * `Result<()>` - Ok if successful, Err otherwise
272    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
273    where
274        S: Data,
275        S::Elem: Float + NumCast,
276    {
277        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
278
279        let n_features = x_f64.shape()[1];
280        let mut features_with_missing = Vec::new();
281
282        for j in 0..n_features {
283            let has_missing = x_f64.column(j).iter().any(|&val| self.is_missing(val));
284            if has_missing {
285                features_with_missing.push(j);
286            }
287        }
288
289        self.features_ = Some(features_with_missing);
290        Ok(())
291    }
292
293    /// Transforms the input data to create missing value indicators
294    ///
295    /// # Arguments
296    /// * `x` - The input data, shape (n_samples, n_features)
297    ///
298    /// # Returns
299    /// * `Result<Array2<f64>>` - Binary indicator matrix, shape (n_samples, n_features_with_missing)
300    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
301    where
302        S: Data,
303        S::Elem: Float + NumCast,
304    {
305        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
306
307        let n_samples = x_f64.shape()[0];
308
309        if self.features_.is_none() {
310            return Err(TransformError::TransformationError(
311                "MissingIndicator has not been fitted".to_string(),
312            ));
313        }
314
315        let features_with_missing = self.features_.as_ref().unwrap();
316        let n_output_features = features_with_missing.len();
317
318        let mut indicators = Array2::zeros((n_samples, n_output_features));
319
320        for i in 0..n_samples {
321            for (out_j, &orig_j) in features_with_missing.iter().enumerate() {
322                if self.is_missing(x_f64[[i, orig_j]]) {
323                    indicators[[i, out_j]] = 1.0;
324                }
325            }
326        }
327
328        Ok(indicators)
329    }
330
331    /// Fits the MissingIndicator to the input data and transforms it
332    ///
333    /// # Arguments
334    /// * `x` - The input data, shape (n_samples, n_features)
335    ///
336    /// # Returns
337    /// * `Result<Array2<f64>>` - Binary indicator matrix
338    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
339    where
340        S: Data,
341        S::Elem: Float + NumCast,
342    {
343        self.fit(x)?;
344        self.transform(x)
345    }
346
347    /// Returns the features that have missing values
348    ///
349    /// # Returns
350    /// * `Option<&Vec<usize>>` - Indices of features with missing values
351    pub fn features(&self) -> Option<&Vec<usize>> {
352        self.features_.as_ref()
353    }
354
355    /// Checks if a value is considered missing
356    ///
357    /// # Arguments
358    /// * `value` - The value to check
359    ///
360    /// # Returns
361    /// * `bool` - True if the value is missing, false otherwise
362    fn is_missing(&self, value: f64) -> bool {
363        if self.missingvalues.is_nan() {
364            value.is_nan()
365        } else {
366            (value - self.missingvalues).abs() < f64::EPSILON
367        }
368    }
369}
370
371/// Distance metric for k-nearest neighbors search
372#[derive(Debug, Clone, PartialEq)]
373pub enum DistanceMetric {
374    /// Euclidean distance (L2 norm)
375    Euclidean,
376    /// Manhattan distance (L1 norm)
377    Manhattan,
378}
379
380/// Weighting scheme for k-nearest neighbors imputation
381#[derive(Debug, Clone, PartialEq)]
382pub enum WeightingScheme {
383    /// All neighbors contribute equally
384    Uniform,
385    /// Weight by inverse distance (closer neighbors have more influence)
386    Distance,
387}
388
389/// K-Nearest Neighbors Imputer for filling missing values
390///
391/// This transformer fills missing values using k-nearest neighbors.
392/// For each sample, the missing features are imputed from the nearest
393/// neighbors that have a value for that feature.
394pub struct KNNImputer {
395    /// Number of nearest neighbors to use
396    _nneighbors: usize,
397    /// Distance metric to use for finding neighbors
398    metric: DistanceMetric,
399    /// Weighting scheme for aggregating neighbor values
400    weights: WeightingScheme,
401    /// Missing value indicator (what value is considered missing)
402    missingvalues: f64,
403    /// Training data (stored to find neighbors during transform)
404    x_train_: Option<Array2<f64>>,
405}
406
407impl KNNImputer {
408    /// Creates a new KNNImputer
409    ///
410    /// # Arguments
411    /// * `_nneighbors` - Number of neighboring samples to use for imputation
412    /// * `metric` - Distance metric for finding neighbors
413    /// * `weights` - Weight function used in imputation
414    /// * `missingvalues` - The value that represents missing data (default: NaN)
415    ///
416    /// # Returns
417    /// * A new KNNImputer instance
418    pub fn new(
419        _nneighbors: usize,
420        metric: DistanceMetric,
421        weights: WeightingScheme,
422        missingvalues: f64,
423    ) -> Self {
424        KNNImputer {
425            _nneighbors,
426            metric,
427            weights,
428            missingvalues,
429            x_train_: None,
430        }
431    }
432
433    /// Creates a KNNImputer with default parameters
434    ///
435    /// Uses 5 neighbors, Euclidean distance, uniform weighting, and NaN as missing values
436    pub fn with_defaults() -> Self {
437        Self::new(
438            5,
439            DistanceMetric::Euclidean,
440            WeightingScheme::Uniform,
441            f64::NAN,
442        )
443    }
444
445    /// Creates a KNNImputer with specified number of neighbors and defaults for other parameters
446    pub fn with_n_neighbors(_nneighbors: usize) -> Self {
447        Self::new(
448            _nneighbors,
449            DistanceMetric::Euclidean,
450            WeightingScheme::Uniform,
451            f64::NAN,
452        )
453    }
454
455    /// Creates a KNNImputer with distance weighting
456    pub fn with_distance_weighting(_nneighbors: usize) -> Self {
457        Self::new(
458            _nneighbors,
459            DistanceMetric::Euclidean,
460            WeightingScheme::Distance,
461            f64::NAN,
462        )
463    }
464
465    /// Fits the KNNImputer to the input data
466    ///
467    /// # Arguments
468    /// * `x` - The input data, shape (n_samples, n_features)
469    ///
470    /// # Returns
471    /// * `Result<()>` - Ok if successful, Err otherwise
472    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
473    where
474        S: Data,
475        S::Elem: Float + NumCast,
476    {
477        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
478
479        // Validate that we have enough samples for k-nearest neighbors
480        let n_samples = x_f64.shape()[0];
481        if n_samples < self._nneighbors {
482            return Err(TransformError::InvalidInput(format!(
483                "Number of samples ({}) must be >= _nneighbors ({})",
484                n_samples, self._nneighbors
485            )));
486        }
487
488        // Store training data for neighbor search during transform
489        self.x_train_ = Some(x_f64);
490        Ok(())
491    }
492
493    /// Transforms the input data by imputing missing values
494    ///
495    /// # Arguments
496    /// * `x` - The input data, shape (n_samples, n_features)
497    ///
498    /// # Returns
499    /// * `Result<Array2<f64>>` - Transformed data with missing values imputed
500    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
501    where
502        S: Data,
503        S::Elem: Float + NumCast,
504    {
505        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
506
507        if self.x_train_.is_none() {
508            return Err(TransformError::TransformationError(
509                "KNNImputer must be fitted before transform".to_string(),
510            ));
511        }
512
513        let x_train = self.x_train_.as_ref().unwrap();
514        let (n_samples, n_features) = x_f64.dim();
515
516        if n_features != x_train.shape()[1] {
517            return Err(TransformError::InvalidInput(format!(
518                "Number of features in transform data ({}) doesn't match training data ({})",
519                n_features,
520                x_train.shape()[1]
521            )));
522        }
523
524        let mut result = x_f64.clone();
525
526        // Process each sample
527        for i in 0..n_samples {
528            let sample = x_f64.row(i);
529
530            // Find features that need imputation
531            let missing_features: Vec<usize> = (0..n_features)
532                .filter(|&j| self.is_missing(sample[j]))
533                .collect();
534
535            if missing_features.is_empty() {
536                continue; // No missing values in this sample
537            }
538
539            // Find k-nearest neighbors for this sample (excluding itself)
540            let neighbors =
541                self.find_nearest_neighbors_excluding(&sample.to_owned(), x_train, i)?;
542
543            // Impute each missing feature
544            for &feature_idx in &missing_features {
545                let imputed_value = self.impute_feature(feature_idx, &neighbors, x_train)?;
546                result[[i, feature_idx]] = imputed_value;
547            }
548        }
549
550        Ok(result)
551    }
552
553    /// Fits the imputer and transforms the data in one step
554    ///
555    /// # Arguments
556    /// * `x` - The input data, shape (n_samples, n_features)
557    ///
558    /// # Returns
559    /// * `Result<Array2<f64>>` - Transformed data with missing values imputed
560    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
561    where
562        S: Data,
563        S::Elem: Float + NumCast,
564    {
565        self.fit(x)?;
566        self.transform(x)
567    }
568
569    /// Find k-nearest neighbors for a given sample, excluding a specific index
570    fn find_nearest_neighbors_excluding(
571        &self,
572        sample: &Array1<f64>,
573        x_train: &Array2<f64>,
574        exclude_idx: usize,
575    ) -> Result<Vec<usize>> {
576        let n_train_samples = x_train.shape()[0];
577
578        // Compute distances to all training samples (excluding the specified index)
579        let distances: Vec<(usize, f64)> = (0..n_train_samples)
580            .into_par_iter()
581            .filter(|&i| i != exclude_idx)
582            .map(|i| {
583                let train_sample = x_train.row(i);
584                let distance = self.compute_distance(sample, &train_sample.to_owned());
585                (i, distance)
586            })
587            .collect();
588
589        // Sort by distance and take k nearest
590        let mut sorted_distances = distances;
591        sorted_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
592
593        let neighbors: Vec<usize> = sorted_distances
594            .into_iter()
595            .take(self._nneighbors)
596            .map(|(idx_, _)| idx_)
597            .collect();
598
599        Ok(neighbors)
600    }
601
602    /// Compute distance between two samples, handling missing values
603    fn compute_distance(&self, sample1: &Array1<f64>, sample2: &Array1<f64>) -> f64 {
604        let n_features = sample1.len();
605        let mut distance = 0.0;
606        let mut valid_features = 0;
607
608        for i in 0..n_features {
609            let val1 = sample1[i];
610            let val2 = sample2[i];
611
612            // Skip features where either sample has missing values
613            if self.is_missing(val1) || self.is_missing(val2) {
614                continue;
615            }
616
617            valid_features += 1;
618            let diff = val1 - val2;
619
620            match self.metric {
621                DistanceMetric::Euclidean => {
622                    distance += diff * diff;
623                }
624                DistanceMetric::Manhattan => {
625                    distance += diff.abs();
626                }
627            }
628        }
629
630        // Handle case where no valid features for comparison
631        if valid_features == 0 {
632            return f64::INFINITY;
633        }
634
635        // Normalize by number of valid features to make distances comparable
636        distance /= valid_features as f64;
637
638        match self.metric {
639            DistanceMetric::Euclidean => distance.sqrt(),
640            DistanceMetric::Manhattan => distance,
641        }
642    }
643
644    /// Impute a single feature using the k-nearest neighbors
645    fn impute_feature(
646        &self,
647        feature_idx: usize,
648        neighbors: &[usize],
649        x_train: &Array2<f64>,
650    ) -> Result<f64> {
651        let mut values = Vec::new();
652        let mut weights = Vec::new();
653
654        // Collect non-missing values from neighbors for this feature
655        for &neighbor_idx in neighbors {
656            let neighbor_value = x_train[[neighbor_idx, feature_idx]];
657
658            if !self.is_missing(neighbor_value) {
659                values.push(neighbor_value);
660
661                // Compute weight based on weighting scheme
662                let weight = match self.weights {
663                    WeightingScheme::Uniform => 1.0,
664                    WeightingScheme::Distance => {
665                        // For distance weighting, we need to recompute distance
666                        // This is a simplified version - could be optimized by storing distances
667                        1.0 // Placeholder - in practice, would use inverse distance
668                    }
669                };
670                weights.push(weight);
671            }
672        }
673
674        if values.is_empty() {
675            return Err(TransformError::TransformationError(format!(
676                "No valid neighbors found for feature {feature_idx} imputation"
677            )));
678        }
679
680        // Compute weighted average
681        let total_weight: f64 = weights.iter().sum();
682        if total_weight == 0.0 {
683            return Err(TransformError::TransformationError(
684                "Total weight is zero for imputation".to_string(),
685            ));
686        }
687
688        let weighted_sum: f64 = values
689            .iter()
690            .zip(weights.iter())
691            .map(|(&val, &weight)| val * weight)
692            .sum();
693
694        Ok(weighted_sum / total_weight)
695    }
696
697    /// Checks if a value is considered missing
698    fn is_missing(&self, value: f64) -> bool {
699        if self.missingvalues.is_nan() {
700            value.is_nan()
701        } else {
702            (value - self.missingvalues).abs() < f64::EPSILON
703        }
704    }
705
706    /// Returns the number of neighbors used for imputation
707    pub fn _nneighbors(&self) -> usize {
708        self._nneighbors
709    }
710
711    /// Returns the distance metric used
712    pub fn metric(&self) -> &DistanceMetric {
713        &self.metric
714    }
715
716    /// Returns the weighting scheme used
717    pub fn weights(&self) -> &WeightingScheme {
718        &self.weights
719    }
720}
721
722/// Simple regression model for MICE imputation
723///
724/// This is a basic linear regression implementation for use in the MICE algorithm.
725/// It uses a simple least squares approach with optional regularization.
726#[derive(Debug, Clone)]
727struct SimpleRegressor {
728    /// Regression coefficients (including intercept as first element)
729    coefficients: Option<Array1<f64>>,
730    /// Whether to include an intercept term
731    includeintercept: bool,
732    /// Regularization parameter (ridge regression)
733    alpha: f64,
734}
735
736impl SimpleRegressor {
737    /// Create a new simple regressor
738    fn new(includeintercept: bool, alpha: f64) -> Self {
739        Self {
740            coefficients: None,
741            includeintercept,
742            alpha,
743        }
744    }
745
746    /// Fit the regressor to the data
747    fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<()> {
748        let (n_samples, n_features) = x.dim();
749
750        if n_samples != y.len() {
751            return Err(TransformError::InvalidInput(
752                "X and y must have the same number of samples".to_string(),
753            ));
754        }
755
756        // Add intercept column if needed
757        let x_design = if self.includeintercept {
758            let mut x_with_intercept = Array2::ones((n_samples, n_features + 1));
759            x_with_intercept
760                .slice_mut(scirs2_core::ndarray::s![.., 1..])
761                .assign(x);
762            x_with_intercept
763        } else {
764            x.to_owned()
765        };
766
767        // Solve normal equations: (X^T X + alpha*I) * beta = X^T y
768        let xtx = x_design.t().dot(&x_design);
769        let xty = x_design.t().dot(y);
770
771        // Add regularization
772        let mut regularized_xtx = xtx;
773        let n_coeffs = regularized_xtx.shape()[0];
774        for i in 0..n_coeffs {
775            regularized_xtx[[i, i]] += self.alpha;
776        }
777
778        // Solve using simple Gaussian elimination (for small problems)
779        self.coefficients = Some(self.solve_linear_system(&regularized_xtx, &xty)?);
780
781        Ok(())
782    }
783
784    /// Predict using the fitted regressor
785    fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>> {
786        let coeffs = self.coefficients.as_ref().ok_or_else(|| {
787            TransformError::TransformationError(
788                "Regressor must be fitted before prediction".to_string(),
789            )
790        })?;
791
792        let x_design = if self.includeintercept {
793            let (n_samples, n_features) = x.dim();
794            let mut x_with_intercept = Array2::ones((n_samples, n_features + 1));
795            x_with_intercept
796                .slice_mut(scirs2_core::ndarray::s![.., 1..])
797                .assign(x);
798            x_with_intercept
799        } else {
800            x.to_owned()
801        };
802
803        Ok(x_design.dot(coeffs))
804    }
805
806    /// Simple linear system solver using Gaussian elimination
807    fn solve_linear_system(&self, a: &Array2<f64>, b: &Array1<f64>) -> Result<Array1<f64>> {
808        let n = a.shape()[0];
809        let mut aug_matrix = Array2::zeros((n, n + 1));
810
811        // Create augmented matrix [A|b]
812        aug_matrix
813            .slice_mut(scirs2_core::ndarray::s![.., ..n])
814            .assign(a);
815        aug_matrix
816            .slice_mut(scirs2_core::ndarray::s![.., n])
817            .assign(b);
818
819        // Forward elimination
820        for i in 0..n {
821            // Find pivot
822            let mut max_row = i;
823            for k in (i + 1)..n {
824                if aug_matrix[[k, i]].abs() > aug_matrix[[max_row, i]].abs() {
825                    max_row = k;
826                }
827            }
828
829            // Swap rows
830            if max_row != i {
831                for j in 0..=n {
832                    let temp = aug_matrix[[i, j]];
833                    aug_matrix[[i, j]] = aug_matrix[[max_row, j]];
834                    aug_matrix[[max_row, j]] = temp;
835                }
836            }
837
838            // Check for singular matrix
839            if aug_matrix[[i, i]].abs() < 1e-12 {
840                return Err(TransformError::TransformationError(
841                    "Singular matrix in regression".to_string(),
842                ));
843            }
844
845            // Make diagonal element 1
846            let pivot = aug_matrix[[i, i]];
847            for j in i..=n {
848                aug_matrix[[i, j]] /= pivot;
849            }
850
851            // Eliminate column
852            for k in 0..n {
853                if k != i {
854                    let factor = aug_matrix[[k, i]];
855                    for j in i..=n {
856                        aug_matrix[[k, j]] -= factor * aug_matrix[[i, j]];
857                    }
858                }
859            }
860        }
861
862        // Extract solution
863        let mut solution = Array1::zeros(n);
864        for i in 0..n {
865            solution[i] = aug_matrix[[i, n]];
866        }
867
868        Ok(solution)
869    }
870}
871
872/// Iterative Imputer using the MICE (Multiple Imputation by Chained Equations) algorithm
873///
874/// This transformer iteratively models each feature with missing values as a function
875/// of other features. The algorithm performs multiple rounds of imputation where each
876/// feature is predicted using the other features in a round-robin fashion.
877///
878/// MICE is particularly useful when:
879/// - There are multiple features with missing values
880/// - The missing patterns are complex
881/// - You want to model relationships between features
882pub struct IterativeImputer {
883    /// Maximum number of iterations to perform
884    max_iter: usize,
885    /// Convergence tolerance (change in imputed values between iterations)
886    tolerance: f64,
887    /// Initial strategy for first round of imputation
888    initial_strategy: ImputeStrategy,
889    /// Random seed for reproducibility
890    random_seed: Option<u64>,
891    /// Missing value indicator
892    missingvalues: f64,
893    /// Regularization parameter for regression
894    alpha: f64,
895    /// Minimum improvement to continue iterating
896    min_improvement: f64,
897
898    // Internal state
899    /// Training data for fitting predictors
900    x_train_: Option<Array2<f64>>,
901    /// Indices of features that had missing values during fitting
902    missing_features_: Option<Vec<usize>>,
903    /// Initial imputation values for features
904    initial_values_: Option<Array1<f64>>,
905    /// Whether the imputer has been fitted
906    is_fitted_: bool,
907}
908
909impl IterativeImputer {
910    /// Creates a new IterativeImputer
911    ///
912    /// # Arguments
913    /// * `max_iter` - Maximum number of iterations
914    /// * `tolerance` - Convergence tolerance
915    /// * `initial_strategy` - Strategy for initial imputation
916    /// * `missingvalues` - Value representing missing data
917    /// * `alpha` - Regularization parameter for regression
918    ///
919    /// # Returns
920    /// * A new IterativeImputer instance
921    pub fn new(
922        max_iter: usize,
923        tolerance: f64,
924        initial_strategy: ImputeStrategy,
925        missingvalues: f64,
926        alpha: f64,
927    ) -> Self {
928        IterativeImputer {
929            max_iter,
930            tolerance,
931            initial_strategy,
932            random_seed: None,
933            missingvalues,
934            alpha,
935            min_improvement: 1e-6,
936            x_train_: None,
937            missing_features_: None,
938            initial_values_: None,
939            is_fitted_: false,
940        }
941    }
942
943    /// Creates an IterativeImputer with default parameters
944    ///
945    /// Uses 10 iterations, 1e-3 tolerance, mean initial strategy, NaN missing values,
946    /// and 1e-6 regularization.
947    pub fn with_defaults() -> Self {
948        Self::new(10, 1e-3, ImputeStrategy::Mean, f64::NAN, 1e-6)
949    }
950
951    /// Creates an IterativeImputer with specified max iterations and defaults for other parameters
952    pub fn with_max_iter(_maxiter: usize) -> Self {
953        Self::new(_maxiter, 1e-3, ImputeStrategy::Mean, f64::NAN, 1e-6)
954    }
955
956    /// Set the random seed for reproducible results
957    pub fn with_random_seed(mut self, seed: u64) -> Self {
958        self.random_seed = Some(seed);
959        self
960    }
961
962    /// Set the regularization parameter
963    pub fn with_alpha(mut self, alpha: f64) -> Self {
964        self.alpha = alpha;
965        self
966    }
967
968    /// Set the minimum improvement threshold
969    pub fn with_min_improvement(mut self, minimprovement: f64) -> Self {
970        self.min_improvement = minimprovement;
971        self
972    }
973
974    /// Fits the IterativeImputer to the input data
975    ///
976    /// # Arguments
977    /// * `x` - The input data, shape (n_samples, n_features)
978    ///
979    /// # Returns
980    /// * `Result<()>` - Ok if successful, Err otherwise
981    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
982    where
983        S: Data,
984        S::Elem: Float + NumCast,
985    {
986        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
987        let (n_samples, n_features) = x_f64.dim();
988
989        if n_samples == 0 || n_features == 0 {
990            return Err(TransformError::InvalidInput("Empty input data".to_string()));
991        }
992
993        // Find features that have missing values
994        let missing_features: Vec<usize> = (0..n_features)
995            .filter(|&j| x_f64.column(j).iter().any(|&val| self.is_missing(val)))
996            .collect();
997
998        if missing_features.is_empty() {
999            // No missing values, store data as-is
1000            self.x_train_ = Some(x_f64);
1001            self.missing_features_ = Some(Vec::new());
1002            self.initial_values_ = Some(Array1::zeros(0));
1003            self.is_fitted_ = true;
1004            return Ok(());
1005        }
1006
1007        // Compute initial imputation values for each feature
1008        let mut initial_values = Array1::zeros(n_features);
1009        for &feature_idx in &missing_features {
1010            let feature_data: Vec<f64> = x_f64
1011                .column(feature_idx)
1012                .iter()
1013                .filter(|&&val| !self.is_missing(val))
1014                .copied()
1015                .collect();
1016
1017            if feature_data.is_empty() {
1018                return Err(TransformError::InvalidInput(format!(
1019                    "All values are missing in feature {feature_idx}"
1020                )));
1021            }
1022
1023            initial_values[feature_idx] = match &self.initial_strategy {
1024                ImputeStrategy::Mean => {
1025                    feature_data.iter().sum::<f64>() / feature_data.len() as f64
1026                }
1027                ImputeStrategy::Median => {
1028                    let mut sorted_data = feature_data;
1029                    sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap());
1030                    let len = sorted_data.len();
1031                    if len.is_multiple_of(2) {
1032                        (sorted_data[len / 2 - 1] + sorted_data[len / 2]) / 2.0
1033                    } else {
1034                        sorted_data[len / 2]
1035                    }
1036                }
1037                ImputeStrategy::MostFrequent => {
1038                    // For continuous data, use mean as approximation
1039                    feature_data.iter().sum::<f64>() / feature_data.len() as f64
1040                }
1041                ImputeStrategy::Constant(value) => *value,
1042            };
1043        }
1044
1045        self.x_train_ = Some(x_f64);
1046        self.missing_features_ = Some(missing_features);
1047        self.initial_values_ = Some(initial_values);
1048        self.is_fitted_ = true;
1049
1050        Ok(())
1051    }
1052
1053    /// Transforms the input data by imputing missing values using MICE
1054    ///
1055    /// # Arguments
1056    /// * `x` - The input data, shape (n_samples, n_features)
1057    ///
1058    /// # Returns
1059    /// * `Result<Array2<f64>>` - Transformed data with missing values imputed
1060    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
1061    where
1062        S: Data,
1063        S::Elem: Float + NumCast,
1064    {
1065        if !self.is_fitted_ {
1066            return Err(TransformError::TransformationError(
1067                "IterativeImputer must be fitted before transform".to_string(),
1068            ));
1069        }
1070
1071        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
1072        let missing_features = self.missing_features_.as_ref().unwrap();
1073
1074        if missing_features.is_empty() {
1075            // No missing values in training data, return as-is
1076            return Ok(x_f64);
1077        }
1078
1079        let initial_values = self.initial_values_.as_ref().unwrap();
1080        let (n_samples, n_features) = x_f64.dim();
1081
1082        // Start with initial imputation
1083        let mut imputed_data = x_f64.clone();
1084        self.apply_initial_imputation(&mut imputed_data, initial_values)?;
1085
1086        // MICE iterations
1087        for iteration in 0..self.max_iter {
1088            let mut max_change = 0.0;
1089            let old_imputed_data = imputed_data.clone();
1090
1091            // Iterate through each feature with missing values
1092            for &feature_idx in missing_features {
1093                // Find samples with missing values for this feature
1094                let missing_mask: Vec<bool> = (0..n_samples)
1095                    .map(|i| self.is_missing(x_f64[[i, feature_idx]]))
1096                    .collect();
1097
1098                if !missing_mask.iter().any(|&x| x) {
1099                    continue; // No missing values for this feature
1100                }
1101
1102                // Prepare predictors (all other features)
1103                let predictor_indices: Vec<usize> =
1104                    (0..n_features).filter(|&i| i != feature_idx).collect();
1105
1106                // Create training data from samples without missing values for this feature
1107                let (train_x, train_y) = self.prepare_training_data(
1108                    &imputed_data,
1109                    feature_idx,
1110                    &predictor_indices,
1111                    &missing_mask,
1112                )?;
1113
1114                if train_x.is_empty() {
1115                    continue; // Cannot train predictor
1116                }
1117
1118                // Fit predictor
1119                let mut regressor = SimpleRegressor::new(true, self.alpha);
1120                regressor.fit(&train_x, &train_y)?;
1121
1122                // Predict missing values
1123                let test_x =
1124                    self.prepare_test_data(&imputed_data, &predictor_indices, &missing_mask)?;
1125
1126                if !test_x.is_empty() {
1127                    let predictions = regressor.predict(&test_x)?;
1128
1129                    // Update imputed values
1130                    let mut pred_idx = 0;
1131                    for i in 0..n_samples {
1132                        if missing_mask[i] {
1133                            let old_value = imputed_data[[i, feature_idx]];
1134                            let new_value = predictions[pred_idx];
1135                            imputed_data[[i, feature_idx]] = new_value;
1136
1137                            let change = (new_value - old_value).abs();
1138                            max_change = max_change.max(change);
1139                            pred_idx += 1;
1140                        }
1141                    }
1142                }
1143            }
1144
1145            // Check convergence
1146            if max_change < self.tolerance {
1147                break;
1148            }
1149
1150            // Check for minimum improvement
1151            if iteration > 0 {
1152                let total_change = self.compute_total_change(&old_imputed_data, &imputed_data);
1153                if total_change < self.min_improvement {
1154                    break;
1155                }
1156            }
1157        }
1158
1159        Ok(imputed_data)
1160    }
1161
1162    /// Fits the imputer and transforms the data in one step
1163    ///
1164    /// # Arguments
1165    /// * `x` - The input data, shape (n_samples, n_features)
1166    ///
1167    /// # Returns
1168    /// * `Result<Array2<f64>>` - Transformed data with missing values imputed
1169    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
1170    where
1171        S: Data,
1172        S::Elem: Float + NumCast,
1173    {
1174        self.fit(x)?;
1175        self.transform(x)
1176    }
1177
1178    /// Apply initial imputation using the specified strategy
1179    fn apply_initial_imputation(
1180        &self,
1181        data: &mut Array2<f64>,
1182        initial_values: &Array1<f64>,
1183    ) -> Result<()> {
1184        let (n_samples, n_features) = data.dim();
1185
1186        for i in 0..n_samples {
1187            for j in 0..n_features {
1188                if self.is_missing(data[[i, j]]) {
1189                    data[[i, j]] = initial_values[j];
1190                }
1191            }
1192        }
1193
1194        Ok(())
1195    }
1196
1197    /// Prepare training data for a specific feature
1198    fn prepare_training_data(
1199        &self,
1200        data: &Array2<f64>,
1201        target_feature: usize,
1202        predictor_indices: &[usize],
1203        missing_mask: &[bool],
1204    ) -> Result<(Array2<f64>, Array1<f64>)> {
1205        let n_samples = data.shape()[0];
1206        let n_predictors = predictor_indices.len();
1207
1208        // Count non-missing samples
1209        let non_missing_count = missing_mask.iter().filter(|&&x| !x).count();
1210
1211        if non_missing_count == 0 {
1212            return Ok((Array2::zeros((0, n_predictors)), Array1::zeros(0)));
1213        }
1214
1215        let mut train_x = Array2::zeros((non_missing_count, n_predictors));
1216        let mut train_y = Array1::zeros(non_missing_count);
1217
1218        let mut train_idx = 0;
1219        for i in 0..n_samples {
1220            if !missing_mask[i] {
1221                // Copy predictor features
1222                for (pred_j, &orig_j) in predictor_indices.iter().enumerate() {
1223                    train_x[[train_idx, pred_j]] = data[[i, orig_j]];
1224                }
1225                // Copy target _feature
1226                train_y[train_idx] = data[[i, target_feature]];
1227                train_idx += 1;
1228            }
1229        }
1230
1231        Ok((train_x, train_y))
1232    }
1233
1234    /// Prepare test data for prediction
1235    fn prepare_test_data(
1236        &self,
1237        data: &Array2<f64>,
1238        predictor_indices: &[usize],
1239        missing_mask: &[bool],
1240    ) -> Result<Array2<f64>> {
1241        let n_samples = data.shape()[0];
1242        let n_predictors = predictor_indices.len();
1243
1244        // Count missing samples
1245        let missing_count = missing_mask.iter().filter(|&&x| x).count();
1246
1247        if missing_count == 0 {
1248            return Ok(Array2::zeros((0, n_predictors)));
1249        }
1250
1251        let mut test_x = Array2::zeros((missing_count, n_predictors));
1252
1253        let mut test_idx = 0;
1254        for i in 0..n_samples {
1255            if missing_mask[i] {
1256                // Copy predictor features
1257                for (pred_j, &orig_j) in predictor_indices.iter().enumerate() {
1258                    test_x[[test_idx, pred_j]] = data[[i, orig_j]];
1259                }
1260                test_idx += 1;
1261            }
1262        }
1263
1264        Ok(test_x)
1265    }
1266
1267    /// Compute total change between two imputation iterations
1268    fn compute_total_change(&self, old_data: &Array2<f64>, newdata: &Array2<f64>) -> f64 {
1269        let diff = newdata - old_data;
1270        diff.iter().map(|&x| x * x).sum::<f64>().sqrt()
1271    }
1272
1273    /// Check if a value is considered missing
1274    fn is_missing(&self, value: f64) -> bool {
1275        if self.missingvalues.is_nan() {
1276            value.is_nan()
1277        } else {
1278            (value - self.missingvalues).abs() < f64::EPSILON
1279        }
1280    }
1281}
1282
1283#[cfg(test)]
1284mod tests {
1285    use super::*;
1286    use approx::assert_abs_diff_eq;
1287    use scirs2_core::ndarray::Array;
1288
1289    #[test]
1290    fn test_simple_imputer_mean() {
1291        // Create test data with NaN values
1292        let data = Array::from_shape_vec(
1293            (4, 3),
1294            vec![
1295                1.0,
1296                2.0,
1297                3.0,
1298                f64::NAN,
1299                5.0,
1300                6.0,
1301                7.0,
1302                f64::NAN,
1303                9.0,
1304                10.0,
1305                11.0,
1306                f64::NAN,
1307            ],
1308        )
1309        .unwrap();
1310
1311        let mut imputer = SimpleImputer::with_strategy(ImputeStrategy::Mean);
1312        let transformed = imputer.fit_transform(&data).unwrap();
1313
1314        // Check shape is preserved
1315        assert_eq!(transformed.shape(), &[4, 3]);
1316
1317        // Check that mean values were used for imputation
1318        // Column 0: mean of [1.0, 7.0, 10.0] = 6.0
1319        // Column 1: mean of [2.0, 5.0, 11.0] = 6.0
1320        // Column 2: mean of [3.0, 6.0, 9.0] = 6.0
1321
1322        assert_abs_diff_eq!(transformed[[0, 0]], 1.0, epsilon = 1e-10);
1323        assert_abs_diff_eq!(transformed[[1, 0]], 6.0, epsilon = 1e-10); // Imputed
1324        assert_abs_diff_eq!(transformed[[2, 0]], 7.0, epsilon = 1e-10);
1325        assert_abs_diff_eq!(transformed[[3, 0]], 10.0, epsilon = 1e-10);
1326
1327        assert_abs_diff_eq!(transformed[[0, 1]], 2.0, epsilon = 1e-10);
1328        assert_abs_diff_eq!(transformed[[1, 1]], 5.0, epsilon = 1e-10);
1329        assert_abs_diff_eq!(transformed[[2, 1]], 6.0, epsilon = 1e-10); // Imputed
1330        assert_abs_diff_eq!(transformed[[3, 1]], 11.0, epsilon = 1e-10);
1331
1332        assert_abs_diff_eq!(transformed[[0, 2]], 3.0, epsilon = 1e-10);
1333        assert_abs_diff_eq!(transformed[[1, 2]], 6.0, epsilon = 1e-10);
1334        assert_abs_diff_eq!(transformed[[2, 2]], 9.0, epsilon = 1e-10);
1335        assert_abs_diff_eq!(transformed[[3, 2]], 6.0, epsilon = 1e-10); // Imputed
1336    }
1337
1338    #[test]
1339    fn test_simple_imputer_median() {
1340        // Create test data with NaN values
1341        let data = Array::from_shape_vec(
1342            (5, 2),
1343            vec![
1344                1.0,
1345                10.0,
1346                f64::NAN,
1347                20.0,
1348                3.0,
1349                f64::NAN,
1350                4.0,
1351                40.0,
1352                5.0,
1353                50.0,
1354            ],
1355        )
1356        .unwrap();
1357
1358        let mut imputer = SimpleImputer::with_strategy(ImputeStrategy::Median);
1359        let transformed = imputer.fit_transform(&data).unwrap();
1360
1361        // Check shape is preserved
1362        assert_eq!(transformed.shape(), &[5, 2]);
1363
1364        // Column 0: median of [1.0, 3.0, 4.0, 5.0] = 3.5
1365        // Column 1: median of [10.0, 20.0, 40.0, 50.0] = 30.0
1366
1367        assert_abs_diff_eq!(transformed[[1, 0]], 3.5, epsilon = 1e-10); // Imputed
1368        assert_abs_diff_eq!(transformed[[2, 1]], 30.0, epsilon = 1e-10); // Imputed
1369    }
1370
1371    #[test]
1372    fn test_simple_imputer_constant() {
1373        // Create test data with NaN values
1374        let data =
1375            Array::from_shape_vec((3, 2), vec![1.0, f64::NAN, f64::NAN, 3.0, 4.0, 5.0]).unwrap();
1376
1377        let mut imputer = SimpleImputer::with_strategy(ImputeStrategy::Constant(99.0));
1378        let transformed = imputer.fit_transform(&data).unwrap();
1379
1380        // Check that constant value was used for imputation
1381        assert_abs_diff_eq!(transformed[[0, 1]], 99.0, epsilon = 1e-10); // Imputed
1382        assert_abs_diff_eq!(transformed[[1, 0]], 99.0, epsilon = 1e-10); // Imputed
1383
1384        // Non-missing values should remain unchanged
1385        assert_abs_diff_eq!(transformed[[0, 0]], 1.0, epsilon = 1e-10);
1386        assert_abs_diff_eq!(transformed[[1, 1]], 3.0, epsilon = 1e-10);
1387        assert_abs_diff_eq!(transformed[[2, 0]], 4.0, epsilon = 1e-10);
1388        assert_abs_diff_eq!(transformed[[2, 1]], 5.0, epsilon = 1e-10);
1389    }
1390
1391    #[test]
1392    fn test_missing_indicator() {
1393        // Create test data with NaN values
1394        let data = Array::from_shape_vec(
1395            (3, 4),
1396            vec![
1397                1.0,
1398                f64::NAN,
1399                3.0,
1400                4.0,
1401                f64::NAN,
1402                6.0,
1403                f64::NAN,
1404                8.0,
1405                9.0,
1406                10.0,
1407                11.0,
1408                f64::NAN,
1409            ],
1410        )
1411        .unwrap();
1412
1413        let mut indicator = MissingIndicator::with_nan();
1414        let indicators = indicator.fit_transform(&data).unwrap();
1415
1416        // All features have missing values, so output shape should be (3, 4)
1417        assert_eq!(indicators.shape(), &[3, 4]);
1418
1419        // Check indicators
1420        assert_abs_diff_eq!(indicators[[0, 0]], 0.0, epsilon = 1e-10); // Not missing
1421        assert_abs_diff_eq!(indicators[[0, 1]], 1.0, epsilon = 1e-10); // Missing
1422        assert_abs_diff_eq!(indicators[[0, 2]], 0.0, epsilon = 1e-10); // Not missing
1423        assert_abs_diff_eq!(indicators[[0, 3]], 0.0, epsilon = 1e-10); // Not missing
1424
1425        assert_abs_diff_eq!(indicators[[1, 0]], 1.0, epsilon = 1e-10); // Missing
1426        assert_abs_diff_eq!(indicators[[1, 1]], 0.0, epsilon = 1e-10); // Not missing
1427        assert_abs_diff_eq!(indicators[[1, 2]], 1.0, epsilon = 1e-10); // Missing
1428        assert_abs_diff_eq!(indicators[[1, 3]], 0.0, epsilon = 1e-10); // Not missing
1429
1430        assert_abs_diff_eq!(indicators[[2, 0]], 0.0, epsilon = 1e-10); // Not missing
1431        assert_abs_diff_eq!(indicators[[2, 1]], 0.0, epsilon = 1e-10); // Not missing
1432        assert_abs_diff_eq!(indicators[[2, 2]], 0.0, epsilon = 1e-10); // Not missing
1433        assert_abs_diff_eq!(indicators[[2, 3]], 1.0, epsilon = 1e-10); // Missing
1434    }
1435
1436    #[test]
1437    fn test_imputer_errors() {
1438        // Test error when all values are missing in a feature
1439        let data = Array::from_shape_vec((2, 2), vec![f64::NAN, 1.0, f64::NAN, 2.0]).unwrap();
1440
1441        let mut imputer = SimpleImputer::with_strategy(ImputeStrategy::Mean);
1442        assert!(imputer.fit(&data).is_err());
1443    }
1444
1445    #[test]
1446    fn test_knn_imputer_basic() {
1447        // Create test data with missing values
1448        // Dataset:
1449        // [1.0, 2.0, 3.0]
1450        // [4.0, NaN, 6.0]
1451        // [7.0, 8.0, NaN]
1452        // [10.0, 11.0, 12.0]
1453        let data = Array::from_shape_vec(
1454            (4, 3),
1455            vec![
1456                1.0,
1457                2.0,
1458                3.0,
1459                4.0,
1460                f64::NAN,
1461                6.0,
1462                7.0,
1463                8.0,
1464                f64::NAN,
1465                10.0,
1466                11.0,
1467                12.0,
1468            ],
1469        )
1470        .unwrap();
1471
1472        let mut imputer = KNNImputer::with_n_neighbors(2);
1473        let transformed = imputer.fit_transform(&data).unwrap();
1474
1475        // Check shape is preserved
1476        assert_eq!(transformed.shape(), &[4, 3]);
1477
1478        // Check that non-missing values are unchanged
1479        assert_abs_diff_eq!(transformed[[0, 0]], 1.0, epsilon = 1e-10);
1480        assert_abs_diff_eq!(transformed[[0, 1]], 2.0, epsilon = 1e-10);
1481        assert_abs_diff_eq!(transformed[[0, 2]], 3.0, epsilon = 1e-10);
1482        assert_abs_diff_eq!(transformed[[3, 0]], 10.0, epsilon = 1e-10);
1483        assert_abs_diff_eq!(transformed[[3, 1]], 11.0, epsilon = 1e-10);
1484        assert_abs_diff_eq!(transformed[[3, 2]], 12.0, epsilon = 1e-10);
1485
1486        // Missing values should have been imputed (values depend on neighbors chosen)
1487        assert!(!transformed[[1, 1]].is_nan()); // Should be imputed
1488        assert!(!transformed[[2, 2]].is_nan()); // Should be imputed
1489    }
1490
1491    #[test]
1492    fn test_knn_imputer_simple_case() {
1493        // Simple test case where neighbors are easy to determine
1494        let data = Array::from_shape_vec((3, 2), vec![1.0, 1.0, f64::NAN, 2.0, 3.0, 3.0]).unwrap();
1495
1496        let mut imputer = KNNImputer::with_n_neighbors(2);
1497        let transformed = imputer.fit_transform(&data).unwrap();
1498
1499        // The missing value [?, 2.0] should be imputed based on nearest neighbors
1500        // Neighbors should be [1.0, 1.0] and [3.0, 3.0]
1501        // Expected imputed value for feature 0 should be close to 2.0 (average of 1.0 and 3.0)
1502        assert_abs_diff_eq!(transformed[[1, 0]], 2.0, epsilon = 1e-1);
1503    }
1504
1505    #[test]
1506    fn test_knn_imputer_manhattan_distance() {
1507        let data =
1508            Array::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, f64::NAN, 2.0, 2.0, 10.0, 10.0])
1509                .unwrap();
1510
1511        let mut imputer = KNNImputer::new(
1512            2,
1513            DistanceMetric::Manhattan,
1514            WeightingScheme::Uniform,
1515            f64::NAN,
1516        );
1517        let transformed = imputer.fit_transform(&data).unwrap();
1518
1519        // With Manhattan distance, the closest neighbors to [1.0, ?] should be
1520        // [0.0, 0.0] and [2.0, 2.0], not [10.0, 10.0]
1521        assert!(!transformed[[1, 1]].is_nan());
1522        // The imputed value should be reasonable (around 1.0)
1523        assert!(transformed[[1, 1]] < 5.0); // Should not be close to 10.0
1524    }
1525
1526    #[test]
1527    fn test_knn_imputer_validation_errors() {
1528        // Test insufficient samples
1529        let small_data = Array::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1530        let mut imputer = KNNImputer::with_n_neighbors(5); // More neighbors than samples
1531        assert!(imputer.fit(&small_data).is_err());
1532
1533        // Test transform without fit
1534        let data =
1535            Array::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1536        let unfitted_imputer = KNNImputer::with_n_neighbors(2);
1537        assert!(unfitted_imputer.transform(&data).is_err());
1538    }
1539
1540    #[test]
1541    fn test_knn_imputer_no_missing_values() {
1542        // Test data with no missing values
1543        let data = Array::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1544
1545        let mut imputer = KNNImputer::with_n_neighbors(2);
1546        let transformed = imputer.fit_transform(&data).unwrap();
1547
1548        // Should be unchanged
1549        assert_eq!(transformed, data);
1550    }
1551
1552    #[test]
1553    fn test_knn_imputer_accessors() {
1554        let imputer = KNNImputer::new(
1555            3,
1556            DistanceMetric::Manhattan,
1557            WeightingScheme::Distance,
1558            -999.0,
1559        );
1560
1561        assert_eq!(imputer._nneighbors(), 3);
1562        assert_eq!(imputer.metric(), &DistanceMetric::Manhattan);
1563        assert_eq!(imputer.weights(), &WeightingScheme::Distance);
1564    }
1565
1566    #[test]
1567    fn test_knn_imputer_multiple_missing_features() {
1568        // Test sample with multiple missing features
1569        let data = Array::from_shape_vec(
1570            (4, 3),
1571            vec![
1572                1.0,
1573                2.0,
1574                3.0,
1575                f64::NAN,
1576                f64::NAN,
1577                6.0,
1578                7.0,
1579                8.0,
1580                9.0,
1581                10.0,
1582                11.0,
1583                12.0,
1584            ],
1585        )
1586        .unwrap();
1587
1588        let mut imputer = KNNImputer::with_n_neighbors(2);
1589        let transformed = imputer.fit_transform(&data).unwrap();
1590
1591        // Both missing values should be imputed
1592        assert!(!transformed[[1, 0]].is_nan());
1593        assert!(!transformed[[1, 1]].is_nan());
1594        // Non-missing value should remain unchanged
1595        assert_abs_diff_eq!(transformed[[1, 2]], 6.0, epsilon = 1e-10);
1596    }
1597
1598    #[test]
1599    fn test_iterative_imputer_basic() {
1600        // Create test data with missing values that have relationships
1601        // Dataset with correlated features:
1602        // Feature 0: [1.0, 2.0, 3.0, NaN]
1603        // Feature 1: [2.0, 4.0, NaN, 8.0] (roughly 2 * feature 0)
1604        let data = Array::from_shape_vec(
1605            (4, 2),
1606            vec![1.0, 2.0, 2.0, 4.0, 3.0, f64::NAN, f64::NAN, 8.0],
1607        )
1608        .unwrap();
1609
1610        let mut imputer = IterativeImputer::with_max_iter(5);
1611        let transformed = imputer.fit_transform(&data).unwrap();
1612
1613        // Check that missing values have been imputed
1614        assert!(!transformed[[2, 1]].is_nan()); // Feature 1 in row 2
1615        assert!(!transformed[[3, 0]].is_nan()); // Feature 0 in row 3
1616
1617        // Non-missing values should remain unchanged
1618        assert_abs_diff_eq!(transformed[[0, 0]], 1.0, epsilon = 1e-10);
1619        assert_abs_diff_eq!(transformed[[0, 1]], 2.0, epsilon = 1e-10);
1620        assert_abs_diff_eq!(transformed[[1, 0]], 2.0, epsilon = 1e-10);
1621        assert_abs_diff_eq!(transformed[[1, 1]], 4.0, epsilon = 1e-10);
1622        assert_abs_diff_eq!(transformed[[2, 0]], 3.0, epsilon = 1e-10);
1623        assert_abs_diff_eq!(transformed[[3, 1]], 8.0, epsilon = 1e-10);
1624
1625        // Check that the imputed values are reasonable given the linear relationship
1626        // Feature 1 should be approximately 2 * feature 0
1627        let imputed_f1_row2 = transformed[[2, 1]];
1628        let expected_f1_row2 = 2.0 * transformed[[2, 0]]; // 2 * 3.0 = 6.0
1629        assert!((imputed_f1_row2 - expected_f1_row2).abs() < 1.0); // Allow some tolerance
1630
1631        let imputed_f0_row3 = transformed[[3, 0]];
1632        let expected_f0_row3 = transformed[[3, 1]] / 2.0; // 8.0 / 2.0 = 4.0
1633        assert!((imputed_f0_row3 - expected_f0_row3).abs() < 1.0); // Allow some tolerance
1634    }
1635
1636    #[test]
1637    fn test_iterative_imputer_no_missing_values() {
1638        // Test with data that has no missing values
1639        let data = Array::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1640
1641        let mut imputer = IterativeImputer::with_defaults();
1642        let transformed = imputer.fit_transform(&data).unwrap();
1643
1644        // Data should remain unchanged
1645        for i in 0..3 {
1646            for j in 0..2 {
1647                assert_abs_diff_eq!(transformed[[i, j]], data[[i, j]], epsilon = 1e-10);
1648            }
1649        }
1650    }
1651
1652    #[test]
1653    fn test_iterative_imputer_convergence() {
1654        // Test with data that should converge quickly
1655        let data = Array::from_shape_vec(
1656            (5, 3),
1657            vec![
1658                1.0,
1659                2.0,
1660                3.0,
1661                2.0,
1662                f64::NAN,
1663                6.0,
1664                3.0,
1665                6.0,
1666                f64::NAN,
1667                4.0,
1668                8.0,
1669                12.0,
1670                f64::NAN,
1671                10.0,
1672                15.0,
1673            ],
1674        )
1675        .unwrap();
1676
1677        let mut imputer = IterativeImputer::new(
1678            20,   // max_iter
1679            1e-4, // tolerance
1680            ImputeStrategy::Mean,
1681            f64::NAN,
1682            1e-6, // alpha
1683        );
1684
1685        let transformed = imputer.fit_transform(&data).unwrap();
1686
1687        // All missing values should be imputed
1688        for i in 0..5 {
1689            for j in 0..3 {
1690                assert!(!transformed[[i, j]].is_nan());
1691            }
1692        }
1693    }
1694
1695    #[test]
1696    fn test_iterative_imputer_different_strategies() {
1697        let data = Array::from_shape_vec(
1698            (4, 2),
1699            vec![1.0, f64::NAN, 2.0, 4.0, 3.0, 6.0, f64::NAN, 8.0],
1700        )
1701        .unwrap();
1702
1703        // Test with median initial strategy
1704        let mut imputer_median =
1705            IterativeImputer::new(5, 1e-3, ImputeStrategy::Median, f64::NAN, 1e-6);
1706        let transformed_median = imputer_median.fit_transform(&data).unwrap();
1707        assert!(!transformed_median[[0, 1]].is_nan());
1708        assert!(!transformed_median[[3, 0]].is_nan());
1709
1710        // Test with constant initial strategy
1711        let mut imputer_constant =
1712            IterativeImputer::new(5, 1e-3, ImputeStrategy::Constant(999.0), f64::NAN, 1e-6);
1713        let transformed_constant = imputer_constant.fit_transform(&data).unwrap();
1714        assert!(!transformed_constant[[0, 1]].is_nan());
1715        assert!(!transformed_constant[[3, 0]].is_nan());
1716    }
1717
1718    #[test]
1719    fn test_iterative_imputer_builder_methods() {
1720        let imputer = IterativeImputer::with_defaults()
1721            .with_random_seed(42)
1722            .with_alpha(1e-3)
1723            .with_min_improvement(1e-5);
1724
1725        assert_eq!(imputer.random_seed, Some(42));
1726        assert_abs_diff_eq!(imputer.alpha, 1e-3, epsilon = 1e-10);
1727        assert_abs_diff_eq!(imputer.min_improvement, 1e-5, epsilon = 1e-10);
1728    }
1729
1730    #[test]
1731    fn test_iterative_imputer_errors() {
1732        // Test error when not fitted
1733        let imputer = IterativeImputer::with_defaults();
1734        let test_data = Array::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1735        assert!(imputer.transform(&test_data).is_err());
1736
1737        // Test error when all values are missing in a feature
1738        let bad_data =
1739            Array::from_shape_vec((3, 2), vec![f64::NAN, 1.0, f64::NAN, 2.0, f64::NAN, 3.0])
1740                .unwrap();
1741        let mut imputer = IterativeImputer::with_defaults();
1742        assert!(imputer.fit(&bad_data).is_err());
1743    }
1744
1745    #[test]
1746    fn test_simple_regressor() {
1747        // Test the internal SimpleRegressor
1748        let x = Array::from_shape_vec((3, 2), vec![1.0, 2.0, 2.0, 3.0, 3.0, 4.0]).unwrap();
1749        let y = Array::from_vec(vec![5.0, 8.0, 11.0]); // y = 2*x1 + x2 + 1
1750
1751        let mut regressor = SimpleRegressor::new(true, 1e-6);
1752        regressor.fit(&x, &y).unwrap();
1753
1754        let test_x = Array::from_shape_vec((2, 2), vec![4.0, 5.0, 5.0, 6.0]).unwrap();
1755        let predictions = regressor.predict(&test_x).unwrap();
1756
1757        // Check that predictions are reasonable
1758        assert_eq!(predictions.len(), 2);
1759        assert!(!predictions[0].is_nan());
1760        assert!(!predictions[1].is_nan());
1761    }
1762}