1use anyhow::Result;
4use polars::prelude::*;
5use std::collections::HashMap;
6
7#[derive(Debug, Clone)]
9pub struct QualityReport {
10 pub total_rows: usize,
12 pub total_columns: usize,
14 pub column_names: Vec<String>,
16 pub missing_values: HashMap<String, MissingStats>,
18 pub duplicate_rows: usize,
20 pub column_types: HashMap<String, String>,
22 pub suggestions: Vec<Suggestion>,
24 pub outliers: HashMap<String, OutlierInfo>,
26}
27
28#[derive(Debug, Clone)]
30pub struct Suggestion {
31 pub suggestion_type: SuggestionType,
33 pub message: String,
35 pub column: Option<String>,
37 pub severity: Severity,
39}
40
41#[derive(Debug, Clone, PartialEq)]
43pub enum SuggestionType {
44 MissingValues,
46 Duplicates,
48 Outliers,
50 DataType,
52 Distribution,
54}
55
56#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
58pub enum Severity {
59 Info,
61 Warning,
63 Error,
65}
66
67#[derive(Debug, Clone)]
69pub struct OutlierInfo {
70 pub count: usize,
72 pub percentage: f64,
74 pub indices: Vec<usize>,
76}
77
78#[derive(Debug, Clone)]
80pub struct MissingStats {
81 pub count: usize,
83 pub percentage: f64,
85}
86
87#[derive(Debug, Clone)]
89pub struct ColumnStats {
90 pub name: String,
92 pub dtype: String,
94 pub unique_count: usize,
96 pub null_count: usize,
98 pub null_percentage: f64,
100}
101
102pub struct DataProfiler;
104
105impl DataProfiler {
106 pub fn new() -> Self {
108 Self
109 }
110
111 pub fn profile(&self, df: &DataFrame) -> Result<QualityReport> {
113 let total_rows = df.height();
114 let total_columns = df.width();
115 let column_names: Vec<String> =
116 df.get_column_names().iter().map(|s| s.to_string()).collect();
117
118 let mut missing_values = HashMap::new();
120 let mut column_types = HashMap::new();
121
122 for col_name in &column_names {
123 let series = df.column(col_name).map_err(|e| anyhow::anyhow!("获取列失败: {}", e))?;
124
125 let null_count = series.null_count();
127 let percentage =
128 if total_rows > 0 { (null_count as f64 / total_rows as f64) * 100.0 } else { 0.0 };
129
130 if null_count > 0 {
131 missing_values
132 .insert(col_name.clone(), MissingStats { count: null_count, percentage });
133 }
134
135 column_types.insert(col_name.clone(), format!("{:?}", series.dtype()));
137 }
138
139 let duplicate_rows = self.count_duplicates(df)?;
141
142 let outliers = self.detect_outliers_batch(df)?;
144
145 let suggestions =
147 self.generate_suggestions(total_rows, &missing_values, duplicate_rows, &outliers);
148
149 Ok(QualityReport {
150 total_rows,
151 total_columns,
152 column_names,
153 missing_values,
154 duplicate_rows,
155 column_types,
156 suggestions,
157 outliers,
158 })
159 }
160
161 fn detect_outliers_batch(&self, df: &DataFrame) -> Result<HashMap<String, OutlierInfo>> {
163 let mut outliers = HashMap::new();
164
165 for col_name in df.get_column_names() {
166 let column = df.column(col_name).map_err(|e| anyhow::anyhow!("获取列失败: {}", e))?;
167 let series = column.as_materialized_series();
168
169 if !series.dtype().is_numeric() {
171 continue;
172 }
173
174 match self.detect_outliers(df, col_name) {
176 Ok(indices) if !indices.is_empty() => {
177 let count = indices.len();
178 let percentage = if df.height() > 0 {
179 (count as f64 / df.height() as f64) * 100.0
180 } else {
181 0.0
182 };
183
184 let indices_sample = indices.into_iter().take(10).collect();
186
187 outliers.insert(
188 col_name.to_string(),
189 OutlierInfo { count, percentage, indices: indices_sample },
190 );
191 }
192 _ => {}
193 }
194 }
195
196 Ok(outliers)
197 }
198
199 fn generate_suggestions(
201 &self,
202 total_rows: usize,
203 missing_values: &HashMap<String, MissingStats>,
204 duplicate_rows: usize,
205 outliers: &HashMap<String, OutlierInfo>,
206 ) -> Vec<Suggestion> {
207 let mut suggestions = Vec::new();
208
209 for (col_name, stats) in missing_values {
211 let severity = if stats.percentage > 50.0 {
212 Severity::Error
213 } else if stats.percentage > 10.0 {
214 Severity::Warning
215 } else {
216 Severity::Info
217 };
218
219 let message = if stats.percentage > 50.0 {
220 format!(
221 "列 '{}' 缺失值过多 ({:.1}%),建议检查数据源或考虑删除该列",
222 col_name, stats.percentage
223 )
224 } else if stats.percentage > 10.0 {
225 format!(
226 "列 '{}' 存在 {:.1}% 缺失值,建议使用插值或填充方法处理",
227 col_name, stats.percentage
228 )
229 } else {
230 format!("列 '{}' 存在少量缺失值 ({:.1}%)", col_name, stats.percentage)
231 };
232
233 suggestions.push(Suggestion {
234 suggestion_type: SuggestionType::MissingValues,
235 message,
236 column: Some(col_name.clone()),
237 severity,
238 });
239 }
240
241 if duplicate_rows > 0 {
243 let dup_percentage = if total_rows > 0 {
244 (duplicate_rows as f64 / total_rows as f64) * 100.0
245 } else {
246 0.0
247 };
248
249 let severity = if dup_percentage > 20.0 { Severity::Warning } else { Severity::Info };
250
251 let message = if dup_percentage > 20.0 {
252 format!(
253 "检测到 {} 行重复数据 ({:.1}%),强烈建议去重以提高数据质量",
254 duplicate_rows, dup_percentage
255 )
256 } else {
257 format!("检测到 {} 行重复数据,建议检查是否需要去重", duplicate_rows)
258 };
259
260 suggestions.push(Suggestion {
261 suggestion_type: SuggestionType::Duplicates,
262 message,
263 column: None,
264 severity,
265 });
266 }
267
268 for (col_name, info) in outliers {
270 let severity = if info.percentage > 5.0 { Severity::Warning } else { Severity::Info };
271
272 let message = if info.percentage > 5.0 {
273 format!(
274 "列 '{}' 检测到 {} 个离群值 ({:.1}%),建议检查数据异常",
275 col_name, info.count, info.percentage
276 )
277 } else {
278 format!("列 '{}' 检测到 {} 个离群值,可能需要进一步分析", col_name, info.count)
279 };
280
281 suggestions.push(Suggestion {
282 suggestion_type: SuggestionType::Outliers,
283 message,
284 column: Some(col_name.clone()),
285 severity,
286 });
287 }
288
289 suggestions.sort_by(|a, b| b.severity.cmp(&a.severity));
291
292 suggestions
293 }
294
295 fn count_duplicates(&self, df: &DataFrame) -> Result<usize> {
297 let mask = df.is_duplicated().map_err(|e| anyhow::anyhow!("检测重复行失败: {}", e))?;
299
300 let duplicate_count = mask.sum().ok_or_else(|| anyhow::anyhow!("计算重复行数失败"))?;
301
302 Ok(duplicate_count as usize)
303 }
304
305 pub fn column_stats(&self, df: &DataFrame, column_name: &str) -> Result<ColumnStats> {
307 let series = df.column(column_name).map_err(|e| anyhow::anyhow!("获取列失败: {}", e))?;
308
309 let total_rows = df.height();
310 let null_count = series.null_count();
311 let null_percentage =
312 if total_rows > 0 { (null_count as f64 / total_rows as f64) * 100.0 } else { 0.0 };
313
314 let unique_count =
315 series.n_unique().map_err(|e| anyhow::anyhow!("计算唯一值失败: {}", e))?;
316
317 Ok(ColumnStats {
318 name: column_name.to_string(),
319 dtype: format!("{:?}", series.dtype()),
320 unique_count,
321 null_count,
322 null_percentage,
323 })
324 }
325
326 pub fn detect_outliers(&self, df: &DataFrame, column_name: &str) -> Result<Vec<usize>> {
329 let column = df.column(column_name).map_err(|e| anyhow::anyhow!("获取列失败: {}", e))?;
330
331 let series = column.as_materialized_series();
333
334 if !series.dtype().is_numeric() {
336 return Err(anyhow::anyhow!("列 {} 不是数值类型", column_name));
337 }
338
339 let sorted =
341 series.sort(Default::default()).map_err(|e| anyhow::anyhow!("排序失败: {}", e))?;
342
343 let len = sorted.len();
344 let q1_idx = len / 4;
345 let q3_idx = (3 * len) / 4;
346
347 let q1_val = sorted
348 .get(q1_idx)
349 .map_err(|e| anyhow::anyhow!("获取 Q1 失败: {}", e))?
350 .try_extract::<f64>()
351 .map_err(|e| anyhow::anyhow!("提取 Q1 值失败: {}", e))?;
352
353 let q3_val = sorted
354 .get(q3_idx)
355 .map_err(|e| anyhow::anyhow!("获取 Q3 失败: {}", e))?
356 .try_extract::<f64>()
357 .map_err(|e| anyhow::anyhow!("提取 Q3 值失败: {}", e))?;
358
359 let iqr = q3_val - q1_val;
360 let lower_bound = q1_val - 1.5 * iqr;
361 let upper_bound = q3_val + 1.5 * iqr;
362
363 let mut outlier_indices = Vec::new();
365 for idx in 0..series.len() {
366 if let Ok(val) = series.get(idx) {
367 if let Ok(v) = val.try_extract::<f64>() {
368 if v < lower_bound || v > upper_bound {
369 outlier_indices.push(idx);
370 }
371 }
372 }
373 }
374
375 Ok(outlier_indices)
376 }
377}
378
379impl Default for DataProfiler {
380 fn default() -> Self {
381 Self::new()
382 }
383}
384
385#[cfg(test)]
386mod tests {
387 use super::*;
388 use polars::df;
389
390 #[test]
391 fn test_profile_basic() {
392 let df = df! {
393 "id" => &[1, 2, 3, 4, 5],
394 "name" => &["Alice", "Bob", "Charlie", "David", "Eve"],
395 "age" => &[Some(25), Some(30), None, Some(35), Some(40)],
396 }
397 .unwrap();
398
399 let profiler = DataProfiler::new();
400 let report = profiler.profile(&df).unwrap();
401
402 assert_eq!(report.total_rows, 5);
403 assert_eq!(report.total_columns, 3);
404 assert_eq!(report.column_names.len(), 3);
405
406 assert!(report.missing_values.contains_key("age"));
408 let age_missing = &report.missing_values["age"];
409 assert_eq!(age_missing.count, 1);
410 assert_eq!(age_missing.percentage, 20.0);
411 }
412
413 #[test]
414 fn test_column_stats() {
415 let df = df! {
416 "id" => &[1, 2, 3, 4, 5],
417 "category" => &["A", "B", "A", "C", "B"],
418 }
419 .unwrap();
420
421 let profiler = DataProfiler::new();
422 let stats = profiler.column_stats(&df, "category").unwrap();
423
424 assert_eq!(stats.name, "category");
425 assert_eq!(stats.unique_count, 3); assert_eq!(stats.null_count, 0);
427 }
428
429 #[test]
430 fn test_count_duplicates() {
431 let df = df! {
432 "id" => &[1, 2, 3, 2, 1],
433 "value" => &[10, 20, 30, 20, 10],
434 }
435 .unwrap();
436
437 let profiler = DataProfiler::new();
438 let duplicate_count = profiler.count_duplicates(&df).unwrap();
439
440 assert_eq!(duplicate_count, 4);
443 }
444
445 #[test]
446 fn test_detect_outliers() {
447 let df = df! {
448 "values" => &[1.0, 2.0, 3.0, 4.0, 5.0, 100.0], }
450 .unwrap();
451
452 let profiler = DataProfiler::new();
453 let outliers = profiler.detect_outliers(&df, "values").unwrap();
454
455 assert!(!outliers.is_empty());
456 assert!(outliers.contains(&5)); }
458
459 #[test]
460 fn test_no_missing_values() {
461 let df = df! {
462 "id" => &[1, 2, 3],
463 "name" => &["A", "B", "C"],
464 }
465 .unwrap();
466
467 let profiler = DataProfiler::new();
468 let report = profiler.profile(&df).unwrap();
469
470 assert!(report.missing_values.is_empty());
471 }
472
473 #[test]
474 fn test_suggestions_generation() {
475 let df = df! {
476 "id" => &[1, 2, 3, 4, 5],
477 "age" => &[Some(25), None, None, Some(35), Some(40)], "score" => &[1.0, 2.0, 3.0, 4.0, 100.0], }
480 .unwrap();
481
482 let profiler = DataProfiler::new();
483 let report = profiler.profile(&df).unwrap();
484
485 assert!(!report.suggestions.is_empty());
487
488 let missing_suggestions: Vec<_> = report
490 .suggestions
491 .iter()
492 .filter(|s| s.suggestion_type == SuggestionType::MissingValues)
493 .collect();
494 assert!(!missing_suggestions.is_empty());
495
496 let outlier_suggestions: Vec<_> = report
498 .suggestions
499 .iter()
500 .filter(|s| s.suggestion_type == SuggestionType::Outliers)
501 .collect();
502 assert!(!outlier_suggestions.is_empty());
503 }
504
505 #[test]
506 fn test_batch_outlier_detection() {
507 let df = df! {
508 "col1" => &[1.0, 2.0, 3.0, 4.0, 5.0, 100.0],
509 "col2" => &[10.0, 20.0, 30.0, 40.0, 50.0, 500.0],
510 "col3" => &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], }
512 .unwrap();
513
514 let profiler = DataProfiler::new();
515 let report = profiler.profile(&df).unwrap();
516
517 assert!(report.outliers.contains_key("col1"));
519 assert!(report.outliers.contains_key("col2"));
520
521 assert!(!report.outliers.contains_key("col3"));
523
524 let col1_outliers = &report.outliers["col1"];
526 assert!(col1_outliers.count > 0);
527 assert!(col1_outliers.percentage > 0.0);
528 }
529
530 #[test]
531 fn test_duplicate_suggestions() {
532 let df = df! {
533 "id" => &[1, 2, 3, 2, 1], "value" => &[10, 20, 30, 20, 10],
535 }
536 .unwrap();
537
538 let profiler = DataProfiler::new();
539 let report = profiler.profile(&df).unwrap();
540
541 let dup_suggestions: Vec<_> = report
543 .suggestions
544 .iter()
545 .filter(|s| s.suggestion_type == SuggestionType::Duplicates)
546 .collect();
547 assert!(!dup_suggestions.is_empty());
548 }
549
550 #[test]
551 fn test_severity_levels() {
552 let df = df! {
553 "critical" => &[Some(1), None, None, None, None], "warning" => &[Some(1), Some(2), None, Some(4), Some(5)], }
556 .unwrap();
557
558 let profiler = DataProfiler::new();
559 let report = profiler.profile(&df).unwrap();
560
561 let has_error = report.suggestions.iter().any(|s| s.severity == Severity::Error);
563 let has_warning_or_info = report
564 .suggestions
565 .iter()
566 .any(|s| s.severity == Severity::Warning || s.severity == Severity::Info);
567
568 assert!(has_error || has_warning_or_info);
569 }
570}