1use crate::error::{Error, Result};
10use ndarray::{ArrayD, ArrayView, Zip};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct QuantizationConfig {
17 pub bits: u8,
19 pub signed: bool,
21 pub scheme: QuantizationScheme,
23 pub calibration_size: usize,
25 pub mode: QuantizationMode,
27 pub per_channel: bool,
29 pub range_clipping: f32,
31}
32
33impl Default for QuantizationConfig {
34 fn default() -> Self {
35 Self {
36 bits: 8,
37 signed: true,
38 scheme: QuantizationScheme::Symmetric,
39 calibration_size: 1000,
40 mode: QuantizationMode::Static,
41 per_channel: false,
42 range_clipping: 0.999,
43 }
44 }
45}
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
49pub enum QuantizationScheme {
50 Symmetric,
52 Asymmetric,
54 PowerOfTwo,
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
60pub enum QuantizationMode {
61 Static,
63 Dynamic,
65 QAT,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct QuantizationParams {
72 pub scale: f32,
74 pub zero_point: i32,
76 pub bits: u8,
78 pub qmin: i32,
80 pub qmax: i32,
82}
83
84impl QuantizationParams {
85 pub fn new(bits: u8, signed: bool) -> Self {
87 let (qmin, qmax) = if signed {
88 (-(1 << (bits - 1)), (1 << (bits - 1)) - 1)
89 } else {
90 (0, (1 << bits) - 1)
91 };
92
93 Self {
94 scale: 1.0,
95 zero_point: 0,
96 bits,
97 qmin,
98 qmax,
99 }
100 }
101
102 pub fn from_tensor(
104 tensor: &ArrayView<f32, ndarray::IxDyn>,
105 config: &QuantizationConfig,
106 ) -> Result<Self> {
107 let mut params = Self::new(config.bits, config.signed);
108
109 let min_val = tensor.iter().cloned().fold(f32::INFINITY, f32::min);
111 let max_val = tensor.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
112
113 let range = max_val - min_val;
115 let clipped_range = range * config.range_clipping;
116 let center = (max_val + min_val) / 2.0;
117 let clipped_min = center - clipped_range / 2.0;
118 let clipped_max = center + clipped_range / 2.0;
119
120 match config.scheme {
121 QuantizationScheme::Symmetric => {
122 let abs_max = clipped_max.abs().max(clipped_min.abs());
123 params.scale = (2.0 * abs_max) / (params.qmax - params.qmin) as f32;
124 params.zero_point = 0;
125 }
126 QuantizationScheme::Asymmetric => {
127 params.scale = (clipped_max - clipped_min) / (params.qmax - params.qmin) as f32;
128 params.zero_point = params.qmin - (clipped_min / params.scale).round() as i32;
129 }
130 QuantizationScheme::PowerOfTwo => {
131 let abs_max = clipped_max.abs().max(clipped_min.abs());
132 let scale_log2 = (abs_max / (1 << (config.bits - 1)) as f32).log2().ceil();
133 params.scale = 2.0_f32.powf(scale_log2);
134 params.zero_point = 0;
135 }
136 }
137
138 Ok(params)
139 }
140}
141
142#[derive(Debug, Clone)]
144pub struct QuantizedTensor {
145 pub data: ArrayD<i8>,
147 pub params: QuantizationParams,
149 pub shape: Vec<usize>,
151}
152
153impl QuantizedTensor {
154 pub fn from_float(tensor: &ArrayD<f32>, config: &QuantizationConfig) -> Result<Self> {
156 let params = QuantizationParams::from_tensor(&tensor.view(), config)?;
157 let quantized_data = Self::quantize_tensor(tensor, ¶ms)?;
158
159 Ok(Self {
160 data: quantized_data,
161 params,
162 shape: tensor.shape().to_vec(),
163 })
164 }
165
166 fn quantize_tensor(tensor: &ArrayD<f32>, params: &QuantizationParams) -> Result<ArrayD<i8>> {
168 let quantized = tensor.mapv(|x| {
169 let q_val = (x / params.scale).round() + params.zero_point as f32;
170 let clamped = q_val.max(params.qmin as f32).min(params.qmax as f32);
171 clamped as i8
172 });
173
174 Ok(quantized)
175 }
176
177 pub fn dequantize(&self) -> ArrayD<f32> {
179 self.data
180 .mapv(|q| (q as f32 - self.params.zero_point as f32) * self.params.scale)
181 }
182
183 pub fn size_bytes(&self) -> usize {
185 self.data.len() + std::mem::size_of::<QuantizationParams>()
186 }
187
188 pub fn compression_ratio(&self) -> f32 {
190 let original_size = self.data.len() * std::mem::size_of::<f32>();
191 let quantized_size = self.size_bytes();
192 original_size as f32 / quantized_size as f32
193 }
194}
195
196#[derive(Debug)]
198pub struct PostTrainingQuantizer {
199 config: QuantizationConfig,
201 calibration_stats: HashMap<String, TensorStats>,
203}
204
205#[derive(Debug, Clone)]
207struct TensorStats {
208 min: f32,
209 max: f32,
210 mean: f32,
211 std: f32,
212 histogram: Vec<u32>,
213}
214
215impl TensorStats {
216 fn new() -> Self {
217 Self {
218 min: f32::INFINITY,
219 max: f32::NEG_INFINITY,
220 mean: 0.0,
221 std: 0.0,
222 histogram: vec![0; 256],
223 }
224 }
225
226 fn update(&mut self, tensor: &ArrayView<f32, ndarray::IxDyn>) {
227 self.min = self.min.min(
228 *tensor
229 .iter()
230 .min_by(|a, b| a.partial_cmp(b).unwrap())
231 .unwrap(),
232 );
233 self.max = self.max.max(
234 *tensor
235 .iter()
236 .max_by(|a, b| a.partial_cmp(b).unwrap())
237 .unwrap(),
238 );
239
240 let sum: f32 = tensor.sum();
241 let count = tensor.len() as f32;
242 self.mean = sum / count;
243
244 let variance: f32 = tensor.iter().map(|&x| (x - self.mean).powi(2)).sum::<f32>() / count;
245 self.std = variance.sqrt();
246
247 for &val in tensor.iter() {
249 let normalized = ((val - self.min) / (self.max - self.min) * 255.0).round() as usize;
250 let bin = normalized.min(255);
251 self.histogram[bin] += 1;
252 }
253 }
254}
255
256impl PostTrainingQuantizer {
257 pub fn new(config: QuantizationConfig) -> Self {
259 Self {
260 config,
261 calibration_stats: HashMap::new(),
262 }
263 }
264
265 pub fn add_calibration_data(&mut self, name: &str, tensor: &ArrayD<f32>) {
267 let stats = self
268 .calibration_stats
269 .entry(name.to_string())
270 .or_insert_with(TensorStats::new);
271 stats.update(&tensor.view());
272 }
273
274 pub fn finalize_calibration(&mut self) -> Result<HashMap<String, QuantizationParams>> {
276 let mut params_map = HashMap::new();
277
278 for (name, stats) in &self.calibration_stats {
279 let optimal_params = self.compute_optimal_params(stats)?;
283 params_map.insert(name.clone(), optimal_params);
284 }
285
286 Ok(params_map)
287 }
288
289 fn compute_optimal_params(&self, stats: &TensorStats) -> Result<QuantizationParams> {
291 let mut best_params = QuantizationParams::new(self.config.bits, self.config.signed);
292 let mut best_kl_div = f32::INFINITY;
293
294 for threshold_idx in 128..=255 {
296 let threshold = stats.min + (threshold_idx as f32 / 255.0) * (stats.max - stats.min);
297
298 let mut params = QuantizationParams::new(self.config.bits, self.config.signed);
300
301 match self.config.scheme {
302 QuantizationScheme::Symmetric => {
303 params.scale = (2.0 * threshold) / (params.qmax - params.qmin) as f32;
304 params.zero_point = 0;
305 }
306 QuantizationScheme::Asymmetric => {
307 params.scale = (threshold - stats.min) / (params.qmax - params.qmin) as f32;
308 params.zero_point = params.qmin - (stats.min / params.scale).round() as i32;
309 }
310 QuantizationScheme::PowerOfTwo => {
311 let scale_log2 = (threshold / (1 << (self.config.bits - 1)) as f32)
312 .log2()
313 .ceil();
314 params.scale = 2.0_f32.powf(scale_log2);
315 params.zero_point = 0;
316 }
317 }
318
319 let kl_div = self.compute_kl_divergence(&stats.histogram, ¶ms);
321
322 if kl_div < best_kl_div {
323 best_kl_div = kl_div;
324 best_params = params;
325 }
326 }
327
328 Ok(best_params)
329 }
330
331 fn compute_kl_divergence(&self, histogram: &[u32], params: &QuantizationParams) -> f32 {
333 let total_count: u32 = histogram.iter().sum();
334 if total_count == 0 {
335 return 0.0;
336 }
337
338 let mut kl_div = 0.0;
339 for (i, &count) in histogram.iter().enumerate() {
340 if count > 0 {
341 let p = count as f32 / total_count as f32;
342
343 let bin_value = i as f32 / 255.0;
345 let quantized = (bin_value / params.scale)
346 .round()
347 .max(params.qmin as f32)
348 .min(params.qmax as f32);
349 let dequantized = quantized * params.scale;
350
351 let q = (dequantized * 255.0).round() as usize;
353 let q_count = if q < histogram.len() { histogram[q] } else { 1 };
354 let q_prob = (q_count as f32 / total_count as f32).max(1e-8);
355
356 kl_div += p * (p / q_prob).ln();
357 }
358 }
359
360 kl_div
361 }
362
363 pub fn quantize_tensor(
365 &self,
366 tensor: &ArrayD<f32>,
367 params: &QuantizationParams,
368 ) -> Result<QuantizedTensor> {
369 let quantized_data = QuantizedTensor::quantize_tensor(tensor, params)?;
370
371 Ok(QuantizedTensor {
372 data: quantized_data,
373 params: params.clone(),
374 shape: tensor.shape().to_vec(),
375 })
376 }
377}
378
379#[derive(Debug)]
381pub struct QuantizationAwareTraining {
382 config: QuantizationConfig,
384 layer_params: HashMap<String, QuantizationParams>,
386 step_count: usize,
388 warmup_steps: usize,
390}
391
392impl QuantizationAwareTraining {
393 pub fn new(config: QuantizationConfig) -> Self {
395 Self {
396 config,
397 layer_params: HashMap::new(),
398 step_count: 0,
399 warmup_steps: 1000,
400 }
401 }
402
403 pub fn set_warmup_steps(&mut self, steps: usize) {
405 self.warmup_steps = steps;
406 }
407
408 pub fn init_layer_params(&mut self, layer_name: &str, tensor: &ArrayD<f32>) -> Result<()> {
410 let params = QuantizationParams::from_tensor(&tensor.view(), &self.config)?;
411 self.layer_params.insert(layer_name.to_string(), params);
412 Ok(())
413 }
414
415 pub fn fake_quantize(&mut self, layer_name: &str, tensor: &ArrayD<f32>) -> Result<ArrayD<f32>> {
417 self.step_count += 1;
418
419 if self.step_count < self.warmup_steps {
421 return Ok(tensor.clone());
422 }
423
424 let params = self.layer_params.get_mut(layer_name).ok_or_else(|| {
425 Error::InvalidArgument(format!("Layer {} not initialized", layer_name))
426 })?;
427
428 let new_params = QuantizationParams::from_tensor(&tensor.view(), &self.config)?;
430 let alpha = 0.01; params.scale = params.scale * (1.0 - alpha) + new_params.scale * alpha;
433 if self.config.scheme == QuantizationScheme::Asymmetric {
434 params.zero_point = ((params.zero_point as f32) * (1.0 - alpha)
435 + (new_params.zero_point as f32) * alpha)
436 .round() as i32;
437 }
438
439 let quantized = QuantizedTensor::quantize_tensor(tensor, params)?;
441 let dequantized = quantized.mapv(|q| (q as f32 - params.zero_point as f32) * params.scale);
442
443 Ok(dequantized)
444 }
445
446 pub fn get_quantization_params(&self) -> &HashMap<String, QuantizationParams> {
448 &self.layer_params
449 }
450
451 pub fn add_quantization_noise(&self, tensor: &ArrayD<f32>, noise_scale: f32) -> ArrayD<f32> {
453 use rand::Rng;
454 let mut rng = rand::rng();
455
456 tensor.mapv(|x| {
457 let noise = rng.random::<f32>() - 0.5; x + noise * noise_scale
459 })
460 }
461}
462
463#[derive(Debug)]
465pub struct MixedBitWidthQuantizer {
466 layer_configs: HashMap<String, QuantizationConfig>,
468 sensitivity_scores: HashMap<String, f32>,
470}
471
472impl Default for MixedBitWidthQuantizer {
473 fn default() -> Self {
474 Self::new()
475 }
476}
477
478impl MixedBitWidthQuantizer {
479 pub fn new() -> Self {
481 Self {
482 layer_configs: HashMap::new(),
483 sensitivity_scores: HashMap::new(),
484 }
485 }
486
487 pub fn set_layer_config(&mut self, layer_name: &str, config: QuantizationConfig) {
489 self.layer_configs.insert(layer_name.to_string(), config);
490 }
491
492 pub fn analyze_sensitivity(
494 &mut self,
495 layer_outputs: &HashMap<String, ArrayD<f32>>,
496 ) -> Result<()> {
497 for (layer_name, output) in layer_outputs {
498 let variance = self.compute_variance(output);
500 let entropy = self.compute_entropy(output);
501 let gradient_norm = self.compute_gradient_norm(output);
502
503 let sensitivity = variance * 0.4 + entropy * 0.3 + gradient_norm * 0.3;
505 self.sensitivity_scores
506 .insert(layer_name.clone(), sensitivity);
507 }
508
509 self.assign_bit_widths()?;
511
512 Ok(())
513 }
514
515 fn compute_variance(&self, tensor: &ArrayD<f32>) -> f32 {
517 let mean = tensor.mean().unwrap_or(0.0);
518 let variance =
519 tensor.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / tensor.len() as f32;
520 variance
521 }
522
523 fn compute_entropy(&self, tensor: &ArrayD<f32>) -> f32 {
525 let mut histogram = vec![0; 256];
526 let min_val = tensor.iter().cloned().fold(f32::INFINITY, f32::min);
527 let max_val = tensor.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
528 let range = max_val - min_val;
529
530 if range == 0.0 {
531 return 0.0;
532 }
533
534 for &val in tensor.iter() {
535 let bin = ((val - min_val) / range * 255.0).round() as usize;
536 let bin = bin.min(255);
537 histogram[bin] += 1;
538 }
539
540 let total = tensor.len() as f32;
541 let mut entropy = 0.0;
542 for count in histogram {
543 if count > 0 {
544 let p = count as f32 / total;
545 entropy -= p * p.ln();
546 }
547 }
548
549 entropy
550 }
551
552 fn compute_gradient_norm(&self, tensor: &ArrayD<f32>) -> f32 {
554 let mut grad_norm = 0.0;
556 for axis in 0..tensor.ndim() {
557 if tensor.shape()[axis] > 1 {
558 for _i in 0..tensor.shape()[axis] - 1 {
559 grad_norm += 1.0; }
562 }
563 }
564 grad_norm / tensor.len() as f32
565 }
566
567 fn assign_bit_widths(&mut self) -> Result<()> {
569 let mut scores: Vec<(String, f32)> = self
570 .sensitivity_scores
571 .iter()
572 .map(|(name, &score)| (name.clone(), score))
573 .collect();
574
575 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
577
578 for (i, (layer_name, _)) in scores.iter().enumerate() {
580 let bits = if i < scores.len() / 3 {
581 8 } else if i < 2 * scores.len() / 3 {
583 6 } else {
585 4 };
587
588 let mut config = self
589 .layer_configs
590 .get(layer_name)
591 .cloned()
592 .unwrap_or_default();
593 config.bits = bits;
594 self.layer_configs.insert(layer_name.clone(), config);
595 }
596
597 Ok(())
598 }
599
600 pub fn get_layer_config(&self, layer_name: &str) -> Option<&QuantizationConfig> {
602 self.layer_configs.get(layer_name)
603 }
604
605 pub fn get_sensitivity_score(&self, layer_name: &str) -> Option<f32> {
607 self.sensitivity_scores.get(layer_name).copied()
608 }
609}
610
611#[derive(Debug)]
613pub struct DynamicQuantizer {
614 config: QuantizationConfig,
616 params_cache: HashMap<String, QuantizationParams>,
618 cache_size_limit: usize,
620}
621
622impl DynamicQuantizer {
623 pub fn new(config: QuantizationConfig) -> Self {
625 Self {
626 config,
627 params_cache: HashMap::new(),
628 cache_size_limit: 100,
629 }
630 }
631
632 pub fn quantize(
634 &mut self,
635 tensor: &ArrayD<f32>,
636 cache_key: Option<&str>,
637 ) -> Result<QuantizedTensor> {
638 let params = if let Some(key) = cache_key {
639 if let Some(cached_params) = self.params_cache.get(key) {
640 cached_params.clone()
641 } else {
642 let params = QuantizationParams::from_tensor(&tensor.view(), &self.config)?;
643 self.cache_params(key, params.clone());
644 params
645 }
646 } else {
647 QuantizationParams::from_tensor(&tensor.view(), &self.config)?
648 };
649
650 let quantized_data = QuantizedTensor::quantize_tensor(tensor, ¶ms)?;
651
652 Ok(QuantizedTensor {
653 data: quantized_data,
654 params,
655 shape: tensor.shape().to_vec(),
656 })
657 }
658
659 fn cache_params(&mut self, key: &str, params: QuantizationParams) {
661 if self.params_cache.len() >= self.cache_size_limit {
662 if let Some(first_key) = self.params_cache.keys().next().cloned() {
664 self.params_cache.remove(&first_key);
665 }
666 }
667 self.params_cache.insert(key.to_string(), params);
668 }
669
670 pub fn clear_cache(&mut self) {
672 self.params_cache.clear();
673 }
674
675 pub fn cache_stats(&self) -> (usize, usize) {
677 (self.params_cache.len(), self.cache_size_limit)
678 }
679}
680
681pub mod utils {
683 use super::*;
684
685 pub fn compute_quantization_error(original: &ArrayD<f32>, quantized: &QuantizedTensor) -> f32 {
687 let dequantized = quantized.dequantize();
688 let mse = Zip::from(original)
689 .and(&dequantized)
690 .fold(0.0, |acc, &orig, &deq| acc + (orig - deq).powi(2));
691 mse / original.len() as f32
692 }
693
694 pub fn estimate_size_reduction(bit_width: u8) -> f32 {
696 32.0 / bit_width as f32
697 }
698
699 pub fn estimate_performance_gain(bit_width: u8) -> f32 {
701 match bit_width {
703 8 => 2.0, 4 => 4.0, 1 => 16.0, _ => 1.0,
707 }
708 }
709
710 pub fn convert_quantization_scheme(
712 tensor: &QuantizedTensor,
713 target_scheme: QuantizationScheme,
714 target_bits: u8,
715 ) -> Result<QuantizedTensor> {
716 let float_tensor = tensor.dequantize();
718
719 let config = QuantizationConfig {
721 scheme: target_scheme,
722 bits: target_bits,
723 ..Default::default()
724 };
725
726 QuantizedTensor::from_float(&float_tensor, &config)
728 }
729}
730
731#[cfg(test)]
732mod tests {
733 use super::*;
734 use ndarray::{array, Array2};
735 use ndarray_rand::rand::distributions::Standard;
736 use ndarray_rand::RandomExt;
737
738 #[test]
739 fn test_quantization_config_default() {
740 let config = QuantizationConfig::default();
741 assert_eq!(config.bits, 8);
742 assert!(config.signed);
743 assert_eq!(config.scheme, QuantizationScheme::Symmetric);
744 }
745
746 #[test]
747 fn test_quantization_params_creation() {
748 let params = QuantizationParams::new(8, true);
749 assert_eq!(params.bits, 8);
750 assert_eq!(params.qmin, -128);
751 assert_eq!(params.qmax, 127);
752 }
753
754 #[test]
755 fn test_symmetric_quantization() {
756 let tensor = array![[1.0, -1.0], [2.0, -2.0]].into_dyn();
757 let config = QuantizationConfig {
758 scheme: QuantizationScheme::Symmetric,
759 ..Default::default()
760 };
761
762 let quantized = QuantizedTensor::from_float(&tensor, &config).unwrap();
763 let _dequantized = quantized.dequantize();
764
765 let error = utils::compute_quantization_error(&tensor, &quantized);
767 assert!(error < 0.1); }
769
770 #[test]
771 fn test_asymmetric_quantization() {
772 let tensor = array![[0.0, 1.0], [2.0, 3.0]].into_dyn();
773 let config = QuantizationConfig {
774 scheme: QuantizationScheme::Asymmetric,
775 ..Default::default()
776 };
777
778 let quantized = QuantizedTensor::from_float(&tensor, &config).unwrap();
779 let _dequantized = quantized.dequantize();
780
781 assert!(quantized.params.zero_point != 0); let error = utils::compute_quantization_error(&tensor, &quantized);
784 assert!(error < 0.1);
785 }
786
787 #[test]
788 fn test_post_training_quantization() {
789 let mut ptq = PostTrainingQuantizer::new(QuantizationConfig::default());
790
791 let calib_data = Array2::random((100, 50), Standard).into_dyn();
793 ptq.add_calibration_data("layer1", &calib_data);
794
795 let params = ptq.finalize_calibration().unwrap();
796 assert!(params.contains_key("layer1"));
797 }
798
799 #[test]
800 fn test_quantization_aware_training() {
801 let mut qat = QuantizationAwareTraining::new(QuantizationConfig::default());
802 let tensor = Array2::ones((10, 10)).into_dyn();
803
804 qat.init_layer_params("layer1", &tensor).unwrap();
805 let fake_quantized = qat.fake_quantize("layer1", &tensor).unwrap();
806
807 assert_eq!(fake_quantized.shape(), tensor.shape());
808 }
809
810 #[test]
811 fn test_mixed_bitwidth_quantization() {
812 let mut mbq = MixedBitWidthQuantizer::new();
813
814 let mut outputs = HashMap::new();
815 outputs.insert(
816 "layer1".to_string(),
817 Array2::random((50, 50), Standard).into_dyn(),
818 );
819 outputs.insert("layer2".to_string(), Array2::ones((50, 50)).into_dyn());
820
821 mbq.analyze_sensitivity(&outputs).unwrap();
822
823 assert!(mbq.get_sensitivity_score("layer1").is_some());
824 assert!(mbq.get_layer_config("layer1").is_some());
825 }
826
827 #[test]
828 fn test_dynamic_quantization() {
829 let mut dq = DynamicQuantizer::new(QuantizationConfig::default());
830 let tensor = Array2::random((20, 20), Standard).into_dyn();
831
832 let quantized = dq.quantize(&tensor, Some("test_key")).unwrap();
833 assert_eq!(quantized.shape, tensor.shape().to_vec());
834
835 let (cache_size, _) = dq.cache_stats();
836 assert_eq!(cache_size, 1);
837 }
838
839 #[test]
840 fn test_quantization_utilities() {
841 let original = Array2::random((10, 10), Standard).into_dyn();
842 let quantized =
843 QuantizedTensor::from_float(&original, &QuantizationConfig::default()).unwrap();
844
845 let error = utils::compute_quantization_error(&original, &quantized);
846 assert!(error >= 0.0);
847
848 let size_reduction = utils::estimate_size_reduction(8);
849 assert_eq!(size_reduction, 4.0);
850
851 let perf_gain = utils::estimate_performance_gain(8);
852 assert_eq!(perf_gain, 2.0);
853 }
854
855 #[test]
856 fn test_compression_ratio() {
857 let tensor = Array2::ones((100, 100)).into_dyn();
858 let quantized =
859 QuantizedTensor::from_float(&tensor, &QuantizationConfig::default()).unwrap();
860
861 let ratio = quantized.compression_ratio();
862 assert!(ratio > 1.0); }
864
865 #[test]
866 fn test_power_of_two_quantization() {
867 let tensor = Array2::random((10, 10), Standard).into_dyn();
868 let config = QuantizationConfig {
869 scheme: QuantizationScheme::PowerOfTwo,
870 ..Default::default()
871 };
872
873 let quantized = QuantizedTensor::from_float(&tensor, &config).unwrap();
874
875 let scale_log2 = quantized.params.scale.log2();
877 assert!((scale_log2.round() - scale_log2).abs() < 1e-6);
878 }
879
880 #[test]
881 fn test_quantization_scheme_conversion() {
882 let tensor = Array2::random((10, 10), Standard).into_dyn();
883 let quantized =
884 QuantizedTensor::from_float(&tensor, &QuantizationConfig::default()).unwrap();
885
886 let converted =
887 utils::convert_quantization_scheme(&quantized, QuantizationScheme::Asymmetric, 4)
888 .unwrap();
889
890 assert_eq!(converted.params.bits, 4);
891 }
892}