scirs2_datasets/streaming/
iterator.rs1use crate::error::DatasetsError;
9use scirs2_core::ndarray::{Array1, Array2};
10use scirs2_core::random::prelude::*;
11use scirs2_core::random::rngs::StdRng;
12
13#[non_exhaustive]
22#[derive(Debug)]
23pub enum DataSource {
24 InMemory(Vec<Vec<f64>>),
28
29 Csv(String),
34
35 Parquet(String),
40
41 Directory(String),
44}
45
46#[derive(Debug, Clone)]
52pub struct StreamingIteratorConfig {
53 pub chunk_size: usize,
55 pub prefetch: usize,
58 pub shuffle: bool,
60 pub seed: u64,
62}
63
64impl Default for StreamingIteratorConfig {
65 fn default() -> Self {
66 Self {
67 chunk_size: 1024,
68 prefetch: 2,
69 shuffle: false,
70 seed: 42,
71 }
72 }
73}
74
75#[derive(Debug, Clone)]
81pub struct StreamingDataChunk {
82 pub features: Array2<f64>,
84 pub labels: Option<Vec<f64>>,
86 pub chunk_id: usize,
88}
89
90impl StreamingDataChunk {
91 pub fn n_rows(&self) -> usize {
93 self.features.nrows()
94 }
95
96 pub fn n_features(&self) -> usize {
98 self.features.ncols()
99 }
100}
101
102type CsvParseResult = (Vec<f64>, Vec<Option<f64>>, usize);
108
109struct RowStore {
114 data: Vec<f64>,
116 labels: Vec<Option<f64>>,
118 n_features: usize,
120 n_rows: usize,
122}
123
124impl RowStore {
125 fn from_in_memory(rows: Vec<Vec<f64>>) -> Result<Self, DatasetsError> {
126 if rows.is_empty() {
127 return Ok(Self {
128 data: vec![],
129 labels: vec![],
130 n_features: 0,
131 n_rows: 0,
132 });
133 }
134 let n_features = rows[0].len();
135 if n_features == 0 {
136 return Err(DatasetsError::InvalidFormat(
137 "InMemory rows must have at least one element".to_string(),
138 ));
139 }
140 let n_rows = rows.len();
141 let mut data = Vec::with_capacity(n_rows * n_features);
142 for row in &rows {
143 if row.len() != n_features {
144 return Err(DatasetsError::InvalidFormat(format!(
145 "Inconsistent row length: expected {n_features}, got {}",
146 row.len()
147 )));
148 }
149 data.extend_from_slice(row);
150 }
151 Ok(Self {
152 data,
153 labels: vec![None; n_rows],
154 n_features,
155 n_rows,
156 })
157 }
158
159 fn parse_csv_file(path: &str) -> Result<CsvParseResult, DatasetsError> {
161 use std::fs::File;
162 use std::io::{BufRead, BufReader};
163
164 let file = File::open(path).map_err(DatasetsError::IoError)?;
165 let reader = BufReader::new(file);
166 let mut lines = reader.lines();
167
168 let _ = lines.next();
170
171 let mut all_data: Vec<f64> = Vec::new();
172 let mut all_labels: Vec<Option<f64>> = Vec::new();
173 let mut n_features: Option<usize> = None;
174
175 for line_res in lines {
176 let line = line_res.map_err(DatasetsError::IoError)?;
177 let line = line.trim();
178 if line.is_empty() {
179 continue;
180 }
181 let values: Vec<f64> = line
182 .split(',')
183 .map(|s| s.trim().parse::<f64>().unwrap_or(0.0))
184 .collect();
185 if values.is_empty() {
186 continue;
187 }
188 let features_here = values.len() - 1; if features_here == 0 {
190 match n_features {
192 None => n_features = Some(1),
193 Some(f) if f != 1 => {
194 return Err(DatasetsError::InvalidFormat(
195 "Inconsistent number of columns in CSV".to_string(),
196 ))
197 }
198 _ => {}
199 }
200 all_data.push(values[0]);
201 all_labels.push(None);
202 } else {
203 match n_features {
204 None => n_features = Some(features_here),
205 Some(f) if f != features_here => {
206 return Err(DatasetsError::InvalidFormat(
207 "Inconsistent number of columns in CSV".to_string(),
208 ))
209 }
210 _ => {}
211 }
212 all_data.extend_from_slice(&values[..features_here]);
213 all_labels.push(Some(*values.last().expect("non-empty")));
214 }
215 }
216
217 let nf = n_features.unwrap_or(0);
218 Ok((all_data, all_labels, nf))
219 }
220
221 fn from_csv(path: &str) -> Result<Self, DatasetsError> {
222 let (data, labels, n_features) = Self::parse_csv_file(path)?;
223 let n_rows = data.len().checked_div(n_features).unwrap_or(0);
224 Ok(Self {
225 data,
226 labels,
227 n_features,
228 n_rows,
229 })
230 }
231
232 fn from_directory(dir: &str) -> Result<Self, DatasetsError> {
233 use std::fs;
234 let mut all_data: Vec<f64> = Vec::new();
235 let mut all_labels: Vec<Option<f64>> = Vec::new();
236 let mut n_features: Option<usize> = None;
237
238 let entries = fs::read_dir(dir).map_err(DatasetsError::IoError)?;
239 let mut paths: Vec<_> = entries
240 .filter_map(|e| e.ok().map(|de| de.path()))
241 .filter(|p| p.is_file())
242 .collect();
243 paths.sort(); for path in paths {
246 let path_str = path.to_string_lossy();
247 let (data, labels, nf) = Self::parse_csv_file(&path_str)?;
248 if nf == 0 {
249 continue;
250 }
251 match n_features {
252 None => n_features = Some(nf),
253 Some(f) if f != nf => {
254 return Err(DatasetsError::InvalidFormat(format!(
255 "Directory file {} has {nf} features, expected {f}",
256 path.display()
257 )))
258 }
259 _ => {}
260 }
261 all_data.extend(data);
262 all_labels.extend(labels);
263 }
264
265 let nf = n_features.unwrap_or(0);
266 let n_rows = all_data.len().checked_div(nf).unwrap_or(0);
267 Ok(Self {
268 data: all_data,
269 labels: all_labels,
270 n_features: nf,
271 n_rows,
272 })
273 }
274
275 fn slice_chunk(
277 &self,
278 start: usize,
279 end: usize,
280 chunk_id: usize,
281 shuffle: bool,
282 rng: &mut StdRng,
283 ) -> Result<StreamingDataChunk, DatasetsError> {
284 let end = end.min(self.n_rows);
285 if start >= end {
286 let features = Array2::zeros((0, self.n_features.max(1)));
288 return Ok(StreamingDataChunk {
289 features,
290 labels: None,
291 chunk_id,
292 });
293 }
294 let count = end - start;
295 let nf = self.n_features;
296
297 let mut indices: Vec<usize> = (start..end).collect();
299 if shuffle {
300 for i in (1..count).rev() {
302 let j = rng.next_u64() as usize % (i + 1);
303 indices.swap(i, j);
304 }
305 }
306
307 let mut feat_flat: Vec<f64> = Vec::with_capacity(count * nf);
308 let mut labels_out: Vec<f64> = Vec::with_capacity(count);
309 let mut has_labels = false;
310
311 for &row_idx in &indices {
312 let base = row_idx * nf;
313 feat_flat.extend_from_slice(&self.data[base..base + nf]);
314 if let Some(lbl) = self.labels[row_idx] {
315 labels_out.push(lbl);
316 has_labels = true;
317 } else {
318 labels_out.push(0.0);
319 }
320 }
321
322 let features = Array2::from_shape_vec((count, nf), feat_flat)
323 .map_err(|e| DatasetsError::ComputationError(format!("Shape error: {e}")))?;
324
325 Ok(StreamingDataChunk {
326 features,
327 labels: if has_labels { Some(labels_out) } else { None },
328 chunk_id,
329 })
330 }
331}
332
333pub struct NewStreamingIterator {
342 store: RowStore,
343 config: StreamingIteratorConfig,
344 current_chunk: usize,
345 rng: StdRng,
346}
347
348impl NewStreamingIterator {
349 pub fn new(source: DataSource, config: StreamingIteratorConfig) -> Result<Self, DatasetsError> {
354 let store = match source {
355 DataSource::InMemory(rows) => RowStore::from_in_memory(rows)?,
356 DataSource::Csv(path) => RowStore::from_csv(&path)?,
357 DataSource::Directory(dir) => RowStore::from_directory(&dir)?,
358 DataSource::Parquet(_) => {
359 return Err(DatasetsError::Other(
360 "Parquet source requires the `formats` feature".to_string(),
361 ))
362 }
363 };
364
365 let rng = StdRng::seed_from_u64(config.seed);
366 Ok(Self {
367 store,
368 config,
369 current_chunk: 0,
370 rng,
371 })
372 }
373
374 pub fn n_chunks(&self) -> Option<usize> {
376 if self.config.chunk_size == 0 {
377 return Some(0);
378 }
379 Some(self.store.n_rows.div_ceil(self.config.chunk_size))
380 }
381
382 pub fn n_features(&self) -> usize {
384 self.store.n_features
385 }
386
387 pub fn n_rows(&self) -> usize {
389 self.store.n_rows
390 }
391
392 pub fn reset(&mut self) {
394 self.current_chunk = 0;
395 }
396}
397
398impl Iterator for NewStreamingIterator {
399 type Item = Result<StreamingDataChunk, DatasetsError>;
400
401 fn next(&mut self) -> Option<Self::Item> {
402 let chunk_size = self.config.chunk_size;
403 let start = self.current_chunk * chunk_size;
404 if start >= self.store.n_rows && self.store.n_rows > 0 {
405 return None;
406 }
407 if self.store.n_rows == 0 {
409 return None;
410 }
411 let end = (start + chunk_size).min(self.store.n_rows);
412 let chunk_id = self.current_chunk;
413 self.current_chunk += 1;
414
415 let result =
416 self.store
417 .slice_chunk(start, end, chunk_id, self.config.shuffle, &mut self.rng);
418 Some(result)
419 }
420}
421
422#[cfg(test)]
427mod tests {
428 use super::*;
429
430 fn make_rows(n: usize, f: usize) -> Vec<Vec<f64>> {
431 (0..n)
432 .map(|i| (0..f).map(|j| (i * f + j) as f64).collect())
433 .collect()
434 }
435
436 #[test]
437 fn test_streaming_inmemory() {
438 let rows = make_rows(100, 4);
439 let config = StreamingIteratorConfig {
440 chunk_size: 30,
441 ..Default::default()
442 };
443 let iter = NewStreamingIterator::new(DataSource::InMemory(rows), config)
444 .expect("construction failed");
445 assert_eq!(iter.n_chunks(), Some(4));
447 assert_eq!(iter.n_features(), 4);
448 }
449
450 #[test]
451 fn test_streaming_chunk_size() {
452 let rows = make_rows(55, 3);
453 let config = StreamingIteratorConfig {
454 chunk_size: 20,
455 ..Default::default()
456 };
457 let iter = NewStreamingIterator::new(DataSource::InMemory(rows), config)
458 .expect("construction failed");
459
460 let chunks: Vec<_> = iter.map(|r| r.expect("chunk error")).collect();
461 assert_eq!(chunks.len(), 3);
463 assert_eq!(chunks[0].n_rows(), 20);
464 assert_eq!(chunks[1].n_rows(), 20);
465 assert_eq!(chunks[2].n_rows(), 15);
466 for chunk in &chunks {
467 assert!(chunk.n_rows() <= 20);
468 }
469 }
470
471 #[test]
472 fn test_streaming_empty_source() {
473 let config = StreamingIteratorConfig::default();
474 let iter =
475 NewStreamingIterator::new(DataSource::InMemory(vec![]), config).expect("construction");
476 let chunks: Vec<_> = iter.collect();
477 assert!(chunks.is_empty());
478 }
479
480 #[test]
481 fn test_streaming_single_row() {
482 let config = StreamingIteratorConfig {
483 chunk_size: 10,
484 ..Default::default()
485 };
486 let iter =
487 NewStreamingIterator::new(DataSource::InMemory(vec![vec![1.0, 2.0, 3.0]]), config)
488 .expect("construction");
489 let chunks: Vec<_> = iter.map(|r| r.expect("err")).collect();
490 assert_eq!(chunks.len(), 1);
491 assert_eq!(chunks[0].n_rows(), 1);
492 assert_eq!(chunks[0].n_features(), 3);
493 }
494
495 #[test]
496 fn test_streaming_exact_multiple() {
497 let rows = make_rows(60, 2);
499 let config = StreamingIteratorConfig {
500 chunk_size: 20,
501 ..Default::default()
502 };
503 let iter =
504 NewStreamingIterator::new(DataSource::InMemory(rows), config).expect("construction");
505 let chunks: Vec<_> = iter.map(|r| r.expect("err")).collect();
506 assert_eq!(chunks.len(), 3);
507 for chunk in &chunks {
508 assert_eq!(chunk.n_rows(), 20);
509 }
510 }
511
512 #[test]
513 fn test_streaming_reset() {
514 let rows = make_rows(10, 2);
515 let config = StreamingIteratorConfig {
516 chunk_size: 5,
517 ..Default::default()
518 };
519 let mut iter =
520 NewStreamingIterator::new(DataSource::InMemory(rows), config).expect("construction");
521 let first_run: Vec<_> = iter.by_ref().map(|r| r.expect("err")).collect();
522 iter.reset();
523 let second_run: Vec<_> = iter.map(|r| r.expect("err")).collect();
524 assert_eq!(first_run.len(), second_run.len());
525 }
526
527 #[test]
528 fn test_streaming_csv() {
529 use std::io::Write;
530 let mut tmp = std::env::temp_dir();
531 tmp.push("scirs2_streaming_test.csv");
532 {
533 let mut f = std::fs::File::create(&tmp).expect("create");
534 writeln!(f, "a,b,c,label").expect("write header");
535 for i in 0..20_usize {
536 writeln!(f, "{},{},{},{}", i, i + 1, i + 2, i % 3).expect("write row");
537 }
538 }
539 let config = StreamingIteratorConfig {
540 chunk_size: 8,
541 ..Default::default()
542 };
543 let iter =
544 NewStreamingIterator::new(DataSource::Csv(tmp.to_string_lossy().into_owned()), config)
545 .expect("construction");
546 let chunks: Vec<_> = iter.map(|r| r.expect("err")).collect();
547 assert_eq!(chunks.len(), 3);
549 let total_rows: usize = chunks.iter().map(|c| c.n_rows()).sum();
550 assert_eq!(total_rows, 20);
551 assert!(chunks[0].labels.is_some());
553 let _ = std::fs::remove_file(&tmp);
554 }
555}