1use crate::error::{CnnError, CnnResult};
7use super::params::QuantizationParams;
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct QuantizationMetadata {
16 pub scale: f32,
18
19 pub zero_point: i32,
21
22 pub shape: Vec<usize>,
24}
25
26impl QuantizationMetadata {
27 pub fn new(scale: f32, zero_point: i32, shape: Vec<usize>) -> Self {
29 Self {
30 scale,
31 zero_point,
32 shape,
33 }
34 }
35
36 pub fn numel(&self) -> usize {
38 self.shape.iter().product()
39 }
40
41 pub fn validate(&self) -> CnnResult<()> {
43 if self.scale <= 0.0 {
44 return Err(CnnError::QuantizationError(format!(
45 "scale must be positive, got {}",
46 self.scale
47 )));
48 }
49
50 if self.shape.is_empty() {
51 return Err(CnnError::QuantizationError(
52 "shape cannot be empty".to_string()
53 ));
54 }
55
56 if self.shape.iter().any(|&d| d == 0) {
57 return Err(CnnError::QuantizationError(
58 "shape dimensions must be positive".to_string()
59 ));
60 }
61
62 Ok(())
63 }
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct QuantizedTensor<T> {
94 data: Vec<T>,
96
97 metadata: QuantizationMetadata,
99}
100
101impl QuantizedTensor<i8> {
102 pub fn new(data: Vec<i8>, metadata: QuantizationMetadata) -> CnnResult<Self> {
114 metadata.validate()?;
115
116 if data.len() != metadata.numel() {
117 return Err(CnnError::InvalidShape {
118 expected: format!("data length {}", metadata.numel()),
119 got: format!("{}", data.len()),
120 });
121 }
122
123 Ok(Self { data, metadata })
124 }
125
126 pub fn quantize(
147 fp32_data: &[f32],
148 shape: &[usize],
149 params: &QuantizationParams,
150 ) -> CnnResult<Self> {
151 params.validate()?;
152
153 let expected_numel: usize = shape.iter().product();
154 if fp32_data.len() != expected_numel {
155 return Err(CnnError::InvalidShape {
156 expected: format!("data length {}", expected_numel),
157 got: format!("{}", fp32_data.len()),
158 });
159 }
160
161 let data: Vec<i8> = fp32_data
163 .iter()
164 .map(|&val| params.quantize_value(val))
165 .collect();
166
167 let metadata = QuantizationMetadata::new(
168 params.scale,
169 params.zero_point,
170 shape.to_vec(),
171 );
172
173 Ok(Self { data, metadata })
174 }
175
176 pub fn dequantize(&self) -> CnnResult<Vec<f32>> {
189 self.metadata.validate()?;
190
191 let params = QuantizationParams {
192 scale: self.metadata.scale,
193 zero_point: self.metadata.zero_point,
194 qmin: -127,
195 qmax: 127,
196 };
197
198 let fp32_data: Vec<f32> = self.data
199 .iter()
200 .map(|&val| params.dequantize_value(val))
201 .collect();
202
203 Ok(fp32_data)
204 }
205
206 pub fn data(&self) -> &[i8] {
208 &self.data
209 }
210
211 pub fn data_mut(&mut self) -> &mut [i8] {
213 &mut self.data
214 }
215
216 pub fn metadata(&self) -> &QuantizationMetadata {
218 &self.metadata
219 }
220
221 pub fn shape(&self) -> &[usize] {
223 &self.metadata.shape
224 }
225
226 pub fn scale(&self) -> f32 {
228 self.metadata.scale
229 }
230
231 pub fn zero_point(&self) -> i32 {
233 self.metadata.zero_point
234 }
235
236 pub fn check_bounds(&self, qmin: i8, qmax: i8) -> bool {
241 self.data.iter().all(|&val| val >= qmin && val <= qmax)
242 }
243
244 pub fn validate(&self) -> CnnResult<()> {
252 if self.data.len() != self.metadata.numel() {
254 return Err(CnnError::QuantizationError(format!(
255 "INV-1 violation: data length {} != metadata.numel() {}",
256 self.data.len(),
257 self.metadata.numel()
258 )));
259 }
260
261 self.metadata.validate()?;
263
264 if !self.check_bounds(-127, 127) {
266 return Err(CnnError::QuantizationError(
267 "INV-3 violation: some values outside [-127, 127]".to_string()
268 ));
269 }
270
271 Ok(())
272 }
273
274 pub fn reshape(&mut self, new_shape: Vec<usize>) -> CnnResult<()> {
284 let new_numel: usize = new_shape.iter().product();
285 if new_numel != self.data.len() {
286 return Err(CnnError::InvalidShape {
287 expected: format!("numel {}", self.data.len()),
288 got: format!("numel {}", new_numel),
289 });
290 }
291
292 self.metadata.shape = new_shape;
293 Ok(())
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300 use crate::quantize::QuantizationMode;
301
302 fn create_test_params() -> QuantizationParams {
303 QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric).unwrap()
304 }
305
306 #[test]
307 fn test_metadata_creation() {
308 let meta = QuantizationMetadata::new(0.1, 0, vec![2, 3, 4]);
309 assert_eq!(meta.scale, 0.1);
310 assert_eq!(meta.zero_point, 0);
311 assert_eq!(meta.shape, vec![2, 3, 4]);
312 assert_eq!(meta.numel(), 24);
313 }
314
315 #[test]
316 fn test_metadata_validation() {
317 let meta = QuantizationMetadata::new(0.1, 0, vec![2, 3]);
318 assert!(meta.validate().is_ok());
319
320 let invalid = QuantizationMetadata::new(-0.1, 0, vec![2, 3]);
321 assert!(invalid.validate().is_err());
322
323 let empty_shape = QuantizationMetadata::new(0.1, 0, vec![]);
324 assert!(empty_shape.validate().is_err());
325
326 let zero_dim = QuantizationMetadata::new(0.1, 0, vec![2, 0, 3]);
327 assert!(zero_dim.validate().is_err());
328 }
329
330 #[test]
331 fn test_quantize_dequantize() {
332 let fp32_data = vec![1.0, 2.0, -1.0, 0.5, -5.0, 5.0];
333 let shape = vec![6];
334 let params = create_test_params();
335
336 let quantized = QuantizedTensor::quantize(&fp32_data, &shape, ¶ms).unwrap();
337 assert_eq!(quantized.data().len(), 6);
338 assert_eq!(quantized.shape(), &[6]);
339
340 let dequantized = quantized.dequantize().unwrap();
341 assert_eq!(dequantized.len(), 6);
342
343 for (original, restored) in fp32_data.iter().zip(dequantized.iter()) {
345 assert!((original - restored).abs() < 0.2);
346 }
347 }
348
349 #[test]
350 fn test_quantize_shape_mismatch() {
351 let fp32_data = vec![1.0, 2.0, 3.0];
352 let wrong_shape = vec![2, 2]; let params = create_test_params();
354
355 let result = QuantizedTensor::quantize(&fp32_data, &wrong_shape, ¶ms);
356 assert!(result.is_err());
357 }
358
359 #[test]
360 fn test_new_with_invalid_length() {
361 let data = vec![1i8, 2, 3];
362 let metadata = QuantizationMetadata::new(0.1, 0, vec![2, 2]); let result = QuantizedTensor::new(data, metadata);
365 assert!(result.is_err());
366 }
367
368 #[test]
369 fn test_bounds_check() {
370 let data = vec![0i8, 50, -50, 127, -127];
371 let metadata = QuantizationMetadata::new(0.1, 0, vec![5]);
372 let tensor = QuantizedTensor::new(data, metadata).unwrap();
373
374 assert!(tensor.check_bounds(-127, 127));
375 assert!(!tensor.check_bounds(-50, 50));
376 }
377
378 #[test]
379 fn test_validate_invariants() {
380 let fp32_data = vec![1.0, 2.0, 3.0];
381 let shape = vec![3];
382 let params = create_test_params();
383
384 let tensor = QuantizedTensor::quantize(&fp32_data, &shape, ¶ms).unwrap();
385
386 assert!(tensor.validate().is_ok());
388 }
389
390 #[test]
391 fn test_reshape() {
392 let fp32_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
393 let shape = vec![6];
394 let params = create_test_params();
395
396 let mut tensor = QuantizedTensor::quantize(&fp32_data, &shape, ¶ms).unwrap();
397
398 tensor.reshape(vec![2, 3]).unwrap();
400 assert_eq!(tensor.shape(), &[2, 3]);
401
402 assert!(tensor.reshape(vec![2, 2]).is_err());
404 }
405
406 #[test]
407 fn test_zero_value() {
408 let fp32_data = vec![0.0, 0.0, 0.0];
409 let shape = vec![3];
410 let params = create_test_params();
411
412 let quantized = QuantizedTensor::quantize(&fp32_data, &shape, ¶ms).unwrap();
413 let dequantized = quantized.dequantize().unwrap();
414
415 for &val in &dequantized {
416 assert!((val).abs() < 0.01);
417 }
418 }
419
420 #[test]
421 fn test_asymmetric_quantization() {
422 let fp32_data = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0];
423 let shape = vec![6];
424 let params = QuantizationParams::from_minmax(0.0, 5.0, QuantizationMode::Asymmetric)
425 .unwrap();
426
427 let quantized = QuantizedTensor::quantize(&fp32_data, &shape, ¶ms).unwrap();
428 assert!(quantized.validate().is_ok());
429
430 let dequantized = quantized.dequantize().unwrap();
431 for (i, (original, restored)) in fp32_data.iter().zip(dequantized.iter()).enumerate() {
432 let error = (original - restored).abs();
433 assert!(
436 error < 0.6,
437 "Value mismatch at index {}: original={}, restored={}, error={}",
438 i, original, restored, error
439 );
440 }
441 }
442
443 #[test]
444 fn test_getters() {
445 let fp32_data = vec![1.0, 2.0];
446 let shape = vec![2];
447 let params = create_test_params();
448
449 let tensor = QuantizedTensor::quantize(&fp32_data, &shape, ¶ms).unwrap();
450
451 assert_eq!(tensor.data().len(), 2);
452 assert_eq!(tensor.shape(), &[2]);
453 assert!(tensor.scale() > 0.0);
454 assert_eq!(tensor.zero_point(), 0); }
456}