Skip to main content

unsloth_rs/kernels/ternary/
model.rs

1// SPDX-License-Identifier: MIT
2// Copyright 2026 Tyler Zervas
3
4//! End-to-end model quantization for ternary inference.
5//!
6//! Provides utilities to convert pretrained FP models to ternary format
7//! for memory-efficient inference.
8
9use super::config::TernaryConfig;
10use super::linear::TernaryLinear;
11use super::quantize::quantize_tensor;
12use crate::error::{Result, UnslothError};
13use candle_core::{Device, Tensor};
14use std::collections::HashMap;
15
16/// Statistics from model quantization.
17#[derive(Debug, Clone, Default)]
18pub struct QuantizationStats {
19    /// Number of layers quantized to ternary
20    pub layers_quantized: usize,
21    /// Number of layers skipped (non-linear or below threshold)
22    pub layers_skipped: usize,
23    /// Total parameters in original model (computed by `finalize_stats()`)
24    pub original_params: usize,
25    /// Total parameters in quantized model (as ternary)
26    pub quantized_params: usize,
27    /// Original model size in bytes (FP32, computed by `finalize_stats()`)
28    pub original_bytes: usize,
29    /// Total model size in bytes (includes both quantized and preserved layers)
30    pub quantized_bytes: usize,
31    /// Average sparsity across quantized layers
32    pub average_sparsity: f32,
33    /// Per-layer sparsity
34    pub layer_sparsities: HashMap<String, f32>,
35    /// Flag to track if `finalize_stats` has been called (for idempotency)
36    finalized: bool,
37}
38
39impl QuantizationStats {
40    /// Compression ratio (original / quantized).
41    ///
42    /// # Returns
43    ///
44    /// Returns 1.0 (no compression) if `quantized_bytes` is zero, indicating
45    /// no quantization occurred. Otherwise returns `original_bytes` / `quantized_bytes`.
46    #[must_use]
47    pub fn compression_ratio(&self) -> f32 {
48        if self.quantized_bytes == 0 {
49            1.0 // No compression if nothing was quantized
50        } else {
51            // Precision loss acceptable for compression ratio metric
52            #[allow(clippy::cast_precision_loss)]
53            {
54                self.original_bytes as f32 / self.quantized_bytes as f32
55            }
56        }
57    }
58
59    /// Print summary statistics.
60    pub fn print_summary(&self) {
61        println!("=== Quantization Summary ===");
62        println!("Layers quantized: {}", self.layers_quantized);
63        println!("Layers skipped: {}", self.layers_skipped);
64        println!("Original params: {}", self.original_params);
65        println!("Quantized params: {}", self.quantized_params);
66        println!(
67            "Size reduction: {:.2}x ({:.2} MB -> {:.2} MB)",
68            self.compression_ratio(),
69            self.original_bytes as f64 / 1e6,
70            self.quantized_bytes as f64 / 1e6
71        );
72        println!("Average sparsity: {:.1}%", self.average_sparsity * 100.0);
73    }
74}
75
76/// Result of quantizing a single linear layer.
77#[derive(Debug)]
78pub struct QuantizedLayer {
79    /// The ternary linear layer
80    pub layer: TernaryLinear,
81    /// Original layer name
82    pub name: String,
83    /// Sparsity of the quantized weights
84    pub sparsity: f32,
85}
86
87/// Configuration for model quantization.
88#[derive(Debug, Clone)]
89pub struct ModelQuantizationConfig {
90    /// Ternary quantization config
91    pub ternary_config: TernaryConfig,
92    /// Minimum layer size to quantize (skip small layers)
93    pub min_layer_size: usize,
94    /// Skip layers matching these patterns
95    pub skip_patterns: Vec<String>,
96    /// Verbose logging
97    pub verbose: bool,
98}
99
100impl Default for ModelQuantizationConfig {
101    fn default() -> Self {
102        Self {
103            ternary_config: TernaryConfig::default(),
104            min_layer_size: 1024, // Skip very small layers
105            skip_patterns: vec![
106                "embed".to_string(),
107                "norm".to_string(),
108                "lm_head".to_string(),
109            ],
110            verbose: false,
111        }
112    }
113}
114
115/// Quantize a single linear layer's weights to ternary.
116///
117/// # Arguments
118///
119/// * `weight` - Weight tensor [`out_features`, `in_features`]
120/// * `bias` - Optional bias tensor [`out_features`]
121/// * `name` - Layer name for logging
122/// * `config` - Quantization configuration
123/// * `_device` - Target device (currently unused; weights remain on their original device).
124///               Kept for future multi-device support and API stability.
125///
126/// # Returns
127///
128/// Quantized layer and statistics, or None if layer should be skipped
129pub fn quantize_linear_layer(
130    weight: &Tensor,
131    bias: Option<&Tensor>,
132    name: &str,
133    config: &ModelQuantizationConfig,
134    _device: &Device,
135) -> Result<Option<QuantizedLayer>> {
136    let dims = weight.dims();
137    if dims.len() != 2 {
138        return Err(UnslothError::ShapeMismatch {
139            // Expected a 2D tensor (rank 2) for [out_features, in_features]
140            expected: vec![2],
141            actual: dims.to_vec(),
142        });
143    }
144
145    let (out_features, in_features) = (dims[0], dims[1]);
146    let num_params = out_features * in_features;
147
148    // Check if layer should be skipped
149    if num_params < config.min_layer_size {
150        if config.verbose {
151            println!("Skipping {name} (too small: {num_params} params)");
152        }
153        return Ok(None);
154    }
155
156    for pattern in &config.skip_patterns {
157        if name.to_lowercase().contains(&pattern.to_lowercase()) {
158            if config.verbose {
159                println!("Skipping {name} (matches pattern: {pattern})");
160            }
161            return Ok(None);
162        }
163    }
164
165    // Quantize weights
166    let (ternary_weights, _scale) = quantize_tensor(weight, &config.ternary_config)?;
167
168    let sparsity = ternary_weights.sparsity();
169
170    if config.verbose {
171        println!(
172            "Quantizing {}: [{}, {}] -> sparsity {:.1}%",
173            name,
174            out_features,
175            in_features,
176            sparsity * 100.0
177        );
178    }
179
180    // Create ternary linear layer
181    // Note: bias.cloned() is necessary because TernaryLinear::with_config expects Option<Tensor> (owned),
182    // but we receive Option<&Tensor> (borrowed). This clone is intentional for API ergonomics.
183    let layer = TernaryLinear::with_config(ternary_weights, bias.cloned(), config.ternary_config)?;
184
185    Ok(Some(QuantizedLayer {
186        layer,
187        name: name.to_string(),
188        sparsity,
189    }))
190}
191
192/// Container for a quantized transformer model.
193#[derive(Debug)]
194pub struct TernaryModel {
195    /// Quantized linear layers by name
196    pub layers: HashMap<String, TernaryLinear>,
197    /// Non-quantized layers/tensors preserved from original
198    pub preserved_tensors: HashMap<String, Tensor>,
199    /// Quantization statistics
200    pub stats: QuantizationStats,
201    /// Configuration used
202    pub config: ModelQuantizationConfig,
203}
204
205impl TernaryModel {
206    /// Create a new empty ternary model.
207    #[must_use]
208    pub fn new(config: ModelQuantizationConfig) -> Self {
209        Self {
210            layers: HashMap::new(),
211            preserved_tensors: HashMap::new(),
212            stats: QuantizationStats::default(),
213            config,
214        }
215    }
216
217    /// Add a quantized layer.
218    pub fn add_layer(&mut self, name: String, layer: TernaryLinear, sparsity: f32) {
219        let (out_features, in_features) = layer.dims();
220        let num_params = out_features * in_features;
221
222        self.stats.layers_quantized += 1;
223        self.stats.quantized_params += num_params;
224        // Note: original_params will be accumulated in finalize_stats()
225        self.stats.quantized_bytes += layer.memory_bytes();
226        self.stats.layer_sparsities.insert(name.clone(), sparsity);
227
228        self.layers.insert(name, layer);
229    }
230
231    /// Add a preserved (non-quantized) tensor.
232    pub fn add_preserved(&mut self, name: String, tensor: Tensor) {
233        let num_params = tensor.elem_count();
234        self.stats.layers_skipped += 1;
235        self.stats.original_params += num_params;
236        self.stats.quantized_bytes += num_params * 4; // Still FP32
237
238        self.preserved_tensors.insert(name, tensor);
239    }
240
241    /// Finalize statistics after all layers added.
242    ///
243    /// This method is idempotent - calling it multiple times has no additional effect.
244    /// It computes the total original parameters and bytes from quantized and preserved layers.
245    pub fn finalize_stats(&mut self) {
246        // Guard against multiple calls
247        if self.stats.finalized {
248            return;
249        }
250
251        // Total original params = quantized + preserved
252        self.stats.original_params += self.stats.quantized_params;
253        // Total original bytes (FP32) = all params * 4 bytes
254        self.stats.original_bytes = self.stats.original_params * 4;
255
256        if !self.stats.layer_sparsities.is_empty() {
257            self.stats.average_sparsity = self.stats.layer_sparsities.values().sum::<f32>()
258                / self.stats.layer_sparsities.len() as f32;
259        }
260
261        self.stats.finalized = true;
262    }
263
264    /// Get a quantized layer by name.
265    #[must_use]
266    pub fn get_layer(&self, name: &str) -> Option<&TernaryLinear> {
267        self.layers.get(name)
268    }
269
270    /// Get a preserved tensor by name.
271    #[must_use]
272    pub fn get_preserved(&self, name: &str) -> Option<&Tensor> {
273        self.preserved_tensors.get(name)
274    }
275}
276
277/// Quantize a collection of weight tensors into a `TernaryModel`.
278///
279/// # Arguments
280///
281/// * `weights` - Map of layer names to weight tensors
282/// * `biases` - Optional map of layer names to bias tensors
283/// * `config` - Quantization configuration
284/// * `device` - Target device
285///
286/// # Returns
287///
288/// Quantized model with statistics
289pub fn quantize_weights_collection(
290    weights: HashMap<String, Tensor>,
291    biases: HashMap<String, Tensor>,
292    config: ModelQuantizationConfig,
293    device: &Device,
294) -> Result<TernaryModel> {
295    let mut model = TernaryModel::new(config);
296
297    for (name, weight) in weights {
298        let bias = biases.get(&name);
299
300        if let Some(quantized) = quantize_linear_layer(&weight, bias, &name, &model.config, device)?
301        {
302            model.add_layer(quantized.name, quantized.layer, quantized.sparsity);
303        } else {
304            // Preserve the original weight with consistent naming: {name}.weight
305            model.add_preserved(format!("{name}.weight"), weight);
306            if let Some(b) = bias {
307                model.add_preserved(format!("{name}.bias"), b.clone());
308            }
309        }
310    }
311
312    model.finalize_stats();
313
314    if model.config.verbose {
315        model.stats.print_summary();
316    }
317
318    Ok(model)
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    #[test]
326    fn test_quantization_stats() {
327        let mut stats = QuantizationStats {
328            original_bytes: 1000,
329            quantized_bytes: 100,
330            ..Default::default()
331        };
332
333        assert!((stats.compression_ratio() - 10.0).abs() < 0.001);
334    }
335
336    #[test]
337    fn test_model_quantization_config_default() {
338        let config = ModelQuantizationConfig::default();
339        assert_eq!(config.min_layer_size, 1024);
340        assert!(config.skip_patterns.contains(&"embed".to_string()));
341    }
342
343    #[test]
344    fn test_quantize_linear_layer() -> Result<()> {
345        let device = Device::Cpu;
346        let config = ModelQuantizationConfig {
347            min_layer_size: 0, // Don't skip
348            skip_patterns: vec![],
349            ..Default::default()
350        };
351
352        let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device)?;
353
354        let result = quantize_linear_layer(&weight, None, "test_layer", &config, &device)?;
355
356        assert!(result.is_some());
357        let quantized = result.unwrap();
358        assert_eq!(quantized.name, "test_layer");
359        assert!(quantized.sparsity >= 0.0 && quantized.sparsity <= 1.0);
360
361        Ok(())
362    }
363
364    #[test]
365    fn test_skip_small_layer() -> Result<()> {
366        let device = Device::Cpu;
367        let config = ModelQuantizationConfig {
368            min_layer_size: 10000, // Large threshold
369            ..Default::default()
370        };
371
372        let weight = Tensor::randn(0.0f32, 1.0, (8, 8), &device)?;
373
374        let result = quantize_linear_layer(&weight, None, "small_layer", &config, &device)?;
375
376        assert!(result.is_none());
377
378        Ok(())
379    }
380
381    #[test]
382    fn test_skip_pattern() -> Result<()> {
383        let device = Device::Cpu;
384        let config = ModelQuantizationConfig::default();
385
386        let weight = Tensor::randn(0.0f32, 1.0, (128, 128), &device)?;
387
388        let result = quantize_linear_layer(&weight, None, "model.embed_tokens", &config, &device)?;
389
390        assert!(result.is_none()); // Should skip due to "embed" pattern
391
392        Ok(())
393    }
394
395    #[test]
396    fn test_ternary_model() -> Result<()> {
397        let device = Device::Cpu;
398        let config = ModelQuantizationConfig {
399            min_layer_size: 0,
400            skip_patterns: vec![],
401            verbose: false,
402            ..Default::default()
403        };
404
405        let mut weights = HashMap::new();
406        weights.insert(
407            "layer1".to_string(),
408            Tensor::randn(0.0f32, 1.0, (64, 128), &device)?,
409        );
410        weights.insert(
411            "layer2".to_string(),
412            Tensor::randn(0.0f32, 1.0, (128, 64), &device)?,
413        );
414
415        let model = quantize_weights_collection(weights, HashMap::new(), config, &device)?;
416
417        assert_eq!(model.stats.layers_quantized, 2);
418        assert!(model.get_layer("layer1").is_some());
419        assert!(model.get_layer("layer2").is_some());
420
421        // Verify accounting: layer1 = 64*128=8192, layer2 = 128*64=8192, total = 16384
422        let expected_params = 64 * 128 + 128 * 64;
423        assert_eq!(model.stats.original_params, expected_params);
424        assert_eq!(model.stats.quantized_params, expected_params);
425        assert_eq!(model.stats.original_bytes, expected_params * 4); // FP32
426
427        Ok(())
428    }
429
430    #[test]
431    fn test_accounting_with_preserved() -> Result<()> {
432        let device = Device::Cpu;
433        let config = ModelQuantizationConfig {
434            min_layer_size: 10000, // Skip small layers
435            skip_patterns: vec![],
436            verbose: false,
437            ..Default::default()
438        };
439
440        let mut weights = HashMap::new();
441        // Large layer - will be quantized
442        weights.insert(
443            "large".to_string(),
444            Tensor::randn(0.0f32, 1.0, (256, 256), &device)?,
445        );
446        // Small layer - will be preserved
447        weights.insert(
448            "small".to_string(),
449            Tensor::randn(0.0f32, 1.0, (8, 8), &device)?,
450        );
451
452        let model = quantize_weights_collection(weights, HashMap::new(), config, &device)?;
453
454        assert_eq!(model.stats.layers_quantized, 1);
455        assert_eq!(model.stats.layers_skipped, 1);
456
457        // Verify accounting
458        let large_params = 256 * 256; // 65536
459        let small_params = 8 * 8; // 64
460        let total_params = large_params + small_params;
461
462        assert_eq!(model.stats.quantized_params, large_params);
463        assert_eq!(model.stats.original_params, total_params);
464        assert_eq!(model.stats.original_bytes, total_params * 4); // FP32
465
466        Ok(())
467    }
468
469    #[test]
470    fn test_finalize_stats_idempotent() -> Result<()> {
471        let device = Device::Cpu;
472        let config = ModelQuantizationConfig {
473            min_layer_size: 0,
474            skip_patterns: vec![],
475            verbose: false,
476            ..Default::default()
477        };
478
479        let mut weights = HashMap::new();
480        weights.insert(
481            "layer1".to_string(),
482            Tensor::randn(0.0f32, 1.0, (64, 128), &device)?,
483        );
484
485        let mut model = quantize_weights_collection(weights, HashMap::new(), config, &device)?;
486
487        // Store initial values
488        let initial_original_params = model.stats.original_params;
489        let initial_original_bytes = model.stats.original_bytes;
490
491        // Call finalize_stats again
492        model.finalize_stats();
493
494        // Values should not change (idempotent)
495        assert_eq!(model.stats.original_params, initial_original_params);
496        assert_eq!(model.stats.original_bytes, initial_original_bytes);
497
498        // Call a third time to verify
499        model.finalize_stats();
500        assert_eq!(model.stats.original_params, initial_original_params);
501
502        Ok(())
503    }
504
505    #[test]
506    fn test_compression_ratio_no_quantization() {
507        let stats = QuantizationStats::default();
508        // No quantization - should return 1.0 (no compression)
509        assert!((stats.compression_ratio() - 1.0).abs() < 0.001);
510    }
511}