Skip to main content

xore_process/
parser.rs

1//! 数据解析器 - 基于 Polars 的高性能数据加载
2
3use anyhow::{Context, Result};
4use polars::prelude::*;
5use std::fs::File;
6use std::io::Cursor;
7use std::path::Path;
8
9/// 数据解析器配置
10#[derive(Debug, Clone)]
11pub struct ParserConfig {
12    /// 是否使用内存映射(适用于大文件)
13    pub use_mmap: bool,
14    /// 内存映射阈值(字节),超过此大小使用 mmap
15    pub mmap_threshold: u64,
16    /// CSV 分隔符
17    pub csv_delimiter: u8,
18    /// 是否自动推断 Schema
19    pub infer_schema: bool,
20    /// Schema 推断时扫描的行数
21    pub infer_schema_length: Option<usize>,
22    /// 是否跳过空行
23    pub skip_rows: usize,
24    /// 是否有表头
25    pub has_header: bool,
26}
27
28impl Default for ParserConfig {
29    fn default() -> Self {
30        Self {
31            use_mmap: true,
32            mmap_threshold: 1024 * 1024, // 1MB
33            csv_delimiter: b',',
34            infer_schema: true,
35            infer_schema_length: Some(1000),
36            skip_rows: 0,
37            has_header: true,
38        }
39    }
40}
41
42/// 数据解析器
43pub struct DataParser {
44    config: ParserConfig,
45}
46
47impl DataParser {
48    /// 创建新的解析器
49    pub fn new() -> Self {
50        Self { config: ParserConfig::default() }
51    }
52
53    /// 使用自定义配置创建解析器
54    pub fn with_config(config: ParserConfig) -> Self {
55        Self { config }
56    }
57
58    /// 读取 CSV 文件并返回 LazyFrame
59    pub fn read_csv_lazy(&self, path: &Path) -> Result<LazyFrame> {
60        tracing::debug!("读取 CSV 文件: {:?}", path);
61
62        // 检查文件是否存在
63        if !path.exists() {
64            return Err(anyhow::anyhow!(
65                "文件不存在: {}\n💡 提示: 请检查路径是否正确,或使用 'ls' 命令确认文件存在",
66                path.display()
67            ));
68        }
69
70        // 获取文件大小
71        let file_size = std::fs::metadata(path)
72            .with_context(|| format!("无法获取文件元数据: {}", path.display()))?
73            .len();
74        tracing::debug!("文件大小: {} 字节", file_size);
75
76        // 根据文件大小决定是否使用 mmap
77        let use_mmap = self.config.use_mmap && file_size > self.config.mmap_threshold;
78
79        if use_mmap {
80            tracing::debug!("使用内存映射读取大文件");
81            self.read_csv_with_mmap(path)
82        } else {
83            tracing::debug!("使用标准文件读取");
84            self.read_csv_standard(path)
85        }
86    }
87
88    /// 使用标准方式读取 CSV
89    fn read_csv_standard(&self, path: &Path) -> Result<LazyFrame> {
90        let df = LazyCsvReader::new(path)
91            .with_has_header(self.config.has_header)
92            .with_separator(self.config.csv_delimiter)
93            .with_skip_rows(self.config.skip_rows)
94            .with_infer_schema_length(self.config.infer_schema_length)
95            .finish()
96            .map_err(|e| anyhow::anyhow!(
97                "读取 CSV 文件失败: {}\n  --> 文件: {}\n💡 提示: 请确认文件是有效的 CSV 格式,分隔符是否正确",
98                e, path.display()
99            ))?;
100
101        Ok(df)
102    }
103
104    /// 使用内存映射读取 CSV
105    fn read_csv_with_mmap(&self, path: &Path) -> Result<LazyFrame> {
106        use memmap2::Mmap;
107
108        // 打开文件
109        let file = File::open(path).with_context(|| format!("无法打开文件: {}", path.display()))?;
110
111        // 创建内存映射
112        let mmap = unsafe {
113            Mmap::map(&file).with_context(|| format!("内存映射失败: {}", path.display()))?
114        };
115
116        // 使用 Cursor 包装 mmap 数据
117        let cursor = Cursor::new(&mmap[..]);
118
119        // 读取 CSV
120        let df = CsvReadOptions::default()
121            .with_has_header(self.config.has_header)
122            .with_infer_schema_length(self.config.infer_schema_length)
123            .into_reader_with_file_handle(cursor)
124            .finish()
125            .map_err(|e| {
126                anyhow::anyhow!(
127                    "读取 CSV 文件失败 (mmap 模式): {}\n  --> 文件: {}",
128                    e,
129                    path.display()
130                )
131            })?
132            .lazy();
133
134        Ok(df)
135    }
136
137    /// 读取 Parquet 文件并返回 LazyFrame
138    pub fn read_parquet_lazy(&self, path: &Path) -> Result<LazyFrame> {
139        tracing::debug!("读取 Parquet 文件: {:?}", path);
140
141        // 检查文件是否存在
142        if !path.exists() {
143            return Err(anyhow::anyhow!(
144                "文件不存在: {}\n💡 提示: 请检查路径是否正确,或使用 'ls' 命令确认文件存在",
145                path.display()
146            ));
147        }
148
149        let args = ScanArgsParquet::default();
150        let df = LazyFrame::scan_parquet(path, args)
151            .map_err(|e| anyhow::anyhow!(
152                "读取 Parquet 文件失败: {}\n  --> 文件: {}\n💡 提示: 请确认文件是有效的 Parquet 格式",
153                e, path.display()
154            ))?;
155
156        Ok(df)
157    }
158
159    /// 自动识别格式并读取
160    pub fn read_lazy(&self, path: &Path) -> Result<LazyFrame> {
161        let extension = path.extension().and_then(|e| e.to_str()).unwrap_or("").to_lowercase();
162
163        match extension.as_str() {
164            "csv" => self.read_csv_lazy(path),
165            "parquet" => self.read_parquet_lazy(path),
166            _ => Err(anyhow::anyhow!(
167                "不支持的文件格式: '{}'\n  --> 文件: {}\n💡 提示: 支持的格式为 csv 和 parquet",
168                extension,
169                path.display()
170            )),
171        }
172    }
173
174    /// 读取并收集为 DataFrame(用于小数据集或需要立即执行的场景)
175    pub fn read_csv(&self, path: &Path) -> Result<DataFrame> {
176        let lf = self.read_csv_lazy(path)?;
177        lf.collect().with_context(|| format!("执行 CSV 查询失败: {}", path.display()))
178    }
179
180    /// 读取 Parquet 并收集为 DataFrame
181    pub fn read_parquet(&self, path: &Path) -> Result<DataFrame> {
182        let lf = self.read_parquet_lazy(path)?;
183        lf.collect().with_context(|| format!("执行 Parquet 查询失败: {}", path.display()))
184    }
185
186    /// 自动识别格式并读取为 DataFrame
187    pub fn read(&self, path: &Path) -> Result<DataFrame> {
188        let lf = self.read_lazy(path)?;
189        lf.collect().with_context(|| format!("执行查询失败: {}", path.display()))
190    }
191}
192
193impl Default for DataParser {
194    fn default() -> Self {
195        Self::new()
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202    use std::io::Write;
203    use tempfile::NamedTempFile;
204
205    fn create_test_csv() -> NamedTempFile {
206        let mut file = NamedTempFile::new().unwrap();
207        writeln!(file, "id,name,age,city").unwrap();
208        writeln!(file, "1,Alice,28,Beijing").unwrap();
209        writeln!(file, "2,Bob,32,Shanghai").unwrap();
210        writeln!(file, "3,Charlie,25,Guangzhou").unwrap();
211        file.flush().unwrap();
212        file
213    }
214
215    #[test]
216    fn test_read_csv_lazy() {
217        let file = create_test_csv();
218        let parser = DataParser::new();
219
220        let lf = parser.read_csv_lazy(file.path()).unwrap();
221        let df = lf.collect().unwrap();
222
223        assert_eq!(df.height(), 3);
224        assert_eq!(df.width(), 4);
225        assert_eq!(df.get_column_names(), vec!["id", "name", "age", "city"]);
226    }
227
228    #[test]
229    fn test_read_csv() {
230        let file = create_test_csv();
231        let parser = DataParser::new();
232
233        let df = parser.read_csv(file.path()).unwrap();
234
235        assert_eq!(df.height(), 3);
236        assert_eq!(df.width(), 4);
237    }
238
239    #[test]
240    fn test_read_auto_detect() {
241        // 创建带 .csv 扩展名的临时文件
242        let temp_dir = tempfile::tempdir().unwrap();
243        let file_path = temp_dir.path().join("test.csv");
244
245        std::fs::write(
246            &file_path,
247            "id,name,age,city\n1,Alice,28,Beijing\n2,Bob,32,Shanghai\n3,Charlie,25,Guangzhou\n",
248        )
249        .unwrap();
250
251        let parser = DataParser::new();
252        let df = parser.read(&file_path).unwrap();
253
254        assert_eq!(df.height(), 3);
255        assert_eq!(df.width(), 4);
256    }
257
258    #[test]
259    fn test_file_not_found() {
260        let parser = DataParser::new();
261        let result = parser.read_csv_lazy(Path::new("/nonexistent/file.csv"));
262
263        assert!(result.is_err());
264    }
265
266    #[test]
267    fn test_unsupported_format() {
268        let parser = DataParser::new();
269        let result = parser.read_lazy(Path::new("test.txt"));
270
271        assert!(result.is_err());
272        if let Err(e) = result {
273            assert!(e.to_string().contains("不支持的文件格式"));
274        }
275    }
276
277    #[test]
278    fn test_custom_config() {
279        let config = ParserConfig { csv_delimiter: b';', has_header: false, ..Default::default() };
280
281        let parser = DataParser::with_config(config);
282        assert_eq!(parser.config.csv_delimiter, b';');
283        assert!(!parser.config.has_header);
284    }
285
286    #[test]
287    fn test_mmap_threshold() {
288        let file = create_test_csv();
289        let file_size = std::fs::metadata(file.path()).unwrap().len();
290
291        // 设置阈值低于文件大小,应该使用 mmap
292        let config =
293            ParserConfig { use_mmap: true, mmap_threshold: file_size - 1, ..Default::default() };
294
295        let parser = DataParser::with_config(config);
296        let df = parser.read_csv(file.path()).unwrap();
297
298        assert_eq!(df.height(), 3);
299    }
300}