scirs2_io/ml_framework/
batch_processing.rs1use crate::error::Result;
4use crate::ml_framework::{datasets, MLTensor};
5use scirs2_core::parallel_ops::*;
6
7pub struct BatchProcessor {
9 batch_size: usize,
10 #[allow(dead_code)]
11 prefetch_factor: usize,
12}
13
14impl BatchProcessor {
15 pub fn new(batchsize: usize) -> Self {
16 Self {
17 batch_size: batchsize,
18 prefetch_factor: 2,
19 }
20 }
21
22 pub fn process_batches<F>(&self, data: &[MLTensor], processfn: F) -> Result<Vec<MLTensor>>
24 where
25 F: Fn(&[MLTensor]) -> Result<Vec<MLTensor>> + Send + Sync,
26 {
27 let results: Result<Vec<Vec<MLTensor>>> =
28 data.par_chunks(self.batch_size).map(processfn).collect();
29
30 results.map(|chunks| chunks.into_iter().flatten().collect())
31 }
32
33 pub fn create_dataloader(&self, dataset: &datasets::MLDataset) -> DataLoader {
35 DataLoader {
36 dataset: dataset.clone(),
37 batch_size: self.batch_size,
38 shuffle: false,
39 current_idx: 0,
40 }
41 }
42}
43
44#[derive(Clone)]
46pub struct DataLoader {
47 dataset: datasets::MLDataset,
48 batch_size: usize,
49 shuffle: bool,
50 current_idx: usize,
51}
52
53impl Iterator for DataLoader {
54 type Item = (Vec<MLTensor>, Option<Vec<MLTensor>>);
55
56 fn next(&mut self) -> Option<Self::Item> {
57 if self.current_idx >= self.dataset.len() {
58 return None;
59 }
60
61 let end_idx = (self.current_idx + self.batch_size).min(self.dataset.len());
62 let features = self.dataset.features[self.current_idx..end_idx].to_vec();
63 let labels = self
64 .dataset
65 .labels
66 .as_ref()
67 .map(|l| l[self.current_idx..end_idx].to_vec());
68
69 self.current_idx = end_idx;
70 Some((features, labels))
71 }
72}