Skip to main content

scirs2_datasets/streaming/
iterator.rs

1//! Streaming iterator API for datasets exceeding RAM.
2//!
3//! Provides a lazy, chunk-based iteration interface over multiple data sources
4//! (in-memory vectors, CSV files, directories of files). Each iteration step
5//! yields a [`StreamingDataChunk`] holding at most `chunk_size` rows, enabling
6//! processing of arbitrarily large datasets with bounded memory usage.
7
8use crate::error::DatasetsError;
9use scirs2_core::ndarray::{Array1, Array2};
10use scirs2_core::random::prelude::*;
11use scirs2_core::random::rngs::StdRng;
12
13// ---------------------------------------------------------------------------
14// DataSource
15// ---------------------------------------------------------------------------
16
17/// Origin of data for a [`NewStreamingIterator`].
18///
19/// This enum is `#[non_exhaustive]` so that future sources (e.g. Arrow IPC,
20/// HDF5) can be added without breaking existing match arms.
21#[non_exhaustive]
22#[derive(Debug)]
23pub enum DataSource {
24    /// All rows are already in memory as a `Vec<Vec<f64>>`.
25    ///
26    /// Each inner `Vec<f64>` is one row; all rows must have the same length.
27    InMemory(Vec<Vec<f64>>),
28
29    /// Path to a CSV file (first row treated as a header and skipped).
30    ///
31    /// All remaining columns except the last are treated as features; the last
32    /// column is treated as a label.
33    Csv(String),
34
35    /// Path to a Parquet file (requires `formats` feature via scirs2-io).
36    ///
37    /// Currently falls back to an unsupported error unless the `formats`
38    /// feature is enabled.
39    Parquet(String),
40
41    /// Path to a directory.  Every file in the directory is read as a CSV
42    /// (same convention as [`DataSource::Csv`]).
43    Directory(String),
44}
45
46// ---------------------------------------------------------------------------
47// StreamingConfig
48// ---------------------------------------------------------------------------
49
50/// Configuration for a [`NewStreamingIterator`].
51#[derive(Debug, Clone)]
52pub struct StreamingIteratorConfig {
53    /// Number of rows per chunk (default: 1024).
54    pub chunk_size: usize,
55    /// Number of chunks to pre-read ahead (currently unused; reserved for
56    /// future async prefetch). Default: 2.
57    pub prefetch: usize,
58    /// Shuffle row order within each chunk using Fisher-Yates (default: false).
59    pub shuffle: bool,
60    /// RNG seed used when `shuffle` is true (default: 42).
61    pub seed: u64,
62}
63
64impl Default for StreamingIteratorConfig {
65    fn default() -> Self {
66        Self {
67            chunk_size: 1024,
68            prefetch: 2,
69            shuffle: false,
70            seed: 42,
71        }
72    }
73}
74
75// ---------------------------------------------------------------------------
76// StreamingDataChunk
77// ---------------------------------------------------------------------------
78
79/// A single chunk produced by [`NewStreamingIterator`].
80#[derive(Debug, Clone)]
81pub struct StreamingDataChunk {
82    /// Feature matrix, shape `[actual_rows, n_features]`.
83    pub features: Array2<f64>,
84    /// Optional label vector, length `actual_rows`.
85    pub labels: Option<Vec<f64>>,
86    /// Zero-based index of this chunk within the stream.
87    pub chunk_id: usize,
88}
89
90impl StreamingDataChunk {
91    /// Number of rows in this chunk.
92    pub fn n_rows(&self) -> usize {
93        self.features.nrows()
94    }
95
96    /// Number of features (columns) in this chunk.
97    pub fn n_features(&self) -> usize {
98        self.features.ncols()
99    }
100}
101
102// ---------------------------------------------------------------------------
103// Internal helper: in-memory row store
104// ---------------------------------------------------------------------------
105
106/// Parsed CSV result: `(data, labels, n_features)`.
107type CsvParseResult = (Vec<f64>, Vec<Option<f64>>, usize);
108
109/// Rows buffered from the data source.  For CSV/Directory sources we read
110/// the whole source eagerly on construction so that we know the total row
111/// count and feature dimensionality; this also keeps the iteration path
112/// uniform across sources.
113struct RowStore {
114    /// Flat storage: row i, feature j → rows[i * n_features + j]
115    data: Vec<f64>,
116    /// Label per row (same indexing as `data` rows)
117    labels: Vec<Option<f64>>,
118    /// Number of features per row
119    n_features: usize,
120    /// Total number of rows
121    n_rows: usize,
122}
123
124impl RowStore {
125    fn from_in_memory(rows: Vec<Vec<f64>>) -> Result<Self, DatasetsError> {
126        if rows.is_empty() {
127            return Ok(Self {
128                data: vec![],
129                labels: vec![],
130                n_features: 0,
131                n_rows: 0,
132            });
133        }
134        let n_features = rows[0].len();
135        if n_features == 0 {
136            return Err(DatasetsError::InvalidFormat(
137                "InMemory rows must have at least one element".to_string(),
138            ));
139        }
140        let n_rows = rows.len();
141        let mut data = Vec::with_capacity(n_rows * n_features);
142        for row in &rows {
143            if row.len() != n_features {
144                return Err(DatasetsError::InvalidFormat(format!(
145                    "Inconsistent row length: expected {n_features}, got {}",
146                    row.len()
147                )));
148            }
149            data.extend_from_slice(row);
150        }
151        Ok(Self {
152            data,
153            labels: vec![None; n_rows],
154            n_features,
155            n_rows,
156        })
157    }
158
159    /// Read a single CSV file (header-skipped, last column = label).
160    fn parse_csv_file(path: &str) -> Result<CsvParseResult, DatasetsError> {
161        use std::fs::File;
162        use std::io::{BufRead, BufReader};
163
164        let file = File::open(path).map_err(DatasetsError::IoError)?;
165        let reader = BufReader::new(file);
166        let mut lines = reader.lines();
167
168        // skip header
169        let _ = lines.next();
170
171        let mut all_data: Vec<f64> = Vec::new();
172        let mut all_labels: Vec<Option<f64>> = Vec::new();
173        let mut n_features: Option<usize> = None;
174
175        for line_res in lines {
176            let line = line_res.map_err(DatasetsError::IoError)?;
177            let line = line.trim();
178            if line.is_empty() {
179                continue;
180            }
181            let values: Vec<f64> = line
182                .split(',')
183                .map(|s| s.trim().parse::<f64>().unwrap_or(0.0))
184                .collect();
185            if values.is_empty() {
186                continue;
187            }
188            let features_here = values.len() - 1; // last col = label
189            if features_here == 0 {
190                // Single column: treat as feature, no label
191                match n_features {
192                    None => n_features = Some(1),
193                    Some(f) if f != 1 => {
194                        return Err(DatasetsError::InvalidFormat(
195                            "Inconsistent number of columns in CSV".to_string(),
196                        ))
197                    }
198                    _ => {}
199                }
200                all_data.push(values[0]);
201                all_labels.push(None);
202            } else {
203                match n_features {
204                    None => n_features = Some(features_here),
205                    Some(f) if f != features_here => {
206                        return Err(DatasetsError::InvalidFormat(
207                            "Inconsistent number of columns in CSV".to_string(),
208                        ))
209                    }
210                    _ => {}
211                }
212                all_data.extend_from_slice(&values[..features_here]);
213                all_labels.push(Some(*values.last().expect("non-empty")));
214            }
215        }
216
217        let nf = n_features.unwrap_or(0);
218        Ok((all_data, all_labels, nf))
219    }
220
221    fn from_csv(path: &str) -> Result<Self, DatasetsError> {
222        let (data, labels, n_features) = Self::parse_csv_file(path)?;
223        let n_rows = data.len().checked_div(n_features).unwrap_or(0);
224        Ok(Self {
225            data,
226            labels,
227            n_features,
228            n_rows,
229        })
230    }
231
232    fn from_directory(dir: &str) -> Result<Self, DatasetsError> {
233        use std::fs;
234        let mut all_data: Vec<f64> = Vec::new();
235        let mut all_labels: Vec<Option<f64>> = Vec::new();
236        let mut n_features: Option<usize> = None;
237
238        let entries = fs::read_dir(dir).map_err(DatasetsError::IoError)?;
239        let mut paths: Vec<_> = entries
240            .filter_map(|e| e.ok().map(|de| de.path()))
241            .filter(|p| p.is_file())
242            .collect();
243        paths.sort(); // deterministic order
244
245        for path in paths {
246            let path_str = path.to_string_lossy();
247            let (data, labels, nf) = Self::parse_csv_file(&path_str)?;
248            if nf == 0 {
249                continue;
250            }
251            match n_features {
252                None => n_features = Some(nf),
253                Some(f) if f != nf => {
254                    return Err(DatasetsError::InvalidFormat(format!(
255                        "Directory file {} has {nf} features, expected {f}",
256                        path.display()
257                    )))
258                }
259                _ => {}
260            }
261            all_data.extend(data);
262            all_labels.extend(labels);
263        }
264
265        let nf = n_features.unwrap_or(0);
266        let n_rows = all_data.len().checked_div(nf).unwrap_or(0);
267        Ok(Self {
268            data: all_data,
269            labels: all_labels,
270            n_features: nf,
271            n_rows,
272        })
273    }
274
275    /// Extract a slice of rows `[start, end)` as a `StreamingDataChunk`.
276    fn slice_chunk(
277        &self,
278        start: usize,
279        end: usize,
280        chunk_id: usize,
281        shuffle: bool,
282        rng: &mut StdRng,
283    ) -> Result<StreamingDataChunk, DatasetsError> {
284        let end = end.min(self.n_rows);
285        if start >= end {
286            // Return an empty chunk
287            let features = Array2::zeros((0, self.n_features.max(1)));
288            return Ok(StreamingDataChunk {
289                features,
290                labels: None,
291                chunk_id,
292            });
293        }
294        let count = end - start;
295        let nf = self.n_features;
296
297        // Build index list (for shuffle support)
298        let mut indices: Vec<usize> = (start..end).collect();
299        if shuffle {
300            // Fisher-Yates
301            for i in (1..count).rev() {
302                let j = rng.next_u64() as usize % (i + 1);
303                indices.swap(i, j);
304            }
305        }
306
307        let mut feat_flat: Vec<f64> = Vec::with_capacity(count * nf);
308        let mut labels_out: Vec<f64> = Vec::with_capacity(count);
309        let mut has_labels = false;
310
311        for &row_idx in &indices {
312            let base = row_idx * nf;
313            feat_flat.extend_from_slice(&self.data[base..base + nf]);
314            if let Some(lbl) = self.labels[row_idx] {
315                labels_out.push(lbl);
316                has_labels = true;
317            } else {
318                labels_out.push(0.0);
319            }
320        }
321
322        let features = Array2::from_shape_vec((count, nf), feat_flat)
323            .map_err(|e| DatasetsError::ComputationError(format!("Shape error: {e}")))?;
324
325        Ok(StreamingDataChunk {
326            features,
327            labels: if has_labels { Some(labels_out) } else { None },
328            chunk_id,
329        })
330    }
331}
332
333// ---------------------------------------------------------------------------
334// NewStreamingIterator
335// ---------------------------------------------------------------------------
336
337/// Streaming iterator over a [`DataSource`], yielding [`StreamingDataChunk`]s.
338///
339/// Call [`NewStreamingIterator::new`] to construct, then use it as a standard
340/// Rust `Iterator<Item = Result<StreamingDataChunk, DatasetsError>>`.
341pub struct NewStreamingIterator {
342    store: RowStore,
343    config: StreamingIteratorConfig,
344    current_chunk: usize,
345    rng: StdRng,
346}
347
348impl NewStreamingIterator {
349    /// Construct a streaming iterator from the given source and configuration.
350    ///
351    /// For `Csv` and `Directory` sources, the file(s) are read eagerly during
352    /// construction so that the total row count is known immediately.
353    pub fn new(source: DataSource, config: StreamingIteratorConfig) -> Result<Self, DatasetsError> {
354        let store = match source {
355            DataSource::InMemory(rows) => RowStore::from_in_memory(rows)?,
356            DataSource::Csv(path) => RowStore::from_csv(&path)?,
357            DataSource::Directory(dir) => RowStore::from_directory(&dir)?,
358            DataSource::Parquet(_) => {
359                return Err(DatasetsError::Other(
360                    "Parquet source requires the `formats` feature".to_string(),
361                ))
362            }
363        };
364
365        let rng = StdRng::seed_from_u64(config.seed);
366        Ok(Self {
367            store,
368            config,
369            current_chunk: 0,
370            rng,
371        })
372    }
373
374    /// Total number of chunks (known because source is fully loaded).
375    pub fn n_chunks(&self) -> Option<usize> {
376        if self.config.chunk_size == 0 {
377            return Some(0);
378        }
379        Some(self.store.n_rows.div_ceil(self.config.chunk_size))
380    }
381
382    /// Number of features per row.
383    pub fn n_features(&self) -> usize {
384        self.store.n_features
385    }
386
387    /// Total number of rows across all chunks.
388    pub fn n_rows(&self) -> usize {
389        self.store.n_rows
390    }
391
392    /// Reset the iterator to the beginning of the stream.
393    pub fn reset(&mut self) {
394        self.current_chunk = 0;
395    }
396}
397
398impl Iterator for NewStreamingIterator {
399    type Item = Result<StreamingDataChunk, DatasetsError>;
400
401    fn next(&mut self) -> Option<Self::Item> {
402        let chunk_size = self.config.chunk_size;
403        let start = self.current_chunk * chunk_size;
404        if start >= self.store.n_rows && self.store.n_rows > 0 {
405            return None;
406        }
407        // Handle empty source: emit nothing
408        if self.store.n_rows == 0 {
409            return None;
410        }
411        let end = (start + chunk_size).min(self.store.n_rows);
412        let chunk_id = self.current_chunk;
413        self.current_chunk += 1;
414
415        let result =
416            self.store
417                .slice_chunk(start, end, chunk_id, self.config.shuffle, &mut self.rng);
418        Some(result)
419    }
420}
421
422// ---------------------------------------------------------------------------
423// Tests
424// ---------------------------------------------------------------------------
425
426#[cfg(test)]
427mod tests {
428    use super::*;
429
430    fn make_rows(n: usize, f: usize) -> Vec<Vec<f64>> {
431        (0..n)
432            .map(|i| (0..f).map(|j| (i * f + j) as f64).collect())
433            .collect()
434    }
435
436    #[test]
437    fn test_streaming_inmemory() {
438        let rows = make_rows(100, 4);
439        let config = StreamingIteratorConfig {
440            chunk_size: 30,
441            ..Default::default()
442        };
443        let iter = NewStreamingIterator::new(DataSource::InMemory(rows), config)
444            .expect("construction failed");
445        // 100 rows / 30 = 4 chunks (3 full + 1 partial)
446        assert_eq!(iter.n_chunks(), Some(4));
447        assert_eq!(iter.n_features(), 4);
448    }
449
450    #[test]
451    fn test_streaming_chunk_size() {
452        let rows = make_rows(55, 3);
453        let config = StreamingIteratorConfig {
454            chunk_size: 20,
455            ..Default::default()
456        };
457        let iter = NewStreamingIterator::new(DataSource::InMemory(rows), config)
458            .expect("construction failed");
459
460        let chunks: Vec<_> = iter.map(|r| r.expect("chunk error")).collect();
461        // 55 / 20 = 3 chunks: 20, 20, 15
462        assert_eq!(chunks.len(), 3);
463        assert_eq!(chunks[0].n_rows(), 20);
464        assert_eq!(chunks[1].n_rows(), 20);
465        assert_eq!(chunks[2].n_rows(), 15);
466        for chunk in &chunks {
467            assert!(chunk.n_rows() <= 20);
468        }
469    }
470
471    #[test]
472    fn test_streaming_empty_source() {
473        let config = StreamingIteratorConfig::default();
474        let iter =
475            NewStreamingIterator::new(DataSource::InMemory(vec![]), config).expect("construction");
476        let chunks: Vec<_> = iter.collect();
477        assert!(chunks.is_empty());
478    }
479
480    #[test]
481    fn test_streaming_single_row() {
482        let config = StreamingIteratorConfig {
483            chunk_size: 10,
484            ..Default::default()
485        };
486        let iter =
487            NewStreamingIterator::new(DataSource::InMemory(vec![vec![1.0, 2.0, 3.0]]), config)
488                .expect("construction");
489        let chunks: Vec<_> = iter.map(|r| r.expect("err")).collect();
490        assert_eq!(chunks.len(), 1);
491        assert_eq!(chunks[0].n_rows(), 1);
492        assert_eq!(chunks[0].n_features(), 3);
493    }
494
495    #[test]
496    fn test_streaming_exact_multiple() {
497        // 60 rows, chunk_size=20 → exactly 3 full chunks
498        let rows = make_rows(60, 2);
499        let config = StreamingIteratorConfig {
500            chunk_size: 20,
501            ..Default::default()
502        };
503        let iter =
504            NewStreamingIterator::new(DataSource::InMemory(rows), config).expect("construction");
505        let chunks: Vec<_> = iter.map(|r| r.expect("err")).collect();
506        assert_eq!(chunks.len(), 3);
507        for chunk in &chunks {
508            assert_eq!(chunk.n_rows(), 20);
509        }
510    }
511
512    #[test]
513    fn test_streaming_reset() {
514        let rows = make_rows(10, 2);
515        let config = StreamingIteratorConfig {
516            chunk_size: 5,
517            ..Default::default()
518        };
519        let mut iter =
520            NewStreamingIterator::new(DataSource::InMemory(rows), config).expect("construction");
521        let first_run: Vec<_> = iter.by_ref().map(|r| r.expect("err")).collect();
522        iter.reset();
523        let second_run: Vec<_> = iter.map(|r| r.expect("err")).collect();
524        assert_eq!(first_run.len(), second_run.len());
525    }
526
527    #[test]
528    fn test_streaming_csv() {
529        use std::io::Write;
530        let mut tmp = std::env::temp_dir();
531        tmp.push("scirs2_streaming_test.csv");
532        {
533            let mut f = std::fs::File::create(&tmp).expect("create");
534            writeln!(f, "a,b,c,label").expect("write header");
535            for i in 0..20_usize {
536                writeln!(f, "{},{},{},{}", i, i + 1, i + 2, i % 3).expect("write row");
537            }
538        }
539        let config = StreamingIteratorConfig {
540            chunk_size: 8,
541            ..Default::default()
542        };
543        let iter =
544            NewStreamingIterator::new(DataSource::Csv(tmp.to_string_lossy().into_owned()), config)
545                .expect("construction");
546        let chunks: Vec<_> = iter.map(|r| r.expect("err")).collect();
547        // 20 rows / 8 = 3 chunks (8, 8, 4)
548        assert_eq!(chunks.len(), 3);
549        let total_rows: usize = chunks.iter().map(|c| c.n_rows()).sum();
550        assert_eq!(total_rows, 20);
551        // labels should be present
552        assert!(chunks[0].labels.is_some());
553        let _ = std::fs::remove_file(&tmp);
554    }
555}