Skip to main content

tensorlogic_scirs_backend/
quantization.rs

1//! Quantization Infrastructure for TensorLogic
2//!
3//! This module provides utilities for quantizing tensors to lower precision
4//! formats (INT8, FP16, BF16) for improved memory efficiency and performance.
5//! While full quantized execution requires backend support, this infrastructure
6//! prepares the framework for future quantization-aware training and inference.
7
8use crate::{Scirs2Tensor, TlBackendError, TlBackendResult};
9use scirs2_core::ndarray;
10use serde::{Deserialize, Serialize};
11
12/// Quantization data type.
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
14pub enum QuantizationType {
15    /// 8-bit signed integer quantization
16    Int8,
17    /// 16-bit floating point (IEEE 754 half precision)
18    Fp16,
19    /// 16-bit brain floating point (truncated FP32)
20    BFloat16,
21    /// 4-bit integer quantization (experimental)
22    Int4,
23    /// No quantization (full precision)
24    None,
25}
26
27impl QuantizationType {
28    /// Get the number of bits used by this quantization type.
29    pub fn bits(&self) -> usize {
30        match self {
31            QuantizationType::Int4 => 4,
32            QuantizationType::Int8 => 8,
33            QuantizationType::Fp16 | QuantizationType::BFloat16 => 16,
34            QuantizationType::None => 64, // Assuming f64 for full precision
35        }
36    }
37
38    /// Get the memory compression ratio compared to FP64.
39    pub fn compression_ratio(&self) -> f64 {
40        64.0 / self.bits() as f64
41    }
42
43    /// Check if this is a floating-point quantization.
44    pub fn is_float(&self) -> bool {
45        matches!(
46            self,
47            QuantizationType::Fp16 | QuantizationType::BFloat16 | QuantizationType::None
48        )
49    }
50
51    /// Check if this is an integer quantization.
52    pub fn is_integer(&self) -> bool {
53        matches!(self, QuantizationType::Int8 | QuantizationType::Int4)
54    }
55}
56
57/// Quantization scheme (symmetric vs asymmetric).
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
59pub enum QuantizationScheme {
60    /// Symmetric quantization: range is [-max, max]
61    Symmetric,
62    /// Asymmetric quantization: range is [min, max]
63    Asymmetric,
64}
65
66/// Quantization granularity (per-tensor vs per-channel).
67#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
68pub enum QuantizationGranularity {
69    /// Single scale and zero-point for entire tensor
70    PerTensor,
71    /// Separate scale and zero-point per output channel
72    PerChannel,
73}
74
75/// Quantization parameters for a tensor.
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct QuantizationParams {
78    /// Quantization data type
79    pub qtype: QuantizationType,
80
81    /// Quantization scheme
82    pub scheme: QuantizationScheme,
83
84    /// Quantization granularity
85    pub granularity: QuantizationGranularity,
86
87    /// Scale factor(s) for dequantization
88    pub scale: Vec<f64>,
89
90    /// Zero point(s) for asymmetric quantization
91    pub zero_point: Vec<i32>,
92
93    /// Minimum value(s) in original tensor
94    pub min_val: Vec<f64>,
95
96    /// Maximum value(s) in original tensor
97    pub max_val: Vec<f64>,
98}
99
100impl QuantizationParams {
101    /// Create symmetric per-tensor quantization parameters.
102    pub fn symmetric_per_tensor(qtype: QuantizationType, tensor: &Scirs2Tensor) -> Self {
103        let abs_max = tensor.iter().map(|&x| x.abs()).fold(0.0f64, f64::max);
104
105        let scale = match qtype {
106            QuantizationType::Int8 => abs_max / 127.0,
107            QuantizationType::Int4 => abs_max / 7.0,
108            QuantizationType::Fp16 | QuantizationType::BFloat16 => 1.0,
109            QuantizationType::None => 1.0,
110        };
111
112        Self {
113            qtype,
114            scheme: QuantizationScheme::Symmetric,
115            granularity: QuantizationGranularity::PerTensor,
116            scale: vec![scale],
117            zero_point: vec![0],
118            min_val: vec![-abs_max],
119            max_val: vec![abs_max],
120        }
121    }
122
123    /// Create asymmetric per-tensor quantization parameters.
124    pub fn asymmetric_per_tensor(qtype: QuantizationType, tensor: &Scirs2Tensor) -> Self {
125        let min_val = tensor.iter().fold(f64::INFINITY, |a, &b| a.min(b));
126        let max_val = tensor.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
127
128        let (scale, zero_point) = match qtype {
129            QuantizationType::Int8 => {
130                let scale = (max_val - min_val) / 255.0;
131                let zero_point = (-min_val / scale).round() as i32;
132                (scale, zero_point)
133            }
134            QuantizationType::Int4 => {
135                let scale = (max_val - min_val) / 15.0;
136                let zero_point = (-min_val / scale).round() as i32;
137                (scale, zero_point)
138            }
139            QuantizationType::Fp16 | QuantizationType::BFloat16 | QuantizationType::None => {
140                (1.0, 0)
141            }
142        };
143
144        Self {
145            qtype,
146            scheme: QuantizationScheme::Asymmetric,
147            granularity: QuantizationGranularity::PerTensor,
148            scale: vec![scale],
149            zero_point: vec![zero_point],
150            min_val: vec![min_val],
151            max_val: vec![max_val],
152        }
153    }
154
155    /// Get the dynamic range of this quantization.
156    pub fn dynamic_range(&self) -> f64 {
157        self.max_val[0] - self.min_val[0]
158    }
159
160    /// Get the quantization error bound.
161    pub fn quantization_error_bound(&self) -> f64 {
162        self.scale[0] / 2.0
163    }
164}
165
166/// Simulated quantized tensor (stored as f64 but representing quantized values).
167#[derive(Debug, Clone)]
168pub struct QuantizedTensor {
169    /// The quantized data (stored as f64 for compatibility)
170    pub data: Scirs2Tensor,
171
172    /// Quantization parameters
173    pub params: QuantizationParams,
174}
175
176impl QuantizedTensor {
177    /// Quantize a tensor using the given parameters.
178    pub fn quantize(tensor: &Scirs2Tensor, params: QuantizationParams) -> Self {
179        let quantized_data = match params.qtype {
180            QuantizationType::Int8 => quantize_int8(tensor, &params),
181            QuantizationType::Int4 => quantize_int4(tensor, &params),
182            QuantizationType::Fp16 => quantize_fp16(tensor),
183            QuantizationType::BFloat16 => quantize_bf16(tensor),
184            QuantizationType::None => tensor.clone(),
185        };
186
187        Self {
188            data: quantized_data,
189            params,
190        }
191    }
192
193    /// Dequantize the tensor back to full precision.
194    pub fn dequantize(&self) -> Scirs2Tensor {
195        match self.params.qtype {
196            QuantizationType::Int8 | QuantizationType::Int4 => {
197                dequantize_integer(&self.data, &self.params)
198            }
199            QuantizationType::Fp16 | QuantizationType::BFloat16 => {
200                // Already in f64, just return
201                self.data.clone()
202            }
203            QuantizationType::None => self.data.clone(),
204        }
205    }
206
207    /// Get the memory size reduction ratio.
208    pub fn memory_reduction(&self) -> f64 {
209        self.params.qtype.compression_ratio()
210    }
211
212    /// Calculate the quantization error (MSE).
213    pub fn quantization_error(&self, original: &Scirs2Tensor) -> f64 {
214        let dequantized = self.dequantize();
215        let diff = &dequantized - original;
216        let squared_error: f64 = diff.iter().map(|&x| x * x).sum();
217        squared_error / original.len() as f64
218    }
219}
220
221/// Quantize tensor to INT8, respecting per-channel granularity.
222///
223/// For `PerTensor` granularity, `params.scale[0]` / `params.zero_point[0]` are used
224/// uniformly. For `PerChannel`, each output channel (row in a 2D tensor, outermost
225/// axis in nD) uses its own `scale[c]` / `zero_point[c]`.
226fn quantize_int8(tensor: &Scirs2Tensor, params: &QuantizationParams) -> Scirs2Tensor {
227    match params.granularity {
228        QuantizationGranularity::PerTensor => {
229            let scale = params.scale[0];
230            let zero_point = params.zero_point[0] as f64;
231            tensor.mapv(|x| ((x / scale).round() + zero_point).clamp(-128.0, 127.0))
232        }
233        QuantizationGranularity::PerChannel => {
234            let n_channels = tensor.shape()[0];
235            let mut out = tensor.clone();
236            for (c, mut slab) in out.axis_iter_mut(ndarray::Axis(0)).enumerate() {
237                if c >= params.scale.len() {
238                    // Safety: fall back to first element if params under-specified.
239                    break;
240                }
241                let s = params.scale[c];
242                let zp = params.zero_point[c] as f64;
243                slab.mapv_inplace(|x| ((x / s).round() + zp).clamp(-128.0, 127.0));
244            }
245            let _ = n_channels; // used implicitly via axis_iter_mut
246            out
247        }
248    }
249}
250
251/// Quantize tensor to INT4, respecting per-channel granularity.
252///
253/// For `PerTensor` granularity, `params.scale[0]` / `params.zero_point[0]` are used
254/// uniformly. For `PerChannel`, each output channel (row in a 2D tensor, outermost
255/// axis in nD) uses its own `scale[c]` / `zero_point[c]`.
256fn quantize_int4(tensor: &Scirs2Tensor, params: &QuantizationParams) -> Scirs2Tensor {
257    match params.granularity {
258        QuantizationGranularity::PerTensor => {
259            let scale = params.scale[0];
260            let zero_point = params.zero_point[0] as f64;
261            tensor.mapv(|x| ((x / scale).round() + zero_point).clamp(-8.0, 7.0))
262        }
263        QuantizationGranularity::PerChannel => {
264            let n_channels = tensor.shape()[0];
265            let mut out = tensor.clone();
266            for (c, mut slab) in out.axis_iter_mut(ndarray::Axis(0)).enumerate() {
267                if c >= params.scale.len() {
268                    break;
269                }
270                let s = params.scale[c];
271                let zp = params.zero_point[c] as f64;
272                slab.mapv_inplace(|x| ((x / s).round() + zp).clamp(-8.0, 7.0));
273            }
274            let _ = n_channels;
275            out
276        }
277    }
278}
279
280/// Simulate FP16 quantization (with rounding to FP16 precision).
281fn quantize_fp16(tensor: &Scirs2Tensor) -> Scirs2Tensor {
282    tensor.mapv(|x| {
283        // Simulate FP16 by limiting mantissa precision
284        // FP16 has 10 mantissa bits vs FP64's 52 bits
285        let scaled = x * (1024.0f64).powi(2);
286        (scaled.round() / (1024.0f64).powi(2)).clamp(-65504.0, 65504.0)
287    })
288}
289
290/// Simulate BFloat16 quantization.
291fn quantize_bf16(tensor: &Scirs2Tensor) -> Scirs2Tensor {
292    tensor.mapv(|x| {
293        // BF16 has 7 mantissa bits vs FP64's 52 bits
294        let scaled = x * (128.0f64).powi(2);
295        scaled.round() / (128.0f64).powi(2)
296    })
297}
298
299/// Dequantize integer-quantized tensor, respecting per-channel granularity.
300///
301/// For `PerTensor` granularity, `params.scale[0]` / `params.zero_point[0]` are used
302/// uniformly. For `PerChannel`, each output channel (outermost axis) uses its own
303/// `scale[c]` / `zero_point[c]`.
304fn dequantize_integer(tensor: &Scirs2Tensor, params: &QuantizationParams) -> Scirs2Tensor {
305    match params.granularity {
306        QuantizationGranularity::PerTensor => {
307            let scale = params.scale[0];
308            let zero_point = params.zero_point[0] as f64;
309            tensor.mapv(|q| (q - zero_point) * scale)
310        }
311        QuantizationGranularity::PerChannel => {
312            let mut out = tensor.clone();
313            for (c, mut slab) in out.axis_iter_mut(ndarray::Axis(0)).enumerate() {
314                if c >= params.scale.len() {
315                    break;
316                }
317                let s = params.scale[c];
318                let zp = params.zero_point[c] as f64;
319                slab.mapv_inplace(|q| (q - zp) * s);
320            }
321            out
322        }
323    }
324}
325
326/// Quantization-aware training configuration.
327#[derive(Debug, Clone, Serialize, Deserialize)]
328pub struct QatConfig {
329    /// Target quantization type
330    pub target_qtype: QuantizationType,
331
332    /// Quantization scheme
333    pub scheme: QuantizationScheme,
334
335    /// Number of warmup epochs before enabling quantization
336    pub warmup_epochs: usize,
337
338    /// Whether to use straight-through estimator for gradients
339    pub use_ste: bool,
340
341    /// Whether to learn scale and zero-point parameters
342    pub learnable_params: bool,
343}
344
345impl Default for QatConfig {
346    fn default() -> Self {
347        Self {
348            target_qtype: QuantizationType::Int8,
349            scheme: QuantizationScheme::Symmetric,
350            warmup_epochs: 2,
351            use_ste: true,
352            learnable_params: false,
353        }
354    }
355}
356
357/// Quantization statistics for analysis.
358#[derive(Debug, Clone, Serialize, Deserialize)]
359pub struct QuantizationStats {
360    /// Number of quantized tensors
361    pub num_tensors: usize,
362
363    /// Total memory saved (in bytes)
364    pub memory_saved: u64,
365
366    /// Average quantization error (MSE)
367    pub avg_error: f64,
368
369    /// Maximum quantization error
370    pub max_error: f64,
371
372    /// Distribution of quantization types used
373    pub type_distribution: Vec<(QuantizationType, usize)>,
374}
375
376impl QuantizationStats {
377    /// Create empty statistics.
378    pub fn new() -> Self {
379        Self {
380            num_tensors: 0,
381            memory_saved: 0,
382            avg_error: 0.0,
383            max_error: 0.0,
384            type_distribution: Vec::new(),
385        }
386    }
387
388    /// Update statistics with a new quantized tensor.
389    pub fn update(&mut self, original_size: u64, compression_ratio: f64, error: f64) {
390        self.num_tensors += 1;
391        self.memory_saved += (original_size as f64 * (1.0 - 1.0 / compression_ratio)) as u64;
392
393        // Update running average error
394        let n = self.num_tensors as f64;
395        self.avg_error = (self.avg_error * (n - 1.0) + error) / n;
396        self.max_error = self.max_error.max(error);
397    }
398
399    /// Get memory reduction percentage.
400    pub fn memory_reduction_pct(&self, total_memory: u64) -> f64 {
401        if total_memory == 0 {
402            0.0
403        } else {
404            (self.memory_saved as f64 / total_memory as f64) * 100.0
405        }
406    }
407}
408
409impl Default for QuantizationStats {
410    fn default() -> Self {
411        Self::new()
412    }
413}
414
415/// Calibrate quantization parameters using sample data.
416pub fn calibrate_quantization(
417    samples: &[Scirs2Tensor],
418    qtype: QuantizationType,
419    scheme: QuantizationScheme,
420) -> TlBackendResult<QuantizationParams> {
421    if samples.is_empty() {
422        return Err(TlBackendError::GraphError(
423            "Cannot calibrate with empty samples".to_string(),
424        ));
425    }
426
427    // Collect statistics across all samples
428    let mut global_min = f64::INFINITY;
429    let mut global_max = f64::NEG_INFINITY;
430    let mut global_abs_max = 0.0f64;
431
432    for sample in samples {
433        let sample_min = sample.iter().fold(f64::INFINITY, |a, &b| a.min(b));
434        let sample_max = sample.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
435        let sample_abs_max = sample.iter().map(|&x| x.abs()).fold(0.0f64, f64::max);
436
437        global_min = global_min.min(sample_min);
438        global_max = global_max.max(sample_max);
439        global_abs_max = global_abs_max.max(sample_abs_max);
440    }
441
442    let params = match scheme {
443        QuantizationScheme::Symmetric => {
444            let scale = match qtype {
445                QuantizationType::Int8 => global_abs_max / 127.0,
446                QuantizationType::Int4 => global_abs_max / 7.0,
447                _ => 1.0,
448            };
449
450            QuantizationParams {
451                qtype,
452                scheme,
453                granularity: QuantizationGranularity::PerTensor,
454                scale: vec![scale],
455                zero_point: vec![0],
456                min_val: vec![-global_abs_max],
457                max_val: vec![global_abs_max],
458            }
459        }
460        QuantizationScheme::Asymmetric => {
461            let (scale, zero_point) = match qtype {
462                QuantizationType::Int8 => {
463                    let scale = (global_max - global_min) / 255.0;
464                    let zero_point = (-global_min / scale).round() as i32;
465                    (scale, zero_point)
466                }
467                QuantizationType::Int4 => {
468                    let scale = (global_max - global_min) / 15.0;
469                    let zero_point = (-global_min / scale).round() as i32;
470                    (scale, zero_point)
471                }
472                _ => (1.0, 0),
473            };
474
475            QuantizationParams {
476                qtype,
477                scheme,
478                granularity: QuantizationGranularity::PerTensor,
479                scale: vec![scale],
480                zero_point: vec![zero_point],
481                min_val: vec![global_min],
482                max_val: vec![global_max],
483            }
484        }
485    };
486
487    Ok(params)
488}
489
490#[cfg(test)]
491mod tests {
492    use super::*;
493    use scirs2_core::ndarray::ArrayD;
494
495    #[test]
496    fn test_quantization_type_properties() {
497        assert_eq!(QuantizationType::Int8.bits(), 8);
498        assert_eq!(QuantizationType::Int4.bits(), 4);
499        assert_eq!(QuantizationType::Fp16.bits(), 16);
500        assert_eq!(QuantizationType::BFloat16.bits(), 16);
501
502        assert_eq!(QuantizationType::Int8.compression_ratio(), 8.0);
503        assert_eq!(QuantizationType::Int4.compression_ratio(), 16.0);
504
505        assert!(QuantizationType::Int8.is_integer());
506        assert!(QuantizationType::Fp16.is_float());
507    }
508
509    #[test]
510    fn test_symmetric_quantization_int8() {
511        let data = vec![-10.0, -5.0, 0.0, 5.0, 10.0];
512        let tensor = ArrayD::from_shape_vec(vec![5], data.clone()).expect("unwrap");
513
514        let params = QuantizationParams::symmetric_per_tensor(QuantizationType::Int8, &tensor);
515
516        assert_eq!(params.scheme, QuantizationScheme::Symmetric);
517        assert_eq!(params.zero_point[0], 0);
518        assert!(params.scale[0] > 0.0);
519    }
520
521    #[test]
522    fn test_asymmetric_quantization_int8() {
523        let data = vec![0.0, 2.0, 4.0, 6.0, 8.0];
524        let tensor = ArrayD::from_shape_vec(vec![5], data).expect("unwrap");
525
526        let params = QuantizationParams::asymmetric_per_tensor(QuantizationType::Int8, &tensor);
527
528        assert_eq!(params.scheme, QuantizationScheme::Asymmetric);
529        assert!(params.zero_point[0] >= 0);
530        assert!(params.scale[0] > 0.0);
531    }
532
533    #[test]
534    fn test_quantize_dequantize_int8() {
535        let data = vec![-10.0, -5.0, 0.0, 5.0, 10.0];
536        let tensor = ArrayD::from_shape_vec(vec![5], data.clone()).expect("unwrap");
537
538        let params = QuantizationParams::symmetric_per_tensor(QuantizationType::Int8, &tensor);
539        let quantized = QuantizedTensor::quantize(&tensor, params);
540        let dequantized = quantized.dequantize();
541
542        // Check that dequantized values are close to original
543        for (orig, deq) in tensor.iter().zip(dequantized.iter()) {
544            assert!(
545                (orig - deq).abs() < 0.1,
546                "Original: {}, Dequantized: {}",
547                orig,
548                deq
549            );
550        }
551    }
552
553    #[test]
554    fn test_quantization_error() {
555        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
556        let tensor = ArrayD::from_shape_vec(vec![5], data).expect("unwrap");
557
558        let params = QuantizationParams::symmetric_per_tensor(QuantizationType::Int8, &tensor);
559        let quantized = QuantizedTensor::quantize(&tensor, params);
560
561        let error = quantized.quantization_error(&tensor);
562        assert!(error >= 0.0);
563        assert!(error < 1.0); // Error should be small for this simple case
564    }
565
566    #[test]
567    fn test_memory_reduction() {
568        let tensor = ArrayD::from_shape_vec(vec![100], vec![1.0; 100]).expect("unwrap");
569        let params = QuantizationParams::symmetric_per_tensor(QuantizationType::Int8, &tensor);
570        let quantized = QuantizedTensor::quantize(&tensor, params);
571
572        assert_eq!(quantized.memory_reduction(), 8.0); // 64-bit to 8-bit = 8x compression
573    }
574
575    #[test]
576    fn test_calibrate_quantization() {
577        let sample1 = ArrayD::from_shape_vec(vec![3], vec![-10.0, 0.0, 10.0]).expect("unwrap");
578        let sample2 = ArrayD::from_shape_vec(vec![3], vec![-8.0, 2.0, 12.0]).expect("unwrap");
579        let samples = vec![sample1, sample2];
580
581        let params = calibrate_quantization(
582            &samples,
583            QuantizationType::Int8,
584            QuantizationScheme::Symmetric,
585        )
586        .expect("unwrap");
587
588        assert!(params.scale[0] > 0.0);
589        assert_eq!(params.zero_point[0], 0); // Symmetric
590    }
591
592    #[test]
593    fn test_quantization_stats() {
594        let mut stats = QuantizationStats::new();
595
596        stats.update(1000, 8.0, 0.01);
597        stats.update(2000, 8.0, 0.02);
598
599        assert_eq!(stats.num_tensors, 2);
600        assert!(stats.memory_saved > 0);
601        assert!(stats.avg_error > 0.0);
602        assert_eq!(stats.max_error, 0.02);
603    }
604
605    #[test]
606    fn test_fp16_quantization() {
607        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
608        let tensor = ArrayD::from_shape_vec(vec![5], data.clone()).expect("unwrap");
609
610        let quantized = quantize_fp16(&tensor);
611
612        // FP16 should preserve values reasonably well for small numbers
613        for (orig, quant) in tensor.iter().zip(quantized.iter()) {
614            assert!((orig - quant).abs() < 0.001);
615        }
616    }
617
618    #[test]
619    fn test_qat_config_default() {
620        let config = QatConfig::default();
621
622        assert_eq!(config.target_qtype, QuantizationType::Int8);
623        assert_eq!(config.scheme, QuantizationScheme::Symmetric);
624        assert!(config.use_ste);
625    }
626
627    // ------------------------------------------------------------------
628    // Per-channel quantization correctness tests
629    // ------------------------------------------------------------------
630
631    /// Build a 2×3 per-channel INT8 params where channel 0 spans [-100,100]
632    /// and channel 1 spans [-1, 1], so scale[0] >> scale[1].
633    fn make_per_channel_params_int8() -> QuantizationParams {
634        // Channel 0: abs_max = 100  → scale = 100/127 ≈ 0.787
635        // Channel 1: abs_max = 1    → scale = 1/127   ≈ 0.00787
636        let scale_0 = 100.0_f64 / 127.0;
637        let scale_1 = 1.0_f64 / 127.0;
638        QuantizationParams {
639            qtype: QuantizationType::Int8,
640            scheme: QuantizationScheme::Symmetric,
641            granularity: QuantizationGranularity::PerChannel,
642            scale: vec![scale_0, scale_1],
643            zero_point: vec![0, 0],
644            min_val: vec![-100.0, -1.0],
645            max_val: vec![100.0, 1.0],
646        }
647    }
648
649    #[test]
650    fn test_per_channel_uses_different_scales() {
651        let params = make_per_channel_params_int8();
652        // scales must be meaningfully different (ratio ≈ 100×)
653        assert!(
654            (params.scale[0] - params.scale[1]).abs() > 0.1,
655            "scale[0]={} scale[1]={} should differ",
656            params.scale[0],
657            params.scale[1]
658        );
659    }
660
661    #[test]
662    fn test_per_channel_quantize_int8_uses_channel_scale() {
663        // Row 0: large values [100, -100, 50]
664        // Row 1: small values [1, -1, 0.5]
665        let data = vec![100.0, -100.0, 50.0, 1.0, -1.0, 0.5];
666        let tensor = ArrayD::from_shape_vec(vec![2, 3], data).expect("build tensor");
667
668        let params = make_per_channel_params_int8();
669        let quantized_tensor = QuantizedTensor::quantize(&tensor, params.clone());
670
671        // Row 0 quantized with scale≈0.787: 100/0.787 ≈ 127 → clamped 127
672        let row0_q_first = quantized_tensor
673            .data
674            .slice(ndarray::s![0, ..])
675            .iter()
676            .copied()
677            .next()
678            .unwrap_or(f64::NAN);
679        // Row 1 quantized with scale≈0.00787: 1/0.00787 ≈ 127 → clamped 127
680        let row1_q_first = quantized_tensor
681            .data
682            .slice(ndarray::s![1, ..])
683            .iter()
684            .copied()
685            .next()
686            .unwrap_or(f64::NAN);
687
688        // Both rows should use the full INT8 dynamic range for their magnitudes
689        assert!(
690            (row0_q_first - 127.0).abs() < 2.0,
691            "row0[0]={row0_q_first} expected ≈127"
692        );
693        assert!(
694            (row1_q_first - 127.0).abs() < 2.0,
695            "row1[0]={row1_q_first} expected ≈127"
696        );
697
698        // Dequantize and check round-trip within channel-scale tolerance
699        let dequantized = quantized_tensor.dequantize();
700
701        let orig_r0_c0 = 100.0_f64;
702        let deq_r0_c0 = dequantized
703            .slice(ndarray::s![0, 0])
704            .first()
705            .copied()
706            .unwrap_or(f64::NAN);
707        assert!(
708            (orig_r0_c0 - deq_r0_c0).abs() < 1.0,
709            "round-trip row0[0]: orig={} deq={}",
710            orig_r0_c0,
711            deq_r0_c0
712        );
713
714        let orig_r1_c0 = 1.0_f64;
715        let deq_r1_c0 = dequantized
716            .slice(ndarray::s![1, 0])
717            .first()
718            .copied()
719            .unwrap_or(f64::NAN);
720        assert!(
721            (orig_r1_c0 - deq_r1_c0).abs() < 0.02,
722            "round-trip row1[0]: orig={} deq={}",
723            orig_r1_c0,
724            deq_r1_c0
725        );
726    }
727
728    #[test]
729    fn test_per_channel_roundtrip_preserves_row_fidelity() {
730        // If we accidentally used scale[0] for row 1, the small-valued row
731        // would round to 0 (loss of information). This test asserts that
732        // PerChannel dequantize gives better fidelity for the small row.
733        let data = vec![100.0, -100.0, 50.0, 1.0, -1.0, 0.5];
734        let tensor = ArrayD::from_shape_vec(vec![2, 3], data).expect("build tensor");
735
736        let params = make_per_channel_params_int8();
737        let quantized = QuantizedTensor::quantize(&tensor, params);
738        let dequantized = quantized.dequantize();
739
740        // Row 1 (small values) must be recovered with fine precision
741        let orig_vals = [1.0_f64, -1.0, 0.5];
742        for (col, &expected) in orig_vals.iter().enumerate() {
743            let got = *dequantized
744                .slice(ndarray::s![1, col..col + 1])
745                .iter()
746                .next()
747                .expect("element");
748            assert!(
749                (expected - got).abs() < 0.02,
750                "row1 col{}: expected={} got={}",
751                col,
752                expected,
753                got
754            );
755        }
756    }
757}