Skip to main content

torsh_data/dataloader/
core.rs

1//! Core DataLoader implementation
2//!
3//! This module contains the fundamental DataLoader functionality including the main
4//! DataLoader struct, its iterator, builder pattern, and core traits.
5
6use crate::{
7    collate::{Collate, DefaultCollate},
8    dataset::Dataset,
9    sampler::{BatchSampler, BatchingSampler, RandomSampler, SequentialSampler},
10};
11// ✅ SciRS2 POLICY: Use scirs2_core::parallel_ops instead of rayon::prelude
12use scirs2_core::parallel_ops::*;
13use torsh_core::error::Result;
14
15#[cfg(not(feature = "std"))]
16use alloc::{boxed::Box, vec::Vec};
17
18/// Trait for DataLoader functionality
19pub trait DataLoaderTrait<D: Dataset, C: Collate<D::Item>> {
20    /// Get the number of batches
21    fn len(&self) -> usize;
22
23    /// Check if the dataloader is empty
24    fn is_empty(&self) -> bool;
25}
26
27/// DataLoader for batching and iterating over datasets
28///
29/// The DataLoader provides an efficient way to iterate over datasets in batches,
30/// with support for parallel loading, shuffling, and various optimization strategies.
31///
32/// # Type Parameters
33///
34/// - `D`: Dataset type implementing the Dataset trait
35/// - `S`: Sampler type implementing the BatchSampler trait
36/// - `C`: Collate function type implementing the Collate trait
37///
38/// # Examples
39///
40/// ```rust,ignore
41/// use torsh_data::dataloader::core::DataLoader;
42/// use torsh_data::dataset::TensorDataset;
43///
44/// let dataset = TensorDataset::new(vec![1, 2, 3, 4, 5]);
45/// let dataloader = DataLoader::builder(dataset)
46///     .batch_size(2)
47///     .num_workers(4)
48///     .build()?;
49///
50/// for batch in dataloader.iter() {
51///     // Process batch
52/// }
53/// ```
54pub struct DataLoader<D, S, C> {
55    dataset: D,
56    sampler: S,
57    collate_fn: C,
58    num_workers: usize,
59    #[allow(dead_code)]
60    pin_memory: bool,
61    #[allow(dead_code)]
62    drop_last: bool,
63    #[allow(dead_code)]
64    timeout: Option<std::time::Duration>,
65}
66
67impl<D: Dataset> DataLoader<D, (), ()> {
68    /// Create a new DataLoader builder
69    ///
70    /// This provides a fluent API for configuring DataLoader options.
71    ///
72    /// # Arguments
73    ///
74    /// * `dataset` - The dataset to iterate over
75    ///
76    /// # Returns
77    ///
78    /// A DataLoaderBuilder for configuring the DataLoader
79    pub fn builder(dataset: D) -> DataLoaderBuilder<D> {
80        DataLoaderBuilder::new(dataset)
81    }
82}
83
84impl<D, S, C> DataLoader<D, S, C>
85where
86    D: Dataset,
87    S: BatchSampler,
88    C: Collate<D::Item>,
89{
90    /// Create an iterator over the dataset
91    ///
92    /// Returns a DataLoaderIterator that will yield batches according to
93    /// the configured sampler and collation function.
94    pub fn iter(&self) -> DataLoaderIterator<'_, D, S, C> {
95        DataLoaderIterator {
96            dataset: &self.dataset,
97            sampler_iter: self.sampler.iter(),
98            collate_fn: &self.collate_fn,
99            num_workers: self.num_workers,
100        }
101    }
102
103    /// Get the number of batches
104    ///
105    /// Returns the total number of batches that will be produced by this DataLoader,
106    /// based on the underlying sampler's batch count.
107    pub fn len(&self) -> usize {
108        self.sampler.len()
109    }
110
111    /// Check if the dataloader is empty
112    ///
113    /// Returns true if the DataLoader will produce zero batches.
114    pub fn is_empty(&self) -> bool {
115        self.sampler.is_empty()
116    }
117
118    /// Get the dataset
119    pub fn dataset(&self) -> &D {
120        &self.dataset
121    }
122
123    /// Get the sampler
124    pub fn sampler(&self) -> &S {
125        &self.sampler
126    }
127
128    /// Get the collate function
129    pub fn collate_fn(&self) -> &C {
130        &self.collate_fn
131    }
132
133    /// Get the number of workers
134    pub fn num_workers(&self) -> usize {
135        self.num_workers
136    }
137}
138
139impl<D, S, C> DataLoaderTrait<D, C> for DataLoader<D, S, C>
140where
141    D: Dataset + Sync,
142    S: BatchSampler + Sync,
143    C: Collate<D::Item> + Sync,
144    D::Item: Send,
145    C::Output: Send,
146    S::Iter: Iterator<Item = Vec<usize>>,
147{
148    fn len(&self) -> usize {
149        self.sampler.len()
150    }
151
152    fn is_empty(&self) -> bool {
153        self.sampler.is_empty()
154    }
155}
156
157/// Iterator for DataLoader
158///
159/// This iterator handles the actual batch loading process, including parallel
160/// processing when multiple workers are configured.
161pub struct DataLoaderIterator<'a, D, S, C>
162where
163    D: Dataset,
164    S: BatchSampler,
165    C: Collate<D::Item>,
166{
167    dataset: &'a D,
168    sampler_iter: S::Iter,
169    collate_fn: &'a C,
170    num_workers: usize,
171}
172
173impl<D, S, C> Iterator for DataLoaderIterator<'_, D, S, C>
174where
175    D: Dataset + Sync,
176    D::Item: Send,
177    S: BatchSampler,
178    S::Iter: Iterator<Item = Vec<usize>>,
179    C: Collate<D::Item> + Sync,
180    C::Output: Send,
181{
182    type Item = Result<C::Output>;
183
184    fn next(&mut self) -> Option<Self::Item> {
185        let indices = self.sampler_iter.next()?;
186
187        let batch_result = if self.num_workers > 1 {
188            // Parallel loading using Rayon
189            let samples: Result<Vec<_>> = indices
190                .into_par_iter()
191                .map(|idx| self.dataset.get(idx))
192                .collect();
193
194            match samples {
195                Ok(samples) => self.collate_fn.collate(samples),
196                Err(e) => return Some(Err(e)),
197            }
198        } else {
199            // Sequential loading
200            let mut samples = Vec::with_capacity(indices.len());
201            for idx in indices {
202                match self.dataset.get(idx) {
203                    Ok(sample) => samples.push(sample),
204                    Err(e) => return Some(Err(e)),
205                }
206            }
207            self.collate_fn.collate(samples)
208        };
209
210        // Apply memory pinning if enabled
211        match batch_result {
212            Ok(batch) => {
213                // Note: Memory pinning implementation would be applied here
214                // For now, this is a placeholder for the pin_memory flag
215                Some(Ok(batch))
216            }
217            Err(e) => Some(Err(e)),
218        }
219    }
220}
221
222/// Builder for DataLoader
223///
224/// Provides a fluent API for configuring DataLoader instances with various options
225/// such as batch size, shuffling, number of workers, and memory pinning.
226///
227/// # Examples
228///
229/// ```rust,ignore
230/// use torsh_data::dataloader::core::DataLoaderBuilder;
231/// use torsh_data::dataset::TensorDataset;
232///
233/// let dataset = TensorDataset::new(vec![1, 2, 3, 4, 5]);
234/// let dataloader = DataLoaderBuilder::new(dataset)
235///     .batch_size(32)
236///     .shuffle(true)
237///     .num_workers(4)
238///     .pin_memory(true)
239///     .drop_last(true)
240///     .build()?;
241/// ```
242pub struct DataLoaderBuilder<D: Dataset> {
243    dataset: D,
244    batch_size: Option<usize>,
245    shuffle: bool,
246    num_workers: usize,
247    pin_memory: bool,
248    drop_last: bool,
249    timeout: Option<std::time::Duration>,
250    generator: Option<u64>,
251}
252
253impl<D: Dataset> DataLoaderBuilder<D> {
254    /// Create a new builder
255    ///
256    /// # Arguments
257    ///
258    /// * `dataset` - The dataset to create a DataLoader for
259    pub fn new(dataset: D) -> Self {
260        Self {
261            dataset,
262            batch_size: None,
263            shuffle: false,
264            num_workers: 0,
265            pin_memory: false,
266            drop_last: false,
267            timeout: None,
268            generator: None,
269        }
270    }
271
272    /// Set batch size
273    ///
274    /// # Arguments
275    ///
276    /// * `batch_size` - Number of samples per batch
277    pub fn batch_size(mut self, batch_size: usize) -> Self {
278        self.batch_size = Some(batch_size);
279        self
280    }
281
282    /// Set whether to shuffle the data
283    ///
284    /// # Arguments
285    ///
286    /// * `shuffle` - Whether to randomly shuffle the dataset
287    pub fn shuffle(mut self, shuffle: bool) -> Self {
288        self.shuffle = shuffle;
289        self
290    }
291
292    /// Set the number of worker threads
293    ///
294    /// # Arguments
295    ///
296    /// * `num_workers` - Number of worker threads for parallel data loading
297    pub fn num_workers(mut self, num_workers: usize) -> Self {
298        self.num_workers = num_workers;
299        self
300    }
301
302    /// Set whether to pin memory
303    ///
304    /// # Arguments
305    ///
306    /// * `pin_memory` - Whether to pin memory for faster GPU transfers
307    pub fn pin_memory(mut self, pin_memory: bool) -> Self {
308        self.pin_memory = pin_memory;
309        self
310    }
311
312    /// Set whether to drop the last incomplete batch
313    ///
314    /// # Arguments
315    ///
316    /// * `drop_last` - Whether to drop the last batch if it's smaller than batch_size
317    pub fn drop_last(mut self, drop_last: bool) -> Self {
318        self.drop_last = drop_last;
319        self
320    }
321
322    /// Set timeout for collecting a batch
323    ///
324    /// # Arguments
325    ///
326    /// * `timeout` - Maximum time to wait for batch collection
327    pub fn timeout(mut self, timeout: std::time::Duration) -> Self {
328        self.timeout = Some(timeout);
329        self
330    }
331
332    /// Set random generator seed
333    ///
334    /// # Arguments
335    ///
336    /// * `seed` - Random seed for reproducible shuffling
337    pub fn generator(mut self, seed: u64) -> Self {
338        self.generator = Some(seed);
339        self
340    }
341
342    /// Build the DataLoader with sequential sampling
343    ///
344    /// Creates a DataLoader that processes the dataset in sequential order.
345    /// This is the default behavior when shuffle is false or not specified.
346    pub fn build(
347        self,
348    ) -> Result<DataLoader<D, BatchingSampler<SequentialSampler>, DefaultCollate>> {
349        let batch_size = self.batch_size.unwrap_or(1);
350        let base_sampler = SequentialSampler::new(self.dataset.len());
351        let batch_sampler = BatchingSampler::new(base_sampler, batch_size, self.drop_last);
352
353        Ok(DataLoader {
354            dataset: self.dataset,
355            sampler: batch_sampler,
356            collate_fn: DefaultCollate,
357            num_workers: self.num_workers,
358            pin_memory: self.pin_memory,
359            drop_last: self.drop_last,
360            timeout: self.timeout,
361        })
362    }
363
364    /// Build the DataLoader with random sampling (shuffled)
365    ///
366    /// Creates a DataLoader that randomly shuffles the dataset order.
367    /// Useful for training scenarios where data order should be randomized.
368    pub fn build_with_random_sampling(
369        self,
370    ) -> Result<DataLoader<D, BatchingSampler<RandomSampler>, DefaultCollate>> {
371        let batch_size = self.batch_size.unwrap_or(1);
372        let mut base_sampler = RandomSampler::new(self.dataset.len(), None, false);
373
374        if let Some(seed) = self.generator {
375            base_sampler = base_sampler.with_generator(seed);
376        }
377
378        let batch_sampler = BatchingSampler::new(base_sampler, batch_size, self.drop_last);
379
380        Ok(DataLoader {
381            dataset: self.dataset,
382            sampler: batch_sampler,
383            collate_fn: DefaultCollate,
384            num_workers: self.num_workers,
385            pin_memory: self.pin_memory,
386            drop_last: self.drop_last,
387            timeout: self.timeout,
388        })
389    }
390
391    /// Build the DataLoader with auto-selected sampling strategy
392    ///
393    /// Automatically chooses between sequential and random sampling based on
394    /// the shuffle setting configured in the builder.
395    pub fn build_auto(self) -> Result<Box<dyn DataLoaderTrait<D, DefaultCollate> + Send + Sync>>
396    where
397        D: Send + Sync + 'static,
398        D::Item: Send + Sync + 'static,
399        DefaultCollate: Collate<D::Item>,
400        <DefaultCollate as Collate<D::Item>>::Output: Send,
401    {
402        if self.shuffle {
403            Ok(Box::new(self.build_with_random_sampling()?))
404        } else {
405            Ok(Box::new(self.build()?))
406        }
407    }
408}
409
410/// Simplified DataLoader type for common use cases
411///
412/// This type alias provides a convenient shorthand for the most common DataLoader
413/// configuration using sequential sampling and default collation.
414pub type SimpleDataLoader<D> = DataLoader<D, BatchingSampler<SequentialSampler>, DefaultCollate>;
415
416/// Simplified DataLoader type for random sampling use cases
417///
418/// This type alias provides a convenient shorthand for DataLoader with random sampling.
419pub type RandomDataLoader<D> = DataLoader<D, BatchingSampler<RandomSampler>, DefaultCollate>;
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424    use crate::dataset::TensorDataset;
425
426    #[test]
427    fn test_dataloader_builder() {
428        // Create a tensor with 5 samples (first dimension is number of samples)
429        let tensor = torsh_tensor::creation::ones::<f32>(&[5]).expect("operation should succeed");
430        let dataset = TensorDataset::from_tensor(tensor);
431        let builder = DataLoaderBuilder::new(dataset);
432
433        assert_eq!(builder.batch_size, None);
434        assert!(!builder.shuffle);
435        assert_eq!(builder.num_workers, 0);
436        assert!(!builder.pin_memory);
437        assert!(!builder.drop_last);
438    }
439
440    #[test]
441    fn test_dataloader_builder_configuration() {
442        // Create a tensor with 5 samples (first dimension is number of samples)
443        let tensor = torsh_tensor::creation::ones::<f32>(&[5]).expect("operation should succeed");
444        let dataset = TensorDataset::from_tensor(tensor);
445        let builder = DataLoaderBuilder::new(dataset)
446            .batch_size(2)
447            .shuffle(true)
448            .num_workers(4)
449            .pin_memory(true)
450            .drop_last(true);
451
452        assert_eq!(builder.batch_size, Some(2));
453        assert!(builder.shuffle);
454        assert_eq!(builder.num_workers, 4);
455        assert!(builder.pin_memory);
456        assert!(builder.drop_last);
457    }
458
459    #[test]
460    fn test_dataloader_sequential_build() {
461        // Create a tensor with 5 samples (first dimension is number of samples)
462        let tensor = torsh_tensor::creation::ones::<f32>(&[5]).expect("operation should succeed");
463        let dataset = TensorDataset::from_tensor(tensor);
464        let dataloader = DataLoaderBuilder::new(dataset)
465            .batch_size(2)
466            .build()
467            .expect("operation should succeed");
468
469        assert_eq!(dataloader.len(), 3); // 5 items, batch size 2 = 3 batches (last with 1 item)
470        assert!(!dataloader.is_empty());
471    }
472
473    #[test]
474    fn test_dataloader_random_build() {
475        // Create a tensor with 5 samples (first dimension is number of samples)
476        let tensor = torsh_tensor::creation::ones::<f32>(&[5]).expect("operation should succeed");
477        let dataset = TensorDataset::from_tensor(tensor);
478        let dataloader = DataLoaderBuilder::new(dataset)
479            .batch_size(2)
480            .generator(42)
481            .build_with_random_sampling()
482            .expect("operation should succeed");
483
484        assert_eq!(dataloader.len(), 3);
485        assert!(!dataloader.is_empty());
486    }
487
488    #[test]
489    fn test_dataloader_iteration() {
490        // Create a tensor with 4 samples (first dimension is number of samples)
491        let tensor = torsh_tensor::creation::ones::<f32>(&[4]).expect("operation should succeed");
492        let dataset = TensorDataset::from_tensor(tensor);
493        let dataloader = DataLoaderBuilder::new(dataset)
494            .batch_size(2)
495            .build()
496            .expect("operation should succeed");
497
498        let mut iter = dataloader.iter();
499        let batch1 = iter
500            .next()
501            .expect("iterator should have a next element")
502            .expect("operation should succeed");
503        let batch2 = iter
504            .next()
505            .expect("iterator should have a next element")
506            .expect("operation should succeed");
507        assert!(iter.next().is_none());
508
509        // Verify batch contents (each batch should have 1 stacked tensor)
510        assert_eq!(batch1.len(), 1);
511        assert_eq!(batch2.len(), 1);
512
513        // Verify the stacked tensor shape: [batch_size, sample_features]
514        // Original tensor is [4], each sample is [1], batched becomes [2, 1]
515        assert_eq!(batch1[0].shape().dims(), &[2, 1]); // 2 samples batched
516        assert_eq!(batch2[0].shape().dims(), &[2, 1]); // 2 samples batched
517    }
518
519    #[test]
520    fn test_dataloader_drop_last() {
521        // Create a tensor with 5 samples (first dimension is number of samples)
522        let tensor = torsh_tensor::creation::ones::<f32>(&[5]).expect("operation should succeed");
523        let dataset = TensorDataset::from_tensor(tensor);
524        let dataloader = DataLoaderBuilder::new(dataset)
525            .batch_size(2)
526            .drop_last(true)
527            .build()
528            .expect("operation should succeed");
529
530        assert_eq!(dataloader.len(), 2); // 5 items, batch size 2, drop_last = 2 complete batches
531    }
532
533    #[test]
534    fn test_dataloader_trait_implementation() {
535        // Create a tensor with 5 samples (first dimension is number of samples)
536        let tensor = torsh_tensor::creation::ones::<f32>(&[5]).expect("operation should succeed");
537        let dataset = TensorDataset::from_tensor(tensor);
538        let dataloader = DataLoaderBuilder::new(dataset)
539            .batch_size(2)
540            .build()
541            .expect("operation should succeed");
542
543        // Test trait methods
544        assert_eq!(DataLoaderTrait::len(&dataloader), 3);
545        assert!(!DataLoaderTrait::is_empty(&dataloader));
546    }
547
548    #[test]
549    fn test_empty_dataloader() {
550        let tensors: Vec<torsh_tensor::Tensor<f32>> = vec![];
551        let dataset = TensorDataset::new(tensors);
552        let dataloader = DataLoaderBuilder::new(dataset)
553            .batch_size(2)
554            .build()
555            .expect("operation should succeed");
556
557        assert_eq!(dataloader.len(), 0);
558        assert!(dataloader.is_empty());
559
560        let mut iter = dataloader.iter();
561        assert!(iter.next().is_none());
562    }
563}