1use crate::errors::{Result, TrustformersError};
9use crate::tensor::Tensor;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ActivationQuantConfig {
16 pub scheme: ActivationQuantScheme,
18 pub symmetric: bool,
20 pub calibration_samples: usize,
22 pub percentile: f32,
24 pub ema_decay: f32,
26 pub quantize_during_training: bool,
28 pub layer_configs: HashMap<String, LayerQuantConfig>,
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
34pub enum ActivationQuantScheme {
35 Int8,
37 Int16,
39 Dynamic,
41 Adaptive,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct LayerQuantConfig {
48 pub enabled: bool,
50 pub scheme: Option<ActivationQuantScheme>,
52 pub bits: Option<u8>,
54 pub calibrate: bool,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct ActivationStats {
61 pub min_val: f32,
63 pub max_val: f32,
65 pub sum: f64,
67 pub sum_squares: f64,
69 pub count: usize,
71 pub histogram: Vec<(f32, usize)>,
73 pub ema_min: f32,
75 pub ema_max: f32,
77}
78
79#[derive(Debug, Clone)]
81pub struct QuantizedActivation {
82 pub data: Vec<u8>,
84 pub scale: f32,
86 pub zero_point: i32,
88 pub shape: Vec<usize>,
90 pub scheme: ActivationQuantScheme,
92 pub bits: u8,
94}
95
96pub struct ActivationQuantizer {
98 config: ActivationQuantConfig,
99 layer_stats: HashMap<String, ActivationStats>,
101 calibrating: bool,
103 calibration_count: usize,
105}
106
107impl Default for ActivationQuantConfig {
108 fn default() -> Self {
109 Self {
110 scheme: ActivationQuantScheme::Int8,
111 symmetric: false,
112 calibration_samples: 100,
113 percentile: 0.99,
114 ema_decay: 0.01,
115 quantize_during_training: false,
116 layer_configs: HashMap::new(),
117 }
118 }
119}
120
121impl Default for LayerQuantConfig {
122 fn default() -> Self {
123 Self {
124 enabled: true,
125 scheme: None,
126 bits: None,
127 calibrate: true,
128 }
129 }
130}
131
132impl Default for ActivationStats {
133 fn default() -> Self {
134 Self::new()
135 }
136}
137
138impl ActivationStats {
139 pub fn new() -> Self {
141 Self {
142 min_val: f32::INFINITY,
143 max_val: f32::NEG_INFINITY,
144 sum: 0.0,
145 sum_squares: 0.0,
146 count: 0,
147 histogram: Vec::new(),
148 ema_min: f32::INFINITY,
149 ema_max: f32::NEG_INFINITY,
150 }
151 }
152
153 pub fn update(&mut self, tensor: &Tensor, ema_decay: f32) -> Result<()> {
155 match tensor {
156 Tensor::F32(arr) => {
157 let data: Vec<f32> = arr.iter().cloned().collect();
158
159 for &val in &data {
160 if !val.is_finite() {
161 continue; }
163
164 self.min_val = self.min_val.min(val);
165 self.max_val = self.max_val.max(val);
166 self.sum += val as f64;
167 self.sum_squares += (val * val) as f64;
168 self.count += 1;
169
170 if self.ema_min.is_infinite() {
172 self.ema_min = val;
173 self.ema_max = val;
174 } else {
175 if val < self.ema_min {
176 self.ema_min = self.ema_min * (1.0 - ema_decay) + val * ema_decay;
177 }
178 if val > self.ema_max {
179 self.ema_max = self.ema_max * (1.0 - ema_decay) + val * ema_decay;
180 }
181 }
182 }
183
184 let num_bins = 1000;
186 let range = self.max_val - self.min_val;
187 if range > 0.0 {
188 self.histogram.resize(num_bins, (0.0, 0));
189 for &val in &data {
190 if val.is_finite() {
191 let bin_idx =
192 ((val - self.min_val) / range * (num_bins - 1) as f32) as usize;
193 let bin_idx = bin_idx.min(num_bins - 1);
194 self.histogram[bin_idx].0 = val;
195 self.histogram[bin_idx].1 += 1;
196 }
197 }
198 }
199 },
200 _ => {
201 return Err(TrustformersError::quantization_error(
202 "Unsupported tensor type for activation quantization".into(),
203 ))
204 },
205 }
206
207 Ok(())
208 }
209
210 pub fn mean(&self) -> f32 {
212 if self.count == 0 {
213 0.0
214 } else {
215 (self.sum / self.count as f64) as f32
216 }
217 }
218
219 pub fn variance(&self) -> f32 {
221 if self.count <= 1 {
222 0.0
223 } else {
224 let mean = self.mean() as f64;
225 let variance = (self.sum_squares / self.count as f64) - (mean * mean);
226 variance.max(0.0) as f32
227 }
228 }
229
230 pub fn percentile(&self, p: f32) -> f32 {
232 if self.histogram.is_empty() || self.count == 0 {
233 return self.max_val;
234 }
235
236 let target_count = (self.count as f32 * p) as usize;
237 let mut cumulative_count = 0;
238
239 for &(val, count) in &self.histogram {
240 cumulative_count += count;
241 if cumulative_count >= target_count {
242 return val;
243 }
244 }
245
246 self.max_val
247 }
248
249 pub fn get_quantization_params(
251 &self,
252 symmetric: bool,
253 bits: u8,
254 percentile: f32,
255 ) -> Result<(f32, i32)> {
256 if self.count == 0 {
257 return Err(TrustformersError::quantization_error(
258 "No statistics available for quantization".into(),
259 ));
260 }
261
262 let q_min = if symmetric { -(1 << (bits - 1)) } else { 0 };
263 let q_max = if symmetric { (1 << (bits - 1)) - 1 } else { (1 << bits) - 1 };
264
265 let min_val = if percentile < 1.0 {
266 -self.percentile(1.0 - percentile)
268 } else {
269 self.min_val
270 };
271
272 let max_val = if percentile < 1.0 { self.percentile(percentile) } else { self.max_val };
273
274 let (scale, zero_point) = if symmetric {
275 let abs_max = max_val.abs().max(min_val.abs());
276 if abs_max == 0.0 {
277 return Ok((1.0, 0));
278 }
279 let scale = abs_max / (q_max - q_min) as f32;
280 (scale, 0)
281 } else {
282 if max_val == min_val {
283 return Ok((1.0, q_min));
284 }
285 let scale = (max_val - min_val) / (q_max - q_min) as f32;
286 let zero_point = q_min - (min_val / scale).round() as i32;
287 let zero_point = zero_point.clamp(q_min, q_max);
288 (scale, zero_point)
289 };
290
291 Ok((scale, zero_point))
292 }
293}
294
295impl QuantizedActivation {
296 pub fn new(
298 data: Vec<u8>,
299 scale: f32,
300 zero_point: i32,
301 shape: Vec<usize>,
302 scheme: ActivationQuantScheme,
303 bits: u8,
304 ) -> Self {
305 Self {
306 data,
307 scale,
308 zero_point,
309 shape,
310 scheme,
311 bits,
312 }
313 }
314
315 pub fn dequantize(&self) -> Result<Tensor> {
317 let total_elements: usize = self.shape.iter().product();
318 let mut result = Vec::with_capacity(total_elements);
319
320 match self.scheme {
321 ActivationQuantScheme::Int8 | ActivationQuantScheme::Dynamic => {
322 for &quantized_val in &self.data {
323 let int_val = quantized_val as i32 - self.zero_point;
324 let float_val = int_val as f32 * self.scale;
325 result.push(float_val);
326 }
327 },
328 ActivationQuantScheme::Int16 => {
329 for chunk in self.data.chunks(2) {
331 if chunk.len() == 2 {
332 let int16_val =
333 u16::from_le_bytes([chunk[0], chunk[1]]) as i32 - self.zero_point;
334 let float_val = int16_val as f32 * self.scale;
335 result.push(float_val);
336 }
337 }
338 },
339 ActivationQuantScheme::Adaptive => {
340 for &quantized_val in &self.data {
342 let int_val = quantized_val as i32 - self.zero_point;
343 let float_val = int_val as f32 * self.scale;
344 result.push(float_val);
345 }
346 },
347 }
348
349 Tensor::from_vec(result, &self.shape)
350 }
351}
352
353impl ActivationQuantizer {
354 pub fn new(config: ActivationQuantConfig) -> Self {
356 Self {
357 config,
358 layer_stats: HashMap::new(),
359 calibrating: true,
360 calibration_count: 0,
361 }
362 }
363
364 pub fn start_calibration(&mut self) {
366 self.calibrating = true;
367 self.calibration_count = 0;
368 self.layer_stats.clear();
369 }
370
371 pub fn end_calibration(&mut self) {
373 self.calibrating = false;
374 }
375
376 pub fn is_calibration_complete(&self) -> bool {
378 !self.calibrating || self.calibration_count >= self.config.calibration_samples
379 }
380
381 pub fn quantize_activation(
383 &mut self,
384 tensor: &Tensor,
385 layer_name: &str,
386 training: bool,
387 ) -> Result<Tensor> {
388 let layer_config = self.config.layer_configs.get(layer_name).cloned().unwrap_or_default();
390
391 if !layer_config.enabled {
392 return Ok(tensor.clone());
393 }
394
395 if training && !self.config.quantize_during_training {
397 if self.calibrating && layer_config.calibrate {
398 self.update_statistics(tensor, layer_name)?;
399 }
400 return Ok(tensor.clone());
401 }
402
403 if self.calibrating && layer_config.calibrate {
405 self.update_statistics(tensor, layer_name)?;
406
407 if self.calibration_count < self.config.calibration_samples {
409 return Ok(tensor.clone());
410 }
411 }
412
413 self.apply_quantization(tensor, layer_name, &layer_config)
415 }
416
417 fn update_statistics(&mut self, tensor: &Tensor, layer_name: &str) -> Result<()> {
419 let stats = self.layer_stats.entry(layer_name.to_string()).or_default();
420
421 stats.update(tensor, self.config.ema_decay)?;
422 self.calibration_count += 1;
423
424 Ok(())
425 }
426
427 fn apply_quantization(
429 &self,
430 tensor: &Tensor,
431 layer_name: &str,
432 layer_config: &LayerQuantConfig,
433 ) -> Result<Tensor> {
434 let stats = self.layer_stats.get(layer_name).ok_or_else(|| {
435 TrustformersError::quantization_error(format!(
436 "No calibration statistics found for layer {}",
437 layer_name
438 ))
439 })?;
440
441 let scheme = layer_config.scheme.unwrap_or(self.config.scheme);
442 let bits = layer_config.bits.unwrap_or(match scheme {
443 ActivationQuantScheme::Int8
444 | ActivationQuantScheme::Dynamic
445 | ActivationQuantScheme::Adaptive => 8,
446 ActivationQuantScheme::Int16 => 16,
447 });
448
449 let (scale, zero_point) =
450 stats.get_quantization_params(self.config.symmetric, bits, self.config.percentile)?;
451
452 match scheme {
453 ActivationQuantScheme::Int8 | ActivationQuantScheme::Dynamic => {
454 self.quantize_int8(tensor, scale, zero_point)
455 },
456 ActivationQuantScheme::Int16 => self.quantize_int16(tensor, scale, zero_point),
457 ActivationQuantScheme::Adaptive => {
458 self.quantize_adaptive(tensor, stats, scale, zero_point)
459 },
460 }
461 }
462
463 fn quantize_int8(&self, tensor: &Tensor, scale: f32, zero_point: i32) -> Result<Tensor> {
465 match tensor {
466 Tensor::F32(arr) => {
467 let quantized_data: Vec<f32> = arr
468 .iter()
469 .map(|&val| {
470 let q_val = ((val / scale).round() as i32 + zero_point).clamp(0, 255) as u8;
471
472 (q_val as i32 - zero_point) as f32 * scale
473 })
474 .collect();
475
476 Tensor::from_vec(quantized_data, arr.shape())
477 },
478 _ => Err(TrustformersError::quantization_error(
479 "Unsupported tensor type for activation quantization".into(),
480 )),
481 }
482 }
483
484 fn quantize_int16(&self, tensor: &Tensor, scale: f32, zero_point: i32) -> Result<Tensor> {
486 match tensor {
487 Tensor::F32(arr) => {
488 let quantized_data: Vec<f32> = arr
489 .iter()
490 .map(|&val| {
491 let q_val =
492 ((val / scale).round() as i32 + zero_point).clamp(0, 65535) as u16;
493
494 (q_val as i32 - zero_point) as f32 * scale
495 })
496 .collect();
497
498 Tensor::from_vec(quantized_data, arr.shape())
499 },
500 _ => Err(TrustformersError::quantization_error(
501 "Unsupported tensor type for activation quantization".into(),
502 )),
503 }
504 }
505
506 fn quantize_adaptive(
508 &self,
509 tensor: &Tensor,
510 stats: &ActivationStats,
511 scale: f32,
512 zero_point: i32,
513 ) -> Result<Tensor> {
514 match tensor {
515 Tensor::F32(arr) => {
516 let variance = stats.variance();
517 let mean = stats.mean();
518
519 let quantized_data: Vec<f32> = arr
521 .iter()
522 .map(|&val| {
523 let effective_scale = if variance < 0.1 {
525 scale * 0.5 } else {
527 scale
528 };
529
530 let clipped_val = if (val - mean).abs() > 3.0 * variance.sqrt() {
532 if val > mean {
533 mean + 3.0 * variance.sqrt()
534 } else {
535 mean - 3.0 * variance.sqrt()
536 }
537 } else {
538 val
539 };
540
541 let q_val = ((clipped_val / effective_scale).round() as i32 + zero_point)
542 .clamp(0, 255) as u8;
543
544 (q_val as i32 - zero_point) as f32 * effective_scale
545 })
546 .collect();
547
548 Tensor::from_vec(quantized_data, arr.shape())
549 },
550 _ => Err(TrustformersError::quantization_error(
551 "Unsupported tensor type for adaptive quantization".into(),
552 )),
553 }
554 }
555
556 pub fn get_layer_stats(&self, layer_name: &str) -> Option<&ActivationStats> {
558 self.layer_stats.get(layer_name)
559 }
560
561 pub fn get_all_stats(&self) -> &HashMap<String, ActivationStats> {
563 &self.layer_stats
564 }
565
566 pub fn save_calibration(&self, path: &str) -> Result<()> {
568 let json_data = serde_json::to_string_pretty(&self.layer_stats).map_err(|e| {
569 TrustformersError::quantization_error(format!("Failed to serialize statistics: {}", e))
570 })?;
571
572 std::fs::write(path, json_data).map_err(|e| {
573 TrustformersError::quantization_error(format!("Failed to write file: {}", e))
574 })?;
575
576 Ok(())
577 }
578
579 pub fn load_calibration(&mut self, path: &str) -> Result<()> {
581 let json_data = std::fs::read_to_string(path).map_err(|e| {
582 TrustformersError::quantization_error(format!("Failed to read file: {}", e))
583 })?;
584
585 self.layer_stats = serde_json::from_str(&json_data).map_err(|e| {
586 TrustformersError::quantization_error(format!(
587 "Failed to deserialize statistics: {}",
588 e
589 ))
590 })?;
591
592 self.calibrating = false;
593 Ok(())
594 }
595
596 pub fn configure_layer(&mut self, layer_name: &str, config: LayerQuantConfig) {
598 self.config.layer_configs.insert(layer_name.to_string(), config);
599 }
600
601 pub fn disable_layer(&mut self, layer_name: &str) {
603 let config = LayerQuantConfig {
604 enabled: false,
605 ..Default::default()
606 };
607 self.config.layer_configs.insert(layer_name.to_string(), config);
608 }
609
610 pub fn get_memory_savings(&self) -> f32 {
612 match self.config.scheme {
614 ActivationQuantScheme::Int8
615 | ActivationQuantScheme::Dynamic
616 | ActivationQuantScheme::Adaptive => 0.75, ActivationQuantScheme::Int16 => 0.5, }
619 }
620}
621
622#[cfg(test)]
623mod tests {
624 use super::*;
625
626 #[test]
627 fn test_activation_stats_update() {
628 let mut stats = ActivationStats::new();
629 let tensor =
630 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5]).expect("Tensor from_vec failed");
631
632 stats.update(&tensor, 0.01).expect("tensor operation failed");
633
634 assert_eq!(stats.count, 5);
635 assert_eq!(stats.min_val, 1.0);
636 assert_eq!(stats.max_val, 5.0);
637 assert_eq!(stats.mean(), 3.0);
638 }
639
640 #[test]
641 fn test_activation_stats_quantization_params() {
642 let mut stats = ActivationStats::new();
643 let tensor = Tensor::from_vec(vec![-2.0, -1.0, 0.0, 1.0, 2.0], &[5])
644 .expect("Tensor from_vec failed");
645
646 stats.update(&tensor, 0.01).expect("tensor operation failed");
647
648 let (scale, zero_point) =
649 stats.get_quantization_params(true, 8, 1.0).expect("operation failed in test");
650 assert!(scale > 0.0);
651 assert_eq!(zero_point, 0); }
653
654 #[test]
655 fn test_activation_quantizer_calibration() {
656 let config = ActivationQuantConfig {
657 calibration_samples: 2,
658 ..Default::default()
659 };
660 let mut quantizer = ActivationQuantizer::new(config);
661
662 let tensor1 = Tensor::randn(&[10, 20]).expect("Failed to create random tensor");
663 let tensor2 = Tensor::randn(&[10, 20]).expect("Failed to create random tensor");
664
665 assert!(quantizer.calibrating);
667 quantizer
668 .quantize_activation(&tensor1, "layer1", false)
669 .expect("tensor operation failed");
670 quantizer
671 .quantize_activation(&tensor2, "layer1", false)
672 .expect("tensor operation failed");
673
674 assert!(quantizer.get_layer_stats("layer1").is_some());
676 }
677
678 #[test]
679 fn test_activation_quantizer_int8() {
680 let config = ActivationQuantConfig {
681 calibration_samples: 1,
682 scheme: ActivationQuantScheme::Int8,
683 ..Default::default()
684 };
685
686 let mut quantizer = ActivationQuantizer::new(config);
687
688 let tensor =
689 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).expect("Tensor from_vec failed");
690
691 quantizer
693 .quantize_activation(&tensor, "test_layer", false)
694 .expect("tensor operation failed");
695 quantizer.end_calibration();
696
697 let result = quantizer
699 .quantize_activation(&tensor, "test_layer", false)
700 .expect("tensor operation failed");
701 assert_eq!(result.shape(), tensor.shape());
702 }
703
704 #[test]
705 fn test_activation_quantizer_layer_config() {
706 let config = ActivationQuantConfig::default();
707 let mut quantizer = ActivationQuantizer::new(config);
708
709 let layer_config = LayerQuantConfig {
711 enabled: true,
712 scheme: Some(ActivationQuantScheme::Int16),
713 bits: Some(16),
714 calibrate: true,
715 };
716 quantizer.configure_layer("special_layer", layer_config);
717
718 quantizer.disable_layer("disabled_layer");
720
721 let tensor = Tensor::randn(&[8, 8]).expect("Failed to create random tensor");
722
723 let result = quantizer
725 .quantize_activation(&tensor, "disabled_layer", false)
726 .expect("tensor operation failed");
727 assert_eq!(result.shape(), tensor.shape());
729 }
730
731 #[test]
732 fn test_activation_quantizer_adaptive() {
733 let config = ActivationQuantConfig {
734 scheme: ActivationQuantScheme::Adaptive,
735 calibration_samples: 1,
736 ..Default::default()
737 };
738
739 let mut quantizer = ActivationQuantizer::new(config);
740
741 let tensor = Tensor::from_vec(vec![0.1, 0.2, 0.15, 0.18, 10.0], &[5])
742 .expect("Tensor from_vec failed"); quantizer
746 .quantize_activation(&tensor, "adaptive_layer", false)
747 .expect("tensor operation failed");
748 quantizer.end_calibration();
749
750 let result = quantizer
752 .quantize_activation(&tensor, "adaptive_layer", false)
753 .expect("tensor operation failed");
754 assert_eq!(result.shape(), tensor.shape());
755 }
756
757 #[test]
758 fn test_quantized_activation_dequantization() {
759 let _original_data = [1.0, 2.0, 3.0, 4.0];
760 let shape = vec![4];
761
762 let quantized_data = vec![64, 128, 192, 255]; let scale = 4.0 / 255.0; let zero_point = 0;
766
767 let quant_activation = QuantizedActivation::new(
768 quantized_data,
769 scale,
770 zero_point,
771 shape.clone(),
772 ActivationQuantScheme::Int8,
773 8,
774 );
775
776 let dequantized = quant_activation.dequantize().expect("Dequantization failed");
777 assert_eq!(dequantized.shape(), shape);
778 }
779
780 #[test]
781 fn test_memory_savings_calculation() {
782 let config = ActivationQuantConfig {
783 scheme: ActivationQuantScheme::Int8,
784 ..Default::default()
785 };
786 let quantizer = ActivationQuantizer::new(config);
787
788 let savings = quantizer.get_memory_savings();
789 assert_eq!(savings, 0.75); }
791
792 #[test]
793 fn test_percentile_calculation() {
794 let mut stats = ActivationStats::new();
795 let tensor = Tensor::from_vec((1..=100).map(|x| x as f32).collect(), &[100])
796 .expect("tensor operation failed");
797
798 stats.update(&tensor, 0.01).expect("tensor operation failed");
799
800 let p95 = stats.percentile(0.95);
801 assert!((90.0..=100.0).contains(&p95)); }
803
804 #[test]
805 fn test_serialization() {
806 let config = ActivationQuantConfig::default();
807 let serialized = serde_json::to_string(&config).expect("JSON serialization failed");
808 let deserialized: ActivationQuantConfig =
809 serde_json::from_str(&serialized).expect("JSON deserialization failed");
810
811 assert_eq!(config.scheme, deserialized.scheme);
812 assert_eq!(config.symmetric, deserialized.symmetric);
813 }
814}