Skip to main content

trustformers_core/quantization/
mixed_bit.rs

1//! Mixed-bit quantization implementation for TrustformeRS
2//!
3//! This module provides mixed-bit quantization where different layers/channels
4//! can use different quantization bit widths based on their importance and
5//! sensitivity to quantization errors.
6
7#![allow(unused_variables)] // Mixed-bit quantization implementation
8
9use crate::errors::Result;
10use crate::quantization::base::QuantizationScheme;
11use crate::tensor::Tensor;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14
15/// Mixed-bit quantization configuration
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct MixedBitConfig {
18    /// Layer-specific quantization configurations
19    pub layer_configs: HashMap<String, LayerQuantConfig>,
20    /// Default configuration for layers not specified
21    pub default_config: LayerQuantConfig,
22    /// Sensitivity analysis configuration
23    pub sensitivity_config: SensitivityConfig,
24    /// Automatic bit allocation strategy
25    pub auto_bit_allocation: Option<AutoBitAllocationStrategy>,
26}
27
28/// Layer-specific quantization configuration
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct LayerQuantConfig {
31    /// Bit width for weights
32    pub weight_bits: u8,
33    /// Bit width for activations
34    pub activation_bits: u8,
35    /// Quantization scheme to use
36    pub scheme: QuantizationScheme,
37    /// Whether to use symmetric quantization
38    pub symmetric: bool,
39    /// Group size for grouped quantization
40    pub group_size: Option<usize>,
41    /// Channel-specific bit allocation
42    pub channel_bits: Option<Vec<u8>>,
43}
44
45/// Sensitivity analysis configuration
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct SensitivityConfig {
48    /// Number of calibration samples
49    pub calibration_samples: usize,
50    /// Threshold for sensitivity (higher = more sensitive)
51    pub sensitivity_threshold: f32,
52    /// Metrics to consider for sensitivity analysis
53    pub metrics: Vec<SensitivityMetric>,
54}
55
56/// Sensitivity metrics for determining quantization bit allocation
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub enum SensitivityMetric {
59    /// Gradient magnitude
60    GradientMagnitude,
61    /// Hessian diagonal
62    HessianDiagonal,
63    /// Activation variance
64    ActivationVariance,
65    /// Weight variance
66    WeightVariance,
67    /// Output sensitivity
68    OutputSensitivity,
69}
70
71/// Automatic bit allocation strategies
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub enum AutoBitAllocationStrategy {
74    /// Sensitivity-based allocation
75    SensitivityBased {
76        /// Target model size compression ratio
77        target_compression: f32,
78        /// Minimum bits per layer
79        min_bits: u8,
80        /// Maximum bits per layer
81        max_bits: u8,
82    },
83    /// Uniform allocation with adaptive adjustment
84    AdaptiveUniform {
85        /// Base bit width
86        base_bits: u8,
87        /// Adjustment range
88        adjustment_range: u8,
89    },
90    /// Performance-driven allocation
91    PerformanceDriven {
92        /// Target inference latency (ms)
93        target_latency: f32,
94        /// Accuracy tolerance
95        accuracy_tolerance: f32,
96    },
97}
98
99/// Mixed-bit quantized tensor
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct MixedBitQuantizedTensor {
102    /// Layer name
103    pub layer_name: String,
104    /// Quantized data with different bit widths
105    pub quantized_data: Vec<QuantizedBlock>,
106    /// Original tensor shape
107    pub shape: Vec<usize>,
108    /// Quantization configuration used
109    pub config: LayerQuantConfig,
110    /// Sensitivity scores for each block
111    pub sensitivity_scores: Vec<f32>,
112}
113
114/// A block of quantized data with specific bit width
115#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct QuantizedBlock {
117    /// Quantized data
118    pub data: Vec<u8>,
119    /// Scale factor
120    pub scale: f32,
121    /// Zero point
122    pub zero_point: i32,
123    /// Bit width used for this block
124    pub bit_width: u8,
125    /// Block shape
126    pub block_shape: Vec<usize>,
127    /// Block offset in the original tensor
128    pub block_offset: Vec<usize>,
129}
130
131/// Mixed-bit quantizer
132pub struct MixedBitQuantizer {
133    config: MixedBitConfig,
134    sensitivity_analyzer: SensitivityAnalyzer,
135}
136
137/// Sensitivity analyzer for determining optimal bit allocations
138struct SensitivityAnalyzer {
139    config: SensitivityConfig,
140    sensitivity_cache: HashMap<String, Vec<f32>>,
141}
142
143impl Default for MixedBitConfig {
144    fn default() -> Self {
145        Self {
146            layer_configs: HashMap::new(),
147            default_config: LayerQuantConfig::default(),
148            sensitivity_config: SensitivityConfig::default(),
149            auto_bit_allocation: Some(AutoBitAllocationStrategy::SensitivityBased {
150                target_compression: 0.25, // 4x compression
151                min_bits: 2,
152                max_bits: 8,
153            }),
154        }
155    }
156}
157
158impl Default for LayerQuantConfig {
159    fn default() -> Self {
160        Self {
161            weight_bits: 4,
162            activation_bits: 8,
163            scheme: QuantizationScheme::Int4,
164            symmetric: true,
165            group_size: Some(128),
166            channel_bits: None,
167        }
168    }
169}
170
171impl Default for SensitivityConfig {
172    fn default() -> Self {
173        Self {
174            calibration_samples: 128,
175            sensitivity_threshold: 0.01,
176            metrics: vec![
177                SensitivityMetric::GradientMagnitude,
178                SensitivityMetric::ActivationVariance,
179                SensitivityMetric::WeightVariance,
180            ],
181        }
182    }
183}
184
185impl MixedBitQuantizer {
186    /// Create a new mixed-bit quantizer
187    pub fn new(config: MixedBitConfig) -> Self {
188        let sensitivity_analyzer = SensitivityAnalyzer::new(config.sensitivity_config.clone());
189        Self {
190            config,
191            sensitivity_analyzer,
192        }
193    }
194
195    /// Quantize a tensor using mixed-bit quantization
196    pub fn quantize(
197        &mut self,
198        tensor: &Tensor,
199        layer_name: &str,
200    ) -> Result<MixedBitQuantizedTensor> {
201        // Get or create layer configuration
202        let layer_config = self
203            .config
204            .layer_configs
205            .get(layer_name)
206            .cloned()
207            .unwrap_or_else(|| self.config.default_config.clone());
208
209        // Analyze sensitivity if needed
210        let sensitivity_scores = if let Some(ref auto_strategy) = self.config.auto_bit_allocation {
211            self.sensitivity_analyzer
212                .analyze_sensitivity(tensor, layer_name, &layer_config)?
213        } else {
214            vec![1.0; tensor.shape().iter().product()]
215        };
216
217        // Allocate bits based on sensitivity
218        let bit_allocation = self.allocate_bits(&sensitivity_scores, &layer_config)?;
219
220        // Quantize tensor into blocks
221        let quantized_blocks = self.quantize_blocks(tensor, &bit_allocation, &layer_config)?;
222
223        Ok(MixedBitQuantizedTensor {
224            layer_name: layer_name.to_string(),
225            quantized_data: quantized_blocks,
226            shape: tensor.shape(),
227            config: layer_config,
228            sensitivity_scores,
229        })
230    }
231
232    /// Allocate bits based on sensitivity scores
233    fn allocate_bits(
234        &self,
235        sensitivity_scores: &[f32],
236        config: &LayerQuantConfig,
237    ) -> Result<Vec<u8>> {
238        let mut bit_allocation = vec![config.weight_bits; sensitivity_scores.len()];
239
240        if let Some(ref strategy) = self.config.auto_bit_allocation {
241            match strategy {
242                AutoBitAllocationStrategy::SensitivityBased {
243                    target_compression,
244                    min_bits,
245                    max_bits,
246                } => {
247                    // Sort indices by sensitivity
248                    let mut indexed_scores: Vec<(usize, f32)> = sensitivity_scores
249                        .iter()
250                        .enumerate()
251                        .map(|(i, &score)| (i, score))
252                        .collect();
253                    indexed_scores
254                        .sort_by(|a, b| b.1.partial_cmp(&a.1).expect("Partial comparison failed"));
255
256                    // Calculate target total bits
257                    let total_elements = sensitivity_scores.len();
258                    let target_total_bits = (total_elements as f32
259                        * config.weight_bits as f32
260                        * target_compression) as usize;
261                    let mut allocated_bits = 0;
262
263                    // Allocate bits starting from most sensitive
264                    for (idx, _) in indexed_scores {
265                        let remaining_elements =
266                            total_elements - allocated_bits / (*max_bits as usize);
267                        let remaining_budget = target_total_bits.saturating_sub(allocated_bits);
268
269                        let avg_bits_remaining =
270                            remaining_budget.checked_div(remaining_elements).unwrap_or(0);
271                        if avg_bits_remaining > 0 {
272                            let bits = (avg_bits_remaining as u8).clamp(*min_bits, *max_bits);
273                            bit_allocation[idx] = bits;
274                            allocated_bits += bits as usize;
275                        }
276                    }
277                },
278                AutoBitAllocationStrategy::AdaptiveUniform {
279                    base_bits,
280                    adjustment_range,
281                } => {
282                    // Calculate mean sensitivity
283                    let mean_sensitivity =
284                        sensitivity_scores.iter().sum::<f32>() / sensitivity_scores.len() as f32;
285
286                    for (i, &score) in sensitivity_scores.iter().enumerate() {
287                        let normalized_score = score / mean_sensitivity;
288                        let adjustment = (normalized_score * *adjustment_range as f32) as i8;
289                        let bits = (*base_bits as i8 + adjustment).clamp(1, 8) as u8;
290                        bit_allocation[i] = bits;
291                    }
292                },
293                AutoBitAllocationStrategy::PerformanceDriven {
294                    target_latency,
295                    accuracy_tolerance,
296                } => {
297                    // Performance-driven allocation optimizes for latency while maintaining accuracy
298                    return self.allocate_bits_performance_driven(
299                        sensitivity_scores,
300                        config,
301                        *target_latency,
302                        *accuracy_tolerance,
303                    );
304                },
305            }
306        }
307
308        Ok(bit_allocation)
309    }
310
311    /// Allocate bits based on sensitivity scores (fallback implementation)
312    #[allow(dead_code)]
313    fn allocate_bits_sensitivity_based(
314        &self,
315        sensitivity_scores: &[f32],
316        config: &LayerQuantConfig,
317    ) -> Result<Vec<u8>> {
318        let mut bit_allocation = vec![config.weight_bits; sensitivity_scores.len()];
319
320        // Find sensitivity percentiles
321        let mut sorted_scores = sensitivity_scores.to_vec();
322        sorted_scores.sort_by(|a, b| a.partial_cmp(b).expect("Partial comparison failed"));
323
324        let high_sensitivity_threshold =
325            sorted_scores[(sorted_scores.len() * 90 / 100).min(sorted_scores.len() - 1)];
326        let low_sensitivity_threshold = sorted_scores[sorted_scores.len() * 10 / 100];
327
328        for (i, &score) in sensitivity_scores.iter().enumerate() {
329            if score >= high_sensitivity_threshold {
330                bit_allocation[i] = 8; // High precision for sensitive parts
331            } else if score <= low_sensitivity_threshold {
332                bit_allocation[i] = 2; // Low precision for insensitive parts
333            } else {
334                bit_allocation[i] = 4; // Medium precision
335            }
336        }
337
338        Ok(bit_allocation)
339    }
340
341    /// Performance-driven bit allocation optimizing for latency while maintaining accuracy
342    fn allocate_bits_performance_driven(
343        &self,
344        sensitivity_scores: &[f32],
345        config: &LayerQuantConfig,
346        target_latency: f32,
347        accuracy_tolerance: f32,
348    ) -> Result<Vec<u8>> {
349        let total_elements = sensitivity_scores.len();
350
351        // Model performance characteristics (simplified model)
352        // Lower bits = faster computation but potentially lower accuracy
353        let performance_factor = |bits: u8| -> f32 {
354            match bits {
355                1 => 0.1,  // Very fast, very low accuracy impact
356                2 => 0.25, // Fast, low accuracy impact
357                3 => 0.4,  // Medium-fast, medium accuracy impact
358                4 => 0.6,  // Medium, medium accuracy impact
359                5 => 0.75, // Medium-slow, higher accuracy
360                6 => 0.85, // Slow, high accuracy
361                7 => 0.92, // Very slow, very high accuracy
362                8 => 1.0,  // Slowest, highest accuracy
363                _ => 1.0,
364            }
365        };
366
367        // Calculate accuracy impact based on sensitivity
368        let accuracy_impact = |sensitivity: f32, bits: u8| -> f32 {
369            let base_impact = sensitivity / 100.0; // Normalize sensitivity
370            let bit_factor = (8.0 - bits as f32) / 7.0; // Higher impact with fewer bits
371            base_impact * bit_factor
372        };
373
374        // Sort elements by sensitivity to prioritize important layers
375        let mut indexed_scores: Vec<(usize, f32)> =
376            sensitivity_scores.iter().enumerate().map(|(i, &score)| (i, score)).collect();
377        indexed_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("Partial comparison failed"));
378
379        // Start with lowest bits for maximum performance
380        let mut current_bits = vec![2u8; total_elements];
381        let mut current_latency = 0.0;
382        let mut current_accuracy_loss = 0.0;
383
384        // Calculate initial latency and accuracy
385        for (i, &score) in sensitivity_scores.iter().enumerate() {
386            current_latency += performance_factor(2);
387            current_accuracy_loss += accuracy_impact(score, 2);
388        }
389
390        // Iteratively increase bits for most sensitive elements until we hit constraints
391        for (idx, sensitivity) in indexed_scores {
392            let current_element_bits = current_bits[idx];
393
394            // Try increasing bits for this element
395            for new_bits in (current_element_bits + 1)..=8 {
396                let latency_change =
397                    performance_factor(new_bits) - performance_factor(current_element_bits);
398                let accuracy_change = accuracy_impact(sensitivity, current_element_bits)
399                    - accuracy_impact(sensitivity, new_bits);
400
401                let new_latency = current_latency + latency_change;
402                let new_accuracy_loss = current_accuracy_loss - accuracy_change;
403
404                // Check if this change fits within our constraints
405                let normalized_latency = new_latency / total_elements as f32;
406                if normalized_latency <= target_latency && new_accuracy_loss <= accuracy_tolerance {
407                    // Apply the change
408                    current_bits[idx] = new_bits;
409                    current_latency = new_latency;
410                    current_accuracy_loss = new_accuracy_loss;
411                } else {
412                    // Can't improve this element further, move to next
413                    break;
414                }
415            }
416        }
417
418        // Apply final allocation
419        let bit_allocation = current_bits;
420
421        Ok(bit_allocation)
422    }
423
424    /// Quantize tensor into blocks with different bit widths
425    fn quantize_blocks(
426        &self,
427        tensor: &Tensor,
428        bit_allocation: &[u8],
429        config: &LayerQuantConfig,
430    ) -> Result<Vec<QuantizedBlock>> {
431        let data = tensor.data()?;
432        let shape = tensor.shape();
433        let mut blocks = Vec::new();
434
435        // Group elements by bit width
436        let mut bit_groups: HashMap<u8, Vec<(usize, f32)>> = HashMap::new();
437        for (i, (&bits, &value)) in bit_allocation.iter().zip(data.iter()).enumerate() {
438            bit_groups.entry(bits).or_default().push((i, value));
439        }
440
441        // Quantize each group
442        for (bit_width, elements) in bit_groups {
443            let values: Vec<f32> = elements.iter().map(|(_, v)| *v).collect();
444            let indices: Vec<usize> = elements.iter().map(|(i, _)| *i).collect();
445
446            let (quantized_data, scale, zero_point) =
447                self.quantize_group(&values, bit_width, config)?;
448
449            blocks.push(QuantizedBlock {
450                data: quantized_data,
451                scale,
452                zero_point,
453                bit_width,
454                block_shape: vec![values.len()],
455                block_offset: vec![indices[0]], // Simplified for now
456            });
457        }
458
459        Ok(blocks)
460    }
461
462    /// Quantize a group of values with specified bit width
463    fn quantize_group(
464        &self,
465        values: &[f32],
466        bit_width: u8,
467        config: &LayerQuantConfig,
468    ) -> Result<(Vec<u8>, f32, i32)> {
469        if values.is_empty() {
470            return Ok((Vec::new(), 1.0, 0));
471        }
472
473        let min_val = values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
474        let max_val = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
475
476        let qmin = 0;
477        let qmax = (1 << bit_width) - 1;
478
479        let (scale, zero_point) = if config.symmetric {
480            let max_abs = max_val.abs().max(min_val.abs());
481            let scale = max_abs / (qmax as f32 / 2.0);
482            (scale, qmax / 2)
483        } else {
484            let scale = (max_val - min_val) / (qmax - qmin) as f32;
485            let zero_point = qmin as f32 - min_val / scale;
486            (scale, zero_point.round() as i32)
487        };
488
489        let mut quantized = Vec::with_capacity(values.len());
490        for &value in values {
491            let q_val = (value / scale + zero_point as f32).round() as i32;
492            let clamped = q_val.clamp(qmin, qmax) as u8;
493            quantized.push(clamped);
494        }
495
496        Ok((quantized, scale, zero_point))
497    }
498
499    /// Get compression ratio achieved
500    pub fn compression_ratio(
501        &self,
502        original_size: usize,
503        quantized_tensor: &MixedBitQuantizedTensor,
504    ) -> f32 {
505        let compressed_size: usize =
506            quantized_tensor.quantized_data.iter().map(|block| block.data.len()).sum();
507
508        original_size as f32 / compressed_size as f32
509    }
510
511    /// Estimate memory savings
512    pub fn memory_savings(
513        &self,
514        original_tensor: &Tensor,
515        quantized_tensor: &MixedBitQuantizedTensor,
516    ) -> f32 {
517        let original_bytes = original_tensor.size() * std::mem::size_of::<f32>();
518        let quantized_bytes: usize =
519            quantized_tensor.quantized_data.iter().map(|block| block.data.len()).sum();
520
521        1.0 - (quantized_bytes as f32 / original_bytes as f32)
522    }
523}
524
525impl MixedBitQuantizedTensor {
526    /// Dequantize back to original tensor
527    pub fn dequantize(&self) -> Result<Tensor> {
528        let total_elements: usize = self.shape.iter().product();
529        let mut result = vec![0.0f32; total_elements];
530
531        for block in &self.quantized_data {
532            for (i, &quantized_val) in block.data.iter().enumerate() {
533                let dequantized = (quantized_val as i32 - block.zero_point) as f32 * block.scale;
534                // Simplified mapping - in practice, would need proper index mapping
535                if i < result.len() {
536                    result[i] = dequantized;
537                }
538            }
539        }
540
541        Tensor::from_vec(result, &self.shape)
542    }
543
544    /// Get average bit width used
545    pub fn average_bit_width(&self) -> f32 {
546        let total_elements: usize = self.quantized_data.iter().map(|b| b.data.len()).sum();
547        if total_elements == 0 {
548            return 0.0;
549        }
550
551        let total_bits: f32 = self
552            .quantized_data
553            .iter()
554            .map(|block| block.data.len() as f32 * block.bit_width as f32)
555            .sum();
556
557        total_bits / total_elements as f32
558    }
559
560    /// Get memory footprint in bytes
561    pub fn memory_footprint(&self) -> usize {
562        self.quantized_data.iter().map(|block| block.data.len()).sum()
563    }
564}
565
566impl SensitivityAnalyzer {
567    fn new(config: SensitivityConfig) -> Self {
568        Self {
569            config,
570            sensitivity_cache: HashMap::new(),
571        }
572    }
573
574    /// Analyze sensitivity of tensor elements
575    fn analyze_sensitivity(
576        &mut self,
577        tensor: &Tensor,
578        layer_name: &str,
579        _config: &LayerQuantConfig,
580    ) -> Result<Vec<f32>> {
581        // Check cache first
582        if let Some(cached_scores) = self.sensitivity_cache.get(layer_name) {
583            return Ok(cached_scores.clone());
584        }
585
586        let data = tensor.data()?;
587        let mut sensitivity_scores = vec![0.0; data.len()];
588
589        // Analyze each configured metric
590        for metric in &self.config.metrics {
591            let metric_scores = self.compute_metric_scores(tensor, metric)?;
592
593            // Combine metrics (simple averaging for now)
594            for (i, score) in metric_scores.iter().enumerate() {
595                sensitivity_scores[i] += score / self.config.metrics.len() as f32;
596            }
597        }
598
599        // Cache the results
600        self.sensitivity_cache
601            .insert(layer_name.to_string(), sensitivity_scores.clone());
602
603        Ok(sensitivity_scores)
604    }
605
606    /// Compute sensitivity scores for a specific metric
607    fn compute_metric_scores(
608        &self,
609        tensor: &Tensor,
610        metric: &SensitivityMetric,
611    ) -> Result<Vec<f32>> {
612        let data = tensor.data()?;
613
614        match metric {
615            SensitivityMetric::WeightVariance => {
616                // Compute local variance as a sensitivity measure
617                let mean = data.iter().sum::<f32>() / data.len() as f32;
618                let variance: Vec<f32> = data.iter().map(|&x| (x - mean).powi(2)).collect();
619                Ok(variance)
620            },
621            SensitivityMetric::GradientMagnitude => {
622                // Approximate gradient magnitude using weight magnitude
623                Ok(data.iter().map(|&x| x.abs()).collect())
624            },
625            SensitivityMetric::ActivationVariance => {
626                // For weights, use magnitude as proxy for activation impact
627                Ok(data.iter().map(|&x| x.abs()).collect())
628            },
629            SensitivityMetric::HessianDiagonal => {
630                // Simplified hessian approximation
631                let hessian_approx: Vec<f32> = data.iter().map(|&x| x.powi(2)).collect();
632                Ok(hessian_approx)
633            },
634            SensitivityMetric::OutputSensitivity => {
635                // Use weight magnitude as proxy for output sensitivity
636                Ok(data.iter().map(|&x| x.abs()).collect())
637            },
638        }
639    }
640}
641
642#[cfg(test)]
643mod tests {
644    use super::*;
645    use crate::tensor::Tensor;
646
647    #[test]
648    fn test_mixed_bit_quantizer_creation() {
649        let config = MixedBitConfig::default();
650        let quantizer = MixedBitQuantizer::new(config);
651        assert!(quantizer.config.auto_bit_allocation.is_some());
652    }
653
654    #[test]
655    fn test_mixed_bit_quantization() -> Result<()> {
656        let mut quantizer = MixedBitQuantizer::new(MixedBitConfig::default());
657        let tensor = Tensor::randn(&[4, 4])?;
658
659        let quantized = quantizer.quantize(&tensor, "test_layer")?;
660        assert_eq!(quantized.shape, vec![4, 4]);
661        assert!(!quantized.quantized_data.is_empty());
662
663        Ok(())
664    }
665
666    #[test]
667    fn test_mixed_bit_dequantization() -> Result<()> {
668        let mut quantizer = MixedBitQuantizer::new(MixedBitConfig::default());
669        let tensor = Tensor::randn(&[2, 2])?;
670
671        let quantized = quantizer.quantize(&tensor, "test_layer")?;
672        let dequantized = quantized.dequantize()?;
673
674        assert_eq!(dequantized.shape(), tensor.shape());
675        Ok(())
676    }
677
678    #[test]
679    fn test_average_bit_width() -> Result<()> {
680        let mut quantizer = MixedBitQuantizer::new(MixedBitConfig::default());
681        let tensor = Tensor::randn(&[8])?;
682
683        let quantized = quantizer.quantize(&tensor, "test_layer")?;
684        let avg_bits = quantized.average_bit_width();
685
686        assert!(avg_bits > 0.0);
687        assert!(avg_bits <= 8.0);
688        Ok(())
689    }
690
691    #[test]
692    fn test_compression_ratio() -> Result<()> {
693        let mut quantizer = MixedBitQuantizer::new(MixedBitConfig::default());
694        let tensor = Tensor::randn(&[1024])?; // Use larger tensor to overcome metadata overhead
695
696        let quantized = quantizer.quantize(&tensor, "test_layer")?;
697        let ratio = quantizer.compression_ratio(tensor.size(), &quantized);
698
699        assert!(ratio >= 1.0); // Current implementation stores as bytes, so ratio may be 1.0
700        Ok(())
701    }
702
703    #[test]
704    fn test_sensitivity_analysis() -> Result<()> {
705        let config = SensitivityConfig::default();
706        let mut analyzer = SensitivityAnalyzer::new(config);
707        let tensor = Tensor::randn(&[4, 4])?;
708
709        let layer_config = LayerQuantConfig::default();
710        let scores = analyzer.analyze_sensitivity(&tensor, "test_layer", &layer_config)?;
711
712        assert_eq!(scores.len(), 16);
713        assert!(scores.iter().all(|&score| score >= 0.0));
714        Ok(())
715    }
716}