Skip to main content

synth_claw/datasets/
local.rs

1use polars::prelude::*;
2use std::fs::File;
3use std::io::{BufRead, BufReader};
4use std::path::PathBuf;
5
6use super::{DataSource, DatasetInfo, Record};
7use crate::config::FileFormat;
8use crate::{Error, Result};
9
10pub struct LocalSource {
11    path: PathBuf,
12    format: FileFormat,
13    info: DatasetInfo,
14}
15
16impl LocalSource {
17    pub fn new(path: PathBuf, format: FileFormat) -> Result<Self> {
18        if !path.exists() {
19            return Err(Error::Dataset(format!("File not found: {:?}", path)));
20        }
21
22        let info = Self::detect_info(&path, &format)?;
23
24        Ok(Self { path, format, info })
25    }
26
27    fn detect_info(path: &PathBuf, format: &FileFormat) -> Result<DatasetInfo> {
28        let (columns, num_rows) = match format {
29            FileFormat::Jsonl => Self::detect_jsonl_info(path)?,
30            FileFormat::Json => Self::detect_json_info(path)?,
31            FileFormat::Csv => Self::detect_csv_info(path)?,
32            FileFormat::Parquet => Self::detect_parquet_info(path)?,
33        };
34
35        Ok(DatasetInfo {
36            name: path
37                .file_name()
38                .and_then(|n| n.to_str())
39                .unwrap_or("local")
40                .to_string(),
41            description: None,
42            num_rows,
43            columns,
44            splits: vec![],
45        })
46    }
47
48    fn detect_jsonl_info(path: &PathBuf) -> Result<(Vec<String>, usize)> {
49        let file = File::open(path).map_err(|e| Error::Dataset(e.to_string()))?;
50        let reader = BufReader::new(file);
51        let mut columns = Vec::new();
52        let mut num_rows = 0;
53
54        for (i, line) in reader.lines().enumerate() {
55            let line = line.map_err(|e| Error::Dataset(e.to_string()))?;
56            if line.trim().is_empty() {
57                continue;
58            }
59            num_rows += 1;
60
61            if i == 0 {
62                let obj: serde_json::Value =
63                    serde_json::from_str(&line).map_err(|e| Error::Dataset(e.to_string()))?;
64                if let Some(map) = obj.as_object() {
65                    columns = map.keys().cloned().collect();
66                }
67            }
68        }
69
70        Ok((columns, num_rows))
71    }
72
73    fn detect_json_info(path: &PathBuf) -> Result<(Vec<String>, usize)> {
74        let file = File::open(path).map_err(|e| Error::Dataset(e.to_string()))?;
75        let data: serde_json::Value =
76            serde_json::from_reader(file).map_err(|e| Error::Dataset(e.to_string()))?;
77
78        match data {
79            serde_json::Value::Array(arr) => {
80                let num_rows = arr.len();
81                let columns = arr
82                    .first()
83                    .and_then(|v| v.as_object())
84                    .map(|m| m.keys().cloned().collect())
85                    .unwrap_or_default();
86                Ok((columns, num_rows))
87            }
88            _ => Err(Error::Dataset("JSON file must contain an array".into())),
89        }
90    }
91
92    fn detect_csv_info(path: &std::path::Path) -> Result<(Vec<String>, usize)> {
93        let df = CsvReadOptions::default()
94            .with_has_header(true)
95            .try_into_reader_with_file_path(Some(path.to_path_buf()))
96            .map_err(|e| Error::Dataset(e.to_string()))?
97            .finish()
98            .map_err(|e| Error::Dataset(e.to_string()))?;
99
100        let columns: Vec<String> = df
101            .get_column_names()
102            .iter()
103            .map(|s| s.to_string())
104            .collect();
105        Ok((columns, df.height()))
106    }
107
108    fn detect_parquet_info(path: &PathBuf) -> Result<(Vec<String>, usize)> {
109        let file = File::open(path).map_err(|e| Error::Dataset(e.to_string()))?;
110        let df = ParquetReader::new(file)
111            .finish()
112            .map_err(|e| Error::Dataset(e.to_string()))?;
113
114        let columns: Vec<String> = df
115            .get_column_names()
116            .iter()
117            .map(|s| s.to_string())
118            .collect();
119        Ok((columns, df.height()))
120    }
121
122    fn load_jsonl(&self, sample: Option<usize>) -> Result<Vec<Record>> {
123        let file = File::open(&self.path).map_err(|e| Error::Dataset(e.to_string()))?;
124        let reader = BufReader::new(file);
125        let mut records = Vec::new();
126
127        for (i, line) in reader.lines().enumerate() {
128            if sample.is_some_and(|n| records.len() >= n) {
129                break;
130            }
131
132            let line = line.map_err(|e| Error::Dataset(e.to_string()))?;
133            if line.trim().is_empty() {
134                continue;
135            }
136
137            let data: serde_json::Value =
138                serde_json::from_str(&line).map_err(|e| Error::Dataset(e.to_string()))?;
139            records.push(Record { data, index: i });
140        }
141
142        Ok(records)
143    }
144
145    fn load_json(&self, sample: Option<usize>) -> Result<Vec<Record>> {
146        let file = File::open(&self.path).map_err(|e| Error::Dataset(e.to_string()))?;
147        let data: serde_json::Value =
148            serde_json::from_reader(file).map_err(|e| Error::Dataset(e.to_string()))?;
149
150        match data {
151            serde_json::Value::Array(arr) => {
152                let limit = sample.unwrap_or(arr.len());
153                Ok(arr
154                    .into_iter()
155                    .take(limit)
156                    .enumerate()
157                    .map(|(i, data)| Record { data, index: i })
158                    .collect())
159            }
160            _ => Err(Error::Dataset("JSON file must contain an array".into())),
161        }
162    }
163
164    fn load_csv(&self, sample: Option<usize>) -> Result<Vec<Record>> {
165        let mut df = CsvReadOptions::default()
166            .with_has_header(true)
167            .try_into_reader_with_file_path(Some(self.path.clone()))
168            .map_err(|e| Error::Dataset(e.to_string()))?
169            .finish()
170            .map_err(|e| Error::Dataset(e.to_string()))?;
171
172        if let Some(n) = sample {
173            df = df.head(Some(n));
174        }
175
176        dataframe_to_records(df)
177    }
178
179    fn load_parquet(&self, sample: Option<usize>) -> Result<Vec<Record>> {
180        let file = File::open(&self.path).map_err(|e| Error::Dataset(e.to_string()))?;
181        let mut df = ParquetReader::new(file)
182            .finish()
183            .map_err(|e| Error::Dataset(e.to_string()))?;
184
185        if let Some(n) = sample {
186            df = df.head(Some(n));
187        }
188
189        dataframe_to_records(df)
190    }
191}
192
193impl DataSource for LocalSource {
194    fn info(&self) -> &DatasetInfo {
195        &self.info
196    }
197
198    fn load(&mut self, sample: Option<usize>) -> Result<Vec<Record>> {
199        match self.format {
200            FileFormat::Jsonl => self.load_jsonl(sample),
201            FileFormat::Json => self.load_json(sample),
202            FileFormat::Csv => self.load_csv(sample),
203            FileFormat::Parquet => self.load_parquet(sample),
204        }
205    }
206}
207
208fn dataframe_to_records(df: DataFrame) -> Result<Vec<Record>> {
209    let mut records = Vec::with_capacity(df.height());
210
211    for i in 0..df.height() {
212        let row = df
213            .get(i)
214            .ok_or_else(|| Error::Dataset("Row not found".into()))?;
215        let mut map = serde_json::Map::new();
216
217        for (col_name, value) in df.get_column_names().iter().zip(row.iter()) {
218            let json_value = anyvalue_to_json(value);
219            map.insert(col_name.to_string(), json_value);
220        }
221
222        records.push(Record {
223            data: serde_json::Value::Object(map),
224            index: i,
225        });
226    }
227
228    Ok(records)
229}
230
231fn anyvalue_to_json(value: &AnyValue) -> serde_json::Value {
232    match value {
233        AnyValue::Null => serde_json::Value::Null,
234        AnyValue::Boolean(b) => serde_json::Value::Bool(*b),
235        AnyValue::String(s) => serde_json::Value::String(s.to_string()),
236        AnyValue::StringOwned(s) => serde_json::Value::String(s.to_string()),
237        AnyValue::Float32(n) => serde_json::Number::from_f64(*n as f64)
238            .map(serde_json::Value::Number)
239            .unwrap_or(serde_json::Value::Null),
240        AnyValue::Float64(n) => serde_json::Number::from_f64(*n)
241            .map(serde_json::Value::Number)
242            .unwrap_or(serde_json::Value::Null),
243        other => serde_json::Value::String(format!("{}", other)),
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250    use std::io::Write;
251    use tempfile::NamedTempFile;
252
253    #[test]
254    fn test_load_jsonl() {
255        let mut file = NamedTempFile::new().unwrap();
256        writeln!(file, r#"{{"text": "hello", "label": 1}}"#).unwrap();
257        writeln!(file, r#"{{"text": "world", "label": 0}}"#).unwrap();
258        writeln!(file, r#"{{"text": "test", "label": 1}}"#).unwrap();
259
260        let mut source = LocalSource::new(file.path().to_path_buf(), FileFormat::Jsonl).unwrap();
261        let records = source.load(Some(2)).unwrap();
262
263        assert_eq!(records.len(), 2);
264        assert_eq!(records[0].data["text"], "hello");
265        assert_eq!(records[1].data["text"], "world");
266    }
267
268    #[test]
269    fn test_load_json() {
270        let mut file = NamedTempFile::new().unwrap();
271        write!(
272            file,
273            r#"[{{"text": "a", "n": 1}}, {{"text": "b", "n": 2}}]"#
274        )
275        .unwrap();
276
277        let mut source = LocalSource::new(file.path().to_path_buf(), FileFormat::Json).unwrap();
278        let records = source.load(None).unwrap();
279
280        assert_eq!(records.len(), 2);
281        assert_eq!(records[0].data["text"], "a");
282    }
283
284    #[test]
285    fn test_local_source_info() {
286        let mut file = NamedTempFile::new().unwrap();
287        writeln!(file, r#"{{"col1": "val1", "col2": 123}}"#).unwrap();
288        writeln!(file, r#"{{"col1": "val2", "col2": 456}}"#).unwrap();
289
290        let source = LocalSource::new(file.path().to_path_buf(), FileFormat::Jsonl).unwrap();
291        let info = source.info();
292
293        assert_eq!(info.num_rows, 2);
294        assert!(info.columns.contains(&"col1".to_string()));
295        assert!(info.columns.contains(&"col2".to_string()));
296    }
297}