1use crate::config::{QScheme, QuantConfig};
15use torsh_core::{
16 dtype::DType,
17 error::{Result as TorshResult, TorshError},
18};
19use torsh_tensor::Tensor;
20
21#[cfg(not(feature = "std"))]
22extern crate alloc;
23
24#[cfg(not(feature = "std"))]
25use alloc::vec::Vec;
26
27use scirs2_core::parallel_ops::*;
28
29pub fn quantize_with_config(
31 tensor: &Tensor,
32 config: &QuantConfig,
33) -> TorshResult<(Tensor, f32, i32)> {
34 config.validate()?;
35
36 match config.scheme {
37 QScheme::PerTensorAffine | QScheme::PerTensorSymmetric => {
38 quantize_tensor_auto(tensor, config.dtype, config.scheme)
39 }
40 QScheme::PerChannelAffine | QScheme::PerChannelSymmetric => {
41 let axis = config.ch_axis.unwrap_or(0);
42 let (quantized, scales, zero_points) =
43 quantize_per_channel_auto(tensor, axis, config.dtype, config.scheme)?;
44 Ok((quantized, scales[0], zero_points[0]))
46 }
47 QScheme::GroupWise => {
48 let axis = config.ch_axis.unwrap_or(0);
49 let group_size = config.group_size.unwrap_or(32);
50 crate::specialized::quantize_group_wise(tensor, axis, group_size, config)
51 }
52 QScheme::Int4PerTensor => crate::specialized::quantize_int4_per_tensor(tensor, config),
53 QScheme::Int4PerChannel => {
54 let axis = config.ch_axis.unwrap_or(0);
55 crate::specialized::quantize_int4_per_channel(tensor, axis, config)
56 }
57 QScheme::Binary => crate::specialized::quantize_binary(tensor),
58 QScheme::Ternary => crate::specialized::quantize_ternary(tensor),
59 QScheme::MixedPrecision => {
60 Err(TorshError::InvalidArgument(
62 "Mixed precision quantization requires specialized API".to_string(),
63 ))
64 }
65 }
66}
67
68pub fn quantize_per_tensor(
70 tensor: &Tensor,
71 scale: f32,
72 zero_point: i32,
73 _dtype: DType,
74) -> TorshResult<Tensor> {
75 let (quantized, _, _) = quantize_per_tensor_affine(tensor, scale, zero_point)?;
76 Ok(quantized)
77}
78
79pub fn dequantize(tensor: &Tensor, scale: f32, zero_point: i32) -> TorshResult<Tensor> {
81 dequantize_per_tensor_affine(tensor, scale, zero_point)
82}
83
84pub fn quantize_tensor_auto(
86 tensor: &Tensor,
87 dtype: DType,
88 scheme: QScheme,
89) -> TorshResult<(Tensor, f32, i32)> {
90 let data = tensor.data()?;
91
92 if data.is_empty() {
93 return Err(TorshError::InvalidArgument(
94 "Cannot quantize empty tensor".to_string(),
95 ));
96 }
97
98 let (min_val, max_val) = if data.len() > 64 && crate::simd_ops::is_simd_available() {
100 crate::simd_ops::find_min_max_simd(&data)?
101 } else {
102 let min_val = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
104 let max_val = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
105 (min_val, max_val)
106 };
107
108 let (scale, zero_point) = match scheme {
110 QScheme::PerTensorAffine => calculate_affine_quantization_params(min_val, max_val, dtype)?,
111 QScheme::PerTensorSymmetric => {
112 calculate_symmetric_quantization_params(min_val, max_val, dtype)?
113 }
114 _ => {
115 return Err(TorshError::InvalidArgument(format!(
116 "Unsupported scheme for auto quantization: {:?}",
117 scheme
118 )));
119 }
120 };
121
122 quantize_per_tensor_affine(tensor, scale, zero_point)
123}
124
125pub fn quantize_per_channel_auto(
127 tensor: &Tensor,
128 axis: usize,
129 dtype: DType,
130 scheme: QScheme,
131) -> TorshResult<(Tensor, Vec<f32>, Vec<i32>)> {
132 let binding = tensor.shape();
133 let shape = binding.dims();
134
135 if axis >= shape.len() {
136 return Err(TorshError::InvalidArgument(
137 "Axis out of bounds".to_string(),
138 ));
139 }
140
141 let num_channels = shape[axis];
142 let data = tensor.data()?;
143
144 let mut strides = vec![1; shape.len()];
146 for i in (0..shape.len() - 1).rev() {
147 strides[i] = strides[i + 1] * shape[i + 1];
148 }
149
150 let mut scales = Vec::with_capacity(num_channels);
151 let mut zero_points = Vec::with_capacity(num_channels);
152 let mut quantized_data = vec![0.0f32; data.len()];
153
154 for ch in 0..num_channels {
156 let mut channel_min = f32::INFINITY;
157 let mut channel_max = f32::NEG_INFINITY;
158
159 let mut indices = vec![0; shape.len()];
161 let channel_size = data.len() / num_channels;
162
163 for i in 0..channel_size {
164 let mut temp_i = i;
166 for dim in (0..shape.len()).rev() {
167 if dim == axis {
168 indices[dim] = ch;
169 } else {
170 let other_dim_size = if dim == axis { 1 } else { shape[dim] };
171 indices[dim] = temp_i % other_dim_size;
172 temp_i /= other_dim_size;
173 }
174 }
175
176 let flat_idx = indices
178 .iter()
179 .zip(strides.iter())
180 .map(|(idx, stride)| idx * stride)
181 .sum::<usize>();
182
183 if flat_idx < data.len() {
184 let val = data[flat_idx];
185 channel_min = channel_min.min(val);
186 channel_max = channel_max.max(val);
187 }
188 }
189
190 let (scale, zero_point) = match scheme {
192 QScheme::PerChannelAffine => {
193 calculate_affine_quantization_params(channel_min, channel_max, dtype)?
194 }
195 QScheme::PerChannelSymmetric => {
196 calculate_symmetric_quantization_params(channel_min, channel_max, dtype)?
197 }
198 _ => {
199 return Err(TorshError::InvalidArgument(format!(
200 "Unsupported scheme for per-channel quantization: {:?}",
201 scheme
202 )));
203 }
204 };
205
206 scales.push(scale);
207 zero_points.push(zero_point);
208
209 for i in 0..channel_size {
211 let mut temp_i = i;
212 for dim in (0..shape.len()).rev() {
213 if dim == axis {
214 indices[dim] = ch;
215 } else {
216 let other_dim_size = if dim == axis { 1 } else { shape[dim] };
217 indices[dim] = temp_i % other_dim_size;
218 temp_i /= other_dim_size;
219 }
220 }
221
222 let flat_idx = indices
223 .iter()
224 .zip(strides.iter())
225 .map(|(idx, stride)| idx * stride)
226 .sum::<usize>();
227
228 if flat_idx < data.len() {
229 let val = data[flat_idx];
230 let quantized = ((val / scale).round() + zero_point as f32).clamp(
231 get_dtype_range(dtype).0 as f32,
232 get_dtype_range(dtype).1 as f32,
233 );
234 quantized_data[flat_idx] = quantized;
235 }
236 }
237 }
238
239 let quantized_tensor = Tensor::from_data(quantized_data, shape.to_vec(), tensor.device())?;
240
241 Ok((quantized_tensor, scales, zero_points))
242}
243
244pub fn quantize_per_tensor_affine_i8(
246 tensor: &Tensor,
247 scale: f32,
248 zero_point: i32,
249) -> TorshResult<(Tensor<i8>, f32, i32)> {
250 let data = tensor.data()?;
251
252 if scale <= 0.0 {
253 return Err(TorshError::InvalidArgument(
254 "Scale must be positive".to_string(),
255 ));
256 }
257
258 let quantized_data: Vec<i8> = data
259 .iter()
260 .map(|&x| {
261 let quantized = (x / scale).round() + zero_point as f32;
262 quantized.clamp(-128.0, 127.0) as i8
264 })
265 .collect();
266
267 let quantized_tensor = Tensor::from_data(
268 quantized_data,
269 tensor.shape().dims().to_vec(),
270 tensor.device(),
271 )?;
272
273 Ok((quantized_tensor, scale, zero_point))
274}
275
276pub fn quantize_per_tensor_affine(
278 tensor: &Tensor,
279 scale: f32,
280 zero_point: i32,
281) -> TorshResult<(Tensor, f32, i32)> {
282 let data = tensor.data()?;
283
284 if scale <= 0.0 {
285 return Err(TorshError::InvalidArgument(
286 "Scale must be positive".to_string(),
287 ));
288 }
289
290 let mut quantized_data = vec![0.0f32; data.len()];
292 if data.len() > 64 && crate::simd_ops::is_simd_available() {
293 crate::simd_ops::quantize_per_tensor_affine_simd(
295 &data,
296 scale,
297 zero_point,
298 &mut quantized_data,
299 )?;
300 } else {
301 for (i, &x) in data.iter().enumerate() {
303 let quantized = (x / scale).round() + zero_point as f32;
304 quantized_data[i] = quantized.clamp(-128.0, 127.0);
305 }
306 }
307
308 let quantized_tensor = Tensor::from_data(
309 quantized_data,
310 tensor.shape().dims().to_vec(),
311 tensor.device(),
312 )?;
313
314 Ok((quantized_tensor, scale, zero_point))
315}
316
317pub fn dequantize_per_tensor_affine(
319 tensor: &Tensor,
320 scale: f32,
321 zero_point: i32,
322) -> TorshResult<Tensor> {
323 let data = tensor.data()?;
324
325 if scale <= 0.0 {
326 return Err(TorshError::InvalidArgument(
327 "Scale must be positive".to_string(),
328 ));
329 }
330
331 let mut dequantized_data = vec![0.0f32; data.len()];
333 if data.len() > 64 && crate::simd_ops::is_simd_available() {
334 crate::simd_ops::dequantize_per_tensor_affine_simd(
336 &data,
337 scale,
338 zero_point,
339 &mut dequantized_data,
340 )?;
341 } else {
342 for (i, &x) in data.iter().enumerate() {
344 dequantized_data[i] = (x - zero_point as f32) * scale;
345 }
346 }
347
348 let dequantized_tensor = Tensor::from_data(
349 dequantized_data,
350 tensor.shape().dims().to_vec(),
351 tensor.device(),
352 )?;
353
354 Ok(dequantized_tensor)
355}
356
357pub fn calculate_affine_quantization_params(
359 min_val: f32,
360 max_val: f32,
361 dtype: DType,
362) -> TorshResult<(f32, i32)> {
363 if !min_val.is_finite() || !max_val.is_finite() {
364 return Err(TorshError::InvalidArgument(
365 "Min and max values must be finite".to_string(),
366 ));
367 }
368
369 if min_val > max_val {
370 return Err(TorshError::InvalidArgument(
371 "Min value must be <= max value".to_string(),
372 ));
373 }
374
375 let (qmin, qmax) = get_dtype_range(dtype);
376 let qmin = qmin as f32;
377 let qmax = qmax as f32;
378
379 if (max_val - min_val).abs() < f32::EPSILON {
381 let scale = 1.0;
382 let zero_point = qmin as i32;
383 return Ok((scale, zero_point));
384 }
385
386 let scale = (max_val - min_val) / (qmax - qmin);
388
389 let zero_point_fp = qmin - min_val / scale;
391 let zero_point = zero_point_fp.round().clamp(qmin, qmax) as i32;
392
393 Ok((scale, zero_point))
394}
395
396pub fn calculate_symmetric_quantization_params(
398 min_val: f32,
399 max_val: f32,
400 dtype: DType,
401) -> TorshResult<(f32, i32)> {
402 if !min_val.is_finite() || !max_val.is_finite() {
403 return Err(TorshError::InvalidArgument(
404 "Min and max values must be finite".to_string(),
405 ));
406 }
407
408 let (_qmin, qmax) = get_dtype_range(dtype);
409 let abs_max = min_val.abs().max(max_val.abs());
410
411 if abs_max < f32::EPSILON {
413 return Ok((1.0, 0));
414 }
415
416 let scale = abs_max / qmax as f32;
419 let zero_point = 0; Ok((scale, zero_point))
422}
423
424pub fn get_dtype_range(dtype: DType) -> (i32, i32) {
426 match dtype {
427 DType::I8 => (-128, 127),
428 DType::U8 => (0, 255),
429 DType::I16 => (-32768, 32767),
430 DType::I32 => (i32::MIN, i32::MAX),
431 _ => (-128, 127), }
433}
434
435pub fn quantize_auto(tensor: &Tensor, config: &QuantConfig) -> TorshResult<(Tensor, f32, i32)> {
437 quantize_with_config(tensor, config)
438}
439
440#[derive(Debug, Clone)]
444pub struct CacheAwareParams {
445 pub cache_line_size: usize,
447 pub l1_cache_size: usize,
449 pub l2_cache_size: usize,
451 pub l3_cache_size: usize,
453 pub prefetch_distance: usize,
455 pub enable_chunking: bool,
457}
458
459impl Default for CacheAwareParams {
460 fn default() -> Self {
461 Self {
462 cache_line_size: 64,
463 l1_cache_size: 32 * 1024, l2_cache_size: 256 * 1024, l3_cache_size: 8 * 1024 * 1024, prefetch_distance: 16,
467 enable_chunking: true,
468 }
469 }
470}
471
472pub fn quantize_per_tensor_affine_cache_aware(
474 input: &[f32],
475 scale: f32,
476 zero_point: i32,
477 output: &mut [f32],
478 cache_params: &CacheAwareParams,
479) -> TorshResult<()> {
480 if input.len() != output.len() {
481 return Err(TorshError::InvalidArgument(
482 "Input and output length mismatch".to_string(),
483 ));
484 }
485
486 if scale <= 0.0 {
487 return Err(TorshError::InvalidArgument(
488 "Scale must be positive".to_string(),
489 ));
490 }
491
492 let inv_scale = 1.0 / scale;
493 let zero_point_f32 = zero_point as f32;
494
495 if !cache_params.enable_chunking || input.len() < cache_params.cache_line_size {
496 for (inp, out) in input.iter().zip(output.iter_mut()) {
498 let quantized = (*inp * inv_scale).round() + zero_point_f32;
499 *out = quantized.clamp(-128.0, 127.0);
500 }
501 return Ok(());
502 }
503
504 let _elements_per_cache_line = cache_params.cache_line_size / std::mem::size_of::<f32>();
506 let optimal_chunk_size =
507 (cache_params.l2_cache_size / std::mem::size_of::<f32>() / 4).min(input.len());
508
509 input
511 .par_chunks(optimal_chunk_size)
512 .zip(output.par_chunks_mut(optimal_chunk_size))
513 .for_each(|(input_chunk, output_chunk)| {
514 for (inp, out) in input_chunk.iter().zip(output_chunk.iter_mut()) {
516 let quantized = (*inp * inv_scale).round() + zero_point_f32;
517 *out = quantized.clamp(-128.0, 127.0);
518 }
519 });
520
521 Ok(())
522}
523
524pub fn calculate_tensor_stats_cache_optimized(
526 data: &[f32],
527 cache_params: &CacheAwareParams,
528) -> TorshResult<(f32, f32, f32, f32)> {
529 if data.is_empty() {
530 return Err(TorshError::InvalidArgument(
531 "Cannot calculate stats of empty tensor".to_string(),
532 ));
533 }
534
535 let optimal_block_size = cache_params.l2_cache_size / std::mem::size_of::<f32>();
536 let block_size = optimal_block_size.min(data.len());
537
538 let results: Vec<(f32, f32, f64, f64)> = data
540 .par_chunks(block_size)
541 .map(|chunk| {
542 let mut local_min = f32::INFINITY;
543 let mut local_max = f32::NEG_INFINITY;
544 let mut local_sum = 0.0f64;
545 let mut local_sum_sq = 0.0f64;
546
547 for &val in chunk {
549 local_min = local_min.min(val);
550 local_max = local_max.max(val);
551 let val_f64 = val as f64;
552 local_sum += val_f64;
553 local_sum_sq += val_f64 * val_f64;
554 }
555
556 (local_min, local_max, local_sum, local_sum_sq)
557 })
558 .collect();
559
560 let mut min_val = f32::INFINITY;
562 let mut max_val = f32::NEG_INFINITY;
563 let mut total_sum = 0.0f64;
564 let mut total_sum_sq = 0.0f64;
565
566 for (local_min, local_max, local_sum, local_sum_sq) in results {
567 min_val = min_val.min(local_min);
568 max_val = max_val.max(local_max);
569 total_sum += local_sum;
570 total_sum_sq += local_sum_sq;
571 }
572
573 let n = data.len() as f64;
574 let mean = (total_sum / n) as f32;
575 let variance = ((total_sum_sq / n) - (mean as f64).powi(2)) as f32;
576
577 Ok((min_val, max_val, mean, variance.sqrt()))
578}
579
580pub fn quantize_matrix_cache_friendly(
582 matrix: &[f32],
583 rows: usize,
584 cols: usize,
585 scale: f32,
586 zero_point: i32,
587 output: &mut [f32],
588 cache_params: &CacheAwareParams,
589) -> TorshResult<()> {
590 if matrix.len() != rows * cols || output.len() != rows * cols {
591 return Err(TorshError::InvalidArgument(
592 "Matrix dimensions don't match buffer sizes".to_string(),
593 ));
594 }
595
596 if scale <= 0.0 {
597 return Err(TorshError::InvalidArgument(
598 "Scale must be positive".to_string(),
599 ));
600 }
601
602 let inv_scale = 1.0 / scale;
603 let zero_point_f32 = zero_point as f32;
604
605 let elements_per_cache_line = cache_params.cache_line_size / std::mem::size_of::<f32>();
607 let l2_elements = cache_params.l2_cache_size / std::mem::size_of::<f32>();
608
609 let max_tile_size = (l2_elements / 4).min(1024); let tile_rows = (max_tile_size / cols).max(1).min(rows);
612 let tile_cols = (max_tile_size / tile_rows)
613 .max(elements_per_cache_line)
614 .min(cols);
615
616 for row_start in (0..rows).step_by(tile_rows) {
618 let row_end = (row_start + tile_rows).min(rows);
619
620 for col_start in (0..cols).step_by(tile_cols) {
621 let col_end = (col_start + tile_cols).min(cols);
622
623 for row in row_start..row_end {
625 for col in col_start..col_end {
626 let idx = row * cols + col;
627 let quantized = (matrix[idx] * inv_scale).round() + zero_point_f32;
628 output[idx] = quantized.clamp(-128.0, 127.0);
629 }
630 }
631 }
632 }
633
634 Ok(())
635}
636
637pub fn quantize_streaming_with_prefetch(
639 input: &[f32],
640 scale: f32,
641 zero_point: i32,
642 output: &mut [f32],
643 cache_params: &CacheAwareParams,
644) -> TorshResult<()> {
645 if input.len() != output.len() {
646 return Err(TorshError::InvalidArgument(
647 "Input and output length mismatch".to_string(),
648 ));
649 }
650
651 if scale <= 0.0 {
652 return Err(TorshError::InvalidArgument(
653 "Scale must be positive".to_string(),
654 ));
655 }
656
657 let inv_scale = 1.0 / scale;
658 let zero_point_f32 = zero_point as f32;
659 let prefetch_distance = cache_params.prefetch_distance;
660
661 for i in 0..input.len() {
663 if i + prefetch_distance < input.len() {
665 let _prefetch_addr = &input[i + prefetch_distance];
667 }
668
669 let quantized = (input[i] * inv_scale).round() + zero_point_f32;
670 output[i] = quantized.clamp(-128.0, 127.0);
671 }
672
673 Ok(())
674}
675
676pub fn get_cache_optimization_recommendations(
678 tensor_size: usize,
679 element_size: usize,
680 cache_params: &CacheAwareParams,
681) -> Vec<String> {
682 let mut recommendations = Vec::new();
683 let total_bytes = tensor_size * element_size;
684
685 if total_bytes <= cache_params.l1_cache_size {
686 recommendations.push("Tensor fits in L1 cache - use simple sequential access".to_string());
687 } else if total_bytes <= cache_params.l2_cache_size {
688 recommendations.push("Tensor fits in L2 cache - consider blocked algorithms".to_string());
689 } else if total_bytes <= cache_params.l3_cache_size {
690 recommendations
691 .push("Tensor fits in L3 cache - use tiled processing with medium blocks".to_string());
692 } else {
693 recommendations
694 .push("Large tensor - use streaming algorithms with prefetching".to_string());
695 recommendations
696 .push("Consider parallel processing to utilize multiple cache hierarchies".to_string());
697 }
698
699 let elements_per_cache_line = cache_params.cache_line_size / element_size;
700 if tensor_size % elements_per_cache_line != 0 {
701 recommendations.push(format!(
702 "Consider padding to align with cache lines ({}B boundaries)",
703 cache_params.cache_line_size
704 ));
705 }
706
707 recommendations
708}
709
710pub fn quantize_with_cache_optimization(
712 input: &[f32],
713 scale: f32,
714 zero_point: i32,
715 output: &mut [f32],
716 cache_params: Option<&CacheAwareParams>,
717) -> TorshResult<()> {
718 let default_params = CacheAwareParams::default();
719 let params = cache_params.unwrap_or(&default_params);
720 let total_bytes = std::mem::size_of_val(input);
721
722 if total_bytes <= params.l1_cache_size {
723 quantize_streaming_with_prefetch(input, scale, zero_point, output, params)
725 } else if total_bytes <= params.l2_cache_size {
726 quantize_per_tensor_affine_cache_aware(input, scale, zero_point, output, params)
728 } else {
729 quantize_per_tensor_affine_cache_aware(input, scale, zero_point, output, params)
731 }
732}
733
734#[cfg(test)]
735mod tests {
736 use super::*;
737 use crate::config::{QScheme, QuantConfig};
738
739 use torsh_tensor::creation::tensor_1d;
740
741 #[test]
742 fn test_calculate_affine_quantization_params() {
743 let (scale, zero_point) =
745 calculate_affine_quantization_params(-1.0, 1.0, DType::I8).unwrap();
746
747 assert!(scale > 0.0);
748 assert!(zero_point >= -128 && zero_point <= 127);
749
750 let (scale, zero_point) =
752 calculate_affine_quantization_params(1.0, 1.0, DType::I8).unwrap();
753
754 assert_eq!(scale, 1.0);
755 assert_eq!(zero_point, -128);
756
757 let result = calculate_affine_quantization_params(2.0, 1.0, DType::I8);
759 assert!(result.is_err());
760 }
761
762 #[test]
763 fn test_calculate_symmetric_quantization_params() {
764 let (scale, zero_point) =
766 calculate_symmetric_quantization_params(-2.0, 1.0, DType::I8).unwrap();
767
768 assert!(scale > 0.0);
769 assert_eq!(zero_point, 0); let (scale, zero_point) =
773 calculate_symmetric_quantization_params(0.0, 0.0, DType::I8).unwrap();
774
775 assert_eq!(scale, 1.0);
776 assert_eq!(zero_point, 0);
777 }
778
779 #[test]
780 fn test_get_dtype_range() {
781 assert_eq!(get_dtype_range(DType::I8), (-128, 127));
782 assert_eq!(get_dtype_range(DType::U8), (0, 255));
783 assert_eq!(get_dtype_range(DType::I16), (-32768, 32767));
784 }
785
786 #[test]
787 fn test_quantize_per_tensor_affine() {
788 let data = vec![1.0, 2.0, 3.0, 4.0];
789 let tensor = tensor_1d(&data).unwrap();
790
791 let (quantized, scale, zero_point) = quantize_per_tensor_affine(&tensor, 0.1, 0).unwrap();
792
793 let quantized_data = quantized.data().unwrap();
794
795 assert_eq!(quantized_data[0], 10.0); assert_eq!(quantized_data[1], 20.0); assert_eq!(scale, 0.1);
799 assert_eq!(zero_point, 0);
800 }
801
802 #[test]
803 fn test_dequantize_per_tensor_affine() {
804 let quantized_data = vec![10.0, 20.0, 30.0, 40.0];
805 let quantized_tensor = tensor_1d(&quantized_data).unwrap();
806
807 let dequantized = dequantize_per_tensor_affine(&quantized_tensor, 0.1, 0).unwrap();
808 let dequantized_data = dequantized.data().unwrap();
809
810 assert!((dequantized_data[0] - 1.0).abs() < 1e-6); assert!((dequantized_data[1] - 2.0).abs() < 1e-6); }
814
815 #[test]
816 fn test_quantize_tensor_auto() {
817 let data = vec![-1.0, 0.0, 1.0, 2.0];
818 let tensor = tensor_1d(&data).unwrap();
819
820 let (quantized, scale, zero_point) =
821 quantize_tensor_auto(&tensor, DType::I8, QScheme::PerTensorAffine).unwrap();
822
823 assert!(scale > 0.0);
824 assert!(zero_point >= -128 && zero_point <= 127);
825
826 let quantized_data = quantized.data().unwrap();
828 assert_eq!(quantized_data.len(), data.len());
829 }
830
831 #[test]
832 fn test_quantize_with_config() {
833 let data = vec![1.0, 2.0, 3.0, 4.0];
834 let tensor = tensor_1d(&data).unwrap();
835 let config = QuantConfig::int8();
836
837 let result = quantize_with_config(&tensor, &config);
838 assert!(result.is_ok());
839
840 let (quantized, scale, zero_point) = result.unwrap();
841 assert!(scale > 0.0);
842 assert!(zero_point >= -128 && zero_point <= 127);
843 assert_eq!(quantized.shape().dims(), tensor.shape().dims());
844 }
845
846 #[test]
847 fn test_dequantize() {
848 let quantized_data = vec![64.0, 128.0, -64.0, 0.0];
849 let quantized_tensor = tensor_1d(&quantized_data).unwrap();
850
851 let dequantized = dequantize(&quantized_tensor, 0.5, 0).unwrap();
852 let dequantized_data = dequantized.data().unwrap();
853
854 assert!((dequantized_data[0] - 32.0).abs() < 1e-6);
856 assert!((dequantized_data[1] - 64.0).abs() < 1e-6);
857 assert!((dequantized_data[2] + 32.0).abs() < 1e-6);
858 assert!((dequantized_data[3] - 0.0).abs() < 1e-6);
859 }
860
861 #[test]
862 fn test_quantize_auto() {
863 let data = vec![0.5, 1.0, 1.5, 2.0];
864 let tensor = tensor_1d(&data).unwrap();
865 let config = QuantConfig::int8();
866
867 let result = quantize_auto(&tensor, &config);
868 assert!(result.is_ok());
869
870 let (quantized, scale, _zero_point) = result.unwrap();
871 assert!(scale > 0.0);
872 assert_eq!(quantized.shape().dims(), tensor.shape().dims());
873 }
874
875 #[test]
876 fn test_error_cases() {
877 let data = vec![1.0, 2.0];
879 let tensor = tensor_1d(&data).unwrap();
880
881 let result = quantize_per_tensor_affine(&tensor, -1.0, 0);
882 assert!(result.is_err());
883
884 let result = dequantize_per_tensor_affine(&tensor, 0.0, 0);
885 assert!(result.is_err());
886
887 let empty_data: Vec<f32> = vec![];
889 let empty_tensor = tensor_1d(&empty_data).unwrap();
890
891 let result = quantize_tensor_auto(&empty_tensor, DType::I8, QScheme::PerTensorAffine);
892 assert!(result.is_err());
893 }
894
895 #[test]
898 fn test_cache_aware_quantization() {
899 let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
900 let mut output = vec![0.0; 8];
901 let cache_params = CacheAwareParams::default();
902
903 let result =
904 quantize_per_tensor_affine_cache_aware(&input, 0.1, 0, &mut output, &cache_params);
905
906 assert!(result.is_ok());
907 assert_eq!(output[0], 10.0);
908 assert_eq!(output[7], 80.0);
909 }
910
911 #[test]
912 fn test_cache_optimized_stats() {
913 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
914 let cache_params = CacheAwareParams::default();
915
916 let result = calculate_tensor_stats_cache_optimized(&data, &cache_params);
917 assert!(result.is_ok());
918
919 let (min_val, max_val, mean, std_dev) = result.unwrap();
920 assert_eq!(min_val, 1.0);
921 assert_eq!(max_val, 10.0);
922 assert!((mean - 5.5).abs() < 0.001);
923 assert!(std_dev > 0.0);
924 }
925
926 #[test]
927 fn test_matrix_cache_friendly_quantization() {
928 let matrix = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
929 let mut output = vec![0.0; 9];
930 let cache_params = CacheAwareParams::default();
931
932 let result =
933 quantize_matrix_cache_friendly(&matrix, 3, 3, 0.1, 0, &mut output, &cache_params);
934
935 assert!(result.is_ok());
936 assert_eq!(output[0], 10.0); assert_eq!(output[8], 90.0); }
939
940 #[test]
941 fn test_streaming_with_prefetch() {
942 let input = vec![0.1, 0.2, 0.3, 0.4, 0.5];
943 let mut output = vec![0.0; 5];
944 let cache_params = CacheAwareParams::default();
945
946 let result = quantize_streaming_with_prefetch(&input, 0.01, 10, &mut output, &cache_params);
947
948 assert!(result.is_ok());
949 assert_eq!(output[0], 20.0);
951 }
952
953 #[test]
954 fn test_cache_optimization_recommendations() {
955 let cache_params = CacheAwareParams::default();
956
957 let recommendations = get_cache_optimization_recommendations(1000, 4, &cache_params);
959 assert!(!recommendations.is_empty());
960 assert!(recommendations[0].contains("L1 cache"));
961
962 let large_recommendations =
964 get_cache_optimization_recommendations(10_000_000, 4, &cache_params);
965 assert!(large_recommendations
966 .iter()
967 .any(|r| r.contains("streaming")));
968 }
969
970 #[test]
971 fn test_auto_cache_optimization() {
972 let input = vec![1.0, 2.0, 3.0, 4.0];
973 let mut output = vec![0.0; 4];
974
975 let result = quantize_with_cache_optimization(&input, 0.1, 0, &mut output, None);
977
978 assert!(result.is_ok());
979 assert_eq!(output[0], 10.0);
980 assert_eq!(output[3], 40.0);
981 }
982
983 #[test]
984 fn test_cache_params_default() {
985 let params = CacheAwareParams::default();
986
987 assert_eq!(params.cache_line_size, 64);
988 assert_eq!(params.l1_cache_size, 32 * 1024);
989 assert_eq!(params.l2_cache_size, 256 * 1024);
990 assert_eq!(params.l3_cache_size, 8 * 1024 * 1024);
991 assert_eq!(params.prefetch_distance, 16);
992 assert!(params.enable_chunking);
993 }
994
995 #[test]
996 fn test_cache_aware_error_cases() {
997 let input = vec![1.0, 2.0];
998 let mut output = vec![0.0; 3]; let cache_params = CacheAwareParams::default();
1000
1001 let result =
1002 quantize_per_tensor_affine_cache_aware(&input, 0.1, 0, &mut output, &cache_params);
1003 assert!(result.is_err());
1004
1005 let mut output_correct = vec![0.0; 2];
1007 let result = quantize_per_tensor_affine_cache_aware(
1008 &input,
1009 -0.1,
1010 0,
1011 &mut output_correct,
1012 &cache_params,
1013 );
1014 assert!(result.is_err());
1015
1016 let empty_data: Vec<f32> = vec![];
1018 let result = calculate_tensor_stats_cache_optimized(&empty_data, &cache_params);
1019 assert!(result.is_err());
1020 }
1021}