1#![allow(unused_variables)] use crate::errors::Result;
10use crate::quantization::base::QuantizationScheme;
11use crate::tensor::Tensor;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct MixedBitConfig {
18 pub layer_configs: HashMap<String, LayerQuantConfig>,
20 pub default_config: LayerQuantConfig,
22 pub sensitivity_config: SensitivityConfig,
24 pub auto_bit_allocation: Option<AutoBitAllocationStrategy>,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct LayerQuantConfig {
31 pub weight_bits: u8,
33 pub activation_bits: u8,
35 pub scheme: QuantizationScheme,
37 pub symmetric: bool,
39 pub group_size: Option<usize>,
41 pub channel_bits: Option<Vec<u8>>,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct SensitivityConfig {
48 pub calibration_samples: usize,
50 pub sensitivity_threshold: f32,
52 pub metrics: Vec<SensitivityMetric>,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub enum SensitivityMetric {
59 GradientMagnitude,
61 HessianDiagonal,
63 ActivationVariance,
65 WeightVariance,
67 OutputSensitivity,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub enum AutoBitAllocationStrategy {
74 SensitivityBased {
76 target_compression: f32,
78 min_bits: u8,
80 max_bits: u8,
82 },
83 AdaptiveUniform {
85 base_bits: u8,
87 adjustment_range: u8,
89 },
90 PerformanceDriven {
92 target_latency: f32,
94 accuracy_tolerance: f32,
96 },
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct MixedBitQuantizedTensor {
102 pub layer_name: String,
104 pub quantized_data: Vec<QuantizedBlock>,
106 pub shape: Vec<usize>,
108 pub config: LayerQuantConfig,
110 pub sensitivity_scores: Vec<f32>,
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct QuantizedBlock {
117 pub data: Vec<u8>,
119 pub scale: f32,
121 pub zero_point: i32,
123 pub bit_width: u8,
125 pub block_shape: Vec<usize>,
127 pub block_offset: Vec<usize>,
129}
130
131pub struct MixedBitQuantizer {
133 config: MixedBitConfig,
134 sensitivity_analyzer: SensitivityAnalyzer,
135}
136
137struct SensitivityAnalyzer {
139 config: SensitivityConfig,
140 sensitivity_cache: HashMap<String, Vec<f32>>,
141}
142
143impl Default for MixedBitConfig {
144 fn default() -> Self {
145 Self {
146 layer_configs: HashMap::new(),
147 default_config: LayerQuantConfig::default(),
148 sensitivity_config: SensitivityConfig::default(),
149 auto_bit_allocation: Some(AutoBitAllocationStrategy::SensitivityBased {
150 target_compression: 0.25, min_bits: 2,
152 max_bits: 8,
153 }),
154 }
155 }
156}
157
158impl Default for LayerQuantConfig {
159 fn default() -> Self {
160 Self {
161 weight_bits: 4,
162 activation_bits: 8,
163 scheme: QuantizationScheme::Int4,
164 symmetric: true,
165 group_size: Some(128),
166 channel_bits: None,
167 }
168 }
169}
170
171impl Default for SensitivityConfig {
172 fn default() -> Self {
173 Self {
174 calibration_samples: 128,
175 sensitivity_threshold: 0.01,
176 metrics: vec![
177 SensitivityMetric::GradientMagnitude,
178 SensitivityMetric::ActivationVariance,
179 SensitivityMetric::WeightVariance,
180 ],
181 }
182 }
183}
184
185impl MixedBitQuantizer {
186 pub fn new(config: MixedBitConfig) -> Self {
188 let sensitivity_analyzer = SensitivityAnalyzer::new(config.sensitivity_config.clone());
189 Self {
190 config,
191 sensitivity_analyzer,
192 }
193 }
194
195 pub fn quantize(
197 &mut self,
198 tensor: &Tensor,
199 layer_name: &str,
200 ) -> Result<MixedBitQuantizedTensor> {
201 let layer_config = self
203 .config
204 .layer_configs
205 .get(layer_name)
206 .cloned()
207 .unwrap_or_else(|| self.config.default_config.clone());
208
209 let sensitivity_scores = if let Some(ref auto_strategy) = self.config.auto_bit_allocation {
211 self.sensitivity_analyzer
212 .analyze_sensitivity(tensor, layer_name, &layer_config)?
213 } else {
214 vec![1.0; tensor.shape().iter().product()]
215 };
216
217 let bit_allocation = self.allocate_bits(&sensitivity_scores, &layer_config)?;
219
220 let quantized_blocks = self.quantize_blocks(tensor, &bit_allocation, &layer_config)?;
222
223 Ok(MixedBitQuantizedTensor {
224 layer_name: layer_name.to_string(),
225 quantized_data: quantized_blocks,
226 shape: tensor.shape(),
227 config: layer_config,
228 sensitivity_scores,
229 })
230 }
231
232 fn allocate_bits(
234 &self,
235 sensitivity_scores: &[f32],
236 config: &LayerQuantConfig,
237 ) -> Result<Vec<u8>> {
238 let mut bit_allocation = vec![config.weight_bits; sensitivity_scores.len()];
239
240 if let Some(ref strategy) = self.config.auto_bit_allocation {
241 match strategy {
242 AutoBitAllocationStrategy::SensitivityBased {
243 target_compression,
244 min_bits,
245 max_bits,
246 } => {
247 let mut indexed_scores: Vec<(usize, f32)> = sensitivity_scores
249 .iter()
250 .enumerate()
251 .map(|(i, &score)| (i, score))
252 .collect();
253 indexed_scores
254 .sort_by(|a, b| b.1.partial_cmp(&a.1).expect("Partial comparison failed"));
255
256 let total_elements = sensitivity_scores.len();
258 let target_total_bits = (total_elements as f32
259 * config.weight_bits as f32
260 * target_compression) as usize;
261 let mut allocated_bits = 0;
262
263 for (idx, _) in indexed_scores {
265 let remaining_elements =
266 total_elements - allocated_bits / (*max_bits as usize);
267 let remaining_budget = target_total_bits.saturating_sub(allocated_bits);
268
269 let avg_bits_remaining =
270 remaining_budget.checked_div(remaining_elements).unwrap_or(0);
271 if avg_bits_remaining > 0 {
272 let bits = (avg_bits_remaining as u8).clamp(*min_bits, *max_bits);
273 bit_allocation[idx] = bits;
274 allocated_bits += bits as usize;
275 }
276 }
277 },
278 AutoBitAllocationStrategy::AdaptiveUniform {
279 base_bits,
280 adjustment_range,
281 } => {
282 let mean_sensitivity =
284 sensitivity_scores.iter().sum::<f32>() / sensitivity_scores.len() as f32;
285
286 for (i, &score) in sensitivity_scores.iter().enumerate() {
287 let normalized_score = score / mean_sensitivity;
288 let adjustment = (normalized_score * *adjustment_range as f32) as i8;
289 let bits = (*base_bits as i8 + adjustment).clamp(1, 8) as u8;
290 bit_allocation[i] = bits;
291 }
292 },
293 AutoBitAllocationStrategy::PerformanceDriven {
294 target_latency,
295 accuracy_tolerance,
296 } => {
297 return self.allocate_bits_performance_driven(
299 sensitivity_scores,
300 config,
301 *target_latency,
302 *accuracy_tolerance,
303 );
304 },
305 }
306 }
307
308 Ok(bit_allocation)
309 }
310
311 #[allow(dead_code)]
313 fn allocate_bits_sensitivity_based(
314 &self,
315 sensitivity_scores: &[f32],
316 config: &LayerQuantConfig,
317 ) -> Result<Vec<u8>> {
318 let mut bit_allocation = vec![config.weight_bits; sensitivity_scores.len()];
319
320 let mut sorted_scores = sensitivity_scores.to_vec();
322 sorted_scores.sort_by(|a, b| a.partial_cmp(b).expect("Partial comparison failed"));
323
324 let high_sensitivity_threshold =
325 sorted_scores[(sorted_scores.len() * 90 / 100).min(sorted_scores.len() - 1)];
326 let low_sensitivity_threshold = sorted_scores[sorted_scores.len() * 10 / 100];
327
328 for (i, &score) in sensitivity_scores.iter().enumerate() {
329 if score >= high_sensitivity_threshold {
330 bit_allocation[i] = 8; } else if score <= low_sensitivity_threshold {
332 bit_allocation[i] = 2; } else {
334 bit_allocation[i] = 4; }
336 }
337
338 Ok(bit_allocation)
339 }
340
341 fn allocate_bits_performance_driven(
343 &self,
344 sensitivity_scores: &[f32],
345 config: &LayerQuantConfig,
346 target_latency: f32,
347 accuracy_tolerance: f32,
348 ) -> Result<Vec<u8>> {
349 let total_elements = sensitivity_scores.len();
350
351 let performance_factor = |bits: u8| -> f32 {
354 match bits {
355 1 => 0.1, 2 => 0.25, 3 => 0.4, 4 => 0.6, 5 => 0.75, 6 => 0.85, 7 => 0.92, 8 => 1.0, _ => 1.0,
364 }
365 };
366
367 let accuracy_impact = |sensitivity: f32, bits: u8| -> f32 {
369 let base_impact = sensitivity / 100.0; let bit_factor = (8.0 - bits as f32) / 7.0; base_impact * bit_factor
372 };
373
374 let mut indexed_scores: Vec<(usize, f32)> =
376 sensitivity_scores.iter().enumerate().map(|(i, &score)| (i, score)).collect();
377 indexed_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("Partial comparison failed"));
378
379 let mut current_bits = vec![2u8; total_elements];
381 let mut current_latency = 0.0;
382 let mut current_accuracy_loss = 0.0;
383
384 for (i, &score) in sensitivity_scores.iter().enumerate() {
386 current_latency += performance_factor(2);
387 current_accuracy_loss += accuracy_impact(score, 2);
388 }
389
390 for (idx, sensitivity) in indexed_scores {
392 let current_element_bits = current_bits[idx];
393
394 for new_bits in (current_element_bits + 1)..=8 {
396 let latency_change =
397 performance_factor(new_bits) - performance_factor(current_element_bits);
398 let accuracy_change = accuracy_impact(sensitivity, current_element_bits)
399 - accuracy_impact(sensitivity, new_bits);
400
401 let new_latency = current_latency + latency_change;
402 let new_accuracy_loss = current_accuracy_loss - accuracy_change;
403
404 let normalized_latency = new_latency / total_elements as f32;
406 if normalized_latency <= target_latency && new_accuracy_loss <= accuracy_tolerance {
407 current_bits[idx] = new_bits;
409 current_latency = new_latency;
410 current_accuracy_loss = new_accuracy_loss;
411 } else {
412 break;
414 }
415 }
416 }
417
418 let bit_allocation = current_bits;
420
421 Ok(bit_allocation)
422 }
423
424 fn quantize_blocks(
426 &self,
427 tensor: &Tensor,
428 bit_allocation: &[u8],
429 config: &LayerQuantConfig,
430 ) -> Result<Vec<QuantizedBlock>> {
431 let data = tensor.data()?;
432 let shape = tensor.shape();
433 let mut blocks = Vec::new();
434
435 let mut bit_groups: HashMap<u8, Vec<(usize, f32)>> = HashMap::new();
437 for (i, (&bits, &value)) in bit_allocation.iter().zip(data.iter()).enumerate() {
438 bit_groups.entry(bits).or_default().push((i, value));
439 }
440
441 for (bit_width, elements) in bit_groups {
443 let values: Vec<f32> = elements.iter().map(|(_, v)| *v).collect();
444 let indices: Vec<usize> = elements.iter().map(|(i, _)| *i).collect();
445
446 let (quantized_data, scale, zero_point) =
447 self.quantize_group(&values, bit_width, config)?;
448
449 blocks.push(QuantizedBlock {
450 data: quantized_data,
451 scale,
452 zero_point,
453 bit_width,
454 block_shape: vec![values.len()],
455 block_offset: vec![indices[0]], });
457 }
458
459 Ok(blocks)
460 }
461
462 fn quantize_group(
464 &self,
465 values: &[f32],
466 bit_width: u8,
467 config: &LayerQuantConfig,
468 ) -> Result<(Vec<u8>, f32, i32)> {
469 if values.is_empty() {
470 return Ok((Vec::new(), 1.0, 0));
471 }
472
473 let min_val = values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
474 let max_val = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
475
476 let qmin = 0;
477 let qmax = (1 << bit_width) - 1;
478
479 let (scale, zero_point) = if config.symmetric {
480 let max_abs = max_val.abs().max(min_val.abs());
481 let scale = max_abs / (qmax as f32 / 2.0);
482 (scale, qmax / 2)
483 } else {
484 let scale = (max_val - min_val) / (qmax - qmin) as f32;
485 let zero_point = qmin as f32 - min_val / scale;
486 (scale, zero_point.round() as i32)
487 };
488
489 let mut quantized = Vec::with_capacity(values.len());
490 for &value in values {
491 let q_val = (value / scale + zero_point as f32).round() as i32;
492 let clamped = q_val.clamp(qmin, qmax) as u8;
493 quantized.push(clamped);
494 }
495
496 Ok((quantized, scale, zero_point))
497 }
498
499 pub fn compression_ratio(
501 &self,
502 original_size: usize,
503 quantized_tensor: &MixedBitQuantizedTensor,
504 ) -> f32 {
505 let compressed_size: usize =
506 quantized_tensor.quantized_data.iter().map(|block| block.data.len()).sum();
507
508 original_size as f32 / compressed_size as f32
509 }
510
511 pub fn memory_savings(
513 &self,
514 original_tensor: &Tensor,
515 quantized_tensor: &MixedBitQuantizedTensor,
516 ) -> f32 {
517 let original_bytes = original_tensor.size() * std::mem::size_of::<f32>();
518 let quantized_bytes: usize =
519 quantized_tensor.quantized_data.iter().map(|block| block.data.len()).sum();
520
521 1.0 - (quantized_bytes as f32 / original_bytes as f32)
522 }
523}
524
525impl MixedBitQuantizedTensor {
526 pub fn dequantize(&self) -> Result<Tensor> {
528 let total_elements: usize = self.shape.iter().product();
529 let mut result = vec![0.0f32; total_elements];
530
531 for block in &self.quantized_data {
532 for (i, &quantized_val) in block.data.iter().enumerate() {
533 let dequantized = (quantized_val as i32 - block.zero_point) as f32 * block.scale;
534 if i < result.len() {
536 result[i] = dequantized;
537 }
538 }
539 }
540
541 Tensor::from_vec(result, &self.shape)
542 }
543
544 pub fn average_bit_width(&self) -> f32 {
546 let total_elements: usize = self.quantized_data.iter().map(|b| b.data.len()).sum();
547 if total_elements == 0 {
548 return 0.0;
549 }
550
551 let total_bits: f32 = self
552 .quantized_data
553 .iter()
554 .map(|block| block.data.len() as f32 * block.bit_width as f32)
555 .sum();
556
557 total_bits / total_elements as f32
558 }
559
560 pub fn memory_footprint(&self) -> usize {
562 self.quantized_data.iter().map(|block| block.data.len()).sum()
563 }
564}
565
566impl SensitivityAnalyzer {
567 fn new(config: SensitivityConfig) -> Self {
568 Self {
569 config,
570 sensitivity_cache: HashMap::new(),
571 }
572 }
573
574 fn analyze_sensitivity(
576 &mut self,
577 tensor: &Tensor,
578 layer_name: &str,
579 _config: &LayerQuantConfig,
580 ) -> Result<Vec<f32>> {
581 if let Some(cached_scores) = self.sensitivity_cache.get(layer_name) {
583 return Ok(cached_scores.clone());
584 }
585
586 let data = tensor.data()?;
587 let mut sensitivity_scores = vec![0.0; data.len()];
588
589 for metric in &self.config.metrics {
591 let metric_scores = self.compute_metric_scores(tensor, metric)?;
592
593 for (i, score) in metric_scores.iter().enumerate() {
595 sensitivity_scores[i] += score / self.config.metrics.len() as f32;
596 }
597 }
598
599 self.sensitivity_cache
601 .insert(layer_name.to_string(), sensitivity_scores.clone());
602
603 Ok(sensitivity_scores)
604 }
605
606 fn compute_metric_scores(
608 &self,
609 tensor: &Tensor,
610 metric: &SensitivityMetric,
611 ) -> Result<Vec<f32>> {
612 let data = tensor.data()?;
613
614 match metric {
615 SensitivityMetric::WeightVariance => {
616 let mean = data.iter().sum::<f32>() / data.len() as f32;
618 let variance: Vec<f32> = data.iter().map(|&x| (x - mean).powi(2)).collect();
619 Ok(variance)
620 },
621 SensitivityMetric::GradientMagnitude => {
622 Ok(data.iter().map(|&x| x.abs()).collect())
624 },
625 SensitivityMetric::ActivationVariance => {
626 Ok(data.iter().map(|&x| x.abs()).collect())
628 },
629 SensitivityMetric::HessianDiagonal => {
630 let hessian_approx: Vec<f32> = data.iter().map(|&x| x.powi(2)).collect();
632 Ok(hessian_approx)
633 },
634 SensitivityMetric::OutputSensitivity => {
635 Ok(data.iter().map(|&x| x.abs()).collect())
637 },
638 }
639 }
640}
641
642#[cfg(test)]
643mod tests {
644 use super::*;
645 use crate::tensor::Tensor;
646
647 #[test]
648 fn test_mixed_bit_quantizer_creation() {
649 let config = MixedBitConfig::default();
650 let quantizer = MixedBitQuantizer::new(config);
651 assert!(quantizer.config.auto_bit_allocation.is_some());
652 }
653
654 #[test]
655 fn test_mixed_bit_quantization() -> Result<()> {
656 let mut quantizer = MixedBitQuantizer::new(MixedBitConfig::default());
657 let tensor = Tensor::randn(&[4, 4])?;
658
659 let quantized = quantizer.quantize(&tensor, "test_layer")?;
660 assert_eq!(quantized.shape, vec![4, 4]);
661 assert!(!quantized.quantized_data.is_empty());
662
663 Ok(())
664 }
665
666 #[test]
667 fn test_mixed_bit_dequantization() -> Result<()> {
668 let mut quantizer = MixedBitQuantizer::new(MixedBitConfig::default());
669 let tensor = Tensor::randn(&[2, 2])?;
670
671 let quantized = quantizer.quantize(&tensor, "test_layer")?;
672 let dequantized = quantized.dequantize()?;
673
674 assert_eq!(dequantized.shape(), tensor.shape());
675 Ok(())
676 }
677
678 #[test]
679 fn test_average_bit_width() -> Result<()> {
680 let mut quantizer = MixedBitQuantizer::new(MixedBitConfig::default());
681 let tensor = Tensor::randn(&[8])?;
682
683 let quantized = quantizer.quantize(&tensor, "test_layer")?;
684 let avg_bits = quantized.average_bit_width();
685
686 assert!(avg_bits > 0.0);
687 assert!(avg_bits <= 8.0);
688 Ok(())
689 }
690
691 #[test]
692 fn test_compression_ratio() -> Result<()> {
693 let mut quantizer = MixedBitQuantizer::new(MixedBitConfig::default());
694 let tensor = Tensor::randn(&[1024])?; let quantized = quantizer.quantize(&tensor, "test_layer")?;
697 let ratio = quantizer.compression_ratio(tensor.size(), &quantized);
698
699 assert!(ratio >= 1.0); Ok(())
701 }
702
703 #[test]
704 fn test_sensitivity_analysis() -> Result<()> {
705 let config = SensitivityConfig::default();
706 let mut analyzer = SensitivityAnalyzer::new(config);
707 let tensor = Tensor::randn(&[4, 4])?;
708
709 let layer_config = LayerQuantConfig::default();
710 let scores = analyzer.analyze_sensitivity(&tensor, "test_layer", &layer_config)?;
711
712 assert_eq!(scores.len(), 16);
713 assert!(scores.iter().all(|&score| score >= 0.0));
714 Ok(())
715 }
716}