Skip to main content

shrew_data/
loader.rs

1// DataLoader — batching, shuffling, iteration
2
3use std::collections::HashMap;
4
5use rand::rngs::StdRng;
6use rand::seq::SliceRandom;
7use rand::{thread_rng, SeedableRng};
8
9use rayon::prelude::*;
10
11use shrew_core::backend::Backend;
12use shrew_core::tensor::Tensor;
13use shrew_core::DType;
14
15use crate::dataset::{Dataset, Sample};
16use crate::transform::Transform;
17
18/// Configuration for the DataLoader.
19#[derive(Debug, Clone)]
20pub struct DataLoaderConfig {
21    /// Number of samples per batch.
22    pub batch_size: usize,
23    /// Whether to shuffle indices each epoch.
24    pub shuffle: bool,
25    /// Whether to drop the last incomplete batch.
26    pub drop_last: bool,
27    /// DType for the created tensors.
28    pub dtype: DType,
29    /// Number of parallel workers for sample fetching (0 = sequential).
30    pub num_workers: usize,
31    /// Optional random seed for reproducible shuffling.
32    pub seed: Option<u64>,
33}
34
35impl Default for DataLoaderConfig {
36    fn default() -> Self {
37        Self {
38            batch_size: 32,
39            shuffle: true,
40            drop_last: false,
41            dtype: DType::F32,
42            num_workers: 0,
43            seed: None,
44        }
45    }
46}
47
48impl DataLoaderConfig {
49    pub fn batch_size(mut self, bs: usize) -> Self {
50        self.batch_size = bs;
51        self
52    }
53
54    pub fn shuffle(mut self, s: bool) -> Self {
55        self.shuffle = s;
56        self
57    }
58
59    pub fn drop_last(mut self, d: bool) -> Self {
60        self.drop_last = d;
61        self
62    }
63
64    pub fn dtype(mut self, d: DType) -> Self {
65        self.dtype = d;
66        self
67    }
68
69    pub fn num_workers(mut self, n: usize) -> Self {
70        self.num_workers = n;
71        self
72    }
73
74    pub fn seed(mut self, s: u64) -> Self {
75        self.seed = Some(s);
76        self
77    }
78}
79
80/// A DataLoader wraps a Dataset and produces batches of tensors.
81///
82/// Each batch is a `HashMap<String, Tensor<B>>` with keys `"input"` and
83/// `"target"`, matching the format expected by `Executor::run()` and
84/// `Trainer::train()`.
85pub struct DataLoader<'a, B: Backend> {
86    dataset: &'a dyn Dataset,
87    config: DataLoaderConfig,
88    transforms: Vec<Box<dyn Transform>>,
89    device: B::Device,
90    indices: Vec<usize>,
91}
92
93impl<'a, B: Backend> DataLoader<'a, B> {
94    /// Create a new DataLoader over a dataset.
95    pub fn new(dataset: &'a dyn Dataset, device: B::Device, config: DataLoaderConfig) -> Self {
96        let indices: Vec<usize> = (0..dataset.len()).collect();
97        Self {
98            dataset,
99            config,
100            transforms: Vec::new(),
101            device,
102            indices,
103        }
104    }
105
106    /// Add a transform to apply to each sample.
107    pub fn with_transform(mut self, t: Box<dyn Transform>) -> Self {
108        self.transforms.push(t);
109        self
110    }
111
112    /// The number of batches per epoch.
113    pub fn num_batches(&self) -> usize {
114        if self.config.drop_last {
115            self.dataset.len() / self.config.batch_size
116        } else {
117            self.dataset.len().div_ceil(self.config.batch_size)
118        }
119    }
120
121    /// Total number of samples.
122    pub fn len(&self) -> usize {
123        self.dataset.len()
124    }
125
126    /// Whether the dataset is empty.
127    pub fn is_empty(&self) -> bool {
128        self.dataset.is_empty()
129    }
130
131    /// Reshuffle indices (call at the start of each epoch).
132    pub fn reshuffle(&mut self) {
133        if self.config.shuffle {
134            match self.config.seed {
135                Some(seed) => {
136                    let mut rng = StdRng::seed_from_u64(seed);
137                    self.indices.shuffle(&mut rng);
138                }
139                None => {
140                    let mut rng = thread_rng();
141                    self.indices.shuffle(&mut rng);
142                }
143            }
144        }
145    }
146
147    /// Fetch a slice of samples, optionally in parallel via rayon.
148    fn fetch_samples(&self, indices: &[usize]) -> Vec<Sample> {
149        if self.config.num_workers > 0 && indices.len() > 1 {
150            // Parallel fetch + transform
151            indices
152                .par_iter()
153                .map(|&i| {
154                    let mut s = self.dataset.get(i);
155                    for t in &self.transforms {
156                        s = t.apply(s);
157                    }
158                    s
159                })
160                .collect()
161        } else {
162            // Sequential
163            indices
164                .iter()
165                .map(|&i| {
166                    let mut s = self.dataset.get(i);
167                    for t in &self.transforms {
168                        s = t.apply(s);
169                    }
170                    s
171                })
172                .collect()
173        }
174    }
175
176    /// Produce all batches for one epoch as a Vec of HashMap tensors.
177    ///
178    /// Each HashMap has keys matching the input/target names you'll pass
179    /// to the executor.  By default: `"input"` and `"target"`.
180    pub fn epoch_batches(
181        &mut self,
182        input_name: &str,
183        target_name: &str,
184    ) -> Result<Vec<HashMap<String, Tensor<B>>>, shrew_core::Error> {
185        self.reshuffle();
186
187        let bs = self.config.batch_size;
188        let n = self.dataset.len();
189        let num_batches = self.num_batches();
190        let mut batches = Vec::with_capacity(num_batches);
191
192        for batch_idx in 0..num_batches {
193            let start = batch_idx * bs;
194            let end = (start + bs).min(n);
195            let actual_bs = end - start;
196
197            // Collect samples (potentially in parallel)
198            let batch_indices: Vec<usize> = (start..end).map(|i| self.indices[i]).collect();
199            let samples = self.fetch_samples(&batch_indices);
200
201            // Stack features into a batch tensor [actual_bs, ...feature_shape]
202            let feat_shape = samples[0].feature_shape.clone();
203            let tgt_shape = samples[0].target_shape.clone();
204
205            let mut feat_data: Vec<f64> = Vec::with_capacity(actual_bs * samples[0].features.len());
206            let mut tgt_data: Vec<f64> = Vec::with_capacity(actual_bs * samples[0].target.len());
207
208            for s in &samples {
209                feat_data.extend_from_slice(&s.features);
210                tgt_data.extend_from_slice(&s.target);
211            }
212
213            // Build batch shapes: [actual_bs, ...original_shape]
214            let mut batch_feat_shape = vec![actual_bs];
215            batch_feat_shape.extend_from_slice(&feat_shape);
216
217            let mut batch_tgt_shape = vec![actual_bs];
218            batch_tgt_shape.extend_from_slice(&tgt_shape);
219
220            let feat_tensor = Tensor::<B>::from_f64_slice(
221                &feat_data,
222                batch_feat_shape,
223                self.config.dtype,
224                &self.device,
225            )?;
226
227            let tgt_tensor = Tensor::<B>::from_f64_slice(
228                &tgt_data,
229                batch_tgt_shape,
230                self.config.dtype,
231                &self.device,
232            )?;
233
234            let mut batch_map = HashMap::new();
235            batch_map.insert(input_name.to_string(), feat_tensor);
236            batch_map.insert(target_name.to_string(), tgt_tensor);
237
238            batches.push(batch_map);
239        }
240
241        Ok(batches)
242    }
243
244    /// Iterate over batches one at a time (lower memory than `epoch_batches`).
245    pub fn iter_batches(
246        &mut self,
247        input_name: &str,
248        target_name: &str,
249    ) -> BatchIterator<'_, 'a, B> {
250        self.reshuffle();
251        BatchIterator {
252            loader: self,
253            batch_idx: 0,
254            input_name: input_name.to_string(),
255            target_name: target_name.to_string(),
256        }
257    }
258}
259
260/// Iterator that yields one batch at a time.
261pub struct BatchIterator<'l, 'a, B: Backend> {
262    loader: &'l DataLoader<'a, B>,
263    batch_idx: usize,
264    input_name: String,
265    target_name: String,
266}
267
268impl<'l, 'a, B: Backend> Iterator for BatchIterator<'l, 'a, B> {
269    type Item = Result<HashMap<String, Tensor<B>>, shrew_core::Error>;
270
271    fn next(&mut self) -> Option<Self::Item> {
272        let bs = self.loader.config.batch_size;
273        let n = self.loader.dataset.len();
274        let start = self.batch_idx * bs;
275
276        if start >= n {
277            return None;
278        }
279
280        if self.loader.config.drop_last && start + bs > n {
281            return None;
282        }
283
284        let end = (start + bs).min(n);
285        let actual_bs = end - start;
286        self.batch_idx += 1;
287
288        // Collect and transform samples (potentially in parallel)
289        let batch_indices: Vec<usize> = (start..end).map(|i| self.loader.indices[i]).collect();
290        let samples = self.loader.fetch_samples(&batch_indices);
291
292        let feat_shape = samples[0].feature_shape.clone();
293        let tgt_shape = samples[0].target_shape.clone();
294
295        let mut feat_data: Vec<f64> = Vec::with_capacity(actual_bs * samples[0].features.len());
296        let mut tgt_data: Vec<f64> = Vec::with_capacity(actual_bs * samples[0].target.len());
297
298        for s in &samples {
299            feat_data.extend_from_slice(&s.features);
300            tgt_data.extend_from_slice(&s.target);
301        }
302
303        let mut batch_feat_shape = vec![actual_bs];
304        batch_feat_shape.extend_from_slice(&feat_shape);
305
306        let mut batch_tgt_shape = vec![actual_bs];
307        batch_tgt_shape.extend_from_slice(&tgt_shape);
308
309        let feat_tensor = match Tensor::<B>::from_f64_slice(
310            &feat_data,
311            batch_feat_shape,
312            self.loader.config.dtype,
313            &self.loader.device,
314        ) {
315            Ok(t) => t,
316            Err(e) => return Some(Err(e)),
317        };
318
319        let tgt_tensor = match Tensor::<B>::from_f64_slice(
320            &tgt_data,
321            batch_tgt_shape,
322            self.loader.config.dtype,
323            &self.loader.device,
324        ) {
325            Ok(t) => t,
326            Err(e) => return Some(Err(e)),
327        };
328
329        let mut batch_map = HashMap::new();
330        batch_map.insert(self.input_name.clone(), feat_tensor);
331        batch_map.insert(self.target_name.clone(), tgt_tensor);
332
333        Some(Ok(batch_map))
334    }
335}