Skip to main content

trustformers_core/quantization/
activation.rs

1//! Activation quantization for TrustformeRS
2//!
3//! This module provides activation quantization functionality, which quantizes intermediate
4//! layer outputs during inference and training. Unlike weight quantization which is applied
5//! to model parameters, activation quantization is applied dynamically to the data flowing
6//! through the network.
7
8use crate::errors::{Result, TrustformersError};
9use crate::tensor::Tensor;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13/// Configuration for activation quantization
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ActivationQuantConfig {
16    /// Quantization scheme for activations
17    pub scheme: ActivationQuantScheme,
18    /// Whether to use symmetric quantization
19    pub symmetric: bool,
20    /// Number of calibration samples to collect statistics
21    pub calibration_samples: usize,
22    /// Percentile for outlier-aware quantization (e.g., 0.99)
23    pub percentile: f32,
24    /// Moving average decay for running statistics
25    pub ema_decay: f32,
26    /// Whether to apply quantization during training
27    pub quantize_during_training: bool,
28    /// Layer-specific configurations
29    pub layer_configs: HashMap<String, LayerQuantConfig>,
30}
31
32/// Activation quantization schemes
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
34pub enum ActivationQuantScheme {
35    /// 8-bit integer quantization
36    Int8,
37    /// 16-bit integer quantization
38    Int16,
39    /// Dynamic range quantization
40    Dynamic,
41    /// Adaptive quantization based on activation distribution
42    Adaptive,
43}
44
45/// Layer-specific quantization configuration
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct LayerQuantConfig {
48    /// Whether to quantize this layer's activations
49    pub enabled: bool,
50    /// Custom quantization scheme for this layer
51    pub scheme: Option<ActivationQuantScheme>,
52    /// Custom bit width (overrides scheme if provided)
53    pub bits: Option<u8>,
54    /// Whether to use layer-specific calibration
55    pub calibrate: bool,
56}
57
58/// Statistics for activation quantization calibration
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct ActivationStats {
61    /// Running minimum value
62    pub min_val: f32,
63    /// Running maximum value
64    pub max_val: f32,
65    /// Running sum for mean calculation
66    pub sum: f64,
67    /// Running sum of squares for variance calculation
68    pub sum_squares: f64,
69    /// Number of samples observed
70    pub count: usize,
71    /// Histogram for percentile calculation
72    pub histogram: Vec<(f32, usize)>,
73    /// EMA of min/max values
74    pub ema_min: f32,
75    /// EMA of max values
76    pub ema_max: f32,
77}
78
79/// Quantized activation tensor
80#[derive(Debug, Clone)]
81pub struct QuantizedActivation {
82    /// Quantized data
83    pub data: Vec<u8>,
84    /// Quantization scale
85    pub scale: f32,
86    /// Zero point for asymmetric quantization
87    pub zero_point: i32,
88    /// Original tensor shape
89    pub shape: Vec<usize>,
90    /// Quantization scheme used
91    pub scheme: ActivationQuantScheme,
92    /// Number of bits used
93    pub bits: u8,
94}
95
96/// Activation quantization manager
97pub struct ActivationQuantizer {
98    config: ActivationQuantConfig,
99    /// Layer statistics for calibration
100    layer_stats: HashMap<String, ActivationStats>,
101    /// Whether calibration phase is active
102    calibrating: bool,
103    /// Number of calibration samples seen
104    calibration_count: usize,
105}
106
107impl Default for ActivationQuantConfig {
108    fn default() -> Self {
109        Self {
110            scheme: ActivationQuantScheme::Int8,
111            symmetric: false,
112            calibration_samples: 100,
113            percentile: 0.99,
114            ema_decay: 0.01,
115            quantize_during_training: false,
116            layer_configs: HashMap::new(),
117        }
118    }
119}
120
121impl Default for LayerQuantConfig {
122    fn default() -> Self {
123        Self {
124            enabled: true,
125            scheme: None,
126            bits: None,
127            calibrate: true,
128        }
129    }
130}
131
132impl Default for ActivationStats {
133    fn default() -> Self {
134        Self::new()
135    }
136}
137
138impl ActivationStats {
139    /// Create new activation statistics
140    pub fn new() -> Self {
141        Self {
142            min_val: f32::INFINITY,
143            max_val: f32::NEG_INFINITY,
144            sum: 0.0,
145            sum_squares: 0.0,
146            count: 0,
147            histogram: Vec::new(),
148            ema_min: f32::INFINITY,
149            ema_max: f32::NEG_INFINITY,
150        }
151    }
152
153    /// Update statistics with new tensor
154    pub fn update(&mut self, tensor: &Tensor, ema_decay: f32) -> Result<()> {
155        match tensor {
156            Tensor::F32(arr) => {
157                let data: Vec<f32> = arr.iter().cloned().collect();
158
159                for &val in &data {
160                    if !val.is_finite() {
161                        continue; // Skip NaN/Inf values
162                    }
163
164                    self.min_val = self.min_val.min(val);
165                    self.max_val = self.max_val.max(val);
166                    self.sum += val as f64;
167                    self.sum_squares += (val * val) as f64;
168                    self.count += 1;
169
170                    // Update EMA min/max
171                    if self.ema_min.is_infinite() {
172                        self.ema_min = val;
173                        self.ema_max = val;
174                    } else {
175                        if val < self.ema_min {
176                            self.ema_min = self.ema_min * (1.0 - ema_decay) + val * ema_decay;
177                        }
178                        if val > self.ema_max {
179                            self.ema_max = self.ema_max * (1.0 - ema_decay) + val * ema_decay;
180                        }
181                    }
182                }
183
184                // Update histogram (simple binning for percentile calculation)
185                let num_bins = 1000;
186                let range = self.max_val - self.min_val;
187                if range > 0.0 {
188                    self.histogram.resize(num_bins, (0.0, 0));
189                    for &val in &data {
190                        if val.is_finite() {
191                            let bin_idx =
192                                ((val - self.min_val) / range * (num_bins - 1) as f32) as usize;
193                            let bin_idx = bin_idx.min(num_bins - 1);
194                            self.histogram[bin_idx].0 = val;
195                            self.histogram[bin_idx].1 += 1;
196                        }
197                    }
198                }
199            },
200            _ => {
201                return Err(TrustformersError::quantization_error(
202                    "Unsupported tensor type for activation quantization".into(),
203                ))
204            },
205        }
206
207        Ok(())
208    }
209
210    /// Get mean of observed values
211    pub fn mean(&self) -> f32 {
212        if self.count == 0 {
213            0.0
214        } else {
215            (self.sum / self.count as f64) as f32
216        }
217    }
218
219    /// Get variance of observed values
220    pub fn variance(&self) -> f32 {
221        if self.count <= 1 {
222            0.0
223        } else {
224            let mean = self.mean() as f64;
225            let variance = (self.sum_squares / self.count as f64) - (mean * mean);
226            variance.max(0.0) as f32
227        }
228    }
229
230    /// Get percentile value from histogram
231    pub fn percentile(&self, p: f32) -> f32 {
232        if self.histogram.is_empty() || self.count == 0 {
233            return self.max_val;
234        }
235
236        let target_count = (self.count as f32 * p) as usize;
237        let mut cumulative_count = 0;
238
239        for &(val, count) in &self.histogram {
240            cumulative_count += count;
241            if cumulative_count >= target_count {
242                return val;
243            }
244        }
245
246        self.max_val
247    }
248
249    /// Get quantization parameters based on statistics
250    pub fn get_quantization_params(
251        &self,
252        symmetric: bool,
253        bits: u8,
254        percentile: f32,
255    ) -> Result<(f32, i32)> {
256        if self.count == 0 {
257            return Err(TrustformersError::quantization_error(
258                "No statistics available for quantization".into(),
259            ));
260        }
261
262        let q_min = if symmetric { -(1 << (bits - 1)) } else { 0 };
263        let q_max = if symmetric { (1 << (bits - 1)) - 1 } else { (1 << bits) - 1 };
264
265        let min_val = if percentile < 1.0 {
266            // Use percentile-based clipping for outlier robustness
267            -self.percentile(1.0 - percentile)
268        } else {
269            self.min_val
270        };
271
272        let max_val = if percentile < 1.0 { self.percentile(percentile) } else { self.max_val };
273
274        let (scale, zero_point) = if symmetric {
275            let abs_max = max_val.abs().max(min_val.abs());
276            if abs_max == 0.0 {
277                return Ok((1.0, 0));
278            }
279            let scale = abs_max / (q_max - q_min) as f32;
280            (scale, 0)
281        } else {
282            if max_val == min_val {
283                return Ok((1.0, q_min));
284            }
285            let scale = (max_val - min_val) / (q_max - q_min) as f32;
286            let zero_point = q_min - (min_val / scale).round() as i32;
287            let zero_point = zero_point.clamp(q_min, q_max);
288            (scale, zero_point)
289        };
290
291        Ok((scale, zero_point))
292    }
293}
294
295impl QuantizedActivation {
296    /// Create new quantized activation
297    pub fn new(
298        data: Vec<u8>,
299        scale: f32,
300        zero_point: i32,
301        shape: Vec<usize>,
302        scheme: ActivationQuantScheme,
303        bits: u8,
304    ) -> Self {
305        Self {
306            data,
307            scale,
308            zero_point,
309            shape,
310            scheme,
311            bits,
312        }
313    }
314
315    /// Dequantize back to float tensor
316    pub fn dequantize(&self) -> Result<Tensor> {
317        let total_elements: usize = self.shape.iter().product();
318        let mut result = Vec::with_capacity(total_elements);
319
320        match self.scheme {
321            ActivationQuantScheme::Int8 | ActivationQuantScheme::Dynamic => {
322                for &quantized_val in &self.data {
323                    let int_val = quantized_val as i32 - self.zero_point;
324                    let float_val = int_val as f32 * self.scale;
325                    result.push(float_val);
326                }
327            },
328            ActivationQuantScheme::Int16 => {
329                // For 16-bit, we need to unpack the data differently
330                for chunk in self.data.chunks(2) {
331                    if chunk.len() == 2 {
332                        let int16_val =
333                            u16::from_le_bytes([chunk[0], chunk[1]]) as i32 - self.zero_point;
334                        let float_val = int16_val as f32 * self.scale;
335                        result.push(float_val);
336                    }
337                }
338            },
339            ActivationQuantScheme::Adaptive => {
340                // Same as Int8 for now
341                for &quantized_val in &self.data {
342                    let int_val = quantized_val as i32 - self.zero_point;
343                    let float_val = int_val as f32 * self.scale;
344                    result.push(float_val);
345                }
346            },
347        }
348
349        Tensor::from_vec(result, &self.shape)
350    }
351}
352
353impl ActivationQuantizer {
354    /// Create new activation quantizer
355    pub fn new(config: ActivationQuantConfig) -> Self {
356        Self {
357            config,
358            layer_stats: HashMap::new(),
359            calibrating: true,
360            calibration_count: 0,
361        }
362    }
363
364    /// Start calibration phase
365    pub fn start_calibration(&mut self) {
366        self.calibrating = true;
367        self.calibration_count = 0;
368        self.layer_stats.clear();
369    }
370
371    /// End calibration phase
372    pub fn end_calibration(&mut self) {
373        self.calibrating = false;
374    }
375
376    /// Check if calibration is complete
377    pub fn is_calibration_complete(&self) -> bool {
378        !self.calibrating || self.calibration_count >= self.config.calibration_samples
379    }
380
381    /// Quantize activation tensor for a specific layer
382    pub fn quantize_activation(
383        &mut self,
384        tensor: &Tensor,
385        layer_name: &str,
386        training: bool,
387    ) -> Result<Tensor> {
388        // Get layer-specific configuration
389        let layer_config = self.config.layer_configs.get(layer_name).cloned().unwrap_or_default();
390
391        if !layer_config.enabled {
392            return Ok(tensor.clone());
393        }
394
395        // Don't quantize during training unless explicitly enabled
396        if training && !self.config.quantize_during_training {
397            if self.calibrating && layer_config.calibrate {
398                self.update_statistics(tensor, layer_name)?;
399            }
400            return Ok(tensor.clone());
401        }
402
403        // Update statistics during calibration
404        if self.calibrating && layer_config.calibrate {
405            self.update_statistics(tensor, layer_name)?;
406
407            // Return original tensor during calibration
408            if self.calibration_count < self.config.calibration_samples {
409                return Ok(tensor.clone());
410            }
411        }
412
413        // Apply quantization
414        self.apply_quantization(tensor, layer_name, &layer_config)
415    }
416
417    /// Update statistics for a layer
418    fn update_statistics(&mut self, tensor: &Tensor, layer_name: &str) -> Result<()> {
419        let stats = self.layer_stats.entry(layer_name.to_string()).or_default();
420
421        stats.update(tensor, self.config.ema_decay)?;
422        self.calibration_count += 1;
423
424        Ok(())
425    }
426
427    /// Apply quantization to tensor
428    fn apply_quantization(
429        &self,
430        tensor: &Tensor,
431        layer_name: &str,
432        layer_config: &LayerQuantConfig,
433    ) -> Result<Tensor> {
434        let stats = self.layer_stats.get(layer_name).ok_or_else(|| {
435            TrustformersError::quantization_error(format!(
436                "No calibration statistics found for layer {}",
437                layer_name
438            ))
439        })?;
440
441        let scheme = layer_config.scheme.unwrap_or(self.config.scheme);
442        let bits = layer_config.bits.unwrap_or(match scheme {
443            ActivationQuantScheme::Int8
444            | ActivationQuantScheme::Dynamic
445            | ActivationQuantScheme::Adaptive => 8,
446            ActivationQuantScheme::Int16 => 16,
447        });
448
449        let (scale, zero_point) =
450            stats.get_quantization_params(self.config.symmetric, bits, self.config.percentile)?;
451
452        match scheme {
453            ActivationQuantScheme::Int8 | ActivationQuantScheme::Dynamic => {
454                self.quantize_int8(tensor, scale, zero_point)
455            },
456            ActivationQuantScheme::Int16 => self.quantize_int16(tensor, scale, zero_point),
457            ActivationQuantScheme::Adaptive => {
458                self.quantize_adaptive(tensor, stats, scale, zero_point)
459            },
460        }
461    }
462
463    /// Quantize tensor to 8-bit integers
464    fn quantize_int8(&self, tensor: &Tensor, scale: f32, zero_point: i32) -> Result<Tensor> {
465        match tensor {
466            Tensor::F32(arr) => {
467                let quantized_data: Vec<f32> = arr
468                    .iter()
469                    .map(|&val| {
470                        let q_val = ((val / scale).round() as i32 + zero_point).clamp(0, 255) as u8;
471
472                        (q_val as i32 - zero_point) as f32 * scale
473                    })
474                    .collect();
475
476                Tensor::from_vec(quantized_data, arr.shape())
477            },
478            _ => Err(TrustformersError::quantization_error(
479                "Unsupported tensor type for activation quantization".into(),
480            )),
481        }
482    }
483
484    /// Quantize tensor to 16-bit integers
485    fn quantize_int16(&self, tensor: &Tensor, scale: f32, zero_point: i32) -> Result<Tensor> {
486        match tensor {
487            Tensor::F32(arr) => {
488                let quantized_data: Vec<f32> = arr
489                    .iter()
490                    .map(|&val| {
491                        let q_val =
492                            ((val / scale).round() as i32 + zero_point).clamp(0, 65535) as u16;
493
494                        (q_val as i32 - zero_point) as f32 * scale
495                    })
496                    .collect();
497
498                Tensor::from_vec(quantized_data, arr.shape())
499            },
500            _ => Err(TrustformersError::quantization_error(
501                "Unsupported tensor type for activation quantization".into(),
502            )),
503        }
504    }
505
506    /// Adaptive quantization based on activation distribution
507    fn quantize_adaptive(
508        &self,
509        tensor: &Tensor,
510        stats: &ActivationStats,
511        scale: f32,
512        zero_point: i32,
513    ) -> Result<Tensor> {
514        match tensor {
515            Tensor::F32(arr) => {
516                let variance = stats.variance();
517                let mean = stats.mean();
518
519                // Use different quantization strategies based on distribution characteristics
520                let quantized_data: Vec<f32> = arr
521                    .iter()
522                    .map(|&val| {
523                        // For low variance activations, use more aggressive quantization
524                        let effective_scale = if variance < 0.1 {
525                            scale * 0.5 // Finer quantization for low variance
526                        } else {
527                            scale
528                        };
529
530                        // Apply outlier clipping for values far from mean
531                        let clipped_val = if (val - mean).abs() > 3.0 * variance.sqrt() {
532                            if val > mean {
533                                mean + 3.0 * variance.sqrt()
534                            } else {
535                                mean - 3.0 * variance.sqrt()
536                            }
537                        } else {
538                            val
539                        };
540
541                        let q_val = ((clipped_val / effective_scale).round() as i32 + zero_point)
542                            .clamp(0, 255) as u8;
543
544                        (q_val as i32 - zero_point) as f32 * effective_scale
545                    })
546                    .collect();
547
548                Tensor::from_vec(quantized_data, arr.shape())
549            },
550            _ => Err(TrustformersError::quantization_error(
551                "Unsupported tensor type for adaptive quantization".into(),
552            )),
553        }
554    }
555
556    /// Get statistics for a specific layer
557    pub fn get_layer_stats(&self, layer_name: &str) -> Option<&ActivationStats> {
558        self.layer_stats.get(layer_name)
559    }
560
561    /// Get all layer statistics
562    pub fn get_all_stats(&self) -> &HashMap<String, ActivationStats> {
563        &self.layer_stats
564    }
565
566    /// Save calibration statistics to file
567    pub fn save_calibration(&self, path: &str) -> Result<()> {
568        let json_data = serde_json::to_string_pretty(&self.layer_stats).map_err(|e| {
569            TrustformersError::quantization_error(format!("Failed to serialize statistics: {}", e))
570        })?;
571
572        std::fs::write(path, json_data).map_err(|e| {
573            TrustformersError::quantization_error(format!("Failed to write file: {}", e))
574        })?;
575
576        Ok(())
577    }
578
579    /// Load calibration statistics from file
580    pub fn load_calibration(&mut self, path: &str) -> Result<()> {
581        let json_data = std::fs::read_to_string(path).map_err(|e| {
582            TrustformersError::quantization_error(format!("Failed to read file: {}", e))
583        })?;
584
585        self.layer_stats = serde_json::from_str(&json_data).map_err(|e| {
586            TrustformersError::quantization_error(format!(
587                "Failed to deserialize statistics: {}",
588                e
589            ))
590        })?;
591
592        self.calibrating = false;
593        Ok(())
594    }
595
596    /// Configure quantization for a specific layer
597    pub fn configure_layer(&mut self, layer_name: &str, config: LayerQuantConfig) {
598        self.config.layer_configs.insert(layer_name.to_string(), config);
599    }
600
601    /// Disable quantization for a specific layer
602    pub fn disable_layer(&mut self, layer_name: &str) {
603        let config = LayerQuantConfig {
604            enabled: false,
605            ..Default::default()
606        };
607        self.config.layer_configs.insert(layer_name.to_string(), config);
608    }
609
610    /// Get memory savings from activation quantization
611    pub fn get_memory_savings(&self) -> f32 {
612        // Estimate memory savings based on bit width
613        match self.config.scheme {
614            ActivationQuantScheme::Int8
615            | ActivationQuantScheme::Dynamic
616            | ActivationQuantScheme::Adaptive => 0.75, // 32-bit to 8-bit = 75% savings
617            ActivationQuantScheme::Int16 => 0.5, // 32-bit to 16-bit = 50% savings
618        }
619    }
620}
621
622#[cfg(test)]
623mod tests {
624    use super::*;
625
626    #[test]
627    fn test_activation_stats_update() {
628        let mut stats = ActivationStats::new();
629        let tensor =
630            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5]).expect("Tensor from_vec failed");
631
632        stats.update(&tensor, 0.01).expect("tensor operation failed");
633
634        assert_eq!(stats.count, 5);
635        assert_eq!(stats.min_val, 1.0);
636        assert_eq!(stats.max_val, 5.0);
637        assert_eq!(stats.mean(), 3.0);
638    }
639
640    #[test]
641    fn test_activation_stats_quantization_params() {
642        let mut stats = ActivationStats::new();
643        let tensor = Tensor::from_vec(vec![-2.0, -1.0, 0.0, 1.0, 2.0], &[5])
644            .expect("Tensor from_vec failed");
645
646        stats.update(&tensor, 0.01).expect("tensor operation failed");
647
648        let (scale, zero_point) =
649            stats.get_quantization_params(true, 8, 1.0).expect("operation failed in test");
650        assert!(scale > 0.0);
651        assert_eq!(zero_point, 0); // Symmetric quantization
652    }
653
654    #[test]
655    fn test_activation_quantizer_calibration() {
656        let config = ActivationQuantConfig {
657            calibration_samples: 2,
658            ..Default::default()
659        };
660        let mut quantizer = ActivationQuantizer::new(config);
661
662        let tensor1 = Tensor::randn(&[10, 20]).expect("Failed to create random tensor");
663        let tensor2 = Tensor::randn(&[10, 20]).expect("Failed to create random tensor");
664
665        // Calibration phase
666        assert!(quantizer.calibrating);
667        quantizer
668            .quantize_activation(&tensor1, "layer1", false)
669            .expect("tensor operation failed");
670        quantizer
671            .quantize_activation(&tensor2, "layer1", false)
672            .expect("tensor operation failed");
673
674        // Should have statistics now
675        assert!(quantizer.get_layer_stats("layer1").is_some());
676    }
677
678    #[test]
679    fn test_activation_quantizer_int8() {
680        let config = ActivationQuantConfig {
681            calibration_samples: 1,
682            scheme: ActivationQuantScheme::Int8,
683            ..Default::default()
684        };
685
686        let mut quantizer = ActivationQuantizer::new(config);
687
688        let tensor =
689            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4]).expect("Tensor from_vec failed");
690
691        // Calibrate
692        quantizer
693            .quantize_activation(&tensor, "test_layer", false)
694            .expect("tensor operation failed");
695        quantizer.end_calibration();
696
697        // Quantize
698        let result = quantizer
699            .quantize_activation(&tensor, "test_layer", false)
700            .expect("tensor operation failed");
701        assert_eq!(result.shape(), tensor.shape());
702    }
703
704    #[test]
705    fn test_activation_quantizer_layer_config() {
706        let config = ActivationQuantConfig::default();
707        let mut quantizer = ActivationQuantizer::new(config);
708
709        // Configure a specific layer
710        let layer_config = LayerQuantConfig {
711            enabled: true,
712            scheme: Some(ActivationQuantScheme::Int16),
713            bits: Some(16),
714            calibrate: true,
715        };
716        quantizer.configure_layer("special_layer", layer_config);
717
718        // Disable another layer
719        quantizer.disable_layer("disabled_layer");
720
721        let tensor = Tensor::randn(&[8, 8]).expect("Failed to create random tensor");
722
723        // Disabled layer should return original tensor
724        let result = quantizer
725            .quantize_activation(&tensor, "disabled_layer", false)
726            .expect("tensor operation failed");
727        // Should be same reference (no quantization applied)
728        assert_eq!(result.shape(), tensor.shape());
729    }
730
731    #[test]
732    fn test_activation_quantizer_adaptive() {
733        let config = ActivationQuantConfig {
734            scheme: ActivationQuantScheme::Adaptive,
735            calibration_samples: 1,
736            ..Default::default()
737        };
738
739        let mut quantizer = ActivationQuantizer::new(config);
740
741        let tensor = Tensor::from_vec(vec![0.1, 0.2, 0.15, 0.18, 10.0], &[5])
742            .expect("Tensor from_vec failed"); // One outlier
743
744        // Calibrate
745        quantizer
746            .quantize_activation(&tensor, "adaptive_layer", false)
747            .expect("tensor operation failed");
748        quantizer.end_calibration();
749
750        // Quantize with adaptive scheme
751        let result = quantizer
752            .quantize_activation(&tensor, "adaptive_layer", false)
753            .expect("tensor operation failed");
754        assert_eq!(result.shape(), tensor.shape());
755    }
756
757    #[test]
758    fn test_quantized_activation_dequantization() {
759        let _original_data = [1.0, 2.0, 3.0, 4.0];
760        let shape = vec![4];
761
762        // Simulate quantized data
763        let quantized_data = vec![64, 128, 192, 255]; // 8-bit quantized values
764        let scale = 4.0 / 255.0; // Scale for range [0, 4]
765        let zero_point = 0;
766
767        let quant_activation = QuantizedActivation::new(
768            quantized_data,
769            scale,
770            zero_point,
771            shape.clone(),
772            ActivationQuantScheme::Int8,
773            8,
774        );
775
776        let dequantized = quant_activation.dequantize().expect("Dequantization failed");
777        assert_eq!(dequantized.shape(), shape);
778    }
779
780    #[test]
781    fn test_memory_savings_calculation() {
782        let config = ActivationQuantConfig {
783            scheme: ActivationQuantScheme::Int8,
784            ..Default::default()
785        };
786        let quantizer = ActivationQuantizer::new(config);
787
788        let savings = quantizer.get_memory_savings();
789        assert_eq!(savings, 0.75); // 75% savings for int8
790    }
791
792    #[test]
793    fn test_percentile_calculation() {
794        let mut stats = ActivationStats::new();
795        let tensor = Tensor::from_vec((1..=100).map(|x| x as f32).collect(), &[100])
796            .expect("tensor operation failed");
797
798        stats.update(&tensor, 0.01).expect("tensor operation failed");
799
800        let p95 = stats.percentile(0.95);
801        assert!((90.0..=100.0).contains(&p95)); // Should be around 95
802    }
803
804    #[test]
805    fn test_serialization() {
806        let config = ActivationQuantConfig::default();
807        let serialized = serde_json::to_string(&config).expect("JSON serialization failed");
808        let deserialized: ActivationQuantConfig =
809            serde_json::from_str(&serialized).expect("JSON deserialization failed");
810
811        assert_eq!(config.scheme, deserialized.scheme);
812        assert_eq!(config.symmetric, deserialized.symmetric);
813    }
814}