Skip to main content

tl_data/
io.rs

1use crate::engine::DataEngine;
2use datafusion::prelude::*;
3use std::path::Path;
4
5impl DataEngine {
6    /// Read a CSV file into a DataFusion DataFrame.
7    /// Supports local paths and s3:// URLs (after register_s3).
8    pub fn read_csv(&self, path: &str) -> Result<DataFrame, String> {
9        if !path.starts_with("s3://")
10            && !path.starts_with("http://")
11            && !path.starts_with("https://")
12        {
13            let p = Path::new(path);
14            if !p.exists() {
15                return Err(format!("CSV file not found: {}", p.display()));
16            }
17        }
18        self.rt
19            .block_on(self.ctx.read_csv(path, CsvReadOptions::default()))
20            .map_err(|e| format!("CSV read error: {e}"))
21    }
22
23    /// Read a Parquet file into a DataFusion DataFrame.
24    /// Supports local paths and s3:// URLs (after register_s3).
25    pub fn read_parquet(&self, path: &str) -> Result<DataFrame, String> {
26        if !path.starts_with("s3://")
27            && !path.starts_with("http://")
28            && !path.starts_with("https://")
29        {
30            let p = Path::new(path);
31            if !p.exists() {
32                return Err(format!("Parquet file not found: {}", p.display()));
33            }
34        }
35        self.rt
36            .block_on(self.ctx.read_parquet(path, ParquetReadOptions::default()))
37            .map_err(|e| format!("Parquet read error: {e}"))
38    }
39
40    /// Write a DataFrame to a CSV file.
41    pub fn write_csv(&self, df: DataFrame, path: &str) -> Result<(), String> {
42        self.rt
43            .block_on(df.write_csv(
44                path,
45                datafusion::dataframe::DataFrameWriteOptions::default(),
46                None,
47            ))
48            .map_err(|e| format!("CSV write error: {e}"))?;
49        Ok(())
50    }
51
52    /// Write a DataFrame to a Parquet file.
53    pub fn write_parquet(&self, df: DataFrame, path: &str) -> Result<(), String> {
54        self.rt
55            .block_on(df.write_parquet(
56                path,
57                datafusion::dataframe::DataFrameWriteOptions::default(),
58                None,
59            ))
60            .map_err(|e| format!("Parquet write error: {e}"))?;
61        Ok(())
62    }
63}
64
65#[cfg(test)]
66mod tests {
67    use super::*;
68    use std::fs;
69
70    #[test]
71    fn test_csv_round_trip() {
72        let dir = tempfile::tempdir().unwrap();
73        let csv_path = dir.path().join("test.csv");
74
75        // Write a test CSV
76        fs::write(
77            &csv_path,
78            "id,name,age\n1,Alice,30\n2,Bob,25\n3,Charlie,35\n",
79        )
80        .unwrap();
81
82        let engine = DataEngine::new();
83        let df = engine.read_csv(csv_path.to_str().unwrap()).unwrap();
84
85        // Verify read
86        let batches = engine.collect(df).unwrap();
87        let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
88        assert_eq!(total_rows, 3);
89
90        // Write back
91        let df = engine.read_csv(csv_path.to_str().unwrap()).unwrap();
92        let out_dir = dir.path().join("output");
93        engine.write_csv(df, out_dir.to_str().unwrap()).unwrap();
94
95        // Verify output directory was created
96        assert!(out_dir.exists());
97    }
98}