1#[cfg(not(feature = "std"))]
7use alloc::vec::Vec;
8
9use super::core::{BatchSampler, Sampler};
10
11#[derive(Debug, Clone)]
29pub struct BatchingSampler<S: Sampler> {
30 sampler: S,
31 batch_size: usize,
32 drop_last: bool,
33}
34
35impl<S: Sampler> BatchingSampler<S> {
36 pub fn new(sampler: S, batch_size: usize, drop_last: bool) -> Self {
48 assert!(batch_size > 0, "Batch size must be positive");
49 Self {
50 sampler,
51 batch_size,
52 drop_last,
53 }
54 }
55
56 pub fn batch_size(&self) -> usize {
58 self.batch_size
59 }
60
61 pub fn drop_last(&self) -> bool {
63 self.drop_last
64 }
65
66 pub fn sampler(&self) -> &S {
68 &self.sampler
69 }
70
71 pub fn into_sampler(self) -> S {
73 self.sampler
74 }
75
76 pub fn into_distributed(
86 self,
87 num_replicas: usize,
88 rank: usize,
89 ) -> BatchingSampler<super::distributed::DistributedWrapper<S>> {
90 let distributed_sampler = self.sampler.into_distributed(num_replicas, rank);
91 BatchingSampler::new(distributed_sampler, self.batch_size, self.drop_last)
92 }
93}
94
95impl<S: Sampler> BatchSampler for BatchingSampler<S> {
96 type Iter = BatchSamplerIter<S::Iter>;
97
98 fn iter(&self) -> Self::Iter {
99 BatchSamplerIter::new(self.sampler.iter(), self.batch_size, self.drop_last)
100 }
101
102 fn num_batches(&self) -> usize {
103 let total_samples = self.sampler.len();
104 if total_samples == 0 {
105 return 0;
106 }
107
108 if self.drop_last {
109 total_samples / self.batch_size
110 } else {
111 (total_samples + self.batch_size - 1) / self.batch_size
112 }
113 }
114}
115
116#[derive(Debug)]
118pub struct BatchSamplerIter<I: Iterator<Item = usize>> {
119 inner: I,
120 batch_size: usize,
121 drop_last: bool,
122}
123
124impl<I: Iterator<Item = usize>> BatchSamplerIter<I> {
125 pub fn new(inner: I, batch_size: usize, drop_last: bool) -> Self {
127 Self {
128 inner,
129 batch_size,
130 drop_last,
131 }
132 }
133
134 pub fn batch_size(&self) -> usize {
136 self.batch_size
137 }
138
139 pub fn drop_last(&self) -> bool {
141 self.drop_last
142 }
143}
144
145impl<I: Iterator<Item = usize>> Iterator for BatchSamplerIter<I> {
146 type Item = Vec<usize>;
147
148 fn next(&mut self) -> Option<Self::Item> {
149 let mut batch = Vec::with_capacity(self.batch_size);
150
151 for _ in 0..self.batch_size {
153 if let Some(item) = self.inner.next() {
154 batch.push(item);
155 } else {
156 break;
157 }
158 }
159
160 if batch.is_empty() {
161 None
162 } else if batch.len() < self.batch_size && self.drop_last {
163 None
164 } else {
165 Some(batch)
166 }
167 }
168
169 fn size_hint(&self) -> (usize, Option<usize>) {
170 let (lower, upper) = self.inner.size_hint();
171
172 let lower_batches = if self.drop_last {
173 lower / self.batch_size
174 } else {
175 (lower + self.batch_size - 1) / self.batch_size
176 };
177
178 let upper_batches = upper.map(|u| {
179 if self.drop_last {
180 u / self.batch_size
181 } else {
182 (u + self.batch_size - 1) / self.batch_size
183 }
184 });
185
186 (lower_batches, upper_batches)
187 }
188}
189
190pub fn batch<S: Sampler>(sampler: S, batch_size: usize, drop_last: bool) -> BatchingSampler<S> {
200 BatchingSampler::new(sampler, batch_size, drop_last)
201}
202
203pub fn batch_keep_last<S: Sampler>(sampler: S, batch_size: usize) -> BatchingSampler<S> {
213 BatchingSampler::new(sampler, batch_size, false)
214}
215
216pub fn batch_drop_last<S: Sampler>(sampler: S, batch_size: usize) -> BatchingSampler<S> {
226 BatchingSampler::new(sampler, batch_size, true)
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use crate::sampler::basic::SequentialSampler;
233
234 #[test]
235 fn test_batching_sampler_basic() {
236 let base_sampler = SequentialSampler::new(10);
237 let batch_sampler = BatchingSampler::new(base_sampler, 3, false);
238
239 assert_eq!(batch_sampler.batch_size(), 3);
240 assert!(!batch_sampler.drop_last());
241 assert_eq!(batch_sampler.num_batches(), 4); let batches: Vec<Vec<usize>> = batch_sampler.iter().collect();
244 assert_eq!(batches.len(), 4);
245 assert_eq!(batches[0], vec![0, 1, 2]);
246 assert_eq!(batches[1], vec![3, 4, 5]);
247 assert_eq!(batches[2], vec![6, 7, 8]);
248 assert_eq!(batches[3], vec![9]); }
250
251 #[test]
252 fn test_batching_sampler_drop_last() {
253 let base_sampler = SequentialSampler::new(10);
254 let batch_sampler = BatchingSampler::new(base_sampler, 3, true);
255
256 assert!(batch_sampler.drop_last());
257 assert_eq!(batch_sampler.num_batches(), 3); let batches: Vec<Vec<usize>> = batch_sampler.iter().collect();
260 assert_eq!(batches.len(), 3);
261 assert_eq!(batches[0], vec![0, 1, 2]);
262 assert_eq!(batches[1], vec![3, 4, 5]);
263 assert_eq!(batches[2], vec![6, 7, 8]);
264 }
266
267 #[test]
268 fn test_batching_sampler_exact_division() {
269 let base_sampler = SequentialSampler::new(9);
270 let batch_sampler = BatchingSampler::new(base_sampler, 3, true);
271
272 assert_eq!(batch_sampler.num_batches(), 3);
273
274 let batches: Vec<Vec<usize>> = batch_sampler.iter().collect();
275 assert_eq!(batches.len(), 3);
276 assert_eq!(batches[0], vec![0, 1, 2]);
277 assert_eq!(batches[1], vec![3, 4, 5]);
278 assert_eq!(batches[2], vec![6, 7, 8]);
279 }
280
281 #[test]
282 fn test_batching_sampler_empty() {
283 let base_sampler = SequentialSampler::new(0);
284 let batch_sampler = BatchingSampler::new(base_sampler, 3, false);
285
286 assert_eq!(batch_sampler.num_batches(), 0);
287 assert!(batch_sampler.is_empty());
288
289 let batches: Vec<Vec<usize>> = batch_sampler.iter().collect();
290 assert_eq!(batches.len(), 0);
291 }
292
293 #[test]
294 fn test_batching_sampler_single_item() {
295 let base_sampler = SequentialSampler::new(1);
296 let batch_sampler = BatchingSampler::new(base_sampler, 3, false);
297
298 assert_eq!(batch_sampler.num_batches(), 1);
299
300 let batches: Vec<Vec<usize>> = batch_sampler.iter().collect();
301 assert_eq!(batches.len(), 1);
302 assert_eq!(batches[0], vec![0]);
303 }
304
305 #[test]
306 fn test_batching_sampler_single_item_drop_last() {
307 let base_sampler = SequentialSampler::new(1);
308 let batch_sampler = BatchingSampler::new(base_sampler, 3, true);
309
310 assert_eq!(batch_sampler.num_batches(), 0);
311
312 let batches: Vec<Vec<usize>> = batch_sampler.iter().collect();
313 assert_eq!(batches.len(), 0);
314 }
315
316 #[test]
317 #[should_panic(expected = "Batch size must be positive")]
318 fn test_batching_sampler_zero_batch_size() {
319 let base_sampler = SequentialSampler::new(10);
320 BatchingSampler::new(base_sampler, 0, false);
321 }
322
323 #[test]
324 fn test_batch_sampler_iter_size_hint() {
325 let base_sampler = SequentialSampler::new(10);
326 let batch_sampler = BatchingSampler::new(base_sampler, 3, false);
327
328 let iter = batch_sampler.iter();
329 assert_eq!(iter.size_hint(), (4, Some(4)));
330
331 let batch_sampler_drop = BatchingSampler::new(SequentialSampler::new(10), 3, true);
332 let iter_drop = batch_sampler_drop.iter();
333 assert_eq!(iter_drop.size_hint(), (3, Some(3)));
334 }
335
336 #[test]
337 fn test_batching_sampler_into_sampler() {
338 let base_sampler = SequentialSampler::new(5);
339 let batch_sampler = BatchingSampler::new(base_sampler, 2, false);
340
341 let recovered_sampler = batch_sampler.into_sampler();
342 assert_eq!(recovered_sampler.len(), 5);
343 }
344
345 #[test]
346 fn test_convenience_functions() {
347 let base_sampler = SequentialSampler::new(10);
348
349 let batch_keep = batch_keep_last(base_sampler.clone(), 3);
350 assert!(!batch_keep.drop_last());
351 assert_eq!(batch_keep.num_batches(), 4);
352
353 let batch_drop = batch_drop_last(base_sampler.clone(), 3);
354 assert!(batch_drop.drop_last());
355 assert_eq!(batch_drop.num_batches(), 3);
356
357 let batch_generic = batch(base_sampler, 3, true);
358 assert!(batch_generic.drop_last());
359 assert_eq!(batch_generic.num_batches(), 3);
360 }
361
362 #[test]
363 fn test_batch_sampler_iter_properties() {
364 let base_sampler = SequentialSampler::new(7);
365 let batch_sampler = BatchingSampler::new(base_sampler, 3, false);
366
367 let mut iter = batch_sampler.iter();
368 assert_eq!(iter.batch_size(), 3);
369 assert!(!iter.drop_last());
370
371 let batch1 = iter.next().expect("iterator should have a next element");
373 assert_eq!(batch1, vec![0, 1, 2]);
374
375 let batch2 = iter.next().expect("iterator should have a next element");
376 assert_eq!(batch2, vec![3, 4, 5]);
377
378 let batch3 = iter.next().expect("iterator should have a next element");
379 assert_eq!(batch3, vec![6]);
380
381 assert!(iter.next().is_none());
382 }
383
384 #[test]
385 fn test_batch_sizes() {
386 let test_cases = vec![
388 (10, 1, false, 10), (10, 10, false, 1), (10, 15, false, 1), (0, 5, false, 0), ];
393
394 for (dataset_size, batch_size, drop_last, expected_batches) in test_cases {
395 if dataset_size == 0 && batch_size > 0 {
396 continue;
398 }
399
400 let base_sampler = SequentialSampler::new(dataset_size);
401 let batch_sampler = BatchingSampler::new(base_sampler, batch_size, drop_last);
402
403 assert_eq!(
404 batch_sampler.num_batches(),
405 expected_batches,
406 "Failed for dataset_size={}, batch_size={}, drop_last={}",
407 dataset_size,
408 batch_size,
409 drop_last
410 );
411
412 let batches: Vec<Vec<usize>> = batch_sampler.iter().collect();
413 assert_eq!(
414 batches.len(),
415 expected_batches,
416 "Actual batch count doesn't match for dataset_size={}, batch_size={}, drop_last={}",
417 dataset_size,
418 batch_size,
419 drop_last
420 );
421 }
422 }
423
424 #[test]
425 fn test_edge_case_large_batch_size() {
426 let base_sampler = SequentialSampler::new(3);
427 let batch_sampler = BatchingSampler::new(base_sampler, 100, false);
428
429 let batches: Vec<Vec<usize>> = batch_sampler.iter().collect();
430 assert_eq!(batches.len(), 1);
431 assert_eq!(batches[0], vec![0, 1, 2]);
432 }
433}