xdl_dataframe/readers/
csv.rs1use crate::dataframe::DataFrame;
4use crate::error::{DataFrameError, DataFrameResult};
5use crate::series::Series;
6use csv::{ReaderBuilder, StringRecord};
7use indexmap::IndexMap;
8use std::path::Path;
9use xdl_core::XdlValue;
10
11#[derive(Debug, Clone)]
13pub struct CsvReaderOptions {
14 pub delimiter: u8,
16 pub has_headers: bool,
18 pub skip_rows: usize,
20 pub max_rows: Option<usize>,
22 pub infer_types: bool,
24}
25
26impl Default for CsvReaderOptions {
27 fn default() -> Self {
28 Self {
29 delimiter: b',',
30 has_headers: true,
31 skip_rows: 0,
32 max_rows: None,
33 infer_types: true,
34 }
35 }
36}
37
38impl CsvReaderOptions {
39 pub fn csv() -> Self {
41 Self::default()
42 }
43
44 pub fn tsv() -> Self {
46 Self {
47 delimiter: b'\t',
48 ..Self::default()
49 }
50 }
51
52 pub fn with_delimiter(mut self, delimiter: u8) -> Self {
54 self.delimiter = delimiter;
55 self
56 }
57
58 pub fn with_headers(mut self, has_headers: bool) -> Self {
60 self.has_headers = has_headers;
61 self
62 }
63
64 pub fn with_skip_rows(mut self, skip_rows: usize) -> Self {
66 self.skip_rows = skip_rows;
67 self
68 }
69
70 pub fn with_max_rows(mut self, max_rows: usize) -> Self {
72 self.max_rows = Some(max_rows);
73 self
74 }
75
76 pub fn with_infer_types(mut self, infer_types: bool) -> Self {
78 self.infer_types = infer_types;
79 self
80 }
81}
82
83pub fn read_csv<P: AsRef<Path>>(path: P, options: CsvReaderOptions) -> DataFrameResult<DataFrame> {
85 let mut reader = ReaderBuilder::new()
86 .delimiter(options.delimiter)
87 .has_headers(options.has_headers)
88 .from_path(path)?;
89
90 let headers = if options.has_headers {
92 reader.headers()?.clone()
93 } else {
94 let first_record = reader
96 .records()
97 .next()
98 .ok_or_else(|| DataFrameError::ParseError("Empty CSV file".to_string()))??;
99 let num_cols = first_record.len();
100 let mut headers = StringRecord::new();
101 for i in 0..num_cols {
102 headers.push_field(&format!("col_{}", i));
103 }
104 headers
105 };
106
107 let num_cols = headers.len();
108
109 let mut column_data: Vec<Vec<String>> = vec![vec![]; num_cols];
111
112 let mut row_count = 0;
114 for (idx, result) in reader.records().enumerate() {
115 if idx < options.skip_rows {
116 continue;
117 }
118
119 if let Some(max) = options.max_rows {
120 if row_count >= max {
121 break;
122 }
123 }
124
125 let record = result?;
126 for (col_idx, field) in record.iter().enumerate() {
127 if col_idx < num_cols {
128 column_data[col_idx].push(field.to_string());
129 }
130 }
131 row_count += 1;
132 }
133
134 let mut columns = IndexMap::new();
136 for (col_idx, header) in headers.iter().enumerate() {
137 let col_values = if options.infer_types {
138 infer_and_convert_types(&column_data[col_idx])
139 } else {
140 column_data[col_idx]
141 .iter()
142 .map(|s| XdlValue::String(s.clone()))
143 .collect()
144 };
145
146 columns.insert(header.to_string(), Series::from_vec(col_values)?);
147 }
148
149 DataFrame::from_columns(columns)
150}
151
152pub fn read_csv_string(content: &str, options: CsvReaderOptions) -> DataFrameResult<DataFrame> {
154 let mut reader = ReaderBuilder::new()
155 .delimiter(options.delimiter)
156 .has_headers(options.has_headers)
157 .from_reader(content.as_bytes());
158
159 let headers = if options.has_headers {
160 reader.headers()?.clone()
161 } else {
162 let first_record = reader
163 .records()
164 .next()
165 .ok_or_else(|| DataFrameError::ParseError("Empty CSV string".to_string()))??;
166 let num_cols = first_record.len();
167 let mut headers = StringRecord::new();
168 for i in 0..num_cols {
169 headers.push_field(&format!("col_{}", i));
170 }
171 headers
172 };
173
174 let num_cols = headers.len();
175 let mut column_data: Vec<Vec<String>> = vec![vec![]; num_cols];
176
177 let mut row_count = 0;
178 for (idx, result) in reader.records().enumerate() {
179 if idx < options.skip_rows {
180 continue;
181 }
182
183 if let Some(max) = options.max_rows {
184 if row_count >= max {
185 break;
186 }
187 }
188
189 let record = result?;
190 for (col_idx, field) in record.iter().enumerate() {
191 if col_idx < num_cols {
192 column_data[col_idx].push(field.to_string());
193 }
194 }
195 row_count += 1;
196 }
197
198 let mut columns = IndexMap::new();
199 for (col_idx, header) in headers.iter().enumerate() {
200 let col_values = if options.infer_types {
201 infer_and_convert_types(&column_data[col_idx])
202 } else {
203 column_data[col_idx]
204 .iter()
205 .map(|s| XdlValue::String(s.clone()))
206 .collect()
207 };
208
209 columns.insert(header.to_string(), Series::from_vec(col_values)?);
210 }
211
212 DataFrame::from_columns(columns)
213}
214
215fn infer_and_convert_types(values: &[String]) -> Vec<XdlValue> {
217 if values.is_empty() {
218 return vec![];
219 }
220
221 let mut is_int = true;
223 let mut is_float = true;
224
225 for val in values.iter().take(100.min(values.len())) {
226 if val.is_empty() {
227 continue;
228 }
229
230 if is_int && val.parse::<i64>().is_err() {
231 is_int = false;
232 }
233
234 if is_float && val.parse::<f64>().is_err() {
235 is_float = false;
236 }
237
238 if !is_int && !is_float {
239 break;
240 }
241 }
242
243 values
245 .iter()
246 .map(|s| {
247 if s.is_empty() {
248 return XdlValue::Undefined;
249 }
250
251 if is_int {
252 if let Ok(i) = s.parse::<i32>() {
253 return XdlValue::Long(i);
254 } else if let Ok(i) = s.parse::<i64>() {
255 return XdlValue::Long64(i);
256 }
257 }
258
259 if is_float {
260 if let Ok(f) = s.parse::<f64>() {
261 return XdlValue::Double(f);
262 }
263 }
264
265 XdlValue::String(s.clone())
266 })
267 .collect()
268}
269
270pub fn write_csv<P: AsRef<Path>>(
272 dataframe: &DataFrame,
273 path: P,
274 delimiter: u8,
275) -> DataFrameResult<()> {
276 use std::fs::File;
277
278 let file = File::create(path)?;
279 let mut writer = csv::WriterBuilder::new()
280 .delimiter(delimiter)
281 .from_writer(file);
282
283 writer.write_record(dataframe.column_names())?;
285
286 for row_idx in 0..dataframe.nrows() {
288 let row = dataframe.row(row_idx)?;
289 let row_strings: Vec<String> = dataframe
290 .column_names()
291 .iter()
292 .map(|col_name| {
293 row.get(col_name)
294 .map(|v| v.to_string_repr())
295 .unwrap_or_default()
296 })
297 .collect();
298 writer.write_record(&row_strings)?;
299 }
300
301 writer.flush()?;
302 Ok(())
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308
309 #[test]
310 fn test_read_csv_string() {
311 let csv_data = "name,age,city\nAlice,30,NYC\nBob,25,LA\nCarol,35,Chicago";
312
313 let df = read_csv_string(csv_data, CsvReaderOptions::csv()).unwrap();
314
315 assert_eq!(df.nrows(), 3);
316 assert_eq!(df.ncols(), 3);
317 assert_eq!(df.column_names(), vec!["name", "age", "city"]);
318 }
319
320 #[test]
321 fn test_read_tsv_string() {
322 let tsv_data = "name\tage\tcity\nAlice\t30\tNYC\nBob\t25\tLA";
323
324 let df = read_csv_string(tsv_data, CsvReaderOptions::tsv()).unwrap();
325
326 assert_eq!(df.nrows(), 2);
327 assert_eq!(df.ncols(), 3);
328 }
329
330 #[test]
331 fn test_type_inference() {
332 let csv_data = "int_col,float_col,str_col\n1,1.5,hello\n2,2.5,world\n3,3.5,test";
333
334 let df = read_csv_string(csv_data, CsvReaderOptions::csv()).unwrap();
335
336 assert!(df.column("int_col").is_ok());
338 assert!(df.column("float_col").is_ok());
339 assert!(df.column("str_col").is_ok());
340 }
341}