Skip to main content

xore_process/
profiler.rs

1//! 数据质量分析器
2
3use anyhow::Result;
4use polars::prelude::*;
5use std::collections::HashMap;
6
7/// 数据质量报告
8#[derive(Debug, Clone)]
9pub struct QualityReport {
10    /// 总行数
11    pub total_rows: usize,
12    /// 总列数
13    pub total_columns: usize,
14    /// 列名列表
15    pub column_names: Vec<String>,
16    /// 每列的缺失值统计
17    pub missing_values: HashMap<String, MissingStats>,
18    /// 重复行数
19    pub duplicate_rows: usize,
20    /// 数据类型信息
21    pub column_types: HashMap<String, String>,
22    /// 智能建议列表
23    pub suggestions: Vec<Suggestion>,
24    /// 离群值检测结果
25    pub outliers: HashMap<String, OutlierInfo>,
26}
27
28/// 智能建议
29#[derive(Debug, Clone)]
30pub struct Suggestion {
31    /// 建议类型
32    pub suggestion_type: SuggestionType,
33    /// 建议描述
34    pub message: String,
35    /// 相关列名(如果适用)
36    pub column: Option<String>,
37    /// 严重程度
38    pub severity: Severity,
39}
40
41/// 建议类型
42#[derive(Debug, Clone, PartialEq)]
43pub enum SuggestionType {
44    /// 缺失值处理
45    MissingValues,
46    /// 重复数据
47    Duplicates,
48    /// 离群值
49    Outliers,
50    /// 数据类型
51    DataType,
52    /// 数据分布
53    Distribution,
54}
55
56/// 严重程度
57#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
58pub enum Severity {
59    /// 信息
60    Info,
61    /// 警告
62    Warning,
63    /// 错误
64    Error,
65}
66
67/// 离群值信息
68#[derive(Debug, Clone)]
69pub struct OutlierInfo {
70    /// 离群值数量
71    pub count: usize,
72    /// 离群值百分比
73    pub percentage: f64,
74    /// 离群值索引(最多保留前 10 个)
75    pub indices: Vec<usize>,
76}
77
78/// 缺失值统计
79#[derive(Debug, Clone)]
80pub struct MissingStats {
81    /// 缺失值数量
82    pub count: usize,
83    /// 缺失值百分比
84    pub percentage: f64,
85}
86
87/// 列统计信息
88#[derive(Debug, Clone)]
89pub struct ColumnStats {
90    /// 列名
91    pub name: String,
92    /// 数据类型
93    pub dtype: String,
94    /// 唯一值数量
95    pub unique_count: usize,
96    /// 缺失值数量
97    pub null_count: usize,
98    /// 缺失值百分比
99    pub null_percentage: f64,
100}
101
102/// 数据分析器
103pub struct DataProfiler;
104
105impl DataProfiler {
106    /// 创建新的分析器
107    pub fn new() -> Self {
108        Self
109    }
110
111    /// 生成数据质量报告
112    pub fn profile(&self, df: &DataFrame) -> Result<QualityReport> {
113        let total_rows = df.height();
114        let total_columns = df.width();
115        let column_names: Vec<String> =
116            df.get_column_names().iter().map(|s| s.to_string()).collect();
117
118        // 统计缺失值
119        let mut missing_values = HashMap::new();
120        let mut column_types = HashMap::new();
121
122        for col_name in &column_names {
123            let series = df.column(col_name).map_err(|e| anyhow::anyhow!("获取列失败: {}", e))?;
124
125            // 缺失值统计
126            let null_count = series.null_count();
127            let percentage =
128                if total_rows > 0 { (null_count as f64 / total_rows as f64) * 100.0 } else { 0.0 };
129
130            if null_count > 0 {
131                missing_values
132                    .insert(col_name.clone(), MissingStats { count: null_count, percentage });
133            }
134
135            // 数据类型
136            column_types.insert(col_name.clone(), format!("{:?}", series.dtype()));
137        }
138
139        // 统计重复行(使用哈希)
140        let duplicate_rows = self.count_duplicates(df)?;
141
142        // 批量检测离群值
143        let outliers = self.detect_outliers_batch(df)?;
144
145        // 生成智能建议
146        let suggestions =
147            self.generate_suggestions(total_rows, &missing_values, duplicate_rows, &outliers);
148
149        Ok(QualityReport {
150            total_rows,
151            total_columns,
152            column_names,
153            missing_values,
154            duplicate_rows,
155            column_types,
156            suggestions,
157            outliers,
158        })
159    }
160
161    /// 批量检测所有数值列的离群值
162    fn detect_outliers_batch(&self, df: &DataFrame) -> Result<HashMap<String, OutlierInfo>> {
163        let mut outliers = HashMap::new();
164
165        for col_name in df.get_column_names() {
166            let column = df.column(col_name).map_err(|e| anyhow::anyhow!("获取列失败: {}", e))?;
167            let series = column.as_materialized_series();
168
169            // 只处理数值类型
170            if !series.dtype().is_numeric() {
171                continue;
172            }
173
174            // 检测离群值
175            match self.detect_outliers(df, col_name) {
176                Ok(indices) if !indices.is_empty() => {
177                    let count = indices.len();
178                    let percentage = if df.height() > 0 {
179                        (count as f64 / df.height() as f64) * 100.0
180                    } else {
181                        0.0
182                    };
183
184                    // 只保留前 10 个索引
185                    let indices_sample = indices.into_iter().take(10).collect();
186
187                    outliers.insert(
188                        col_name.to_string(),
189                        OutlierInfo { count, percentage, indices: indices_sample },
190                    );
191                }
192                _ => {}
193            }
194        }
195
196        Ok(outliers)
197    }
198
199    /// 生成智能建议
200    fn generate_suggestions(
201        &self,
202        total_rows: usize,
203        missing_values: &HashMap<String, MissingStats>,
204        duplicate_rows: usize,
205        outliers: &HashMap<String, OutlierInfo>,
206    ) -> Vec<Suggestion> {
207        let mut suggestions = Vec::new();
208
209        // 缺失值建议
210        for (col_name, stats) in missing_values {
211            let severity = if stats.percentage > 50.0 {
212                Severity::Error
213            } else if stats.percentage > 10.0 {
214                Severity::Warning
215            } else {
216                Severity::Info
217            };
218
219            let message = if stats.percentage > 50.0 {
220                format!(
221                    "列 '{}' 缺失值过多 ({:.1}%),建议检查数据源或考虑删除该列",
222                    col_name, stats.percentage
223                )
224            } else if stats.percentage > 10.0 {
225                format!(
226                    "列 '{}' 存在 {:.1}% 缺失值,建议使用插值或填充方法处理",
227                    col_name, stats.percentage
228                )
229            } else {
230                format!("列 '{}' 存在少量缺失值 ({:.1}%)", col_name, stats.percentage)
231            };
232
233            suggestions.push(Suggestion {
234                suggestion_type: SuggestionType::MissingValues,
235                message,
236                column: Some(col_name.clone()),
237                severity,
238            });
239        }
240
241        // 重复数据建议
242        if duplicate_rows > 0 {
243            let dup_percentage = if total_rows > 0 {
244                (duplicate_rows as f64 / total_rows as f64) * 100.0
245            } else {
246                0.0
247            };
248
249            let severity = if dup_percentage > 20.0 { Severity::Warning } else { Severity::Info };
250
251            let message = if dup_percentage > 20.0 {
252                format!(
253                    "检测到 {} 行重复数据 ({:.1}%),强烈建议去重以提高数据质量",
254                    duplicate_rows, dup_percentage
255                )
256            } else {
257                format!("检测到 {} 行重复数据,建议检查是否需要去重", duplicate_rows)
258            };
259
260            suggestions.push(Suggestion {
261                suggestion_type: SuggestionType::Duplicates,
262                message,
263                column: None,
264                severity,
265            });
266        }
267
268        // 离群值建议
269        for (col_name, info) in outliers {
270            let severity = if info.percentage > 5.0 { Severity::Warning } else { Severity::Info };
271
272            let message = if info.percentage > 5.0 {
273                format!(
274                    "列 '{}' 检测到 {} 个离群值 ({:.1}%),建议检查数据异常",
275                    col_name, info.count, info.percentage
276                )
277            } else {
278                format!("列 '{}' 检测到 {} 个离群值,可能需要进一步分析", col_name, info.count)
279            };
280
281            suggestions.push(Suggestion {
282                suggestion_type: SuggestionType::Outliers,
283                message,
284                column: Some(col_name.clone()),
285                severity,
286            });
287        }
288
289        // 按严重程度排序
290        suggestions.sort_by(|a, b| b.severity.cmp(&a.severity));
291
292        suggestions
293    }
294
295    /// 统计重复行数
296    fn count_duplicates(&self, df: &DataFrame) -> Result<usize> {
297        // 使用 Polars 的 is_duplicated 功能
298        let mask = df.is_duplicated().map_err(|e| anyhow::anyhow!("检测重复行失败: {}", e))?;
299
300        let duplicate_count = mask.sum().ok_or_else(|| anyhow::anyhow!("计算重复行数失败"))?;
301
302        Ok(duplicate_count as usize)
303    }
304
305    /// 获取列的详细统计信息
306    pub fn column_stats(&self, df: &DataFrame, column_name: &str) -> Result<ColumnStats> {
307        let series = df.column(column_name).map_err(|e| anyhow::anyhow!("获取列失败: {}", e))?;
308
309        let total_rows = df.height();
310        let null_count = series.null_count();
311        let null_percentage =
312            if total_rows > 0 { (null_count as f64 / total_rows as f64) * 100.0 } else { 0.0 };
313
314        let unique_count =
315            series.n_unique().map_err(|e| anyhow::anyhow!("计算唯一值失败: {}", e))?;
316
317        Ok(ColumnStats {
318            name: column_name.to_string(),
319            dtype: format!("{:?}", series.dtype()),
320            unique_count,
321            null_count,
322            null_percentage,
323        })
324    }
325
326    /// 检测数值列的离群值(使用 IQR 方法)
327    /// 注意:此功能需要将 Column 转换为 Series
328    pub fn detect_outliers(&self, df: &DataFrame, column_name: &str) -> Result<Vec<usize>> {
329        let column = df.column(column_name).map_err(|e| anyhow::anyhow!("获取列失败: {}", e))?;
330
331        // 转换为 Series
332        let series = column.as_materialized_series();
333
334        // 只处理数值类型
335        if !series.dtype().is_numeric() {
336            return Err(anyhow::anyhow!("列 {} 不是数值类型", column_name));
337        }
338
339        // 计算 Q1, Q3 使用 median 和排序
340        let sorted =
341            series.sort(Default::default()).map_err(|e| anyhow::anyhow!("排序失败: {}", e))?;
342
343        let len = sorted.len();
344        let q1_idx = len / 4;
345        let q3_idx = (3 * len) / 4;
346
347        let q1_val = sorted
348            .get(q1_idx)
349            .map_err(|e| anyhow::anyhow!("获取 Q1 失败: {}", e))?
350            .try_extract::<f64>()
351            .map_err(|e| anyhow::anyhow!("提取 Q1 值失败: {}", e))?;
352
353        let q3_val = sorted
354            .get(q3_idx)
355            .map_err(|e| anyhow::anyhow!("获取 Q3 失败: {}", e))?
356            .try_extract::<f64>()
357            .map_err(|e| anyhow::anyhow!("提取 Q3 值失败: {}", e))?;
358
359        let iqr = q3_val - q1_val;
360        let lower_bound = q1_val - 1.5 * iqr;
361        let upper_bound = q3_val + 1.5 * iqr;
362
363        // 找出离群值的索引
364        let mut outlier_indices = Vec::new();
365        for idx in 0..series.len() {
366            if let Ok(val) = series.get(idx) {
367                if let Ok(v) = val.try_extract::<f64>() {
368                    if v < lower_bound || v > upper_bound {
369                        outlier_indices.push(idx);
370                    }
371                }
372            }
373        }
374
375        Ok(outlier_indices)
376    }
377}
378
379impl Default for DataProfiler {
380    fn default() -> Self {
381        Self::new()
382    }
383}
384
385#[cfg(test)]
386mod tests {
387    use super::*;
388    use polars::df;
389
390    #[test]
391    fn test_profile_basic() {
392        let df = df! {
393            "id" => &[1, 2, 3, 4, 5],
394            "name" => &["Alice", "Bob", "Charlie", "David", "Eve"],
395            "age" => &[Some(25), Some(30), None, Some(35), Some(40)],
396        }
397        .unwrap();
398
399        let profiler = DataProfiler::new();
400        let report = profiler.profile(&df).unwrap();
401
402        assert_eq!(report.total_rows, 5);
403        assert_eq!(report.total_columns, 3);
404        assert_eq!(report.column_names.len(), 3);
405
406        // 检查缺失值统计
407        assert!(report.missing_values.contains_key("age"));
408        let age_missing = &report.missing_values["age"];
409        assert_eq!(age_missing.count, 1);
410        assert_eq!(age_missing.percentage, 20.0);
411    }
412
413    #[test]
414    fn test_column_stats() {
415        let df = df! {
416            "id" => &[1, 2, 3, 4, 5],
417            "category" => &["A", "B", "A", "C", "B"],
418        }
419        .unwrap();
420
421        let profiler = DataProfiler::new();
422        let stats = profiler.column_stats(&df, "category").unwrap();
423
424        assert_eq!(stats.name, "category");
425        assert_eq!(stats.unique_count, 3); // A, B, C
426        assert_eq!(stats.null_count, 0);
427    }
428
429    #[test]
430    fn test_count_duplicates() {
431        let df = df! {
432            "id" => &[1, 2, 3, 2, 1],
433            "value" => &[10, 20, 30, 20, 10],
434        }
435        .unwrap();
436
437        let profiler = DataProfiler::new();
438        let duplicate_count = profiler.count_duplicates(&df).unwrap();
439
440        // is_duplicated() 标记所有重复的行(包括原始行)
441        // 所以 [1,10], [2,20], [2,20], [1,10] 都被标记,共4行
442        assert_eq!(duplicate_count, 4);
443    }
444
445    #[test]
446    fn test_detect_outliers() {
447        let df = df! {
448            "values" => &[1.0, 2.0, 3.0, 4.0, 5.0, 100.0], // 100.0 是离群值
449        }
450        .unwrap();
451
452        let profiler = DataProfiler::new();
453        let outliers = profiler.detect_outliers(&df, "values").unwrap();
454
455        assert!(!outliers.is_empty());
456        assert!(outliers.contains(&5)); // 索引 5 是离群值
457    }
458
459    #[test]
460    fn test_no_missing_values() {
461        let df = df! {
462            "id" => &[1, 2, 3],
463            "name" => &["A", "B", "C"],
464        }
465        .unwrap();
466
467        let profiler = DataProfiler::new();
468        let report = profiler.profile(&df).unwrap();
469
470        assert!(report.missing_values.is_empty());
471    }
472
473    #[test]
474    fn test_suggestions_generation() {
475        let df = df! {
476            "id" => &[1, 2, 3, 4, 5],
477            "age" => &[Some(25), None, None, Some(35), Some(40)], // 40% 缺失
478            "score" => &[1.0, 2.0, 3.0, 4.0, 100.0], // 有离群值
479        }
480        .unwrap();
481
482        let profiler = DataProfiler::new();
483        let report = profiler.profile(&df).unwrap();
484
485        // 应该生成建议
486        assert!(!report.suggestions.is_empty());
487
488        // 应该有缺失值建议
489        let missing_suggestions: Vec<_> = report
490            .suggestions
491            .iter()
492            .filter(|s| s.suggestion_type == SuggestionType::MissingValues)
493            .collect();
494        assert!(!missing_suggestions.is_empty());
495
496        // 应该有离群值建议
497        let outlier_suggestions: Vec<_> = report
498            .suggestions
499            .iter()
500            .filter(|s| s.suggestion_type == SuggestionType::Outliers)
501            .collect();
502        assert!(!outlier_suggestions.is_empty());
503    }
504
505    #[test]
506    fn test_batch_outlier_detection() {
507        let df = df! {
508            "col1" => &[1.0, 2.0, 3.0, 4.0, 5.0, 100.0],
509            "col2" => &[10.0, 20.0, 30.0, 40.0, 50.0, 500.0],
510            "col3" => &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], // 无离群值
511        }
512        .unwrap();
513
514        let profiler = DataProfiler::new();
515        let report = profiler.profile(&df).unwrap();
516
517        // col1 和 col2 应该有离群值
518        assert!(report.outliers.contains_key("col1"));
519        assert!(report.outliers.contains_key("col2"));
520
521        // col3 不应该有离群值
522        assert!(!report.outliers.contains_key("col3"));
523
524        // 检查离群值信息
525        let col1_outliers = &report.outliers["col1"];
526        assert!(col1_outliers.count > 0);
527        assert!(col1_outliers.percentage > 0.0);
528    }
529
530    #[test]
531    fn test_duplicate_suggestions() {
532        let df = df! {
533            "id" => &[1, 2, 3, 2, 1], // 有重复
534            "value" => &[10, 20, 30, 20, 10],
535        }
536        .unwrap();
537
538        let profiler = DataProfiler::new();
539        let report = profiler.profile(&df).unwrap();
540
541        // 应该有重复数据建议
542        let dup_suggestions: Vec<_> = report
543            .suggestions
544            .iter()
545            .filter(|s| s.suggestion_type == SuggestionType::Duplicates)
546            .collect();
547        assert!(!dup_suggestions.is_empty());
548    }
549
550    #[test]
551    fn test_severity_levels() {
552        let df = df! {
553            "critical" => &[Some(1), None, None, None, None], // 80% 缺失 -> Error
554            "warning" => &[Some(1), Some(2), None, Some(4), Some(5)], // 20% 缺失 -> Warning/Info
555        }
556        .unwrap();
557
558        let profiler = DataProfiler::new();
559        let report = profiler.profile(&df).unwrap();
560
561        // 应该有不同严重程度的建议
562        let has_error = report.suggestions.iter().any(|s| s.severity == Severity::Error);
563        let has_warning_or_info = report
564            .suggestions
565            .iter()
566            .any(|s| s.severity == Severity::Warning || s.severity == Severity::Info);
567
568        assert!(has_error || has_warning_or_info);
569    }
570}