1use bitnet_quantize::{quantize_weights, BitNetConfig, TernaryWeight};
23use candle_core::{Device, Tensor};
24use thiserror::Error;
25
26use super::ternary::PackedTernary;
27
28#[derive(Debug, Error)]
30pub enum QuantizationError {
31 #[error("invalid config: {0}")]
33 InvalidConfig(String),
34
35 #[error("tensor error: {0}")]
37 Tensor(#[from] candle_core::Error),
38
39 #[error("quantization failed: {0}")]
41 Quantize(#[from] bitnet_quantize::BitNetError),
42}
43
44#[derive(Debug, Clone)]
46pub struct QuantizeConfig {
47 pub group_size: usize,
51
52 pub symmetric: bool,
54}
55
56impl Default for QuantizeConfig {
57 fn default() -> Self {
58 Self {
59 group_size: 0, symmetric: true,
61 }
62 }
63}
64
65impl QuantizeConfig {
66 pub fn with_group_size(mut self, size: usize) -> Self {
68 self.group_size = size;
69 self
70 }
71}
72
73#[derive(Debug, Clone)]
75pub struct QuantizationResult {
76 pub values: Vec<i8>,
78 pub scales: Vec<f32>,
80 pub shape: (usize, usize),
82 pub group_size: usize,
84}
85
86impl QuantizationResult {
87 pub fn to_packed(&self) -> Result<PackedTernary, super::ternary::TernaryError> {
89 let (rows, cols) = self.shape;
92
93 if self.group_size == 0 || self.group_size >= cols {
94 PackedTernary::from_i8(&self.values, &self.scales, self.shape)
96 } else {
97 let groups_per_row = cols.div_ceil(self.group_size);
100 let mut row_scales = Vec::with_capacity(rows);
101
102 for row in 0..rows {
103 let start = row * groups_per_row;
104 let end = (start + groups_per_row).min(self.scales.len());
105 let avg: f32 = self.scales[start..end].iter().sum::<f32>() / (end - start) as f32;
106 row_scales.push(avg);
107 }
108
109 PackedTernary::from_i8(&self.values, &row_scales, self.shape)
110 }
111 }
112
113 pub fn to_tensor(&self, device: &Device) -> Result<Tensor, QuantizationError> {
115 let (rows, cols) = self.shape;
116 let mut output = vec![0.0f32; rows * cols];
117
118 if self.group_size == 0 || self.group_size >= cols {
119 for row in 0..rows {
121 let scale = self.scales[row];
122 for col in 0..cols {
123 let idx = row * cols + col;
124 output[idx] = f32::from(self.values[idx]) * scale;
125 }
126 }
127 } else {
128 let groups_per_row = cols.div_ceil(self.group_size);
130 for row in 0..rows {
131 for col in 0..cols {
132 let group = col / self.group_size;
133 let scale_idx = row * groups_per_row + group;
134 let idx = row * cols + col;
135 output[idx] = f32::from(self.values[idx]) * self.scales[scale_idx];
136 }
137 }
138 }
139
140 Ok(Tensor::from_vec(output, (rows, cols), device)?)
141 }
142}
143
144pub fn quantize_absmean(
157 weights: &Tensor,
158 config: &QuantizeConfig,
159) -> Result<QuantizationResult, QuantizationError> {
160 let (rows, cols) = weights.dims2()?;
161
162 let effective_group_size = if config.group_size == 0 {
164 cols
165 } else {
166 config.group_size
167 };
168
169 let bitnet_config = BitNetConfig::default().with_group_size(effective_group_size);
170
171 let ternary: TernaryWeight = quantize_weights(weights, &bitnet_config)?;
172
173 let mut values = Vec::with_capacity(rows * cols);
175 for packed in &ternary.data {
176 for col in 0..cols {
177 values.push(packed.get(col).value());
178 }
179 }
180
181 Ok(QuantizationResult {
182 values,
183 scales: ternary.scales,
184 shape: (rows, cols),
185 group_size: effective_group_size,
186 })
187}
188
189pub fn quantize_absmax(
199 weights: &Tensor,
200 config: &QuantizeConfig,
201) -> Result<QuantizationResult, QuantizationError> {
202 let (rows, cols) = weights.dims2()?;
203 let data: Vec<f32> = weights.flatten_all()?.to_vec1()?;
204
205 let effective_group_size = if config.group_size == 0 {
206 cols
207 } else {
208 config.group_size
209 };
210
211 let groups_per_row = cols.div_ceil(effective_group_size);
212 let mut scales = Vec::with_capacity(rows * groups_per_row);
213 let mut values = Vec::with_capacity(rows * cols);
214
215 for row in 0..rows {
216 for group in 0..groups_per_row {
217 let start = group * effective_group_size;
218 let end = (start + effective_group_size).min(cols);
219
220 let mut max_abs = 0.0f32;
222 for col in start..end {
223 let val = data[row * cols + col].abs();
224 if val > max_abs {
225 max_abs = val;
226 }
227 }
228
229 let scale = if max_abs > 1e-10 { max_abs } else { 1.0 };
231 scales.push(scale);
232
233 for col in start..end {
235 let val = data[row * cols + col];
236 let normalized = val / scale;
237 let quantized = if normalized > 0.5 {
238 1i8
239 } else if normalized < -0.5 {
240 -1i8
241 } else {
242 0i8
243 };
244 values.push(quantized);
245 }
246 }
247 }
248
249 Ok(QuantizationResult {
250 values,
251 scales,
252 shape: (rows, cols),
253 group_size: effective_group_size,
254 })
255}
256
257pub fn quantize_activations(activations: &Tensor) -> Result<(Tensor, f32), QuantizationError> {
261 let data: Vec<f32> = activations.flatten_all()?.to_vec1()?;
262
263 let max_abs = data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
265 let scale = if max_abs > 1e-10 { max_abs } else { 1.0 };
266
267 let scaled: Vec<f32> = data.iter().map(|x| x / scale).collect();
269
270 Ok((
271 Tensor::from_vec(scaled, activations.shape(), activations.device())?,
272 scale,
273 ))
274}
275
276pub fn dequantize(result: &QuantizationResult, device: &Device) -> Result<Tensor, QuantizationError> {
278 result.to_tensor(device)
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284
285 #[test]
286 fn test_quantize_absmean() {
287 let device = Device::Cpu;
288 let weights = Tensor::from_vec(
289 vec![0.5f32, -0.3, 0.1, 0.8, -0.2, 0.6, -0.7, 0.4],
290 (2, 4),
291 &device,
292 )
293 .unwrap();
294
295 let config = QuantizeConfig::default();
296 let result = quantize_absmean(&weights, &config).unwrap();
297
298 assert_eq!(result.shape, (2, 4));
299 assert_eq!(result.values.len(), 8);
300 assert_eq!(result.scales.len(), 2); for v in &result.values {
304 assert!([-1, 0, 1].contains(v));
305 }
306 }
307
308 #[test]
309 fn test_quantize_absmax() {
310 let device = Device::Cpu;
311 let weights = Tensor::from_vec(
312 vec![0.5f32, -0.3, 0.1, 0.8, -0.2, 0.6, -0.7, 0.4],
313 (2, 4),
314 &device,
315 )
316 .unwrap();
317
318 let config = QuantizeConfig::default();
319 let result = quantize_absmax(&weights, &config).unwrap();
320
321 assert_eq!(result.shape, (2, 4));
322 assert_eq!(result.values.len(), 8);
323
324 for v in &result.values {
325 assert!([-1, 0, 1].contains(v));
326 }
327 }
328
329 #[test]
330 fn test_quantize_dequantize_roundtrip() {
331 let device = Device::Cpu;
332 let weights = Tensor::from_vec(
333 vec![0.8f32, -0.8, 0.0, 0.8, -0.8, 0.8, -0.8, 0.0],
334 (2, 4),
335 &device,
336 )
337 .unwrap();
338
339 let config = QuantizeConfig::default();
340 let result = quantize_absmean(&weights, &config).unwrap();
341 let dequantized = dequantize(&result, &device).unwrap();
342
343 assert_eq!(dequantized.dims(), &[2, 4]);
345
346 let deq_data: Vec<f32> = dequantized.flatten_all().unwrap().to_vec1().unwrap();
348 let orig_data: Vec<f32> = weights.flatten_all().unwrap().to_vec1().unwrap();
349
350 for (d, o) in deq_data.iter().zip(orig_data.iter()) {
352 if o.abs() > 0.5 {
353 assert_eq!(d.signum(), o.signum());
354 }
355 }
356 }
357}