scirs2_datasets/
loaders.rs1use crate::error::{DatasetsError, Result};
4use crate::utils::Dataset;
5use csv::ReaderBuilder;
6use ndarray::{Array1, Array2};
7use std::fs::File;
8use std::io::{BufReader, Read};
9use std::path::Path;
10
11pub fn load_csv<P: AsRef<Path>>(
13 path: P,
14 has_header: bool,
15 target_column: Option<usize>,
16) -> Result<Dataset> {
17 let file = File::open(path).map_err(DatasetsError::IoError)?;
18 let mut reader = ReaderBuilder::new()
19 .has_headers(has_header)
20 .from_reader(file);
21
22 let mut records: Vec<Vec<f64>> = Vec::new();
23 let mut header: Option<Vec<String>> = None;
24
25 if has_header {
27 let headers = reader.headers().map_err(|e| {
28 DatasetsError::InvalidFormat(format!("Failed to read CSV headers: {}", e))
29 })?;
30 header = Some(headers.iter().map(|s| s.to_string()).collect());
31 }
32
33 for result in reader.records() {
35 let record = result.map_err(|e| {
36 DatasetsError::InvalidFormat(format!("Failed to read CSV record: {}", e))
37 })?;
38
39 let values: Vec<f64> = record
40 .iter()
41 .map(|s| {
42 s.parse::<f64>().map_err(|_| {
43 DatasetsError::InvalidFormat(format!("Failed to parse value: {}", s))
44 })
45 })
46 .collect::<Result<Vec<f64>>>()?;
47
48 if !values.is_empty() {
49 records.push(values);
50 }
51 }
52
53 if records.is_empty() {
54 return Err(DatasetsError::InvalidFormat(
55 "CSV file is empty".to_string(),
56 ));
57 }
58
59 let n_rows = records.len();
61 let n_cols = records[0].len();
62
63 let (data, target, feature_names, _target_name) = if let Some(idx) = target_column {
64 if idx >= n_cols {
65 return Err(DatasetsError::InvalidFormat(format!(
66 "Target column index {} is out of bounds (max: {})",
67 idx,
68 n_cols - 1
69 )));
70 }
71
72 let mut data_array = Array2::zeros((n_rows, n_cols - 1));
73 let mut target_array = Array1::zeros(n_rows);
74
75 for (i, row) in records.iter().enumerate() {
76 let mut data_col = 0;
77 for (j, &val) in row.iter().enumerate() {
78 if j == idx {
79 target_array[i] = val;
80 } else {
81 data_array[[i, data_col]] = val;
82 data_col += 1;
83 }
84 }
85 }
86
87 let feature_names = header.as_ref().map(|h| {
88 let mut names = Vec::new();
89 for (j, name) in h.iter().enumerate() {
90 if j != idx {
91 names.push(name.clone());
92 }
93 }
94 names
95 });
96
97 (
98 data_array,
99 Some(target_array),
100 feature_names,
101 header.as_ref().map(|h| h[idx].clone()),
102 )
103 } else {
104 let mut data_array = Array2::zeros((n_rows, n_cols));
105
106 for (i, row) in records.iter().enumerate() {
107 for (j, &val) in row.iter().enumerate() {
108 data_array[[i, j]] = val;
109 }
110 }
111
112 (data_array, None, header, None)
113 };
114
115 let mut dataset = Dataset::new(data, target);
116
117 if let Some(names) = feature_names {
118 dataset = dataset.with_feature_names(names);
119 }
120
121 Ok(dataset)
122}
123
124pub fn load_json<P: AsRef<Path>>(path: P) -> Result<Dataset> {
126 let file = File::open(path).map_err(DatasetsError::IoError)?;
127 let reader = BufReader::new(file);
128
129 let dataset: Dataset = serde_json::from_reader(reader)
130 .map_err(|e| DatasetsError::InvalidFormat(format!("Failed to parse JSON: {}", e)))?;
131
132 Ok(dataset)
133}
134
135pub fn save_json<P: AsRef<Path>>(dataset: &Dataset, path: P) -> Result<()> {
137 let file = File::create(path).map_err(DatasetsError::IoError)?;
138
139 serde_json::to_writer_pretty(file, dataset)
140 .map_err(|e| DatasetsError::SerdeError(format!("Failed to write JSON: {}", e)))?;
141
142 Ok(())
143}
144
145pub fn load_raw<P: AsRef<Path>>(path: P) -> Result<Vec<u8>> {
147 let mut file = File::open(path).map_err(DatasetsError::IoError)?;
148 let mut buffer = Vec::new();
149
150 file.read_to_end(&mut buffer)
151 .map_err(DatasetsError::IoError)?;
152
153 Ok(buffer)
154}