scirs2_datasets/
loaders.rs

1//! Data loading utilities
2
3use crate::error::{DatasetsError, Result};
4use crate::utils::Dataset;
5use csv::ReaderBuilder;
6use scirs2_core::ndarray::{Array1, Array2};
7use std::fs::File;
8use std::io::{BufReader, Read};
9use std::path::Path;
10use std::sync::{Arc, Mutex};
11
12/// Load a dataset from a CSV file (legacy API)
13#[allow(dead_code)]
14pub fn load_csv_legacy<P: AsRef<Path>>(
15    path: P,
16    has_header: bool,
17    target_column: Option<usize>,
18) -> Result<Dataset> {
19    let config = CsvConfig::new()
20        .with_header(has_header)
21        .with_target_column(target_column);
22    load_csv(path, config)
23}
24
25/// Load a dataset from a JSON file
26#[allow(dead_code)]
27pub fn load_json<P: AsRef<Path>>(path: P) -> Result<Dataset> {
28    let file = File::open(path).map_err(DatasetsError::IoError)?;
29    let reader = BufReader::new(file);
30
31    let dataset: Dataset = serde_json::from_reader(reader)
32        .map_err(|e| DatasetsError::InvalidFormat(format!("Failed to parse JSON: {e}")))?;
33
34    Ok(dataset)
35}
36
37/// Save a dataset to a JSON file
38#[allow(dead_code)]
39pub fn save_json<P: AsRef<Path>>(dataset: &Dataset, path: P) -> Result<()> {
40    let file = File::create(path).map_err(DatasetsError::IoError)?;
41
42    serde_json::to_writer_pretty(file, dataset)
43        .map_err(|e| DatasetsError::SerdeError(format!("Failed to write JSON: {e}")))?;
44
45    Ok(())
46}
47
48/// Load raw data from a file
49#[allow(dead_code)]
50pub fn load_raw<P: AsRef<Path>>(path: P) -> Result<Vec<u8>> {
51    let mut file = File::open(path).map_err(DatasetsError::IoError)?;
52    let mut buffer = Vec::new();
53
54    file.read_to_end(&mut buffer)
55        .map_err(DatasetsError::IoError)?;
56
57    Ok(buffer)
58}
59
60/// Configuration for CSV loading operations
61#[derive(Debug, Clone)]
62pub struct CsvConfig {
63    /// Whether the CSV has a header row
64    pub has_header: bool,
65    /// Index of the target column (if any)
66    pub target_column: Option<usize>,
67    /// Delimiter character
68    pub delimiter: u8,
69    /// Quote character
70    pub quote: u8,
71    /// Whether to use double quotes
72    pub double_quote: bool,
73    /// Escape character
74    pub escape: Option<u8>,
75    /// Flexible parsing (ignore inconsistent columns)
76    pub flexible: bool,
77}
78
79impl Default for CsvConfig {
80    fn default() -> Self {
81        Self {
82            has_header: true,
83            target_column: None,
84            delimiter: b',',
85            quote: b'"',
86            double_quote: true,
87            escape: None,
88            flexible: false,
89        }
90    }
91}
92
93impl CsvConfig {
94    /// Create a new CSV configuration
95    pub fn new() -> Self {
96        Self::default()
97    }
98
99    /// Set whether the CSV has headers
100    pub fn with_header(mut self, hasheader: bool) -> Self {
101        self.has_header = hasheader;
102        self
103    }
104
105    /// Set the target column index
106    pub fn with_target_column(mut self, targetcolumn: Option<usize>) -> Self {
107        self.target_column = targetcolumn;
108        self
109    }
110
111    /// Set the delimiter character
112    pub fn with_delimiter(mut self, delimiter: u8) -> Self {
113        self.delimiter = delimiter;
114        self
115    }
116
117    /// Set flexible parsing mode
118    pub fn with_flexible(mut self, flexible: bool) -> Self {
119        self.flexible = flexible;
120        self
121    }
122}
123
124/// Configuration for streaming dataset loading
125#[derive(Debug, Clone)]
126pub struct StreamingConfig {
127    /// Size of each chunk (number of rows)
128    pub chunk_size: usize,
129    /// Whether to use parallel processing
130    pub parallel: bool,
131    /// Number of parallel threads (0 = auto-detect)
132    pub num_threads: usize,
133    /// Maximum memory usage in bytes (0 = unlimited)
134    pub max_memory: usize,
135    /// Whether to use memory mapping for large files
136    pub use_mmap: bool,
137}
138
139impl Default for StreamingConfig {
140    fn default() -> Self {
141        Self {
142            chunk_size: 1000,
143            parallel: true,
144            num_threads: 0, // Auto-detect
145            max_memory: 0,  // Unlimited
146            use_mmap: false,
147        }
148    }
149}
150
151impl StreamingConfig {
152    /// Create a new streaming configuration
153    pub fn new() -> Self {
154        Self::default()
155    }
156
157    /// Set the chunk size
158    pub fn with_chunk_size(mut self, chunksize: usize) -> Self {
159        self.chunk_size = chunksize;
160        self
161    }
162
163    /// Enable or disable parallel processing
164    pub fn with_parallel(mut self, parallel: bool) -> Self {
165        self.parallel = parallel;
166        self
167    }
168
169    /// Set the number of threads
170    pub fn with_num_threads(mut self, numthreads: usize) -> Self {
171        self.num_threads = numthreads;
172        self
173    }
174
175    /// Set maximum memory usage
176    pub fn with_max_memory(mut self, maxmemory: usize) -> Self {
177        self.max_memory = maxmemory;
178        self
179    }
180
181    /// Enable or disable memory mapping
182    pub fn with_mmap(mut self, usemmap: bool) -> Self {
183        self.use_mmap = usemmap;
184        self
185    }
186}
187
188/// Iterator for streaming dataset chunks
189pub struct DatasetChunkIterator {
190    reader: csv::Reader<File>,
191    chunk_size: usize,
192    target_column: Option<usize>,
193    featurenames: Option<Vec<String>>,
194    n_features: usize,
195    buffer: Vec<Vec<f64>>,
196    finished: bool,
197}
198
199impl DatasetChunkIterator {
200    /// Create a new chunk iterator
201    pub fn new<P: AsRef<Path>>(path: P, csv_config: CsvConfig, chunksize: usize) -> Result<Self> {
202        let file = File::open(path).map_err(DatasetsError::IoError)?;
203        let mut reader = ReaderBuilder::new()
204            .has_headers(csv_config.has_header)
205            .delimiter(csv_config.delimiter)
206            .quote(csv_config.quote)
207            .double_quote(csv_config.double_quote)
208            .flexible(csv_config.flexible)
209            .from_reader(file);
210
211        // Read header if present
212        let featurenames = if csv_config.has_header {
213            let headers = reader.headers().map_err(|e| {
214                DatasetsError::InvalidFormat(format!("Failed to read CSV headers: {e}"))
215            })?;
216            Some(
217                headers
218                    .iter()
219                    .map(|s| s.to_string())
220                    .collect::<Vec<String>>(),
221            )
222        } else {
223            None
224        };
225
226        // Determine number of features
227        let n_features = if let Some(ref names) = featurenames {
228            if csv_config.target_column.is_some() {
229                names.len() - 1
230            } else {
231                names.len()
232            }
233        } else {
234            // We'll determine this from the first row
235            0
236        };
237
238        Ok(Self {
239            reader,
240            chunk_size: chunksize,
241            target_column: csv_config.target_column,
242            featurenames,
243            n_features,
244            buffer: Vec::new(),
245            finished: false,
246        })
247    }
248
249    /// Get feature names
250    pub fn featurenames(&self) -> Option<&Vec<String>> {
251        self.featurenames.as_ref()
252    }
253
254    /// Get number of features
255    pub fn n_features(&self) -> usize {
256        self.n_features
257    }
258}
259
260impl Iterator for DatasetChunkIterator {
261    type Item = Result<Dataset>;
262
263    fn next(&mut self) -> Option<Self::Item> {
264        if self.finished {
265            return None;
266        }
267
268        self.buffer.clear();
269
270        // Read chunk_size rows
271        for _ in 0..self.chunk_size {
272            match self.reader.records().next() {
273                Some(Ok(record)) => {
274                    let values: Vec<f64> = match record
275                        .iter()
276                        .map(|s| s.parse::<f64>())
277                        .collect::<std::result::Result<Vec<f64>, _>>()
278                    {
279                        Ok(vals) => vals,
280                        Err(e) => {
281                            return Some(Err(DatasetsError::InvalidFormat(format!(
282                                "Failed to parse value: {e}"
283                            ))))
284                        }
285                    };
286
287                    if !values.is_empty() {
288                        // Update n_features if not set
289                        if self.n_features == 0 {
290                            self.n_features = if self.target_column.is_some() {
291                                values.len() - 1
292                            } else {
293                                values.len()
294                            };
295                        }
296                        self.buffer.push(values);
297                    }
298                }
299                Some(Err(e)) => {
300                    return Some(Err(DatasetsError::InvalidFormat(format!(
301                        "Failed to read CSV record: {e}"
302                    ))))
303                }
304                None => {
305                    self.finished = true;
306                    break;
307                }
308            }
309        }
310
311        if self.buffer.is_empty() {
312            return None;
313        }
314
315        // Create dataset from buffer
316        let n_rows = self.buffer.len();
317        let n_cols = self.buffer[0].len();
318
319        let (data, target) = if let Some(idx) = self.target_column {
320            if idx >= n_cols {
321                return Some(Err(DatasetsError::InvalidFormat(format!(
322                    "Target column index {idx} is out of bounds (max: {})",
323                    n_cols - 1
324                ))));
325            }
326
327            let mut data_array = Array2::zeros((n_rows, n_cols - 1));
328            let mut target_array = Array1::zeros(n_rows);
329
330            for (i, row) in self.buffer.iter().enumerate() {
331                let mut data_col = 0;
332                for (j, &val) in row.iter().enumerate() {
333                    if j == idx {
334                        target_array[i] = val;
335                    } else {
336                        data_array[[i, data_col]] = val;
337                        data_col += 1;
338                    }
339                }
340            }
341
342            (data_array, Some(target_array))
343        } else {
344            let mut data_array = Array2::zeros((n_rows, n_cols));
345
346            for (i, row) in self.buffer.iter().enumerate() {
347                for (j, &val) in row.iter().enumerate() {
348                    data_array[[i, j]] = val;
349                }
350            }
351
352            (data_array, None)
353        };
354
355        let mut dataset = Dataset::new(data, target);
356
357        // Set feature names (excluding target column)
358        if let Some(ref names) = self.featurenames {
359            let featurenames = if let Some(target_idx) = self.target_column {
360                names
361                    .iter()
362                    .enumerate()
363                    .filter_map(|(i, name)| {
364                        if i != target_idx {
365                            Some(name.clone())
366                        } else {
367                            None
368                        }
369                    })
370                    .collect()
371            } else {
372                names.clone()
373            };
374            dataset = dataset.with_featurenames(featurenames);
375        }
376
377        Some(Ok(dataset))
378    }
379}
380
381/// Load a CSV file using streaming with configurable chunking
382#[allow(dead_code)]
383pub fn load_csv_streaming<P: AsRef<Path>>(
384    path: P,
385    csv_config: CsvConfig,
386    streaming_config: StreamingConfig,
387) -> Result<DatasetChunkIterator> {
388    DatasetChunkIterator::new(path, csv_config, streaming_config.chunk_size)
389}
390
391/// Load a large CSV file efficiently by processing in parallel chunks
392#[allow(dead_code)]
393pub fn load_csv_parallel<P: AsRef<Path>>(
394    path: P,
395    csv_config: CsvConfig,
396    streaming_config: StreamingConfig,
397) -> Result<Dataset> {
398    // First pass: determine dataset dimensions
399    let file = File::open(&path).map_err(DatasetsError::IoError)?;
400    let mut reader = ReaderBuilder::new()
401        .has_headers(csv_config.has_header)
402        .delimiter(csv_config.delimiter)
403        .from_reader(file);
404
405    let featurenames = if csv_config.has_header {
406        let headers = reader.headers().map_err(|e| {
407            DatasetsError::InvalidFormat(format!("Failed to read CSV headers: {e}"))
408        })?;
409        Some(
410            headers
411                .iter()
412                .map(|s| s.to_string())
413                .collect::<Vec<String>>(),
414        )
415    } else {
416        None
417    };
418
419    // Count rows and determine column count
420    let mut row_count = 0;
421    let mut col_count = 0;
422
423    for result in reader.records() {
424        let record = result
425            .map_err(|e| DatasetsError::InvalidFormat(format!("Failed to read CSV record: {e}")))?;
426
427        if col_count == 0 {
428            col_count = record.len();
429        }
430        row_count += 1;
431    }
432
433    if row_count == 0 {
434        return Err(DatasetsError::InvalidFormat(
435            "CSV file is empty".to_string(),
436        ));
437    }
438
439    // Determine final dimensions
440    let data_cols = if csv_config.target_column.is_some() {
441        col_count - 1
442    } else {
443        col_count
444    };
445
446    // Create output arrays
447    let data = Arc::new(Mutex::new(Array2::zeros((row_count, data_cols))));
448    let target = if csv_config.target_column.is_some() {
449        Some(Arc::new(Mutex::new(Array1::zeros(row_count))))
450    } else {
451        None
452    };
453
454    // Second pass: parallel processing in chunks
455    if streaming_config.parallel && row_count > streaming_config.chunk_size {
456        load_csv_parallel_chunks(
457            &path,
458            csv_config.clone(),
459            streaming_config,
460            data.clone(),
461            target.clone(),
462            row_count,
463        )?;
464    } else {
465        load_csv_sequential(&path, csv_config.clone(), data.clone(), target.clone())?;
466    }
467
468    // Extract final arrays
469    let final_data = Arc::try_unwrap(data)
470        .map_err(|_| DatasetsError::Other("Failed to unwrap data array".to_string()))?
471        .into_inner()
472        .map_err(|_| DatasetsError::Other("Failed to acquire data lock".to_string()))?;
473
474    let final_target = if let Some(target_arc) = target {
475        Some(
476            Arc::try_unwrap(target_arc)
477                .map_err(|_| DatasetsError::Other("Failed to unwrap target array".to_string()))?
478                .into_inner()
479                .map_err(|_| DatasetsError::Other("Failed to acquire target lock".to_string()))?,
480        )
481    } else {
482        None
483    };
484
485    let mut dataset = Dataset::new(final_data, final_target);
486
487    // Set feature names
488    if let Some(names) = featurenames {
489        let featurenames = if let Some(target_idx) = csv_config.target_column {
490            names
491                .iter()
492                .enumerate()
493                .filter_map(|(i, name)| {
494                    if i != target_idx {
495                        Some(name.clone())
496                    } else {
497                        None
498                    }
499                })
500                .collect()
501        } else {
502            names
503        };
504        dataset = dataset.with_featurenames(featurenames);
505    }
506
507    Ok(dataset)
508}
509
510/// Load CSV using parallel chunks
511#[allow(clippy::too_many_arguments)]
512#[allow(dead_code)]
513fn load_csv_parallel_chunks<P: AsRef<Path>>(
514    path: P,
515    csv_config: CsvConfig,
516    streaming_config: StreamingConfig,
517    data: Arc<Mutex<Array2<f64>>>,
518    target: Option<Arc<Mutex<Array1<f64>>>>,
519    total_rows: usize,
520) -> Result<()> {
521    let chunk_size = streaming_config.chunk_size;
522    let num_chunks = total_rows.div_ceil(chunk_size);
523
524    // Process chunks sequentially (parallel processing disabled for now)
525    for chunk_idx in 0..num_chunks {
526        let start_row = chunk_idx * chunk_size;
527        let end_row = std::cmp::min(start_row + chunk_size, total_rows);
528
529        if let Err(e) = process_csv_chunk(
530            &path,
531            &csv_config,
532            start_row,
533            end_row,
534            data.clone(),
535            target.clone(),
536        ) {
537            eprintln!("Error processing chunk {chunk_idx}: {e}");
538        }
539    }
540
541    Ok(())
542}
543
544/// Process a single CSV chunk
545#[allow(clippy::too_many_arguments)]
546#[allow(dead_code)]
547fn process_csv_chunk<P: AsRef<Path>>(
548    path: P,
549    csv_config: &CsvConfig,
550    start_row: usize,
551    end_row: usize,
552    data: Arc<Mutex<Array2<f64>>>,
553    target: Option<Arc<Mutex<Array1<f64>>>>,
554) -> Result<()> {
555    let file = File::open(path).map_err(DatasetsError::IoError)?;
556    let mut reader = ReaderBuilder::new()
557        .has_headers(csv_config.has_header)
558        .delimiter(csv_config.delimiter)
559        .from_reader(file);
560
561    // Skip to start _row
562    if csv_config.has_header {
563        reader
564            .headers()
565            .map_err(|e| DatasetsError::InvalidFormat(format!("Failed to read headers: {e}")))?;
566    }
567
568    for (current_row, result) in reader.records().enumerate() {
569        if current_row >= end_row {
570            break;
571        }
572
573        if current_row >= start_row {
574            let record = result.map_err(|e| {
575                DatasetsError::InvalidFormat(format!("Failed to read CSV record: {e}"))
576            })?;
577
578            let values: Vec<f64> = record
579                .iter()
580                .map(|s| s.parse::<f64>())
581                .collect::<std::result::Result<Vec<f64>, _>>()
582                .map_err(|e| DatasetsError::InvalidFormat(format!("Failed to parse value: {e}")))?;
583
584            // Write to shared arrays
585            {
586                let mut data_lock = data.lock().unwrap();
587                if let Some(target_idx) = csv_config.target_column {
588                    let mut data_col = 0;
589                    for (j, &val) in values.iter().enumerate() {
590                        if j == target_idx {
591                            if let Some(ref target_arc) = target {
592                                let mut target_lock = target_arc.lock().unwrap();
593                                target_lock[current_row] = val;
594                            }
595                        } else {
596                            data_lock[[current_row, data_col]] = val;
597                            data_col += 1;
598                        }
599                    }
600                } else {
601                    for (j, &val) in values.iter().enumerate() {
602                        data_lock[[current_row, j]] = val;
603                    }
604                }
605            }
606        }
607    }
608
609    Ok(())
610}
611
612/// Load CSV sequentially (fallback)
613#[allow(dead_code)]
614fn load_csv_sequential<P: AsRef<Path>>(
615    path: P,
616    csv_config: CsvConfig,
617    data: Arc<Mutex<Array2<f64>>>,
618    target: Option<Arc<Mutex<Array1<f64>>>>,
619) -> Result<()> {
620    let file = File::open(path).map_err(DatasetsError::IoError)?;
621    let mut reader = ReaderBuilder::new()
622        .has_headers(csv_config.has_header)
623        .delimiter(csv_config.delimiter)
624        .from_reader(file);
625
626    if csv_config.has_header {
627        reader
628            .headers()
629            .map_err(|e| DatasetsError::InvalidFormat(format!("Failed to read headers: {e}")))?;
630    }
631
632    for (row_idx, result) in reader.records().enumerate() {
633        let record = result
634            .map_err(|e| DatasetsError::InvalidFormat(format!("Failed to read CSV record: {e}")))?;
635
636        let values: Vec<f64> = record
637            .iter()
638            .map(|s| s.parse::<f64>())
639            .collect::<std::result::Result<Vec<f64>, _>>()
640            .map_err(|e| DatasetsError::InvalidFormat(format!("Failed to parse value: {e}")))?;
641
642        {
643            let mut data_lock = data.lock().unwrap();
644            if let Some(target_idx) = csv_config.target_column {
645                let mut data_col = 0;
646                for (j, &val) in values.iter().enumerate() {
647                    if j == target_idx {
648                        if let Some(ref target_arc) = target {
649                            let mut target_lock = target_arc.lock().unwrap();
650                            target_lock[row_idx] = val;
651                        }
652                    } else {
653                        data_lock[[row_idx, data_col]] = val;
654                        data_col += 1;
655                    }
656                }
657            } else {
658                for (j, &val) in values.iter().enumerate() {
659                    data_lock[[row_idx, j]] = val;
660                }
661            }
662        }
663    }
664
665    Ok(())
666}
667
668/// Enhanced CSV loader with improved configuration
669#[allow(dead_code)]
670pub fn load_csv<P: AsRef<Path>>(path: P, config: CsvConfig) -> Result<Dataset> {
671    let file = File::open(path).map_err(DatasetsError::IoError)?;
672    let mut reader = ReaderBuilder::new()
673        .has_headers(config.has_header)
674        .delimiter(config.delimiter)
675        .quote(config.quote)
676        .double_quote(config.double_quote)
677        .flexible(config.flexible)
678        .from_reader(file);
679
680    let mut records: Vec<Vec<f64>> = Vec::new();
681    let mut header: Option<Vec<String>> = None;
682
683    // Read header if needed
684    if config.has_header {
685        let headers = reader.headers().map_err(|e| {
686            DatasetsError::InvalidFormat(format!("Failed to read CSV headers: {e}"))
687        })?;
688        header = Some(headers.iter().map(|s| s.to_string()).collect());
689    }
690
691    // Read rows
692    for result in reader.records() {
693        let record = result
694            .map_err(|e| DatasetsError::InvalidFormat(format!("Failed to read CSV record: {e}")))?;
695
696        let values: Vec<f64> = record
697            .iter()
698            .map(|s| {
699                s.parse::<f64>().map_err(|_| {
700                    DatasetsError::InvalidFormat(format!("Failed to parse value: {s}"))
701                })
702            })
703            .collect::<Result<Vec<f64>>>()?;
704
705        if !values.is_empty() {
706            records.push(values);
707        }
708    }
709
710    if records.is_empty() {
711        return Err(DatasetsError::InvalidFormat(
712            "CSV file is empty".to_string(),
713        ));
714    }
715
716    // Create data array and target array if needed
717    let n_rows = records.len();
718    let n_cols = records[0].len();
719
720    let (data, target, featurenames, _targetname) = if let Some(idx) = config.target_column {
721        if idx >= n_cols {
722            return Err(DatasetsError::InvalidFormat(format!(
723                "Target column index {idx} is out of bounds (max: {})",
724                n_cols - 1
725            )));
726        }
727
728        let mut data_array = Array2::zeros((n_rows, n_cols - 1));
729        let mut target_array = Array1::zeros(n_rows);
730
731        for (i, row) in records.iter().enumerate() {
732            let mut data_col = 0;
733            for (j, &val) in row.iter().enumerate() {
734                if j == idx {
735                    target_array[i] = val;
736                } else {
737                    data_array[[i, data_col]] = val;
738                    data_col += 1;
739                }
740            }
741        }
742
743        let featurenames = header.as_ref().map(|h| {
744            let mut names = Vec::new();
745            for (j, name) in h.iter().enumerate() {
746                if j != idx {
747                    names.push(name.clone());
748                }
749            }
750            names
751        });
752
753        (
754            data_array,
755            Some(target_array),
756            featurenames,
757            header.as_ref().map(|h| h[idx].clone()),
758        )
759    } else {
760        let mut data_array = Array2::zeros((n_rows, n_cols));
761
762        for (i, row) in records.iter().enumerate() {
763            for (j, &val) in row.iter().enumerate() {
764                data_array[[i, j]] = val;
765            }
766        }
767
768        (data_array, None, header, None)
769    };
770
771    let mut dataset = Dataset::new(data, target);
772
773    if let Some(names) = featurenames {
774        dataset = dataset.with_featurenames(names);
775    }
776
777    Ok(dataset)
778}