1use crate::{UtilsError, UtilsResult};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1, Axis};
8use scirs2_core::numeric::Float;
9use std::collections::HashMap;
10
11pub struct DataCleaner;
13
14impl DataCleaner {
15 pub fn drop_missing_rows<T>(data: &Array2<T>) -> UtilsResult<Array2<T>>
17 where
18 T: Float + Clone + std::iter::Sum,
19 {
20 let mut valid_rows = Vec::new();
21
22 for (row_idx, row) in data.axis_iter(Axis(0)).enumerate() {
23 if !row.iter().any(|&x| x.is_nan()) {
24 valid_rows.push(row_idx);
25 }
26 }
27
28 if valid_rows.is_empty() {
29 return Err(UtilsError::EmptyInput);
30 }
31
32 let mut result = Array2::zeros((valid_rows.len(), data.ncols()));
33 for (new_idx, &old_idx) in valid_rows.iter().enumerate() {
34 result.row_mut(new_idx).assign(&data.row(old_idx));
35 }
36
37 Ok(result)
38 }
39
40 pub fn fill_missing<T>(data: &mut Array2<T>, fill_value: T)
42 where
43 T: Float + Clone + std::iter::Sum,
44 {
45 data.mapv_inplace(|x| if x.is_nan() { fill_value } else { x });
46 }
47
48 pub fn fill_with_mean<T>(data: &mut Array2<T>) -> UtilsResult<()>
50 where
51 T: Float + Clone + std::iter::Sum,
52 {
53 for col_idx in 0..data.ncols() {
54 let col = data.column(col_idx);
55 let valid_values: Vec<T> = col.iter().cloned().filter(|x| !x.is_nan()).collect();
56
57 if !valid_values.is_empty() {
58 let mean =
59 valid_values.iter().cloned().sum::<T>() / T::from(valid_values.len()).unwrap();
60
61 for row_idx in 0..data.nrows() {
62 if data[[row_idx, col_idx]].is_nan() {
63 data[[row_idx, col_idx]] = mean;
64 }
65 }
66 }
67 }
68 Ok(())
69 }
70
71 pub fn fill_with_median<T>(data: &mut Array2<T>) -> UtilsResult<()>
73 where
74 T: Float + Clone + PartialOrd,
75 {
76 for col_idx in 0..data.ncols() {
77 let col = data.column(col_idx);
78 let mut valid_values: Vec<T> = col.iter().cloned().filter(|x| !x.is_nan()).collect();
79
80 if !valid_values.is_empty() {
81 valid_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
82 let median = if valid_values.len() % 2 == 0 {
83 let mid = valid_values.len() / 2;
84 (valid_values[mid - 1] + valid_values[mid]) / T::from(2).unwrap()
85 } else {
86 valid_values[valid_values.len() / 2]
87 };
88
89 for row_idx in 0..data.nrows() {
90 if data[[row_idx, col_idx]].is_nan() {
91 data[[row_idx, col_idx]] = median;
92 }
93 }
94 }
95 }
96 Ok(())
97 }
98}
99
100pub struct OutlierDetector;
102
103impl OutlierDetector {
104 pub fn zscore_outliers<T>(data: &ArrayView1<T>, threshold: T) -> Vec<usize>
106 where
107 T: Float + Clone + std::iter::Sum,
108 {
109 let mean = data.iter().cloned().sum::<T>() / T::from(data.len()).unwrap();
110 let variance =
111 data.iter().map(|&x| (x - mean).powi(2)).sum::<T>() / T::from(data.len()).unwrap();
112 let std_dev = variance.sqrt();
113
114 if std_dev == T::zero() {
115 return Vec::new();
116 }
117
118 data.iter()
119 .enumerate()
120 .filter_map(|(idx, &value)| {
121 let z_score = (value - mean).abs() / std_dev;
122 if z_score > threshold {
123 Some(idx)
124 } else {
125 None
126 }
127 })
128 .collect()
129 }
130
131 pub fn iqr_outliers<T>(data: &ArrayView1<T>, multiplier: T) -> Vec<usize>
133 where
134 T: Float + Clone + PartialOrd,
135 {
136 let mut sorted_data: Vec<T> = data.iter().cloned().collect();
137 sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap());
138
139 let n = sorted_data.len();
140 if n < 4 {
141 return Vec::new();
142 }
143
144 let q1_idx = n / 4;
145 let q3_idx = 3 * n / 4;
146 let q1 = sorted_data[q1_idx];
147 let q3 = sorted_data[q3_idx];
148 let iqr = q3 - q1;
149
150 let lower_bound = q1 - multiplier * iqr;
151 let upper_bound = q3 + multiplier * iqr;
152
153 data.iter()
154 .enumerate()
155 .filter_map(|(idx, &value)| {
156 if value < lower_bound || value > upper_bound {
157 Some(idx)
158 } else {
159 None
160 }
161 })
162 .collect()
163 }
164
165 pub fn modified_zscore_outliers<T>(data: &ArrayView1<T>, threshold: T) -> Vec<usize>
167 where
168 T: Float + Clone + PartialOrd,
169 {
170 let mut sorted_data: Vec<T> = data.iter().cloned().collect();
171 sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap());
172
173 let n = sorted_data.len();
174 if n == 0 {
175 return Vec::new();
176 }
177
178 let median = if n % 2 == 0 {
179 (sorted_data[n / 2 - 1] + sorted_data[n / 2]) / T::from(2).unwrap()
180 } else {
181 sorted_data[n / 2]
182 };
183
184 let mut deviations: Vec<T> = data.iter().map(|&x| (x - median).abs()).collect();
186 deviations.sort_by(|a, b| a.partial_cmp(b).unwrap());
187
188 let mad = if deviations.len() % 2 == 0 {
189 let mid = deviations.len() / 2;
190 (deviations[mid - 1] + deviations[mid]) / T::from(2).unwrap()
191 } else {
192 deviations[deviations.len() / 2]
193 };
194
195 if mad == T::zero() {
196 return Vec::new();
197 }
198
199 let mad_scaled = mad * T::from(1.4826).unwrap(); data.iter()
202 .enumerate()
203 .filter_map(|(idx, &value)| {
204 let modified_z = T::from(0.6745).unwrap() * (value - median).abs() / mad_scaled;
205 if modified_z > threshold {
206 Some(idx)
207 } else {
208 None
209 }
210 })
211 .collect()
212 }
213}
214
215pub struct FeatureScaler;
217
218impl FeatureScaler {
219 pub fn standard_scale<T>(data: &Array2<T>) -> UtilsResult<(Array2<T>, Array1<T>, Array1<T>)>
221 where
222 T: Float + Clone + std::iter::Sum,
223 {
224 let mut scaled_data = data.clone();
225 let mut means = Array1::zeros(data.ncols());
226 let mut stds = Array1::zeros(data.ncols());
227
228 for col_idx in 0..data.ncols() {
229 let col = data.column(col_idx);
230 let mean = col.iter().cloned().sum::<T>() / T::from(col.len()).unwrap();
231 let variance =
232 col.iter().map(|&x| (x - mean).powi(2)).sum::<T>() / T::from(col.len()).unwrap();
233 let std_dev = variance.sqrt();
234
235 means[col_idx] = mean;
236 stds[col_idx] = std_dev;
237
238 if std_dev != T::zero() {
239 for row_idx in 0..data.nrows() {
240 scaled_data[[row_idx, col_idx]] = (data[[row_idx, col_idx]] - mean) / std_dev;
241 }
242 }
243 }
244
245 Ok((scaled_data, means, stds))
246 }
247
248 pub fn minmax_scale<T>(data: &Array2<T>) -> UtilsResult<(Array2<T>, Array1<T>, Array1<T>)>
250 where
251 T: Float + Clone + PartialOrd,
252 {
253 let mut scaled_data = data.clone();
254 let mut mins = Array1::zeros(data.ncols());
255 let mut maxs = Array1::zeros(data.ncols());
256
257 for col_idx in 0..data.ncols() {
258 let col = data.column(col_idx);
259 let min_val = col
260 .iter()
261 .cloned()
262 .fold(col[0], |acc, x| if x < acc { x } else { acc });
263 let max_val = col
264 .iter()
265 .cloned()
266 .fold(col[0], |acc, x| if x > acc { x } else { acc });
267
268 mins[col_idx] = min_val;
269 maxs[col_idx] = max_val;
270
271 let range = max_val - min_val;
272 if range != T::zero() {
273 for row_idx in 0..data.nrows() {
274 scaled_data[[row_idx, col_idx]] = (data[[row_idx, col_idx]] - min_val) / range;
275 }
276 }
277 }
278
279 Ok((scaled_data, mins, maxs))
280 }
281
282 pub fn robust_scale<T>(data: &Array2<T>) -> UtilsResult<(Array2<T>, Array1<T>, Array1<T>)>
284 where
285 T: Float + Clone + PartialOrd,
286 {
287 let mut scaled_data = data.clone();
288 let mut medians = Array1::zeros(data.ncols());
289 let mut iqrs = Array1::zeros(data.ncols());
290
291 for col_idx in 0..data.ncols() {
292 let col = data.column(col_idx);
293 let mut sorted_col: Vec<T> = col.iter().cloned().collect();
294 sorted_col.sort_by(|a, b| a.partial_cmp(b).unwrap());
295
296 let n = sorted_col.len();
297 let median = if n % 2 == 0 {
298 (sorted_col[n / 2 - 1] + sorted_col[n / 2]) / T::from(2).unwrap()
299 } else {
300 sorted_col[n / 2]
301 };
302
303 let q1_idx = n / 4;
304 let q3_idx = 3 * n / 4;
305 let q1 = sorted_col[q1_idx];
306 let q3 = sorted_col[q3_idx];
307 let iqr = q3 - q1;
308
309 medians[col_idx] = median;
310 iqrs[col_idx] = iqr;
311
312 if iqr != T::zero() {
313 for row_idx in 0..data.nrows() {
314 scaled_data[[row_idx, col_idx]] = (data[[row_idx, col_idx]] - median) / iqr;
315 }
316 }
317 }
318
319 Ok((scaled_data, medians, iqrs))
320 }
321}
322
323pub struct DataQualityAssessor;
325
326impl DataQualityAssessor {
327 pub fn missing_value_stats<T>(data: &Array2<T>) -> HashMap<String, f64>
329 where
330 T: Float,
331 {
332 let total_cells = data.len() as f64;
333 let mut missing_count = 0;
334 let mut missing_per_column = Vec::new();
335 let mut missing_per_row = Vec::new();
336
337 for col_idx in 0..data.ncols() {
339 let col_missing = data.column(col_idx).iter().filter(|&&x| x.is_nan()).count();
340 missing_per_column.push(col_missing as f64 / data.nrows() as f64);
341 missing_count += col_missing;
342 }
343
344 for row_idx in 0..data.nrows() {
346 let row_missing = data.row(row_idx).iter().filter(|&&x| x.is_nan()).count();
347 missing_per_row.push(row_missing as f64 / data.ncols() as f64);
348 }
349
350 let mut stats = HashMap::new();
351 stats.insert(
352 "total_missing_ratio".to_string(),
353 missing_count as f64 / total_cells,
354 );
355 stats.insert(
356 "max_column_missing_ratio".to_string(),
357 missing_per_column.iter().cloned().fold(0.0, f64::max),
358 );
359 stats.insert(
360 "max_row_missing_ratio".to_string(),
361 missing_per_row.iter().cloned().fold(0.0, f64::max),
362 );
363 stats.insert(
364 "columns_with_missing".to_string(),
365 missing_per_column.iter().filter(|&&x| x > 0.0).count() as f64,
366 );
367 stats.insert(
368 "rows_with_missing".to_string(),
369 missing_per_row.iter().filter(|&&x| x > 0.0).count() as f64,
370 );
371
372 stats
373 }
374
375 pub fn quality_metrics<T>(data: &Array2<T>) -> HashMap<String, f64>
377 where
378 T: Float + PartialOrd + std::iter::Sum + std::fmt::Display,
379 {
380 let mut metrics = HashMap::new();
381
382 let total_cells = data.len() as f64;
384 let missing_count = data.iter().filter(|&&x| x.is_nan()).count() as f64;
385 metrics.insert(
386 "completeness".to_string(),
387 1.0 - (missing_count / total_cells),
388 );
389
390 let mut unique_counts = Vec::new();
392 for col_idx in 0..data.ncols() {
393 let col = data.column(col_idx);
394 let mut unique_values = std::collections::HashSet::new();
395 for &value in col.iter() {
396 if !value.is_nan() {
397 unique_values.insert(format!("{value:.6}"));
399 }
400 }
401 let uniqueness = unique_values.len() as f64 / col.len() as f64;
402 unique_counts.push(uniqueness);
403 }
404
405 let avg_uniqueness = unique_counts.iter().sum::<f64>() / unique_counts.len() as f64;
406 metrics.insert("uniqueness".to_string(), avg_uniqueness);
407
408 let mut cv_values = Vec::new();
410 for col_idx in 0..data.ncols() {
411 let col = data.column(col_idx);
412 let valid_values: Vec<T> = col.iter().cloned().filter(|x| !x.is_nan()).collect();
413
414 if valid_values.len() > 1 {
415 let mean =
416 valid_values.iter().cloned().sum::<T>() / T::from(valid_values.len()).unwrap();
417 let variance = valid_values.iter().map(|&x| (x - mean).powi(2)).sum::<T>()
418 / T::from(valid_values.len()).unwrap();
419 let std_dev = variance.sqrt();
420
421 if mean != T::zero() {
422 let cv = (std_dev / mean.abs()).to_f64().unwrap();
423 cv_values.push(cv);
424 }
425 }
426 }
427
428 if !cv_values.is_empty() {
429 let avg_cv = cv_values.iter().sum::<f64>() / cv_values.len() as f64;
430 metrics.insert("consistency".to_string(), 1.0 / (1.0 + avg_cv)); }
432
433 metrics
434 }
435}
436
437#[allow(non_snake_case)]
438#[cfg(test)]
439mod tests {
440 use super::*;
441 use approx::assert_abs_diff_eq;
442 use scirs2_core::ndarray::array;
443
444 #[test]
445 fn test_drop_missing_rows() {
446 let data = array![
447 [1.0, 2.0, 3.0],
448 [4.0, f64::NAN, 6.0],
449 [7.0, 8.0, 9.0],
450 [f64::NAN, 11.0, 12.0]
451 ];
452
453 let cleaned = DataCleaner::drop_missing_rows(&data).unwrap();
454 assert_eq!(cleaned.nrows(), 2);
455 assert_eq!(cleaned.row(0), array![1.0, 2.0, 3.0]);
456 assert_eq!(cleaned.row(1), array![7.0, 8.0, 9.0]);
457 }
458
459 #[test]
460 fn test_fill_missing_with_value() {
461 let mut data = array![[1.0, 2.0], [f64::NAN, 4.0], [5.0, f64::NAN]];
462
463 DataCleaner::fill_missing(&mut data, 0.0);
464
465 assert_eq!(data, array![[1.0, 2.0], [0.0, 4.0], [5.0, 0.0]]);
466 }
467
468 #[test]
469 fn test_fill_with_mean() {
470 let mut data = array![[1.0, 2.0], [f64::NAN, 4.0], [5.0, f64::NAN]];
471
472 DataCleaner::fill_with_mean(&mut data).unwrap();
473
474 assert_abs_diff_eq!(data[[1, 0]], 3.0, epsilon = 1e-10);
476 assert_abs_diff_eq!(data[[2, 1]], 3.0, epsilon = 1e-10);
477 }
478
479 #[test]
480 fn test_zscore_outliers() {
481 let data = array![1.0, 2.0, 3.0, 4.0, 100.0]; let outliers = OutlierDetector::zscore_outliers(&data.view(), 1.5);
483 assert_eq!(outliers, vec![4]);
484 }
485
486 #[test]
487 fn test_iqr_outliers() {
488 let data = array![1.0, 2.0, 3.0, 4.0, 5.0, 100.0]; let outliers = OutlierDetector::iqr_outliers(&data.view(), 1.5);
490 assert_eq!(outliers, vec![5]);
491 }
492
493 #[test]
494 fn test_standard_scaling() {
495 let data = array![[1.0, 10.0], [2.0, 20.0], [3.0, 30.0]];
496
497 let (scaled, _means, _stds) = FeatureScaler::standard_scale(&data).unwrap();
498
499 for col_idx in 0..scaled.ncols() {
501 let col = scaled.column(col_idx);
502 let mean = col.iter().sum::<f64>() / col.len() as f64;
503 assert_abs_diff_eq!(mean, 0.0, epsilon = 1e-10);
504 }
505 }
506
507 #[test]
508 fn test_minmax_scaling() {
509 let data = array![[1.0, 10.0], [2.0, 20.0], [3.0, 30.0]];
510
511 let (scaled, _mins, _maxs) = FeatureScaler::minmax_scale(&data).unwrap();
512
513 for col_idx in 0..scaled.ncols() {
515 let col = scaled.column(col_idx);
516 let min_val = col.iter().cloned().fold(col[0], f64::min);
517 let max_val = col.iter().cloned().fold(col[0], f64::max);
518
519 assert_abs_diff_eq!(min_val, 0.0, epsilon = 1e-10);
520 assert_abs_diff_eq!(max_val, 1.0, epsilon = 1e-10);
521 }
522 }
523
524 #[test]
525 fn test_missing_value_stats() {
526 let data = array![[1.0, 2.0, 3.0], [f64::NAN, 5.0, 6.0], [7.0, f64::NAN, 9.0]];
527
528 let stats = DataQualityAssessor::missing_value_stats(&data);
529
530 assert_abs_diff_eq!(stats["total_missing_ratio"], 2.0 / 9.0, epsilon = 1e-10);
531 assert_eq!(stats["columns_with_missing"], 2.0);
532 assert_eq!(stats["rows_with_missing"], 2.0);
533 }
534
535 #[test]
536 fn test_quality_metrics() {
537 let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
538
539 let metrics = DataQualityAssessor::quality_metrics(&data);
540
541 assert_abs_diff_eq!(metrics["completeness"], 1.0, epsilon = 1e-10);
543 assert!(metrics.contains_key("uniqueness"));
544 assert!(metrics.contains_key("consistency"));
545 }
546}