Skip to main content

shrew_data/
csv_dataset.rs

1// CsvDataset — load tabular data from CSV files
2//
3// A lightweight CSV parser that doesn't require an external CSV crate.
4// Supports headerless or header-row CSVs. The caller specifies which columns
5// are features and which are targets.
6
7use std::fs;
8use std::path::Path;
9
10use crate::dataset::{Dataset, Sample};
11
12/// A dataset loaded from a CSV file.
13///
14/// All values are parsed as `f64`.  Non-numeric fields will cause a panic.
15///
16/// # Example
17/// ```ignore
18/// // Load iris.csv: 4 feature columns, 1 target column (last)
19/// let ds = CsvDataset::load("data/iris.csv", CsvConfig {
20///     has_header: true,
21///     feature_cols: vec![0, 1, 2, 3],
22///     target_cols: vec![4],
23///     delimiter: b',',
24/// }).unwrap();
25/// ```
26#[derive(Debug)]
27pub struct CsvDataset {
28    samples: Vec<Sample>,
29    feature_shape: Vec<usize>,
30    target_shape: Vec<usize>,
31}
32
33/// Configuration for loading a CSV file.
34#[derive(Debug, Clone)]
35pub struct CsvConfig {
36    /// Whether the first row is a header (to be skipped).
37    pub has_header: bool,
38    /// Column indices to use as features.
39    pub feature_cols: Vec<usize>,
40    /// Column indices to use as targets.
41    pub target_cols: Vec<usize>,
42    /// Delimiter character (default: `,`).
43    pub delimiter: u8,
44}
45
46impl Default for CsvConfig {
47    fn default() -> Self {
48        Self {
49            has_header: true,
50            feature_cols: Vec::new(),
51            target_cols: Vec::new(),
52            delimiter: b',',
53        }
54    }
55}
56
57impl CsvConfig {
58    pub fn has_header(mut self, h: bool) -> Self {
59        self.has_header = h;
60        self
61    }
62    pub fn feature_cols(mut self, cols: Vec<usize>) -> Self {
63        self.feature_cols = cols;
64        self
65    }
66    pub fn target_cols(mut self, cols: Vec<usize>) -> Self {
67        self.target_cols = cols;
68        self
69    }
70    pub fn delimiter(mut self, d: u8) -> Self {
71        self.delimiter = d;
72        self
73    }
74}
75
76impl CsvDataset {
77    /// Load a CSV file from disk.
78    pub fn load<P: AsRef<Path>>(path: P, config: CsvConfig) -> Result<Self, String> {
79        let content = fs::read_to_string(path.as_ref())
80            .map_err(|e| format!("CsvDataset: failed to read {:?}: {}", path.as_ref(), e))?;
81        Self::from_string(&content, config)
82    }
83
84    /// Parse CSV from an in-memory string.
85    pub fn from_string(content: &str, config: CsvConfig) -> Result<Self, String> {
86        let delim = config.delimiter as char;
87        let lines: Vec<&str> = content.lines().filter(|l| !l.trim().is_empty()).collect();
88
89        if lines.is_empty() {
90            return Err("CsvDataset: empty CSV".to_string());
91        }
92
93        let start = if config.has_header { 1 } else { 0 };
94        if start >= lines.len() {
95            return Err("CsvDataset: CSV has only a header, no data".to_string());
96        }
97
98        // Auto-detect columns if not specified
99        let first_row: Vec<&str> = lines[start].split(delim).collect();
100        let num_cols = first_row.len();
101
102        let feat_cols = if config.feature_cols.is_empty() {
103            // All columns except the last
104            (0..num_cols.saturating_sub(1)).collect::<Vec<_>>()
105        } else {
106            config.feature_cols
107        };
108
109        let tgt_cols = if config.target_cols.is_empty() {
110            // Last column only
111            vec![num_cols - 1]
112        } else {
113            config.target_cols
114        };
115
116        let mut samples = Vec::with_capacity(lines.len() - start);
117
118        for (line_no, &line) in lines[start..].iter().enumerate() {
119            let cols: Vec<&str> = line.split(delim).collect();
120            if cols.len() != num_cols {
121                return Err(format!(
122                    "CsvDataset: line {} has {} columns, expected {}",
123                    line_no + start + 1,
124                    cols.len(),
125                    num_cols
126                ));
127            }
128
129            let mut features = Vec::with_capacity(feat_cols.len());
130            for &c in &feat_cols {
131                let val: f64 = cols[c].trim().parse().map_err(|e| {
132                    format!(
133                        "CsvDataset: line {}, col {}: parse error: {}",
134                        line_no + start + 1,
135                        c,
136                        e
137                    )
138                })?;
139                features.push(val);
140            }
141
142            let mut target = Vec::with_capacity(tgt_cols.len());
143            for &c in &tgt_cols {
144                let val: f64 = cols[c].trim().parse().map_err(|e| {
145                    format!(
146                        "CsvDataset: line {}, col {}: parse error: {}",
147                        line_no + start + 1,
148                        c,
149                        e
150                    )
151                })?;
152                target.push(val);
153            }
154
155            samples.push(Sample {
156                features,
157                feature_shape: vec![feat_cols.len()],
158                target,
159                target_shape: vec![tgt_cols.len()],
160            });
161        }
162
163        let feature_shape = vec![feat_cols.len()];
164        let target_shape = vec![tgt_cols.len()];
165
166        Ok(Self {
167            samples,
168            feature_shape,
169            target_shape,
170        })
171    }
172}
173
174impl Dataset for CsvDataset {
175    fn len(&self) -> usize {
176        self.samples.len()
177    }
178
179    fn get(&self, index: usize) -> Sample {
180        self.samples[index].clone()
181    }
182
183    fn feature_shape(&self) -> &[usize] {
184        &self.feature_shape
185    }
186
187    fn target_shape(&self) -> &[usize] {
188        &self.target_shape
189    }
190
191    fn name(&self) -> &str {
192        "csv"
193    }
194}
195
196// Tests
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    #[test]
203    fn csv_with_header() {
204        let csv = "a,b,c\n1.0,2.0,0.0\n3.0,4.0,1.0\n5.0,6.0,0.0\n";
205        let config = CsvConfig::default();
206        let ds = CsvDataset::from_string(csv, config).unwrap();
207        assert_eq!(ds.len(), 3);
208        assert_eq!(ds.feature_shape(), &[2]);
209        assert_eq!(ds.target_shape(), &[1]);
210        assert_eq!(ds.get(0).features, vec![1.0, 2.0]);
211        assert_eq!(ds.get(0).target, vec![0.0]);
212        assert_eq!(ds.get(2).features, vec![5.0, 6.0]);
213    }
214
215    #[test]
216    fn csv_no_header() {
217        let csv = "1.0,2.0,3.0\n4.0,5.0,6.0\n";
218        let config = CsvConfig::default().has_header(false);
219        let ds = CsvDataset::from_string(csv, config).unwrap();
220        assert_eq!(ds.len(), 2);
221        assert_eq!(ds.get(0).features, vec![1.0, 2.0]);
222        assert_eq!(ds.get(0).target, vec![3.0]);
223    }
224
225    #[test]
226    fn csv_custom_columns() {
227        let csv = "a,b,c,d\n1,2,3,4\n5,6,7,8\n";
228        let config = CsvConfig::default()
229            .feature_cols(vec![0, 2])
230            .target_cols(vec![1, 3]);
231        let ds = CsvDataset::from_string(csv, config).unwrap();
232        assert_eq!(ds.feature_shape(), &[2]);
233        assert_eq!(ds.target_shape(), &[2]);
234        assert_eq!(ds.get(0).features, vec![1.0, 3.0]);
235        assert_eq!(ds.get(0).target, vec![2.0, 4.0]);
236    }
237
238    #[test]
239    fn csv_tab_delimiter() {
240        let csv = "a\tb\tc\n1.0\t2.0\t0.0\n3.0\t4.0\t1.0\n";
241        let config = CsvConfig::default().delimiter(b'\t');
242        let ds = CsvDataset::from_string(csv, config).unwrap();
243        assert_eq!(ds.len(), 2);
244        assert_eq!(ds.get(0).features, vec![1.0, 2.0]);
245    }
246
247    #[test]
248    fn csv_parse_error() {
249        let csv = "a,b,c\n1.0,hello,0.0\n";
250        let config = CsvConfig::default();
251        let result = CsvDataset::from_string(csv, config);
252        assert!(result.is_err());
253        assert!(result.unwrap_err().contains("parse error"));
254    }
255
256    #[test]
257    fn csv_empty() {
258        let csv = "";
259        let result = CsvDataset::from_string(csv, CsvConfig::default());
260        assert!(result.is_err());
261    }
262}