1use candle_core::{Device, Tensor};
22use thiserror::Error;
23
24use super::quantization::{quantize_absmean, QuantizationError, QuantizeConfig};
25use super::ternary::{matmul, PackedTernary, TernaryError, TernaryMatmulConfig};
26
27#[derive(Debug, Error)]
29pub enum InferenceError {
30 #[error("config error: {0}")]
32 Config(String),
33
34 #[error("device error: {0}")]
36 Device(String),
37
38 #[error("shape mismatch: {0}")]
40 Shape(String),
41
42 #[error("tensor error: {0}")]
44 Tensor(#[from] candle_core::Error),
45
46 #[error("ternary error: {0}")]
48 Ternary(#[from] TernaryError),
49
50 #[error("quantization error: {0}")]
52 Quantization(#[from] QuantizationError),
53}
54
55#[derive(Debug, Clone)]
57pub struct InferenceConfig {
58 pub device: DeviceType,
60 pub max_batch_size: usize,
62 pub quantize_weights: bool,
64 pub cache_activations: bool,
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70pub enum DeviceType {
71 Auto,
73 Cpu,
75 Gpu(Option<usize>),
77}
78
79impl Default for InferenceConfig {
80 fn default() -> Self {
81 Self {
82 device: DeviceType::Auto,
83 max_batch_size: 32,
84 quantize_weights: false,
85 cache_activations: false,
86 }
87 }
88}
89
90impl InferenceConfig {
91 pub fn with_device(mut self, device: DeviceType) -> Self {
93 self.device = device;
94 self
95 }
96
97 pub fn with_max_batch_size(mut self, size: usize) -> Self {
99 self.max_batch_size = size;
100 self
101 }
102
103 pub fn with_quantization(mut self, enabled: bool) -> Self {
105 self.quantize_weights = enabled;
106 self
107 }
108}
109
110#[derive(Debug)]
112pub struct InferenceEngine {
113 config: InferenceConfig,
114 device: Device,
115}
116
117impl InferenceEngine {
118 pub fn new(config: InferenceConfig) -> Result<Self, InferenceError> {
120 let device = match config.device {
121 DeviceType::Cpu => Device::Cpu,
122 DeviceType::Auto => {
123 #[cfg(feature = "cuda")]
124 {
125 Device::cuda_if_available(0).unwrap_or(Device::Cpu)
126 }
127 #[cfg(not(feature = "cuda"))]
128 {
129 Device::Cpu
130 }
131 }
132 DeviceType::Gpu(ordinal) => {
133 #[cfg(feature = "cuda")]
134 {
135 let idx = ordinal.unwrap_or(0);
136 Device::new_cuda(idx)
137 .map_err(|e| InferenceError::Device(format!("CUDA device {idx}: {e}")))?
138 }
139 #[cfg(not(feature = "cuda"))]
140 {
141 let _ = ordinal;
142 return Err(InferenceError::Device(
143 "CUDA not compiled. Rebuild with --features cuda".to_string(),
144 ));
145 }
146 }
147 };
148
149 Ok(Self { config, device })
150 }
151
152 pub fn device(&self) -> &Device {
154 &self.device
155 }
156
157 pub fn is_gpu(&self) -> bool {
159 matches!(self.device, Device::Cuda(_))
160 }
161
162 pub fn linear(
166 &self,
167 input: &Tensor,
168 weight: &Tensor,
169 bias: Option<&Tensor>,
170 ) -> Result<Tensor, InferenceError> {
171 let input = input.to_device(&self.device)?;
173 let weight = weight.to_device(&self.device)?;
174
175 let output = input.matmul(&weight.t()?)?;
177
178 let output = if let Some(b) = bias {
180 let b = b.to_device(&self.device)?;
181 output.broadcast_add(&b)?
182 } else {
183 output
184 };
185
186 Ok(output)
187 }
188
189 pub fn ternary_linear(
193 &self,
194 input: &Tensor,
195 weight: &Tensor,
196 bias: Option<&Tensor>,
197 ) -> Result<Tensor, InferenceError> {
198 let quant_config = QuantizeConfig::default();
200 let quantized = quantize_absmean(weight, &quant_config)?;
201 let packed = quantized.to_packed()?;
202
203 let input = input.to_device(&self.device)?;
205
206 let matmul_config = TernaryMatmulConfig::default();
208 let output = matmul(&input, &packed, Some(&matmul_config))?;
209
210 let output = if let Some(b) = bias {
212 let b = b.to_device(&self.device)?;
213 output.broadcast_add(&b)?
214 } else {
215 output
216 };
217
218 Ok(output)
219 }
220
221 pub fn batched_forward<F>(
225 &self,
226 inputs: &Tensor,
227 forward_fn: F,
228 ) -> Result<Tensor, InferenceError>
229 where
230 F: Fn(&Tensor) -> Result<Tensor, InferenceError>,
231 {
232 let batch_size = inputs.dim(0)?;
233
234 if batch_size <= self.config.max_batch_size {
235 return forward_fn(inputs);
236 }
237
238 let mut outputs = Vec::new();
240 let mut start = 0;
241
242 while start < batch_size {
243 let end = (start + self.config.max_batch_size).min(batch_size);
244 let chunk = inputs.narrow(0, start, end - start)?;
245 let output = forward_fn(&chunk)?;
246 outputs.push(output);
247 start = end;
248 }
249
250 Ok(Tensor::cat(&outputs, 0)?)
252 }
253
254 pub fn softmax(&self, input: &Tensor, dim: usize) -> Result<Tensor, InferenceError> {
256 let input = input.to_device(&self.device)?;
257 Ok(candle_nn::ops::softmax(&input, dim)?)
258 }
259
260 pub fn layer_norm(
262 &self,
263 input: &Tensor,
264 weight: &Tensor,
265 bias: &Tensor,
266 eps: f64,
267 ) -> Result<Tensor, InferenceError> {
268 let input = input.to_device(&self.device)?;
269 let weight = weight.to_device(&self.device)?;
270 let bias = bias.to_device(&self.device)?;
271
272 let dim = input.dims().len() - 1;
274 let mean = input.mean_keepdim(dim)?;
275 let var = input
276 .broadcast_sub(&mean)?
277 .sqr()?
278 .mean_keepdim(dim)?;
279
280 let normalized = input
282 .broadcast_sub(&mean)?
283 .broadcast_div(&(var + eps)?.sqrt()?)?;
284
285 Ok(normalized.broadcast_mul(&weight)?.broadcast_add(&bias)?)
287 }
288}
289
290#[derive(Debug)]
294pub struct TernaryLayer {
295 pub weights: PackedTernary,
297 pub bias: Option<Vec<f32>>,
299 pub in_features: usize,
301 pub out_features: usize,
303}
304
305impl TernaryLayer {
306 pub fn from_tensor(
308 weight: &Tensor,
309 bias: Option<&Tensor>,
310 ) -> Result<Self, InferenceError> {
311 let (out_features, in_features) = weight.dims2()?;
312
313 let quant_config = QuantizeConfig::default();
315 let quantized = quantize_absmean(weight, &quant_config)?;
316 let weights = quantized.to_packed()?;
317
318 let bias = if let Some(b) = bias {
320 Some(b.flatten_all()?.to_vec1()?)
321 } else {
322 None
323 };
324
325 Ok(Self {
326 weights,
327 bias,
328 in_features,
329 out_features,
330 })
331 }
332
333 pub fn forward(&self, input: &Tensor) -> Result<Tensor, InferenceError> {
335 let matmul_config = TernaryMatmulConfig::default();
336 let output = matmul(input, &self.weights, Some(&matmul_config))?;
337
338 if let Some(ref bias) = self.bias {
340 let bias_tensor = Tensor::from_vec(bias.clone(), self.out_features, input.device())?;
341 Ok(output.broadcast_add(&bias_tensor)?)
342 } else {
343 Ok(output)
344 }
345 }
346
347 pub fn memory_bytes(&self) -> usize {
349 let weight_bits = self.in_features * self.out_features * 2;
351 let weight_bytes = weight_bits.div_ceil(8);
352
353 let scale_bytes = self.out_features * 4;
355
356 let bias_bytes = self.bias.as_ref().map(|b| b.len() * 4).unwrap_or(0);
358
359 weight_bytes + scale_bytes + bias_bytes
360 }
361
362 pub fn original_memory_bytes(&self) -> usize {
364 let weight_bytes = self.in_features * self.out_features * 4;
366 let bias_bytes = self.bias.as_ref().map(|b| b.len() * 4).unwrap_or(0);
367 weight_bytes + bias_bytes
368 }
369
370 #[allow(clippy::cast_precision_loss)]
372 pub fn compression_ratio(&self) -> f32 {
373 self.original_memory_bytes() as f32 / self.memory_bytes() as f32
374 }
375}
376
377#[derive(Debug)]
379pub struct KVCache {
380 keys: Vec<Tensor>,
382 values: Vec<Tensor>,
384 max_seq_len: usize,
386 seq_len: usize,
388}
389
390impl KVCache {
391 pub fn new(max_seq_len: usize) -> Self {
393 Self {
394 keys: Vec::new(),
395 values: Vec::new(),
396 max_seq_len,
397 seq_len: 0,
398 }
399 }
400
401 pub fn update(
403 &mut self,
404 new_keys: Tensor,
405 new_values: Tensor,
406 ) -> Result<(Tensor, Tensor), InferenceError> {
407 self.keys.push(new_keys);
409 self.values.push(new_values);
410 self.seq_len += 1;
411
412 let all_keys = Tensor::cat(&self.keys, 1)?;
414 let all_values = Tensor::cat(&self.values, 1)?;
415
416 if self.seq_len > self.max_seq_len {
418 self.keys.remove(0);
419 self.values.remove(0);
420 self.seq_len = self.max_seq_len;
421 }
422
423 Ok((all_keys, all_values))
424 }
425
426 pub fn clear(&mut self) {
428 self.keys.clear();
429 self.values.clear();
430 self.seq_len = 0;
431 }
432
433 pub fn len(&self) -> usize {
435 self.seq_len
436 }
437
438 pub fn is_empty(&self) -> bool {
440 self.seq_len == 0
441 }
442}
443
444#[cfg(test)]
445mod tests {
446 use super::*;
447
448 #[test]
449 fn test_inference_engine_creation() {
450 let config = InferenceConfig::default().with_device(DeviceType::Cpu);
451 let engine = InferenceEngine::new(config).unwrap();
452
453 assert!(!engine.is_gpu());
454 }
455
456 #[test]
457 fn test_linear_forward() {
458 let config = InferenceConfig::default().with_device(DeviceType::Cpu);
459 let engine = InferenceEngine::new(config).unwrap();
460
461 let input = Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], (1, 4), engine.device()).unwrap();
462 let weight =
463 Tensor::from_vec(vec![1.0f32; 8], (2, 4), engine.device()).unwrap();
464
465 let output = engine.linear(&input, &weight, None).unwrap();
466
467 assert_eq!(output.dims(), &[1, 2]);
468 }
469
470 #[test]
471 fn test_ternary_layer() {
472 let device = Device::Cpu;
473 let weight = Tensor::randn(0f32, 1f32, (16, 32), &device).unwrap();
474
475 let layer = TernaryLayer::from_tensor(&weight, None).unwrap();
476
477 assert!(layer.compression_ratio() > 10.0);
479
480 let input = Tensor::randn(0f32, 1f32, (1, 32), &device).unwrap();
482 let output = layer.forward(&input).unwrap();
483
484 assert_eq!(output.dims(), &[1, 16]);
485 }
486
487 #[test]
488 fn test_kv_cache() {
489 let mut cache = KVCache::new(4);
490
491 assert!(cache.is_empty());
492
493 let device = Device::Cpu;
494 let k1 = Tensor::zeros((1, 1, 8), candle_core::DType::F32, &device).unwrap();
495 let v1 = Tensor::zeros((1, 1, 8), candle_core::DType::F32, &device).unwrap();
496
497 let (keys, values) = cache.update(k1, v1).unwrap();
498
499 assert_eq!(cache.len(), 1);
500 assert_eq!(keys.dim(1).unwrap(), 1);
501 assert_eq!(values.dim(1).unwrap(), 1);
502 }
503
504 #[test]
505 fn test_batched_forward() {
506 let config = InferenceConfig::default()
507 .with_device(DeviceType::Cpu)
508 .with_max_batch_size(2);
509 let engine = InferenceEngine::new(config).unwrap();
510
511 let input = Tensor::randn(0f32, 1f32, (5, 4), engine.device()).unwrap();
513
514 let output = engine
516 .batched_forward(&input, |x| Ok(x.clone()))
517 .unwrap();
518
519 assert_eq!(output.dims(), &[5, 4]);
520 }
521}