1use super::super::{QuantizationMethod, QuantizationParams};
8use crate::error::LinalgResult;
9use scirs2_core::ndarray::{ArrayView1, ArrayView2};
10use std::fmt::Debug;
11
12#[allow(dead_code)]
18pub fn find_min_max<F>(matrix: &ArrayView2<F>) -> (f32, f32)
19where
20 F: scirs2_core::numeric::Float + scirs2_core::numeric::AsPrimitive<f32>,
21{
22 let mut min_val = f32::MAX;
23 let mut max_val = f32::MIN;
24
25 for &val in matrix.iter() {
26 let val_f32 = val.as_();
27 if val_f32.is_finite() {
28 min_val = min_val.min(val_f32);
29 max_val = max_val.max(val_f32);
30 }
31 }
32
33 if !min_val.is_finite() || !max_val.is_finite() {
35 min_val = 0.0;
36 max_val = 1.0;
37 }
38
39 if min_val == max_val {
40 min_val -= 1.0;
41 max_val += 1.0;
42 }
43
44 (min_val, max_val)
45}
46
47#[allow(dead_code)]
49pub fn find_min_max_vec<F>(vector: &ArrayView1<F>) -> (f32, f32)
50where
51 F: scirs2_core::numeric::Float + scirs2_core::numeric::AsPrimitive<f32>,
52{
53 let mut min_val = f32::MAX;
54 let mut max_val = f32::MIN;
55
56 for &val in vector.iter() {
57 let val_f32 = val.as_();
58 if val_f32.is_finite() {
59 min_val = min_val.min(val_f32);
60 max_val = max_val.max(val_f32);
61 }
62 }
63
64 if !min_val.is_finite() || !max_val.is_finite() {
66 min_val = 0.0;
67 max_val = 1.0;
68 }
69
70 if min_val == max_val {
71 min_val -= 1.0;
72 max_val += 1.0;
73 }
74
75 (min_val, max_val)
76}
77
78#[allow(dead_code)]
80pub fn create_histogram<F>(
81 matrix: &ArrayView2<F>,
82 min_val: f32,
83 max_val: f32,
84 num_bins: usize,
85) -> Vec<usize>
86where
87 F: scirs2_core::numeric::Float + scirs2_core::numeric::AsPrimitive<f32>,
88{
89 let mut histogram = vec![0; num_bins];
90 let bin_width = (max_val - min_val) / num_bins as f32;
91
92 if bin_width == 0.0 {
93 histogram[num_bins / 2] = matrix.len();
95 return histogram;
96 }
97
98 for &_val in matrix.iter() {
99 let val_f32 = _val.as_();
100 if val_f32.is_finite() {
101 let bin_idx = ((val_f32 - min_val) / bin_width).floor() as usize;
102 let bin_idx = bin_idx.min(num_bins - 1); histogram[bin_idx] += 1;
104 }
105 }
106
107 histogram
108}
109
110#[allow(dead_code)]
112pub fn create_histogram_vec<F>(
113 vector: &ArrayView1<F>,
114 min_val: f32,
115 max_val: f32,
116 num_bins: usize,
117) -> Vec<usize>
118where
119 F: scirs2_core::numeric::Float + scirs2_core::numeric::AsPrimitive<f32>,
120{
121 let mut histogram = vec![0; num_bins];
122 let bin_width = (max_val - min_val) / num_bins as f32;
123
124 if bin_width == 0.0 {
125 histogram[num_bins / 2] = vector.len();
127 return histogram;
128 }
129
130 for &_val in vector.iter() {
131 let val_f32 = _val.as_();
132 if val_f32.is_finite() {
133 let bin_idx = ((val_f32 - min_val) / bin_width).floor() as usize;
134 let bin_idx = bin_idx.min(num_bins - 1); histogram[bin_idx] += 1;
136 }
137 }
138
139 histogram
140}
141
142#[allow(dead_code)]
144pub fn optimize_thresholds_kl_divergence(
145 histogram: &[usize],
146 min_val: f32,
147 max_val: f32,
148 bits: u8,
149 symmetric: bool,
150) -> (f32, f32) {
151 let num_bins = histogram.len();
152 let bin_width = (max_val - min_val) / num_bins as f32;
153
154 let total_count = histogram.iter().sum::<usize>() as f32;
156 let distribution: Vec<f32> = histogram
157 .iter()
158 .map(|&count| count as f32 / total_count)
159 .collect();
160
161 let levels = if symmetric {
163 (1 << (bits - 1)) as usize } else {
165 (1 << bits) as usize };
167
168 if symmetric {
170 let mut best_abs_max = max_val.abs().max(min_val.abs());
172 let mut min_kl = f32::MAX;
173
174 let step = (best_abs_max / 20.0).max(1e-6);
176 for i in 0..40 {
177 let abs_max = best_abs_max - 20.0 * step + i as f32 * step;
178 if abs_max <= 0.0 {
179 continue;
180 }
181
182 let quantization_step = abs_max / (levels - 1) as f32;
184
185 let kl = calculate_kl_divergence_symmetric(
187 &distribution,
188 min_val,
189 max_val,
190 bin_width,
191 abs_max,
192 quantization_step,
193 );
194
195 if kl < min_kl {
196 min_kl = kl;
197 best_abs_max = abs_max;
198 }
199 }
200
201 (-best_abs_max, best_abs_max)
203 } else {
204 let mut best_min = min_val;
206 let mut best_max = max_val;
207 let mut min_kl = f32::MAX;
208
209 let min_step = (max_val - min_val) / 40.0;
211 let max_step = min_step;
212
213 for i in 0..10 {
214 let trial_min = min_val + i as f32 * min_step;
215
216 for j in 0..10 {
217 let trial_max = max_val - j as f32 * max_step;
218
219 if trial_min >= trial_max {
220 continue;
221 }
222
223 let quantization_step = (trial_max - trial_min) / (levels - 1) as f32;
225
226 let kl = calculate_kl_divergence_asymmetric(
228 &distribution,
229 min_val,
230 max_val,
231 bin_width,
232 trial_min,
233 trial_max,
234 quantization_step,
235 );
236
237 if kl < min_kl {
238 min_kl = kl;
239 best_min = trial_min;
240 best_max = trial_max;
241 }
242 }
243 }
244
245 (best_min, best_max)
246 }
247}
248
249#[allow(dead_code)]
251fn calculate_kl_divergence_symmetric(
252 distribution: &[f32],
253 min_val: f32,
254 _max_val: f32,
255 bin_width: f32,
256 abs_max: f32,
257 quantization_step: f32,
258) -> f32 {
259 let num_bins = distribution.len();
260
261 let mut quantized_dist = vec![0.0; num_bins];
263
264 for (bin_idx, &prob) in distribution.iter().enumerate() {
265 let orig_val = min_val + (bin_idx as f32 + 0.5) * bin_width;
267
268 let quantized_val = if orig_val > abs_max {
270 abs_max
271 } else if orig_val < -abs_max {
272 -abs_max
273 } else {
274 (orig_val / quantization_step).round() * quantization_step
276 };
277
278 let new_bin_idx = ((quantized_val - min_val) / bin_width).floor() as i32;
280
281 if new_bin_idx >= 0 && new_bin_idx < num_bins as i32 {
282 quantized_dist[new_bin_idx as usize] += prob;
283 }
284 }
285
286 let mut kl = 0.0;
288 for (i, &p) in distribution.iter().enumerate() {
289 if p > 0.0 {
290 let q = quantized_dist[i].max(1e-10); kl += p * (p / q).ln();
292 }
293 }
294
295 kl
296}
297
298#[allow(dead_code)]
300fn calculate_kl_divergence_asymmetric(
301 distribution: &[f32],
302 min_val: f32,
303 _max_val: f32,
304 bin_width: f32,
305 quant_min: f32,
306 quant_max: f32,
307 quantization_step: f32,
308) -> f32 {
309 let num_bins = distribution.len();
310
311 let mut quantized_dist = vec![0.0; num_bins];
313
314 for (bin_idx, &prob) in distribution.iter().enumerate() {
315 let orig_val = min_val + (bin_idx as f32 + 0.5) * bin_width;
317
318 let quantized_val = if orig_val > quant_max {
320 quant_max
321 } else if orig_val < quant_min {
322 quant_min
323 } else {
324 let steps = ((orig_val - quant_min) / quantization_step).round();
326 quant_min + steps * quantization_step
327 };
328
329 let new_bin_idx = ((quantized_val - min_val) / bin_width).floor() as i32;
331
332 if new_bin_idx >= 0 && new_bin_idx < num_bins as i32 {
333 quantized_dist[new_bin_idx as usize] += prob;
334 }
335 }
336
337 let mut kl = 0.0;
339 for (i, &p) in distribution.iter().enumerate() {
340 if p > 0.0 {
341 let q = quantized_dist[i].max(1e-10); kl += p * (p / q).ln();
343 }
344 }
345
346 kl
347}
348
349#[allow(dead_code)]
351pub fn optimize_symmetric_scale<F>(matrix: &ArrayView2<F>, bits: u8, basescale: f32) -> f32
352where
353 F: scirs2_core::numeric::Float
354 + Debug
355 + scirs2_core::numeric::AsPrimitive<f32>
356 + scirs2_core::numeric::FromPrimitive,
357 f32: scirs2_core::numeric::AsPrimitive<F>,
358{
359 let num_trials = 20;
360 let scales: Vec<f32> = (0..num_trials)
361 .map(|i| {
362 let factor = 0.5 + 1.5 * (i as f32 / (num_trials - 1) as f32);
363 basescale * factor
364 })
365 .collect();
366
367 let mut best_scale = basescale;
368 let mut min_mse = f32::MAX;
369
370 for &scale in &scales {
372 let abs_max = matrix
374 .mapv(|x| x.as_().abs())
375 .fold(0.0, |a: f32, &b| a.max(b));
376 let params = QuantizationParams {
377 bits,
378 scale,
379 zero_point: 0,
380 min_val: -abs_max,
381 max_val: abs_max,
382 method: if bits == 4 {
383 QuantizationMethod::Int4
384 } else {
385 QuantizationMethod::Symmetric
386 },
387 data_type: determine_data_type(bits),
388 channel_scales: None,
389 channel_zero_points: None,
390 };
391
392 let matrix_f32 = matrix.mapv(|x| x.as_());
394 let current_scale = params.scale;
395 let dequantized = matrix_f32.mapv(|x| {
396 let quantized = (x / scale)
397 .round()
398 .clamp(-(1 << (bits - 1)) as f32, ((1 << (bits - 1)) - 1) as f32);
399 quantized * current_scale
400 });
401
402 let mse = (&matrix_f32 - &dequantized).mapv(|x| x * x).sum() / matrix.len() as f32;
404
405 if mse < min_mse {
406 min_mse = mse;
407 best_scale = scale;
408 }
409 }
410
411 best_scale
412}
413
414#[allow(dead_code)]
416pub fn optimize_symmetric_scale_vec<F>(_vector: &ArrayView1<F>, bits: u8, basescale: f32) -> f32
417where
418 F: scirs2_core::numeric::Float
419 + Debug
420 + scirs2_core::numeric::AsPrimitive<f32>
421 + scirs2_core::numeric::FromPrimitive,
422 f32: scirs2_core::numeric::AsPrimitive<F>,
423{
424 let num_trials = 20;
425 let scales: Vec<f32> = (0..num_trials)
426 .map(|i| {
427 let factor = 0.5 + 1.5 * (i as f32 / (num_trials - 1) as f32);
428 basescale * factor
429 })
430 .collect();
431
432 let mut best_scale = basescale;
433 let mut min_mse = f32::MAX;
434
435 for &scale in &scales {
437 let abs_max = _vector
439 .mapv(|x| x.as_().abs())
440 .fold(0.0, |a: f32, &b| a.max(b));
441 let params = QuantizationParams {
442 bits,
443 scale,
444 zero_point: 0,
445 min_val: -abs_max,
446 max_val: abs_max,
447 method: if bits == 4 {
448 QuantizationMethod::Int4
449 } else {
450 QuantizationMethod::Symmetric
451 },
452 data_type: determine_data_type(bits),
453 channel_scales: None,
454 channel_zero_points: None,
455 };
456
457 let vector_f32 = _vector.mapv(|x| x.as_());
459 let current_scale = params.scale;
460 let dequantized = vector_f32.mapv(|x| {
461 let quantized = (x / scale)
462 .round()
463 .clamp(-(1 << (bits - 1)) as f32, ((1 << (bits - 1)) - 1) as f32);
464 quantized * current_scale
465 });
466
467 let mse = (&vector_f32 - &dequantized).mapv(|x| x * x).sum() / _vector.len() as f32;
469
470 if mse < min_mse {
471 min_mse = mse;
472 best_scale = scale;
473 }
474 }
475
476 best_scale
477}
478
479#[allow(dead_code)]
481pub fn optimize_affine_params<F>(
482 matrix: &ArrayView2<F>,
483 bits: u8,
484 base_scale: f32,
485 base_zero_point: i32,
486) -> (f32, i32)
487where
488 F: scirs2_core::numeric::Float
489 + Debug
490 + scirs2_core::numeric::AsPrimitive<f32>
491 + scirs2_core::numeric::FromPrimitive,
492 f32: scirs2_core::numeric::AsPrimitive<F>,
493{
494 let num_scale_trials = 10;
495 let num_zp_trials = 5;
496
497 let scales: Vec<f32> = (0..num_scale_trials)
498 .map(|i| {
499 let factor = 0.8 + 0.4 * (i as f32 / (num_scale_trials - 1) as f32);
500 base_scale * factor
501 })
502 .collect();
503
504 let zero_points: Vec<i32> = (0..num_zp_trials)
505 .map(|i| {
506 let offset = -2 + i;
507 base_zero_point + offset
508 })
509 .collect();
510
511 let mut best_scale = base_scale;
512 let mut best_zero_point = base_zero_point;
513 let mut min_mse = f32::MAX;
514
515 for &_scale in &scales {
517 for &zero_point in &zero_points {
518 let mut params = QuantizationParams {
520 bits,
521 scale: _scale,
522 zero_point,
523 min_val: 0.0, max_val: 0.0, method: QuantizationMethod::Affine,
526 data_type: determine_data_type(bits),
527 channel_scales: None,
528 channel_zero_points: None,
529 };
530
531 let matrix_f32 = matrix.mapv(|x| x.as_());
533 let scale = params.scale;
534 let zero_point = params.zero_point;
535
536 let mut min_val = f32::MAX;
538 let mut max_val = f32::MIN;
539 for &val in matrix_f32.iter() {
540 if val.is_finite() {
541 min_val = min_val.min(val);
542 max_val = max_val.max(val);
543 }
544 }
545 params.min_val = min_val;
546 params.max_val = max_val;
547
548 let dequantized = matrix_f32.mapv(|x| {
549 let quantized = ((x / scale) + zero_point as f32)
550 .round()
551 .clamp(0.0, ((1 << bits) - 1) as f32);
552 (quantized - zero_point as f32) * scale
553 });
554
555 let mse = (&matrix_f32 - &dequantized).mapv(|x| x * x).sum() / matrix.len() as f32;
557
558 if mse < min_mse {
559 min_mse = mse;
560 best_scale = scale;
561 best_zero_point = zero_point;
562 }
563 }
564 }
565
566 (best_scale, best_zero_point)
567}
568
569#[allow(dead_code)]
571pub fn optimize_affine_params_vec<F>(
572 vector: &ArrayView1<F>,
573 bits: u8,
574 base_scale: f32,
575 base_zero_point: i32,
576) -> (f32, i32)
577where
578 F: scirs2_core::numeric::Float
579 + Debug
580 + scirs2_core::numeric::AsPrimitive<f32>
581 + scirs2_core::numeric::FromPrimitive,
582 f32: scirs2_core::numeric::AsPrimitive<F>,
583{
584 let num_scale_trials = 10;
585 let num_zp_trials = 5;
586
587 let scales: Vec<f32> = (0..num_scale_trials)
588 .map(|i| {
589 let factor = 0.8 + 0.4 * (i as f32 / (num_scale_trials - 1) as f32);
590 base_scale * factor
591 })
592 .collect();
593
594 let zero_points: Vec<i32> = (0..num_zp_trials)
595 .map(|i| {
596 let offset = -2 + i;
597 base_zero_point + offset
598 })
599 .collect();
600
601 let mut best_scale = base_scale;
602 let mut best_zero_point = base_zero_point;
603 let mut min_mse = f32::MAX;
604
605 for &_scale in &scales {
607 for &zero_point in &zero_points {
608 let mut params = QuantizationParams {
610 bits,
611 scale: _scale,
612 zero_point,
613 min_val: 0.0, max_val: 0.0, method: QuantizationMethod::Affine,
616 data_type: determine_data_type(bits),
617 channel_scales: None,
618 channel_zero_points: None,
619 };
620
621 let vector_f32 = vector.mapv(|x| x.as_());
623 let scale = params.scale;
624 let zero_point = params.zero_point;
625
626 let mut min_val = f32::MAX;
628 let mut max_val = f32::MIN;
629 for &val in vector_f32.iter() {
630 if val.is_finite() {
631 min_val = min_val.min(val);
632 max_val = max_val.max(val);
633 }
634 }
635 params.min_val = min_val;
636 params.max_val = max_val;
637
638 let dequantized = vector_f32.mapv(|x| {
639 let quantized = ((x / scale) + zero_point as f32)
640 .round()
641 .clamp(0.0, ((1 << bits) - 1) as f32);
642 (quantized - zero_point as f32) * scale
643 });
644
645 let mse = (&vector_f32 - &dequantized).mapv(|x| x * x).sum() / vector.len() as f32;
647
648 if mse < min_mse {
649 min_mse = mse;
650 best_scale = scale;
651 best_zero_point = zero_point;
652 }
653 }
654 }
655
656 (best_scale, best_zero_point)
657}
658
659#[allow(dead_code)]
661pub fn create_params_from_range(
662 bits: u8,
663 min_val: f32,
664 max_val: f32,
665 symmetric: bool,
666) -> LinalgResult<QuantizationParams> {
667 let (method, scale, zero_point) = if symmetric {
668 let abs_max = max_val.abs().max(min_val.abs());
669 let scale = abs_max / ((1 << (bits - 1)) - 1) as f32;
670 (QuantizationMethod::Symmetric, scale, 0)
671 } else {
672 let method = QuantizationMethod::Affine;
673 let scale = (max_val - min_val) / ((1 << bits) - 1) as f32;
674 let zero_point = (-min_val / scale).round() as i32;
675 (method, scale, zero_point)
676 };
677
678 Ok(QuantizationParams {
679 bits,
680 scale,
681 zero_point,
682 min_val,
683 max_val,
684 method,
685 data_type: determine_data_type(bits),
686 channel_scales: None,
687 channel_zero_points: None,
688 })
689}
690
691#[allow(dead_code)]
693pub fn determine_data_type(bits: u8) -> super::super::QuantizedDataType {
694 use super::super::QuantizedDataType;
695
696 match bits {
697 4 => QuantizedDataType::Int4, 8 => QuantizedDataType::Int8, 16 => QuantizedDataType::Float16, _ => QuantizedDataType::Int8, }
702}
703
704#[cfg(test)]
705mod tests {
706 use super::*;
707 use approx::assert_relative_eq;
708 use scirs2_core::ndarray::array;
709
710 #[test]
711 fn test_find_min_max() {
712 let matrix = array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
713 let (min_val, max_val) = find_min_max(&matrix.view());
714 assert_eq!(min_val, 1.0);
715 assert_eq!(max_val, 9.0);
716 }
717
718 #[test]
719 fn test_find_min_max_vec() {
720 let vector = array![1.0f32, 2.0, 3.0, 4.0, 5.0];
721 let (min_val, max_val) = find_min_max_vec(&vector.view());
722 assert_eq!(min_val, 1.0);
723 assert_eq!(max_val, 5.0);
724 }
725
726 #[test]
727 fn test_create_histogram() {
728 let matrix = array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]];
729 let histogram = create_histogram(&matrix.view(), 1.0, 6.0, 5);
730
731 assert_eq!(histogram.iter().sum::<usize>(), 6); assert!(histogram.iter().all(|&count| count <= 2)); }
735
736 #[test]
737 fn test_create_params_from_range() {
738 let params = create_params_from_range(8, -5.0, 5.0, true).unwrap();
740 assert_eq!(params.method, QuantizationMethod::Symmetric);
741 assert_eq!(params.zero_point, 0);
742 assert_relative_eq!(params.scale, 5.0 / 127.0, epsilon = 1e-6);
743
744 let params = create_params_from_range(8, 1.0, 9.0, false).unwrap();
746 assert_eq!(params.method, QuantizationMethod::Affine);
747 assert_relative_eq!(params.scale, 8.0 / 255.0, epsilon = 1e-6);
748 assert_eq!(params.zero_point, (-1.0 / params.scale).round() as i32);
749 }
750}