Skip to main content

tensorlogic_train/
batch.rs

1//! Batch management and data loading.
2
3use crate::{TrainError, TrainResult};
4use scirs2_core::ndarray::{Array, ArrayView, Ix2};
5use std::collections::HashSet;
6
7/// Configuration for batch processing.
8#[derive(Debug, Clone)]
9pub struct BatchConfig {
10    /// Batch size.
11    pub batch_size: usize,
12    /// Whether to shuffle data.
13    pub shuffle: bool,
14    /// Whether to drop last incomplete batch.
15    pub drop_last: bool,
16    /// Random seed for shuffling.
17    pub seed: Option<u64>,
18}
19
20impl Default for BatchConfig {
21    fn default() -> Self {
22        Self {
23            batch_size: 32,
24            shuffle: true,
25            drop_last: false,
26            seed: None,
27        }
28    }
29}
30
31/// Iterator over batches of data.
32pub struct BatchIterator {
33    /// Configuration.
34    config: BatchConfig,
35    /// Total number of samples.
36    num_samples: usize,
37    /// Current batch index.
38    current_batch: usize,
39    /// Shuffled indices (if shuffle=true).
40    indices: Vec<usize>,
41}
42
43impl BatchIterator {
44    /// Create a new batch iterator.
45    pub fn new(num_samples: usize, config: BatchConfig) -> Self {
46        let mut indices: Vec<usize> = (0..num_samples).collect();
47
48        if config.shuffle {
49            // Simple shuffle using seed if provided
50            if let Some(seed) = config.seed {
51                // Deterministic shuffle based on seed
52                let mut rng_state = seed;
53                for i in (1..indices.len()).rev() {
54                    rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
55                    let j = (rng_state % (i as u64 + 1)) as usize;
56                    indices.swap(i, j);
57                }
58            } else {
59                // Non-deterministic shuffle using simple algorithm
60                use std::collections::hash_map::RandomState;
61                use std::hash::BuildHasher;
62                let hasher = RandomState::new();
63                indices.sort_by_cached_key(|&i| hasher.hash_one(i));
64            }
65        }
66
67        Self {
68            config,
69            num_samples,
70            current_batch: 0,
71            indices,
72        }
73    }
74
75    /// Get the next batch indices.
76    pub fn next_batch(&mut self) -> Option<Vec<usize>> {
77        if self.current_batch * self.config.batch_size >= self.num_samples {
78            return None;
79        }
80
81        let start = self.current_batch * self.config.batch_size;
82        let end = (start + self.config.batch_size).min(self.num_samples);
83
84        if self.config.drop_last && end - start < self.config.batch_size {
85            return None;
86        }
87
88        self.current_batch += 1;
89        Some(self.indices[start..end].to_vec())
90    }
91
92    /// Reset iterator to the beginning.
93    pub fn reset(&mut self) {
94        self.current_batch = 0;
95
96        if self.config.shuffle {
97            // Re-shuffle for next epoch
98            if let Some(seed) = self.config.seed {
99                let mut rng_state = seed.wrapping_add(self.current_batch as u64);
100                for i in (1..self.indices.len()).rev() {
101                    rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
102                    let j = (rng_state % (i as u64 + 1)) as usize;
103                    self.indices.swap(i, j);
104                }
105            } else {
106                use std::collections::hash_map::RandomState;
107                use std::hash::BuildHasher;
108                let hasher = RandomState::new();
109                self.indices
110                    .sort_by_cached_key(|&i| hasher.hash_one((i, self.current_batch)));
111            }
112        }
113    }
114
115    /// Get total number of batches.
116    pub fn num_batches(&self) -> usize {
117        let total = self.num_samples.div_ceil(self.config.batch_size);
118        if self.config.drop_last && !self.num_samples.is_multiple_of(self.config.batch_size) {
119            total - 1
120        } else {
121            total
122        }
123    }
124}
125
126/// Data shuffler for randomizing training data.
127pub struct DataShuffler {
128    /// Random seed.
129    #[allow(dead_code)]
130    seed: Option<u64>,
131    /// Internal state for random number generation.
132    rng_state: u64,
133}
134
135impl DataShuffler {
136    /// Create a new data shuffler.
137    pub fn new(seed: Option<u64>) -> Self {
138        Self {
139            seed,
140            rng_state: seed.unwrap_or(42),
141        }
142    }
143
144    /// Shuffle indices.
145    pub fn shuffle(&mut self, indices: &mut [usize]) {
146        for i in (1..indices.len()).rev() {
147            self.rng_state = self
148                .rng_state
149                .wrapping_mul(6364136223846793005)
150                .wrapping_add(1);
151            let j = (self.rng_state % (i as u64 + 1)) as usize;
152            indices.swap(i, j);
153        }
154    }
155
156    /// Generate a random permutation.
157    pub fn permutation(&mut self, n: usize) -> Vec<usize> {
158        let mut indices: Vec<usize> = (0..n).collect();
159        self.shuffle(&mut indices);
160        indices
161    }
162}
163
164/// Extract batches from data arrays.
165pub fn extract_batch(
166    data: &ArrayView<f64, Ix2>,
167    indices: &[usize],
168) -> TrainResult<Array<f64, Ix2>> {
169    let batch_size = indices.len();
170    let num_features = data.ncols();
171    let mut batch = Array::zeros((batch_size, num_features));
172
173    for (i, &idx) in indices.iter().enumerate() {
174        if idx >= data.nrows() {
175            return Err(TrainError::BatchError(format!(
176                "Index {} out of bounds for data with {} rows",
177                idx,
178                data.nrows()
179            )));
180        }
181        batch.row_mut(i).assign(&data.row(idx));
182    }
183
184    Ok(batch)
185}
186
187/// Stratified batch sampler for balanced class sampling.
188#[allow(dead_code)]
189pub struct StratifiedSampler {
190    /// Class labels for each sample.
191    labels: Vec<usize>,
192    /// Indices for each class.
193    class_indices: Vec<Vec<usize>>,
194    /// Current position in each class.
195    class_positions: Vec<usize>,
196    /// Batch size.
197    batch_size: usize,
198    /// Random seed.
199    seed: Option<u64>,
200}
201
202impl StratifiedSampler {
203    /// Create a new stratified sampler.
204    #[allow(dead_code)]
205    pub fn new(labels: Vec<usize>, batch_size: usize, seed: Option<u64>) -> TrainResult<Self> {
206        if labels.is_empty() {
207            return Err(TrainError::BatchError("Empty labels".to_string()));
208        }
209
210        // Find unique classes
211        let unique_classes: HashSet<usize> = labels.iter().copied().collect();
212        let num_classes = unique_classes.len();
213
214        // Group indices by class
215        let mut class_indices = vec![Vec::new(); num_classes];
216        for (idx, &label) in labels.iter().enumerate() {
217            class_indices[label].push(idx);
218        }
219
220        // Shuffle each class independently
221        let mut shuffler = DataShuffler::new(seed);
222        for class_idx in &mut class_indices {
223            shuffler.shuffle(class_idx);
224        }
225
226        Ok(Self {
227            labels,
228            class_indices,
229            class_positions: vec![0; num_classes],
230            batch_size,
231            seed,
232        })
233    }
234
235    /// Get next stratified batch.
236    #[allow(dead_code)]
237    pub fn next_batch(&mut self) -> Option<Vec<usize>> {
238        let num_classes = self.class_indices.len();
239        let samples_per_class = self.batch_size / num_classes;
240
241        let mut batch_indices = Vec::new();
242
243        for class_id in 0..num_classes {
244            let class_samples = &self.class_indices[class_id];
245            let pos = self.class_positions[class_id];
246
247            // Check if we have enough samples for this class
248            if pos + samples_per_class > class_samples.len() {
249                // Not enough samples for a complete stratified batch
250                return None;
251            }
252
253            // Add samples from this class
254            for i in 0..samples_per_class {
255                batch_indices.push(class_samples[pos + i]);
256            }
257
258            self.class_positions[class_id] += samples_per_class;
259        }
260
261        if batch_indices.is_empty() {
262            None
263        } else {
264            Some(batch_indices)
265        }
266    }
267
268    /// Reset sampler.
269    #[allow(dead_code)]
270    pub fn reset(&mut self) {
271        self.class_positions.fill(0);
272
273        // Re-shuffle each class
274        let mut shuffler = DataShuffler::new(self.seed);
275        for class_idx in &mut self.class_indices {
276            shuffler.shuffle(class_idx);
277        }
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284    use scirs2_core::ndarray::array;
285
286    #[test]
287    fn test_batch_iterator() {
288        let config = BatchConfig {
289            batch_size: 3,
290            shuffle: false,
291            drop_last: false,
292            seed: Some(42),
293        };
294        let mut iter = BatchIterator::new(10, config);
295
296        let batch1 = iter.next_batch().unwrap();
297        assert_eq!(batch1.len(), 3);
298
299        let batch2 = iter.next_batch().unwrap();
300        assert_eq!(batch2.len(), 3);
301
302        let batch3 = iter.next_batch().unwrap();
303        assert_eq!(batch3.len(), 3);
304
305        let batch4 = iter.next_batch().unwrap();
306        assert_eq!(batch4.len(), 1); // Last batch with remaining samples
307
308        assert!(iter.next_batch().is_none());
309    }
310
311    #[test]
312    fn test_batch_iterator_drop_last() {
313        let config = BatchConfig {
314            batch_size: 3,
315            shuffle: false,
316            drop_last: true,
317            seed: Some(42),
318        };
319        let mut iter = BatchIterator::new(10, config);
320
321        let batch1 = iter.next_batch().unwrap();
322        assert_eq!(batch1.len(), 3);
323
324        let batch2 = iter.next_batch().unwrap();
325        assert_eq!(batch2.len(), 3);
326
327        let batch3 = iter.next_batch().unwrap();
328        assert_eq!(batch3.len(), 3);
329
330        assert!(iter.next_batch().is_none()); // Last incomplete batch is dropped
331    }
332
333    #[test]
334    fn test_data_shuffler() {
335        let mut shuffler = DataShuffler::new(Some(42));
336        let mut indices = vec![0, 1, 2, 3, 4];
337        let original = indices.clone();
338
339        shuffler.shuffle(&mut indices);
340        assert_ne!(indices, original); // Should be shuffled
341        assert_eq!(indices.len(), original.len());
342    }
343
344    #[test]
345    fn test_extract_batch() {
346        let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
347        let indices = vec![0, 2];
348
349        let batch = extract_batch(&data.view(), &indices).unwrap();
350        assert_eq!(batch.shape(), &[2, 2]);
351        assert_eq!(batch[[0, 0]], 1.0);
352        assert_eq!(batch[[0, 1]], 2.0);
353        assert_eq!(batch[[1, 0]], 5.0);
354        assert_eq!(batch[[1, 1]], 6.0);
355    }
356
357    #[test]
358    fn test_stratified_sampler() {
359        let labels = vec![0, 0, 0, 1, 1, 1, 2, 2, 2];
360        let mut sampler = StratifiedSampler::new(labels, 6, Some(42)).unwrap();
361
362        let batch = sampler.next_batch().unwrap();
363        assert_eq!(batch.len(), 6);
364
365        // Count class distribution in batch
366        let mut class_counts = vec![0; 3];
367        for &idx in &batch {
368            class_counts[sampler.labels[idx]] += 1;
369        }
370
371        // Each class should have 2 samples (6 / 3 classes)
372        assert_eq!(class_counts, vec![2, 2, 2]);
373    }
374}