1use crate::{
7 collate::{Collate, DefaultCollate},
8 dataset::Dataset,
9 sampler::{BatchSampler, BatchingSampler, RandomSampler, SequentialSampler},
10};
11use scirs2_core::parallel_ops::*;
13use torsh_core::error::Result;
14
15#[cfg(not(feature = "std"))]
16use alloc::{boxed::Box, vec::Vec};
17
18pub trait DataLoaderTrait<D: Dataset, C: Collate<D::Item>> {
20 fn len(&self) -> usize;
22
23 fn is_empty(&self) -> bool;
25}
26
27pub struct DataLoader<D, S, C> {
55 dataset: D,
56 sampler: S,
57 collate_fn: C,
58 num_workers: usize,
59 #[allow(dead_code)]
60 pin_memory: bool,
61 #[allow(dead_code)]
62 drop_last: bool,
63 #[allow(dead_code)]
64 timeout: Option<std::time::Duration>,
65}
66
67impl<D: Dataset> DataLoader<D, (), ()> {
68 pub fn builder(dataset: D) -> DataLoaderBuilder<D> {
80 DataLoaderBuilder::new(dataset)
81 }
82}
83
84impl<D, S, C> DataLoader<D, S, C>
85where
86 D: Dataset,
87 S: BatchSampler,
88 C: Collate<D::Item>,
89{
90 pub fn iter(&self) -> DataLoaderIterator<'_, D, S, C> {
95 DataLoaderIterator {
96 dataset: &self.dataset,
97 sampler_iter: self.sampler.iter(),
98 collate_fn: &self.collate_fn,
99 num_workers: self.num_workers,
100 }
101 }
102
103 pub fn len(&self) -> usize {
108 self.sampler.len()
109 }
110
111 pub fn is_empty(&self) -> bool {
115 self.sampler.is_empty()
116 }
117
118 pub fn dataset(&self) -> &D {
120 &self.dataset
121 }
122
123 pub fn sampler(&self) -> &S {
125 &self.sampler
126 }
127
128 pub fn collate_fn(&self) -> &C {
130 &self.collate_fn
131 }
132
133 pub fn num_workers(&self) -> usize {
135 self.num_workers
136 }
137}
138
139impl<D, S, C> DataLoaderTrait<D, C> for DataLoader<D, S, C>
140where
141 D: Dataset + Sync,
142 S: BatchSampler + Sync,
143 C: Collate<D::Item> + Sync,
144 D::Item: Send,
145 C::Output: Send,
146 S::Iter: Iterator<Item = Vec<usize>>,
147{
148 fn len(&self) -> usize {
149 self.sampler.len()
150 }
151
152 fn is_empty(&self) -> bool {
153 self.sampler.is_empty()
154 }
155}
156
157pub struct DataLoaderIterator<'a, D, S, C>
162where
163 D: Dataset,
164 S: BatchSampler,
165 C: Collate<D::Item>,
166{
167 dataset: &'a D,
168 sampler_iter: S::Iter,
169 collate_fn: &'a C,
170 num_workers: usize,
171}
172
173impl<D, S, C> Iterator for DataLoaderIterator<'_, D, S, C>
174where
175 D: Dataset + Sync,
176 D::Item: Send,
177 S: BatchSampler,
178 S::Iter: Iterator<Item = Vec<usize>>,
179 C: Collate<D::Item> + Sync,
180 C::Output: Send,
181{
182 type Item = Result<C::Output>;
183
184 fn next(&mut self) -> Option<Self::Item> {
185 let indices = self.sampler_iter.next()?;
186
187 let batch_result = if self.num_workers > 1 {
188 let samples: Result<Vec<_>> = indices
190 .into_par_iter()
191 .map(|idx| self.dataset.get(idx))
192 .collect();
193
194 match samples {
195 Ok(samples) => self.collate_fn.collate(samples),
196 Err(e) => return Some(Err(e)),
197 }
198 } else {
199 let mut samples = Vec::with_capacity(indices.len());
201 for idx in indices {
202 match self.dataset.get(idx) {
203 Ok(sample) => samples.push(sample),
204 Err(e) => return Some(Err(e)),
205 }
206 }
207 self.collate_fn.collate(samples)
208 };
209
210 match batch_result {
212 Ok(batch) => {
213 Some(Ok(batch))
216 }
217 Err(e) => Some(Err(e)),
218 }
219 }
220}
221
222pub struct DataLoaderBuilder<D: Dataset> {
243 dataset: D,
244 batch_size: Option<usize>,
245 shuffle: bool,
246 num_workers: usize,
247 pin_memory: bool,
248 drop_last: bool,
249 timeout: Option<std::time::Duration>,
250 generator: Option<u64>,
251}
252
253impl<D: Dataset> DataLoaderBuilder<D> {
254 pub fn new(dataset: D) -> Self {
260 Self {
261 dataset,
262 batch_size: None,
263 shuffle: false,
264 num_workers: 0,
265 pin_memory: false,
266 drop_last: false,
267 timeout: None,
268 generator: None,
269 }
270 }
271
272 pub fn batch_size(mut self, batch_size: usize) -> Self {
278 self.batch_size = Some(batch_size);
279 self
280 }
281
282 pub fn shuffle(mut self, shuffle: bool) -> Self {
288 self.shuffle = shuffle;
289 self
290 }
291
292 pub fn num_workers(mut self, num_workers: usize) -> Self {
298 self.num_workers = num_workers;
299 self
300 }
301
302 pub fn pin_memory(mut self, pin_memory: bool) -> Self {
308 self.pin_memory = pin_memory;
309 self
310 }
311
312 pub fn drop_last(mut self, drop_last: bool) -> Self {
318 self.drop_last = drop_last;
319 self
320 }
321
322 pub fn timeout(mut self, timeout: std::time::Duration) -> Self {
328 self.timeout = Some(timeout);
329 self
330 }
331
332 pub fn generator(mut self, seed: u64) -> Self {
338 self.generator = Some(seed);
339 self
340 }
341
342 pub fn build(
347 self,
348 ) -> Result<DataLoader<D, BatchingSampler<SequentialSampler>, DefaultCollate>> {
349 let batch_size = self.batch_size.unwrap_or(1);
350 let base_sampler = SequentialSampler::new(self.dataset.len());
351 let batch_sampler = BatchingSampler::new(base_sampler, batch_size, self.drop_last);
352
353 Ok(DataLoader {
354 dataset: self.dataset,
355 sampler: batch_sampler,
356 collate_fn: DefaultCollate,
357 num_workers: self.num_workers,
358 pin_memory: self.pin_memory,
359 drop_last: self.drop_last,
360 timeout: self.timeout,
361 })
362 }
363
364 pub fn build_with_random_sampling(
369 self,
370 ) -> Result<DataLoader<D, BatchingSampler<RandomSampler>, DefaultCollate>> {
371 let batch_size = self.batch_size.unwrap_or(1);
372 let mut base_sampler = RandomSampler::new(self.dataset.len(), None, false);
373
374 if let Some(seed) = self.generator {
375 base_sampler = base_sampler.with_generator(seed);
376 }
377
378 let batch_sampler = BatchingSampler::new(base_sampler, batch_size, self.drop_last);
379
380 Ok(DataLoader {
381 dataset: self.dataset,
382 sampler: batch_sampler,
383 collate_fn: DefaultCollate,
384 num_workers: self.num_workers,
385 pin_memory: self.pin_memory,
386 drop_last: self.drop_last,
387 timeout: self.timeout,
388 })
389 }
390
391 pub fn build_auto(self) -> Result<Box<dyn DataLoaderTrait<D, DefaultCollate> + Send + Sync>>
396 where
397 D: Send + Sync + 'static,
398 D::Item: Send + Sync + 'static,
399 DefaultCollate: Collate<D::Item>,
400 <DefaultCollate as Collate<D::Item>>::Output: Send,
401 {
402 if self.shuffle {
403 Ok(Box::new(self.build_with_random_sampling()?))
404 } else {
405 Ok(Box::new(self.build()?))
406 }
407 }
408}
409
410pub type SimpleDataLoader<D> = DataLoader<D, BatchingSampler<SequentialSampler>, DefaultCollate>;
415
416pub type RandomDataLoader<D> = DataLoader<D, BatchingSampler<RandomSampler>, DefaultCollate>;
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424 use crate::dataset::TensorDataset;
425
426 #[test]
427 fn test_dataloader_builder() {
428 let tensor = torsh_tensor::creation::ones::<f32>(&[5]).expect("operation should succeed");
430 let dataset = TensorDataset::from_tensor(tensor);
431 let builder = DataLoaderBuilder::new(dataset);
432
433 assert_eq!(builder.batch_size, None);
434 assert!(!builder.shuffle);
435 assert_eq!(builder.num_workers, 0);
436 assert!(!builder.pin_memory);
437 assert!(!builder.drop_last);
438 }
439
440 #[test]
441 fn test_dataloader_builder_configuration() {
442 let tensor = torsh_tensor::creation::ones::<f32>(&[5]).expect("operation should succeed");
444 let dataset = TensorDataset::from_tensor(tensor);
445 let builder = DataLoaderBuilder::new(dataset)
446 .batch_size(2)
447 .shuffle(true)
448 .num_workers(4)
449 .pin_memory(true)
450 .drop_last(true);
451
452 assert_eq!(builder.batch_size, Some(2));
453 assert!(builder.shuffle);
454 assert_eq!(builder.num_workers, 4);
455 assert!(builder.pin_memory);
456 assert!(builder.drop_last);
457 }
458
459 #[test]
460 fn test_dataloader_sequential_build() {
461 let tensor = torsh_tensor::creation::ones::<f32>(&[5]).expect("operation should succeed");
463 let dataset = TensorDataset::from_tensor(tensor);
464 let dataloader = DataLoaderBuilder::new(dataset)
465 .batch_size(2)
466 .build()
467 .expect("operation should succeed");
468
469 assert_eq!(dataloader.len(), 3); assert!(!dataloader.is_empty());
471 }
472
473 #[test]
474 fn test_dataloader_random_build() {
475 let tensor = torsh_tensor::creation::ones::<f32>(&[5]).expect("operation should succeed");
477 let dataset = TensorDataset::from_tensor(tensor);
478 let dataloader = DataLoaderBuilder::new(dataset)
479 .batch_size(2)
480 .generator(42)
481 .build_with_random_sampling()
482 .expect("operation should succeed");
483
484 assert_eq!(dataloader.len(), 3);
485 assert!(!dataloader.is_empty());
486 }
487
488 #[test]
489 fn test_dataloader_iteration() {
490 let tensor = torsh_tensor::creation::ones::<f32>(&[4]).expect("operation should succeed");
492 let dataset = TensorDataset::from_tensor(tensor);
493 let dataloader = DataLoaderBuilder::new(dataset)
494 .batch_size(2)
495 .build()
496 .expect("operation should succeed");
497
498 let mut iter = dataloader.iter();
499 let batch1 = iter
500 .next()
501 .expect("iterator should have a next element")
502 .expect("operation should succeed");
503 let batch2 = iter
504 .next()
505 .expect("iterator should have a next element")
506 .expect("operation should succeed");
507 assert!(iter.next().is_none());
508
509 assert_eq!(batch1.len(), 1);
511 assert_eq!(batch2.len(), 1);
512
513 assert_eq!(batch1[0].shape().dims(), &[2, 1]); assert_eq!(batch2[0].shape().dims(), &[2, 1]); }
518
519 #[test]
520 fn test_dataloader_drop_last() {
521 let tensor = torsh_tensor::creation::ones::<f32>(&[5]).expect("operation should succeed");
523 let dataset = TensorDataset::from_tensor(tensor);
524 let dataloader = DataLoaderBuilder::new(dataset)
525 .batch_size(2)
526 .drop_last(true)
527 .build()
528 .expect("operation should succeed");
529
530 assert_eq!(dataloader.len(), 2); }
532
533 #[test]
534 fn test_dataloader_trait_implementation() {
535 let tensor = torsh_tensor::creation::ones::<f32>(&[5]).expect("operation should succeed");
537 let dataset = TensorDataset::from_tensor(tensor);
538 let dataloader = DataLoaderBuilder::new(dataset)
539 .batch_size(2)
540 .build()
541 .expect("operation should succeed");
542
543 assert_eq!(DataLoaderTrait::len(&dataloader), 3);
545 assert!(!DataLoaderTrait::is_empty(&dataloader));
546 }
547
548 #[test]
549 fn test_empty_dataloader() {
550 let tensors: Vec<torsh_tensor::Tensor<f32>> = vec![];
551 let dataset = TensorDataset::new(tensors);
552 let dataloader = DataLoaderBuilder::new(dataset)
553 .batch_size(2)
554 .build()
555 .expect("operation should succeed");
556
557 assert_eq!(dataloader.len(), 0);
558 assert!(dataloader.is_empty());
559
560 let mut iter = dataloader.iter();
561 assert!(iter.next().is_none());
562 }
563}