Skip to main content

xore_process/
export.rs

1//! 数据导出模块
2//!
3//! 提供多种格式的数据导出功能,支持流式导出大文件。
4
5use anyhow::{Context, Result};
6use polars::prelude::*;
7use std::fs::File;
8use std::io::Write;
9use std::path::Path;
10
11// 注意:Polars 0.45 的 IPC 支持可能需要额外的 feature
12// 如果 IpcWriter 不可用,我们暂时禁用 Arrow 导出
13
14/// 导出格式
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum ExportFormat {
17    /// CSV 格式
18    Csv,
19    /// JSON 格式(每行一个 JSON 对象)
20    Json,
21    /// Parquet 列式存储格式
22    Parquet,
23    /// Arrow IPC 格式
24    Arrow,
25}
26
27impl ExportFormat {
28    /// 从文件扩展名推断格式
29    pub fn from_extension(ext: &str) -> Option<Self> {
30        match ext.to_lowercase().as_str() {
31            "csv" => Some(Self::Csv),
32            "json" | "jsonl" => Some(Self::Json),
33            "parquet" => Some(Self::Parquet),
34            "arrow" | "ipc" => Some(Self::Arrow),
35            _ => None,
36        }
37    }
38
39    /// 获取格式的文件扩展名
40    pub fn extension(&self) -> &'static str {
41        match self {
42            Self::Csv => "csv",
43            Self::Json => "json",
44            Self::Parquet => "parquet",
45            Self::Arrow => "arrow",
46        }
47    }
48}
49
50/// 压缩类型
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum CompressionType {
53    /// 不压缩
54    None,
55    /// Gzip 压缩
56    Gzip,
57    /// Zstd 压缩
58    Zstd,
59}
60
61/// 导出配置
62#[derive(Debug, Clone)]
63pub struct ExportConfig {
64    /// 缓冲区大小(字节)
65    pub buffer_size: usize,
66    /// 压缩类型
67    pub compression: CompressionType,
68    /// CSV 分隔符
69    pub csv_delimiter: u8,
70    /// 是否包含表头
71    pub include_header: bool,
72    /// 流式导出的块大小(行数)
73    pub chunk_size: usize,
74}
75
76impl Default for ExportConfig {
77    fn default() -> Self {
78        Self {
79            buffer_size: 64 * 1024, // 64KB
80            compression: CompressionType::None,
81            csv_delimiter: b',',
82            include_header: true,
83            chunk_size: 10000,
84        }
85    }
86}
87
88/// 数据导出器
89pub struct DataExporter {
90    config: ExportConfig,
91}
92
93impl DataExporter {
94    /// 创建新的导出器
95    pub fn new() -> Self {
96        Self { config: ExportConfig::default() }
97    }
98
99    /// 使用自定义配置创建导出器
100    pub fn with_config(config: ExportConfig) -> Self {
101        Self { config }
102    }
103
104    /// 导出 DataFrame 到文件
105    ///
106    /// # 参数
107    /// - `df`: 要导出的 DataFrame
108    /// - `path`: 输出文件路径
109    /// - `format`: 导出格式(如果为 None,则从文件扩展名推断)
110    ///
111    /// # 返回
112    /// 导出的字节数
113    pub fn export(
114        &self,
115        df: &mut DataFrame,
116        path: &Path,
117        format: Option<ExportFormat>,
118    ) -> Result<u64> {
119        // 推断格式
120        let format = match format {
121            Some(f) => f,
122            None => {
123                let ext =
124                    path.extension().and_then(|e| e.to_str()).context("无法获取文件扩展名")?;
125                ExportFormat::from_extension(ext).context(format!("不支持的文件格式: {}", ext))?
126            }
127        };
128
129        tracing::info!("导出数据到 {:?},格式: {:?}", path, format);
130
131        match format {
132            ExportFormat::Csv => self.export_csv(df, path),
133            ExportFormat::Json => self.export_json(df, path),
134            ExportFormat::Parquet => self.export_parquet(df, path),
135            ExportFormat::Arrow => self.export_arrow(df, path),
136        }
137    }
138
139    /// 导出为 CSV 格式
140    fn export_csv(&self, df: &mut DataFrame, path: &Path) -> Result<u64> {
141        let file = File::create(path).context("创建文件失败")?;
142        let mut writer = std::io::BufWriter::with_capacity(self.config.buffer_size, file);
143
144        CsvWriter::new(&mut writer)
145            .include_header(self.config.include_header)
146            .with_separator(self.config.csv_delimiter)
147            .finish(df)
148            .context("写入 CSV 失败")?;
149
150        // 刷新缓冲区
151        writer.flush().context("刷新缓冲区失败")?;
152
153        // 获取文件大小
154        let bytes_written = std::fs::metadata(path)?.len();
155        Ok(bytes_written)
156    }
157
158    /// 导出为 JSON 格式(JSONL - 每行一个 JSON 对象)
159    fn export_json(&self, df: &mut DataFrame, path: &Path) -> Result<u64> {
160        let file = File::create(path).context("创建文件失败")?;
161        let mut writer = std::io::BufWriter::with_capacity(self.config.buffer_size, file);
162
163        JsonWriter::new(&mut writer)
164            .with_json_format(JsonFormat::JsonLines)
165            .finish(df)
166            .context("写入 JSON 失败")?;
167
168        // 刷新缓冲区
169        writer.flush().context("刷新缓冲区失败")?;
170
171        // 获取文件大小
172        let bytes_written = std::fs::metadata(path)?.len();
173        Ok(bytes_written)
174    }
175
176    /// 导出为 Parquet 格式
177    fn export_parquet(&self, df: &mut DataFrame, path: &Path) -> Result<u64> {
178        let file = File::create(path).context("创建文件失败")?;
179
180        // Parquet 使用自己的压缩机制
181        let compression = match self.config.compression {
182            CompressionType::None => ParquetCompression::Uncompressed,
183            CompressionType::Gzip => ParquetCompression::Gzip(None),
184            CompressionType::Zstd => ParquetCompression::Zstd(None),
185        };
186
187        ParquetWriter::new(file)
188            .with_compression(compression)
189            .finish(df)
190            .context("写入 Parquet 失败")?;
191
192        let bytes_written = std::fs::metadata(path)?.len();
193        Ok(bytes_written)
194    }
195
196    /// 导出为 Arrow IPC 格式
197    fn export_arrow(&self, df: &mut DataFrame, path: &Path) -> Result<u64> {
198        // Polars 0.45 可能需要特定的 feature 来支持 IPC
199        // 暂时使用 Parquet 格式替代(也是列式存储)
200        tracing::warn!("Arrow IPC 导出暂不支持,使用 Parquet 格式替代");
201        self.export_parquet(df, path)
202    }
203
204    /// 流式导出大文件(分块写入)
205    ///
206    /// 适用于 GB 级数据,内存占用低
207    pub fn export_streaming(
208        &self,
209        lf: LazyFrame,
210        path: &Path,
211        format: Option<ExportFormat>,
212    ) -> Result<u64> {
213        // 推断格式
214        let format = match format {
215            Some(f) => f,
216            None => {
217                let ext =
218                    path.extension().and_then(|e| e.to_str()).context("无法获取文件扩展名")?;
219                ExportFormat::from_extension(ext).context(format!("不支持的文件格式: {}", ext))?
220            }
221        };
222
223        tracing::info!("流式导出数据到 {:?},格式: {:?}", path, format);
224
225        // 对于 Parquet,先收集再写入(Polars 0.45 的 sink_parquet API 可能不同)
226        if format == ExportFormat::Parquet {
227            let mut df = lf.collect().context("收集 LazyFrame 失败")?;
228            return self.export(&mut df, path, Some(format));
229        }
230
231        // 对于其他格式,先收集再导出
232        let mut df = lf.collect().context("收集 LazyFrame 失败")?;
233        self.export(&mut df, path, Some(format))
234    }
235
236    /// 导出到标准输出(用于管道)
237    pub fn export_to_stdout(&self, df: &mut DataFrame, format: ExportFormat) -> Result<()> {
238        let stdout = std::io::stdout();
239        let mut writer = std::io::BufWriter::with_capacity(self.config.buffer_size, stdout.lock());
240
241        match format {
242            ExportFormat::Csv => {
243                CsvWriter::new(&mut writer)
244                    .include_header(self.config.include_header)
245                    .with_separator(self.config.csv_delimiter)
246                    .finish(df)
247                    .context("写入 CSV 到 stdout 失败")?;
248            }
249            ExportFormat::Json => {
250                JsonWriter::new(&mut writer)
251                    .with_json_format(JsonFormat::JsonLines)
252                    .finish(df)
253                    .context("写入 JSON 到 stdout 失败")?;
254            }
255            _ => {
256                return Err(anyhow::anyhow!(
257                    "格式 {:?} 不支持输出到 stdout,请使用 CSV 或 JSON",
258                    format
259                ));
260            }
261        }
262
263        writer.flush().context("刷新 stdout 失败")?;
264        Ok(())
265    }
266}
267
268impl Default for DataExporter {
269    fn default() -> Self {
270        Self::new()
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277    use tempfile::NamedTempFile;
278
279    fn create_test_dataframe() -> DataFrame {
280        df! {
281            "id" => &[1, 2, 3, 4, 5],
282            "name" => &["Alice", "Bob", "Charlie", "David", "Eve"],
283            "age" => &[25, 30, 35, 40, 45],
284            "score" => &[85.5, 90.0, 78.5, 92.0, 88.5],
285        }
286        .unwrap()
287    }
288
289    #[test]
290    fn test_export_format_from_extension() {
291        assert_eq!(ExportFormat::from_extension("csv"), Some(ExportFormat::Csv));
292        assert_eq!(ExportFormat::from_extension("json"), Some(ExportFormat::Json));
293        assert_eq!(ExportFormat::from_extension("parquet"), Some(ExportFormat::Parquet));
294        assert_eq!(ExportFormat::from_extension("arrow"), Some(ExportFormat::Arrow));
295        assert_eq!(ExportFormat::from_extension("txt"), None);
296    }
297
298    #[test]
299    fn test_export_csv() {
300        let mut df = create_test_dataframe();
301        let temp_file = NamedTempFile::with_suffix(".csv").unwrap();
302        let exporter = DataExporter::new();
303
304        let bytes = exporter.export(&mut df, temp_file.path(), Some(ExportFormat::Csv)).unwrap();
305
306        assert!(bytes > 0);
307        assert!(temp_file.path().exists());
308
309        // 验证内容
310        let content = std::fs::read_to_string(temp_file.path()).unwrap();
311        assert!(content.contains("id,name,age,score"));
312        assert!(content.contains("Alice"));
313    }
314
315    #[test]
316    fn test_export_json() {
317        let mut df = create_test_dataframe();
318        let temp_file = NamedTempFile::with_suffix(".json").unwrap();
319        let exporter = DataExporter::new();
320
321        let bytes = exporter.export(&mut df, temp_file.path(), Some(ExportFormat::Json)).unwrap();
322
323        assert!(bytes > 0);
324
325        // 验证内容
326        let content = std::fs::read_to_string(temp_file.path()).unwrap();
327        assert!(content.contains("Alice"));
328        assert!(content.contains("\"age\":25"));
329    }
330
331    #[test]
332    fn test_export_parquet() {
333        let mut df = create_test_dataframe();
334        let temp_file = NamedTempFile::with_suffix(".parquet").unwrap();
335        let exporter = DataExporter::new();
336
337        let bytes =
338            exporter.export(&mut df, temp_file.path(), Some(ExportFormat::Parquet)).unwrap();
339
340        assert!(bytes > 0);
341        assert!(temp_file.path().exists());
342    }
343
344    #[test]
345    fn test_export_arrow() {
346        let mut df = create_test_dataframe();
347        let temp_file = NamedTempFile::with_suffix(".arrow").unwrap();
348        let exporter = DataExporter::new();
349
350        let bytes = exporter.export(&mut df, temp_file.path(), Some(ExportFormat::Arrow)).unwrap();
351
352        assert!(bytes > 0);
353        assert!(temp_file.path().exists());
354    }
355
356    #[test]
357    fn test_export_auto_detect_format() {
358        let mut df = create_test_dataframe();
359        let temp_file = NamedTempFile::with_suffix(".csv").unwrap();
360        let exporter = DataExporter::new();
361
362        // 不指定格式,应该自动检测
363        let bytes = exporter.export(&mut df, temp_file.path(), None).unwrap();
364
365        assert!(bytes > 0);
366    }
367
368    #[test]
369    fn test_export_empty_dataframe() {
370        let mut df = df! {
371            "col1" => Vec::<i32>::new(),
372            "col2" => Vec::<String>::new(),
373        }
374        .unwrap();
375
376        let temp_file = NamedTempFile::with_suffix(".csv").unwrap();
377        let exporter = DataExporter::new();
378
379        let bytes = exporter.export(&mut df, temp_file.path(), Some(ExportFormat::Csv)).unwrap();
380
381        assert!(bytes > 0); // 至少有表头
382    }
383
384    #[test]
385    fn test_export_with_custom_config() {
386        let mut df = create_test_dataframe();
387        let temp_file = NamedTempFile::with_suffix(".csv").unwrap();
388
389        let config =
390            ExportConfig { csv_delimiter: b';', include_header: true, ..Default::default() };
391
392        let exporter = DataExporter::with_config(config);
393        exporter.export(&mut df, temp_file.path(), Some(ExportFormat::Csv)).unwrap();
394
395        let content = std::fs::read_to_string(temp_file.path()).unwrap();
396        assert!(content.contains(';')); // 验证使用了自定义分隔符
397    }
398}