Skip to main content

tritter_accel/core/
inference.rs

1//! Inference acceleration utilities.
2//!
3//! Provides tools for accelerating neural network inference:
4//! - Batched operations
5//! - Device dispatch (CPU/GPU)
6//! - Model optimization helpers
7//!
8//! # Example
9//!
10//! ```rust,ignore
11//! use tritter_accel::core::inference::{InferenceEngine, InferenceConfig};
12//! use candle_core::Device;
13//!
14//! let config = InferenceConfig::default();
15//! let engine = InferenceEngine::new(config)?;
16//!
17//! // Run batched inference
18//! let outputs = engine.forward_batch(&inputs)?;
19//! ```
20
21use candle_core::{Device, Tensor};
22use thiserror::Error;
23
24use super::quantization::{quantize_absmean, QuantizationError, QuantizeConfig};
25use super::ternary::{matmul, PackedTernary, TernaryError, TernaryMatmulConfig};
26
27/// Errors from inference operations.
28#[derive(Debug, Error)]
29pub enum InferenceError {
30    /// Configuration error.
31    #[error("config error: {0}")]
32    Config(String),
33
34    /// Device error.
35    #[error("device error: {0}")]
36    Device(String),
37
38    /// Shape mismatch.
39    #[error("shape mismatch: {0}")]
40    Shape(String),
41
42    /// Tensor error.
43    #[error("tensor error: {0}")]
44    Tensor(#[from] candle_core::Error),
45
46    /// Ternary error.
47    #[error("ternary error: {0}")]
48    Ternary(#[from] TernaryError),
49
50    /// Quantization error.
51    #[error("quantization error: {0}")]
52    Quantization(#[from] QuantizationError),
53}
54
55/// Configuration for inference engine.
56#[derive(Debug, Clone)]
57pub struct InferenceConfig {
58    /// Preferred device.
59    pub device: DeviceType,
60    /// Maximum batch size for batched operations.
61    pub max_batch_size: usize,
62    /// Enable weight quantization.
63    pub quantize_weights: bool,
64    /// Enable activation caching.
65    pub cache_activations: bool,
66}
67
68/// Device type for inference.
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70pub enum DeviceType {
71    /// Automatic device selection.
72    Auto,
73    /// Force CPU.
74    Cpu,
75    /// Force GPU with optional device index.
76    Gpu(Option<usize>),
77}
78
79impl Default for InferenceConfig {
80    fn default() -> Self {
81        Self {
82            device: DeviceType::Auto,
83            max_batch_size: 32,
84            quantize_weights: false,
85            cache_activations: false,
86        }
87    }
88}
89
90impl InferenceConfig {
91    /// Set device type.
92    pub fn with_device(mut self, device: DeviceType) -> Self {
93        self.device = device;
94        self
95    }
96
97    /// Set max batch size.
98    pub fn with_max_batch_size(mut self, size: usize) -> Self {
99        self.max_batch_size = size;
100        self
101    }
102
103    /// Enable weight quantization.
104    pub fn with_quantization(mut self, enabled: bool) -> Self {
105        self.quantize_weights = enabled;
106        self
107    }
108}
109
110/// Inference engine for accelerated forward passes.
111#[derive(Debug)]
112pub struct InferenceEngine {
113    config: InferenceConfig,
114    device: Device,
115}
116
117impl InferenceEngine {
118    /// Create a new inference engine.
119    pub fn new(config: InferenceConfig) -> Result<Self, InferenceError> {
120        let device = match config.device {
121            DeviceType::Cpu => Device::Cpu,
122            DeviceType::Auto => {
123                #[cfg(feature = "cuda")]
124                {
125                    Device::cuda_if_available(0).unwrap_or(Device::Cpu)
126                }
127                #[cfg(not(feature = "cuda"))]
128                {
129                    Device::Cpu
130                }
131            }
132            DeviceType::Gpu(ordinal) => {
133                #[cfg(feature = "cuda")]
134                {
135                    let idx = ordinal.unwrap_or(0);
136                    Device::new_cuda(idx)
137                        .map_err(|e| InferenceError::Device(format!("CUDA device {idx}: {e}")))?
138                }
139                #[cfg(not(feature = "cuda"))]
140                {
141                    let _ = ordinal;
142                    return Err(InferenceError::Device(
143                        "CUDA not compiled. Rebuild with --features cuda".to_string(),
144                    ));
145                }
146            }
147        };
148
149        Ok(Self { config, device })
150    }
151
152    /// Get the active device.
153    pub fn device(&self) -> &Device {
154        &self.device
155    }
156
157    /// Check if running on GPU.
158    pub fn is_gpu(&self) -> bool {
159        matches!(self.device, Device::Cuda(_))
160    }
161
162    /// Linear layer forward pass with optional ternary quantization.
163    ///
164    /// Computes: output = input @ weight.T + bias
165    pub fn linear(
166        &self,
167        input: &Tensor,
168        weight: &Tensor,
169        bias: Option<&Tensor>,
170    ) -> Result<Tensor, InferenceError> {
171        // Move tensors to device if needed
172        let input = input.to_device(&self.device)?;
173        let weight = weight.to_device(&self.device)?;
174
175        // Compute matmul
176        let output = input.matmul(&weight.t()?)?;
177
178        // Add bias if present
179        let output = if let Some(b) = bias {
180            let b = b.to_device(&self.device)?;
181            output.broadcast_add(&b)?
182        } else {
183            output
184        };
185
186        Ok(output)
187    }
188
189    /// Ternary linear layer (quantized weights).
190    ///
191    /// Quantizes weights to ternary for memory-efficient inference.
192    pub fn ternary_linear(
193        &self,
194        input: &Tensor,
195        weight: &Tensor,
196        bias: Option<&Tensor>,
197    ) -> Result<Tensor, InferenceError> {
198        // Quantize weights
199        let quant_config = QuantizeConfig::default();
200        let quantized = quantize_absmean(weight, &quant_config)?;
201        let packed = quantized.to_packed()?;
202
203        // Move input to device
204        let input = input.to_device(&self.device)?;
205
206        // Ternary matmul
207        let matmul_config = TernaryMatmulConfig::default();
208        let output = matmul(&input, &packed, Some(&matmul_config))?;
209
210        // Add bias
211        let output = if let Some(b) = bias {
212            let b = b.to_device(&self.device)?;
213            output.broadcast_add(&b)?
214        } else {
215            output
216        };
217
218        Ok(output)
219    }
220
221    /// Batched inference with automatic chunking.
222    ///
223    /// Splits large batches into smaller chunks to fit in memory.
224    pub fn batched_forward<F>(
225        &self,
226        inputs: &Tensor,
227        forward_fn: F,
228    ) -> Result<Tensor, InferenceError>
229    where
230        F: Fn(&Tensor) -> Result<Tensor, InferenceError>,
231    {
232        let batch_size = inputs.dim(0)?;
233
234        if batch_size <= self.config.max_batch_size {
235            return forward_fn(inputs);
236        }
237
238        // Split into chunks
239        let mut outputs = Vec::new();
240        let mut start = 0;
241
242        while start < batch_size {
243            let end = (start + self.config.max_batch_size).min(batch_size);
244            let chunk = inputs.narrow(0, start, end - start)?;
245            let output = forward_fn(&chunk)?;
246            outputs.push(output);
247            start = end;
248        }
249
250        // Concatenate outputs
251        Ok(Tensor::cat(&outputs, 0)?)
252    }
253
254    /// Apply softmax along specified dimension.
255    pub fn softmax(&self, input: &Tensor, dim: usize) -> Result<Tensor, InferenceError> {
256        let input = input.to_device(&self.device)?;
257        Ok(candle_nn::ops::softmax(&input, dim)?)
258    }
259
260    /// Apply layer normalization.
261    pub fn layer_norm(
262        &self,
263        input: &Tensor,
264        weight: &Tensor,
265        bias: &Tensor,
266        eps: f64,
267    ) -> Result<Tensor, InferenceError> {
268        let input = input.to_device(&self.device)?;
269        let weight = weight.to_device(&self.device)?;
270        let bias = bias.to_device(&self.device)?;
271
272        // Compute mean and variance along last dimension
273        let dim = input.dims().len() - 1;
274        let mean = input.mean_keepdim(dim)?;
275        let var = input
276            .broadcast_sub(&mean)?
277            .sqr()?
278            .mean_keepdim(dim)?;
279
280        // Normalize
281        let normalized = input
282            .broadcast_sub(&mean)?
283            .broadcast_div(&(var + eps)?.sqrt()?)?;
284
285        // Scale and shift
286        Ok(normalized.broadcast_mul(&weight)?.broadcast_add(&bias)?)
287    }
288}
289
290/// Pre-computed ternary layer for repeated inference.
291///
292/// Stores quantized weights to avoid re-quantization overhead.
293#[derive(Debug)]
294pub struct TernaryLayer {
295    /// Packed ternary weights.
296    pub weights: PackedTernary,
297    /// Bias (optional).
298    pub bias: Option<Vec<f32>>,
299    /// Input features.
300    pub in_features: usize,
301    /// Output features.
302    pub out_features: usize,
303}
304
305impl TernaryLayer {
306    /// Create from float weight tensor.
307    pub fn from_tensor(
308        weight: &Tensor,
309        bias: Option<&Tensor>,
310    ) -> Result<Self, InferenceError> {
311        let (out_features, in_features) = weight.dims2()?;
312
313        // Quantize weights
314        let quant_config = QuantizeConfig::default();
315        let quantized = quantize_absmean(weight, &quant_config)?;
316        let weights = quantized.to_packed()?;
317
318        // Extract bias if present
319        let bias = if let Some(b) = bias {
320            Some(b.flatten_all()?.to_vec1()?)
321        } else {
322            None
323        };
324
325        Ok(Self {
326            weights,
327            bias,
328            in_features,
329            out_features,
330        })
331    }
332
333    /// Forward pass.
334    pub fn forward(&self, input: &Tensor) -> Result<Tensor, InferenceError> {
335        let matmul_config = TernaryMatmulConfig::default();
336        let output = matmul(input, &self.weights, Some(&matmul_config))?;
337
338        // Add bias
339        if let Some(ref bias) = self.bias {
340            let bias_tensor = Tensor::from_vec(bias.clone(), self.out_features, input.device())?;
341            Ok(output.broadcast_add(&bias_tensor)?)
342        } else {
343            Ok(output)
344        }
345    }
346
347    /// Memory usage in bytes.
348    pub fn memory_bytes(&self) -> usize {
349        // Packed weights: 2 bits per value
350        let weight_bits = self.in_features * self.out_features * 2;
351        let weight_bytes = weight_bits.div_ceil(8);
352
353        // Scales: f32 per row
354        let scale_bytes = self.out_features * 4;
355
356        // Bias: f32 per output
357        let bias_bytes = self.bias.as_ref().map(|b| b.len() * 4).unwrap_or(0);
358
359        weight_bytes + scale_bytes + bias_bytes
360    }
361
362    /// Original (unquantized) memory usage for comparison.
363    pub fn original_memory_bytes(&self) -> usize {
364        // f32 weights + f32 bias
365        let weight_bytes = self.in_features * self.out_features * 4;
366        let bias_bytes = self.bias.as_ref().map(|b| b.len() * 4).unwrap_or(0);
367        weight_bytes + bias_bytes
368    }
369
370    /// Compression ratio achieved.
371    #[allow(clippy::cast_precision_loss)]
372    pub fn compression_ratio(&self) -> f32 {
373        self.original_memory_bytes() as f32 / self.memory_bytes() as f32
374    }
375}
376
377/// KV cache for efficient autoregressive inference.
378#[derive(Debug)]
379pub struct KVCache {
380    /// Cached keys.
381    keys: Vec<Tensor>,
382    /// Cached values.
383    values: Vec<Tensor>,
384    /// Maximum sequence length.
385    max_seq_len: usize,
386    /// Current sequence length.
387    seq_len: usize,
388}
389
390impl KVCache {
391    /// Create a new KV cache.
392    pub fn new(max_seq_len: usize) -> Self {
393        Self {
394            keys: Vec::new(),
395            values: Vec::new(),
396            max_seq_len,
397            seq_len: 0,
398        }
399    }
400
401    /// Update cache with new key-value pairs.
402    pub fn update(
403        &mut self,
404        new_keys: Tensor,
405        new_values: Tensor,
406    ) -> Result<(Tensor, Tensor), InferenceError> {
407        // Append new KV pairs
408        self.keys.push(new_keys);
409        self.values.push(new_values);
410        self.seq_len += 1;
411
412        // Concatenate all cached KVs
413        let all_keys = Tensor::cat(&self.keys, 1)?;
414        let all_values = Tensor::cat(&self.values, 1)?;
415
416        // Trim if exceeds max length
417        if self.seq_len > self.max_seq_len {
418            self.keys.remove(0);
419            self.values.remove(0);
420            self.seq_len = self.max_seq_len;
421        }
422
423        Ok((all_keys, all_values))
424    }
425
426    /// Clear the cache.
427    pub fn clear(&mut self) {
428        self.keys.clear();
429        self.values.clear();
430        self.seq_len = 0;
431    }
432
433    /// Current sequence length.
434    pub fn len(&self) -> usize {
435        self.seq_len
436    }
437
438    /// Check if cache is empty.
439    pub fn is_empty(&self) -> bool {
440        self.seq_len == 0
441    }
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447
448    #[test]
449    fn test_inference_engine_creation() {
450        let config = InferenceConfig::default().with_device(DeviceType::Cpu);
451        let engine = InferenceEngine::new(config).unwrap();
452
453        assert!(!engine.is_gpu());
454    }
455
456    #[test]
457    fn test_linear_forward() {
458        let config = InferenceConfig::default().with_device(DeviceType::Cpu);
459        let engine = InferenceEngine::new(config).unwrap();
460
461        let input = Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], (1, 4), engine.device()).unwrap();
462        let weight =
463            Tensor::from_vec(vec![1.0f32; 8], (2, 4), engine.device()).unwrap();
464
465        let output = engine.linear(&input, &weight, None).unwrap();
466
467        assert_eq!(output.dims(), &[1, 2]);
468    }
469
470    #[test]
471    fn test_ternary_layer() {
472        let device = Device::Cpu;
473        let weight = Tensor::randn(0f32, 1f32, (16, 32), &device).unwrap();
474
475        let layer = TernaryLayer::from_tensor(&weight, None).unwrap();
476
477        // Check compression
478        assert!(layer.compression_ratio() > 10.0);
479
480        // Test forward pass
481        let input = Tensor::randn(0f32, 1f32, (1, 32), &device).unwrap();
482        let output = layer.forward(&input).unwrap();
483
484        assert_eq!(output.dims(), &[1, 16]);
485    }
486
487    #[test]
488    fn test_kv_cache() {
489        let mut cache = KVCache::new(4);
490
491        assert!(cache.is_empty());
492
493        let device = Device::Cpu;
494        let k1 = Tensor::zeros((1, 1, 8), candle_core::DType::F32, &device).unwrap();
495        let v1 = Tensor::zeros((1, 1, 8), candle_core::DType::F32, &device).unwrap();
496
497        let (keys, values) = cache.update(k1, v1).unwrap();
498
499        assert_eq!(cache.len(), 1);
500        assert_eq!(keys.dim(1).unwrap(), 1);
501        assert_eq!(values.dim(1).unwrap(), 1);
502    }
503
504    #[test]
505    fn test_batched_forward() {
506        let config = InferenceConfig::default()
507            .with_device(DeviceType::Cpu)
508            .with_max_batch_size(2);
509        let engine = InferenceEngine::new(config).unwrap();
510
511        // Create input with 5 samples (larger than max_batch_size)
512        let input = Tensor::randn(0f32, 1f32, (5, 4), engine.device()).unwrap();
513
514        // Simple identity-like forward function
515        let output = engine
516            .batched_forward(&input, |x| Ok(x.clone()))
517            .unwrap();
518
519        assert_eq!(output.dims(), &[5, 4]);
520    }
521}