1use anyhow::{Context, Result};
4use polars::prelude::*;
5use std::fs::File;
6use std::io::Cursor;
7use std::path::Path;
8
9#[derive(Debug, Clone)]
11pub struct ParserConfig {
12 pub use_mmap: bool,
14 pub mmap_threshold: u64,
16 pub csv_delimiter: u8,
18 pub infer_schema: bool,
20 pub infer_schema_length: Option<usize>,
22 pub skip_rows: usize,
24 pub has_header: bool,
26}
27
28impl Default for ParserConfig {
29 fn default() -> Self {
30 Self {
31 use_mmap: true,
32 mmap_threshold: 1024 * 1024, csv_delimiter: b',',
34 infer_schema: true,
35 infer_schema_length: Some(1000),
36 skip_rows: 0,
37 has_header: true,
38 }
39 }
40}
41
42pub struct DataParser {
44 config: ParserConfig,
45}
46
47impl DataParser {
48 pub fn new() -> Self {
50 Self { config: ParserConfig::default() }
51 }
52
53 pub fn with_config(config: ParserConfig) -> Self {
55 Self { config }
56 }
57
58 pub fn read_csv_lazy(&self, path: &Path) -> Result<LazyFrame> {
60 tracing::debug!("读取 CSV 文件: {:?}", path);
61
62 if !path.exists() {
64 return Err(anyhow::anyhow!(
65 "文件不存在: {}\n💡 提示: 请检查路径是否正确,或使用 'ls' 命令确认文件存在",
66 path.display()
67 ));
68 }
69
70 let file_size = std::fs::metadata(path)
72 .with_context(|| format!("无法获取文件元数据: {}", path.display()))?
73 .len();
74 tracing::debug!("文件大小: {} 字节", file_size);
75
76 let use_mmap = self.config.use_mmap && file_size > self.config.mmap_threshold;
78
79 if use_mmap {
80 tracing::debug!("使用内存映射读取大文件");
81 self.read_csv_with_mmap(path)
82 } else {
83 tracing::debug!("使用标准文件读取");
84 self.read_csv_standard(path)
85 }
86 }
87
88 fn read_csv_standard(&self, path: &Path) -> Result<LazyFrame> {
90 let df = LazyCsvReader::new(path)
91 .with_has_header(self.config.has_header)
92 .with_separator(self.config.csv_delimiter)
93 .with_skip_rows(self.config.skip_rows)
94 .with_infer_schema_length(self.config.infer_schema_length)
95 .finish()
96 .map_err(|e| anyhow::anyhow!(
97 "读取 CSV 文件失败: {}\n --> 文件: {}\n💡 提示: 请确认文件是有效的 CSV 格式,分隔符是否正确",
98 e, path.display()
99 ))?;
100
101 Ok(df)
102 }
103
104 fn read_csv_with_mmap(&self, path: &Path) -> Result<LazyFrame> {
106 use memmap2::Mmap;
107
108 let file = File::open(path).with_context(|| format!("无法打开文件: {}", path.display()))?;
110
111 let mmap = unsafe {
113 Mmap::map(&file).with_context(|| format!("内存映射失败: {}", path.display()))?
114 };
115
116 let cursor = Cursor::new(&mmap[..]);
118
119 let df = CsvReadOptions::default()
121 .with_has_header(self.config.has_header)
122 .with_infer_schema_length(self.config.infer_schema_length)
123 .into_reader_with_file_handle(cursor)
124 .finish()
125 .map_err(|e| {
126 anyhow::anyhow!(
127 "读取 CSV 文件失败 (mmap 模式): {}\n --> 文件: {}",
128 e,
129 path.display()
130 )
131 })?
132 .lazy();
133
134 Ok(df)
135 }
136
137 pub fn read_parquet_lazy(&self, path: &Path) -> Result<LazyFrame> {
139 tracing::debug!("读取 Parquet 文件: {:?}", path);
140
141 if !path.exists() {
143 return Err(anyhow::anyhow!(
144 "文件不存在: {}\n💡 提示: 请检查路径是否正确,或使用 'ls' 命令确认文件存在",
145 path.display()
146 ));
147 }
148
149 let args = ScanArgsParquet::default();
150 let df = LazyFrame::scan_parquet(path, args)
151 .map_err(|e| anyhow::anyhow!(
152 "读取 Parquet 文件失败: {}\n --> 文件: {}\n💡 提示: 请确认文件是有效的 Parquet 格式",
153 e, path.display()
154 ))?;
155
156 Ok(df)
157 }
158
159 pub fn read_lazy(&self, path: &Path) -> Result<LazyFrame> {
161 let extension = path.extension().and_then(|e| e.to_str()).unwrap_or("").to_lowercase();
162
163 match extension.as_str() {
164 "csv" => self.read_csv_lazy(path),
165 "parquet" => self.read_parquet_lazy(path),
166 _ => Err(anyhow::anyhow!(
167 "不支持的文件格式: '{}'\n --> 文件: {}\n💡 提示: 支持的格式为 csv 和 parquet",
168 extension,
169 path.display()
170 )),
171 }
172 }
173
174 pub fn read_csv(&self, path: &Path) -> Result<DataFrame> {
176 let lf = self.read_csv_lazy(path)?;
177 lf.collect().with_context(|| format!("执行 CSV 查询失败: {}", path.display()))
178 }
179
180 pub fn read_parquet(&self, path: &Path) -> Result<DataFrame> {
182 let lf = self.read_parquet_lazy(path)?;
183 lf.collect().with_context(|| format!("执行 Parquet 查询失败: {}", path.display()))
184 }
185
186 pub fn read(&self, path: &Path) -> Result<DataFrame> {
188 let lf = self.read_lazy(path)?;
189 lf.collect().with_context(|| format!("执行查询失败: {}", path.display()))
190 }
191}
192
193impl Default for DataParser {
194 fn default() -> Self {
195 Self::new()
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202 use std::io::Write;
203 use tempfile::NamedTempFile;
204
205 fn create_test_csv() -> NamedTempFile {
206 let mut file = NamedTempFile::new().unwrap();
207 writeln!(file, "id,name,age,city").unwrap();
208 writeln!(file, "1,Alice,28,Beijing").unwrap();
209 writeln!(file, "2,Bob,32,Shanghai").unwrap();
210 writeln!(file, "3,Charlie,25,Guangzhou").unwrap();
211 file.flush().unwrap();
212 file
213 }
214
215 #[test]
216 fn test_read_csv_lazy() {
217 let file = create_test_csv();
218 let parser = DataParser::new();
219
220 let lf = parser.read_csv_lazy(file.path()).unwrap();
221 let df = lf.collect().unwrap();
222
223 assert_eq!(df.height(), 3);
224 assert_eq!(df.width(), 4);
225 assert_eq!(df.get_column_names(), vec!["id", "name", "age", "city"]);
226 }
227
228 #[test]
229 fn test_read_csv() {
230 let file = create_test_csv();
231 let parser = DataParser::new();
232
233 let df = parser.read_csv(file.path()).unwrap();
234
235 assert_eq!(df.height(), 3);
236 assert_eq!(df.width(), 4);
237 }
238
239 #[test]
240 fn test_read_auto_detect() {
241 let temp_dir = tempfile::tempdir().unwrap();
243 let file_path = temp_dir.path().join("test.csv");
244
245 std::fs::write(
246 &file_path,
247 "id,name,age,city\n1,Alice,28,Beijing\n2,Bob,32,Shanghai\n3,Charlie,25,Guangzhou\n",
248 )
249 .unwrap();
250
251 let parser = DataParser::new();
252 let df = parser.read(&file_path).unwrap();
253
254 assert_eq!(df.height(), 3);
255 assert_eq!(df.width(), 4);
256 }
257
258 #[test]
259 fn test_file_not_found() {
260 let parser = DataParser::new();
261 let result = parser.read_csv_lazy(Path::new("/nonexistent/file.csv"));
262
263 assert!(result.is_err());
264 }
265
266 #[test]
267 fn test_unsupported_format() {
268 let parser = DataParser::new();
269 let result = parser.read_lazy(Path::new("test.txt"));
270
271 assert!(result.is_err());
272 if let Err(e) = result {
273 assert!(e.to_string().contains("不支持的文件格式"));
274 }
275 }
276
277 #[test]
278 fn test_custom_config() {
279 let config = ParserConfig { csv_delimiter: b';', has_header: false, ..Default::default() };
280
281 let parser = DataParser::with_config(config);
282 assert_eq!(parser.config.csv_delimiter, b';');
283 assert!(!parser.config.has_header);
284 }
285
286 #[test]
287 fn test_mmap_threshold() {
288 let file = create_test_csv();
289 let file_size = std::fs::metadata(file.path()).unwrap().len();
290
291 let config =
293 ParserConfig { use_mmap: true, mmap_threshold: file_size - 1, ..Default::default() };
294
295 let parser = DataParser::with_config(config);
296 let df = parser.read_csv(file.path()).unwrap();
297
298 assert_eq!(df.height(), 3);
299 }
300}