Skip to main content

yscv_model/
batch_infer.rs

1use yscv_tensor::Tensor;
2
3use crate::ModelError;
4
5/// Dynamic batching configuration for inference.
6#[derive(Debug, Clone)]
7pub struct DynamicBatchConfig {
8    /// Maximum batch size to accumulate before dispatching.
9    pub max_batch_size: usize,
10    /// Pad incomplete batches with zeros to enable fixed-size dispatch.
11    pub pad_incomplete: bool,
12}
13
14impl Default for DynamicBatchConfig {
15    fn default() -> Self {
16        Self {
17            max_batch_size: 32,
18            pad_incomplete: false,
19        }
20    }
21}
22
23/// Splits a large input into batches, runs inference, and reassembles.
24///
25/// `input` is `[N, ...]`, `infer_fn` processes `[B, ...]` and returns `[B, ...]`.
26pub fn batched_inference<F>(
27    input: &Tensor,
28    config: &DynamicBatchConfig,
29    infer_fn: F,
30) -> Result<Tensor, ModelError>
31where
32    F: Fn(&Tensor) -> Result<Tensor, ModelError>,
33{
34    let shape = input.shape();
35    if shape.is_empty() {
36        return Err(ModelError::InvalidFlattenShape {
37            got: shape.to_vec(),
38        });
39    }
40    let n = shape[0];
41    let sample_size: usize = shape[1..].iter().product();
42    let data = input.data();
43    let bs = config.max_batch_size;
44
45    let mut all_outputs: Vec<f32> = Vec::new();
46    let mut out_sample_shape: Option<Vec<usize>> = None;
47
48    let mut offset = 0;
49    while offset < n {
50        let batch_n = (n - offset).min(bs);
51        let start = offset * sample_size;
52        let end = (offset + batch_n) * sample_size;
53        let mut batch_data = data[start..end].to_vec();
54
55        let actual_n = if config.pad_incomplete && batch_n < bs {
56            let pad = (bs - batch_n) * sample_size;
57            batch_data.extend(std::iter::repeat_n(0.0f32, pad));
58            bs
59        } else {
60            batch_n
61        };
62
63        let mut batch_shape = shape.to_vec();
64        batch_shape[0] = actual_n;
65        let batch_tensor = Tensor::from_vec(batch_shape, batch_data)?;
66
67        let result = infer_fn(&batch_tensor)?;
68        let result_shape = result.shape();
69
70        if out_sample_shape.is_none() {
71            out_sample_shape = Some(result_shape[1..].to_vec());
72        }
73
74        let out_sample_size: usize = result_shape[1..].iter().product();
75        let useful_data = &result.data()[..batch_n * out_sample_size];
76        all_outputs.extend_from_slice(useful_data);
77
78        offset += batch_n;
79    }
80
81    let sample_shape = out_sample_shape.unwrap_or_default();
82    let mut final_shape = vec![n];
83    final_shape.extend_from_slice(&sample_shape);
84    Tensor::from_vec(final_shape, all_outputs).map_err(Into::into)
85}
86
87/// Collects individual samples into batches for efficient processing.
88pub struct BatchCollector {
89    samples: Vec<Vec<f32>>,
90    sample_shape: Vec<usize>,
91    max_batch: usize,
92}
93
94impl BatchCollector {
95    pub fn new(sample_shape: Vec<usize>, max_batch: usize) -> Self {
96        Self {
97            samples: Vec::new(),
98            sample_shape,
99            max_batch,
100        }
101    }
102
103    /// Adds a sample `[...]` (must match `sample_shape`).
104    pub fn push(&mut self, sample: &Tensor) -> Result<(), ModelError> {
105        if sample.shape() != self.sample_shape {
106            return Err(ModelError::InvalidParameterShape {
107                parameter: "sample",
108                expected: self.sample_shape.clone(),
109                got: sample.shape().to_vec(),
110            });
111        }
112        self.samples.push(sample.data().to_vec());
113        Ok(())
114    }
115
116    /// Returns true if the collector has enough samples for a full batch.
117    pub fn is_ready(&self) -> bool {
118        self.samples.len() >= self.max_batch
119    }
120
121    /// Flushes collected samples as a batched tensor `[N, ...]`.
122    pub fn flush(&mut self) -> Result<Option<Tensor>, ModelError> {
123        if self.samples.is_empty() {
124            return Ok(None);
125        }
126        let n = self.samples.len().min(self.max_batch);
127        let batch: Vec<f32> = self.samples.drain(..n).flatten().collect();
128        let mut shape = vec![n];
129        shape.extend_from_slice(&self.sample_shape);
130        let t = Tensor::from_vec(shape, batch)?;
131        Ok(Some(t))
132    }
133
134    /// Number of pending samples.
135    pub fn pending(&self) -> usize {
136        self.samples.len()
137    }
138}