unsloth_rs/kernels/ternary/
quantize.rs1use super::config::{CalibrationMethodConfig, TernaryConfig};
33use super::types::TernaryTensor;
34use crate::error::{Result, UnslothError};
35use candle_core::{DType, Tensor};
36
37#[derive(Debug, Clone, Copy)]
39pub enum CalibrationMethod {
40 AbsMax {
43 factor: f32,
45 },
46
47 Percentile {
50 percentile: f32,
52 },
53
54 MeanStd {
57 k: f32,
59 },
60
61 Manual {
63 threshold: f32,
65 },
66}
67
68impl Default for CalibrationMethod {
69 fn default() -> Self {
70 Self::AbsMax { factor: 0.7 }
71 }
72}
73
74impl From<CalibrationMethodConfig> for CalibrationMethod {
75 fn from(config: CalibrationMethodConfig) -> Self {
76 match config {
77 CalibrationMethodConfig::AbsMax => Self::AbsMax { factor: 0.7 },
78 CalibrationMethodConfig::Percentile(p) => Self::Percentile { percentile: p },
79 CalibrationMethodConfig::MeanStd(k) => Self::MeanStd { k },
80 CalibrationMethodConfig::Manual(t) => Self::Manual { threshold: t },
81 }
82 }
83}
84
85#[derive(Debug, Clone)]
87pub struct QuantizationStats {
88 pub sparsity: f32,
90
91 pub positive_ratio: f32,
93
94 pub negative_ratio: f32,
96
97 pub thresholds: Vec<f32>,
99
100 pub scales: Vec<f32>,
102
103 pub mean_error: f32,
105
106 pub max_error: f32,
108}
109
110pub fn quantize_tensor(
137 tensor: &Tensor,
138 config: &TernaryConfig,
139) -> Result<(TernaryTensor, QuantizationStats)> {
140 let shape = tensor.shape();
142 if shape.dims().len() != 2 {
143 return Err(UnslothError::ShapeMismatch {
144 expected: vec![2],
146 actual: shape.dims().to_vec(),
147 });
148 }
149
150 if tensor.dtype() != DType::F32 {
151 return Err(UnslothError::InvalidConfig(format!(
152 "quantize_tensor requires f32, got {:?}",
153 tensor.dtype()
154 )));
155 }
156
157 let (out_features, in_features) = (shape.dims()[0], shape.dims()[1]);
158
159 let data: Vec<f32> = tensor.flatten_all()?.to_vec1()?;
161
162 let calibration = CalibrationMethod::from(config.calibration_method);
164
165 let k_words = in_features.div_ceil(32);
167 let mut plus_plane = vec![0u32; out_features * k_words];
168 let mut minus_plane = vec![0u32; out_features * k_words];
169 let mut scales = vec![0.0f32; out_features];
170 let mut thresholds = vec![0.0f32; out_features];
171
172 let mut total_positive = 0usize;
173 let mut total_negative = 0usize;
174 let mut total_zero = 0usize;
175 let mut total_error = 0.0f64;
176 let mut max_error = 0.0f32;
177
178 for row in 0..out_features {
179 let row_start = row * in_features;
180 let row_data = &data[row_start..row_start + in_features];
181
182 let threshold = compute_threshold(row_data, calibration);
184 thresholds[row] = threshold;
185
186 let (row_plus, row_minus, scale, pos, neg, zero) =
188 quantize_row(row_data, threshold, k_words);
189
190 let plane_offset = row * k_words;
192 plus_plane[plane_offset..plane_offset + k_words].copy_from_slice(&row_plus);
193 minus_plane[plane_offset..plane_offset + k_words].copy_from_slice(&row_minus);
194 scales[row] = scale;
195
196 total_positive += pos;
197 total_negative += neg;
198 total_zero += zero;
199
200 for (i, &val) in row_data.iter().enumerate() {
202 let word_idx = i / 32;
203 let bit_idx = i % 32;
204 let mask = 1u32 << bit_idx;
205
206 let is_plus = (row_plus[word_idx] & mask) != 0;
207 let is_minus = (row_minus[word_idx] & mask) != 0;
208
209 let reconstructed = if is_plus {
210 scale
211 } else if is_minus {
212 -scale
213 } else {
214 0.0
215 };
216
217 let error = (val - reconstructed).abs();
218 total_error += f64::from(error);
219 max_error = max_error.max(error);
220 }
221 }
222
223 let total_elements = out_features * in_features;
224 #[allow(clippy::cast_precision_loss)] let stats = QuantizationStats {
226 sparsity: total_zero as f32 / total_elements as f32,
227 positive_ratio: total_positive as f32 / total_elements as f32,
228 negative_ratio: total_negative as f32 / total_elements as f32,
229 thresholds,
230 scales: scales.clone(),
231 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] mean_error: (total_error / total_elements as f64) as f32,
233 max_error,
234 };
235
236 let ternary = TernaryTensor::new(plus_plane, minus_plane, scales, (out_features, in_features));
237
238 Ok((ternary, stats))
239}
240
241fn compute_threshold(data: &[f32], method: CalibrationMethod) -> f32 {
243 match method {
244 CalibrationMethod::AbsMax { factor } => {
245 let max_abs = data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
246 factor * max_abs
247 }
248
249 CalibrationMethod::Percentile { percentile } => {
250 let mut abs_values: Vec<f32> = data.iter().map(|x| x.abs()).collect();
251 abs_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
252
253 #[allow(
254 clippy::cast_possible_truncation,
255 clippy::cast_sign_loss,
256 clippy::cast_precision_loss
257 )]
258 let idx = ((percentile / 100.0) * (abs_values.len() - 1) as f32) as usize;
260 abs_values[idx.min(abs_values.len() - 1)]
261 }
262
263 CalibrationMethod::MeanStd { k } => {
264 #[allow(clippy::cast_precision_loss)]
265 let n = data.len() as f64;
267 let abs_values: Vec<f64> = data.iter().map(|x| f64::from(x.abs())).collect();
268
269 let mean = abs_values.iter().sum::<f64>() / n;
270 let variance = abs_values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
271 let std = variance.sqrt();
272
273 #[allow(clippy::cast_possible_truncation)]
275 let threshold_value = (mean + f64::from(k) * std) as f32;
276 threshold_value
277 }
278
279 CalibrationMethod::Manual { threshold } => threshold,
280 }
281}
282
283fn quantize_row(
287 data: &[f32],
288 threshold: f32,
289 k_words: usize,
290) -> (Vec<u32>, Vec<u32>, f32, usize, usize, usize) {
291 let mut plus = vec![0u32; k_words];
292 let mut minus = vec![0u32; k_words];
293
294 let mut positive_sum = 0.0f64;
295 let mut positive_count = 0usize;
296 let mut negative_sum = 0.0f64;
297 let mut negative_count = 0usize;
298 let mut zero_count = 0usize;
299
300 for (i, &val) in data.iter().enumerate() {
301 let word_idx = i / 32;
302 let bit_idx = i % 32;
303 let mask = 1u32 << bit_idx;
304
305 if val > threshold {
306 plus[word_idx] |= mask;
307 positive_sum += f64::from(val.abs());
308 positive_count += 1;
309 } else if val < -threshold {
310 minus[word_idx] |= mask;
311 negative_sum += f64::from(val.abs());
312 negative_count += 1;
313 } else {
314 zero_count += 1;
315 }
316 }
317
318 let nonzero_count = positive_count + negative_count;
320 let scale = if nonzero_count > 0 {
321 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
323 let scale = ((positive_sum + negative_sum) / nonzero_count as f64) as f32;
324 scale
325 } else {
326 1.0 };
328
329 (
330 plus,
331 minus,
332 scale,
333 positive_count,
334 negative_count,
335 zero_count,
336 )
337}
338
339pub fn dequantize_tensor(ternary: &TernaryTensor) -> Result<Tensor> {
353 let (out_features, in_features) = ternary.dims();
354 let mut data = vec![0.0f32; out_features * in_features];
355
356 for row in 0..out_features {
357 let scale = ternary.scales[row];
358 let planes = ternary.get_row_planes(row);
359
360 for col in 0..in_features {
361 let val = planes.get(col);
362 data[row * in_features + col] = f32::from(val) * scale;
363 }
364 }
365
366 let tensor = Tensor::from_vec(data, (out_features, in_features), &candle_core::Device::Cpu)?;
367 Ok(tensor)
368}
369
370pub fn quantize_linear_weights(weights: &Tensor, config: &TernaryConfig) -> Result<TernaryTensor> {
386 let (ternary, _stats) = quantize_tensor(weights, config)?;
387 Ok(ternary)
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393 use candle_core::Device;
394
395 #[test]
396 fn test_quantize_simple() -> Result<()> {
397 let data: Vec<f32> = vec![
399 0.5, -0.5, 0.1, -0.1, 0.8, -0.8, 0.0, 0.3, 1.0, -1.0, 0.2, -0.2, 0.0, 0.0, 0.9, -0.9, ];
402 let tensor = Tensor::from_vec(data, (2, 8), &Device::Cpu)?;
403
404 let config = TernaryConfig {
405 calibration_method: CalibrationMethodConfig::Manual(0.3),
406 ..Default::default()
407 };
408
409 let (ternary, stats) = quantize_tensor(&tensor, &config)?;
410
411 assert_eq!(ternary.dims(), (2, 8));
412 assert!(stats.sparsity > 0.0); assert!(stats.positive_ratio > 0.0);
414 assert!(stats.negative_ratio > 0.0);
415
416 Ok(())
417 }
418
419 #[test]
420 fn test_quantize_dequantize_roundtrip() -> Result<()> {
421 let data: Vec<f32> = (0..256)
423 .map(|i| {
424 #[allow(clippy::cast_precision_loss)]
426 {
427 (i as f32 - 128.0) / 128.0
428 }
429 })
430 .collect();
431 let tensor = Tensor::from_vec(data.clone(), (4, 64), &Device::Cpu)?;
432
433 let config = TernaryConfig::default();
434 let (ternary, _stats) = quantize_tensor(&tensor, &config)?;
435
436 let reconstructed = dequantize_tensor(&ternary)?;
437 let recon_data: Vec<f32> = reconstructed.flatten_all()?.to_vec1()?;
438
439 let mse: f32 = data
441 .iter()
442 .zip(recon_data.iter())
443 .map(|(a, b)| (a - b).powi(2))
444 .sum::<f32>()
445 / {
446 #[allow(clippy::cast_precision_loss)]
448 {
449 data.len() as f32
450 }
451 };
452
453 assert!(mse < 0.5, "MSE too high: {mse}");
455
456 Ok(())
457 }
458
459 #[test]
460 fn test_calibration_methods() {
461 let data: Vec<f32> = vec![0.1, 0.5, 1.0, -0.3, -0.8, 2.0, -1.5, 0.0];
462
463 let t1 = compute_threshold(&data, CalibrationMethod::AbsMax { factor: 0.7 });
465 assert!((t1 - 1.4).abs() < 0.01);
466
467 let t2 = compute_threshold(&data, CalibrationMethod::Manual { threshold: 0.5 });
469 assert!((t2 - 0.5).abs() < 0.001);
470 }
471
472 #[test]
473 fn test_sparsity_detection() -> Result<()> {
474 let mut data = vec![0.0f32; 1000];
476 for i in 0..100 {
477 data[i * 10] = if i % 2 == 0 { 1.0 } else { -1.0 };
478 }
479 let tensor = Tensor::from_vec(data, (10, 100), &Device::Cpu)?;
480
481 let config = TernaryConfig {
482 calibration_method: CalibrationMethodConfig::Manual(0.1),
483 ..Default::default()
484 };
485
486 let (ternary, stats) = quantize_tensor(&tensor, &config)?;
487
488 assert!(stats.sparsity > 0.85, "Sparsity: {}", stats.sparsity);
490 assert!(ternary.sparsity() > 0.85);
491
492 Ok(())
493 }
494
495 #[test]
496 fn test_compression_ratio() -> Result<()> {
497 let data = vec![0.0f32; 4096 * 4096];
498 let tensor = Tensor::from_vec(data, (4096, 4096), &Device::Cpu)?;
499
500 let config = TernaryConfig::default();
501 let (ternary, _) = quantize_tensor(&tensor, &config)?;
502
503 let ratio = ternary.compression_ratio();
505 assert!(ratio > 10.0, "Compression ratio too low: {ratio}");
506
507 Ok(())
508 }
509}