1use crate::engine::DataEngine;
2use datafusion::prelude::*;
3use std::path::Path;
4
5impl DataEngine {
6 pub fn read_csv(&self, path: &str) -> Result<DataFrame, String> {
9 if !path.starts_with("s3://")
10 && !path.starts_with("http://")
11 && !path.starts_with("https://")
12 {
13 let p = Path::new(path);
14 if !p.exists() {
15 return Err(format!("CSV file not found: {}", p.display()));
16 }
17 }
18 self.rt
19 .block_on(self.ctx.read_csv(path, CsvReadOptions::default()))
20 .map_err(|e| format!("CSV read error: {e}"))
21 }
22
23 pub fn read_parquet(&self, path: &str) -> Result<DataFrame, String> {
26 if !path.starts_with("s3://")
27 && !path.starts_with("http://")
28 && !path.starts_with("https://")
29 {
30 let p = Path::new(path);
31 if !p.exists() {
32 return Err(format!("Parquet file not found: {}", p.display()));
33 }
34 }
35 self.rt
36 .block_on(self.ctx.read_parquet(path, ParquetReadOptions::default()))
37 .map_err(|e| format!("Parquet read error: {e}"))
38 }
39
40 pub fn write_csv(&self, df: DataFrame, path: &str) -> Result<(), String> {
42 self.rt
43 .block_on(df.write_csv(
44 path,
45 datafusion::dataframe::DataFrameWriteOptions::default(),
46 None,
47 ))
48 .map_err(|e| format!("CSV write error: {e}"))?;
49 Ok(())
50 }
51
52 pub fn write_parquet(&self, df: DataFrame, path: &str) -> Result<(), String> {
54 self.rt
55 .block_on(df.write_parquet(
56 path,
57 datafusion::dataframe::DataFrameWriteOptions::default(),
58 None,
59 ))
60 .map_err(|e| format!("Parquet write error: {e}"))?;
61 Ok(())
62 }
63}
64
65#[cfg(test)]
66mod tests {
67 use super::*;
68 use std::fs;
69
70 #[test]
71 fn test_csv_round_trip() {
72 let dir = tempfile::tempdir().unwrap();
73 let csv_path = dir.path().join("test.csv");
74
75 fs::write(
77 &csv_path,
78 "id,name,age\n1,Alice,30\n2,Bob,25\n3,Charlie,35\n",
79 )
80 .unwrap();
81
82 let engine = DataEngine::new();
83 let df = engine.read_csv(csv_path.to_str().unwrap()).unwrap();
84
85 let batches = engine.collect(df).unwrap();
87 let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
88 assert_eq!(total_rows, 3);
89
90 let df = engine.read_csv(csv_path.to_str().unwrap()).unwrap();
92 let out_dir = dir.path().join("output");
93 engine.write_csv(df, out_dir.to_str().unwrap()).unwrap();
94
95 assert!(out_dir.exists());
97 }
98}