scirs2_io/ml_framework/
batch_processing.rs

1//! Batch processing support for ML models
2
3use crate::error::Result;
4use crate::ml_framework::{datasets, MLTensor};
5use scirs2_core::parallel_ops::*;
6
7/// Batch processor for ML models
8pub struct BatchProcessor {
9    batch_size: usize,
10    #[allow(dead_code)]
11    prefetch_factor: usize,
12}
13
14impl BatchProcessor {
15    pub fn new(batchsize: usize) -> Self {
16        Self {
17            batch_size: batchsize,
18            prefetch_factor: 2,
19        }
20    }
21
22    /// Process data in batches
23    pub fn process_batches<F>(&self, data: &[MLTensor], processfn: F) -> Result<Vec<MLTensor>>
24    where
25        F: Fn(&[MLTensor]) -> Result<Vec<MLTensor>> + Send + Sync,
26    {
27        let results: Result<Vec<Vec<MLTensor>>> =
28            data.par_chunks(self.batch_size).map(processfn).collect();
29
30        results.map(|chunks| chunks.into_iter().flatten().collect())
31    }
32
33    /// Create data loader
34    pub fn create_dataloader(&self, dataset: &datasets::MLDataset) -> DataLoader {
35        DataLoader {
36            dataset: dataset.clone(),
37            batch_size: self.batch_size,
38            shuffle: false,
39            current_idx: 0,
40        }
41    }
42}
43
44/// Data loader for batched iteration
45#[derive(Clone)]
46pub struct DataLoader {
47    dataset: datasets::MLDataset,
48    batch_size: usize,
49    shuffle: bool,
50    current_idx: usize,
51}
52
53impl Iterator for DataLoader {
54    type Item = (Vec<MLTensor>, Option<Vec<MLTensor>>);
55
56    fn next(&mut self) -> Option<Self::Item> {
57        if self.current_idx >= self.dataset.len() {
58            return None;
59        }
60
61        let end_idx = (self.current_idx + self.batch_size).min(self.dataset.len());
62        let features = self.dataset.features[self.current_idx..end_idx].to_vec();
63        let labels = self
64            .dataset
65            .labels
66            .as_ref()
67            .map(|l| l[self.current_idx..end_idx].to_vec());
68
69        self.current_idx = end_idx;
70        Some((features, labels))
71    }
72}