Skip to main content

ruvector_cnn/quantize/
tensor.rs

1//! Quantized tensor types with metadata.
2//!
3//! This module provides type-safe INT8 tensors with quantization metadata
4//! for efficient neural network inference.
5
6use crate::error::{CnnError, CnnResult};
7use super::params::QuantizationParams;
8use serde::{Deserialize, Serialize};
9
10/// Metadata for a quantized tensor.
11///
12/// Stores the quantization parameters and shape information needed
13/// to correctly interpret and dequantize the tensor data.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct QuantizationMetadata {
16    /// Quantization scale factor.
17    pub scale: f32,
18
19    /// Zero point for asymmetric quantization.
20    pub zero_point: i32,
21
22    /// Tensor shape (e.g., [batch, height, width, channels]).
23    pub shape: Vec<usize>,
24}
25
26impl QuantizationMetadata {
27    /// Create new quantization metadata.
28    pub fn new(scale: f32, zero_point: i32, shape: Vec<usize>) -> Self {
29        Self {
30            scale,
31            zero_point,
32            shape,
33        }
34    }
35
36    /// Total number of elements in the tensor.
37    pub fn numel(&self) -> usize {
38        self.shape.iter().product()
39    }
40
41    /// Validate metadata consistency.
42    pub fn validate(&self) -> CnnResult<()> {
43        if self.scale <= 0.0 {
44            return Err(CnnError::QuantizationError(format!(
45                "scale must be positive, got {}",
46                self.scale
47            )));
48        }
49
50        if self.shape.is_empty() {
51            return Err(CnnError::QuantizationError(
52                "shape cannot be empty".to_string()
53            ));
54        }
55
56        if self.shape.iter().any(|&d| d == 0) {
57            return Err(CnnError::QuantizationError(
58                "shape dimensions must be positive".to_string()
59            ));
60        }
61
62        Ok(())
63    }
64}
65
66/// Quantized tensor with INT8 data and metadata.
67///
68/// Stores quantized values along with the information needed to
69/// dequantize them back to FP32.
70///
71/// ## Invariants (Enforced at Construction)
72///
73/// - **INV-1**: `data.len() == metadata.numel()`
74/// - **INV-2**: `metadata.scale > 0.0`
75/// - **INV-3**: All values in `data` are in range `[qmin, qmax]`
76///
77/// ## Example
78///
79/// ```rust,ignore
80/// use ruvector_cnn::quantize::{QuantizedTensor, QuantizationParams, QuantizationMode};
81///
82/// let fp32_data = vec![1.0, 2.0, -1.0, 0.5];
83/// let shape = vec![4];
84/// let params = QuantizationParams::from_minmax(-2.0, 2.0, QuantizationMode::Symmetric)?;
85///
86/// // Quantize
87/// let quantized = QuantizedTensor::<i8>::quantize(&fp32_data, &shape, &params)?;
88///
89/// // Dequantize
90/// let dequantized = quantized.dequantize()?;
91/// ```
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct QuantizedTensor<T> {
94    /// Quantized data (INT8).
95    data: Vec<T>,
96
97    /// Quantization metadata.
98    metadata: QuantizationMetadata,
99}
100
101impl QuantizedTensor<i8> {
102    /// Create a new quantized tensor with validation.
103    ///
104    /// # Arguments
105    ///
106    /// * `data` - Quantized INT8 values
107    /// * `metadata` - Quantization metadata (scale, zero_point, shape)
108    ///
109    /// # Errors
110    ///
111    /// - If `data.len() != metadata.numel()` (INV-1)
112    /// - If metadata is invalid (INV-2)
113    pub fn new(data: Vec<i8>, metadata: QuantizationMetadata) -> CnnResult<Self> {
114        metadata.validate()?;
115
116        if data.len() != metadata.numel() {
117            return Err(CnnError::InvalidShape {
118                expected: format!("data length {}", metadata.numel()),
119                got: format!("{}", data.len()),
120            });
121        }
122
123        Ok(Self { data, metadata })
124    }
125
126    /// Quantize FP32 data to INT8.
127    ///
128    /// # Arguments
129    ///
130    /// * `fp32_data` - Input FP32 values
131    /// * `shape` - Tensor shape
132    /// * `params` - Quantization parameters
133    ///
134    /// # Returns
135    ///
136    /// Quantized INT8 tensor.
137    ///
138    /// # Example
139    ///
140    /// ```rust,ignore
141    /// let fp32 = vec![1.0, 2.0, -1.0];
142    /// let shape = vec![3];
143    /// let params = QuantizationParams::from_minmax(-2.0, 2.0, QuantizationMode::Symmetric)?;
144    /// let quantized = QuantizedTensor::quantize(&fp32, &shape, &params)?;
145    /// ```
146    pub fn quantize(
147        fp32_data: &[f32],
148        shape: &[usize],
149        params: &QuantizationParams,
150    ) -> CnnResult<Self> {
151        params.validate()?;
152
153        let expected_numel: usize = shape.iter().product();
154        if fp32_data.len() != expected_numel {
155            return Err(CnnError::InvalidShape {
156                expected: format!("data length {}", expected_numel),
157                got: format!("{}", fp32_data.len()),
158            });
159        }
160
161        // Quantize each value
162        let data: Vec<i8> = fp32_data
163            .iter()
164            .map(|&val| params.quantize_value(val))
165            .collect();
166
167        let metadata = QuantizationMetadata::new(
168            params.scale,
169            params.zero_point,
170            shape.to_vec(),
171        );
172
173        Ok(Self { data, metadata })
174    }
175
176    /// Dequantize INT8 data to FP32.
177    ///
178    /// # Returns
179    ///
180    /// FP32 values with the same shape.
181    ///
182    /// # Example
183    ///
184    /// ```rust,ignore
185    /// let dequantized = quantized.dequantize()?;
186    /// assert_eq!(dequantized.len(), quantized.data().len());
187    /// ```
188    pub fn dequantize(&self) -> CnnResult<Vec<f32>> {
189        self.metadata.validate()?;
190
191        let params = QuantizationParams {
192            scale: self.metadata.scale,
193            zero_point: self.metadata.zero_point,
194            qmin: -127,
195            qmax: 127,
196        };
197
198        let fp32_data: Vec<f32> = self.data
199            .iter()
200            .map(|&val| params.dequantize_value(val))
201            .collect();
202
203        Ok(fp32_data)
204    }
205
206    /// Get reference to quantized data.
207    pub fn data(&self) -> &[i8] {
208        &self.data
209    }
210
211    /// Get mutable reference to quantized data.
212    pub fn data_mut(&mut self) -> &mut [i8] {
213        &mut self.data
214    }
215
216    /// Get reference to metadata.
217    pub fn metadata(&self) -> &QuantizationMetadata {
218        &self.metadata
219    }
220
221    /// Get tensor shape.
222    pub fn shape(&self) -> &[usize] {
223        &self.metadata.shape
224    }
225
226    /// Get scale factor.
227    pub fn scale(&self) -> f32 {
228        self.metadata.scale
229    }
230
231    /// Get zero point.
232    pub fn zero_point(&self) -> i32 {
233        self.metadata.zero_point
234    }
235
236    /// Check bounds invariant: all values in `[qmin, qmax]`.
237    ///
238    /// This is a sanity check to ensure data hasn't been corrupted.
239    /// Should always return `true` for properly constructed tensors.
240    pub fn check_bounds(&self, qmin: i8, qmax: i8) -> bool {
241        self.data.iter().all(|&val| val >= qmin && val <= qmax)
242    }
243
244    /// Validate all invariants.
245    ///
246    /// # Invariants
247    ///
248    /// - **INV-1**: `data.len() == metadata.numel()`
249    /// - **INV-2**: `metadata.scale > 0.0`
250    /// - **INV-3**: All values in `[-127, 127]`
251    pub fn validate(&self) -> CnnResult<()> {
252        // INV-1: Length check
253        if self.data.len() != self.metadata.numel() {
254            return Err(CnnError::QuantizationError(format!(
255                "INV-1 violation: data length {} != metadata.numel() {}",
256                self.data.len(),
257                self.metadata.numel()
258            )));
259        }
260
261        // INV-2: Metadata validation (includes scale > 0)
262        self.metadata.validate()?;
263
264        // INV-3: Bounds check
265        if !self.check_bounds(-127, 127) {
266            return Err(CnnError::QuantizationError(
267                "INV-3 violation: some values outside [-127, 127]".to_string()
268            ));
269        }
270
271        Ok(())
272    }
273
274    /// Reshape the tensor to a new shape.
275    ///
276    /// # Arguments
277    ///
278    /// * `new_shape` - New shape (must have same total elements)
279    ///
280    /// # Errors
281    ///
282    /// If `new_shape.iter().product() != self.data.len()`.
283    pub fn reshape(&mut self, new_shape: Vec<usize>) -> CnnResult<()> {
284        let new_numel: usize = new_shape.iter().product();
285        if new_numel != self.data.len() {
286            return Err(CnnError::InvalidShape {
287                expected: format!("numel {}", self.data.len()),
288                got: format!("numel {}", new_numel),
289            });
290        }
291
292        self.metadata.shape = new_shape;
293        Ok(())
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300    use crate::quantize::QuantizationMode;
301
302    fn create_test_params() -> QuantizationParams {
303        QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric).unwrap()
304    }
305
306    #[test]
307    fn test_metadata_creation() {
308        let meta = QuantizationMetadata::new(0.1, 0, vec![2, 3, 4]);
309        assert_eq!(meta.scale, 0.1);
310        assert_eq!(meta.zero_point, 0);
311        assert_eq!(meta.shape, vec![2, 3, 4]);
312        assert_eq!(meta.numel(), 24);
313    }
314
315    #[test]
316    fn test_metadata_validation() {
317        let meta = QuantizationMetadata::new(0.1, 0, vec![2, 3]);
318        assert!(meta.validate().is_ok());
319
320        let invalid = QuantizationMetadata::new(-0.1, 0, vec![2, 3]);
321        assert!(invalid.validate().is_err());
322
323        let empty_shape = QuantizationMetadata::new(0.1, 0, vec![]);
324        assert!(empty_shape.validate().is_err());
325
326        let zero_dim = QuantizationMetadata::new(0.1, 0, vec![2, 0, 3]);
327        assert!(zero_dim.validate().is_err());
328    }
329
330    #[test]
331    fn test_quantize_dequantize() {
332        let fp32_data = vec![1.0, 2.0, -1.0, 0.5, -5.0, 5.0];
333        let shape = vec![6];
334        let params = create_test_params();
335
336        let quantized = QuantizedTensor::quantize(&fp32_data, &shape, &params).unwrap();
337        assert_eq!(quantized.data().len(), 6);
338        assert_eq!(quantized.shape(), &[6]);
339
340        let dequantized = quantized.dequantize().unwrap();
341        assert_eq!(dequantized.len(), 6);
342
343        // Check quantization error is reasonable
344        for (original, restored) in fp32_data.iter().zip(dequantized.iter()) {
345            assert!((original - restored).abs() < 0.2);
346        }
347    }
348
349    #[test]
350    fn test_quantize_shape_mismatch() {
351        let fp32_data = vec![1.0, 2.0, 3.0];
352        let wrong_shape = vec![2, 2]; // 4 elements expected, but 3 provided
353        let params = create_test_params();
354
355        let result = QuantizedTensor::quantize(&fp32_data, &wrong_shape, &params);
356        assert!(result.is_err());
357    }
358
359    #[test]
360    fn test_new_with_invalid_length() {
361        let data = vec![1i8, 2, 3];
362        let metadata = QuantizationMetadata::new(0.1, 0, vec![2, 2]); // Expects 4 elements
363
364        let result = QuantizedTensor::new(data, metadata);
365        assert!(result.is_err());
366    }
367
368    #[test]
369    fn test_bounds_check() {
370        let data = vec![0i8, 50, -50, 127, -127];
371        let metadata = QuantizationMetadata::new(0.1, 0, vec![5]);
372        let tensor = QuantizedTensor::new(data, metadata).unwrap();
373
374        assert!(tensor.check_bounds(-127, 127));
375        assert!(!tensor.check_bounds(-50, 50));
376    }
377
378    #[test]
379    fn test_validate_invariants() {
380        let fp32_data = vec![1.0, 2.0, 3.0];
381        let shape = vec![3];
382        let params = create_test_params();
383
384        let tensor = QuantizedTensor::quantize(&fp32_data, &shape, &params).unwrap();
385
386        // Should pass all invariants
387        assert!(tensor.validate().is_ok());
388    }
389
390    #[test]
391    fn test_reshape() {
392        let fp32_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
393        let shape = vec![6];
394        let params = create_test_params();
395
396        let mut tensor = QuantizedTensor::quantize(&fp32_data, &shape, &params).unwrap();
397
398        // Reshape to 2x3
399        tensor.reshape(vec![2, 3]).unwrap();
400        assert_eq!(tensor.shape(), &[2, 3]);
401
402        // Invalid reshape
403        assert!(tensor.reshape(vec![2, 2]).is_err());
404    }
405
406    #[test]
407    fn test_zero_value() {
408        let fp32_data = vec![0.0, 0.0, 0.0];
409        let shape = vec![3];
410        let params = create_test_params();
411
412        let quantized = QuantizedTensor::quantize(&fp32_data, &shape, &params).unwrap();
413        let dequantized = quantized.dequantize().unwrap();
414
415        for &val in &dequantized {
416            assert!((val).abs() < 0.01);
417        }
418    }
419
420    #[test]
421    fn test_asymmetric_quantization() {
422        let fp32_data = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0];
423        let shape = vec![6];
424        let params = QuantizationParams::from_minmax(0.0, 5.0, QuantizationMode::Asymmetric)
425            .unwrap();
426
427        let quantized = QuantizedTensor::quantize(&fp32_data, &shape, &params).unwrap();
428        assert!(quantized.validate().is_ok());
429
430        let dequantized = quantized.dequantize().unwrap();
431        for (i, (original, restored)) in fp32_data.iter().zip(dequantized.iter()).enumerate() {
432            let error = (original - restored).abs();
433            // Asymmetric quantization maps [0,5] to [-128,127] (255 bins)
434            // Scale = 5.0/255 ~= 0.0196, max quantization error ~= scale
435            assert!(
436                error < 0.6,
437                "Value mismatch at index {}: original={}, restored={}, error={}",
438                i, original, restored, error
439            );
440        }
441    }
442
443    #[test]
444    fn test_getters() {
445        let fp32_data = vec![1.0, 2.0];
446        let shape = vec![2];
447        let params = create_test_params();
448
449        let tensor = QuantizedTensor::quantize(&fp32_data, &shape, &params).unwrap();
450
451        assert_eq!(tensor.data().len(), 2);
452        assert_eq!(tensor.shape(), &[2]);
453        assert!(tensor.scale() > 0.0);
454        assert_eq!(tensor.zero_point(), 0); // Symmetric mode
455    }
456}