Skip to main content

tl_gpu/
batch.rs

1// BatchInference — auto-batching for GPU inference
2
3use tl_ai::{TlModel, TlTensor, predict_batch};
4
5/// Batch inference helper. Currently delegates to CPU-based predict_batch.
6/// Future: keep intermediate data on GPU between batches.
7pub struct BatchInference;
8
9impl BatchInference {
10    /// Run batched prediction. Splits input into batches and collects results.
11    pub fn batch_predict(
12        model: &TlModel,
13        input: &TlTensor,
14        batch_size: Option<usize>,
15    ) -> Result<TlTensor, String> {
16        let bs = batch_size.unwrap_or(32);
17        predict_batch(model, input, bs)
18    }
19}
20
21#[cfg(test)]
22mod tests {
23    // Batch inference tests require a model file, tested at integration level
24}