Skip to main content

torsh_quantization/
metrics.rs

1//! Quantization quality metrics and analysis tools
2//!
3//! This module provides comprehensive tools for measuring and analyzing the quality
4//! of quantization operations, including performance metrics, benchmarking utilities,
5//! and automated analysis reports.
6//!
7//! # Features
8//!
9//! - **Quality Metrics**: MSE, PSNR, SNR, MAE, cosine similarity calculations
10//! - **Performance Analysis**: Timing, compression ratio, and efficiency metrics
11//! - **Configuration Comparison**: Side-by-side comparison of quantization schemes
12//! - **Auto-calibration**: Automated optimal configuration selection
13//! - **Report Generation**: Comprehensive analysis reports in Markdown format
14//! - **Outlier Detection**: Statistical analysis and recommendation systems
15
16use crate::config::{ObserverType, QuantConfig};
17
18#[cfg(not(feature = "std"))]
19extern crate alloc;
20
21#[cfg(not(feature = "std"))]
22use alloc::{collections::BTreeMap as HashMap, format, string::String, vec::Vec};
23
24use torsh_core::{
25    dtype::DType,
26    error::{Result as TorshResult, TorshError},
27};
28use torsh_tensor::Tensor;
29
30/// Quantization quality metrics for measuring accuracy loss
31#[derive(Debug, Clone)]
32pub struct QuantizationMetrics {
33    /// Mean Squared Error between original and quantized tensors
34    pub mse: f32,
35    /// Peak Signal-to-Noise Ratio (PSNR) in dB
36    pub psnr: f32,
37    /// Signal-to-Noise Ratio (SNR) in dB
38    pub snr: f32,
39    /// Mean Absolute Error between original and quantized tensors
40    pub mae: f32,
41    /// Maximum absolute error
42    pub max_error: f32,
43    /// Percentage of values with zero error
44    pub zero_error_percentage: f32,
45    /// Cosine similarity between original and quantized tensors
46    pub cosine_similarity: f32,
47    /// Compression ratio achieved
48    pub compression_ratio: f32,
49}
50
51impl Default for QuantizationMetrics {
52    fn default() -> Self {
53        Self {
54            mse: 0.0,
55            psnr: 0.0,
56            snr: 0.0,
57            mae: 0.0,
58            max_error: 0.0,
59            zero_error_percentage: 100.0,
60            cosine_similarity: 1.0,
61            compression_ratio: 1.0,
62        }
63    }
64}
65
66/// Calculate comprehensive quantization quality metrics
67pub fn calculate_quantization_metrics(
68    original: &Tensor,
69    quantized: &Tensor,
70    original_bits: u32,
71    quantized_bits: u32,
72) -> TorshResult<QuantizationMetrics> {
73    if original.shape() != quantized.shape() {
74        return Err(TorshError::InvalidArgument(format!(
75            "Shape mismatch: expected {:?}, got {:?}",
76            original.shape(),
77            quantized.shape()
78        )));
79    }
80
81    let original_data = original.data()?;
82    let quantized_data = quantized.data()?;
83
84    if original_data.len() != quantized_data.len() {
85        return Err(TorshError::InvalidArgument(
86            "Data length mismatch between tensors".to_string(),
87        ));
88    }
89
90    if original_data.is_empty() {
91        return Ok(QuantizationMetrics::default());
92    }
93
94    // Calculate MSE
95    let mse = original_data
96        .iter()
97        .zip(quantized_data.iter())
98        .map(|(a, b)| (a - b).powi(2))
99        .sum::<f32>()
100        / original_data.len() as f32;
101
102    // Calculate MAE
103    let mae = original_data
104        .iter()
105        .zip(quantized_data.iter())
106        .map(|(a, b)| (a - b).abs())
107        .sum::<f32>()
108        / original_data.len() as f32;
109
110    // Calculate max error
111    let max_error = original_data
112        .iter()
113        .zip(quantized_data.iter())
114        .map(|(a, b)| (a - b).abs())
115        .fold(0.0, f32::max);
116
117    // Calculate zero error percentage
118    let zero_errors = original_data
119        .iter()
120        .zip(quantized_data.iter())
121        .filter(|(a, b)| (*a - *b).abs() < 1e-7)
122        .count();
123    let zero_error_percentage = (zero_errors as f32 / original_data.len() as f32) * 100.0;
124
125    // Calculate signal power for SNR/PSNR
126    let signal_power =
127        original_data.iter().map(|x| x.powi(2)).sum::<f32>() / original_data.len() as f32;
128
129    // Calculate PSNR (assuming signal range is [0, 1])
130    let max_signal = original_data
131        .iter()
132        .fold(0.0f32, |acc, &x| acc.max(x.abs()));
133    let psnr = if mse > 0.0 {
134        20.0 * (max_signal / mse.sqrt()).log10()
135    } else {
136        f32::INFINITY
137    };
138
139    // Calculate SNR
140    let snr = if mse > 0.0 && signal_power > 0.0 {
141        10.0 * (signal_power / mse).log10()
142    } else {
143        f32::INFINITY
144    };
145
146    // Calculate cosine similarity
147    let dot_product = original_data
148        .iter()
149        .zip(quantized_data.iter())
150        .map(|(a, b)| a * b)
151        .sum::<f32>();
152
153    let original_norm = original_data.iter().map(|x| x.powi(2)).sum::<f32>().sqrt();
154    let quantized_norm = quantized_data.iter().map(|x| x.powi(2)).sum::<f32>().sqrt();
155
156    let cosine_similarity = if original_norm > 0.0 && quantized_norm > 0.0 {
157        dot_product / (original_norm * quantized_norm)
158    } else {
159        0.0
160    };
161
162    // Calculate compression ratio
163    let compression_ratio = original_bits as f32 / quantized_bits as f32;
164
165    Ok(QuantizationMetrics {
166        mse,
167        psnr,
168        snr,
169        mae,
170        max_error,
171        zero_error_percentage,
172        cosine_similarity,
173        compression_ratio,
174    })
175}
176
177/// Compare multiple quantization configurations and return ranked results
178pub fn compare_quantization_configs(
179    tensor: &Tensor,
180    configs: &[QuantConfig],
181) -> TorshResult<Vec<(QuantConfig, QuantizationMetrics, f64)>> {
182    let mut results = Vec::new();
183
184    for config in configs {
185        // Time the quantization process
186        let start = std::time::Instant::now();
187
188        // Quantize the tensor
189        let quantize_result = crate::algorithms::quantize_with_config(tensor, config);
190
191        let duration = start.elapsed().as_secs_f64();
192
193        match quantize_result {
194            Ok((quantized, scale, zero_point)) => {
195                // Dequantize back to original precision
196                let dequantized = crate::algorithms::dequantize(&quantized, scale, zero_point)?;
197
198                // Calculate metrics
199                let original_bits = match tensor.dtype() {
200                    DType::F32 => 32,
201                    DType::F16 => 16,
202                    _ => 8,
203                };
204
205                let quantized_bits = match config.dtype {
206                    DType::I8 | DType::U8 => 8,
207                    DType::I16 => 16,
208                    DType::I32 => 32,
209                    DType::F16 => 16,
210                    DType::F32 => 32,
211                    _ => 8,
212                };
213
214                let metrics = calculate_quantization_metrics(
215                    tensor,
216                    &dequantized,
217                    original_bits,
218                    quantized_bits,
219                )?;
220
221                results.push((config.clone(), metrics, duration));
222            }
223            Err(_) => {
224                // If quantization fails, create a worst-case metrics entry
225                let worst_metrics = QuantizationMetrics {
226                    mse: f32::INFINITY,
227                    psnr: f32::NEG_INFINITY,
228                    snr: f32::NEG_INFINITY,
229                    mae: f32::INFINITY,
230                    max_error: f32::INFINITY,
231                    zero_error_percentage: 0.0,
232                    cosine_similarity: 0.0,
233                    compression_ratio: 1.0,
234                };
235
236                results.push((config.clone(), worst_metrics, duration));
237            }
238        }
239    }
240
241    // Sort by PSNR (higher is better)
242    results.sort_by(|a, b| {
243        b.1.psnr
244            .partial_cmp(&a.1.psnr)
245            .unwrap_or(core::cmp::Ordering::Equal)
246    });
247
248    Ok(results)
249}
250
251/// Automatic calibration assistant to find optimal quantization configuration
252pub fn auto_calibrate_quantization(
253    calibration_tensors: &[&Tensor],
254    target_accuracy_threshold: f32,
255    max_compression_ratio: f32,
256) -> TorshResult<QuantConfig> {
257    if calibration_tensors.is_empty() {
258        return Err(TorshError::InvalidArgument(
259            "No calibration tensors provided".to_string(),
260        ));
261    }
262
263    // Define candidate configurations to test
264    let candidate_configs = vec![
265        QuantConfig::int8(),
266        QuantConfig::int8().with_observer(ObserverType::Histogram),
267        QuantConfig::per_channel(0),
268        QuantConfig::per_channel(1),
269        QuantConfig::group_wise(0, 8),
270        QuantConfig::group_wise(1, 16),
271        QuantConfig::int4(),
272        QuantConfig::ternary(),
273    ];
274
275    let mut best_config = None;
276    let mut best_score = f32::NEG_INFINITY;
277
278    // Test each configuration with all calibration tensors
279    for config in candidate_configs {
280        let mut total_metrics = QuantizationMetrics::default();
281        let mut successful_tests = 0;
282
283        for tensor in calibration_tensors {
284            if let Ok(comparison) =
285                compare_quantization_configs(tensor, std::slice::from_ref(&config))
286            {
287                if let Some((_, metrics, _)) = comparison.first() {
288                    if metrics.psnr.is_finite() {
289                        total_metrics.mse += metrics.mse;
290                        total_metrics.psnr += metrics.psnr;
291                        total_metrics.snr += metrics.snr;
292                        total_metrics.mae += metrics.mae;
293                        total_metrics.max_error = total_metrics.max_error.max(metrics.max_error);
294                        total_metrics.zero_error_percentage += metrics.zero_error_percentage;
295                        total_metrics.cosine_similarity += metrics.cosine_similarity;
296                        total_metrics.compression_ratio += metrics.compression_ratio;
297                        successful_tests += 1;
298                    }
299                }
300            }
301        }
302
303        if successful_tests > 0 {
304            // Average the metrics
305            let avg_metrics = QuantizationMetrics {
306                mse: total_metrics.mse / successful_tests as f32,
307                psnr: total_metrics.psnr / successful_tests as f32,
308                snr: total_metrics.snr / successful_tests as f32,
309                mae: total_metrics.mae / successful_tests as f32,
310                max_error: total_metrics.max_error,
311                zero_error_percentage: total_metrics.zero_error_percentage
312                    / successful_tests as f32,
313                cosine_similarity: total_metrics.cosine_similarity / successful_tests as f32,
314                compression_ratio: total_metrics.compression_ratio / successful_tests as f32,
315            };
316
317            // Calculate a composite score (higher is better)
318            let score = if avg_metrics.psnr >= target_accuracy_threshold
319                && avg_metrics.compression_ratio <= max_compression_ratio
320            {
321                // Prioritize compression ratio if accuracy threshold is met
322                avg_metrics.compression_ratio + avg_metrics.psnr / 100.0
323            } else {
324                // Otherwise prioritize accuracy
325                avg_metrics.psnr / avg_metrics.compression_ratio
326            };
327
328            if score > best_score {
329                best_score = score;
330                best_config = Some(config.clone());
331            }
332        }
333    }
334
335    best_config
336        .ok_or_else(|| TorshError::InvalidArgument("No suitable configuration found".to_string()))
337}
338
339/// Generate a comprehensive quantization report
340pub fn generate_quantization_report(
341    original: &Tensor,
342    configs: &[QuantConfig],
343) -> TorshResult<String> {
344    let mut report = String::new();
345
346    report.push_str("# Quantization Analysis Report\n\n");
347    report.push_str(&format!(
348        "**Original Tensor Shape:** {:?}\n",
349        original.shape()
350    ));
351    report.push_str(&format!(
352        "**Original Tensor DType:** {:?}\n",
353        original.dtype()
354    ));
355    report.push_str(&format!(
356        "**Number of Elements:** {}\n\n",
357        original.shape().numel()
358    ));
359
360    // Get tensor statistics
361    let data = original.data()?;
362    let min_val = data.iter().fold(f32::INFINITY, |acc, &x| acc.min(x));
363    let max_val = data.iter().fold(f32::NEG_INFINITY, |acc, &x| acc.max(x));
364    let mean = data.iter().sum::<f32>() / data.len() as f32;
365    let std_dev = (data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32).sqrt();
366
367    report.push_str("**Original Tensor Statistics:**\n");
368    report.push_str(&format!("- Min: {min_val:.6}\n"));
369    report.push_str(&format!("- Max: {max_val:.6}\n"));
370    report.push_str(&format!("- Mean: {mean:.6}\n"));
371    report.push_str(&format!("- Std Dev: {std_dev:.6}\n"));
372    report.push_str(&format!("- Dynamic Range: {:.6}\n\n", max_val - min_val));
373
374    // Compare configurations
375    let comparison_results = compare_quantization_configs(original, configs)?;
376
377    report.push_str("## Quantization Configuration Comparison\n\n");
378    report.push_str(
379        "| Rank | Scheme | Observer | PSNR (dB) | SNR (dB) | MAE | Compression | Time (ms) |\n",
380    );
381    report.push_str(
382        "|------|--------|----------|-----------|----------|-----|-------------|----------|\n",
383    );
384
385    for (rank, (config, metrics, duration)) in comparison_results.iter().enumerate() {
386        report.push_str(&format!(
387            "| {} | {:?} | {:?} | {:.2} | {:.2} | {:.6} | {:.1}x | {:.2} |\n",
388            rank + 1,
389            config.scheme,
390            config.observer_type,
391            metrics.psnr,
392            metrics.snr,
393            metrics.mae,
394            metrics.compression_ratio,
395            duration * 1000.0
396        ));
397    }
398
399    report.push_str("\n## Detailed Metrics\n\n");
400
401    for (rank, (config, metrics, _)) in comparison_results.iter().enumerate() {
402        report.push_str(&format!(
403            "### Configuration #{} - {:?}\n",
404            rank + 1,
405            config.scheme
406        ));
407        report.push_str(&format!("- **MSE:** {:.8}\n", metrics.mse));
408        report.push_str(&format!("- **PSNR:** {:.2} dB\n", metrics.psnr));
409        report.push_str(&format!("- **SNR:** {:.2} dB\n", metrics.snr));
410        report.push_str(&format!("- **MAE:** {:.6}\n", metrics.mae));
411        report.push_str(&format!("- **Max Error:** {:.6}\n", metrics.max_error));
412        report.push_str(&format!(
413            "- **Zero Error %:** {:.2}%\n",
414            metrics.zero_error_percentage
415        ));
416        report.push_str(&format!(
417            "- **Cosine Similarity:** {:.6}\n",
418            metrics.cosine_similarity
419        ));
420        report.push_str(&format!(
421            "- **Compression Ratio:** {:.1}x\n\n",
422            metrics.compression_ratio
423        ));
424    }
425
426    report.push_str("## Recommendations\n\n");
427
428    if let Some((best_config, best_metrics, _)) = comparison_results.first() {
429        report.push_str(&format!(
430            "**Best Configuration:** {:?} with {:?} observer\n",
431            best_config.scheme, best_config.observer_type
432        ));
433        report.push_str(&format!(
434            "- Achieves {:.2} dB PSNR with {:.1}x compression\n",
435            best_metrics.psnr, best_metrics.compression_ratio
436        ));
437
438        if best_metrics.psnr > 40.0 {
439            report.push_str("- ✅ Excellent quality preservation\n");
440        } else if best_metrics.psnr > 30.0 {
441            report.push_str("- ✅ Good quality preservation\n");
442        } else if best_metrics.psnr > 20.0 {
443            report.push_str("- ⚠️ Moderate quality loss\n");
444        } else {
445            report.push_str("- ❌ Significant quality loss\n");
446        }
447    }
448
449    Ok(report)
450}
451
452/// Generate optimization hints for tensor and configuration
453pub fn generate_optimization_hints(
454    tensor: &Tensor,
455    config: &QuantConfig,
456) -> TorshResult<Vec<String>> {
457    let mut hints = Vec::new();
458    let shape = tensor.shape();
459    let data = tensor.data()?;
460
461    // Data distribution analysis
462    if !data.is_empty() {
463        let min_val = data.iter().fold(f32::INFINITY, |acc, &x| acc.min(x));
464        let max_val = data.iter().fold(f32::NEG_INFINITY, |acc, &x| acc.max(x));
465        let mean = data.iter().sum::<f32>() / data.len() as f32;
466        let std_dev =
467            (data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32).sqrt();
468
469        // Range analysis
470        let dynamic_range = max_val - min_val;
471        if dynamic_range > 100.0 {
472            hints.push("Large dynamic range detected. Consider using Histogram or Percentile observer for better quantization parameters.".to_string());
473        }
474
475        // Sparsity analysis
476        let zero_count = data.iter().filter(|&&x| x.abs() < 1e-6).count();
477        let sparsity = zero_count as f32 / data.len() as f32;
478        if sparsity > 0.5 {
479            hints.push(
480                "High sparsity detected. Sparse quantization schemes may be more efficient."
481                    .to_string(),
482            );
483        }
484
485        // Outlier detection
486        let outlier_threshold = mean + 3.0 * std_dev;
487        let outlier_count = data
488            .iter()
489            .filter(|&&x| x.abs() > outlier_threshold)
490            .count();
491        if outlier_count > 0 {
492            hints.push("Outliers detected. Percentile-based observers may provide better quantization parameters.".to_string());
493        }
494
495        // Memory efficiency hints
496        if data.len() > 1_000_000 {
497            hints.push("For large tensors, Histogram observer may be more memory-efficient than Percentile observer.".to_string());
498        }
499    }
500
501    // Shape-based hints
502    if shape.dims().len() >= 2 && shape.dims().iter().any(|&dim| dim > 16) {
503        hints.push("Multi-channel tensor detected. Per-channel or group-wise quantization may provide better accuracy.".to_string());
504    }
505
506    // Scheme-specific hints
507    match config.scheme {
508        crate::config::QScheme::PerChannelAffine | crate::config::QScheme::PerChannelSymmetric => {
509            if let Some(axis) = config.ch_axis {
510                if axis >= shape.dims().len() {
511                    hints.push(
512                        "Channel axis is out of bounds. This will cause an error.".to_string(),
513                    );
514                } else if shape.dims()[axis] < 4 {
515                    hints.push(
516                        "Few channels detected. Per-tensor quantization might be sufficient."
517                            .to_string(),
518                    );
519                }
520            }
521        }
522        crate::config::QScheme::GroupWise => {
523            if let (Some(axis), Some(group_size)) = (config.ch_axis, config.group_size) {
524                if axis < shape.dims().len() {
525                    let num_channels = shape.dims()[axis];
526                    let num_groups = num_channels.div_ceil(group_size);
527                    if num_groups == 1 {
528                        hints.push("Only one group will be created. Consider per-tensor quantization instead.".to_string());
529                    } else if num_groups == num_channels {
530                        hints.push("Each channel forms its own group. Consider per-channel quantization instead.".to_string());
531                    }
532                }
533            }
534        }
535        _ => {}
536    }
537
538    Ok(hints)
539}
540
541/// Benchmark quantization performance for different configurations
542pub fn benchmark_quantization_performance(
543    tensor: &Tensor,
544    configs: &[QuantConfig],
545    num_iterations: usize,
546) -> TorshResult<Vec<(QuantConfig, f64, f64)>> {
547    let mut results = Vec::new();
548
549    for config in configs {
550        let mut total_quantize_time = 0.0;
551        let mut total_dequantize_time = 0.0;
552        let mut successful_runs = 0;
553
554        for _ in 0..num_iterations {
555            // Benchmark quantization
556            let quantize_start = std::time::Instant::now();
557            let quantize_result = crate::algorithms::quantize_with_config(tensor, config);
558            let quantize_time = quantize_start.elapsed().as_secs_f64();
559
560            if let Ok((quantized, scale, zero_point)) = quantize_result {
561                // Benchmark dequantization
562                let dequantize_start = std::time::Instant::now();
563                let _dequantized = crate::algorithms::dequantize(&quantized, scale, zero_point)?;
564                let dequantize_time = dequantize_start.elapsed().as_secs_f64();
565
566                total_quantize_time += quantize_time;
567                total_dequantize_time += dequantize_time;
568                successful_runs += 1;
569            }
570        }
571
572        if successful_runs > 0 {
573            let avg_quantize_time = total_quantize_time / successful_runs as f64;
574            let avg_dequantize_time = total_dequantize_time / successful_runs as f64;
575            results.push((config.clone(), avg_quantize_time, avg_dequantize_time));
576        }
577    }
578
579    Ok(results)
580}
581
582#[cfg(test)]
583mod tests {
584    use super::*;
585
586    use torsh_tensor::creation::tensor_1d;
587
588    #[test]
589    fn test_calculate_quantization_metrics() {
590        let original_data = vec![1.0, 2.0, 3.0, 4.0];
591        let quantized_data = vec![1.1, 2.1, 2.9, 3.9];
592
593        let original = tensor_1d(&original_data).unwrap();
594        let quantized = tensor_1d(&quantized_data).unwrap();
595
596        let metrics = calculate_quantization_metrics(&original, &quantized, 32, 8).unwrap();
597
598        // Verify basic metrics
599        assert!(metrics.mse > 0.0);
600        assert!(metrics.mse < 1.0); // Should be small error
601        assert!(metrics.mae > 0.0);
602        assert!(metrics.mae < 1.0);
603        assert!(metrics.psnr > 0.0);
604        assert!(metrics.snr > 0.0);
605        assert!(metrics.max_error >= 0.0);
606        assert!(metrics.zero_error_percentage >= 0.0);
607        assert!(metrics.zero_error_percentage <= 100.0);
608        assert!(metrics.cosine_similarity > 0.8); // Should be high similarity
609        assert_eq!(metrics.compression_ratio, 4.0); // 32-bit to 8-bit
610
611        // Test perfect match (zero error)
612        let metrics_perfect = calculate_quantization_metrics(&original, &original, 32, 16).unwrap();
613        assert_eq!(metrics_perfect.mse, 0.0);
614        assert_eq!(metrics_perfect.mae, 0.0);
615        assert_eq!(metrics_perfect.max_error, 0.0);
616        assert_eq!(metrics_perfect.zero_error_percentage, 100.0);
617        assert!((metrics_perfect.cosine_similarity - 1.0).abs() < 1e-6);
618        assert!(metrics_perfect.psnr.is_infinite());
619        assert!(metrics_perfect.snr.is_infinite());
620        assert_eq!(metrics_perfect.compression_ratio, 2.0); // 32-bit to 16-bit
621    }
622
623    #[test]
624    fn test_compare_quantization_configs() {
625        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
626        let tensor = tensor_1d(&data).unwrap();
627
628        let configs = vec![
629            QuantConfig::int8(),
630            QuantConfig::binary(),
631            QuantConfig::ternary(),
632        ];
633
634        let results = compare_quantization_configs(&tensor, &configs).unwrap();
635
636        // Should return results for all configs
637        assert_eq!(results.len(), 3);
638
639        // Verify all results have valid metrics
640        for (config, metrics, duration) in &results {
641            assert!(configs.iter().any(|c| c.scheme == config.scheme));
642            assert!(duration >= &0.0);
643
644            // Metrics should be reasonable (not infinity for successful quantization)
645            if metrics.psnr.is_finite() {
646                assert!(metrics.psnr > 0.0);
647                assert!(metrics.compression_ratio >= 1.0);
648                assert!(metrics.mae >= 0.0);
649                assert!(metrics.mse >= 0.0);
650            }
651        }
652
653        // Results should be sorted by PSNR (higher is better)
654        for i in 1..results.len() {
655            let prev_psnr = results[i - 1].1.psnr;
656            let curr_psnr = results[i].1.psnr;
657            if prev_psnr.is_finite() && curr_psnr.is_finite() {
658                assert!(prev_psnr >= curr_psnr);
659            }
660        }
661    }
662
663    #[test]
664    fn test_auto_calibrate_quantization() {
665        let tensor1 = tensor_1d(&[1.0, 2.0, 3.0, 4.0]).unwrap();
666        let tensor2 = tensor_1d(&[2.0, 3.0, 4.0, 5.0]).unwrap();
667        let tensor3 = tensor_1d(&[0.5, 1.5, 2.5, 3.5]).unwrap();
668
669        let calibration_tensors = vec![&tensor1, &tensor2, &tensor3];
670
671        // Test with reasonable thresholds
672        let result = auto_calibrate_quantization(&calibration_tensors, 20.0, 10.0);
673        assert!(result.is_ok());
674
675        let config = result.unwrap();
676        assert!(config.validate().is_ok());
677
678        // Test with impossible thresholds (should still return a config)
679        let result_strict = auto_calibrate_quantization(&calibration_tensors, 100.0, 1.1);
680        assert!(result_strict.is_ok());
681
682        // Test empty calibration tensors
683        let empty_tensors = vec![];
684        let result_empty = auto_calibrate_quantization(&empty_tensors, 20.0, 10.0);
685        assert!(result_empty.is_err());
686    }
687
688    #[test]
689    fn test_generate_quantization_report() {
690        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
691        let tensor = tensor_1d(&data).unwrap();
692
693        let configs = vec![QuantConfig::int8(), QuantConfig::binary()];
694
695        let report_result = generate_quantization_report(&tensor, &configs);
696        assert!(report_result.is_ok());
697
698        let report = report_result.unwrap();
699
700        // Verify report contains expected sections
701        assert!(report.contains("# Quantization Analysis Report"));
702        assert!(report.contains("**Original Tensor Shape:**"));
703        assert!(report.contains("**Original Tensor Statistics:**"));
704        assert!(report.contains("## Quantization Configuration Comparison"));
705        assert!(report.contains("## Detailed Metrics"));
706        assert!(report.contains("## Recommendations"));
707
708        // Verify it contains data about our configs
709        assert!(report.contains("PerTensorAffine"));
710        assert!(report.contains("Binary"));
711
712        // Verify it contains statistical information
713        assert!(report.contains("Min:"));
714        assert!(report.contains("Max:"));
715        assert!(report.contains("Mean:"));
716        assert!(report.contains("Std Dev:"));
717        assert!(report.contains("Dynamic Range:"));
718
719        // Verify it contains metrics columns
720        assert!(report.contains("PSNR (dB)"));
721        assert!(report.contains("SNR (dB)"));
722        assert!(report.contains("MAE"));
723        assert!(report.contains("Compression"));
724        assert!(report.contains("Time (ms)"));
725
726        // Verify it contains recommendations
727        assert!(report.contains("**Best Configuration:**"));
728    }
729
730    #[test]
731    fn test_generate_optimization_hints() {
732        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
733        let tensor = tensor_1d(&data).unwrap();
734        let config = QuantConfig::int8();
735
736        let hints = generate_optimization_hints(&tensor, &config).unwrap();
737        // Hints can be empty or non-empty - both are valid outcomes
738        assert!(hints.is_empty() || !hints.is_empty());
739
740        // Test with per-channel config
741        let per_channel_config = QuantConfig::per_channel(0);
742        let hints = generate_optimization_hints(&tensor, &per_channel_config).unwrap();
743        // Just verify the call succeeds - hints may or may not be present
744        assert!(hints.is_empty() || !hints.is_empty());
745    }
746
747    #[test]
748    fn test_benchmark_quantization_performance() {
749        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
750        let tensor = tensor_1d(&data).unwrap();
751
752        let configs = vec![QuantConfig::int8(), QuantConfig::binary()];
753
754        let results = benchmark_quantization_performance(&tensor, &configs, 3).unwrap();
755
756        // Should return timing results for successful configs
757        assert!(results.len() <= configs.len());
758
759        for (config, quantize_time, dequantize_time) in &results {
760            assert!(configs.iter().any(|c| c.scheme == config.scheme));
761            assert!(quantize_time >= &0.0);
762            assert!(dequantize_time >= &0.0);
763        }
764    }
765
766    #[test]
767    fn test_quantization_metrics_edge_cases() {
768        // Test with different shaped tensors (should fail)
769        let tensor1 = tensor_1d(&[1.0, 2.0]).unwrap();
770        let tensor2 = tensor_1d(&[1.0, 2.0, 3.0]).unwrap();
771
772        let result = calculate_quantization_metrics(&tensor1, &tensor2, 32, 8);
773        assert!(result.is_err());
774
775        // Test with zero tensors
776        let zero_tensor = tensor_1d(&[0.0, 0.0, 0.0]).unwrap();
777        let metrics = calculate_quantization_metrics(&zero_tensor, &zero_tensor, 32, 8).unwrap();
778
779        assert_eq!(metrics.mse, 0.0);
780        assert_eq!(metrics.mae, 0.0);
781        assert_eq!(metrics.max_error, 0.0);
782        assert_eq!(metrics.zero_error_percentage, 100.0);
783        assert!(metrics.psnr.is_infinite());
784        assert_eq!(metrics.cosine_similarity, 0.0); // Both vectors are zero
785
786        // Test with very small differences
787        let original = tensor_1d(&[1.0, 2.0, 3.0]).unwrap();
788        let almost_same = tensor_1d(&[1.0000001, 2.0000001, 3.0000001]).unwrap();
789
790        let metrics = calculate_quantization_metrics(&original, &almost_same, 32, 8).unwrap();
791        assert!(metrics.mse < 1e-12);
792        assert!(metrics.mae < 1e-6);
793        assert!(metrics.cosine_similarity > 0.999999);
794        assert!(metrics.psnr > 100.0); // Very high PSNR for very small error
795    }
796
797    #[test]
798    fn test_metrics_default() {
799        let default_metrics = QuantizationMetrics::default();
800        assert_eq!(default_metrics.mse, 0.0);
801        assert_eq!(default_metrics.psnr, 0.0);
802        assert_eq!(default_metrics.snr, 0.0);
803        assert_eq!(default_metrics.mae, 0.0);
804        assert_eq!(default_metrics.max_error, 0.0);
805        assert_eq!(default_metrics.zero_error_percentage, 100.0);
806        assert_eq!(default_metrics.cosine_similarity, 1.0);
807        assert_eq!(default_metrics.compression_ratio, 1.0);
808    }
809}