1use crate::config::{MixedPrecisionConfig, QuantConfig};
16
17#[cfg(feature = "std")]
18use std::collections::HashMap;
19
20#[cfg(not(feature = "std"))]
21extern crate alloc;
22
23#[cfg(not(feature = "std"))]
24use alloc::{collections::BTreeMap as HashMap, string::String, vec::Vec};
25
26use torsh_core::{
27 dtype::DType,
28 error::{Result as TorshResult, TorshError},
29};
30use torsh_tensor::Tensor;
31
32pub fn quantize_int4_per_tensor(
34 tensor: &Tensor,
35 _config: &QuantConfig,
36) -> TorshResult<(Tensor, f32, i32)> {
37 let data = tensor.data()?;
38 let min_val = data.iter().fold(f32::INFINITY, |a, &b| a.min(b)).min(0.0);
39 let max_val = data
40 .iter()
41 .fold(f32::NEG_INFINITY, |a, &b| a.max(b))
42 .max(0.0);
43
44 let scale = (max_val - min_val) / 15.0; let scale = if scale == 0.0 { 1.0 } else { scale };
47
48 let zero_point = (-8.0 - min_val / scale).round().clamp(-8.0, 7.0) as i32;
49
50 let quantized_data: Vec<f32> = data
51 .iter()
52 .map(|&x| {
53 let quantized = (x / scale).round() + zero_point as f32;
54 quantized.clamp(-8.0, 7.0) })
56 .collect();
57
58 let quantized_tensor = Tensor::from_data(
59 quantized_data,
60 tensor.shape().dims().to_vec(),
61 tensor.device(),
62 )?;
63
64 Ok((quantized_tensor, scale, zero_point))
65}
66
67pub fn quantize_int4_per_channel(
69 tensor: &Tensor,
70 axis: usize,
71 _config: &QuantConfig,
72) -> TorshResult<(Tensor, f32, i32)> {
73 let binding = tensor.shape();
74 let shape = binding.dims();
75
76 if axis >= shape.len() {
77 return Err(TorshError::InvalidArgument(
78 "Axis out of bounds".to_string(),
79 ));
80 }
81
82 let num_channels = shape[axis];
83 let data = tensor.data()?;
84
85 let mut strides = vec![1; shape.len()];
87 for i in (0..shape.len().saturating_sub(1)).rev() {
88 strides[i] = strides[i + 1] * shape[i + 1];
89 }
90
91 let mut scales = Vec::with_capacity(num_channels);
92 let mut zero_points = Vec::with_capacity(num_channels);
93 let mut quantized_data = vec![0.0f32; data.len()];
94
95 for ch in 0..num_channels {
97 let mut channel_min = f32::INFINITY;
98 let mut channel_max = f32::NEG_INFINITY;
99
100 for (i, &val) in data.iter().enumerate() {
102 let mut ch_idx = 0;
103 let mut remaining = i;
104
105 for (dim, &stride) in strides.iter().enumerate() {
107 let coord = remaining / stride;
108 remaining %= stride;
109 if dim == axis {
110 ch_idx = coord;
111 }
112 }
113
114 if ch_idx == ch {
115 channel_min = channel_min.min(val);
116 channel_max = channel_max.max(val);
117 }
118 }
119
120 channel_min = channel_min.min(0.0);
122 channel_max = channel_max.max(0.0);
123
124 let scale = (channel_max - channel_min) / 15.0; let scale = if scale == 0.0 { 1.0 } else { scale };
127 let zero_point = (-8.0 - channel_min / scale).round().clamp(-8.0, 7.0) as i32;
128
129 scales.push(scale);
130 zero_points.push(zero_point);
131
132 for (i, &val) in data.iter().enumerate() {
134 let mut ch_idx = 0;
135 let mut remaining = i;
136
137 for (dim, &stride) in strides.iter().enumerate() {
138 let coord = remaining / stride;
139 remaining %= stride;
140 if dim == axis {
141 ch_idx = coord;
142 }
143 }
144
145 if ch_idx == ch {
146 let quantized = (val / scale).round() + zero_point as f32;
147 quantized_data[i] = quantized.clamp(-8.0, 7.0);
148 }
149 }
150 }
151
152 let quantized_tensor = Tensor::from_data(quantized_data, shape.to_vec(), tensor.device())?;
153
154 let avg_scale = scales.iter().sum::<f32>() / scales.len() as f32;
156 let avg_zero_point =
157 (zero_points.iter().sum::<i32>() as f32 / zero_points.len() as f32).round() as i32;
158
159 Ok((quantized_tensor, avg_scale, avg_zero_point))
160}
161
162pub fn quantize_binary(tensor: &Tensor) -> TorshResult<(Tensor, f32, i32)> {
164 let data = tensor.data()?;
165
166 if data.is_empty() {
167 return Err(TorshError::InvalidArgument(
168 "Cannot quantize empty tensor".to_string(),
169 ));
170 }
171
172 let scale = data.iter().map(|&x| x.abs()).sum::<f32>() / data.len() as f32;
174 let scale = if scale == 0.0 { 1.0 } else { scale };
175
176 let quantized_data: Vec<f32> = data
177 .iter()
178 .map(|&x| if x >= 0.0 { 1.0 } else { -1.0 })
179 .collect();
180
181 let quantized_tensor = Tensor::from_data(
182 quantized_data,
183 tensor.shape().dims().to_vec(),
184 tensor.device(),
185 )?;
186
187 Ok((quantized_tensor, scale, 0)) }
189
190pub fn quantize_ternary(tensor: &Tensor) -> TorshResult<(Tensor, f32, i32)> {
192 let data = tensor.data()?;
193
194 if data.is_empty() {
195 return Err(TorshError::InvalidArgument(
196 "Cannot quantize empty tensor".to_string(),
197 ));
198 }
199
200 let max_abs = data.iter().map(|&x| x.abs()).fold(0.0f32, f32::max);
202 let threshold = max_abs * 0.7; let non_zero_sum: f32 = data
206 .iter()
207 .filter(|&&x| x.abs() > threshold)
208 .map(|&x| x.abs())
209 .sum();
210 let non_zero_count = data.iter().filter(|&&x| x.abs() > threshold).count();
211
212 let scale = if non_zero_count > 0 {
213 non_zero_sum / non_zero_count as f32
214 } else {
215 1.0
216 };
217
218 let quantized_data: Vec<f32> = data
219 .iter()
220 .map(|&x| {
221 if x.abs() <= threshold {
222 0.0
223 } else if x > 0.0 {
224 1.0
225 } else {
226 -1.0
227 }
228 })
229 .collect();
230
231 let quantized_tensor = Tensor::from_data(
232 quantized_data,
233 tensor.shape().dims().to_vec(),
234 tensor.device(),
235 )?;
236
237 Ok((quantized_tensor, scale, 0)) }
239
240pub fn quantize_group_wise(
242 tensor: &Tensor,
243 axis: usize,
244 group_size: usize,
245 config: &QuantConfig,
246) -> TorshResult<(Tensor, f32, i32)> {
247 let binding = tensor.shape();
248 let shape = binding.dims();
249
250 if axis >= shape.len() {
251 return Err(TorshError::InvalidArgument(
252 "Axis out of bounds".to_string(),
253 ));
254 }
255
256 if group_size == 0 {
257 return Err(TorshError::InvalidArgument(
258 "Group size must be greater than 0".to_string(),
259 ));
260 }
261
262 let num_channels = shape[axis];
263 let num_groups = num_channels.div_ceil(group_size); let data = tensor.data()?;
266 let mut quantized_data = vec![0.0f32; data.len()];
267
268 let mut strides = vec![1; shape.len()];
270 for i in (0..shape.len().saturating_sub(1)).rev() {
271 strides[i] = strides[i + 1] * shape[i + 1];
272 }
273
274 let mut group_scales = Vec::new();
275 let mut group_zero_points = Vec::new();
276
277 for group_idx in 0..num_groups {
279 let start_ch = group_idx * group_size;
280 let end_ch = (start_ch + group_size).min(num_channels);
281
282 let mut group_data = Vec::new();
284 for ch in start_ch..end_ch {
285 for (i, _) in data.iter().enumerate() {
287 let idx = i;
288 let mut ch_idx = 0;
289 let mut remaining = idx;
290
291 for (dim, &stride) in strides.iter().enumerate() {
293 let coord = remaining / stride;
294 remaining %= stride;
295 if dim == axis {
296 ch_idx = coord;
297 }
298 }
299
300 if ch_idx == ch {
301 group_data.push(data[i]);
302 }
303 }
304 }
305
306 if group_data.is_empty() {
307 continue;
308 }
309
310 let min_val = group_data
312 .iter()
313 .fold(f32::INFINITY, |a, &b| a.min(b))
314 .min(0.0);
315 let max_val = group_data
316 .iter()
317 .fold(f32::NEG_INFINITY, |a, &b| a.max(b))
318 .max(0.0);
319
320 let (qmin, qmax) = config.get_qint_range();
321 let scale = (max_val - min_val) / (qmax - qmin) as f32;
322 let scale = if scale == 0.0 { 1.0 } else { scale };
323
324 let zero_point = (qmin as f32 - min_val / scale)
325 .round()
326 .max(qmin as f32)
327 .min(qmax as f32) as i32;
328
329 group_scales.push(scale);
330 group_zero_points.push(zero_point);
331
332 for ch in start_ch..end_ch {
334 for i in 0..data.len() {
335 let idx = i;
336 let mut ch_idx = 0;
337 let mut remaining = idx;
338
339 for (dim, &stride) in strides.iter().enumerate() {
341 let coord = remaining / stride;
342 remaining %= stride;
343 if dim == axis {
344 ch_idx = coord;
345 }
346 }
347
348 if ch_idx == ch {
349 let quantized = (data[i] / scale).round() + zero_point as f32;
350 quantized_data[i] = quantized.max(qmin as f32).min(qmax as f32);
351 }
352 }
353 }
354 }
355
356 let quantized_tensor = Tensor::from_data(
357 quantized_data,
358 tensor.shape().dims().to_vec(),
359 tensor.device(),
360 )?;
361
362 let avg_scale = if group_scales.is_empty() {
364 1.0
365 } else {
366 group_scales.iter().sum::<f32>() / group_scales.len() as f32
367 };
368 let avg_zero_point = if group_zero_points.is_empty() {
369 0
370 } else {
371 (group_zero_points.iter().sum::<i32>() as f32 / group_zero_points.len() as f32).round()
372 as i32
373 };
374
375 Ok((quantized_tensor, avg_scale, avg_zero_point))
376}
377
378pub fn quantize_mixed_precision(
380 tensors: &HashMap<String, Tensor>,
381 config: &MixedPrecisionConfig,
382) -> TorshResult<HashMap<String, (Tensor, f32, i32)>> {
383 let mut results = HashMap::new();
384
385 for (layer_name, tensor) in tensors {
386 let precision = determine_layer_precision(layer_name, config);
388
389 let layer_config = create_precision_config(precision);
391
392 let result = crate::algorithms::quantize_with_config(tensor, &layer_config)?;
394 results.insert(layer_name.clone(), result);
395 }
396
397 Ok(results)
398}
399
400pub fn determine_layer_precision(layer_name: &str, config: &MixedPrecisionConfig) -> DType {
402 for (pattern, precision) in &config.layer_precision {
404 if layer_name.contains(pattern) {
405 return *precision;
406 }
407 }
408
409 config.default_precision
411}
412
413pub fn create_precision_config(precision: DType) -> QuantConfig {
415 match precision {
416 DType::I8 => QuantConfig::int8(),
417 DType::U8 => QuantConfig::uint8(),
418 DType::F16 => {
419 QuantConfig {
421 dtype: DType::F16,
422 enable_fake_quant: false,
423 ..Default::default()
424 }
425 }
426 DType::F32 => {
427 QuantConfig {
429 dtype: DType::F32,
430 enable_fake_quant: false,
431 ..Default::default()
432 }
433 }
434 _ => QuantConfig::int8(), }
436}
437
438pub fn quantize_binary_learned_threshold(
440 tensor: &Tensor,
441 threshold: Option<f32>,
442) -> TorshResult<(Tensor, f32, i32, f32)> {
443 let data = tensor.data()?;
444
445 if data.is_empty() {
446 return Err(TorshError::InvalidArgument(
447 "Cannot quantize empty tensor".to_string(),
448 ));
449 }
450
451 let threshold = threshold.unwrap_or_else(|| {
453 let abs_sum: f32 = data.iter().map(|&x| x.abs()).sum();
455 abs_sum / data.len() as f32
456 });
457
458 let above_threshold: Vec<f32> = data
460 .iter()
461 .filter(|&&x| x.abs() > threshold)
462 .cloned()
463 .collect();
464
465 let scale = if above_threshold.is_empty() {
466 1.0
467 } else {
468 above_threshold.iter().map(|&x| x.abs()).sum::<f32>() / above_threshold.len() as f32
469 };
470
471 let quantized_data: Vec<f32> = data
472 .iter()
473 .map(|&x| {
474 if x.abs() <= threshold {
475 0.0
476 } else if x >= 0.0 {
477 1.0
478 } else {
479 -1.0
480 }
481 })
482 .collect();
483
484 let quantized_tensor = Tensor::from_data(
485 quantized_data,
486 tensor.shape().dims().to_vec(),
487 tensor.device(),
488 )?;
489
490 Ok((quantized_tensor, scale, 0, threshold))
491}
492
493pub fn quantize_ternary_adaptive(tensor: &Tensor) -> TorshResult<(Tensor, f32, i32, f32)> {
495 let data = tensor.data()?;
496
497 if data.is_empty() {
498 return Err(TorshError::InvalidArgument(
499 "Cannot quantize empty tensor".to_string(),
500 ));
501 }
502
503 let max_abs = data.iter().map(|&x| x.abs()).fold(0.0f32, f32::max);
505 let mut best_threshold = 0.0;
506 let mut best_error = f32::INFINITY;
507
508 for i in 1..=10 {
510 let threshold = max_abs * (i as f32 * 0.1);
511 let error = calculate_ternary_error(&data, threshold);
512 if error < best_error {
513 best_error = error;
514 best_threshold = threshold;
515 }
516 }
517
518 let non_zero_sum: f32 = data
520 .iter()
521 .filter(|&&x| x.abs() > best_threshold)
522 .map(|&x| x.abs())
523 .sum();
524 let non_zero_count = data.iter().filter(|&&x| x.abs() > best_threshold).count();
525
526 let scale = if non_zero_count > 0 {
527 non_zero_sum / non_zero_count as f32
528 } else {
529 1.0
530 };
531
532 let quantized_data: Vec<f32> = data
533 .iter()
534 .map(|&x| {
535 if x.abs() <= best_threshold {
536 0.0
537 } else if x > 0.0 {
538 1.0
539 } else {
540 -1.0
541 }
542 })
543 .collect();
544
545 let quantized_tensor = Tensor::from_data(
546 quantized_data,
547 tensor.shape().dims().to_vec(),
548 tensor.device(),
549 )?;
550
551 Ok((quantized_tensor, scale, 0, best_threshold))
552}
553
554fn calculate_ternary_error(data: &[f32], threshold: f32) -> f32 {
556 data.iter()
557 .map(|&x| {
558 let quantized = if x.abs() <= threshold {
559 0.0
560 } else if x > 0.0 {
561 1.0
562 } else {
563 -1.0
564 };
565 (x - quantized).powi(2)
566 })
567 .sum::<f32>()
568 / data.len() as f32
569}
570
571#[cfg(test)]
572mod tests {
573 use super::*;
574 use torsh_core::device::DeviceType;
575 use torsh_tensor::creation::tensor_1d;
576
577 #[test]
578 fn test_quantize_int4_per_tensor() {
579 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
580 let tensor = tensor_1d(&data).unwrap();
581 let config = QuantConfig::int4();
582
583 let result = quantize_int4_per_tensor(&tensor, &config);
584 assert!(result.is_ok());
585
586 let (quantized, scale, zero_point) = result.unwrap();
587 assert!(scale > 0.0);
588 assert!(zero_point >= -8 && zero_point <= 7);
589
590 let quantized_data = quantized.data().unwrap();
591 assert_eq!(quantized_data.len(), data.len());
592
593 for &val in &quantized_data {
595 assert!(val >= -8.0 && val <= 7.0);
596 }
597 }
598
599 #[test]
600 fn test_quantize_binary() {
601 let data = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
602 let tensor = tensor_1d(&data).unwrap();
603
604 let result = quantize_binary(&tensor);
605 assert!(result.is_ok());
606
607 let (quantized, scale, zero_point) = result.unwrap();
608 assert!(scale > 0.0);
609 assert_eq!(zero_point, 0); let quantized_data = quantized.data().unwrap();
612 assert_eq!(quantized_data.len(), data.len());
613
614 for &val in &quantized_data {
616 assert!(val == -1.0 || val == 1.0);
617 }
618 }
619
620 #[test]
621 fn test_quantize_ternary() {
622 let data = vec![-3.0, -1.0, 0.1, 1.0, 3.0];
623 let tensor = tensor_1d(&data).unwrap();
624
625 let result = quantize_ternary(&tensor);
626 assert!(result.is_ok());
627
628 let (quantized, scale, zero_point) = result.unwrap();
629 assert!(scale > 0.0);
630 assert_eq!(zero_point, 0); let quantized_data = quantized.data().unwrap();
633 assert_eq!(quantized_data.len(), data.len());
634
635 for &val in &quantized_data {
637 assert!(val == -1.0 || val == 0.0 || val == 1.0);
638 }
639 }
640
641 #[test]
642 fn test_quantize_group_wise() {
643 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
644 let tensor = Tensor::from_data(data, vec![2, 3], DeviceType::Cpu).unwrap();
645 let config = QuantConfig::group_wise(1, 2);
646
647 let result = quantize_group_wise(&tensor, 1, 2, &config);
648 assert!(result.is_ok());
649
650 let (quantized, scale, _zero_point) = result.unwrap();
651 assert!(scale > 0.0);
652 assert_eq!(quantized.shape().dims(), tensor.shape().dims());
653 }
654
655 #[test]
656 fn test_mixed_precision() {
657 let mut tensors = HashMap::new();
658 tensors.insert(
659 "embedding".to_string(),
660 tensor_1d(&[1.0, 2.0, 3.0]).unwrap(),
661 );
662 tensors.insert(
663 "attention".to_string(),
664 tensor_1d(&[4.0, 5.0, 6.0]).unwrap(),
665 );
666
667 let config = MixedPrecisionConfig::default();
668
669 let result = quantize_mixed_precision(&tensors, &config);
670 assert!(result.is_ok());
671
672 let results = result.unwrap();
673 assert_eq!(results.len(), 2);
674 assert!(results.contains_key("embedding"));
675 assert!(results.contains_key("attention"));
676 }
677
678 #[test]
679 fn test_determine_layer_precision() {
680 let config = MixedPrecisionConfig::default();
681
682 let embedding_precision = determine_layer_precision("layer.embedding.weight", &config);
683 assert_eq!(embedding_precision, DType::I8);
684
685 let attention_precision = determine_layer_precision("layer.attention.query", &config);
686 assert_eq!(attention_precision, DType::F16);
687
688 let unknown_precision = determine_layer_precision("layer.unknown.weight", &config);
689 assert_eq!(unknown_precision, DType::I8); }
691
692 #[test]
693 fn test_binary_learned_threshold() {
694 let data = vec![-2.0, -0.1, 0.1, 0.5, 2.0];
695 let tensor = tensor_1d(&data).unwrap();
696
697 let result = quantize_binary_learned_threshold(&tensor, Some(0.3));
698 assert!(result.is_ok());
699
700 let (quantized, scale, zero_point, threshold) = result.unwrap();
701 assert!(scale > 0.0);
702 assert_eq!(zero_point, 0);
703 assert_eq!(threshold, 0.3);
704
705 let quantized_data = quantized.data().unwrap();
706
707 for (i, &original) in data.iter().enumerate() {
709 let expected = if original.abs() <= 0.3 {
710 0.0
711 } else if original >= 0.0 {
712 1.0
713 } else {
714 -1.0
715 };
716 assert_eq!(quantized_data[i], expected);
717 }
718 }
719
720 #[test]
721 fn test_ternary_adaptive() {
722 let data = vec![-3.0, -0.5, 0.0, 0.5, 3.0];
723 let tensor = tensor_1d(&data).unwrap();
724
725 let result = quantize_ternary_adaptive(&tensor);
726 assert!(result.is_ok());
727
728 let (quantized, scale, zero_point, threshold) = result.unwrap();
729 assert!(scale > 0.0);
730 assert_eq!(zero_point, 0);
731 assert!(threshold > 0.0);
732
733 let quantized_data = quantized.data().unwrap();
734 assert_eq!(quantized_data.len(), data.len());
735
736 for &val in &quantized_data {
738 assert!(val == -1.0 || val == 0.0 || val == 1.0);
739 }
740 }
741
742 #[test]
743 fn test_error_cases() {
744 let empty_data: Vec<f32> = vec![];
746 let empty_tensor = tensor_1d(&empty_data).unwrap();
747
748 assert!(quantize_binary(&empty_tensor).is_err());
749 assert!(quantize_ternary(&empty_tensor).is_err());
750
751 let data = vec![1.0, 2.0, 3.0];
753 let tensor = tensor_1d(&data).unwrap();
754 let config = QuantConfig::group_wise(0, 2);
755
756 let result = quantize_group_wise(&tensor, 5, 2, &config);
757 assert!(result.is_err());
758
759 let result = quantize_group_wise(&tensor, 0, 0, &config);
761 assert!(result.is_err());
762 }
763}