Skip to main content

scirs2_datasets/streaming/
dataloader.rs

1//! DataLoader-style batching API for neural network training.
2//!
3//! Provides a `DataLoader` struct that wraps an in-memory dataset and yields
4//! mini-batches according to configurable sampling strategies. Epoch-level
5//! shuffling, stratified sampling, and weighted-random sampling are all
6//! supported without external dependencies.
7
8use crate::error::DatasetsError;
9use scirs2_core::ndarray::{Array1, Array2};
10use scirs2_core::random::prelude::*;
11use scirs2_core::random::rngs::StdRng;
12
13// ---------------------------------------------------------------------------
14// SamplingStrategy
15// ---------------------------------------------------------------------------
16
17/// How rows are ordered / selected when building batches.
18///
19/// `#[non_exhaustive]` allows future strategies to be added without a
20/// breaking change.
21#[non_exhaustive]
22#[derive(Debug, Clone)]
23pub enum SamplingStrategy {
24    /// Rows are yielded in their natural (insertion) order.
25    Sequential,
26
27    /// Rows are globally shuffled at the start of each epoch.
28    RandomShuffle,
29
30    /// Stratified sampling: rows are first grouped by the provided class
31    /// indices, then interleaved so every batch contains representatives
32    /// of every class present in the dataset.
33    ///
34    /// The inner `Vec<usize>` maps each dataset row index to its class label
35    /// (integer encoded).
36    Stratified(Vec<usize>),
37
38    /// Weighted random sampling without replacement (within an epoch).
39    ///
40    /// The inner `Vec<f64>` provides a non-negative weight for every row.
41    /// Rows with higher weights are more likely to appear early in the epoch
42    /// ordering.
43    WeightedRandom(Vec<f64>),
44}
45
46// ---------------------------------------------------------------------------
47// DataLoaderConfig
48// ---------------------------------------------------------------------------
49
50/// Configuration for a [`DataLoader`].
51#[derive(Debug, Clone)]
52pub struct DataLoaderConfig {
53    /// Mini-batch size (default: 32).
54    pub batch_size: usize,
55    /// If `true`, epoch-level shuffling is performed (overrides `sampling`
56    /// strategy to `RandomShuffle` when set). Default: `true`.
57    pub shuffle: bool,
58    /// Drop the last (potentially smaller) batch if it has fewer than
59    /// `batch_size` rows. Default: `false`.
60    pub drop_last: bool,
61    /// RNG seed (default: 42).
62    pub seed: u64,
63    /// Row-selection strategy.  `shuffle = true` takes precedence over this
64    /// field by forcing `RandomShuffle` behaviour.
65    pub sampling: SamplingStrategy,
66}
67
68impl Default for DataLoaderConfig {
69    fn default() -> Self {
70        Self {
71            batch_size: 32,
72            shuffle: true,
73            drop_last: false,
74            seed: 42,
75            sampling: SamplingStrategy::RandomShuffle,
76        }
77    }
78}
79
80// ---------------------------------------------------------------------------
81// Batch
82// ---------------------------------------------------------------------------
83
84/// A single mini-batch produced by [`DataLoader`].
85#[derive(Debug, Clone)]
86pub struct Batch {
87    /// Feature matrix, shape `[actual_batch_size, n_features]`.
88    pub features: Array2<f64>,
89    /// Optional label vector, length `actual_batch_size`.
90    pub labels: Option<Array1<f64>>,
91    /// Original dataset indices of the rows in this batch.
92    pub indices: Vec<usize>,
93}
94
95impl Batch {
96    /// Number of rows in this batch.
97    pub fn batch_size(&self) -> usize {
98        self.features.nrows()
99    }
100
101    /// Number of features per row.
102    pub fn n_features(&self) -> usize {
103        self.features.ncols()
104    }
105}
106
107// ---------------------------------------------------------------------------
108// DataLoader
109// ---------------------------------------------------------------------------
110
111/// Mini-batch iterator over an in-memory dataset.
112///
113/// Construct with [`DataLoader::new`], then iterate with the standard `Iterator`
114/// interface.  Call [`DataLoader::reset_epoch`] to start a fresh epoch (with a
115/// new shuffle if configured).
116pub struct DataLoader {
117    features: Array2<f64>,
118    labels: Option<Vec<f64>>,
119    config: DataLoaderConfig,
120    /// Permuted row indices for the current epoch.
121    indices: Vec<usize>,
122    current_pos: usize,
123    epoch: usize,
124    rng: StdRng,
125}
126
127impl DataLoader {
128    /// Create a new `DataLoader`.
129    ///
130    /// `labels` is optional; when `None`, the yielded [`Batch`]es will have
131    /// `labels = None`.
132    pub fn new(features: Array2<f64>, labels: Option<Vec<f64>>, config: DataLoaderConfig) -> Self {
133        let n_rows = features.nrows();
134        let mut rng = StdRng::seed_from_u64(config.seed);
135        let indices = Self::build_indices(n_rows, &config, &mut rng);
136        Self {
137            features,
138            labels,
139            config,
140            indices,
141            current_pos: 0,
142            epoch: 0,
143            rng,
144        }
145    }
146
147    /// Total number of complete batches in the current epoch.
148    ///
149    /// If `drop_last` is `false` and the dataset size is not an exact multiple
150    /// of `batch_size`, this includes the partial final batch.
151    pub fn n_batches(&self) -> usize {
152        let n = self.indices.len();
153        let bs = self.config.batch_size.max(1);
154        if self.config.drop_last {
155            n / bs
156        } else {
157            n.div_ceil(bs)
158        }
159    }
160
161    /// Number of rows in the dataset.
162    pub fn n_rows(&self) -> usize {
163        self.features.nrows()
164    }
165
166    /// Number of features per row.
167    pub fn n_features(&self) -> usize {
168        self.features.ncols()
169    }
170
171    /// The 0-based epoch counter. Incremented by `reset_epoch`.
172    pub fn epoch(&self) -> usize {
173        self.epoch
174    }
175
176    /// Advance to the next epoch: resets the position and, when shuffling is
177    /// enabled, builds a fresh permutation.
178    pub fn reset_epoch(&mut self) {
179        self.epoch += 1;
180        self.current_pos = 0;
181        let n_rows = self.features.nrows();
182        self.indices = Self::build_indices(n_rows, &self.config, &mut self.rng);
183    }
184
185    // ------------------------------------------------------------------
186    // Internal helpers
187    // ------------------------------------------------------------------
188
189    /// Build an ordered index array according to the sampling strategy.
190    fn build_indices(n_rows: usize, config: &DataLoaderConfig, rng: &mut StdRng) -> Vec<usize> {
191        if n_rows == 0 {
192            return vec![];
193        }
194
195        // `shuffle = true` overrides the sampling strategy
196        if config.shuffle {
197            return Self::fisher_yates(n_rows, rng);
198        }
199
200        match &config.sampling {
201            SamplingStrategy::Sequential => (0..n_rows).collect(),
202
203            SamplingStrategy::RandomShuffle => Self::fisher_yates(n_rows, rng),
204
205            SamplingStrategy::Stratified(class_labels) => {
206                Self::stratified_indices(n_rows, class_labels, rng)
207            }
208
209            SamplingStrategy::WeightedRandom(weights) => {
210                Self::weighted_indices(n_rows, weights, rng)
211            }
212        }
213    }
214
215    /// Fisher-Yates full-dataset shuffle.
216    fn fisher_yates(n: usize, rng: &mut StdRng) -> Vec<usize> {
217        let mut idx: Vec<usize> = (0..n).collect();
218        for i in (1..n).rev() {
219            let j = (rng.next_u64() as usize) % (i + 1);
220            idx.swap(i, j);
221        }
222        idx
223    }
224
225    /// Interleave class buckets so every batch sees all classes.
226    fn stratified_indices(n_rows: usize, class_labels: &[usize], rng: &mut StdRng) -> Vec<usize> {
227        // Build per-class index lists
228        let max_class = class_labels.iter().copied().max().unwrap_or(0);
229        let mut buckets: Vec<Vec<usize>> = vec![vec![]; max_class + 1];
230        for (row, &cls) in class_labels.iter().enumerate().take(n_rows) {
231            buckets[cls].push(row);
232        }
233        // Shuffle within each bucket
234        for bucket in &mut buckets {
235            for i in (1..bucket.len()).rev() {
236                let j = (rng.next_u64() as usize) % (i + 1);
237                bucket.swap(i, j);
238            }
239        }
240        // Round-robin interleave
241        let mut result = Vec::with_capacity(n_rows);
242        let mut cursors = vec![0usize; buckets.len()];
243        let mut any_remaining = true;
244        while any_remaining {
245            any_remaining = false;
246            for (cls, bucket) in buckets.iter().enumerate() {
247                if cursors[cls] < bucket.len() {
248                    result.push(bucket[cursors[cls]]);
249                    cursors[cls] += 1;
250                    any_remaining = true;
251                }
252            }
253        }
254        result
255    }
256
257    /// Weighted sampling without replacement via the alias / rejection method.
258    ///
259    /// Uses a simple O(n log n) approach: sort by uniform_variate / weight,
260    /// which is equivalent to Efraimidis-Spirakis weighted reservoir sampling
261    /// with a reservoir of size n (i.e., all rows).
262    fn weighted_indices(n_rows: usize, weights: &[f64], rng: &mut StdRng) -> Vec<usize> {
263        let mut keyed: Vec<(f64, usize)> = (0..n_rows)
264            .map(|i| {
265                let w = if i < weights.len() {
266                    weights[i].max(0.0)
267                } else {
268                    1.0
269                };
270                // key = -ln(u) / w  (minimise → Efraimidis-Spirakis)
271                let u = (rng.next_u64() as f64 + 1.0) / (u64::MAX as f64 + 1.0);
272                let key = if w > 0.0 { -u.ln() / w } else { f64::INFINITY };
273                (key, i)
274            })
275            .collect();
276        keyed.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
277        keyed.into_iter().map(|(_, idx)| idx).collect()
278    }
279
280    /// Extract rows at `row_indices` from the feature / label arrays.
281    fn extract_batch(&self, row_indices: &[usize]) -> Batch {
282        let nf = self.features.ncols();
283        let bs = row_indices.len();
284        let mut feat_flat = Vec::with_capacity(bs * nf);
285        let mut label_vals = Vec::with_capacity(bs);
286
287        for &ri in row_indices {
288            for j in 0..nf {
289                feat_flat.push(self.features[[ri, j]]);
290            }
291            if let Some(lbl_vec) = &self.labels {
292                label_vals.push(if ri < lbl_vec.len() { lbl_vec[ri] } else { 0.0 });
293            }
294        }
295
296        let features = Array2::from_shape_vec((bs, nf), feat_flat)
297            .unwrap_or_else(|_| Array2::zeros((bs, nf.max(1))));
298
299        let labels = if self.labels.is_some() {
300            Some(Array1::from_vec(label_vals))
301        } else {
302            None
303        };
304
305        Batch {
306            features,
307            labels,
308            indices: row_indices.to_vec(),
309        }
310    }
311}
312
313impl Iterator for DataLoader {
314    type Item = Batch;
315
316    fn next(&mut self) -> Option<Self::Item> {
317        let remaining = self.indices.len().saturating_sub(self.current_pos);
318        if remaining == 0 {
319            return None;
320        }
321
322        let bs = self.config.batch_size.max(1);
323        let batch_rows = remaining.min(bs);
324
325        // Drop incomplete last batch if requested
326        if self.config.drop_last && batch_rows < bs {
327            return None;
328        }
329
330        let start = self.current_pos;
331        let end = start + batch_rows;
332        let row_indices: Vec<usize> = self.indices[start..end].to_vec();
333        self.current_pos = end;
334
335        Some(self.extract_batch(&row_indices))
336    }
337}
338
339// ---------------------------------------------------------------------------
340// Tests
341// ---------------------------------------------------------------------------
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346    use scirs2_core::ndarray::Array2;
347
348    fn make_loader(n: usize, f: usize, bs: usize, shuffle: bool) -> DataLoader {
349        let data: Vec<f64> = (0..n * f).map(|x| x as f64).collect();
350        let features = Array2::from_shape_vec((n, f), data).unwrap();
351        let labels: Vec<f64> = (0..n).map(|i| (i % 3) as f64).collect();
352        let config = DataLoaderConfig {
353            batch_size: bs,
354            shuffle,
355            drop_last: false,
356            seed: 42,
357            sampling: if shuffle {
358                SamplingStrategy::RandomShuffle
359            } else {
360                SamplingStrategy::Sequential
361            },
362        };
363        DataLoader::new(features, Some(labels), config)
364    }
365
366    #[test]
367    fn test_dataloader_basic() {
368        // 100 rows, batch 32 → 4 batches (32, 32, 32, 4)
369        let loader = make_loader(100, 4, 32, false);
370        assert_eq!(loader.n_batches(), 4);
371        let batches: Vec<_> = loader.collect();
372        assert_eq!(batches.len(), 4);
373        let total: usize = batches.iter().map(|b| b.batch_size()).sum();
374        assert_eq!(total, 100);
375    }
376
377    #[test]
378    fn test_dataloader_last_batch() {
379        // 105 rows, batch 32, drop_last=false → 4 batches (32, 32, 32, 9)
380        let data: Vec<f64> = (0..105 * 2).map(|x| x as f64).collect();
381        let features = Array2::from_shape_vec((105, 2), data).unwrap();
382        let config = DataLoaderConfig {
383            batch_size: 32,
384            shuffle: false,
385            drop_last: false,
386            seed: 0,
387            sampling: SamplingStrategy::Sequential,
388        };
389        let loader = DataLoader::new(features, None, config);
390        let batches: Vec<_> = loader.collect();
391        assert_eq!(batches.len(), 4);
392        assert_eq!(batches.last().unwrap().batch_size(), 9);
393    }
394
395    #[test]
396    fn test_dataloader_drop_last() {
397        // 105 rows, batch 32, drop_last=true → 3 batches of 32 each
398        let data: Vec<f64> = (0..105 * 2).map(|x| x as f64).collect();
399        let features = Array2::from_shape_vec((105, 2), data).unwrap();
400        let config = DataLoaderConfig {
401            batch_size: 32,
402            shuffle: false,
403            drop_last: true,
404            seed: 0,
405            sampling: SamplingStrategy::Sequential,
406        };
407        let loader = DataLoader::new(features, None, config);
408        let batches: Vec<_> = loader.collect();
409        assert_eq!(batches.len(), 3);
410        for b in &batches {
411            assert_eq!(b.batch_size(), 32);
412        }
413    }
414
415    #[test]
416    fn test_dataloader_shuffle() {
417        // Two consecutive epochs should produce different orderings (with high probability)
418        let data: Vec<f64> = (0..50 * 2).map(|x| x as f64).collect();
419        let features = Array2::from_shape_vec((50, 2), data).unwrap();
420        let config = DataLoaderConfig {
421            batch_size: 50,
422            shuffle: true,
423            drop_last: false,
424            seed: 99,
425            sampling: SamplingStrategy::RandomShuffle,
426        };
427        let mut loader = DataLoader::new(features, None, config);
428
429        let first_batch = loader.next().expect("first epoch batch");
430        loader.reset_epoch();
431        let second_batch = loader.next().expect("second epoch batch");
432
433        // The index orderings should differ (p(same) = 1/50! ≈ 0)
434        assert_ne!(first_batch.indices, second_batch.indices);
435    }
436
437    #[test]
438    fn test_dataloader_stratified() {
439        // 30 rows, 3 classes × 10 each; batch_size=6 → each batch has 2 per class
440        let n = 30usize;
441        let data: Vec<f64> = (0..n * 2).map(|x| x as f64).collect();
442        let features = Array2::from_shape_vec((n, 2), data).unwrap();
443        let class_labels: Vec<usize> = (0..n).map(|i| i % 3).collect();
444        let label_f64: Vec<f64> = class_labels.iter().map(|&c| c as f64).collect();
445        let config = DataLoaderConfig {
446            batch_size: 6,
447            shuffle: false,
448            drop_last: false,
449            seed: 1,
450            sampling: SamplingStrategy::Stratified(class_labels),
451        };
452        let loader = DataLoader::new(features, Some(label_f64), config);
453        let batches: Vec<_> = loader.collect();
454        // 30 / 6 = 5 batches
455        assert_eq!(batches.len(), 5);
456        // Each batch should contain rows from multiple classes
457        for batch in &batches {
458            if let Some(lbls) = &batch.labels {
459                let unique: std::collections::HashSet<i64> =
460                    lbls.iter().map(|&x| x as i64).collect();
461                assert!(
462                    unique.len() >= 2,
463                    "expected multiple classes per batch, got {unique:?}"
464                );
465            }
466        }
467    }
468
469    #[test]
470    fn test_dataloader_epoch_count() {
471        let mut loader = make_loader(20, 2, 5, true);
472        assert_eq!(loader.epoch(), 0);
473        // drain
474        for _ in loader.by_ref() {}
475        loader.reset_epoch();
476        assert_eq!(loader.epoch(), 1);
477        for _ in loader.by_ref() {}
478        loader.reset_epoch();
479        assert_eq!(loader.epoch(), 2);
480    }
481
482    #[test]
483    fn test_dataloader_empty() {
484        let features = Array2::<f64>::zeros((0, 3));
485        let config = DataLoaderConfig::default();
486        let loader = DataLoader::new(features, None, config);
487        assert_eq!(loader.n_batches(), 0);
488        let batches: Vec<_> = loader.collect();
489        assert!(batches.is_empty());
490    }
491
492    #[test]
493    fn test_dataloader_exact_multiple() {
494        // 64 rows, batch 32, drop_last = false → exactly 2 full batches
495        let loader = make_loader(64, 4, 32, false);
496        let batches: Vec<_> = loader.collect();
497        assert_eq!(batches.len(), 2);
498        for b in &batches {
499            assert_eq!(b.batch_size(), 32);
500        }
501    }
502
503    #[test]
504    fn test_dataloader_weighted_random() {
505        let n = 40usize;
506        let data: Vec<f64> = (0..n * 2).map(|x| x as f64).collect();
507        let features = Array2::from_shape_vec((n, 2), data).unwrap();
508        // Give first 10 rows very high weight
509        let weights: Vec<f64> = (0..n).map(|i| if i < 10 { 100.0 } else { 1.0 }).collect();
510        let config = DataLoaderConfig {
511            batch_size: n, // one big batch
512            shuffle: false,
513            drop_last: false,
514            seed: 7,
515            sampling: SamplingStrategy::WeightedRandom(weights),
516        };
517        let mut loader = DataLoader::new(features, None, config);
518        let batch = loader.next().expect("batch");
519        // High-weight rows (0-9) should dominate early positions
520        let top10: Vec<usize> = batch.indices[..10].to_vec();
521        let heavy_in_top10 = top10.iter().filter(|&&i| i < 10).count();
522        // Statistically very likely to see ≥ 7 heavy rows in first 10
523        assert!(
524            heavy_in_top10 >= 5,
525            "expected heavy rows near top, got {heavy_in_top10}"
526        );
527    }
528}