Skip to main content

unsloth_rs/kernels/ternary/
quantize.rs

1// SPDX-License-Identifier: MIT
2// Copyright 2026 Tyler Zervas
3
4//! FP → Ternary quantization with scale calibration.
5//!
6//! This module implements TWN-style (Ternary Weight Networks) quantization
7//! to convert floating-point weights to ternary {-1, 0, +1} representation.
8//!
9//! ## Quantization Formula
10//!
11//! For a weight tensor W with threshold Δ:
12//!
13//! ```text
14//! W_ternary[i] = +1  if W[i] > Δ
15//!              =  0  if |W[i]| ≤ Δ
16//!              = -1  if W[i] < -Δ
17//!
18//! scale = mean(|W[i]|) for i where |W[i]| > Δ
19//! ```
20//!
21//! ## Calibration Methods
22//!
23//! - **`AbsMax`**: Δ = α × max(|W|), where α ∈ [0, 1] (typically 0.7)
24//! - **Percentile**: Δ = percentile(|W|, p), e.g., p=99.5
25//! - **`MeanStd`**: Δ = mean(|W|) + k × std(|W|)
26//!
27//! ## References
28//!
29//! - Li et al., "Ternary Weight Networks" (2016)
30//! - Mellempudi et al., "Ternary Neural Networks with Fine-Grained Quantization" (2017)
31
32use super::config::{CalibrationMethodConfig, TernaryConfig};
33use super::types::TernaryTensor;
34use crate::error::{Result, UnslothError};
35use candle_core::{DType, Tensor};
36
37/// Calibration method for determining quantization threshold.
38#[derive(Debug, Clone, Copy)]
39pub enum CalibrationMethod {
40    /// Threshold = factor × max(|W|).
41    /// Factor typically 0.7 for TWN.
42    AbsMax {
43        /// Scaling factor applied to max absolute value (typically 0.7).
44        factor: f32,
45    },
46
47    /// Threshold = percentile of |W|.
48    /// Percentile typically 99.5 to exclude outliers.
49    Percentile {
50        /// Percentile value (0-100) for threshold selection.
51        percentile: f32,
52    },
53
54    /// Threshold = mean(|W|) + k × std(|W|).
55    /// k typically 1.0-2.0.
56    MeanStd {
57        /// Standard deviation multiplier.
58        k: f32,
59    },
60
61    /// Fixed threshold value.
62    Manual {
63        /// Fixed threshold for ternary quantization.
64        threshold: f32,
65    },
66}
67
68impl Default for CalibrationMethod {
69    fn default() -> Self {
70        Self::AbsMax { factor: 0.7 }
71    }
72}
73
74impl From<CalibrationMethodConfig> for CalibrationMethod {
75    fn from(config: CalibrationMethodConfig) -> Self {
76        match config {
77            CalibrationMethodConfig::AbsMax => Self::AbsMax { factor: 0.7 },
78            CalibrationMethodConfig::Percentile(p) => Self::Percentile { percentile: p },
79            CalibrationMethodConfig::MeanStd(k) => Self::MeanStd { k },
80            CalibrationMethodConfig::Manual(t) => Self::Manual { threshold: t },
81        }
82    }
83}
84
85/// Statistics computed during quantization.
86#[derive(Debug, Clone)]
87pub struct QuantizationStats {
88    /// Fraction of weights quantized to 0.
89    pub sparsity: f32,
90
91    /// Fraction of weights quantized to +1.
92    pub positive_ratio: f32,
93
94    /// Fraction of weights quantized to -1.
95    pub negative_ratio: f32,
96
97    /// Per-channel thresholds used.
98    pub thresholds: Vec<f32>,
99
100    /// Per-channel scales computed.
101    pub scales: Vec<f32>,
102
103    /// Mean absolute quantization error.
104    pub mean_error: f32,
105
106    /// Max absolute quantization error.
107    pub max_error: f32,
108}
109
110/// Quantize a 2D tensor to ternary representation.
111///
112/// # Arguments
113///
114/// * `tensor` - Input tensor [`out_features`, `in_features`] (must be 2D, f32)
115/// * `config` - Ternary configuration with calibration settings
116///
117/// # Returns
118///
119/// Tuple of (`TernaryTensor`, `QuantizationStats`)
120///
121/// # Errors
122///
123/// Returns error if tensor is not 2D or not f32.
124///
125/// # Example
126///
127/// ```rust,ignore
128/// use unsloth_rs::kernels::ternary::{quantize_tensor, TernaryConfig};
129///
130/// let weights = Tensor::randn(0.0f32, 1.0, (1024, 4096), &Device::Cpu)?;
131/// let (ternary, stats) = quantize_tensor(&weights, &TernaryConfig::default())?;
132///
133/// println!("Sparsity: {:.1}%", stats.sparsity * 100.0);
134/// println!("Compression: {:.1}x", ternary.compression_ratio());
135/// ```
136pub fn quantize_tensor(
137    tensor: &Tensor,
138    config: &TernaryConfig,
139) -> Result<(TernaryTensor, QuantizationStats)> {
140    // Validate input
141    let shape = tensor.shape();
142    if shape.dims().len() != 2 {
143        return Err(UnslothError::ShapeMismatch {
144            // Expected a 2D tensor (rank 2) for weight matrix
145            expected: vec![2],
146            actual: shape.dims().to_vec(),
147        });
148    }
149
150    if tensor.dtype() != DType::F32 {
151        return Err(UnslothError::InvalidConfig(format!(
152            "quantize_tensor requires f32, got {:?}",
153            tensor.dtype()
154        )));
155    }
156
157    let (out_features, in_features) = (shape.dims()[0], shape.dims()[1]);
158
159    // Get data as f32 vec (move to CPU if needed)
160    let data: Vec<f32> = tensor.flatten_all()?.to_vec1()?;
161
162    // Determine calibration method
163    let calibration = CalibrationMethod::from(config.calibration_method);
164
165    // Quantize per row (per output channel)
166    let k_words = in_features.div_ceil(32);
167    let mut plus_plane = vec![0u32; out_features * k_words];
168    let mut minus_plane = vec![0u32; out_features * k_words];
169    let mut scales = vec![0.0f32; out_features];
170    let mut thresholds = vec![0.0f32; out_features];
171
172    let mut total_positive = 0usize;
173    let mut total_negative = 0usize;
174    let mut total_zero = 0usize;
175    let mut total_error = 0.0f64;
176    let mut max_error = 0.0f32;
177
178    for row in 0..out_features {
179        let row_start = row * in_features;
180        let row_data = &data[row_start..row_start + in_features];
181
182        // Compute threshold for this row
183        let threshold = compute_threshold(row_data, calibration);
184        thresholds[row] = threshold;
185
186        // Quantize and compute scale
187        let (row_plus, row_minus, scale, pos, neg, zero) =
188            quantize_row(row_data, threshold, k_words);
189
190        // Copy to output planes
191        let plane_offset = row * k_words;
192        plus_plane[plane_offset..plane_offset + k_words].copy_from_slice(&row_plus);
193        minus_plane[plane_offset..plane_offset + k_words].copy_from_slice(&row_minus);
194        scales[row] = scale;
195
196        total_positive += pos;
197        total_negative += neg;
198        total_zero += zero;
199
200        // Compute reconstruction error
201        for (i, &val) in row_data.iter().enumerate() {
202            let word_idx = i / 32;
203            let bit_idx = i % 32;
204            let mask = 1u32 << bit_idx;
205
206            let is_plus = (row_plus[word_idx] & mask) != 0;
207            let is_minus = (row_minus[word_idx] & mask) != 0;
208
209            let reconstructed = if is_plus {
210                scale
211            } else if is_minus {
212                -scale
213            } else {
214                0.0
215            };
216
217            let error = (val - reconstructed).abs();
218            total_error += f64::from(error);
219            max_error = max_error.max(error);
220        }
221    }
222
223    let total_elements = out_features * in_features;
224    #[allow(clippy::cast_precision_loss)] // Sparsity calculations for statistics only
225    let stats = QuantizationStats {
226        sparsity: total_zero as f32 / total_elements as f32,
227        positive_ratio: total_positive as f32 / total_elements as f32,
228        negative_ratio: total_negative as f32 / total_elements as f32,
229        thresholds,
230        scales: scales.clone(),
231        #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] // Error statistics approximation
232        mean_error: (total_error / total_elements as f64) as f32,
233        max_error,
234    };
235
236    let ternary = TernaryTensor::new(plus_plane, minus_plane, scales, (out_features, in_features));
237
238    Ok((ternary, stats))
239}
240
241/// Compute quantization threshold using specified calibration method.
242fn compute_threshold(data: &[f32], method: CalibrationMethod) -> f32 {
243    match method {
244        CalibrationMethod::AbsMax { factor } => {
245            let max_abs = data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
246            factor * max_abs
247        }
248
249        CalibrationMethod::Percentile { percentile } => {
250            let mut abs_values: Vec<f32> = data.iter().map(|x| x.abs()).collect();
251            abs_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
252
253            #[allow(
254                clippy::cast_possible_truncation,
255                clippy::cast_sign_loss,
256                clippy::cast_precision_loss
257            )]
258            // Percentile calculation: precision loss acceptable for threshold approximation
259            let idx = ((percentile / 100.0) * (abs_values.len() - 1) as f32) as usize;
260            abs_values[idx.min(abs_values.len() - 1)]
261        }
262
263        CalibrationMethod::MeanStd { k } => {
264            #[allow(clippy::cast_precision_loss)]
265            // Precision loss acceptable for statistical calculations
266            let n = data.len() as f64;
267            let abs_values: Vec<f64> = data.iter().map(|x| f64::from(x.abs())).collect();
268
269            let mean = abs_values.iter().sum::<f64>() / n;
270            let variance = abs_values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
271            let std = variance.sqrt();
272
273            // Truncation acceptable for threshold calculation
274            #[allow(clippy::cast_possible_truncation)]
275            let threshold_value = (mean + f64::from(k) * std) as f32;
276            threshold_value
277        }
278
279        CalibrationMethod::Manual { threshold } => threshold,
280    }
281}
282
283/// Quantize a single row to ternary planes.
284///
285/// Returns (`plus_plane`, `minus_plane`, `scale`, `positive_count`, `negative_count`, `zero_count`)
286fn quantize_row(
287    data: &[f32],
288    threshold: f32,
289    k_words: usize,
290) -> (Vec<u32>, Vec<u32>, f32, usize, usize, usize) {
291    let mut plus = vec![0u32; k_words];
292    let mut minus = vec![0u32; k_words];
293
294    let mut positive_sum = 0.0f64;
295    let mut positive_count = 0usize;
296    let mut negative_sum = 0.0f64;
297    let mut negative_count = 0usize;
298    let mut zero_count = 0usize;
299
300    for (i, &val) in data.iter().enumerate() {
301        let word_idx = i / 32;
302        let bit_idx = i % 32;
303        let mask = 1u32 << bit_idx;
304
305        if val > threshold {
306            plus[word_idx] |= mask;
307            positive_sum += f64::from(val.abs());
308            positive_count += 1;
309        } else if val < -threshold {
310            minus[word_idx] |= mask;
311            negative_sum += f64::from(val.abs());
312            negative_count += 1;
313        } else {
314            zero_count += 1;
315        }
316    }
317
318    // Scale is mean of non-zero absolute values
319    let nonzero_count = positive_count + negative_count;
320    let scale = if nonzero_count > 0 {
321        // Truncation/precision loss acceptable for scale approximation
322        #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
323        let scale = ((positive_sum + negative_sum) / nonzero_count as f64) as f32;
324        scale
325    } else {
326        1.0 // Fallback for all-zero rows
327    };
328
329    (
330        plus,
331        minus,
332        scale,
333        positive_count,
334        negative_count,
335        zero_count,
336    )
337}
338
339/// Dequantize a ternary tensor back to f32 (for validation).
340///
341/// # Arguments
342///
343/// * `ternary` - Ternary tensor to dequantize
344///
345/// # Returns
346///
347/// Candle tensor [`out_features`, `in_features`] with reconstructed f32 values.
348///
349/// # Errors
350///
351/// Returns error if tensor creation fails.
352pub fn dequantize_tensor(ternary: &TernaryTensor) -> Result<Tensor> {
353    let (out_features, in_features) = ternary.dims();
354    let mut data = vec![0.0f32; out_features * in_features];
355
356    for row in 0..out_features {
357        let scale = ternary.scales[row];
358        let planes = ternary.get_row_planes(row);
359
360        for col in 0..in_features {
361            let val = planes.get(col);
362            data[row * in_features + col] = f32::from(val) * scale;
363        }
364    }
365
366    let tensor = Tensor::from_vec(data, (out_features, in_features), &candle_core::Device::Cpu)?;
367    Ok(tensor)
368}
369
370/// Quantize weights directly from Candle Linear layer.
371///
372/// # Arguments
373///
374/// * `weights` - Weight tensor from `nn::Linear` [`out_features`, `in_features`]
375/// * `config` - Ternary configuration
376///
377/// # Returns
378///
379/// Ternary tensor ready for use in `TernaryLinear`.
380///
381/// # Errors
382///
383/// Returns an error if the weight tensor cannot be accessed, has invalid dimensions,
384/// or if quantization fails due to numerical issues.
385pub fn quantize_linear_weights(weights: &Tensor, config: &TernaryConfig) -> Result<TernaryTensor> {
386    let (ternary, _stats) = quantize_tensor(weights, config)?;
387    Ok(ternary)
388}
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393    use candle_core::Device;
394
395    #[test]
396    fn test_quantize_simple() -> Result<()> {
397        // Create simple test tensor
398        let data: Vec<f32> = vec![
399            0.5, -0.5, 0.1, -0.1, 0.8, -0.8, 0.0, 0.3, // Row 0
400            1.0, -1.0, 0.2, -0.2, 0.0, 0.0, 0.9, -0.9, // Row 1
401        ];
402        let tensor = Tensor::from_vec(data, (2, 8), &Device::Cpu)?;
403
404        let config = TernaryConfig {
405            calibration_method: CalibrationMethodConfig::Manual(0.3),
406            ..Default::default()
407        };
408
409        let (ternary, stats) = quantize_tensor(&tensor, &config)?;
410
411        assert_eq!(ternary.dims(), (2, 8));
412        assert!(stats.sparsity > 0.0); // Some zeros
413        assert!(stats.positive_ratio > 0.0);
414        assert!(stats.negative_ratio > 0.0);
415
416        Ok(())
417    }
418
419    #[test]
420    fn test_quantize_dequantize_roundtrip() -> Result<()> {
421        // Test that dequantization produces reasonable values
422        let data: Vec<f32> = (0..256)
423            .map(|i| {
424                // Precision loss acceptable for test data generation
425                #[allow(clippy::cast_precision_loss)]
426                {
427                    (i as f32 - 128.0) / 128.0
428                }
429            })
430            .collect();
431        let tensor = Tensor::from_vec(data.clone(), (4, 64), &Device::Cpu)?;
432
433        let config = TernaryConfig::default();
434        let (ternary, _stats) = quantize_tensor(&tensor, &config)?;
435
436        let reconstructed = dequantize_tensor(&ternary)?;
437        let recon_data: Vec<f32> = reconstructed.flatten_all()?.to_vec1()?;
438
439        // Check that reconstruction is reasonable (not exact due to quantization)
440        let mse: f32 = data
441            .iter()
442            .zip(recon_data.iter())
443            .map(|(a, b)| (a - b).powi(2))
444            .sum::<f32>()
445            / {
446                // Precision loss acceptable for test metric calculation
447                #[allow(clippy::cast_precision_loss)]
448                {
449                    data.len() as f32
450                }
451            };
452
453        // MSE should be reasonable (< 0.5 for well-calibrated ternary)
454        assert!(mse < 0.5, "MSE too high: {mse}");
455
456        Ok(())
457    }
458
459    #[test]
460    fn test_calibration_methods() {
461        let data: Vec<f32> = vec![0.1, 0.5, 1.0, -0.3, -0.8, 2.0, -1.5, 0.0];
462
463        // AbsMax: factor * max(|x|) = 0.7 * 2.0 = 1.4
464        let t1 = compute_threshold(&data, CalibrationMethod::AbsMax { factor: 0.7 });
465        assert!((t1 - 1.4).abs() < 0.01);
466
467        // Manual
468        let t2 = compute_threshold(&data, CalibrationMethod::Manual { threshold: 0.5 });
469        assert!((t2 - 0.5).abs() < 0.001);
470    }
471
472    #[test]
473    fn test_sparsity_detection() -> Result<()> {
474        // Create sparse tensor (90% zeros)
475        let mut data = vec![0.0f32; 1000];
476        for i in 0..100 {
477            data[i * 10] = if i % 2 == 0 { 1.0 } else { -1.0 };
478        }
479        let tensor = Tensor::from_vec(data, (10, 100), &Device::Cpu)?;
480
481        let config = TernaryConfig {
482            calibration_method: CalibrationMethodConfig::Manual(0.1),
483            ..Default::default()
484        };
485
486        let (ternary, stats) = quantize_tensor(&tensor, &config)?;
487
488        // Should have ~90% sparsity after quantization
489        assert!(stats.sparsity > 0.85, "Sparsity: {}", stats.sparsity);
490        assert!(ternary.sparsity() > 0.85);
491
492        Ok(())
493    }
494
495    #[test]
496    fn test_compression_ratio() -> Result<()> {
497        let data = vec![0.0f32; 4096 * 4096];
498        let tensor = Tensor::from_vec(data, (4096, 4096), &Device::Cpu)?;
499
500        let config = TernaryConfig::default();
501        let (ternary, _) = quantize_tensor(&tensor, &config)?;
502
503        // Compression should be ~16x for 2 bits vs 32 bits
504        let ratio = ternary.compression_ratio();
505        assert!(ratio > 10.0, "Compression ratio too low: {ratio}");
506
507        Ok(())
508    }
509}