sklears_utils/
preprocessing.rs

1//! Data preprocessing utilities for machine learning
2//!
3//! This module provides utilities for data cleaning, outlier detection,
4//! data transformation, and quality assessment.
5
6use crate::{UtilsError, UtilsResult};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1, Axis};
8use scirs2_core::numeric::Float;
9use std::collections::HashMap;
10
11/// Data cleaning utilities
12pub struct DataCleaner;
13
14impl DataCleaner {
15    /// Remove rows with missing values (NaN)
16    pub fn drop_missing_rows<T>(data: &Array2<T>) -> UtilsResult<Array2<T>>
17    where
18        T: Float + Clone + std::iter::Sum,
19    {
20        let mut valid_rows = Vec::new();
21
22        for (row_idx, row) in data.axis_iter(Axis(0)).enumerate() {
23            if !row.iter().any(|&x| x.is_nan()) {
24                valid_rows.push(row_idx);
25            }
26        }
27
28        if valid_rows.is_empty() {
29            return Err(UtilsError::EmptyInput);
30        }
31
32        let mut result = Array2::zeros((valid_rows.len(), data.ncols()));
33        for (new_idx, &old_idx) in valid_rows.iter().enumerate() {
34            result.row_mut(new_idx).assign(&data.row(old_idx));
35        }
36
37        Ok(result)
38    }
39
40    /// Fill missing values with specified value
41    pub fn fill_missing<T>(data: &mut Array2<T>, fill_value: T)
42    where
43        T: Float + Clone + std::iter::Sum,
44    {
45        data.mapv_inplace(|x| if x.is_nan() { fill_value } else { x });
46    }
47
48    /// Fill missing values with column means
49    pub fn fill_with_mean<T>(data: &mut Array2<T>) -> UtilsResult<()>
50    where
51        T: Float + Clone + std::iter::Sum,
52    {
53        for col_idx in 0..data.ncols() {
54            let col = data.column(col_idx);
55            let valid_values: Vec<T> = col.iter().cloned().filter(|x| !x.is_nan()).collect();
56
57            if !valid_values.is_empty() {
58                let mean =
59                    valid_values.iter().cloned().sum::<T>() / T::from(valid_values.len()).unwrap();
60
61                for row_idx in 0..data.nrows() {
62                    if data[[row_idx, col_idx]].is_nan() {
63                        data[[row_idx, col_idx]] = mean;
64                    }
65                }
66            }
67        }
68        Ok(())
69    }
70
71    /// Fill missing values with column medians
72    pub fn fill_with_median<T>(data: &mut Array2<T>) -> UtilsResult<()>
73    where
74        T: Float + Clone + PartialOrd,
75    {
76        for col_idx in 0..data.ncols() {
77            let col = data.column(col_idx);
78            let mut valid_values: Vec<T> = col.iter().cloned().filter(|x| !x.is_nan()).collect();
79
80            if !valid_values.is_empty() {
81                valid_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
82                let median = if valid_values.len() % 2 == 0 {
83                    let mid = valid_values.len() / 2;
84                    (valid_values[mid - 1] + valid_values[mid]) / T::from(2).unwrap()
85                } else {
86                    valid_values[valid_values.len() / 2]
87                };
88
89                for row_idx in 0..data.nrows() {
90                    if data[[row_idx, col_idx]].is_nan() {
91                        data[[row_idx, col_idx]] = median;
92                    }
93                }
94            }
95        }
96        Ok(())
97    }
98}
99
100/// Outlier detection methods
101pub struct OutlierDetector;
102
103impl OutlierDetector {
104    /// Detect outliers using Z-score method
105    pub fn zscore_outliers<T>(data: &ArrayView1<T>, threshold: T) -> Vec<usize>
106    where
107        T: Float + Clone + std::iter::Sum,
108    {
109        let mean = data.iter().cloned().sum::<T>() / T::from(data.len()).unwrap();
110        let variance =
111            data.iter().map(|&x| (x - mean).powi(2)).sum::<T>() / T::from(data.len()).unwrap();
112        let std_dev = variance.sqrt();
113
114        if std_dev == T::zero() {
115            return Vec::new();
116        }
117
118        data.iter()
119            .enumerate()
120            .filter_map(|(idx, &value)| {
121                let z_score = (value - mean).abs() / std_dev;
122                if z_score > threshold {
123                    Some(idx)
124                } else {
125                    None
126                }
127            })
128            .collect()
129    }
130
131    /// Detect outliers using IQR (Interquartile Range) method
132    pub fn iqr_outliers<T>(data: &ArrayView1<T>, multiplier: T) -> Vec<usize>
133    where
134        T: Float + Clone + PartialOrd,
135    {
136        let mut sorted_data: Vec<T> = data.iter().cloned().collect();
137        sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap());
138
139        let n = sorted_data.len();
140        if n < 4 {
141            return Vec::new();
142        }
143
144        let q1_idx = n / 4;
145        let q3_idx = 3 * n / 4;
146        let q1 = sorted_data[q1_idx];
147        let q3 = sorted_data[q3_idx];
148        let iqr = q3 - q1;
149
150        let lower_bound = q1 - multiplier * iqr;
151        let upper_bound = q3 + multiplier * iqr;
152
153        data.iter()
154            .enumerate()
155            .filter_map(|(idx, &value)| {
156                if value < lower_bound || value > upper_bound {
157                    Some(idx)
158                } else {
159                    None
160                }
161            })
162            .collect()
163    }
164
165    /// Detect outliers using modified Z-score method (using median)
166    pub fn modified_zscore_outliers<T>(data: &ArrayView1<T>, threshold: T) -> Vec<usize>
167    where
168        T: Float + Clone + PartialOrd,
169    {
170        let mut sorted_data: Vec<T> = data.iter().cloned().collect();
171        sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap());
172
173        let n = sorted_data.len();
174        if n == 0 {
175            return Vec::new();
176        }
177
178        let median = if n % 2 == 0 {
179            (sorted_data[n / 2 - 1] + sorted_data[n / 2]) / T::from(2).unwrap()
180        } else {
181            sorted_data[n / 2]
182        };
183
184        // Calculate MAD (Median Absolute Deviation)
185        let mut deviations: Vec<T> = data.iter().map(|&x| (x - median).abs()).collect();
186        deviations.sort_by(|a, b| a.partial_cmp(b).unwrap());
187
188        let mad = if deviations.len() % 2 == 0 {
189            let mid = deviations.len() / 2;
190            (deviations[mid - 1] + deviations[mid]) / T::from(2).unwrap()
191        } else {
192            deviations[deviations.len() / 2]
193        };
194
195        if mad == T::zero() {
196            return Vec::new();
197        }
198
199        let mad_scaled = mad * T::from(1.4826).unwrap(); // Scale factor for normal distribution
200
201        data.iter()
202            .enumerate()
203            .filter_map(|(idx, &value)| {
204                let modified_z = T::from(0.6745).unwrap() * (value - median).abs() / mad_scaled;
205                if modified_z > threshold {
206                    Some(idx)
207                } else {
208                    None
209                }
210            })
211            .collect()
212    }
213}
214
215/// Feature scaling utilities
216pub struct FeatureScaler;
217
218impl FeatureScaler {
219    /// Standard scaling (z-score normalization)
220    pub fn standard_scale<T>(data: &Array2<T>) -> UtilsResult<(Array2<T>, Array1<T>, Array1<T>)>
221    where
222        T: Float + Clone + std::iter::Sum,
223    {
224        let mut scaled_data = data.clone();
225        let mut means = Array1::zeros(data.ncols());
226        let mut stds = Array1::zeros(data.ncols());
227
228        for col_idx in 0..data.ncols() {
229            let col = data.column(col_idx);
230            let mean = col.iter().cloned().sum::<T>() / T::from(col.len()).unwrap();
231            let variance =
232                col.iter().map(|&x| (x - mean).powi(2)).sum::<T>() / T::from(col.len()).unwrap();
233            let std_dev = variance.sqrt();
234
235            means[col_idx] = mean;
236            stds[col_idx] = std_dev;
237
238            if std_dev != T::zero() {
239                for row_idx in 0..data.nrows() {
240                    scaled_data[[row_idx, col_idx]] = (data[[row_idx, col_idx]] - mean) / std_dev;
241                }
242            }
243        }
244
245        Ok((scaled_data, means, stds))
246    }
247
248    /// Min-max scaling to [0, 1] range
249    pub fn minmax_scale<T>(data: &Array2<T>) -> UtilsResult<(Array2<T>, Array1<T>, Array1<T>)>
250    where
251        T: Float + Clone + PartialOrd,
252    {
253        let mut scaled_data = data.clone();
254        let mut mins = Array1::zeros(data.ncols());
255        let mut maxs = Array1::zeros(data.ncols());
256
257        for col_idx in 0..data.ncols() {
258            let col = data.column(col_idx);
259            let min_val = col
260                .iter()
261                .cloned()
262                .fold(col[0], |acc, x| if x < acc { x } else { acc });
263            let max_val = col
264                .iter()
265                .cloned()
266                .fold(col[0], |acc, x| if x > acc { x } else { acc });
267
268            mins[col_idx] = min_val;
269            maxs[col_idx] = max_val;
270
271            let range = max_val - min_val;
272            if range != T::zero() {
273                for row_idx in 0..data.nrows() {
274                    scaled_data[[row_idx, col_idx]] = (data[[row_idx, col_idx]] - min_val) / range;
275                }
276            }
277        }
278
279        Ok((scaled_data, mins, maxs))
280    }
281
282    /// Robust scaling using median and IQR
283    pub fn robust_scale<T>(data: &Array2<T>) -> UtilsResult<(Array2<T>, Array1<T>, Array1<T>)>
284    where
285        T: Float + Clone + PartialOrd,
286    {
287        let mut scaled_data = data.clone();
288        let mut medians = Array1::zeros(data.ncols());
289        let mut iqrs = Array1::zeros(data.ncols());
290
291        for col_idx in 0..data.ncols() {
292            let col = data.column(col_idx);
293            let mut sorted_col: Vec<T> = col.iter().cloned().collect();
294            sorted_col.sort_by(|a, b| a.partial_cmp(b).unwrap());
295
296            let n = sorted_col.len();
297            let median = if n % 2 == 0 {
298                (sorted_col[n / 2 - 1] + sorted_col[n / 2]) / T::from(2).unwrap()
299            } else {
300                sorted_col[n / 2]
301            };
302
303            let q1_idx = n / 4;
304            let q3_idx = 3 * n / 4;
305            let q1 = sorted_col[q1_idx];
306            let q3 = sorted_col[q3_idx];
307            let iqr = q3 - q1;
308
309            medians[col_idx] = median;
310            iqrs[col_idx] = iqr;
311
312            if iqr != T::zero() {
313                for row_idx in 0..data.nrows() {
314                    scaled_data[[row_idx, col_idx]] = (data[[row_idx, col_idx]] - median) / iqr;
315                }
316            }
317        }
318
319        Ok((scaled_data, medians, iqrs))
320    }
321}
322
323/// Data quality assessment utilities
324pub struct DataQualityAssessor;
325
326impl DataQualityAssessor {
327    /// Calculate missing value statistics
328    pub fn missing_value_stats<T>(data: &Array2<T>) -> HashMap<String, f64>
329    where
330        T: Float,
331    {
332        let total_cells = data.len() as f64;
333        let mut missing_count = 0;
334        let mut missing_per_column = Vec::new();
335        let mut missing_per_row = Vec::new();
336
337        // Count missing values per column
338        for col_idx in 0..data.ncols() {
339            let col_missing = data.column(col_idx).iter().filter(|&&x| x.is_nan()).count();
340            missing_per_column.push(col_missing as f64 / data.nrows() as f64);
341            missing_count += col_missing;
342        }
343
344        // Count missing values per row
345        for row_idx in 0..data.nrows() {
346            let row_missing = data.row(row_idx).iter().filter(|&&x| x.is_nan()).count();
347            missing_per_row.push(row_missing as f64 / data.ncols() as f64);
348        }
349
350        let mut stats = HashMap::new();
351        stats.insert(
352            "total_missing_ratio".to_string(),
353            missing_count as f64 / total_cells,
354        );
355        stats.insert(
356            "max_column_missing_ratio".to_string(),
357            missing_per_column.iter().cloned().fold(0.0, f64::max),
358        );
359        stats.insert(
360            "max_row_missing_ratio".to_string(),
361            missing_per_row.iter().cloned().fold(0.0, f64::max),
362        );
363        stats.insert(
364            "columns_with_missing".to_string(),
365            missing_per_column.iter().filter(|&&x| x > 0.0).count() as f64,
366        );
367        stats.insert(
368            "rows_with_missing".to_string(),
369            missing_per_row.iter().filter(|&&x| x > 0.0).count() as f64,
370        );
371
372        stats
373    }
374
375    /// Calculate basic data quality metrics
376    pub fn quality_metrics<T>(data: &Array2<T>) -> HashMap<String, f64>
377    where
378        T: Float + PartialOrd + std::iter::Sum + std::fmt::Display,
379    {
380        let mut metrics = HashMap::new();
381
382        // Calculate completeness (non-missing ratio)
383        let total_cells = data.len() as f64;
384        let missing_count = data.iter().filter(|&&x| x.is_nan()).count() as f64;
385        metrics.insert(
386            "completeness".to_string(),
387            1.0 - (missing_count / total_cells),
388        );
389
390        // Calculate uniformity (check for repeated values)
391        let mut unique_counts = Vec::new();
392        for col_idx in 0..data.ncols() {
393            let col = data.column(col_idx);
394            let mut unique_values = std::collections::HashSet::new();
395            for &value in col.iter() {
396                if !value.is_nan() {
397                    // Convert to string for hashing (approximation)
398                    unique_values.insert(format!("{value:.6}"));
399                }
400            }
401            let uniqueness = unique_values.len() as f64 / col.len() as f64;
402            unique_counts.push(uniqueness);
403        }
404
405        let avg_uniqueness = unique_counts.iter().sum::<f64>() / unique_counts.len() as f64;
406        metrics.insert("uniqueness".to_string(), avg_uniqueness);
407
408        // Calculate consistency (low coefficient of variation)
409        let mut cv_values = Vec::new();
410        for col_idx in 0..data.ncols() {
411            let col = data.column(col_idx);
412            let valid_values: Vec<T> = col.iter().cloned().filter(|x| !x.is_nan()).collect();
413
414            if valid_values.len() > 1 {
415                let mean =
416                    valid_values.iter().cloned().sum::<T>() / T::from(valid_values.len()).unwrap();
417                let variance = valid_values.iter().map(|&x| (x - mean).powi(2)).sum::<T>()
418                    / T::from(valid_values.len()).unwrap();
419                let std_dev = variance.sqrt();
420
421                if mean != T::zero() {
422                    let cv = (std_dev / mean.abs()).to_f64().unwrap();
423                    cv_values.push(cv);
424                }
425            }
426        }
427
428        if !cv_values.is_empty() {
429            let avg_cv = cv_values.iter().sum::<f64>() / cv_values.len() as f64;
430            metrics.insert("consistency".to_string(), 1.0 / (1.0 + avg_cv)); // Higher is better
431        }
432
433        metrics
434    }
435}
436
437#[allow(non_snake_case)]
438#[cfg(test)]
439mod tests {
440    use super::*;
441    use approx::assert_abs_diff_eq;
442    use scirs2_core::ndarray::array;
443
444    #[test]
445    fn test_drop_missing_rows() {
446        let data = array![
447            [1.0, 2.0, 3.0],
448            [4.0, f64::NAN, 6.0],
449            [7.0, 8.0, 9.0],
450            [f64::NAN, 11.0, 12.0]
451        ];
452
453        let cleaned = DataCleaner::drop_missing_rows(&data).unwrap();
454        assert_eq!(cleaned.nrows(), 2);
455        assert_eq!(cleaned.row(0), array![1.0, 2.0, 3.0]);
456        assert_eq!(cleaned.row(1), array![7.0, 8.0, 9.0]);
457    }
458
459    #[test]
460    fn test_fill_missing_with_value() {
461        let mut data = array![[1.0, 2.0], [f64::NAN, 4.0], [5.0, f64::NAN]];
462
463        DataCleaner::fill_missing(&mut data, 0.0);
464
465        assert_eq!(data, array![[1.0, 2.0], [0.0, 4.0], [5.0, 0.0]]);
466    }
467
468    #[test]
469    fn test_fill_with_mean() {
470        let mut data = array![[1.0, 2.0], [f64::NAN, 4.0], [5.0, f64::NAN]];
471
472        DataCleaner::fill_with_mean(&mut data).unwrap();
473
474        // Mean of first column (1, 5) = 3, mean of second column (2, 4) = 3
475        assert_abs_diff_eq!(data[[1, 0]], 3.0, epsilon = 1e-10);
476        assert_abs_diff_eq!(data[[2, 1]], 3.0, epsilon = 1e-10);
477    }
478
479    #[test]
480    fn test_zscore_outliers() {
481        let data = array![1.0, 2.0, 3.0, 4.0, 100.0]; // 100 is clearly an outlier
482        let outliers = OutlierDetector::zscore_outliers(&data.view(), 1.5);
483        assert_eq!(outliers, vec![4]);
484    }
485
486    #[test]
487    fn test_iqr_outliers() {
488        let data = array![1.0, 2.0, 3.0, 4.0, 5.0, 100.0]; // 100 is an outlier
489        let outliers = OutlierDetector::iqr_outliers(&data.view(), 1.5);
490        assert_eq!(outliers, vec![5]);
491    }
492
493    #[test]
494    fn test_standard_scaling() {
495        let data = array![[1.0, 10.0], [2.0, 20.0], [3.0, 30.0]];
496
497        let (scaled, _means, _stds) = FeatureScaler::standard_scale(&data).unwrap();
498
499        // Check that scaled data has mean ~0 and std ~1
500        for col_idx in 0..scaled.ncols() {
501            let col = scaled.column(col_idx);
502            let mean = col.iter().sum::<f64>() / col.len() as f64;
503            assert_abs_diff_eq!(mean, 0.0, epsilon = 1e-10);
504        }
505    }
506
507    #[test]
508    fn test_minmax_scaling() {
509        let data = array![[1.0, 10.0], [2.0, 20.0], [3.0, 30.0]];
510
511        let (scaled, _mins, _maxs) = FeatureScaler::minmax_scale(&data).unwrap();
512
513        // Check that scaled data is in [0, 1] range
514        for col_idx in 0..scaled.ncols() {
515            let col = scaled.column(col_idx);
516            let min_val = col.iter().cloned().fold(col[0], f64::min);
517            let max_val = col.iter().cloned().fold(col[0], f64::max);
518
519            assert_abs_diff_eq!(min_val, 0.0, epsilon = 1e-10);
520            assert_abs_diff_eq!(max_val, 1.0, epsilon = 1e-10);
521        }
522    }
523
524    #[test]
525    fn test_missing_value_stats() {
526        let data = array![[1.0, 2.0, 3.0], [f64::NAN, 5.0, 6.0], [7.0, f64::NAN, 9.0]];
527
528        let stats = DataQualityAssessor::missing_value_stats(&data);
529
530        assert_abs_diff_eq!(stats["total_missing_ratio"], 2.0 / 9.0, epsilon = 1e-10);
531        assert_eq!(stats["columns_with_missing"], 2.0);
532        assert_eq!(stats["rows_with_missing"], 2.0);
533    }
534
535    #[test]
536    fn test_quality_metrics() {
537        let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
538
539        let metrics = DataQualityAssessor::quality_metrics(&data);
540
541        // All data is present, so completeness should be 1.0
542        assert_abs_diff_eq!(metrics["completeness"], 1.0, epsilon = 1e-10);
543        assert!(metrics.contains_key("uniqueness"));
544        assert!(metrics.contains_key("consistency"));
545    }
546}