1use std::collections::HashMap;
4
5use rand::rngs::StdRng;
6use rand::seq::SliceRandom;
7use rand::{thread_rng, SeedableRng};
8
9use rayon::prelude::*;
10
11use shrew_core::backend::Backend;
12use shrew_core::tensor::Tensor;
13use shrew_core::DType;
14
15use crate::dataset::{Dataset, Sample};
16use crate::transform::Transform;
17
18#[derive(Debug, Clone)]
20pub struct DataLoaderConfig {
21 pub batch_size: usize,
23 pub shuffle: bool,
25 pub drop_last: bool,
27 pub dtype: DType,
29 pub num_workers: usize,
31 pub seed: Option<u64>,
33}
34
35impl Default for DataLoaderConfig {
36 fn default() -> Self {
37 Self {
38 batch_size: 32,
39 shuffle: true,
40 drop_last: false,
41 dtype: DType::F32,
42 num_workers: 0,
43 seed: None,
44 }
45 }
46}
47
48impl DataLoaderConfig {
49 pub fn batch_size(mut self, bs: usize) -> Self {
50 self.batch_size = bs;
51 self
52 }
53
54 pub fn shuffle(mut self, s: bool) -> Self {
55 self.shuffle = s;
56 self
57 }
58
59 pub fn drop_last(mut self, d: bool) -> Self {
60 self.drop_last = d;
61 self
62 }
63
64 pub fn dtype(mut self, d: DType) -> Self {
65 self.dtype = d;
66 self
67 }
68
69 pub fn num_workers(mut self, n: usize) -> Self {
70 self.num_workers = n;
71 self
72 }
73
74 pub fn seed(mut self, s: u64) -> Self {
75 self.seed = Some(s);
76 self
77 }
78}
79
80pub struct DataLoader<'a, B: Backend> {
86 dataset: &'a dyn Dataset,
87 config: DataLoaderConfig,
88 transforms: Vec<Box<dyn Transform>>,
89 device: B::Device,
90 indices: Vec<usize>,
91}
92
93impl<'a, B: Backend> DataLoader<'a, B> {
94 pub fn new(dataset: &'a dyn Dataset, device: B::Device, config: DataLoaderConfig) -> Self {
96 let indices: Vec<usize> = (0..dataset.len()).collect();
97 Self {
98 dataset,
99 config,
100 transforms: Vec::new(),
101 device,
102 indices,
103 }
104 }
105
106 pub fn with_transform(mut self, t: Box<dyn Transform>) -> Self {
108 self.transforms.push(t);
109 self
110 }
111
112 pub fn num_batches(&self) -> usize {
114 if self.config.drop_last {
115 self.dataset.len() / self.config.batch_size
116 } else {
117 self.dataset.len().div_ceil(self.config.batch_size)
118 }
119 }
120
121 pub fn len(&self) -> usize {
123 self.dataset.len()
124 }
125
126 pub fn is_empty(&self) -> bool {
128 self.dataset.is_empty()
129 }
130
131 pub fn reshuffle(&mut self) {
133 if self.config.shuffle {
134 match self.config.seed {
135 Some(seed) => {
136 let mut rng = StdRng::seed_from_u64(seed);
137 self.indices.shuffle(&mut rng);
138 }
139 None => {
140 let mut rng = thread_rng();
141 self.indices.shuffle(&mut rng);
142 }
143 }
144 }
145 }
146
147 fn fetch_samples(&self, indices: &[usize]) -> Vec<Sample> {
149 if self.config.num_workers > 0 && indices.len() > 1 {
150 indices
152 .par_iter()
153 .map(|&i| {
154 let mut s = self.dataset.get(i);
155 for t in &self.transforms {
156 s = t.apply(s);
157 }
158 s
159 })
160 .collect()
161 } else {
162 indices
164 .iter()
165 .map(|&i| {
166 let mut s = self.dataset.get(i);
167 for t in &self.transforms {
168 s = t.apply(s);
169 }
170 s
171 })
172 .collect()
173 }
174 }
175
176 pub fn epoch_batches(
181 &mut self,
182 input_name: &str,
183 target_name: &str,
184 ) -> Result<Vec<HashMap<String, Tensor<B>>>, shrew_core::Error> {
185 self.reshuffle();
186
187 let bs = self.config.batch_size;
188 let n = self.dataset.len();
189 let num_batches = self.num_batches();
190 let mut batches = Vec::with_capacity(num_batches);
191
192 for batch_idx in 0..num_batches {
193 let start = batch_idx * bs;
194 let end = (start + bs).min(n);
195 let actual_bs = end - start;
196
197 let batch_indices: Vec<usize> = (start..end).map(|i| self.indices[i]).collect();
199 let samples = self.fetch_samples(&batch_indices);
200
201 let feat_shape = samples[0].feature_shape.clone();
203 let tgt_shape = samples[0].target_shape.clone();
204
205 let mut feat_data: Vec<f64> = Vec::with_capacity(actual_bs * samples[0].features.len());
206 let mut tgt_data: Vec<f64> = Vec::with_capacity(actual_bs * samples[0].target.len());
207
208 for s in &samples {
209 feat_data.extend_from_slice(&s.features);
210 tgt_data.extend_from_slice(&s.target);
211 }
212
213 let mut batch_feat_shape = vec![actual_bs];
215 batch_feat_shape.extend_from_slice(&feat_shape);
216
217 let mut batch_tgt_shape = vec![actual_bs];
218 batch_tgt_shape.extend_from_slice(&tgt_shape);
219
220 let feat_tensor = Tensor::<B>::from_f64_slice(
221 &feat_data,
222 batch_feat_shape,
223 self.config.dtype,
224 &self.device,
225 )?;
226
227 let tgt_tensor = Tensor::<B>::from_f64_slice(
228 &tgt_data,
229 batch_tgt_shape,
230 self.config.dtype,
231 &self.device,
232 )?;
233
234 let mut batch_map = HashMap::new();
235 batch_map.insert(input_name.to_string(), feat_tensor);
236 batch_map.insert(target_name.to_string(), tgt_tensor);
237
238 batches.push(batch_map);
239 }
240
241 Ok(batches)
242 }
243
244 pub fn iter_batches(
246 &mut self,
247 input_name: &str,
248 target_name: &str,
249 ) -> BatchIterator<'_, 'a, B> {
250 self.reshuffle();
251 BatchIterator {
252 loader: self,
253 batch_idx: 0,
254 input_name: input_name.to_string(),
255 target_name: target_name.to_string(),
256 }
257 }
258}
259
260pub struct BatchIterator<'l, 'a, B: Backend> {
262 loader: &'l DataLoader<'a, B>,
263 batch_idx: usize,
264 input_name: String,
265 target_name: String,
266}
267
268impl<'l, 'a, B: Backend> Iterator for BatchIterator<'l, 'a, B> {
269 type Item = Result<HashMap<String, Tensor<B>>, shrew_core::Error>;
270
271 fn next(&mut self) -> Option<Self::Item> {
272 let bs = self.loader.config.batch_size;
273 let n = self.loader.dataset.len();
274 let start = self.batch_idx * bs;
275
276 if start >= n {
277 return None;
278 }
279
280 if self.loader.config.drop_last && start + bs > n {
281 return None;
282 }
283
284 let end = (start + bs).min(n);
285 let actual_bs = end - start;
286 self.batch_idx += 1;
287
288 let batch_indices: Vec<usize> = (start..end).map(|i| self.loader.indices[i]).collect();
290 let samples = self.loader.fetch_samples(&batch_indices);
291
292 let feat_shape = samples[0].feature_shape.clone();
293 let tgt_shape = samples[0].target_shape.clone();
294
295 let mut feat_data: Vec<f64> = Vec::with_capacity(actual_bs * samples[0].features.len());
296 let mut tgt_data: Vec<f64> = Vec::with_capacity(actual_bs * samples[0].target.len());
297
298 for s in &samples {
299 feat_data.extend_from_slice(&s.features);
300 tgt_data.extend_from_slice(&s.target);
301 }
302
303 let mut batch_feat_shape = vec![actual_bs];
304 batch_feat_shape.extend_from_slice(&feat_shape);
305
306 let mut batch_tgt_shape = vec![actual_bs];
307 batch_tgt_shape.extend_from_slice(&tgt_shape);
308
309 let feat_tensor = match Tensor::<B>::from_f64_slice(
310 &feat_data,
311 batch_feat_shape,
312 self.loader.config.dtype,
313 &self.loader.device,
314 ) {
315 Ok(t) => t,
316 Err(e) => return Some(Err(e)),
317 };
318
319 let tgt_tensor = match Tensor::<B>::from_f64_slice(
320 &tgt_data,
321 batch_tgt_shape,
322 self.loader.config.dtype,
323 &self.loader.device,
324 ) {
325 Ok(t) => t,
326 Err(e) => return Some(Err(e)),
327 };
328
329 let mut batch_map = HashMap::new();
330 batch_map.insert(self.input_name.clone(), feat_tensor);
331 batch_map.insert(self.target_name.clone(), tgt_tensor);
332
333 Some(Ok(batch_map))
334 }
335}