Skip to main content

scirs2_datasets/
sampling.rs

1//! Mini-batch sampler for iterating over datasets in batches.
2//!
3//! Provides configurable sampling strategies including sequential, random,
4//! stratified (proportional label representation per batch), and weighted
5//! random sampling.
6
7use crate::error::{DatasetsError, Result};
8
9// ─────────────────────────────────────────────────────────────────────────────
10// LCG helper (deterministic PRNG without external rand dependency)
11// ─────────────────────────────────────────────────────────────────────────────
12
13/// Minimal 64-bit LCG (Knuth parameters).
14struct Lcg64 {
15    state: u64,
16}
17
18impl Lcg64 {
19    fn new(seed: u64) -> Self {
20        Self {
21            state: seed.wrapping_add(1),
22        }
23    }
24
25    fn next_u64(&mut self) -> u64 {
26        self.state = self
27            .state
28            .wrapping_mul(6_364_136_223_846_793_005)
29            .wrapping_add(1_442_695_040_888_963_407);
30        self.state
31    }
32
33    fn next_usize(&mut self, n: usize) -> usize {
34        if n == 0 {
35            return 0;
36        }
37        (self.next_u64() % n as u64) as usize
38    }
39
40    fn next_f64(&mut self) -> f64 {
41        (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
42    }
43}
44
45// ─────────────────────────────────────────────────────────────────────────────
46// Public types
47// ─────────────────────────────────────────────────────────────────────────────
48
49/// Strategy used by the mini-batch sampler to select samples.
50#[non_exhaustive]
51#[derive(Debug, Clone, PartialEq, Default)]
52pub enum SamplerStrategy {
53    /// Iterate through the dataset in order.
54    #[default]
55    Sequential,
56    /// Randomly shuffle all indices, then iterate sequentially over the shuffled order.
57    Random,
58    /// Each batch maintains the same label proportions as the full dataset.
59    Stratified,
60    /// Sample indices according to per-example weights (with replacement).
61    WeightedRandom {
62        /// Per-sample weight. Length must equal the number of samples.
63        weights: Vec<f64>,
64    },
65}
66
67/// Configuration for the [`MiniBatchSampler`].
68#[derive(Debug, Clone)]
69pub struct SamplerConfig {
70    /// Number of samples per batch.
71    pub batch_size: usize,
72    /// Whether to shuffle the dataset before creating batches.
73    pub shuffle: bool,
74    /// If `true`, the last batch is dropped when it has fewer than `batch_size` samples.
75    pub drop_last: bool,
76    /// Random seed for reproducibility.
77    pub seed: u64,
78    /// Sampling strategy.
79    pub strategy: SamplerStrategy,
80}
81
82impl Default for SamplerConfig {
83    fn default() -> Self {
84        Self {
85            batch_size: 32,
86            shuffle: true,
87            drop_last: false,
88            seed: 42,
89            strategy: SamplerStrategy::default(),
90        }
91    }
92}
93
94/// A single mini-batch of data and labels.
95#[derive(Debug, Clone)]
96pub struct MiniBatch {
97    /// Feature vectors for this batch.
98    pub data: Vec<Vec<f64>>,
99    /// Labels for this batch.
100    pub labels: Vec<usize>,
101    /// Indices into the original dataset for this batch.
102    pub indices: Vec<usize>,
103}
104
105/// Mini-batch sampler that yields batches from a dataset.
106///
107/// Construct via [`MiniBatchSampler::new`] and iterate with [`iter_batches`](MiniBatchSampler::iter_batches).
108#[derive(Debug, Clone)]
109pub struct MiniBatchSampler {
110    config: SamplerConfig,
111}
112
113impl MiniBatchSampler {
114    /// Create a new sampler with the given configuration.
115    pub fn new(config: SamplerConfig) -> Self {
116        Self { config }
117    }
118
119    /// Return a reference to the current configuration.
120    pub fn config(&self) -> &SamplerConfig {
121        &self.config
122    }
123
124    /// Generate all mini-batches for the given data and labels.
125    ///
126    /// Returns a `Vec<MiniBatch>` where each batch has at most `config.batch_size`
127    /// samples. When `drop_last` is `true`, the final partial batch is omitted.
128    ///
129    /// # Errors
130    ///
131    /// Returns an error if data and labels have different lengths, or if
132    /// `batch_size` is zero.
133    pub fn iter_batches(&self, data: &[Vec<f64>], labels: &[usize]) -> Result<Vec<MiniBatch>> {
134        iter_batches(data, labels, &self.config)
135    }
136}
137
138/// Generate mini-batches from a dataset according to the given configuration.
139///
140/// This is the free-function equivalent of [`MiniBatchSampler::iter_batches`].
141///
142/// # Errors
143///
144/// Returns an error when `data.len() != labels.len()` or `config.batch_size == 0`.
145pub fn iter_batches(
146    data: &[Vec<f64>],
147    labels: &[usize],
148    config: &SamplerConfig,
149) -> Result<Vec<MiniBatch>> {
150    let n = data.len();
151    if n != labels.len() {
152        return Err(DatasetsError::InvalidFormat(format!(
153            "data length ({}) != labels length ({})",
154            n,
155            labels.len()
156        )));
157    }
158    if config.batch_size == 0 {
159        return Err(DatasetsError::InvalidFormat(
160            "batch_size must be >= 1".into(),
161        ));
162    }
163    if n == 0 {
164        return Ok(Vec::new());
165    }
166
167    let indices = build_index_order(n, labels, config);
168    let mut batches = Vec::new();
169    let mut offset = 0;
170
171    while offset < indices.len() {
172        let end = (offset + config.batch_size).min(indices.len());
173        let batch_indices: Vec<usize> = indices[offset..end].to_vec();
174
175        if config.drop_last && batch_indices.len() < config.batch_size {
176            break;
177        }
178
179        let batch_data: Vec<Vec<f64>> = batch_indices.iter().map(|&i| data[i].clone()).collect();
180        let batch_labels: Vec<usize> = batch_indices.iter().map(|&i| labels[i]).collect();
181
182        batches.push(MiniBatch {
183            data: batch_data,
184            labels: batch_labels,
185            indices: batch_indices,
186        });
187
188        offset = end;
189    }
190
191    Ok(batches)
192}
193
194// ─────────────────────────────────────────────────────────────────────────────
195// Internal helpers
196// ─────────────────────────────────────────────────────────────────────────────
197
198/// Build the ordered index list according to the strategy.
199fn build_index_order(n: usize, labels: &[usize], config: &SamplerConfig) -> Vec<usize> {
200    match &config.strategy {
201        SamplerStrategy::Sequential => {
202            let mut indices: Vec<usize> = (0..n).collect();
203            if config.shuffle {
204                fisher_yates_shuffle(&mut indices, config.seed);
205            }
206            indices
207        }
208
209        SamplerStrategy::Random => {
210            let mut indices: Vec<usize> = (0..n).collect();
211            fisher_yates_shuffle(&mut indices, config.seed);
212            indices
213        }
214
215        SamplerStrategy::Stratified => build_stratified_order(n, labels, config),
216
217        SamplerStrategy::WeightedRandom { weights } => build_weighted_order(n, weights, config),
218    }
219}
220
221/// Fisher-Yates shuffle using seeded LCG.
222fn fisher_yates_shuffle(indices: &mut [usize], seed: u64) {
223    let n = indices.len();
224    if n <= 1 {
225        return;
226    }
227    let mut rng = Lcg64::new(seed);
228    for i in (1..n).rev() {
229        let j = rng.next_usize(i + 1);
230        indices.swap(i, j);
231    }
232}
233
234/// Build an index order that groups samples so each batch has proportional
235/// label representation.
236fn build_stratified_order(n: usize, labels: &[usize], config: &SamplerConfig) -> Vec<usize> {
237    if n == 0 {
238        return Vec::new();
239    }
240
241    // Group indices by class.
242    let max_class = labels.iter().copied().max().unwrap_or(0);
243    let mut class_indices: Vec<Vec<usize>> = vec![Vec::new(); max_class + 1];
244    for (i, &label) in labels.iter().enumerate() {
245        class_indices[label].push(i);
246    }
247
248    // Optionally shuffle within each class.
249    if config.shuffle {
250        for (cls, indices) in class_indices.iter_mut().enumerate() {
251            let class_seed = config.seed.wrapping_add(cls as u64 * 0x9e37_79b9_7f4a_7c15);
252            fisher_yates_shuffle(indices, class_seed);
253        }
254    }
255
256    // Interleave: round-robin across classes to ensure proportional representation
257    // within each batch-sized chunk.
258    let mut result = Vec::with_capacity(n);
259    let mut cursors: Vec<usize> = vec![0; class_indices.len()];
260    let mut remaining = n;
261
262    while remaining > 0 {
263        let mut added = false;
264        for (cls, indices) in class_indices.iter().enumerate() {
265            if cursors[cls] < indices.len() {
266                result.push(indices[cursors[cls]]);
267                cursors[cls] += 1;
268                remaining -= 1;
269                added = true;
270                if remaining == 0 {
271                    break;
272                }
273            }
274        }
275        if !added {
276            break;
277        }
278    }
279
280    result
281}
282
283/// Build an index order using weighted sampling with replacement.
284fn build_weighted_order(n: usize, weights: &[f64], config: &SamplerConfig) -> Vec<usize> {
285    if n == 0 || weights.is_empty() {
286        return Vec::new();
287    }
288
289    let mut rng = Lcg64::new(config.seed);
290    let actual_weights: Vec<f64> = if weights.len() >= n {
291        weights[..n].to_vec()
292    } else {
293        // Pad with uniform weight 1.0.
294        let mut w = weights.to_vec();
295        w.resize(n, 1.0);
296        w
297    };
298
299    // Build cumulative distribution.
300    let total: f64 = actual_weights.iter().sum();
301    if total <= 0.0 {
302        // Fallback to uniform.
303        return (0..n).collect();
304    }
305    let cumulative: Vec<f64> = actual_weights
306        .iter()
307        .scan(0.0, |acc, &w| {
308            *acc += w / total;
309            Some(*acc)
310        })
311        .collect();
312
313    // Sample n indices with replacement.
314    (0..n)
315        .map(|_| {
316            let u = rng.next_f64();
317            // Binary search in cumulative distribution.
318            match cumulative.binary_search_by(|probe| {
319                probe.partial_cmp(&u).unwrap_or(std::cmp::Ordering::Equal)
320            }) {
321                Ok(idx) => idx.min(n - 1),
322                Err(idx) => idx.min(n - 1),
323            }
324        })
325        .collect()
326}
327
328// ─────────────────────────────────────────────────────────────────────────────
329// Tests
330// ─────────────────────────────────────────────────────────────────────────────
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335
336    fn make_simple_data(n: usize, n_features: usize) -> Vec<Vec<f64>> {
337        (0..n)
338            .map(|i| {
339                (0..n_features)
340                    .map(|j| (i * n_features + j) as f64)
341                    .collect()
342            })
343            .collect()
344    }
345
346    #[test]
347    fn test_sequential_batches_correct_size() {
348        let data = make_simple_data(100, 5);
349        let labels: Vec<usize> = (0..100).map(|i| i % 3).collect();
350        let config = SamplerConfig {
351            batch_size: 32,
352            shuffle: false,
353            drop_last: false,
354            seed: 42,
355            strategy: SamplerStrategy::Sequential,
356        };
357        let batches = iter_batches(&data, &labels, &config).expect("should succeed");
358        // 100 / 32 = 3 full + 1 partial = 4 batches
359        assert_eq!(batches.len(), 4);
360        assert_eq!(batches[0].data.len(), 32);
361        assert_eq!(batches[3].data.len(), 4); // remainder
362    }
363
364    #[test]
365    fn test_drop_last() {
366        let data = make_simple_data(50, 3);
367        let labels: Vec<usize> = vec![0; 50];
368        let config = SamplerConfig {
369            batch_size: 16,
370            shuffle: false,
371            drop_last: true,
372            seed: 0,
373            strategy: SamplerStrategy::Sequential,
374        };
375        let batches = iter_batches(&data, &labels, &config).expect("should succeed");
376        // 50 / 16 = 3 full batches (48), last partial (2) dropped
377        assert_eq!(batches.len(), 3);
378        for b in &batches {
379            assert_eq!(b.data.len(), 16);
380        }
381    }
382
383    #[test]
384    fn test_random_shuffles_indices() {
385        let data = make_simple_data(20, 2);
386        let labels: Vec<usize> = vec![0; 20];
387        let config = SamplerConfig {
388            batch_size: 20,
389            shuffle: true,
390            drop_last: false,
391            seed: 99,
392            strategy: SamplerStrategy::Random,
393        };
394        let batches = iter_batches(&data, &labels, &config).expect("should succeed");
395        assert_eq!(batches.len(), 1);
396        // Indices should be a permutation of 0..20
397        let mut sorted = batches[0].indices.clone();
398        sorted.sort_unstable();
399        assert_eq!(sorted, (0..20).collect::<Vec<_>>());
400        // Very unlikely to be in natural order
401        assert_ne!(batches[0].indices, (0..20).collect::<Vec<_>>());
402    }
403
404    #[test]
405    fn test_stratified_label_proportions() {
406        // 60 class-0, 40 class-1
407        let n = 100;
408        let mut labels: Vec<usize> = vec![0; 60];
409        labels.extend(vec![1; 40]);
410        let data = make_simple_data(n, 2);
411
412        let config = SamplerConfig {
413            batch_size: 20,
414            shuffle: false,
415            drop_last: false,
416            seed: 42,
417            strategy: SamplerStrategy::Stratified,
418        };
419        let batches = iter_batches(&data, &labels, &config).expect("should succeed");
420        assert_eq!(batches.len(), 5); // 100 / 20
421
422        // Check overall: total class proportions should be preserved
423        let total_c0: usize = batches
424            .iter()
425            .map(|b| b.labels.iter().filter(|&&l| l == 0).count())
426            .sum();
427        let total_c1: usize = batches
428            .iter()
429            .map(|b| b.labels.iter().filter(|&&l| l == 1).count())
430            .sum();
431        assert_eq!(total_c0, 60);
432        assert_eq!(total_c1, 40);
433
434        // Each batch should have at least some representation of both classes
435        // (round-robin interleaving distributes them across batches)
436        let batches_with_both: usize = batches
437            .iter()
438            .filter(|b| {
439                let c0 = b.labels.iter().filter(|&&l| l == 0).count();
440                let c1 = b.labels.iter().filter(|&&l| l == 1).count();
441                c0 > 0 && c1 > 0
442            })
443            .count();
444        // At least 4 out of 5 batches should have both classes
445        assert!(
446            batches_with_both >= 4,
447            "Expected most batches to have both classes, got {batches_with_both}"
448        );
449    }
450
451    #[test]
452    fn test_weighted_sampling() {
453        let n = 50;
454        let data = make_simple_data(n, 2);
455        let labels: Vec<usize> = vec![0; n];
456        // Give all weight to index 0
457        let mut weights = vec![0.0; n];
458        weights[0] = 1.0;
459
460        let config = SamplerConfig {
461            batch_size: 10,
462            shuffle: false,
463            drop_last: false,
464            seed: 42,
465            strategy: SamplerStrategy::WeightedRandom { weights },
466        };
467        let batches = iter_batches(&data, &labels, &config).expect("should succeed");
468        // All samples should be index 0
469        for batch in &batches {
470            for &idx in &batch.indices {
471                assert_eq!(idx, 0, "All indices should be 0 with weight=[1,0,0,...]");
472            }
473        }
474    }
475
476    #[test]
477    fn test_reproducibility_same_seed() {
478        let data = make_simple_data(40, 3);
479        let labels: Vec<usize> = (0..40).map(|i| i % 2).collect();
480        let config = SamplerConfig {
481            batch_size: 10,
482            shuffle: true,
483            drop_last: false,
484            seed: 777,
485            strategy: SamplerStrategy::Random,
486        };
487        let b1 = iter_batches(&data, &labels, &config).expect("ok");
488        let b2 = iter_batches(&data, &labels, &config).expect("ok");
489        assert_eq!(b1.len(), b2.len());
490        for (a, b) in b1.iter().zip(b2.iter()) {
491            assert_eq!(a.indices, b.indices);
492        }
493    }
494
495    #[test]
496    fn test_mismatched_lengths_error() {
497        let data = make_simple_data(10, 2);
498        let labels: Vec<usize> = vec![0; 5];
499        let config = SamplerConfig::default();
500        assert!(iter_batches(&data, &labels, &config).is_err());
501    }
502
503    #[test]
504    fn test_zero_batch_size_error() {
505        let data = make_simple_data(10, 2);
506        let labels: Vec<usize> = vec![0; 10];
507        let config = SamplerConfig {
508            batch_size: 0,
509            ..Default::default()
510        };
511        assert!(iter_batches(&data, &labels, &config).is_err());
512    }
513
514    #[test]
515    fn test_empty_dataset() {
516        let data: Vec<Vec<f64>> = Vec::new();
517        let labels: Vec<usize> = Vec::new();
518        let config = SamplerConfig::default();
519        let batches = iter_batches(&data, &labels, &config).expect("ok");
520        assert!(batches.is_empty());
521    }
522
523    #[test]
524    fn test_sampler_struct() {
525        let data = make_simple_data(20, 2);
526        let labels: Vec<usize> = vec![0; 20];
527        let sampler = MiniBatchSampler::new(SamplerConfig {
528            batch_size: 5,
529            shuffle: false,
530            drop_last: false,
531            seed: 0,
532            strategy: SamplerStrategy::Sequential,
533        });
534        let batches = sampler.iter_batches(&data, &labels).expect("ok");
535        assert_eq!(batches.len(), 4);
536        assert_eq!(sampler.config().batch_size, 5);
537    }
538
539    #[test]
540    fn test_all_indices_covered_sequential() {
541        let n = 37;
542        let data = make_simple_data(n, 2);
543        let labels: Vec<usize> = vec![0; n];
544        let config = SamplerConfig {
545            batch_size: 10,
546            shuffle: false,
547            drop_last: false,
548            seed: 0,
549            strategy: SamplerStrategy::Sequential,
550        };
551        let batches = iter_batches(&data, &labels, &config).expect("ok");
552        let mut all_indices: Vec<usize> = batches
553            .iter()
554            .flat_map(|b| b.indices.iter().copied())
555            .collect();
556        all_indices.sort_unstable();
557        assert_eq!(all_indices, (0..n).collect::<Vec<_>>());
558    }
559
560    #[test]
561    fn test_batch_data_matches_original() {
562        let data = make_simple_data(15, 3);
563        let labels: Vec<usize> = (0..15).map(|i| i % 2).collect();
564        let config = SamplerConfig {
565            batch_size: 5,
566            shuffle: false,
567            drop_last: false,
568            seed: 0,
569            strategy: SamplerStrategy::Sequential,
570        };
571        let batches = iter_batches(&data, &labels, &config).expect("ok");
572        for batch in &batches {
573            for (pos, &idx) in batch.indices.iter().enumerate() {
574                assert_eq!(batch.data[pos], data[idx]);
575                assert_eq!(batch.labels[pos], labels[idx]);
576            }
577        }
578    }
579}