1use crate::config::{ObserverType, QuantConfig};
17
18#[cfg(not(feature = "std"))]
19extern crate alloc;
20
21#[cfg(not(feature = "std"))]
22use alloc::{collections::BTreeMap as HashMap, format, string::String, vec::Vec};
23
24use torsh_core::{
25 dtype::DType,
26 error::{Result as TorshResult, TorshError},
27};
28use torsh_tensor::Tensor;
29
30#[derive(Debug, Clone)]
32pub struct QuantizationMetrics {
33 pub mse: f32,
35 pub psnr: f32,
37 pub snr: f32,
39 pub mae: f32,
41 pub max_error: f32,
43 pub zero_error_percentage: f32,
45 pub cosine_similarity: f32,
47 pub compression_ratio: f32,
49}
50
51impl Default for QuantizationMetrics {
52 fn default() -> Self {
53 Self {
54 mse: 0.0,
55 psnr: 0.0,
56 snr: 0.0,
57 mae: 0.0,
58 max_error: 0.0,
59 zero_error_percentage: 100.0,
60 cosine_similarity: 1.0,
61 compression_ratio: 1.0,
62 }
63 }
64}
65
66pub fn calculate_quantization_metrics(
68 original: &Tensor,
69 quantized: &Tensor,
70 original_bits: u32,
71 quantized_bits: u32,
72) -> TorshResult<QuantizationMetrics> {
73 if original.shape() != quantized.shape() {
74 return Err(TorshError::InvalidArgument(format!(
75 "Shape mismatch: expected {:?}, got {:?}",
76 original.shape(),
77 quantized.shape()
78 )));
79 }
80
81 let original_data = original.data()?;
82 let quantized_data = quantized.data()?;
83
84 if original_data.len() != quantized_data.len() {
85 return Err(TorshError::InvalidArgument(
86 "Data length mismatch between tensors".to_string(),
87 ));
88 }
89
90 if original_data.is_empty() {
91 return Ok(QuantizationMetrics::default());
92 }
93
94 let mse = original_data
96 .iter()
97 .zip(quantized_data.iter())
98 .map(|(a, b)| (a - b).powi(2))
99 .sum::<f32>()
100 / original_data.len() as f32;
101
102 let mae = original_data
104 .iter()
105 .zip(quantized_data.iter())
106 .map(|(a, b)| (a - b).abs())
107 .sum::<f32>()
108 / original_data.len() as f32;
109
110 let max_error = original_data
112 .iter()
113 .zip(quantized_data.iter())
114 .map(|(a, b)| (a - b).abs())
115 .fold(0.0, f32::max);
116
117 let zero_errors = original_data
119 .iter()
120 .zip(quantized_data.iter())
121 .filter(|(a, b)| (*a - *b).abs() < 1e-7)
122 .count();
123 let zero_error_percentage = (zero_errors as f32 / original_data.len() as f32) * 100.0;
124
125 let signal_power =
127 original_data.iter().map(|x| x.powi(2)).sum::<f32>() / original_data.len() as f32;
128
129 let max_signal = original_data
131 .iter()
132 .fold(0.0f32, |acc, &x| acc.max(x.abs()));
133 let psnr = if mse > 0.0 {
134 20.0 * (max_signal / mse.sqrt()).log10()
135 } else {
136 f32::INFINITY
137 };
138
139 let snr = if mse > 0.0 && signal_power > 0.0 {
141 10.0 * (signal_power / mse).log10()
142 } else {
143 f32::INFINITY
144 };
145
146 let dot_product = original_data
148 .iter()
149 .zip(quantized_data.iter())
150 .map(|(a, b)| a * b)
151 .sum::<f32>();
152
153 let original_norm = original_data.iter().map(|x| x.powi(2)).sum::<f32>().sqrt();
154 let quantized_norm = quantized_data.iter().map(|x| x.powi(2)).sum::<f32>().sqrt();
155
156 let cosine_similarity = if original_norm > 0.0 && quantized_norm > 0.0 {
157 dot_product / (original_norm * quantized_norm)
158 } else {
159 0.0
160 };
161
162 let compression_ratio = original_bits as f32 / quantized_bits as f32;
164
165 Ok(QuantizationMetrics {
166 mse,
167 psnr,
168 snr,
169 mae,
170 max_error,
171 zero_error_percentage,
172 cosine_similarity,
173 compression_ratio,
174 })
175}
176
177pub fn compare_quantization_configs(
179 tensor: &Tensor,
180 configs: &[QuantConfig],
181) -> TorshResult<Vec<(QuantConfig, QuantizationMetrics, f64)>> {
182 let mut results = Vec::new();
183
184 for config in configs {
185 let start = std::time::Instant::now();
187
188 let quantize_result = crate::algorithms::quantize_with_config(tensor, config);
190
191 let duration = start.elapsed().as_secs_f64();
192
193 match quantize_result {
194 Ok((quantized, scale, zero_point)) => {
195 let dequantized = crate::algorithms::dequantize(&quantized, scale, zero_point)?;
197
198 let original_bits = match tensor.dtype() {
200 DType::F32 => 32,
201 DType::F16 => 16,
202 _ => 8,
203 };
204
205 let quantized_bits = match config.dtype {
206 DType::I8 | DType::U8 => 8,
207 DType::I16 => 16,
208 DType::I32 => 32,
209 DType::F16 => 16,
210 DType::F32 => 32,
211 _ => 8,
212 };
213
214 let metrics = calculate_quantization_metrics(
215 tensor,
216 &dequantized,
217 original_bits,
218 quantized_bits,
219 )?;
220
221 results.push((config.clone(), metrics, duration));
222 }
223 Err(_) => {
224 let worst_metrics = QuantizationMetrics {
226 mse: f32::INFINITY,
227 psnr: f32::NEG_INFINITY,
228 snr: f32::NEG_INFINITY,
229 mae: f32::INFINITY,
230 max_error: f32::INFINITY,
231 zero_error_percentage: 0.0,
232 cosine_similarity: 0.0,
233 compression_ratio: 1.0,
234 };
235
236 results.push((config.clone(), worst_metrics, duration));
237 }
238 }
239 }
240
241 results.sort_by(|a, b| {
243 b.1.psnr
244 .partial_cmp(&a.1.psnr)
245 .unwrap_or(core::cmp::Ordering::Equal)
246 });
247
248 Ok(results)
249}
250
251pub fn auto_calibrate_quantization(
253 calibration_tensors: &[&Tensor],
254 target_accuracy_threshold: f32,
255 max_compression_ratio: f32,
256) -> TorshResult<QuantConfig> {
257 if calibration_tensors.is_empty() {
258 return Err(TorshError::InvalidArgument(
259 "No calibration tensors provided".to_string(),
260 ));
261 }
262
263 let candidate_configs = vec![
265 QuantConfig::int8(),
266 QuantConfig::int8().with_observer(ObserverType::Histogram),
267 QuantConfig::per_channel(0),
268 QuantConfig::per_channel(1),
269 QuantConfig::group_wise(0, 8),
270 QuantConfig::group_wise(1, 16),
271 QuantConfig::int4(),
272 QuantConfig::ternary(),
273 ];
274
275 let mut best_config = None;
276 let mut best_score = f32::NEG_INFINITY;
277
278 for config in candidate_configs {
280 let mut total_metrics = QuantizationMetrics::default();
281 let mut successful_tests = 0;
282
283 for tensor in calibration_tensors {
284 if let Ok(comparison) =
285 compare_quantization_configs(tensor, std::slice::from_ref(&config))
286 {
287 if let Some((_, metrics, _)) = comparison.first() {
288 if metrics.psnr.is_finite() {
289 total_metrics.mse += metrics.mse;
290 total_metrics.psnr += metrics.psnr;
291 total_metrics.snr += metrics.snr;
292 total_metrics.mae += metrics.mae;
293 total_metrics.max_error = total_metrics.max_error.max(metrics.max_error);
294 total_metrics.zero_error_percentage += metrics.zero_error_percentage;
295 total_metrics.cosine_similarity += metrics.cosine_similarity;
296 total_metrics.compression_ratio += metrics.compression_ratio;
297 successful_tests += 1;
298 }
299 }
300 }
301 }
302
303 if successful_tests > 0 {
304 let avg_metrics = QuantizationMetrics {
306 mse: total_metrics.mse / successful_tests as f32,
307 psnr: total_metrics.psnr / successful_tests as f32,
308 snr: total_metrics.snr / successful_tests as f32,
309 mae: total_metrics.mae / successful_tests as f32,
310 max_error: total_metrics.max_error,
311 zero_error_percentage: total_metrics.zero_error_percentage
312 / successful_tests as f32,
313 cosine_similarity: total_metrics.cosine_similarity / successful_tests as f32,
314 compression_ratio: total_metrics.compression_ratio / successful_tests as f32,
315 };
316
317 let score = if avg_metrics.psnr >= target_accuracy_threshold
319 && avg_metrics.compression_ratio <= max_compression_ratio
320 {
321 avg_metrics.compression_ratio + avg_metrics.psnr / 100.0
323 } else {
324 avg_metrics.psnr / avg_metrics.compression_ratio
326 };
327
328 if score > best_score {
329 best_score = score;
330 best_config = Some(config.clone());
331 }
332 }
333 }
334
335 best_config
336 .ok_or_else(|| TorshError::InvalidArgument("No suitable configuration found".to_string()))
337}
338
339pub fn generate_quantization_report(
341 original: &Tensor,
342 configs: &[QuantConfig],
343) -> TorshResult<String> {
344 let mut report = String::new();
345
346 report.push_str("# Quantization Analysis Report\n\n");
347 report.push_str(&format!(
348 "**Original Tensor Shape:** {:?}\n",
349 original.shape()
350 ));
351 report.push_str(&format!(
352 "**Original Tensor DType:** {:?}\n",
353 original.dtype()
354 ));
355 report.push_str(&format!(
356 "**Number of Elements:** {}\n\n",
357 original.shape().numel()
358 ));
359
360 let data = original.data()?;
362 let min_val = data.iter().fold(f32::INFINITY, |acc, &x| acc.min(x));
363 let max_val = data.iter().fold(f32::NEG_INFINITY, |acc, &x| acc.max(x));
364 let mean = data.iter().sum::<f32>() / data.len() as f32;
365 let std_dev = (data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32).sqrt();
366
367 report.push_str("**Original Tensor Statistics:**\n");
368 report.push_str(&format!("- Min: {min_val:.6}\n"));
369 report.push_str(&format!("- Max: {max_val:.6}\n"));
370 report.push_str(&format!("- Mean: {mean:.6}\n"));
371 report.push_str(&format!("- Std Dev: {std_dev:.6}\n"));
372 report.push_str(&format!("- Dynamic Range: {:.6}\n\n", max_val - min_val));
373
374 let comparison_results = compare_quantization_configs(original, configs)?;
376
377 report.push_str("## Quantization Configuration Comparison\n\n");
378 report.push_str(
379 "| Rank | Scheme | Observer | PSNR (dB) | SNR (dB) | MAE | Compression | Time (ms) |\n",
380 );
381 report.push_str(
382 "|------|--------|----------|-----------|----------|-----|-------------|----------|\n",
383 );
384
385 for (rank, (config, metrics, duration)) in comparison_results.iter().enumerate() {
386 report.push_str(&format!(
387 "| {} | {:?} | {:?} | {:.2} | {:.2} | {:.6} | {:.1}x | {:.2} |\n",
388 rank + 1,
389 config.scheme,
390 config.observer_type,
391 metrics.psnr,
392 metrics.snr,
393 metrics.mae,
394 metrics.compression_ratio,
395 duration * 1000.0
396 ));
397 }
398
399 report.push_str("\n## Detailed Metrics\n\n");
400
401 for (rank, (config, metrics, _)) in comparison_results.iter().enumerate() {
402 report.push_str(&format!(
403 "### Configuration #{} - {:?}\n",
404 rank + 1,
405 config.scheme
406 ));
407 report.push_str(&format!("- **MSE:** {:.8}\n", metrics.mse));
408 report.push_str(&format!("- **PSNR:** {:.2} dB\n", metrics.psnr));
409 report.push_str(&format!("- **SNR:** {:.2} dB\n", metrics.snr));
410 report.push_str(&format!("- **MAE:** {:.6}\n", metrics.mae));
411 report.push_str(&format!("- **Max Error:** {:.6}\n", metrics.max_error));
412 report.push_str(&format!(
413 "- **Zero Error %:** {:.2}%\n",
414 metrics.zero_error_percentage
415 ));
416 report.push_str(&format!(
417 "- **Cosine Similarity:** {:.6}\n",
418 metrics.cosine_similarity
419 ));
420 report.push_str(&format!(
421 "- **Compression Ratio:** {:.1}x\n\n",
422 metrics.compression_ratio
423 ));
424 }
425
426 report.push_str("## Recommendations\n\n");
427
428 if let Some((best_config, best_metrics, _)) = comparison_results.first() {
429 report.push_str(&format!(
430 "**Best Configuration:** {:?} with {:?} observer\n",
431 best_config.scheme, best_config.observer_type
432 ));
433 report.push_str(&format!(
434 "- Achieves {:.2} dB PSNR with {:.1}x compression\n",
435 best_metrics.psnr, best_metrics.compression_ratio
436 ));
437
438 if best_metrics.psnr > 40.0 {
439 report.push_str("- ✅ Excellent quality preservation\n");
440 } else if best_metrics.psnr > 30.0 {
441 report.push_str("- ✅ Good quality preservation\n");
442 } else if best_metrics.psnr > 20.0 {
443 report.push_str("- ⚠️ Moderate quality loss\n");
444 } else {
445 report.push_str("- ❌ Significant quality loss\n");
446 }
447 }
448
449 Ok(report)
450}
451
452pub fn generate_optimization_hints(
454 tensor: &Tensor,
455 config: &QuantConfig,
456) -> TorshResult<Vec<String>> {
457 let mut hints = Vec::new();
458 let shape = tensor.shape();
459 let data = tensor.data()?;
460
461 if !data.is_empty() {
463 let min_val = data.iter().fold(f32::INFINITY, |acc, &x| acc.min(x));
464 let max_val = data.iter().fold(f32::NEG_INFINITY, |acc, &x| acc.max(x));
465 let mean = data.iter().sum::<f32>() / data.len() as f32;
466 let std_dev =
467 (data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32).sqrt();
468
469 let dynamic_range = max_val - min_val;
471 if dynamic_range > 100.0 {
472 hints.push("Large dynamic range detected. Consider using Histogram or Percentile observer for better quantization parameters.".to_string());
473 }
474
475 let zero_count = data.iter().filter(|&&x| x.abs() < 1e-6).count();
477 let sparsity = zero_count as f32 / data.len() as f32;
478 if sparsity > 0.5 {
479 hints.push(
480 "High sparsity detected. Sparse quantization schemes may be more efficient."
481 .to_string(),
482 );
483 }
484
485 let outlier_threshold = mean + 3.0 * std_dev;
487 let outlier_count = data
488 .iter()
489 .filter(|&&x| x.abs() > outlier_threshold)
490 .count();
491 if outlier_count > 0 {
492 hints.push("Outliers detected. Percentile-based observers may provide better quantization parameters.".to_string());
493 }
494
495 if data.len() > 1_000_000 {
497 hints.push("For large tensors, Histogram observer may be more memory-efficient than Percentile observer.".to_string());
498 }
499 }
500
501 if shape.dims().len() >= 2 && shape.dims().iter().any(|&dim| dim > 16) {
503 hints.push("Multi-channel tensor detected. Per-channel or group-wise quantization may provide better accuracy.".to_string());
504 }
505
506 match config.scheme {
508 crate::config::QScheme::PerChannelAffine | crate::config::QScheme::PerChannelSymmetric => {
509 if let Some(axis) = config.ch_axis {
510 if axis >= shape.dims().len() {
511 hints.push(
512 "Channel axis is out of bounds. This will cause an error.".to_string(),
513 );
514 } else if shape.dims()[axis] < 4 {
515 hints.push(
516 "Few channels detected. Per-tensor quantization might be sufficient."
517 .to_string(),
518 );
519 }
520 }
521 }
522 crate::config::QScheme::GroupWise => {
523 if let (Some(axis), Some(group_size)) = (config.ch_axis, config.group_size) {
524 if axis < shape.dims().len() {
525 let num_channels = shape.dims()[axis];
526 let num_groups = num_channels.div_ceil(group_size);
527 if num_groups == 1 {
528 hints.push("Only one group will be created. Consider per-tensor quantization instead.".to_string());
529 } else if num_groups == num_channels {
530 hints.push("Each channel forms its own group. Consider per-channel quantization instead.".to_string());
531 }
532 }
533 }
534 }
535 _ => {}
536 }
537
538 Ok(hints)
539}
540
541pub fn benchmark_quantization_performance(
543 tensor: &Tensor,
544 configs: &[QuantConfig],
545 num_iterations: usize,
546) -> TorshResult<Vec<(QuantConfig, f64, f64)>> {
547 let mut results = Vec::new();
548
549 for config in configs {
550 let mut total_quantize_time = 0.0;
551 let mut total_dequantize_time = 0.0;
552 let mut successful_runs = 0;
553
554 for _ in 0..num_iterations {
555 let quantize_start = std::time::Instant::now();
557 let quantize_result = crate::algorithms::quantize_with_config(tensor, config);
558 let quantize_time = quantize_start.elapsed().as_secs_f64();
559
560 if let Ok((quantized, scale, zero_point)) = quantize_result {
561 let dequantize_start = std::time::Instant::now();
563 let _dequantized = crate::algorithms::dequantize(&quantized, scale, zero_point)?;
564 let dequantize_time = dequantize_start.elapsed().as_secs_f64();
565
566 total_quantize_time += quantize_time;
567 total_dequantize_time += dequantize_time;
568 successful_runs += 1;
569 }
570 }
571
572 if successful_runs > 0 {
573 let avg_quantize_time = total_quantize_time / successful_runs as f64;
574 let avg_dequantize_time = total_dequantize_time / successful_runs as f64;
575 results.push((config.clone(), avg_quantize_time, avg_dequantize_time));
576 }
577 }
578
579 Ok(results)
580}
581
582#[cfg(test)]
583mod tests {
584 use super::*;
585
586 use torsh_tensor::creation::tensor_1d;
587
588 #[test]
589 fn test_calculate_quantization_metrics() {
590 let original_data = vec![1.0, 2.0, 3.0, 4.0];
591 let quantized_data = vec![1.1, 2.1, 2.9, 3.9];
592
593 let original = tensor_1d(&original_data).unwrap();
594 let quantized = tensor_1d(&quantized_data).unwrap();
595
596 let metrics = calculate_quantization_metrics(&original, &quantized, 32, 8).unwrap();
597
598 assert!(metrics.mse > 0.0);
600 assert!(metrics.mse < 1.0); assert!(metrics.mae > 0.0);
602 assert!(metrics.mae < 1.0);
603 assert!(metrics.psnr > 0.0);
604 assert!(metrics.snr > 0.0);
605 assert!(metrics.max_error >= 0.0);
606 assert!(metrics.zero_error_percentage >= 0.0);
607 assert!(metrics.zero_error_percentage <= 100.0);
608 assert!(metrics.cosine_similarity > 0.8); assert_eq!(metrics.compression_ratio, 4.0); let metrics_perfect = calculate_quantization_metrics(&original, &original, 32, 16).unwrap();
613 assert_eq!(metrics_perfect.mse, 0.0);
614 assert_eq!(metrics_perfect.mae, 0.0);
615 assert_eq!(metrics_perfect.max_error, 0.0);
616 assert_eq!(metrics_perfect.zero_error_percentage, 100.0);
617 assert!((metrics_perfect.cosine_similarity - 1.0).abs() < 1e-6);
618 assert!(metrics_perfect.psnr.is_infinite());
619 assert!(metrics_perfect.snr.is_infinite());
620 assert_eq!(metrics_perfect.compression_ratio, 2.0); }
622
623 #[test]
624 fn test_compare_quantization_configs() {
625 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
626 let tensor = tensor_1d(&data).unwrap();
627
628 let configs = vec![
629 QuantConfig::int8(),
630 QuantConfig::binary(),
631 QuantConfig::ternary(),
632 ];
633
634 let results = compare_quantization_configs(&tensor, &configs).unwrap();
635
636 assert_eq!(results.len(), 3);
638
639 for (config, metrics, duration) in &results {
641 assert!(configs.iter().any(|c| c.scheme == config.scheme));
642 assert!(duration >= &0.0);
643
644 if metrics.psnr.is_finite() {
646 assert!(metrics.psnr > 0.0);
647 assert!(metrics.compression_ratio >= 1.0);
648 assert!(metrics.mae >= 0.0);
649 assert!(metrics.mse >= 0.0);
650 }
651 }
652
653 for i in 1..results.len() {
655 let prev_psnr = results[i - 1].1.psnr;
656 let curr_psnr = results[i].1.psnr;
657 if prev_psnr.is_finite() && curr_psnr.is_finite() {
658 assert!(prev_psnr >= curr_psnr);
659 }
660 }
661 }
662
663 #[test]
664 fn test_auto_calibrate_quantization() {
665 let tensor1 = tensor_1d(&[1.0, 2.0, 3.0, 4.0]).unwrap();
666 let tensor2 = tensor_1d(&[2.0, 3.0, 4.0, 5.0]).unwrap();
667 let tensor3 = tensor_1d(&[0.5, 1.5, 2.5, 3.5]).unwrap();
668
669 let calibration_tensors = vec![&tensor1, &tensor2, &tensor3];
670
671 let result = auto_calibrate_quantization(&calibration_tensors, 20.0, 10.0);
673 assert!(result.is_ok());
674
675 let config = result.unwrap();
676 assert!(config.validate().is_ok());
677
678 let result_strict = auto_calibrate_quantization(&calibration_tensors, 100.0, 1.1);
680 assert!(result_strict.is_ok());
681
682 let empty_tensors = vec![];
684 let result_empty = auto_calibrate_quantization(&empty_tensors, 20.0, 10.0);
685 assert!(result_empty.is_err());
686 }
687
688 #[test]
689 fn test_generate_quantization_report() {
690 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
691 let tensor = tensor_1d(&data).unwrap();
692
693 let configs = vec![QuantConfig::int8(), QuantConfig::binary()];
694
695 let report_result = generate_quantization_report(&tensor, &configs);
696 assert!(report_result.is_ok());
697
698 let report = report_result.unwrap();
699
700 assert!(report.contains("# Quantization Analysis Report"));
702 assert!(report.contains("**Original Tensor Shape:**"));
703 assert!(report.contains("**Original Tensor Statistics:**"));
704 assert!(report.contains("## Quantization Configuration Comparison"));
705 assert!(report.contains("## Detailed Metrics"));
706 assert!(report.contains("## Recommendations"));
707
708 assert!(report.contains("PerTensorAffine"));
710 assert!(report.contains("Binary"));
711
712 assert!(report.contains("Min:"));
714 assert!(report.contains("Max:"));
715 assert!(report.contains("Mean:"));
716 assert!(report.contains("Std Dev:"));
717 assert!(report.contains("Dynamic Range:"));
718
719 assert!(report.contains("PSNR (dB)"));
721 assert!(report.contains("SNR (dB)"));
722 assert!(report.contains("MAE"));
723 assert!(report.contains("Compression"));
724 assert!(report.contains("Time (ms)"));
725
726 assert!(report.contains("**Best Configuration:**"));
728 }
729
730 #[test]
731 fn test_generate_optimization_hints() {
732 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
733 let tensor = tensor_1d(&data).unwrap();
734 let config = QuantConfig::int8();
735
736 let hints = generate_optimization_hints(&tensor, &config).unwrap();
737 assert!(hints.is_empty() || !hints.is_empty());
739
740 let per_channel_config = QuantConfig::per_channel(0);
742 let hints = generate_optimization_hints(&tensor, &per_channel_config).unwrap();
743 assert!(hints.is_empty() || !hints.is_empty());
745 }
746
747 #[test]
748 fn test_benchmark_quantization_performance() {
749 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
750 let tensor = tensor_1d(&data).unwrap();
751
752 let configs = vec![QuantConfig::int8(), QuantConfig::binary()];
753
754 let results = benchmark_quantization_performance(&tensor, &configs, 3).unwrap();
755
756 assert!(results.len() <= configs.len());
758
759 for (config, quantize_time, dequantize_time) in &results {
760 assert!(configs.iter().any(|c| c.scheme == config.scheme));
761 assert!(quantize_time >= &0.0);
762 assert!(dequantize_time >= &0.0);
763 }
764 }
765
766 #[test]
767 fn test_quantization_metrics_edge_cases() {
768 let tensor1 = tensor_1d(&[1.0, 2.0]).unwrap();
770 let tensor2 = tensor_1d(&[1.0, 2.0, 3.0]).unwrap();
771
772 let result = calculate_quantization_metrics(&tensor1, &tensor2, 32, 8);
773 assert!(result.is_err());
774
775 let zero_tensor = tensor_1d(&[0.0, 0.0, 0.0]).unwrap();
777 let metrics = calculate_quantization_metrics(&zero_tensor, &zero_tensor, 32, 8).unwrap();
778
779 assert_eq!(metrics.mse, 0.0);
780 assert_eq!(metrics.mae, 0.0);
781 assert_eq!(metrics.max_error, 0.0);
782 assert_eq!(metrics.zero_error_percentage, 100.0);
783 assert!(metrics.psnr.is_infinite());
784 assert_eq!(metrics.cosine_similarity, 0.0); let original = tensor_1d(&[1.0, 2.0, 3.0]).unwrap();
788 let almost_same = tensor_1d(&[1.0000001, 2.0000001, 3.0000001]).unwrap();
789
790 let metrics = calculate_quantization_metrics(&original, &almost_same, 32, 8).unwrap();
791 assert!(metrics.mse < 1e-12);
792 assert!(metrics.mae < 1e-6);
793 assert!(metrics.cosine_similarity > 0.999999);
794 assert!(metrics.psnr > 100.0); }
796
797 #[test]
798 fn test_metrics_default() {
799 let default_metrics = QuantizationMetrics::default();
800 assert_eq!(default_metrics.mse, 0.0);
801 assert_eq!(default_metrics.psnr, 0.0);
802 assert_eq!(default_metrics.snr, 0.0);
803 assert_eq!(default_metrics.mae, 0.0);
804 assert_eq!(default_metrics.max_error, 0.0);
805 assert_eq!(default_metrics.zero_error_percentage, 100.0);
806 assert_eq!(default_metrics.cosine_similarity, 1.0);
807 assert_eq!(default_metrics.compression_ratio, 1.0);
808 }
809}