Skip to main content

tenflowers_neural/utils/
batch_processing.rs

1//! Batch Processing Utilities
2//!
3//! This module provides efficient batch processing utilities for neural network training,
4//! including batching strategies, collation functions, and data sampling methods.
5
6use scirs2_core::random::{rngs::StdRng, Rng, SeedableRng};
7use scirs2_core::RngExt;
8use std::collections::HashMap;
9use tenflowers_core::{Device, Result, Tensor, TensorError};
10
11#[cfg(feature = "serialize")]
12use serde::{Deserialize, Serialize};
13
14/// Batch sampling strategy
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
17pub enum SamplingStrategy {
18    /// Sequential sampling (in order)
19    Sequential,
20    /// Random sampling (with replacement)
21    Random,
22    /// Shuffle once at the beginning
23    Shuffle,
24    /// Stratified sampling (balanced classes)
25    Stratified,
26    /// Weighted sampling based on importance
27    Weighted,
28}
29
30/// Padding strategy for variable-length sequences
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
33pub enum PaddingStrategy {
34    /// Pad to the longest sequence in the batch
35    LongestInBatch,
36    /// Pad to a fixed maximum length
37    FixedLength,
38    /// Pad to the nearest multiple of a value
39    NearestMultiple,
40    /// No padding (all sequences must be same length)
41    NoPadding,
42}
43
44/// Collation strategy for combining samples into batches
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
47pub enum CollationStrategy {
48    /// Stack tensors along batch dimension
49    Stack,
50    /// Concatenate tensors
51    Concatenate,
52    /// Pad and stack (for variable-length sequences)
53    PadAndStack,
54    /// Custom collation (user-defined)
55    Custom,
56}
57
58/// Batch configuration
59#[derive(Debug, Clone)]
60#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
61pub struct BatchConfig {
62    /// Batch size
63    pub batch_size: usize,
64    /// Whether to drop the last incomplete batch
65    pub drop_last: bool,
66    /// Sampling strategy
67    pub sampling_strategy: SamplingStrategy,
68    /// Padding strategy (for sequences)
69    pub padding_strategy: PaddingStrategy,
70    /// Collation strategy
71    pub collation_strategy: CollationStrategy,
72    /// Maximum sequence length (for padding)
73    pub max_sequence_length: Option<usize>,
74    /// Padding value
75    pub padding_value: f32,
76    /// Random seed for reproducibility
77    pub seed: Option<u64>,
78}
79
80impl Default for BatchConfig {
81    fn default() -> Self {
82        Self {
83            batch_size: 32,
84            drop_last: false,
85            sampling_strategy: SamplingStrategy::Sequential,
86            padding_strategy: PaddingStrategy::LongestInBatch,
87            collation_strategy: CollationStrategy::Stack,
88            max_sequence_length: None,
89            padding_value: 0.0,
90            seed: None,
91        }
92    }
93}
94
95impl BatchConfig {
96    /// Create new batch configuration
97    pub fn new(batch_size: usize) -> Self {
98        Self {
99            batch_size,
100            ..Default::default()
101        }
102    }
103
104    /// Set whether to drop last incomplete batch
105    pub fn with_drop_last(mut self, drop_last: bool) -> Self {
106        self.drop_last = drop_last;
107        self
108    }
109
110    /// Set sampling strategy
111    pub fn with_sampling_strategy(mut self, strategy: SamplingStrategy) -> Self {
112        self.sampling_strategy = strategy;
113        self
114    }
115
116    /// Set padding strategy
117    pub fn with_padding_strategy(mut self, strategy: PaddingStrategy) -> Self {
118        self.padding_strategy = strategy;
119        self
120    }
121
122    /// Set collation strategy
123    pub fn with_collation_strategy(mut self, strategy: CollationStrategy) -> Self {
124        self.collation_strategy = strategy;
125        self
126    }
127
128    /// Set maximum sequence length
129    pub fn with_max_sequence_length(mut self, max_len: usize) -> Self {
130        self.max_sequence_length = Some(max_len);
131        self
132    }
133
134    /// Set padding value
135    pub fn with_padding_value(mut self, value: f32) -> Self {
136        self.padding_value = value;
137        self
138    }
139
140    /// Set random seed
141    pub fn with_seed(mut self, seed: u64) -> Self {
142        self.seed = Some(seed);
143        self
144    }
145}
146
147/// Perform an in-place Fisher-Yates shuffle on `indices` using the provided RNG.
148fn fisher_yates_shuffle(indices: &mut [usize], rng: &mut StdRng) {
149    let n = indices.len();
150    if n <= 1 {
151        return;
152    }
153    for i in (1..n).rev() {
154        let j = rng.random_range(0..=i);
155        indices.swap(i, j);
156    }
157}
158
159/// Batch sampler for generating batch indices
160pub struct BatchSampler {
161    dataset_size: usize,
162    config: BatchConfig,
163    current_index: usize,
164    indices: Vec<usize>,
165    rng: StdRng,
166}
167
168impl BatchSampler {
169    /// Create new batch sampler
170    pub fn new(dataset_size: usize, config: BatchConfig) -> Self {
171        let seed = config.seed.unwrap_or(0);
172        let mut rng = StdRng::seed_from_u64(seed);
173
174        let mut indices: Vec<usize> = (0..dataset_size).collect();
175        if config.sampling_strategy == SamplingStrategy::Shuffle {
176            fisher_yates_shuffle(&mut indices, &mut rng);
177        }
178
179        Self {
180            dataset_size,
181            config,
182            current_index: 0,
183            indices,
184            rng,
185        }
186    }
187
188    /// Get next batch of indices
189    pub fn next_batch(&mut self) -> Option<Vec<usize>> {
190        if self.current_index >= self.dataset_size {
191            return None;
192        }
193
194        let end_index = (self.current_index + self.config.batch_size).min(self.dataset_size);
195        let batch_indices: Vec<usize> = self.indices[self.current_index..end_index].to_vec();
196
197        self.current_index = end_index;
198
199        // Check if we should drop the last incomplete batch
200        if self.config.drop_last && batch_indices.len() < self.config.batch_size {
201            None
202        } else {
203            Some(batch_indices)
204        }
205    }
206
207    /// Reset the sampler to the beginning
208    pub fn reset(&mut self) {
209        self.current_index = 0;
210
211        // Reshuffle if needed — uses the continued RNG state so each epoch
212        // gets a different permutation while remaining deterministic from
213        // the original seed.
214        if self.config.sampling_strategy == SamplingStrategy::Shuffle {
215            self.indices = (0..self.dataset_size).collect();
216            fisher_yates_shuffle(&mut self.indices, &mut self.rng);
217        }
218    }
219
220    /// Get total number of batches
221    pub fn num_batches(&self) -> usize {
222        let total = (self.dataset_size + self.config.batch_size - 1) / self.config.batch_size;
223        if self.config.drop_last && self.dataset_size % self.config.batch_size != 0 {
224            total - 1
225        } else {
226            total
227        }
228    }
229
230    /// Get current batch index
231    pub fn current_batch_index(&self) -> usize {
232        self.current_index / self.config.batch_size
233    }
234}
235
236/// Collation function for combining samples into batches
237pub struct Collator {
238    config: BatchConfig,
239}
240
241impl Collator {
242    /// Create new collator
243    pub fn new(config: BatchConfig) -> Self {
244        Self { config }
245    }
246
247    /// Collate samples into a batch
248    pub fn collate<T>(&self, samples: &[Tensor<T>]) -> Result<Tensor<T>>
249    where
250        T: Clone + Default,
251    {
252        if samples.is_empty() {
253            return Err(TensorError::invalid_shape_simple(
254                "Cannot collate empty batch".to_string(),
255            ));
256        }
257
258        match self.config.collation_strategy {
259            CollationStrategy::Stack => self.stack_samples(samples),
260            CollationStrategy::PadAndStack => self.pad_and_stack_samples(samples),
261            _ => {
262                // Placeholder for other strategies
263                self.stack_samples(samples)
264            }
265        }
266    }
267
268    /// Stack samples along batch dimension
269    fn stack_samples<T>(&self, samples: &[Tensor<T>]) -> Result<Tensor<T>>
270    where
271        T: Clone + Default,
272    {
273        // Placeholder implementation
274        // Real implementation would properly stack tensors
275        Ok(samples[0].clone())
276    }
277
278    /// Pad and stack samples (for variable-length sequences)
279    fn pad_and_stack_samples<T>(&self, samples: &[Tensor<T>]) -> Result<Tensor<T>>
280    where
281        T: Clone + Default,
282    {
283        // Placeholder implementation
284        // Real implementation would:
285        // 1. Find max length based on padding strategy
286        // 2. Pad each sample to max length
287        // 3. Stack padded samples
288        Ok(samples[0].clone())
289    }
290
291    /// Get padding length based on strategy
292    fn get_padding_length(&self, sample_lengths: &[usize]) -> usize {
293        match self.config.padding_strategy {
294            PaddingStrategy::LongestInBatch => *sample_lengths.iter().max().unwrap_or(&0),
295            PaddingStrategy::FixedLength => self.config.max_sequence_length.unwrap_or(512),
296            PaddingStrategy::NearestMultiple => {
297                let max_len = *sample_lengths.iter().max().unwrap_or(&0);
298                let multiple = self.config.max_sequence_length.unwrap_or(8);
299                ((max_len + multiple - 1) / multiple) * multiple
300            }
301            PaddingStrategy::NoPadding => sample_lengths[0],
302        }
303    }
304}
305
306/// Batch statistics for monitoring
307#[derive(Debug, Clone)]
308#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
309pub struct BatchStatistics {
310    /// Total number of batches processed
311    pub total_batches: usize,
312    /// Total number of samples processed
313    pub total_samples: usize,
314    /// Average batch size
315    pub avg_batch_size: f64,
316    /// Min batch size seen
317    pub min_batch_size: usize,
318    /// Max batch size seen
319    pub max_batch_size: usize,
320    /// Average padding ratio (for sequences)
321    pub avg_padding_ratio: f64,
322}
323
324impl BatchStatistics {
325    /// Create new batch statistics
326    pub fn new() -> Self {
327        Self {
328            total_batches: 0,
329            total_samples: 0,
330            avg_batch_size: 0.0,
331            min_batch_size: usize::MAX,
332            max_batch_size: 0,
333            avg_padding_ratio: 0.0,
334        }
335    }
336
337    /// Record a batch
338    pub fn record_batch(&mut self, batch_size: usize, padding_ratio: f64) {
339        self.total_batches += 1;
340        self.total_samples += batch_size;
341        self.min_batch_size = self.min_batch_size.min(batch_size);
342        self.max_batch_size = self.max_batch_size.max(batch_size);
343
344        // Update running averages
345        let n = self.total_batches as f64;
346        self.avg_batch_size = (self.avg_batch_size * (n - 1.0) + batch_size as f64) / n;
347        self.avg_padding_ratio = (self.avg_padding_ratio * (n - 1.0) + padding_ratio) / n;
348    }
349
350    /// Reset statistics
351    pub fn reset(&mut self) {
352        *self = Self::new();
353    }
354
355    /// Get efficiency (ratio of real data to padded data)
356    pub fn efficiency(&self) -> f64 {
357        1.0 - self.avg_padding_ratio
358    }
359}
360
361impl Default for BatchStatistics {
362    fn default() -> Self {
363        Self::new()
364    }
365}
366
367/// Utilities for batch processing
368pub mod batch_utils {
369    use super::*;
370
371    /// Calculate optimal batch size for given memory constraints
372    pub fn calculate_optimal_batch_size(
373        sample_memory_bytes: usize,
374        available_memory_bytes: usize,
375        safety_factor: f64,
376    ) -> usize {
377        let usable_memory = (available_memory_bytes as f64 * safety_factor) as usize;
378        (usable_memory / sample_memory_bytes).max(1)
379    }
380
381    /// Calculate number of batches for a dataset
382    pub fn calculate_num_batches(dataset_size: usize, batch_size: usize, drop_last: bool) -> usize {
383        let total = (dataset_size + batch_size - 1) / batch_size;
384        if drop_last && dataset_size % batch_size != 0 {
385            total - 1
386        } else {
387            total
388        }
389    }
390
391    /// Calculate padding overhead
392    pub fn calculate_padding_overhead(original_lengths: &[usize], padded_length: usize) -> f64 {
393        let original_total: usize = original_lengths.iter().sum();
394        let padded_total = original_lengths.len() * padded_length;
395
396        if padded_total == 0 {
397            0.0
398        } else {
399            1.0 - (original_total as f64 / padded_total as f64)
400        }
401    }
402
403    /// Find optimal padding length to minimize overhead
404    pub fn find_optimal_padding_length(lengths: &[usize], multiple: usize) -> usize {
405        let max_len = lengths.iter().max().copied().unwrap_or(0);
406        ((max_len + multiple - 1) / multiple) * multiple
407    }
408
409    /// Group samples by similar lengths for efficient batching
410    pub fn group_by_length(lengths: Vec<usize>, num_groups: usize) -> Vec<Vec<usize>> {
411        if lengths.is_empty() || num_groups == 0 {
412            return vec![];
413        }
414
415        let mut indexed_lengths: Vec<_> = lengths.into_iter().enumerate().collect();
416        indexed_lengths.sort_by_key(|(_, len)| *len);
417
418        let group_size = (indexed_lengths.len() + num_groups - 1) / num_groups;
419        let mut groups = vec![Vec::new(); num_groups];
420
421        for (group_idx, chunk) in indexed_lengths.chunks(group_size).enumerate() {
422            groups[group_idx] = chunk.iter().map(|(idx, _)| *idx).collect();
423        }
424
425        groups.into_iter().filter(|g| !g.is_empty()).collect()
426    }
427
428    /// Calculate memory efficiency of batching strategy
429    pub fn calculate_memory_efficiency(
430        batch_size: usize,
431        avg_sequence_length: usize,
432        max_sequence_length: usize,
433    ) -> f64 {
434        let used = batch_size * avg_sequence_length;
435        let allocated = batch_size * max_sequence_length;
436
437        if allocated == 0 {
438            0.0
439        } else {
440            used as f64 / allocated as f64
441        }
442    }
443}
444
445#[cfg(test)]
446mod tests {
447    use super::*;
448
449    #[test]
450    fn test_sampling_strategy_variants() {
451        let strategies = [
452            SamplingStrategy::Sequential,
453            SamplingStrategy::Random,
454            SamplingStrategy::Shuffle,
455            SamplingStrategy::Stratified,
456            SamplingStrategy::Weighted,
457        ];
458
459        assert_eq!(strategies.len(), 5);
460    }
461
462    #[test]
463    fn test_batch_config_default() {
464        let config = BatchConfig::default();
465        assert_eq!(config.batch_size, 32);
466        assert!(!config.drop_last);
467        assert_eq!(config.sampling_strategy, SamplingStrategy::Sequential);
468    }
469
470    #[test]
471    fn test_batch_config_builder() {
472        let config = BatchConfig::new(64)
473            .with_drop_last(true)
474            .with_padding_value(1.0)
475            .with_max_sequence_length(128)
476            .with_seed(42);
477
478        assert_eq!(config.batch_size, 64);
479        assert!(config.drop_last);
480        assert_eq!(config.padding_value, 1.0);
481        assert_eq!(config.max_sequence_length, Some(128));
482        assert_eq!(config.seed, Some(42));
483    }
484
485    #[test]
486    fn test_batch_sampler_creation() {
487        let config = BatchConfig::new(10);
488        let sampler = BatchSampler::new(100, config);
489
490        assert_eq!(sampler.dataset_size, 100);
491        assert_eq!(sampler.num_batches(), 10);
492    }
493
494    #[test]
495    fn test_batch_sampler_next_batch() {
496        let config = BatchConfig::new(10);
497        let mut sampler = BatchSampler::new(25, config);
498
499        let batch1 = sampler.next_batch();
500        assert!(batch1.is_some());
501        assert_eq!(batch1.expect("test: operation should succeed").len(), 10);
502
503        let batch2 = sampler.next_batch();
504        assert!(batch2.is_some());
505        assert_eq!(batch2.expect("test: operation should succeed").len(), 10);
506
507        let batch3 = sampler.next_batch();
508        assert!(batch3.is_some());
509        assert_eq!(batch3.expect("test: operation should succeed").len(), 5); // Last incomplete batch
510    }
511
512    #[test]
513    fn test_batch_sampler_drop_last() {
514        let config = BatchConfig::new(10).with_drop_last(true);
515        let mut sampler = BatchSampler::new(25, config);
516
517        sampler.next_batch();
518        sampler.next_batch();
519        let batch3 = sampler.next_batch();
520
521        assert!(batch3.is_none()); // Last batch should be dropped
522    }
523
524    #[test]
525    fn test_batch_sampler_num_batches() {
526        let config = BatchConfig::new(10);
527        let sampler = BatchSampler::new(25, config);
528        assert_eq!(sampler.num_batches(), 3);
529
530        let config_drop = BatchConfig::new(10).with_drop_last(true);
531        let sampler_drop = BatchSampler::new(25, config_drop);
532        assert_eq!(sampler_drop.num_batches(), 2);
533    }
534
535    #[test]
536    fn test_batch_sampler_reset() {
537        let config = BatchConfig::new(10);
538        let mut sampler = BatchSampler::new(25, config);
539
540        sampler.next_batch();
541        sampler.next_batch();
542        assert_eq!(sampler.current_batch_index(), 2);
543
544        sampler.reset();
545        assert_eq!(sampler.current_batch_index(), 0);
546    }
547
548    #[test]
549    fn test_collator_creation() {
550        let config = BatchConfig::new(32);
551        let collator = Collator::new(config);
552
553        // Just verify it was created successfully
554        assert_eq!(collator.config.batch_size, 32);
555    }
556
557    #[test]
558    fn test_batch_statistics_creation() {
559        let stats = BatchStatistics::new();
560        assert_eq!(stats.total_batches, 0);
561        assert_eq!(stats.total_samples, 0);
562        assert_eq!(stats.avg_batch_size, 0.0);
563    }
564
565    #[test]
566    fn test_batch_statistics_record() {
567        let mut stats = BatchStatistics::new();
568
569        stats.record_batch(32, 0.1);
570        assert_eq!(stats.total_batches, 1);
571        assert_eq!(stats.total_samples, 32);
572        assert_eq!(stats.avg_batch_size, 32.0);
573
574        stats.record_batch(30, 0.15);
575        assert_eq!(stats.total_batches, 2);
576        assert_eq!(stats.total_samples, 62);
577        assert_eq!(stats.avg_batch_size, 31.0);
578    }
579
580    #[test]
581    fn test_batch_statistics_min_max() {
582        let mut stats = BatchStatistics::new();
583
584        stats.record_batch(32, 0.1);
585        stats.record_batch(20, 0.1);
586        stats.record_batch(40, 0.1);
587
588        assert_eq!(stats.min_batch_size, 20);
589        assert_eq!(stats.max_batch_size, 40);
590    }
591
592    #[test]
593    fn test_batch_statistics_efficiency() {
594        let mut stats = BatchStatistics::new();
595        stats.record_batch(32, 0.2); // 20% padding
596
597        let efficiency = stats.efficiency();
598        assert!((efficiency - 0.8).abs() < 0.01); // 80% efficiency
599    }
600
601    #[test]
602    fn test_utils_calculate_optimal_batch_size() {
603        let batch_size = batch_utils::calculate_optimal_batch_size(
604            1024 * 1024,      // 1 MB per sample
605            1024 * 1024 * 64, // 64 MB available
606            0.8,              // 80% safety factor
607        );
608
609        assert!(batch_size > 0);
610        assert!(batch_size <= 64);
611    }
612
613    #[test]
614    fn test_utils_calculate_num_batches() {
615        assert_eq!(batch_utils::calculate_num_batches(100, 32, false), 4);
616        assert_eq!(batch_utils::calculate_num_batches(100, 32, true), 3);
617        assert_eq!(batch_utils::calculate_num_batches(96, 32, false), 3);
618        assert_eq!(batch_utils::calculate_num_batches(96, 32, true), 3);
619    }
620
621    #[test]
622    fn test_utils_calculate_padding_overhead() {
623        let lengths = vec![10, 15, 12, 8];
624        let overhead = batch_utils::calculate_padding_overhead(&lengths, 20);
625
626        // Total original: 45, padded total: 80
627        // Overhead: 1 - (45/80) = 0.4375
628        assert!((overhead - 0.4375).abs() < 0.01);
629    }
630
631    #[test]
632    fn test_utils_find_optimal_padding_length() {
633        let lengths = vec![10, 15, 18, 22];
634        let optimal = batch_utils::find_optimal_padding_length(&lengths, 8);
635
636        assert_eq!(optimal, 24); // Next multiple of 8 after 22
637    }
638
639    #[test]
640    fn test_utils_group_by_length() {
641        let lengths = vec![10, 25, 15, 30, 20, 12, 28];
642        let groups = batch_utils::group_by_length(lengths, 3);
643
644        assert_eq!(groups.len(), 3);
645        // Each group should have similar-length sequences
646        for group in &groups {
647            assert!(!group.is_empty());
648        }
649    }
650
651    #[test]
652    fn test_utils_calculate_memory_efficiency() {
653        let efficiency = batch_utils::calculate_memory_efficiency(
654            32,  // batch size
655            100, // avg sequence length
656            128, // max sequence length
657        );
658
659        // Efficiency: (32 * 100) / (32 * 128) = 3200 / 4096 ≈ 0.78
660        assert!((efficiency - 0.78125).abs() < 0.01);
661    }
662
663    #[test]
664    fn test_padding_strategy_variants() {
665        let strategies = [
666            PaddingStrategy::LongestInBatch,
667            PaddingStrategy::FixedLength,
668            PaddingStrategy::NearestMultiple,
669            PaddingStrategy::NoPadding,
670        ];
671
672        assert_eq!(strategies.len(), 4);
673    }
674
675    #[test]
676    fn test_collation_strategy_variants() {
677        let strategies = [
678            CollationStrategy::Stack,
679            CollationStrategy::Concatenate,
680            CollationStrategy::PadAndStack,
681            CollationStrategy::Custom,
682        ];
683
684        assert_eq!(strategies.len(), 4);
685    }
686}