Skip to main content

touchstone_rs/
loader.rs

1use anyhow::{Context, Result};
2use polars::prelude::*;
3use std::fs::File;
4use std::path::Path;
5
6/// In-memory representation of one benchmark dataset.
7pub struct Dataset {
8    /// Dataset identifier, usually derived from the file stem.
9    pub name: String,
10    /// Feature matrix in row-major layout (`n_points x n_dims`).
11    pub features: Vec<Vec<f32>>,
12    /// Binary anomaly labels aligned with `features`.
13    pub labels: Vec<u8>,
14}
15
16/// Returns sorted `(name, path)` pairs without loading any data.
17///
18/// Parquet files take priority over CSV files when both share the same stem.
19pub fn list_datasets(dir: &Path) -> Result<Vec<(String, std::path::PathBuf)>> {
20    let mut entries: Vec<_> = std::fs::read_dir(dir)
21        .with_context(|| format!("cannot read data dir: {}", dir.display()))?
22        .filter_map(|e| e.ok())
23        .filter(|e| {
24            e.path()
25                .extension()
26                .map(|x| x == "parquet" || x == "csv")
27                .unwrap_or(false)
28        })
29        .collect();
30
31    entries.sort_by_key(|e| e.path());
32
33    // Deduplicate by stem, preferring parquet over csv.
34    let mut seen: std::collections::HashMap<String, std::path::PathBuf> =
35        std::collections::HashMap::new();
36    for e in entries {
37        let path = e.path();
38        let stem = path
39            .file_stem()
40            .unwrap_or_default()
41            .to_string_lossy()
42            .into_owned();
43        let is_parquet = path.extension().map(|x| x == "parquet").unwrap_or(false);
44        if is_parquet || !seen.contains_key(&stem) {
45            seen.insert(stem, path);
46        }
47    }
48
49    let mut result: Vec<_> = seen.into_iter().collect();
50    result.sort_by(|a, b| a.1.cmp(&b.1));
51
52    Ok(result)
53}
54
55/// Loads a dataset from a Parquet or CSV file, applying the given name to the result.
56pub fn load_dataset(name: String, path: &Path) -> Result<Dataset> {
57    let loader = if path.extension().map(|x| x == "parquet").unwrap_or(false) {
58        load_parquet
59    } else {
60        load_csv
61    };
62    loader(path)
63        .with_context(|| format!("failed to load {}", path.display()))
64        .map(|ds| Dataset { name, ..ds })
65}
66
67/// Parses a Parquet file into a Dataset.
68///
69/// Expects format: `timestamp, feature_1, ..., feature_n, label`
70fn load_parquet(path: &Path) -> Result<Dataset> {
71    let file = File::open(path).context("open parquet file")?;
72    let df = ParquetReader::new(file).finish().context("parquet parse")?;
73
74    extract_dataset(df)
75}
76
77/// Parses a CSV file into a Dataset.
78///
79/// Expects format: `timestamp, feature_1, ..., feature_n, label`
80fn load_csv(path: &Path) -> Result<Dataset> {
81    let df = CsvReadOptions::default()
82        .with_has_header(true)
83        .try_into_reader_with_file_path(Some(path.into()))
84        .context("csv reader init")?
85        .finish()
86        .context("csv parse")?;
87
88    extract_dataset(df)
89}
90
91/// Shared extraction logic for both CSV and Parquet DataFrames.
92///
93/// Expects columns: timestamp (skipped), feature_1..feature_n, label (last).
94fn extract_dataset(df: DataFrame) -> Result<Dataset> {
95    let n_rows = df.height();
96    let n_cols = df.width();
97    anyhow::ensure!(
98        n_cols >= 3,
99        "dataset must have timestamp, at least one feature column, and one label column"
100    );
101
102    let cols: &[Column] = df.columns();
103    let label_col: &Column = cols.last().unwrap();
104    let labels: Vec<u8> = label_col
105        .cast(&DataType::Int64)
106        .context("label cast")?
107        .i64()
108        .context("label as i64")?
109        .into_iter()
110        .map(|v: Option<i64>| v.unwrap_or(0) as u8)
111        .collect();
112
113    // cols[0] is timestamp — skip it
114    let feature_cols: &[Column] = &cols[1..n_cols - 1];
115    let cast_cols: Vec<Column> = feature_cols
116        .iter()
117        .map(|c: &Column| c.cast(&DataType::Float32).context("feature cast"))
118        .collect::<Result<_>>()?;
119
120    let features: Vec<Vec<f32>> = (0..n_rows)
121        .map(|i| {
122            cast_cols
123                .iter()
124                .map(|c: &Column| {
125                    c.f32()
126                        .expect("cast to f32 failed")
127                        .get(i)
128                        .unwrap_or(f32::NAN)
129                })
130                .collect()
131        })
132        .collect();
133
134    Ok(Dataset {
135        name: String::new(),
136        features,
137        labels,
138    })
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144    use polars::prelude::{Column, DataFrame, ParquetWriter};
145    use std::io::Write;
146    use tempfile::{NamedTempFile, TempDir};
147
148    fn make_df() -> DataFrame {
149        DataFrame::new(
150            3,
151            vec![
152                Column::new("timestamp".into(), &[0i64, 1, 2]),
153                Column::new("x".into(), &[1.0f64, 3.0, 5.0]),
154                Column::new("y".into(), &[2.0f64, 4.0, 6.0]),
155                Column::new("label".into(), &[0i64, 1, 0]),
156            ],
157        )
158        .unwrap()
159    }
160
161    fn write_parquet(df: &mut DataFrame) -> NamedTempFile {
162        let f = NamedTempFile::with_suffix(".parquet").unwrap();
163        let out = std::fs::File::create(f.path()).unwrap();
164        ParquetWriter::new(out).finish(df).unwrap();
165        f
166    }
167
168    #[test]
169    fn parse_simple_csv() {
170        let mut f = NamedTempFile::with_suffix(".csv").unwrap();
171        writeln!(f, "timestamp,x,y,label").unwrap();
172        writeln!(f, "0,1.0,2.0,0").unwrap();
173        writeln!(f, "1,3.0,4.0,1").unwrap();
174        writeln!(f, "2,5.0,6.0,0").unwrap();
175
176        let ds = load_csv(f.path()).unwrap();
177        assert_eq!(ds.labels, vec![0, 1, 0]);
178        assert_eq!(ds.features.len(), 3);
179        assert_eq!(ds.features[0], vec![1.0, 2.0]);
180        assert_eq!(ds.features[1], vec![3.0, 4.0]);
181    }
182
183    #[test]
184    fn parse_simple_parquet() {
185        let mut df = make_df();
186        let f = write_parquet(&mut df);
187
188        let ds = load_parquet(f.path()).unwrap();
189        assert_eq!(ds.labels, vec![0, 1, 0]);
190        assert_eq!(ds.features.len(), 3);
191        assert_eq!(ds.features[0], vec![1.0, 2.0]);
192        assert_eq!(ds.features[1], vec![3.0, 4.0]);
193    }
194
195    #[test]
196    fn load_dataset_dispatches_by_extension() {
197        let mut df = make_df();
198        let f = write_parquet(&mut df);
199
200        let ds = load_dataset("test".into(), f.path()).unwrap();
201        assert_eq!(ds.name, "test");
202        assert_eq!(ds.labels, vec![0, 1, 0]);
203    }
204
205    #[test]
206    fn list_datasets_parquet_preferred_over_csv() {
207        let dir = TempDir::new().unwrap();
208        let stem = "mydata";
209
210        // Write CSV
211        let csv_path = dir.path().join(format!("{stem}.csv"));
212        std::fs::write(&csv_path, "timestamp,x,label\n0,9.0,1\n1,8.0,1\n2,7.0,1\n").unwrap();
213
214        // Write Parquet with different content so we can tell which was loaded
215        let mut df = make_df();
216        let parquet_path = dir.path().join(format!("{stem}.parquet"));
217        let out = std::fs::File::create(&parquet_path).unwrap();
218        ParquetWriter::new(out).finish(&mut df).unwrap();
219
220        let datasets = list_datasets(dir.path()).unwrap();
221        assert_eq!(datasets.len(), 1, "duplicate stems should be deduplicated");
222        assert_eq!(datasets[0].0, stem);
223        assert_eq!(datasets[0].1.extension().unwrap(), "parquet");
224    }
225
226    #[test]
227    fn list_datasets_falls_back_to_csv() {
228        let dir = TempDir::new().unwrap();
229        let csv_path = dir.path().join("only.csv");
230        std::fs::write(&csv_path, "timestamp,x,label\n0,1.0,0\n").unwrap();
231
232        let datasets = list_datasets(dir.path()).unwrap();
233        assert_eq!(datasets.len(), 1);
234        assert_eq!(datasets[0].1.extension().unwrap(), "csv");
235    }
236
237    #[test]
238    fn list_datasets_sorted_by_path() {
239        let dir = TempDir::new().unwrap();
240        for name in ["c_data", "a_data", "b_data"] {
241            std::fs::write(dir.path().join(format!("{name}.csv")), "t,x,l\n0,1.0,0\n").unwrap();
242        }
243
244        let datasets = list_datasets(dir.path()).unwrap();
245        let names: Vec<&str> = datasets.iter().map(|(n, _)| n.as_str()).collect();
246        assert_eq!(names, ["a_data", "b_data", "c_data"]);
247    }
248
249    #[test]
250    fn rejects_too_few_columns() {
251        let mut f = NamedTempFile::with_suffix(".csv").unwrap();
252        writeln!(f, "timestamp,label").unwrap();
253        writeln!(f, "0,0").unwrap();
254
255        assert!(load_csv(f.path()).is_err());
256    }
257}