1use anyhow::{Context, Result};
2use polars::prelude::*;
3use std::fs::File;
4use std::path::Path;
5
6pub struct Dataset {
8 pub name: String,
10 pub features: Vec<Vec<f32>>,
12 pub labels: Vec<u8>,
14}
15
16pub fn list_datasets(dir: &Path) -> Result<Vec<(String, std::path::PathBuf)>> {
20 let mut entries: Vec<_> = std::fs::read_dir(dir)
21 .with_context(|| format!("cannot read data dir: {}", dir.display()))?
22 .filter_map(|e| e.ok())
23 .filter(|e| {
24 e.path()
25 .extension()
26 .map(|x| x == "parquet" || x == "csv")
27 .unwrap_or(false)
28 })
29 .collect();
30
31 entries.sort_by_key(|e| e.path());
32
33 let mut seen: std::collections::HashMap<String, std::path::PathBuf> =
35 std::collections::HashMap::new();
36 for e in entries {
37 let path = e.path();
38 let stem = path
39 .file_stem()
40 .unwrap_or_default()
41 .to_string_lossy()
42 .into_owned();
43 let is_parquet = path.extension().map(|x| x == "parquet").unwrap_or(false);
44 if is_parquet || !seen.contains_key(&stem) {
45 seen.insert(stem, path);
46 }
47 }
48
49 let mut result: Vec<_> = seen.into_iter().collect();
50 result.sort_by(|a, b| a.1.cmp(&b.1));
51
52 Ok(result)
53}
54
55pub fn load_dataset(name: String, path: &Path) -> Result<Dataset> {
57 let loader = if path.extension().map(|x| x == "parquet").unwrap_or(false) {
58 load_parquet
59 } else {
60 load_csv
61 };
62 loader(path)
63 .with_context(|| format!("failed to load {}", path.display()))
64 .map(|ds| Dataset { name, ..ds })
65}
66
67fn load_parquet(path: &Path) -> Result<Dataset> {
71 let file = File::open(path).context("open parquet file")?;
72 let df = ParquetReader::new(file).finish().context("parquet parse")?;
73
74 extract_dataset(df)
75}
76
77fn load_csv(path: &Path) -> Result<Dataset> {
81 let df = CsvReadOptions::default()
82 .with_has_header(true)
83 .try_into_reader_with_file_path(Some(path.into()))
84 .context("csv reader init")?
85 .finish()
86 .context("csv parse")?;
87
88 extract_dataset(df)
89}
90
91fn extract_dataset(df: DataFrame) -> Result<Dataset> {
95 let n_rows = df.height();
96 let n_cols = df.width();
97 anyhow::ensure!(
98 n_cols >= 3,
99 "dataset must have timestamp, at least one feature column, and one label column"
100 );
101
102 let cols: &[Column] = df.columns();
103 let label_col: &Column = cols.last().unwrap();
104 let labels: Vec<u8> = label_col
105 .cast(&DataType::Int64)
106 .context("label cast")?
107 .i64()
108 .context("label as i64")?
109 .into_iter()
110 .map(|v: Option<i64>| v.unwrap_or(0) as u8)
111 .collect();
112
113 let feature_cols: &[Column] = &cols[1..n_cols - 1];
115 let cast_cols: Vec<Column> = feature_cols
116 .iter()
117 .map(|c: &Column| c.cast(&DataType::Float32).context("feature cast"))
118 .collect::<Result<_>>()?;
119
120 let features: Vec<Vec<f32>> = (0..n_rows)
121 .map(|i| {
122 cast_cols
123 .iter()
124 .map(|c: &Column| {
125 c.f32()
126 .expect("cast to f32 failed")
127 .get(i)
128 .unwrap_or(f32::NAN)
129 })
130 .collect()
131 })
132 .collect();
133
134 Ok(Dataset {
135 name: String::new(),
136 features,
137 labels,
138 })
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144 use polars::prelude::{Column, DataFrame, ParquetWriter};
145 use std::io::Write;
146 use tempfile::{NamedTempFile, TempDir};
147
148 fn make_df() -> DataFrame {
149 DataFrame::new(
150 3,
151 vec![
152 Column::new("timestamp".into(), &[0i64, 1, 2]),
153 Column::new("x".into(), &[1.0f64, 3.0, 5.0]),
154 Column::new("y".into(), &[2.0f64, 4.0, 6.0]),
155 Column::new("label".into(), &[0i64, 1, 0]),
156 ],
157 )
158 .unwrap()
159 }
160
161 fn write_parquet(df: &mut DataFrame) -> NamedTempFile {
162 let f = NamedTempFile::with_suffix(".parquet").unwrap();
163 let out = std::fs::File::create(f.path()).unwrap();
164 ParquetWriter::new(out).finish(df).unwrap();
165 f
166 }
167
168 #[test]
169 fn parse_simple_csv() {
170 let mut f = NamedTempFile::with_suffix(".csv").unwrap();
171 writeln!(f, "timestamp,x,y,label").unwrap();
172 writeln!(f, "0,1.0,2.0,0").unwrap();
173 writeln!(f, "1,3.0,4.0,1").unwrap();
174 writeln!(f, "2,5.0,6.0,0").unwrap();
175
176 let ds = load_csv(f.path()).unwrap();
177 assert_eq!(ds.labels, vec![0, 1, 0]);
178 assert_eq!(ds.features.len(), 3);
179 assert_eq!(ds.features[0], vec![1.0, 2.0]);
180 assert_eq!(ds.features[1], vec![3.0, 4.0]);
181 }
182
183 #[test]
184 fn parse_simple_parquet() {
185 let mut df = make_df();
186 let f = write_parquet(&mut df);
187
188 let ds = load_parquet(f.path()).unwrap();
189 assert_eq!(ds.labels, vec![0, 1, 0]);
190 assert_eq!(ds.features.len(), 3);
191 assert_eq!(ds.features[0], vec![1.0, 2.0]);
192 assert_eq!(ds.features[1], vec![3.0, 4.0]);
193 }
194
195 #[test]
196 fn load_dataset_dispatches_by_extension() {
197 let mut df = make_df();
198 let f = write_parquet(&mut df);
199
200 let ds = load_dataset("test".into(), f.path()).unwrap();
201 assert_eq!(ds.name, "test");
202 assert_eq!(ds.labels, vec![0, 1, 0]);
203 }
204
205 #[test]
206 fn list_datasets_parquet_preferred_over_csv() {
207 let dir = TempDir::new().unwrap();
208 let stem = "mydata";
209
210 let csv_path = dir.path().join(format!("{stem}.csv"));
212 std::fs::write(&csv_path, "timestamp,x,label\n0,9.0,1\n1,8.0,1\n2,7.0,1\n").unwrap();
213
214 let mut df = make_df();
216 let parquet_path = dir.path().join(format!("{stem}.parquet"));
217 let out = std::fs::File::create(&parquet_path).unwrap();
218 ParquetWriter::new(out).finish(&mut df).unwrap();
219
220 let datasets = list_datasets(dir.path()).unwrap();
221 assert_eq!(datasets.len(), 1, "duplicate stems should be deduplicated");
222 assert_eq!(datasets[0].0, stem);
223 assert_eq!(datasets[0].1.extension().unwrap(), "parquet");
224 }
225
226 #[test]
227 fn list_datasets_falls_back_to_csv() {
228 let dir = TempDir::new().unwrap();
229 let csv_path = dir.path().join("only.csv");
230 std::fs::write(&csv_path, "timestamp,x,label\n0,1.0,0\n").unwrap();
231
232 let datasets = list_datasets(dir.path()).unwrap();
233 assert_eq!(datasets.len(), 1);
234 assert_eq!(datasets[0].1.extension().unwrap(), "csv");
235 }
236
237 #[test]
238 fn list_datasets_sorted_by_path() {
239 let dir = TempDir::new().unwrap();
240 for name in ["c_data", "a_data", "b_data"] {
241 std::fs::write(dir.path().join(format!("{name}.csv")), "t,x,l\n0,1.0,0\n").unwrap();
242 }
243
244 let datasets = list_datasets(dir.path()).unwrap();
245 let names: Vec<&str> = datasets.iter().map(|(n, _)| n.as_str()).collect();
246 assert_eq!(names, ["a_data", "b_data", "c_data"]);
247 }
248
249 #[test]
250 fn rejects_too_few_columns() {
251 let mut f = NamedTempFile::with_suffix(".csv").unwrap();
252 writeln!(f, "timestamp,label").unwrap();
253 writeln!(f, "0,0").unwrap();
254
255 assert!(load_csv(f.path()).is_err());
256 }
257}