yscv_model/
batch_infer.rs1use yscv_tensor::Tensor;
2
3use crate::ModelError;
4
5#[derive(Debug, Clone)]
7pub struct DynamicBatchConfig {
8 pub max_batch_size: usize,
10 pub pad_incomplete: bool,
12}
13
14impl Default for DynamicBatchConfig {
15 fn default() -> Self {
16 Self {
17 max_batch_size: 32,
18 pad_incomplete: false,
19 }
20 }
21}
22
23pub fn batched_inference<F>(
27 input: &Tensor,
28 config: &DynamicBatchConfig,
29 infer_fn: F,
30) -> Result<Tensor, ModelError>
31where
32 F: Fn(&Tensor) -> Result<Tensor, ModelError>,
33{
34 let shape = input.shape();
35 if shape.is_empty() {
36 return Err(ModelError::InvalidFlattenShape {
37 got: shape.to_vec(),
38 });
39 }
40 let n = shape[0];
41 let sample_size: usize = shape[1..].iter().product();
42 let data = input.data();
43 let bs = config.max_batch_size;
44
45 let mut all_outputs: Vec<f32> = Vec::new();
46 let mut out_sample_shape: Option<Vec<usize>> = None;
47
48 let mut offset = 0;
49 while offset < n {
50 let batch_n = (n - offset).min(bs);
51 let start = offset * sample_size;
52 let end = (offset + batch_n) * sample_size;
53 let mut batch_data = data[start..end].to_vec();
54
55 let actual_n = if config.pad_incomplete && batch_n < bs {
56 let pad = (bs - batch_n) * sample_size;
57 batch_data.extend(std::iter::repeat_n(0.0f32, pad));
58 bs
59 } else {
60 batch_n
61 };
62
63 let mut batch_shape = shape.to_vec();
64 batch_shape[0] = actual_n;
65 let batch_tensor = Tensor::from_vec(batch_shape, batch_data)?;
66
67 let result = infer_fn(&batch_tensor)?;
68 let result_shape = result.shape();
69
70 if out_sample_shape.is_none() {
71 out_sample_shape = Some(result_shape[1..].to_vec());
72 }
73
74 let out_sample_size: usize = result_shape[1..].iter().product();
75 let useful_data = &result.data()[..batch_n * out_sample_size];
76 all_outputs.extend_from_slice(useful_data);
77
78 offset += batch_n;
79 }
80
81 let sample_shape = out_sample_shape.unwrap_or_default();
82 let mut final_shape = vec![n];
83 final_shape.extend_from_slice(&sample_shape);
84 Tensor::from_vec(final_shape, all_outputs).map_err(Into::into)
85}
86
87pub struct BatchCollector {
89 samples: Vec<Vec<f32>>,
90 sample_shape: Vec<usize>,
91 max_batch: usize,
92}
93
94impl BatchCollector {
95 pub fn new(sample_shape: Vec<usize>, max_batch: usize) -> Self {
96 Self {
97 samples: Vec::new(),
98 sample_shape,
99 max_batch,
100 }
101 }
102
103 pub fn push(&mut self, sample: &Tensor) -> Result<(), ModelError> {
105 if sample.shape() != self.sample_shape {
106 return Err(ModelError::InvalidParameterShape {
107 parameter: "sample",
108 expected: self.sample_shape.clone(),
109 got: sample.shape().to_vec(),
110 });
111 }
112 self.samples.push(sample.data().to_vec());
113 Ok(())
114 }
115
116 pub fn is_ready(&self) -> bool {
118 self.samples.len() >= self.max_batch
119 }
120
121 pub fn flush(&mut self) -> Result<Option<Tensor>, ModelError> {
123 if self.samples.is_empty() {
124 return Ok(None);
125 }
126 let n = self.samples.len().min(self.max_batch);
127 let batch: Vec<f32> = self.samples.drain(..n).flatten().collect();
128 let mut shape = vec![n];
129 shape.extend_from_slice(&self.sample_shape);
130 let t = Tensor::from_vec(shape, batch)?;
131 Ok(Some(t))
132 }
133
134 pub fn pending(&self) -> usize {
136 self.samples.len()
137 }
138}