Skip to main content

tritter_accel/core/
quantization.rs

1//! Quantization operations for weight and activation compression.
2//!
3//! Wraps `bitnet-quantize` for ternary quantization with various scaling methods.
4//!
5//! # Quantization Methods
6//!
7//! - **AbsMean**: Scale = mean(|W|), round to {-1, 0, +1}
8//! - **AbsMax**: Scale = max(|W|), more aggressive outlier handling
9//!
10//! # Example
11//!
12//! ```rust,ignore
13//! use tritter_accel::core::quantization::{quantize_absmean, QuantizeConfig};
14//! use candle_core::{Device, Tensor};
15//!
16//! let weights = Tensor::randn(0f32, 1f32, (512, 512), &Device::Cpu)?;
17//! let config = QuantizeConfig::default();
18//! let result = quantize_absmean(&weights, &config)?;
19//! println!("Quantized to {} groups", result.scales.len());
20//! ```
21
22use bitnet_quantize::{quantize_weights, BitNetConfig, TernaryWeight};
23use candle_core::{Device, Tensor};
24use thiserror::Error;
25
26use super::ternary::PackedTernary;
27
28/// Errors from quantization operations.
29#[derive(Debug, Error)]
30pub enum QuantizationError {
31    /// Invalid configuration.
32    #[error("invalid config: {0}")]
33    InvalidConfig(String),
34
35    /// Tensor operation failed.
36    #[error("tensor error: {0}")]
37    Tensor(#[from] candle_core::Error),
38
39    /// BitNet quantization failed.
40    #[error("quantization failed: {0}")]
41    Quantize(#[from] bitnet_quantize::BitNetError),
42}
43
44/// Configuration for quantization operations.
45#[derive(Debug, Clone)]
46pub struct QuantizeConfig {
47    /// Group size for block-wise quantization.
48    /// Smaller groups = more scales = higher accuracy.
49    /// 0 = per-tensor, otherwise per-group.
50    pub group_size: usize,
51
52    /// Whether to use symmetric quantization.
53    pub symmetric: bool,
54}
55
56impl Default for QuantizeConfig {
57    fn default() -> Self {
58        Self {
59            group_size: 0, // Per-row by default
60            symmetric: true,
61        }
62    }
63}
64
65impl QuantizeConfig {
66    /// Set group size.
67    pub fn with_group_size(mut self, size: usize) -> Self {
68        self.group_size = size;
69        self
70    }
71}
72
73/// Result of quantization operation.
74#[derive(Debug, Clone)]
75pub struct QuantizationResult {
76    /// Quantized ternary values as i8 (-1, 0, +1).
77    pub values: Vec<i8>,
78    /// Scale factors (one per group or per row).
79    pub scales: Vec<f32>,
80    /// Original tensor shape.
81    pub shape: (usize, usize),
82    /// Group size used.
83    pub group_size: usize,
84}
85
86impl QuantizationResult {
87    /// Convert to PackedTernary for efficient storage and matmul.
88    pub fn to_packed(&self) -> Result<PackedTernary, super::ternary::TernaryError> {
89        // For per-row quantization, scales align with rows
90        // For group quantization, we need to expand scales to per-row
91        let (rows, cols) = self.shape;
92
93        if self.group_size == 0 || self.group_size >= cols {
94            // Per-row quantization
95            PackedTernary::from_i8(&self.values, &self.scales, self.shape)
96        } else {
97            // Group quantization - we need one scale per row for PackedTernary
98            // Average the group scales for each row
99            let groups_per_row = cols.div_ceil(self.group_size);
100            let mut row_scales = Vec::with_capacity(rows);
101
102            for row in 0..rows {
103                let start = row * groups_per_row;
104                let end = (start + groups_per_row).min(self.scales.len());
105                let avg: f32 = self.scales[start..end].iter().sum::<f32>() / (end - start) as f32;
106                row_scales.push(avg);
107            }
108
109            PackedTernary::from_i8(&self.values, &row_scales, self.shape)
110        }
111    }
112
113    /// Get values as f32 tensor with scales applied.
114    pub fn to_tensor(&self, device: &Device) -> Result<Tensor, QuantizationError> {
115        let (rows, cols) = self.shape;
116        let mut output = vec![0.0f32; rows * cols];
117
118        if self.group_size == 0 || self.group_size >= cols {
119            // Per-row scaling
120            for row in 0..rows {
121                let scale = self.scales[row];
122                for col in 0..cols {
123                    let idx = row * cols + col;
124                    output[idx] = f32::from(self.values[idx]) * scale;
125                }
126            }
127        } else {
128            // Per-group scaling
129            let groups_per_row = cols.div_ceil(self.group_size);
130            for row in 0..rows {
131                for col in 0..cols {
132                    let group = col / self.group_size;
133                    let scale_idx = row * groups_per_row + group;
134                    let idx = row * cols + col;
135                    output[idx] = f32::from(self.values[idx]) * self.scales[scale_idx];
136                }
137            }
138        }
139
140        Ok(Tensor::from_vec(output, (rows, cols), device)?)
141    }
142}
143
144/// Quantize weights using AbsMean scaling (BitNet b1.58 method).
145///
146/// For each group: scale = mean(|W|), then round W/scale to {-1, 0, +1}.
147///
148/// # Arguments
149///
150/// * `weights` - 2D weight tensor
151/// * `config` - Quantization configuration
152///
153/// # Returns
154///
155/// Quantized weights with scales.
156pub fn quantize_absmean(
157    weights: &Tensor,
158    config: &QuantizeConfig,
159) -> Result<QuantizationResult, QuantizationError> {
160    let (rows, cols) = weights.dims2()?;
161
162    // Use row-wise quantization if group_size is 0
163    let effective_group_size = if config.group_size == 0 {
164        cols
165    } else {
166        config.group_size
167    };
168
169    let bitnet_config = BitNetConfig::default().with_group_size(effective_group_size);
170
171    let ternary: TernaryWeight = quantize_weights(weights, &bitnet_config)?;
172
173    // Extract values and scales
174    let mut values = Vec::with_capacity(rows * cols);
175    for packed in &ternary.data {
176        for col in 0..cols {
177            values.push(packed.get(col).value());
178        }
179    }
180
181    Ok(QuantizationResult {
182        values,
183        scales: ternary.scales,
184        shape: (rows, cols),
185        group_size: effective_group_size,
186    })
187}
188
189/// Quantize weights using AbsMax scaling.
190///
191/// For each group: scale = max(|W|), then round W/scale to {-1, 0, +1}.
192/// More robust to outliers than AbsMean.
193///
194/// # Arguments
195///
196/// * `weights` - 2D weight tensor
197/// * `config` - Quantization configuration
198pub fn quantize_absmax(
199    weights: &Tensor,
200    config: &QuantizeConfig,
201) -> Result<QuantizationResult, QuantizationError> {
202    let (rows, cols) = weights.dims2()?;
203    let data: Vec<f32> = weights.flatten_all()?.to_vec1()?;
204
205    let effective_group_size = if config.group_size == 0 {
206        cols
207    } else {
208        config.group_size
209    };
210
211    let groups_per_row = cols.div_ceil(effective_group_size);
212    let mut scales = Vec::with_capacity(rows * groups_per_row);
213    let mut values = Vec::with_capacity(rows * cols);
214
215    for row in 0..rows {
216        for group in 0..groups_per_row {
217            let start = group * effective_group_size;
218            let end = (start + effective_group_size).min(cols);
219
220            // Find max absolute value in group
221            let mut max_abs = 0.0f32;
222            for col in start..end {
223                let val = data[row * cols + col].abs();
224                if val > max_abs {
225                    max_abs = val;
226                }
227            }
228
229            // Avoid division by zero
230            let scale = if max_abs > 1e-10 { max_abs } else { 1.0 };
231            scales.push(scale);
232
233            // Quantize this group
234            for col in start..end {
235                let val = data[row * cols + col];
236                let normalized = val / scale;
237                let quantized = if normalized > 0.5 {
238                    1i8
239                } else if normalized < -0.5 {
240                    -1i8
241                } else {
242                    0i8
243                };
244                values.push(quantized);
245            }
246        }
247    }
248
249    Ok(QuantizationResult {
250        values,
251        scales,
252        shape: (rows, cols),
253        group_size: effective_group_size,
254    })
255}
256
257/// Quantize activations for inference.
258///
259/// Uses AbsMax per-tensor scaling to preserve dynamic range.
260pub fn quantize_activations(activations: &Tensor) -> Result<(Tensor, f32), QuantizationError> {
261    let data: Vec<f32> = activations.flatten_all()?.to_vec1()?;
262
263    // Find max absolute value
264    let max_abs = data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
265    let scale = if max_abs > 1e-10 { max_abs } else { 1.0 };
266
267    // Scale to [-1, 1] range
268    let scaled: Vec<f32> = data.iter().map(|x| x / scale).collect();
269
270    Ok((
271        Tensor::from_vec(scaled, activations.shape(), activations.device())?,
272        scale,
273    ))
274}
275
276/// Dequantize ternary values back to float.
277pub fn dequantize(result: &QuantizationResult, device: &Device) -> Result<Tensor, QuantizationError> {
278    result.to_tensor(device)
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    #[test]
286    fn test_quantize_absmean() {
287        let device = Device::Cpu;
288        let weights = Tensor::from_vec(
289            vec![0.5f32, -0.3, 0.1, 0.8, -0.2, 0.6, -0.7, 0.4],
290            (2, 4),
291            &device,
292        )
293        .unwrap();
294
295        let config = QuantizeConfig::default();
296        let result = quantize_absmean(&weights, &config).unwrap();
297
298        assert_eq!(result.shape, (2, 4));
299        assert_eq!(result.values.len(), 8);
300        assert_eq!(result.scales.len(), 2); // Per-row
301
302        // All values should be in {-1, 0, +1}
303        for v in &result.values {
304            assert!([-1, 0, 1].contains(v));
305        }
306    }
307
308    #[test]
309    fn test_quantize_absmax() {
310        let device = Device::Cpu;
311        let weights = Tensor::from_vec(
312            vec![0.5f32, -0.3, 0.1, 0.8, -0.2, 0.6, -0.7, 0.4],
313            (2, 4),
314            &device,
315        )
316        .unwrap();
317
318        let config = QuantizeConfig::default();
319        let result = quantize_absmax(&weights, &config).unwrap();
320
321        assert_eq!(result.shape, (2, 4));
322        assert_eq!(result.values.len(), 8);
323
324        for v in &result.values {
325            assert!([-1, 0, 1].contains(v));
326        }
327    }
328
329    #[test]
330    fn test_quantize_dequantize_roundtrip() {
331        let device = Device::Cpu;
332        let weights = Tensor::from_vec(
333            vec![0.8f32, -0.8, 0.0, 0.8, -0.8, 0.8, -0.8, 0.0],
334            (2, 4),
335            &device,
336        )
337        .unwrap();
338
339        let config = QuantizeConfig::default();
340        let result = quantize_absmean(&weights, &config).unwrap();
341        let dequantized = dequantize(&result, &device).unwrap();
342
343        // Check shape preserved
344        assert_eq!(dequantized.dims(), &[2, 4]);
345
346        // Dequantized values should be close to original for saturated values
347        let deq_data: Vec<f32> = dequantized.flatten_all().unwrap().to_vec1().unwrap();
348        let orig_data: Vec<f32> = weights.flatten_all().unwrap().to_vec1().unwrap();
349
350        // At least the signs should match for non-zero values
351        for (d, o) in deq_data.iter().zip(orig_data.iter()) {
352            if o.abs() > 0.5 {
353                assert_eq!(d.signum(), o.signum());
354            }
355        }
356    }
357}