1use crate::error::{DatasetsError, Result};
4use crate::utils::Dataset;
5use csv::ReaderBuilder;
6use scirs2_core::ndarray::{Array1, Array2};
7use std::fs::File;
8use std::io::{BufReader, Read};
9use std::path::Path;
10use std::sync::{Arc, Mutex};
11
12#[allow(dead_code)]
14pub fn load_csv_legacy<P: AsRef<Path>>(
15 path: P,
16 has_header: bool,
17 target_column: Option<usize>,
18) -> Result<Dataset> {
19 let config = CsvConfig::new()
20 .with_header(has_header)
21 .with_target_column(target_column);
22 load_csv(path, config)
23}
24
25#[allow(dead_code)]
27pub fn load_json<P: AsRef<Path>>(path: P) -> Result<Dataset> {
28 let file = File::open(path).map_err(DatasetsError::IoError)?;
29 let reader = BufReader::new(file);
30
31 let dataset: Dataset = serde_json::from_reader(reader)
32 .map_err(|e| DatasetsError::InvalidFormat(format!("Failed to parse JSON: {e}")))?;
33
34 Ok(dataset)
35}
36
37#[allow(dead_code)]
39pub fn save_json<P: AsRef<Path>>(dataset: &Dataset, path: P) -> Result<()> {
40 let file = File::create(path).map_err(DatasetsError::IoError)?;
41
42 serde_json::to_writer_pretty(file, dataset)
43 .map_err(|e| DatasetsError::SerdeError(format!("Failed to write JSON: {e}")))?;
44
45 Ok(())
46}
47
48#[allow(dead_code)]
50pub fn load_raw<P: AsRef<Path>>(path: P) -> Result<Vec<u8>> {
51 let mut file = File::open(path).map_err(DatasetsError::IoError)?;
52 let mut buffer = Vec::new();
53
54 file.read_to_end(&mut buffer)
55 .map_err(DatasetsError::IoError)?;
56
57 Ok(buffer)
58}
59
60#[derive(Debug, Clone)]
62pub struct CsvConfig {
63 pub has_header: bool,
65 pub target_column: Option<usize>,
67 pub delimiter: u8,
69 pub quote: u8,
71 pub double_quote: bool,
73 pub escape: Option<u8>,
75 pub flexible: bool,
77}
78
79impl Default for CsvConfig {
80 fn default() -> Self {
81 Self {
82 has_header: true,
83 target_column: None,
84 delimiter: b',',
85 quote: b'"',
86 double_quote: true,
87 escape: None,
88 flexible: false,
89 }
90 }
91}
92
93impl CsvConfig {
94 pub fn new() -> Self {
96 Self::default()
97 }
98
99 pub fn with_header(mut self, hasheader: bool) -> Self {
101 self.has_header = hasheader;
102 self
103 }
104
105 pub fn with_target_column(mut self, targetcolumn: Option<usize>) -> Self {
107 self.target_column = targetcolumn;
108 self
109 }
110
111 pub fn with_delimiter(mut self, delimiter: u8) -> Self {
113 self.delimiter = delimiter;
114 self
115 }
116
117 pub fn with_flexible(mut self, flexible: bool) -> Self {
119 self.flexible = flexible;
120 self
121 }
122}
123
124#[derive(Debug, Clone)]
126pub struct StreamingConfig {
127 pub chunk_size: usize,
129 pub parallel: bool,
131 pub num_threads: usize,
133 pub max_memory: usize,
135 pub use_mmap: bool,
137}
138
139impl Default for StreamingConfig {
140 fn default() -> Self {
141 Self {
142 chunk_size: 1000,
143 parallel: true,
144 num_threads: 0, max_memory: 0, use_mmap: false,
147 }
148 }
149}
150
151impl StreamingConfig {
152 pub fn new() -> Self {
154 Self::default()
155 }
156
157 pub fn with_chunk_size(mut self, chunksize: usize) -> Self {
159 self.chunk_size = chunksize;
160 self
161 }
162
163 pub fn with_parallel(mut self, parallel: bool) -> Self {
165 self.parallel = parallel;
166 self
167 }
168
169 pub fn with_num_threads(mut self, numthreads: usize) -> Self {
171 self.num_threads = numthreads;
172 self
173 }
174
175 pub fn with_max_memory(mut self, maxmemory: usize) -> Self {
177 self.max_memory = maxmemory;
178 self
179 }
180
181 pub fn with_mmap(mut self, usemmap: bool) -> Self {
183 self.use_mmap = usemmap;
184 self
185 }
186}
187
188pub struct DatasetChunkIterator {
190 reader: csv::Reader<File>,
191 chunk_size: usize,
192 target_column: Option<usize>,
193 featurenames: Option<Vec<String>>,
194 n_features: usize,
195 buffer: Vec<Vec<f64>>,
196 finished: bool,
197}
198
199impl DatasetChunkIterator {
200 pub fn new<P: AsRef<Path>>(path: P, csv_config: CsvConfig, chunksize: usize) -> Result<Self> {
202 let file = File::open(path).map_err(DatasetsError::IoError)?;
203 let mut reader = ReaderBuilder::new()
204 .has_headers(csv_config.has_header)
205 .delimiter(csv_config.delimiter)
206 .quote(csv_config.quote)
207 .double_quote(csv_config.double_quote)
208 .flexible(csv_config.flexible)
209 .from_reader(file);
210
211 let featurenames = if csv_config.has_header {
213 let headers = reader.headers().map_err(|e| {
214 DatasetsError::InvalidFormat(format!("Failed to read CSV headers: {e}"))
215 })?;
216 Some(
217 headers
218 .iter()
219 .map(|s| s.to_string())
220 .collect::<Vec<String>>(),
221 )
222 } else {
223 None
224 };
225
226 let n_features = if let Some(ref names) = featurenames {
228 if csv_config.target_column.is_some() {
229 names.len() - 1
230 } else {
231 names.len()
232 }
233 } else {
234 0
236 };
237
238 Ok(Self {
239 reader,
240 chunk_size: chunksize,
241 target_column: csv_config.target_column,
242 featurenames,
243 n_features,
244 buffer: Vec::new(),
245 finished: false,
246 })
247 }
248
249 pub fn featurenames(&self) -> Option<&Vec<String>> {
251 self.featurenames.as_ref()
252 }
253
254 pub fn n_features(&self) -> usize {
256 self.n_features
257 }
258}
259
260impl Iterator for DatasetChunkIterator {
261 type Item = Result<Dataset>;
262
263 fn next(&mut self) -> Option<Self::Item> {
264 if self.finished {
265 return None;
266 }
267
268 self.buffer.clear();
269
270 for _ in 0..self.chunk_size {
272 match self.reader.records().next() {
273 Some(Ok(record)) => {
274 let values: Vec<f64> = match record
275 .iter()
276 .map(|s| s.parse::<f64>())
277 .collect::<std::result::Result<Vec<f64>, _>>()
278 {
279 Ok(vals) => vals,
280 Err(e) => {
281 return Some(Err(DatasetsError::InvalidFormat(format!(
282 "Failed to parse value: {e}"
283 ))))
284 }
285 };
286
287 if !values.is_empty() {
288 if self.n_features == 0 {
290 self.n_features = if self.target_column.is_some() {
291 values.len() - 1
292 } else {
293 values.len()
294 };
295 }
296 self.buffer.push(values);
297 }
298 }
299 Some(Err(e)) => {
300 return Some(Err(DatasetsError::InvalidFormat(format!(
301 "Failed to read CSV record: {e}"
302 ))))
303 }
304 None => {
305 self.finished = true;
306 break;
307 }
308 }
309 }
310
311 if self.buffer.is_empty() {
312 return None;
313 }
314
315 let n_rows = self.buffer.len();
317 let n_cols = self.buffer[0].len();
318
319 let (data, target) = if let Some(idx) = self.target_column {
320 if idx >= n_cols {
321 return Some(Err(DatasetsError::InvalidFormat(format!(
322 "Target column index {idx} is out of bounds (max: {})",
323 n_cols - 1
324 ))));
325 }
326
327 let mut data_array = Array2::zeros((n_rows, n_cols - 1));
328 let mut target_array = Array1::zeros(n_rows);
329
330 for (i, row) in self.buffer.iter().enumerate() {
331 let mut data_col = 0;
332 for (j, &val) in row.iter().enumerate() {
333 if j == idx {
334 target_array[i] = val;
335 } else {
336 data_array[[i, data_col]] = val;
337 data_col += 1;
338 }
339 }
340 }
341
342 (data_array, Some(target_array))
343 } else {
344 let mut data_array = Array2::zeros((n_rows, n_cols));
345
346 for (i, row) in self.buffer.iter().enumerate() {
347 for (j, &val) in row.iter().enumerate() {
348 data_array[[i, j]] = val;
349 }
350 }
351
352 (data_array, None)
353 };
354
355 let mut dataset = Dataset::new(data, target);
356
357 if let Some(ref names) = self.featurenames {
359 let featurenames = if let Some(target_idx) = self.target_column {
360 names
361 .iter()
362 .enumerate()
363 .filter_map(|(i, name)| {
364 if i != target_idx {
365 Some(name.clone())
366 } else {
367 None
368 }
369 })
370 .collect()
371 } else {
372 names.clone()
373 };
374 dataset = dataset.with_featurenames(featurenames);
375 }
376
377 Some(Ok(dataset))
378 }
379}
380
381#[allow(dead_code)]
383pub fn load_csv_streaming<P: AsRef<Path>>(
384 path: P,
385 csv_config: CsvConfig,
386 streaming_config: StreamingConfig,
387) -> Result<DatasetChunkIterator> {
388 DatasetChunkIterator::new(path, csv_config, streaming_config.chunk_size)
389}
390
391#[allow(dead_code)]
393pub fn load_csv_parallel<P: AsRef<Path>>(
394 path: P,
395 csv_config: CsvConfig,
396 streaming_config: StreamingConfig,
397) -> Result<Dataset> {
398 let file = File::open(&path).map_err(DatasetsError::IoError)?;
400 let mut reader = ReaderBuilder::new()
401 .has_headers(csv_config.has_header)
402 .delimiter(csv_config.delimiter)
403 .from_reader(file);
404
405 let featurenames = if csv_config.has_header {
406 let headers = reader.headers().map_err(|e| {
407 DatasetsError::InvalidFormat(format!("Failed to read CSV headers: {e}"))
408 })?;
409 Some(
410 headers
411 .iter()
412 .map(|s| s.to_string())
413 .collect::<Vec<String>>(),
414 )
415 } else {
416 None
417 };
418
419 let mut row_count = 0;
421 let mut col_count = 0;
422
423 for result in reader.records() {
424 let record = result
425 .map_err(|e| DatasetsError::InvalidFormat(format!("Failed to read CSV record: {e}")))?;
426
427 if col_count == 0 {
428 col_count = record.len();
429 }
430 row_count += 1;
431 }
432
433 if row_count == 0 {
434 return Err(DatasetsError::InvalidFormat(
435 "CSV file is empty".to_string(),
436 ));
437 }
438
439 let data_cols = if csv_config.target_column.is_some() {
441 col_count - 1
442 } else {
443 col_count
444 };
445
446 let data = Arc::new(Mutex::new(Array2::zeros((row_count, data_cols))));
448 let target = if csv_config.target_column.is_some() {
449 Some(Arc::new(Mutex::new(Array1::zeros(row_count))))
450 } else {
451 None
452 };
453
454 if streaming_config.parallel && row_count > streaming_config.chunk_size {
456 load_csv_parallel_chunks(
457 &path,
458 csv_config.clone(),
459 streaming_config,
460 data.clone(),
461 target.clone(),
462 row_count,
463 )?;
464 } else {
465 load_csv_sequential(&path, csv_config.clone(), data.clone(), target.clone())?;
466 }
467
468 let final_data = Arc::try_unwrap(data)
470 .map_err(|_| DatasetsError::Other("Failed to unwrap data array".to_string()))?
471 .into_inner()
472 .map_err(|_| DatasetsError::Other("Failed to acquire data lock".to_string()))?;
473
474 let final_target = if let Some(target_arc) = target {
475 Some(
476 Arc::try_unwrap(target_arc)
477 .map_err(|_| DatasetsError::Other("Failed to unwrap target array".to_string()))?
478 .into_inner()
479 .map_err(|_| DatasetsError::Other("Failed to acquire target lock".to_string()))?,
480 )
481 } else {
482 None
483 };
484
485 let mut dataset = Dataset::new(final_data, final_target);
486
487 if let Some(names) = featurenames {
489 let featurenames = if let Some(target_idx) = csv_config.target_column {
490 names
491 .iter()
492 .enumerate()
493 .filter_map(|(i, name)| {
494 if i != target_idx {
495 Some(name.clone())
496 } else {
497 None
498 }
499 })
500 .collect()
501 } else {
502 names
503 };
504 dataset = dataset.with_featurenames(featurenames);
505 }
506
507 Ok(dataset)
508}
509
510#[allow(clippy::too_many_arguments)]
512#[allow(dead_code)]
513fn load_csv_parallel_chunks<P: AsRef<Path>>(
514 path: P,
515 csv_config: CsvConfig,
516 streaming_config: StreamingConfig,
517 data: Arc<Mutex<Array2<f64>>>,
518 target: Option<Arc<Mutex<Array1<f64>>>>,
519 total_rows: usize,
520) -> Result<()> {
521 let chunk_size = streaming_config.chunk_size;
522 let num_chunks = total_rows.div_ceil(chunk_size);
523
524 for chunk_idx in 0..num_chunks {
526 let start_row = chunk_idx * chunk_size;
527 let end_row = std::cmp::min(start_row + chunk_size, total_rows);
528
529 if let Err(e) = process_csv_chunk(
530 &path,
531 &csv_config,
532 start_row,
533 end_row,
534 data.clone(),
535 target.clone(),
536 ) {
537 eprintln!("Error processing chunk {chunk_idx}: {e}");
538 }
539 }
540
541 Ok(())
542}
543
544#[allow(clippy::too_many_arguments)]
546#[allow(dead_code)]
547fn process_csv_chunk<P: AsRef<Path>>(
548 path: P,
549 csv_config: &CsvConfig,
550 start_row: usize,
551 end_row: usize,
552 data: Arc<Mutex<Array2<f64>>>,
553 target: Option<Arc<Mutex<Array1<f64>>>>,
554) -> Result<()> {
555 let file = File::open(path).map_err(DatasetsError::IoError)?;
556 let mut reader = ReaderBuilder::new()
557 .has_headers(csv_config.has_header)
558 .delimiter(csv_config.delimiter)
559 .from_reader(file);
560
561 if csv_config.has_header {
563 reader
564 .headers()
565 .map_err(|e| DatasetsError::InvalidFormat(format!("Failed to read headers: {e}")))?;
566 }
567
568 for (current_row, result) in reader.records().enumerate() {
569 if current_row >= end_row {
570 break;
571 }
572
573 if current_row >= start_row {
574 let record = result.map_err(|e| {
575 DatasetsError::InvalidFormat(format!("Failed to read CSV record: {e}"))
576 })?;
577
578 let values: Vec<f64> = record
579 .iter()
580 .map(|s| s.parse::<f64>())
581 .collect::<std::result::Result<Vec<f64>, _>>()
582 .map_err(|e| DatasetsError::InvalidFormat(format!("Failed to parse value: {e}")))?;
583
584 {
586 let mut data_lock = data.lock().unwrap();
587 if let Some(target_idx) = csv_config.target_column {
588 let mut data_col = 0;
589 for (j, &val) in values.iter().enumerate() {
590 if j == target_idx {
591 if let Some(ref target_arc) = target {
592 let mut target_lock = target_arc.lock().unwrap();
593 target_lock[current_row] = val;
594 }
595 } else {
596 data_lock[[current_row, data_col]] = val;
597 data_col += 1;
598 }
599 }
600 } else {
601 for (j, &val) in values.iter().enumerate() {
602 data_lock[[current_row, j]] = val;
603 }
604 }
605 }
606 }
607 }
608
609 Ok(())
610}
611
612#[allow(dead_code)]
614fn load_csv_sequential<P: AsRef<Path>>(
615 path: P,
616 csv_config: CsvConfig,
617 data: Arc<Mutex<Array2<f64>>>,
618 target: Option<Arc<Mutex<Array1<f64>>>>,
619) -> Result<()> {
620 let file = File::open(path).map_err(DatasetsError::IoError)?;
621 let mut reader = ReaderBuilder::new()
622 .has_headers(csv_config.has_header)
623 .delimiter(csv_config.delimiter)
624 .from_reader(file);
625
626 if csv_config.has_header {
627 reader
628 .headers()
629 .map_err(|e| DatasetsError::InvalidFormat(format!("Failed to read headers: {e}")))?;
630 }
631
632 for (row_idx, result) in reader.records().enumerate() {
633 let record = result
634 .map_err(|e| DatasetsError::InvalidFormat(format!("Failed to read CSV record: {e}")))?;
635
636 let values: Vec<f64> = record
637 .iter()
638 .map(|s| s.parse::<f64>())
639 .collect::<std::result::Result<Vec<f64>, _>>()
640 .map_err(|e| DatasetsError::InvalidFormat(format!("Failed to parse value: {e}")))?;
641
642 {
643 let mut data_lock = data.lock().unwrap();
644 if let Some(target_idx) = csv_config.target_column {
645 let mut data_col = 0;
646 for (j, &val) in values.iter().enumerate() {
647 if j == target_idx {
648 if let Some(ref target_arc) = target {
649 let mut target_lock = target_arc.lock().unwrap();
650 target_lock[row_idx] = val;
651 }
652 } else {
653 data_lock[[row_idx, data_col]] = val;
654 data_col += 1;
655 }
656 }
657 } else {
658 for (j, &val) in values.iter().enumerate() {
659 data_lock[[row_idx, j]] = val;
660 }
661 }
662 }
663 }
664
665 Ok(())
666}
667
668#[allow(dead_code)]
670pub fn load_csv<P: AsRef<Path>>(path: P, config: CsvConfig) -> Result<Dataset> {
671 let file = File::open(path).map_err(DatasetsError::IoError)?;
672 let mut reader = ReaderBuilder::new()
673 .has_headers(config.has_header)
674 .delimiter(config.delimiter)
675 .quote(config.quote)
676 .double_quote(config.double_quote)
677 .flexible(config.flexible)
678 .from_reader(file);
679
680 let mut records: Vec<Vec<f64>> = Vec::new();
681 let mut header: Option<Vec<String>> = None;
682
683 if config.has_header {
685 let headers = reader.headers().map_err(|e| {
686 DatasetsError::InvalidFormat(format!("Failed to read CSV headers: {e}"))
687 })?;
688 header = Some(headers.iter().map(|s| s.to_string()).collect());
689 }
690
691 for result in reader.records() {
693 let record = result
694 .map_err(|e| DatasetsError::InvalidFormat(format!("Failed to read CSV record: {e}")))?;
695
696 let values: Vec<f64> = record
697 .iter()
698 .map(|s| {
699 s.parse::<f64>().map_err(|_| {
700 DatasetsError::InvalidFormat(format!("Failed to parse value: {s}"))
701 })
702 })
703 .collect::<Result<Vec<f64>>>()?;
704
705 if !values.is_empty() {
706 records.push(values);
707 }
708 }
709
710 if records.is_empty() {
711 return Err(DatasetsError::InvalidFormat(
712 "CSV file is empty".to_string(),
713 ));
714 }
715
716 let n_rows = records.len();
718 let n_cols = records[0].len();
719
720 let (data, target, featurenames, _targetname) = if let Some(idx) = config.target_column {
721 if idx >= n_cols {
722 return Err(DatasetsError::InvalidFormat(format!(
723 "Target column index {idx} is out of bounds (max: {})",
724 n_cols - 1
725 )));
726 }
727
728 let mut data_array = Array2::zeros((n_rows, n_cols - 1));
729 let mut target_array = Array1::zeros(n_rows);
730
731 for (i, row) in records.iter().enumerate() {
732 let mut data_col = 0;
733 for (j, &val) in row.iter().enumerate() {
734 if j == idx {
735 target_array[i] = val;
736 } else {
737 data_array[[i, data_col]] = val;
738 data_col += 1;
739 }
740 }
741 }
742
743 let featurenames = header.as_ref().map(|h| {
744 let mut names = Vec::new();
745 for (j, name) in h.iter().enumerate() {
746 if j != idx {
747 names.push(name.clone());
748 }
749 }
750 names
751 });
752
753 (
754 data_array,
755 Some(target_array),
756 featurenames,
757 header.as_ref().map(|h| h[idx].clone()),
758 )
759 } else {
760 let mut data_array = Array2::zeros((n_rows, n_cols));
761
762 for (i, row) in records.iter().enumerate() {
763 for (j, &val) in row.iter().enumerate() {
764 data_array[[i, j]] = val;
765 }
766 }
767
768 (data_array, None, header, None)
769 };
770
771 let mut dataset = Dataset::new(data, target);
772
773 if let Some(names) = featurenames {
774 dataset = dataset.with_featurenames(names);
775 }
776
777 Ok(dataset)
778}