1use crate::{TrainError, TrainResult};
4use scirs2_core::ndarray::{Array, ArrayView, Ix2};
5use std::collections::HashSet;
6
7#[derive(Debug, Clone)]
9pub struct BatchConfig {
10 pub batch_size: usize,
12 pub shuffle: bool,
14 pub drop_last: bool,
16 pub seed: Option<u64>,
18}
19
20impl Default for BatchConfig {
21 fn default() -> Self {
22 Self {
23 batch_size: 32,
24 shuffle: true,
25 drop_last: false,
26 seed: None,
27 }
28 }
29}
30
31pub struct BatchIterator {
33 config: BatchConfig,
35 num_samples: usize,
37 current_batch: usize,
39 indices: Vec<usize>,
41}
42
43impl BatchIterator {
44 pub fn new(num_samples: usize, config: BatchConfig) -> Self {
46 let mut indices: Vec<usize> = (0..num_samples).collect();
47
48 if config.shuffle {
49 if let Some(seed) = config.seed {
51 let mut rng_state = seed;
53 for i in (1..indices.len()).rev() {
54 rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
55 let j = (rng_state % (i as u64 + 1)) as usize;
56 indices.swap(i, j);
57 }
58 } else {
59 use std::collections::hash_map::RandomState;
61 use std::hash::BuildHasher;
62 let hasher = RandomState::new();
63 indices.sort_by_cached_key(|&i| hasher.hash_one(i));
64 }
65 }
66
67 Self {
68 config,
69 num_samples,
70 current_batch: 0,
71 indices,
72 }
73 }
74
75 pub fn next_batch(&mut self) -> Option<Vec<usize>> {
77 if self.current_batch * self.config.batch_size >= self.num_samples {
78 return None;
79 }
80
81 let start = self.current_batch * self.config.batch_size;
82 let end = (start + self.config.batch_size).min(self.num_samples);
83
84 if self.config.drop_last && end - start < self.config.batch_size {
85 return None;
86 }
87
88 self.current_batch += 1;
89 Some(self.indices[start..end].to_vec())
90 }
91
92 pub fn reset(&mut self) {
94 self.current_batch = 0;
95
96 if self.config.shuffle {
97 if let Some(seed) = self.config.seed {
99 let mut rng_state = seed.wrapping_add(self.current_batch as u64);
100 for i in (1..self.indices.len()).rev() {
101 rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
102 let j = (rng_state % (i as u64 + 1)) as usize;
103 self.indices.swap(i, j);
104 }
105 } else {
106 use std::collections::hash_map::RandomState;
107 use std::hash::BuildHasher;
108 let hasher = RandomState::new();
109 self.indices
110 .sort_by_cached_key(|&i| hasher.hash_one((i, self.current_batch)));
111 }
112 }
113 }
114
115 pub fn num_batches(&self) -> usize {
117 let total = self.num_samples.div_ceil(self.config.batch_size);
118 if self.config.drop_last && !self.num_samples.is_multiple_of(self.config.batch_size) {
119 total - 1
120 } else {
121 total
122 }
123 }
124}
125
126pub struct DataShuffler {
128 #[allow(dead_code)]
130 seed: Option<u64>,
131 rng_state: u64,
133}
134
135impl DataShuffler {
136 pub fn new(seed: Option<u64>) -> Self {
138 Self {
139 seed,
140 rng_state: seed.unwrap_or(42),
141 }
142 }
143
144 pub fn shuffle(&mut self, indices: &mut [usize]) {
146 for i in (1..indices.len()).rev() {
147 self.rng_state = self
148 .rng_state
149 .wrapping_mul(6364136223846793005)
150 .wrapping_add(1);
151 let j = (self.rng_state % (i as u64 + 1)) as usize;
152 indices.swap(i, j);
153 }
154 }
155
156 pub fn permutation(&mut self, n: usize) -> Vec<usize> {
158 let mut indices: Vec<usize> = (0..n).collect();
159 self.shuffle(&mut indices);
160 indices
161 }
162}
163
164pub fn extract_batch(
166 data: &ArrayView<f64, Ix2>,
167 indices: &[usize],
168) -> TrainResult<Array<f64, Ix2>> {
169 let batch_size = indices.len();
170 let num_features = data.ncols();
171 let mut batch = Array::zeros((batch_size, num_features));
172
173 for (i, &idx) in indices.iter().enumerate() {
174 if idx >= data.nrows() {
175 return Err(TrainError::BatchError(format!(
176 "Index {} out of bounds for data with {} rows",
177 idx,
178 data.nrows()
179 )));
180 }
181 batch.row_mut(i).assign(&data.row(idx));
182 }
183
184 Ok(batch)
185}
186
187#[allow(dead_code)]
189pub struct StratifiedSampler {
190 labels: Vec<usize>,
192 class_indices: Vec<Vec<usize>>,
194 class_positions: Vec<usize>,
196 batch_size: usize,
198 seed: Option<u64>,
200}
201
202impl StratifiedSampler {
203 #[allow(dead_code)]
205 pub fn new(labels: Vec<usize>, batch_size: usize, seed: Option<u64>) -> TrainResult<Self> {
206 if labels.is_empty() {
207 return Err(TrainError::BatchError("Empty labels".to_string()));
208 }
209
210 let unique_classes: HashSet<usize> = labels.iter().copied().collect();
212 let num_classes = unique_classes.len();
213
214 let mut class_indices = vec![Vec::new(); num_classes];
216 for (idx, &label) in labels.iter().enumerate() {
217 class_indices[label].push(idx);
218 }
219
220 let mut shuffler = DataShuffler::new(seed);
222 for class_idx in &mut class_indices {
223 shuffler.shuffle(class_idx);
224 }
225
226 Ok(Self {
227 labels,
228 class_indices,
229 class_positions: vec![0; num_classes],
230 batch_size,
231 seed,
232 })
233 }
234
235 #[allow(dead_code)]
237 pub fn next_batch(&mut self) -> Option<Vec<usize>> {
238 let num_classes = self.class_indices.len();
239 let samples_per_class = self.batch_size / num_classes;
240
241 let mut batch_indices = Vec::new();
242
243 for class_id in 0..num_classes {
244 let class_samples = &self.class_indices[class_id];
245 let pos = self.class_positions[class_id];
246
247 if pos + samples_per_class > class_samples.len() {
249 return None;
251 }
252
253 for i in 0..samples_per_class {
255 batch_indices.push(class_samples[pos + i]);
256 }
257
258 self.class_positions[class_id] += samples_per_class;
259 }
260
261 if batch_indices.is_empty() {
262 None
263 } else {
264 Some(batch_indices)
265 }
266 }
267
268 #[allow(dead_code)]
270 pub fn reset(&mut self) {
271 self.class_positions.fill(0);
272
273 let mut shuffler = DataShuffler::new(self.seed);
275 for class_idx in &mut self.class_indices {
276 shuffler.shuffle(class_idx);
277 }
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284 use scirs2_core::ndarray::array;
285
286 #[test]
287 fn test_batch_iterator() {
288 let config = BatchConfig {
289 batch_size: 3,
290 shuffle: false,
291 drop_last: false,
292 seed: Some(42),
293 };
294 let mut iter = BatchIterator::new(10, config);
295
296 let batch1 = iter.next_batch().unwrap();
297 assert_eq!(batch1.len(), 3);
298
299 let batch2 = iter.next_batch().unwrap();
300 assert_eq!(batch2.len(), 3);
301
302 let batch3 = iter.next_batch().unwrap();
303 assert_eq!(batch3.len(), 3);
304
305 let batch4 = iter.next_batch().unwrap();
306 assert_eq!(batch4.len(), 1); assert!(iter.next_batch().is_none());
309 }
310
311 #[test]
312 fn test_batch_iterator_drop_last() {
313 let config = BatchConfig {
314 batch_size: 3,
315 shuffle: false,
316 drop_last: true,
317 seed: Some(42),
318 };
319 let mut iter = BatchIterator::new(10, config);
320
321 let batch1 = iter.next_batch().unwrap();
322 assert_eq!(batch1.len(), 3);
323
324 let batch2 = iter.next_batch().unwrap();
325 assert_eq!(batch2.len(), 3);
326
327 let batch3 = iter.next_batch().unwrap();
328 assert_eq!(batch3.len(), 3);
329
330 assert!(iter.next_batch().is_none()); }
332
333 #[test]
334 fn test_data_shuffler() {
335 let mut shuffler = DataShuffler::new(Some(42));
336 let mut indices = vec![0, 1, 2, 3, 4];
337 let original = indices.clone();
338
339 shuffler.shuffle(&mut indices);
340 assert_ne!(indices, original); assert_eq!(indices.len(), original.len());
342 }
343
344 #[test]
345 fn test_extract_batch() {
346 let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
347 let indices = vec![0, 2];
348
349 let batch = extract_batch(&data.view(), &indices).unwrap();
350 assert_eq!(batch.shape(), &[2, 2]);
351 assert_eq!(batch[[0, 0]], 1.0);
352 assert_eq!(batch[[0, 1]], 2.0);
353 assert_eq!(batch[[1, 0]], 5.0);
354 assert_eq!(batch[[1, 1]], 6.0);
355 }
356
357 #[test]
358 fn test_stratified_sampler() {
359 let labels = vec![0, 0, 0, 1, 1, 1, 2, 2, 2];
360 let mut sampler = StratifiedSampler::new(labels, 6, Some(42)).unwrap();
361
362 let batch = sampler.next_batch().unwrap();
363 assert_eq!(batch.len(), 6);
364
365 let mut class_counts = vec![0; 3];
367 for &idx in &batch {
368 class_counts[sampler.labels[idx]] += 1;
369 }
370
371 assert_eq!(class_counts, vec![2, 2, 2]);
373 }
374}