scirs2_datasets/
loaders.rs

1//! Data loading utilities
2
3use crate::error::{DatasetsError, Result};
4use crate::utils::Dataset;
5use csv::ReaderBuilder;
6use ndarray::{Array1, Array2};
7use std::fs::File;
8use std::io::{BufReader, Read};
9use std::path::Path;
10
11/// Load a dataset from a CSV file
12pub fn load_csv<P: AsRef<Path>>(
13    path: P,
14    has_header: bool,
15    target_column: Option<usize>,
16) -> Result<Dataset> {
17    let file = File::open(path).map_err(DatasetsError::IoError)?;
18    let mut reader = ReaderBuilder::new()
19        .has_headers(has_header)
20        .from_reader(file);
21
22    let mut records: Vec<Vec<f64>> = Vec::new();
23    let mut header: Option<Vec<String>> = None;
24
25    // Read header if needed
26    if has_header {
27        let headers = reader.headers().map_err(|e| {
28            DatasetsError::InvalidFormat(format!("Failed to read CSV headers: {}", e))
29        })?;
30        header = Some(headers.iter().map(|s| s.to_string()).collect());
31    }
32
33    // Read rows
34    for result in reader.records() {
35        let record = result.map_err(|e| {
36            DatasetsError::InvalidFormat(format!("Failed to read CSV record: {}", e))
37        })?;
38
39        let values: Vec<f64> = record
40            .iter()
41            .map(|s| {
42                s.parse::<f64>().map_err(|_| {
43                    DatasetsError::InvalidFormat(format!("Failed to parse value: {}", s))
44                })
45            })
46            .collect::<Result<Vec<f64>>>()?;
47
48        if !values.is_empty() {
49            records.push(values);
50        }
51    }
52
53    if records.is_empty() {
54        return Err(DatasetsError::InvalidFormat(
55            "CSV file is empty".to_string(),
56        ));
57    }
58
59    // Create data array and target array if needed
60    let n_rows = records.len();
61    let n_cols = records[0].len();
62
63    let (data, target, feature_names, _target_name) = if let Some(idx) = target_column {
64        if idx >= n_cols {
65            return Err(DatasetsError::InvalidFormat(format!(
66                "Target column index {} is out of bounds (max: {})",
67                idx,
68                n_cols - 1
69            )));
70        }
71
72        let mut data_array = Array2::zeros((n_rows, n_cols - 1));
73        let mut target_array = Array1::zeros(n_rows);
74
75        for (i, row) in records.iter().enumerate() {
76            let mut data_col = 0;
77            for (j, &val) in row.iter().enumerate() {
78                if j == idx {
79                    target_array[i] = val;
80                } else {
81                    data_array[[i, data_col]] = val;
82                    data_col += 1;
83                }
84            }
85        }
86
87        let feature_names = header.as_ref().map(|h| {
88            let mut names = Vec::new();
89            for (j, name) in h.iter().enumerate() {
90                if j != idx {
91                    names.push(name.clone());
92                }
93            }
94            names
95        });
96
97        (
98            data_array,
99            Some(target_array),
100            feature_names,
101            header.as_ref().map(|h| h[idx].clone()),
102        )
103    } else {
104        let mut data_array = Array2::zeros((n_rows, n_cols));
105
106        for (i, row) in records.iter().enumerate() {
107            for (j, &val) in row.iter().enumerate() {
108                data_array[[i, j]] = val;
109            }
110        }
111
112        (data_array, None, header, None)
113    };
114
115    let mut dataset = Dataset::new(data, target);
116
117    if let Some(names) = feature_names {
118        dataset = dataset.with_feature_names(names);
119    }
120
121    Ok(dataset)
122}
123
124/// Load a dataset from a JSON file
125pub fn load_json<P: AsRef<Path>>(path: P) -> Result<Dataset> {
126    let file = File::open(path).map_err(DatasetsError::IoError)?;
127    let reader = BufReader::new(file);
128
129    let dataset: Dataset = serde_json::from_reader(reader)
130        .map_err(|e| DatasetsError::InvalidFormat(format!("Failed to parse JSON: {}", e)))?;
131
132    Ok(dataset)
133}
134
135/// Save a dataset to a JSON file
136pub fn save_json<P: AsRef<Path>>(dataset: &Dataset, path: P) -> Result<()> {
137    let file = File::create(path).map_err(DatasetsError::IoError)?;
138
139    serde_json::to_writer_pretty(file, dataset)
140        .map_err(|e| DatasetsError::SerdeError(format!("Failed to write JSON: {}", e)))?;
141
142    Ok(())
143}
144
145/// Load raw data from a file
146pub fn load_raw<P: AsRef<Path>>(path: P) -> Result<Vec<u8>> {
147    let mut file = File::open(path).map_err(DatasetsError::IoError)?;
148    let mut buffer = Vec::new();
149
150    file.read_to_end(&mut buffer)
151        .map_err(DatasetsError::IoError)?;
152
153    Ok(buffer)
154}