1use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use tensorlogic_ir::OpType;
15use thiserror::Error;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
19pub struct NodeId(pub usize);
20
21#[derive(Error, Debug, Clone, PartialEq)]
23pub enum QuantizationError {
24 #[error("Unsupported data type for quantization: {0}")]
25 UnsupportedDataType(String),
26
27 #[error("Invalid quantization range: min={min}, max={max}")]
28 InvalidRange { min: f64, max: f64 },
29
30 #[error("Calibration failed: {0}")]
31 CalibrationFailed(String),
32
33 #[error("Quantization not supported for operation: {0:?}")]
34 UnsupportedOperation(OpType),
35
36 #[error("Invalid quantization parameters: {0}")]
37 InvalidParameters(String),
38
39 #[error("Insufficient calibration data")]
40 InsufficientData,
41}
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
45pub enum QuantizationType {
46 Int8,
48 Int4,
50 Int2,
52 FP8E4M3,
54 FP8E5M2,
56 FP16,
58 BF16,
60 Binary,
62 Ternary,
64}
65
66impl QuantizationType {
67 pub fn bits(&self) -> u32 {
69 match self {
70 Self::Binary => 1,
71 Self::Int2 => 2,
72 Self::Int4 => 4,
73 Self::Int8 | Self::FP8E4M3 | Self::FP8E5M2 => 8,
74 Self::FP16 | Self::BF16 => 16,
75 Self::Ternary => 2, }
77 }
78
79 pub fn compression_ratio(&self) -> f64 {
81 32.0 / self.bits() as f64
82 }
83
84 pub fn is_floating_point(&self) -> bool {
86 matches!(
87 self,
88 Self::FP8E4M3 | Self::FP8E5M2 | Self::FP16 | Self::BF16
89 )
90 }
91}
92
93#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
95pub enum QuantizationGranularity {
96 PerTensor,
98 PerChannel { axis: usize },
100 PerGroup { axis: usize, group_size: usize },
102}
103
104#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
106pub enum QuantizationSymmetry {
107 Symmetric,
109 Asymmetric,
111}
112
113#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
115pub enum QuantizationMode {
116 Static,
118 Dynamic,
120 QAT,
122}
123
124#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
126pub enum CalibrationStrategy {
127 MinMax,
129 Percentile { lower: u32, upper: u32 },
131 MSE,
133 KLDivergence,
135 Entropy,
137}
138
139#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
141pub struct QuantizationParams {
142 pub qtype: QuantizationType,
144 pub scale: Vec<f64>,
146 pub zero_point: Vec<i32>,
148 pub granularity: QuantizationGranularity,
150 pub symmetry: QuantizationSymmetry,
152 pub observed_min: Option<f64>,
154 pub observed_max: Option<f64>,
155}
156
157impl QuantizationParams {
158 pub fn symmetric_per_tensor(
160 qtype: QuantizationType,
161 abs_max: f64,
162 ) -> Result<Self, QuantizationError> {
163 if abs_max <= 0.0 {
164 return Err(QuantizationError::InvalidRange {
165 min: -abs_max,
166 max: abs_max,
167 });
168 }
169
170 let qmax = match qtype {
171 QuantizationType::Int8 => 127.0,
172 QuantizationType::Int4 => 7.0,
173 QuantizationType::Int2 => 1.0,
174 QuantizationType::Binary => 1.0,
175 QuantizationType::Ternary => 1.0,
176 _ => {
177 return Err(QuantizationError::UnsupportedDataType(format!(
178 "{:?}",
179 qtype
180 )))
181 }
182 };
183
184 let scale = abs_max / qmax;
185
186 Ok(Self {
187 qtype,
188 scale: vec![scale],
189 zero_point: vec![0],
190 granularity: QuantizationGranularity::PerTensor,
191 symmetry: QuantizationSymmetry::Symmetric,
192 observed_min: Some(-abs_max),
193 observed_max: Some(abs_max),
194 })
195 }
196
197 pub fn asymmetric_per_tensor(
199 qtype: QuantizationType,
200 min: f64,
201 max: f64,
202 ) -> Result<Self, QuantizationError> {
203 if min >= max {
204 return Err(QuantizationError::InvalidRange { min, max });
205 }
206
207 let (qmin, qmax) = match qtype {
208 QuantizationType::Int8 => (-128.0, 127.0),
209 QuantizationType::Int4 => (-8.0, 7.0),
210 QuantizationType::Int2 => (-2.0, 1.0),
211 _ => {
212 return Err(QuantizationError::UnsupportedDataType(format!(
213 "{:?}",
214 qtype
215 )))
216 }
217 };
218
219 let scale = (max - min) / (qmax - qmin);
220 let zero_point = (qmin - min / scale).round() as i32;
221
222 Ok(Self {
223 qtype,
224 scale: vec![scale],
225 zero_point: vec![zero_point],
226 granularity: QuantizationGranularity::PerTensor,
227 symmetry: QuantizationSymmetry::Asymmetric,
228 observed_min: Some(min),
229 observed_max: Some(max),
230 })
231 }
232
233 pub fn quantize(&self, value: f64) -> i32 {
235 let scale = self.scale[0];
236 let zero_point = self.zero_point[0];
237 ((value / scale).round() as i32 + zero_point).clamp(self.qmin(), self.qmax())
238 }
239
240 pub fn dequantize(&self, qvalue: i32) -> f64 {
242 let scale = self.scale[0];
243 let zero_point = self.zero_point[0];
244 (qvalue - zero_point) as f64 * scale
245 }
246
247 fn qmin(&self) -> i32 {
249 match self.qtype {
250 QuantizationType::Int8 => -128,
251 QuantizationType::Int4 => -8,
252 QuantizationType::Int2 => -2,
253 QuantizationType::Binary => 0,
254 QuantizationType::Ternary => -1,
255 _ => 0,
256 }
257 }
258
259 fn qmax(&self) -> i32 {
261 match self.qtype {
262 QuantizationType::Int8 => 127,
263 QuantizationType::Int4 => 7,
264 QuantizationType::Int2 => 1,
265 QuantizationType::Binary => 1,
266 QuantizationType::Ternary => 1,
267 _ => 255,
268 }
269 }
270}
271
272#[derive(Debug, Clone, Serialize, Deserialize)]
274pub struct QuantizationConfig {
275 pub default_qtype: QuantizationType,
277 pub mode: QuantizationMode,
279 pub granularity: QuantizationGranularity,
281 pub symmetry: QuantizationSymmetry,
283 pub calibration: CalibrationStrategy,
285 pub calibration_samples: usize,
287 pub skip_ops: Vec<OpType>,
289 pub node_overrides: HashMap<NodeId, QuantizationType>,
291}
292
293impl Default for QuantizationConfig {
294 fn default() -> Self {
295 Self {
296 default_qtype: QuantizationType::Int8,
297 mode: QuantizationMode::Static,
298 granularity: QuantizationGranularity::PerTensor,
299 symmetry: QuantizationSymmetry::Symmetric,
300 calibration: CalibrationStrategy::MinMax,
301 calibration_samples: 100,
302 skip_ops: vec![],
303 node_overrides: HashMap::new(),
304 }
305 }
306}
307
308impl QuantizationConfig {
309 pub fn int8() -> Self {
311 Self {
312 default_qtype: QuantizationType::Int8,
313 ..Default::default()
314 }
315 }
316
317 pub fn int4() -> Self {
319 Self {
320 default_qtype: QuantizationType::Int4,
321 ..Default::default()
322 }
323 }
324
325 pub fn fp8() -> Self {
327 Self {
328 default_qtype: QuantizationType::FP8E4M3,
329 symmetry: QuantizationSymmetry::Symmetric,
330 ..Default::default()
331 }
332 }
333
334 pub fn qat(qtype: QuantizationType) -> Self {
336 Self {
337 default_qtype: qtype,
338 mode: QuantizationMode::QAT,
339 ..Default::default()
340 }
341 }
342
343 pub fn per_channel(mut self, axis: usize) -> Self {
345 self.granularity = QuantizationGranularity::PerChannel { axis };
346 self
347 }
348
349 pub fn asymmetric(mut self) -> Self {
351 self.symmetry = QuantizationSymmetry::Asymmetric;
352 self
353 }
354
355 pub fn with_calibration(mut self, strategy: CalibrationStrategy) -> Self {
357 self.calibration = strategy;
358 self
359 }
360}
361
362#[derive(Debug, Clone, Default, Serialize, Deserialize)]
364pub struct CalibrationStats {
365 pub min_values: HashMap<NodeId, f64>,
367 pub max_values: HashMap<NodeId, f64>,
369 pub histograms: HashMap<NodeId, Vec<u32>>,
371 pub num_samples: usize,
373}
374
375impl CalibrationStats {
376 pub fn new() -> Self {
378 Self::default()
379 }
380
381 pub fn update(&mut self, node_id: NodeId, min: f64, max: f64) {
383 self.min_values
384 .entry(node_id)
385 .and_modify(|v| *v = v.min(min))
386 .or_insert(min);
387 self.max_values
388 .entry(node_id)
389 .and_modify(|v| *v = v.max(max))
390 .or_insert(max);
391 self.num_samples += 1;
392 }
393
394 pub fn compute_params(
396 &self,
397 node_id: NodeId,
398 config: &QuantizationConfig,
399 ) -> Result<QuantizationParams, QuantizationError> {
400 let min = self
401 .min_values
402 .get(&node_id)
403 .ok_or(QuantizationError::InsufficientData)?;
404 let max = self
405 .max_values
406 .get(&node_id)
407 .ok_or(QuantizationError::InsufficientData)?;
408
409 let qtype = config
410 .node_overrides
411 .get(&node_id)
412 .copied()
413 .unwrap_or(config.default_qtype);
414
415 match config.symmetry {
416 QuantizationSymmetry::Symmetric => {
417 let abs_max = min.abs().max(max.abs());
418 QuantizationParams::symmetric_per_tensor(qtype, abs_max)
419 }
420 QuantizationSymmetry::Asymmetric => {
421 QuantizationParams::asymmetric_per_tensor(qtype, *min, *max)
422 }
423 }
424 }
425}
426
427pub struct Quantizer {
429 config: QuantizationConfig,
430 stats: CalibrationStats,
431 params: HashMap<NodeId, QuantizationParams>,
432}
433
434impl Quantizer {
435 pub fn new(config: QuantizationConfig) -> Self {
437 Self {
438 config,
439 stats: CalibrationStats::new(),
440 params: HashMap::new(),
441 }
442 }
443
444 pub fn int8() -> Self {
446 Self::new(QuantizationConfig::int8())
447 }
448
449 pub fn int4() -> Self {
451 Self::new(QuantizationConfig::int4())
452 }
453
454 pub fn config(&self) -> &QuantizationConfig {
456 &self.config
457 }
458
459 pub fn stats(&self) -> &CalibrationStats {
461 &self.stats
462 }
463
464 pub fn get_params(&self, node_id: NodeId) -> Option<&QuantizationParams> {
466 self.params.get(&node_id)
467 }
468
469 pub fn calibrate(&mut self, node_id: NodeId, min: f64, max: f64) {
471 self.stats.update(node_id, min, max);
472 }
473
474 pub fn finalize_calibration(&mut self) -> Result<(), QuantizationError> {
476 if self.stats.num_samples < self.config.calibration_samples {
477 return Err(QuantizationError::InsufficientData);
478 }
479
480 for &node_id in self.stats.min_values.keys() {
482 let params = self.stats.compute_params(node_id, &self.config)?;
483 self.params.insert(node_id, params);
484 }
485
486 Ok(())
487 }
488
489 pub fn summary(&self) -> QuantizationSummary {
491 let mut type_counts = HashMap::new();
492 for params in self.params.values() {
493 *type_counts.entry(params.qtype).or_insert(0) += 1;
494 }
495
496 let total_params = self.params.len();
497 let avg_compression = self
498 .params
499 .values()
500 .map(|p| p.qtype.compression_ratio())
501 .sum::<f64>()
502 / total_params.max(1) as f64;
503
504 QuantizationSummary {
505 num_quantized_nodes: total_params,
506 type_distribution: type_counts,
507 avg_compression_ratio: avg_compression,
508 calibration_samples: self.stats.num_samples,
509 }
510 }
511}
512
513#[derive(Debug, Clone, Serialize, Deserialize)]
515pub struct QuantizationSummary {
516 pub num_quantized_nodes: usize,
518 pub type_distribution: HashMap<QuantizationType, usize>,
520 pub avg_compression_ratio: f64,
522 pub calibration_samples: usize,
524}
525
526impl QuantizationSummary {
527 pub fn memory_savings(&self) -> f64 {
529 if self.avg_compression_ratio > 1.0 {
530 (1.0 - 1.0 / self.avg_compression_ratio) * 100.0
531 } else {
532 0.0
533 }
534 }
535}
536
537pub struct FakeQuantize {
539 params: QuantizationParams,
540 enabled: bool,
541}
542
543impl FakeQuantize {
544 pub fn new(params: QuantizationParams) -> Self {
546 Self {
547 params,
548 enabled: true,
549 }
550 }
551
552 pub fn set_enabled(&mut self, enabled: bool) {
554 self.enabled = enabled;
555 }
556
557 pub fn forward(&self, value: f64) -> f64 {
559 if !self.enabled {
560 return value;
561 }
562
563 let qvalue = self.params.quantize(value);
565 self.params.dequantize(qvalue)
566 }
567
568 pub fn forward_batch(&self, values: &[f64]) -> Vec<f64> {
570 values.iter().map(|&v| self.forward(v)).collect()
571 }
572}
573
574#[cfg(test)]
575mod tests {
576 use super::*;
577
578 #[test]
579 fn test_quantization_type_properties() {
580 assert_eq!(QuantizationType::Int8.bits(), 8);
581 assert_eq!(QuantizationType::Int4.bits(), 4);
582 assert_eq!(QuantizationType::Binary.bits(), 1);
583 assert_eq!(QuantizationType::Int8.compression_ratio(), 4.0);
584 assert!(QuantizationType::FP16.is_floating_point());
585 assert!(!QuantizationType::Int8.is_floating_point());
586 }
587
588 #[test]
589 fn test_symmetric_quantization() {
590 let params =
591 QuantizationParams::symmetric_per_tensor(QuantizationType::Int8, 127.0).unwrap();
592 assert_eq!(params.scale[0], 1.0);
593 assert_eq!(params.zero_point[0], 0);
594
595 assert_eq!(params.quantize(0.0), 0);
597 assert_eq!(params.quantize(127.0), 127);
598 assert_eq!(params.quantize(-127.0), -127);
599 assert!((params.dequantize(127) - 127.0).abs() < 1e-10);
600 }
601
602 #[test]
603 fn test_asymmetric_quantization() {
604 let params =
605 QuantizationParams::asymmetric_per_tensor(QuantizationType::Int8, -10.0, 20.0).unwrap();
606
607 assert!(params.scale[0] > 0.0);
608 assert_ne!(params.zero_point[0], 0);
609
610 let original = 5.0;
612 let quantized = params.quantize(original);
613 let dequantized = params.dequantize(quantized);
614 assert!((dequantized - original).abs() < 1.0); }
616
617 #[test]
618 fn test_quantization_config() {
619 let config = QuantizationConfig::int8();
620 assert_eq!(config.default_qtype, QuantizationType::Int8);
621
622 let config = QuantizationConfig::int4().per_channel(0).asymmetric();
623 assert_eq!(config.default_qtype, QuantizationType::Int4);
624 assert!(matches!(
625 config.granularity,
626 QuantizationGranularity::PerChannel { axis: 0 }
627 ));
628 assert_eq!(config.symmetry, QuantizationSymmetry::Asymmetric);
629 }
630
631 #[test]
632 fn test_calibration_stats() {
633 let mut stats = CalibrationStats::new();
634 stats.update(NodeId(0), -5.0, 10.0);
635 stats.update(NodeId(0), -8.0, 12.0);
636
637 assert_eq!(stats.min_values[&NodeId(0)], -8.0);
638 assert_eq!(stats.max_values[&NodeId(0)], 12.0);
639 }
640
641 #[test]
642 fn test_quantizer() {
643 let mut quantizer = Quantizer::int8();
644
645 quantizer.calibrate(NodeId(0), -10.0, 10.0);
647 quantizer.calibrate(NodeId(0), -8.0, 12.0);
648
649 for _ in 0..100 {
652 quantizer.calibrate(NodeId(0), -10.0, 10.0);
653 }
654
655 assert!(quantizer.finalize_calibration().is_ok());
656 assert!(quantizer.get_params(NodeId(0)).is_some());
657
658 let summary = quantizer.summary();
659 assert_eq!(summary.num_quantized_nodes, 1);
660 assert!(summary.avg_compression_ratio > 1.0);
661 }
662
663 #[test]
664 fn test_fake_quantize() {
665 let params =
666 QuantizationParams::symmetric_per_tensor(QuantizationType::Int8, 10.0).unwrap();
667 let fake_quant = FakeQuantize::new(params);
668
669 let original = 3.5;
670 let faked = fake_quant.forward(original);
671
672 assert!((faked - original).abs() < 1.0);
674 }
675
676 #[test]
677 fn test_quantization_summary() {
678 let mut quantizer = Quantizer::int8();
679 for _ in 0..100 {
680 quantizer.calibrate(NodeId(0), -10.0, 10.0);
681 }
682 quantizer.finalize_calibration().unwrap();
683
684 let summary = quantizer.summary();
685 assert!(summary.memory_savings() > 0.0);
686 assert!(summary.memory_savings() < 100.0);
687 }
688
689 #[test]
690 fn test_int4_quantization() {
691 let params = QuantizationParams::symmetric_per_tensor(QuantizationType::Int4, 7.0).unwrap();
692
693 let value = 5.0;
694 let qvalue = params.quantize(value);
695 assert!((-8..=7).contains(&qvalue));
696 }
697
698 #[test]
699 fn test_invalid_range() {
700 let result = QuantizationParams::asymmetric_per_tensor(QuantizationType::Int8, 10.0, 5.0);
701 assert!(matches!(
702 result,
703 Err(QuantizationError::InvalidRange { .. })
704 ));
705 }
706}