Skip to main content

torsh_data/sampler/
batch.rs

1//! Batch sampling functionality
2//!
3//! This module provides utilities for converting individual samplers into
4//! batch samplers that yield batches of indices instead of individual indices.
5
6#[cfg(not(feature = "std"))]
7use alloc::vec::Vec;
8
9use super::core::{BatchSampler, Sampler};
10
11/// Wrapper that converts any sampler into a batch sampler
12///
13/// This sampler takes an underlying sampler and groups its output into batches
14/// of a specified size. The last batch may be smaller than the batch size
15/// unless `drop_last` is set to true.
16///
17/// # Examples
18///
19/// ```rust,ignore
20/// use torsh_data::sampler::{SequentialSampler, BatchingSampler, BatchSampler};
21///
22/// let base_sampler = SequentialSampler::new(10);
23/// let batch_sampler = BatchingSampler::new(base_sampler, 3, false);
24///
25/// let batches: Vec<Vec<usize>> = batch_sampler.iter().collect();
26/// assert_eq!(batches.len(), 4); // [0,1,2], [3,4,5], [6,7,8], [9]
27/// ```
28#[derive(Debug, Clone)]
29pub struct BatchingSampler<S: Sampler> {
30    sampler: S,
31    batch_size: usize,
32    drop_last: bool,
33}
34
35impl<S: Sampler> BatchingSampler<S> {
36    /// Create a new batching sampler
37    ///
38    /// # Arguments
39    ///
40    /// * `sampler` - The underlying sampler to batch
41    /// * `batch_size` - Size of each batch
42    /// * `drop_last` - Whether to drop the last incomplete batch
43    ///
44    /// # Panics
45    ///
46    /// Panics if `batch_size` is 0
47    pub fn new(sampler: S, batch_size: usize, drop_last: bool) -> Self {
48        assert!(batch_size > 0, "Batch size must be positive");
49        Self {
50            sampler,
51            batch_size,
52            drop_last,
53        }
54    }
55
56    /// Get the batch size
57    pub fn batch_size(&self) -> usize {
58        self.batch_size
59    }
60
61    /// Check if dropping last incomplete batch
62    pub fn drop_last(&self) -> bool {
63        self.drop_last
64    }
65
66    /// Get a reference to the underlying sampler
67    pub fn sampler(&self) -> &S {
68        &self.sampler
69    }
70
71    /// Get the underlying sampler by value
72    pub fn into_sampler(self) -> S {
73        self.sampler
74    }
75
76    /// Convert this batching sampler into a distributed version
77    ///
78    /// This creates a distributed wrapper around the underlying sampler
79    /// and then wraps it with a new BatchingSampler.
80    ///
81    /// # Arguments
82    ///
83    /// * `num_replicas` - Total number of processes
84    /// * `rank` - Current process rank (0-based)
85    pub fn into_distributed(
86        self,
87        num_replicas: usize,
88        rank: usize,
89    ) -> BatchingSampler<super::distributed::DistributedWrapper<S>> {
90        let distributed_sampler = self.sampler.into_distributed(num_replicas, rank);
91        BatchingSampler::new(distributed_sampler, self.batch_size, self.drop_last)
92    }
93}
94
95impl<S: Sampler> BatchSampler for BatchingSampler<S> {
96    type Iter = BatchSamplerIter<S::Iter>;
97
98    fn iter(&self) -> Self::Iter {
99        BatchSamplerIter::new(self.sampler.iter(), self.batch_size, self.drop_last)
100    }
101
102    fn num_batches(&self) -> usize {
103        let total_samples = self.sampler.len();
104        if total_samples == 0 {
105            return 0;
106        }
107
108        if self.drop_last {
109            total_samples / self.batch_size
110        } else {
111            (total_samples + self.batch_size - 1) / self.batch_size
112        }
113    }
114}
115
116/// Iterator that groups indices from an underlying iterator into batches
117#[derive(Debug)]
118pub struct BatchSamplerIter<I: Iterator<Item = usize>> {
119    inner: I,
120    batch_size: usize,
121    drop_last: bool,
122}
123
124impl<I: Iterator<Item = usize>> BatchSamplerIter<I> {
125    /// Create a new batch sampler iterator
126    pub fn new(inner: I, batch_size: usize, drop_last: bool) -> Self {
127        Self {
128            inner,
129            batch_size,
130            drop_last,
131        }
132    }
133
134    /// Get the batch size
135    pub fn batch_size(&self) -> usize {
136        self.batch_size
137    }
138
139    /// Check if dropping last incomplete batch
140    pub fn drop_last(&self) -> bool {
141        self.drop_last
142    }
143}
144
145impl<I: Iterator<Item = usize>> Iterator for BatchSamplerIter<I> {
146    type Item = Vec<usize>;
147
148    fn next(&mut self) -> Option<Self::Item> {
149        let mut batch = Vec::with_capacity(self.batch_size);
150
151        // Collect items for this batch
152        for _ in 0..self.batch_size {
153            if let Some(item) = self.inner.next() {
154                batch.push(item);
155            } else {
156                break;
157            }
158        }
159
160        if batch.is_empty() {
161            None
162        } else if batch.len() < self.batch_size && self.drop_last {
163            None
164        } else {
165            Some(batch)
166        }
167    }
168
169    fn size_hint(&self) -> (usize, Option<usize>) {
170        let (lower, upper) = self.inner.size_hint();
171
172        let lower_batches = if self.drop_last {
173            lower / self.batch_size
174        } else {
175            (lower + self.batch_size - 1) / self.batch_size
176        };
177
178        let upper_batches = upper.map(|u| {
179            if self.drop_last {
180                u / self.batch_size
181            } else {
182                (u + self.batch_size - 1) / self.batch_size
183            }
184        });
185
186        (lower_batches, upper_batches)
187    }
188}
189
190/// Create a batch sampler from any sampler
191///
192/// Convenience function for creating a batch sampler.
193///
194/// # Arguments
195///
196/// * `sampler` - The underlying sampler
197/// * `batch_size` - Size of each batch
198/// * `drop_last` - Whether to drop the last incomplete batch
199pub fn batch<S: Sampler>(sampler: S, batch_size: usize, drop_last: bool) -> BatchingSampler<S> {
200    BatchingSampler::new(sampler, batch_size, drop_last)
201}
202
203/// Create a batch sampler that keeps the last incomplete batch
204///
205/// Convenience function for creating a batch sampler that doesn't drop
206/// the last batch even if it's incomplete.
207///
208/// # Arguments
209///
210/// * `sampler` - The underlying sampler
211/// * `batch_size` - Size of each batch
212pub fn batch_keep_last<S: Sampler>(sampler: S, batch_size: usize) -> BatchingSampler<S> {
213    BatchingSampler::new(sampler, batch_size, false)
214}
215
216/// Create a batch sampler that drops the last incomplete batch
217///
218/// Convenience function for creating a batch sampler that drops
219/// the last batch if it's incomplete.
220///
221/// # Arguments
222///
223/// * `sampler` - The underlying sampler
224/// * `batch_size` - Size of each batch
225pub fn batch_drop_last<S: Sampler>(sampler: S, batch_size: usize) -> BatchingSampler<S> {
226    BatchingSampler::new(sampler, batch_size, true)
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use crate::sampler::basic::SequentialSampler;
233
234    #[test]
235    fn test_batching_sampler_basic() {
236        let base_sampler = SequentialSampler::new(10);
237        let batch_sampler = BatchingSampler::new(base_sampler, 3, false);
238
239        assert_eq!(batch_sampler.batch_size(), 3);
240        assert!(!batch_sampler.drop_last());
241        assert_eq!(batch_sampler.num_batches(), 4); // 10 items, 3 per batch = 4 batches
242
243        let batches: Vec<Vec<usize>> = batch_sampler.iter().collect();
244        assert_eq!(batches.len(), 4);
245        assert_eq!(batches[0], vec![0, 1, 2]);
246        assert_eq!(batches[1], vec![3, 4, 5]);
247        assert_eq!(batches[2], vec![6, 7, 8]);
248        assert_eq!(batches[3], vec![9]); // Last incomplete batch
249    }
250
251    #[test]
252    fn test_batching_sampler_drop_last() {
253        let base_sampler = SequentialSampler::new(10);
254        let batch_sampler = BatchingSampler::new(base_sampler, 3, true);
255
256        assert!(batch_sampler.drop_last());
257        assert_eq!(batch_sampler.num_batches(), 3); // Drops last incomplete batch
258
259        let batches: Vec<Vec<usize>> = batch_sampler.iter().collect();
260        assert_eq!(batches.len(), 3);
261        assert_eq!(batches[0], vec![0, 1, 2]);
262        assert_eq!(batches[1], vec![3, 4, 5]);
263        assert_eq!(batches[2], vec![6, 7, 8]);
264        // Last batch [9] is dropped
265    }
266
267    #[test]
268    fn test_batching_sampler_exact_division() {
269        let base_sampler = SequentialSampler::new(9);
270        let batch_sampler = BatchingSampler::new(base_sampler, 3, true);
271
272        assert_eq!(batch_sampler.num_batches(), 3);
273
274        let batches: Vec<Vec<usize>> = batch_sampler.iter().collect();
275        assert_eq!(batches.len(), 3);
276        assert_eq!(batches[0], vec![0, 1, 2]);
277        assert_eq!(batches[1], vec![3, 4, 5]);
278        assert_eq!(batches[2], vec![6, 7, 8]);
279    }
280
281    #[test]
282    fn test_batching_sampler_empty() {
283        let base_sampler = SequentialSampler::new(0);
284        let batch_sampler = BatchingSampler::new(base_sampler, 3, false);
285
286        assert_eq!(batch_sampler.num_batches(), 0);
287        assert!(batch_sampler.is_empty());
288
289        let batches: Vec<Vec<usize>> = batch_sampler.iter().collect();
290        assert_eq!(batches.len(), 0);
291    }
292
293    #[test]
294    fn test_batching_sampler_single_item() {
295        let base_sampler = SequentialSampler::new(1);
296        let batch_sampler = BatchingSampler::new(base_sampler, 3, false);
297
298        assert_eq!(batch_sampler.num_batches(), 1);
299
300        let batches: Vec<Vec<usize>> = batch_sampler.iter().collect();
301        assert_eq!(batches.len(), 1);
302        assert_eq!(batches[0], vec![0]);
303    }
304
305    #[test]
306    fn test_batching_sampler_single_item_drop_last() {
307        let base_sampler = SequentialSampler::new(1);
308        let batch_sampler = BatchingSampler::new(base_sampler, 3, true);
309
310        assert_eq!(batch_sampler.num_batches(), 0);
311
312        let batches: Vec<Vec<usize>> = batch_sampler.iter().collect();
313        assert_eq!(batches.len(), 0);
314    }
315
316    #[test]
317    #[should_panic(expected = "Batch size must be positive")]
318    fn test_batching_sampler_zero_batch_size() {
319        let base_sampler = SequentialSampler::new(10);
320        BatchingSampler::new(base_sampler, 0, false);
321    }
322
323    #[test]
324    fn test_batch_sampler_iter_size_hint() {
325        let base_sampler = SequentialSampler::new(10);
326        let batch_sampler = BatchingSampler::new(base_sampler, 3, false);
327
328        let iter = batch_sampler.iter();
329        assert_eq!(iter.size_hint(), (4, Some(4)));
330
331        let batch_sampler_drop = BatchingSampler::new(SequentialSampler::new(10), 3, true);
332        let iter_drop = batch_sampler_drop.iter();
333        assert_eq!(iter_drop.size_hint(), (3, Some(3)));
334    }
335
336    #[test]
337    fn test_batching_sampler_into_sampler() {
338        let base_sampler = SequentialSampler::new(5);
339        let batch_sampler = BatchingSampler::new(base_sampler, 2, false);
340
341        let recovered_sampler = batch_sampler.into_sampler();
342        assert_eq!(recovered_sampler.len(), 5);
343    }
344
345    #[test]
346    fn test_convenience_functions() {
347        let base_sampler = SequentialSampler::new(10);
348
349        let batch_keep = batch_keep_last(base_sampler.clone(), 3);
350        assert!(!batch_keep.drop_last());
351        assert_eq!(batch_keep.num_batches(), 4);
352
353        let batch_drop = batch_drop_last(base_sampler.clone(), 3);
354        assert!(batch_drop.drop_last());
355        assert_eq!(batch_drop.num_batches(), 3);
356
357        let batch_generic = batch(base_sampler, 3, true);
358        assert!(batch_generic.drop_last());
359        assert_eq!(batch_generic.num_batches(), 3);
360    }
361
362    #[test]
363    fn test_batch_sampler_iter_properties() {
364        let base_sampler = SequentialSampler::new(7);
365        let batch_sampler = BatchingSampler::new(base_sampler, 3, false);
366
367        let mut iter = batch_sampler.iter();
368        assert_eq!(iter.batch_size(), 3);
369        assert!(!iter.drop_last());
370
371        // Test collecting batches one by one
372        let batch1 = iter.next().expect("iterator should have a next element");
373        assert_eq!(batch1, vec![0, 1, 2]);
374
375        let batch2 = iter.next().expect("iterator should have a next element");
376        assert_eq!(batch2, vec![3, 4, 5]);
377
378        let batch3 = iter.next().expect("iterator should have a next element");
379        assert_eq!(batch3, vec![6]);
380
381        assert!(iter.next().is_none());
382    }
383
384    #[test]
385    fn test_batch_sizes() {
386        // Test various batch sizes
387        let test_cases = vec![
388            (10, 1, false, 10), // Each item is its own batch
389            (10, 10, false, 1), // Single batch with all items
390            (10, 15, false, 1), // Batch size larger than dataset
391            (0, 5, false, 0),   // Empty dataset
392        ];
393
394        for (dataset_size, batch_size, drop_last, expected_batches) in test_cases {
395            if dataset_size == 0 && batch_size > 0 {
396                // Skip invalid combinations handled by SequentialSampler
397                continue;
398            }
399
400            let base_sampler = SequentialSampler::new(dataset_size);
401            let batch_sampler = BatchingSampler::new(base_sampler, batch_size, drop_last);
402
403            assert_eq!(
404                batch_sampler.num_batches(),
405                expected_batches,
406                "Failed for dataset_size={}, batch_size={}, drop_last={}",
407                dataset_size,
408                batch_size,
409                drop_last
410            );
411
412            let batches: Vec<Vec<usize>> = batch_sampler.iter().collect();
413            assert_eq!(
414                batches.len(),
415                expected_batches,
416                "Actual batch count doesn't match for dataset_size={}, batch_size={}, drop_last={}",
417                dataset_size,
418                batch_size,
419                drop_last
420            );
421        }
422    }
423
424    #[test]
425    fn test_edge_case_large_batch_size() {
426        let base_sampler = SequentialSampler::new(3);
427        let batch_sampler = BatchingSampler::new(base_sampler, 100, false);
428
429        let batches: Vec<Vec<usize>> = batch_sampler.iter().collect();
430        assert_eq!(batches.len(), 1);
431        assert_eq!(batches[0], vec![0, 1, 2]);
432    }
433}