tl_gpu/batch.rs
1// BatchInference — auto-batching for GPU inference
2
3use tl_ai::{TlModel, TlTensor, predict_batch};
4
5/// Batch inference helper. Currently delegates to CPU-based predict_batch.
6/// Future: keep intermediate data on GPU between batches.
7pub struct BatchInference;
8
9impl BatchInference {
10 /// Run batched prediction. Splits input into batches and collects results.
11 pub fn batch_predict(
12 model: &TlModel,
13 input: &TlTensor,
14 batch_size: Option<usize>,
15 ) -> Result<TlTensor, String> {
16 let bs = batch_size.unwrap_or(32);
17 predict_batch(model, input, bs)
18 }
19}
20
21#[cfg(test)]
22mod tests {
23 // Batch inference tests require a model file, tested at integration level
24}